diff --git a/server/router.go b/server/router.go index 44747e3..8a3c278 100644 --- a/server/router.go +++ b/server/router.go @@ -36,7 +36,7 @@ func NewRedisRouter(server *Identity, redisUrl string, password string, db int, return &r } -func (r *RedisRouter) Route(msg *meowlib.ToServerMessage) (*meowlib.FromServerMessage, error) { +func (r *RedisRouter) Route(ctx context.Context, msg *meowlib.ToServerMessage) (*meowlib.FromServerMessage, error) { var from_server *meowlib.FromServerMessage // update messages counter err := r.Client.Incr("statistics:messages:total").Err() @@ -60,7 +60,7 @@ func (r *RedisRouter) Route(msg *meowlib.ToServerMessage) (*meowlib.FromServerMe } if msg.Timeout > 0 && len(from_server.Chat) == 0 && from_server.Invitation == nil { logger.Info().Msg("long poll, subscribing for messages") - from_server, err = r.subscribe(msg, int(msg.Timeout)) + from_server, err = r.subscribe(ctx, msg, int(msg.Timeout)) if err != nil { return nil, err } @@ -205,7 +205,7 @@ func (r *RedisRouter) checkForMessage(msg *meowlib.ToServerMessage) (*meowlib.Fr } -func (r *RedisRouter) subscribe(msg *meowlib.ToServerMessage, timeout int) (*meowlib.FromServerMessage, error) { +func (r *RedisRouter) subscribe(reqCtx context.Context, msg *meowlib.ToServerMessage, timeout int) (*meowlib.FromServerMessage, error) { if err := r.Client.Incr("statistics:messages:messagessubscription").Err(); err != nil { return nil, err } @@ -237,7 +237,7 @@ func (r *RedisRouter) subscribe(msg *meowlib.ToServerMessage, timeout int) (*meo return fromServer, nil } - ctx, cancel := context.WithTimeout(r.Context, time.Duration(timeout)*time.Second) + ctx, cancel := context.WithTimeout(reqCtx, time.Duration(timeout)*time.Second) defer cancel() ch := pubsub.Channel() diff --git a/server/router_test.go b/server/router_test.go index eed4a6c..b5c7400 100644 --- a/server/router_test.go +++ b/server/router_test.go @@ -220,7 +220,7 @@ func TestRouteDispatchesStoreAndCheck(t *testing.T) { {Destination: dest, Payload: []byte("routed msg")}, }, } - resp, err := router.Route(storeReq) + resp, err := router.Route(context.Background(),storeReq) assert.NoError(t, err) assert.Equal(t, "route-store-uuid", resp.UuidAck) @@ -230,7 +230,7 @@ func TestRouteDispatchesStoreAndCheck(t *testing.T) { {LookupKey: dest}, }, } - resp, err = router.Route(pullReq) + resp, err = router.Route(context.Background(),pullReq) assert.NoError(t, err) assert.Len(t, resp.Chat, 1) assert.Equal(t, []byte("routed msg"), resp.Chat[0].Payload) @@ -241,7 +241,7 @@ func TestRouteEmptyMessage(t *testing.T) { router, mr := newTestRouter(t) defer mr.Close() - resp, err := router.Route(&meowlib.ToServerMessage{}) + resp, err := router.Route(context.Background(),&meowlib.ToServerMessage{}) assert.NoError(t, err) assert.Nil(t, resp) } @@ -251,9 +251,9 @@ func TestRouteIncrementsTotalCounter(t *testing.T) { router, mr := newTestRouter(t) defer mr.Close() - router.Route(&meowlib.ToServerMessage{}) - router.Route(&meowlib.ToServerMessage{}) - router.Route(&meowlib.ToServerMessage{}) + router.Route(context.Background(),&meowlib.ToServerMessage{}) + router.Route(context.Background(),&meowlib.ToServerMessage{}) + router.Route(context.Background(),&meowlib.ToServerMessage{}) val, err := router.Client.Get("statistics:messages:total").Int() assert.NoError(t, err) @@ -554,7 +554,7 @@ func TestRouteMatriochka(t *testing.T) { Data: []byte("wrapped"), }, } - resp, err := router.Route(msg) + resp, err := router.Route(context.Background(),msg) assert.NoError(t, err) assert.Equal(t, "route-mtk", resp.UuidAck) @@ -578,7 +578,7 @@ func TestRouteInvitation(t *testing.T) { ShortcodeLen: 6, }, } - resp, err := router.Route(msg) + resp, err := router.Route(context.Background(),msg) assert.NoError(t, err) assert.NotEmpty(t, resp.Invitation.Shortcode) assert.Len(t, resp.Invitation.Shortcode, 6) @@ -595,7 +595,7 @@ func TestStatisticsCountersIncrement(t *testing.T) { dest := "stats-dest" // one store increments usermessages - router.Route(&meowlib.ToServerMessage{ + router.Route(context.Background(),&meowlib.ToServerMessage{ Messages: []*meowlib.PackedUserMessage{ {Destination: dest, Payload: []byte("x")}, }, @@ -604,7 +604,7 @@ func TestStatisticsCountersIncrement(t *testing.T) { assert.Equal(t, 1, val) // one pull increments messagelookups - router.Route(&meowlib.ToServerMessage{ + router.Route(context.Background(),&meowlib.ToServerMessage{ PullRequest: []*meowlib.ConversationRequest{ {LookupKey: dest}, }, @@ -613,14 +613,14 @@ func TestStatisticsCountersIncrement(t *testing.T) { assert.Equal(t, 1, val) // one matriochka increments matriochka counter - router.Route(&meowlib.ToServerMessage{ + router.Route(context.Background(),&meowlib.ToServerMessage{ MatriochkaMessage: &meowlib.Matriochka{Data: []byte("m")}, }) val, _ = router.Client.Get("statistics:messages:matriochka").Int() assert.Equal(t, 1, val) // one invitation increments invitation counter - router.Route(&meowlib.ToServerMessage{ + router.Route(context.Background(),&meowlib.ToServerMessage{ Invitation: &meowlib.Invitation{ Step: 1, Payload: []byte("i"),