From bb546752936ec5b191df7823519e5d5a10671e52 Mon Sep 17 00:00:00 2001 From: Toan Nguyen Date: Tue, 28 Feb 2023 22:37:42 +0700 Subject: [PATCH] subscription improvement, merge reset subscription logic into Run (#76) * merge reset subscription logic into Run * fix data race issues and make subscription states immutable --- .github/workflows/test.yml | 3 + go.mod | 4 +- go.sum | 16 +- graphql.go | 2 +- graphql_test.go | 8 +- subscription.go | 262 +++++++++++++++++++++----------- subscription_graphql_ws.go | 19 +-- subscription_graphql_ws_test.go | 108 ++++++++++++- subscription_test.go | 160 ++++++++++++++++++- subscriptions_transport_ws.go | 27 ++-- 10 files changed, 476 insertions(+), 133 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7f96bd7..8b929d3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,3 +61,6 @@ jobs: with: recreate: true path: code-coverage-results.md + - name: Dump docker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 diff --git a/go.mod b/go.mod index 6899379..9f9bde2 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,10 @@ go 1.16 require ( github.com/google/uuid v1.3.0 - github.com/graph-gophers/graphql-go v1.4.0 + github.com/graph-gophers/graphql-go v1.5.0 github.com/graph-gophers/graphql-transport-ws v0.0.2 + golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b // indirect + golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect nhooyr.io/websocket v1.8.7 ) diff --git a/go.sum b/go.sum index ddff347..bdf7594 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/graph-gophers/graphql-go v1.4.0 h1:JE9wveRTSXwJyjdRd6bOQ7Ob5bewTUQ58Jv4OiVdpdE= -github.com/graph-gophers/graphql-go v1.4.0/go.mod h1:YtmJZDLbF1YYNrlNAuiO5zAStUWc3XZT07iGsVqe1Os= +github.com/graph-gophers/graphql-go v1.5.0 h1:fDqblo50TEpD0LY7RXk/LFVYEVqo3+tXMNMPSVXA1yc= +github.com/graph-gophers/graphql-go v1.5.0/go.mod h1:YtmJZDLbF1YYNrlNAuiO5zAStUWc3XZT07iGsVqe1Os= github.com/graph-gophers/graphql-transport-ws v0.0.2 h1:DbmSkbIGzj8SvHei6n8Mh9eLQin8PtA8xY9eCzjRpvo= github.com/graph-gophers/graphql-transport-ws v0.0.2/go.mod h1:5BVKvFzOd2BalVIBFfnfmHjpJi/MZ5rOj8G55mXvZ8g= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= @@ -64,15 +64,23 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= go.opentelemetry.io/otel/trace v1.6.3/go.mod h1:GNJQusJlUgZl9/TQBPKU/Y/ty+0iVB5fjhKeJGZPGFs= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b h1:Qwe1rC8PSniVfAFPFJeyUkB+zcysC3RgJBAGk7eqBEU= +golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/graphql.go b/graphql.go index ff378fe..1e07c09 100644 --- a/graphql.go +++ b/graphql.go @@ -321,7 +321,7 @@ type Error struct { // Error implements error interface. func (e Error) Error() string { - return fmt.Sprintf("Message: %s, Locations: %+v", e.Message, e.Locations) + return fmt.Sprintf("Message: %s, Locations: %+v, Extensions: %+v", e.Message, e.Locations, e.Extensions) } // Error implements error interface. diff --git a/graphql_test.go b/graphql_test.go index 44db7c2..05e7144 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -61,7 +61,7 @@ func TestClient_Query_partialDataWithErrorResponse(t *testing.T) { if err == nil { t.Fatal("got error: nil, want: non-nil") } - if got, want := err.Error(), "Message: Could not resolve to a node with the global id of 'NotExist', Locations: [{Line:10 Column:4}]"; got != want { + if got, want := err.Error(), "Message: Could not resolve to a node with the global id of 'NotExist', Locations: [{Line:10 Column:4}], Extensions: map[]"; got != want { t.Errorf("got error: %v, want: %v", got, want) } @@ -111,7 +111,7 @@ func TestClient_Query_partialDataRawQueryWithErrorResponse(t *testing.T) { if err == nil { t.Fatal("got error: nil, want: non-nil\n") } - if got, want := err.Error(), "Message: Could not resolve to a node with the global id of 'NotExist', Locations: [{Line:10 Column:4}]"; got != want { + if got, want := err.Error(), "Message: Could not resolve to a node with the global id of 'NotExist', Locations: [{Line:10 Column:4}], Extensions: map[]"; got != want { t.Errorf("got error: %v, want: %v\n", got, want) } if q.Node1 == nil || string(q.Node1) != `{"id":"MDEyOklzc3VlQ29tbWVudDE2OTQwNzk0Ng=="}` { @@ -166,7 +166,7 @@ func TestClient_Query_noDataWithErrorResponse(t *testing.T) { if err == nil { t.Fatal("got error: nil, want: non-nil") } - if got, want := err.Error(), "Message: Field 'user' is missing required arguments: login, Locations: [{Line:7 Column:3}]"; got != want { + if got, want := err.Error(), "Message: Field 'user' is missing required arguments: login, Locations: [{Line:7 Column:3}], Extensions: map[]"; got != want { t.Errorf("got error: %v, want: %v", got, want) } if q.User.Name != "" { @@ -216,7 +216,7 @@ func TestClient_Query_errorStatusCode(t *testing.T) { if err == nil { t.Fatal("got error: nil, want: non-nil") } - if got, want := err.Error(), `Message: 500 Internal Server Error; body: "important message\n", Locations: []`; got != want { + if got, want := err.Error(), `Message: 500 Internal Server Error; body: "important message\n", Locations: [], Extensions: map[code:request_error]`; got != want { t.Errorf("got error: %v, want: %v", got, want) } if q.User.Name != "" { diff --git a/subscription.go b/subscription.go index c91fc34..1d692bc 100644 --- a/subscription.go +++ b/subscription.go @@ -21,6 +21,11 @@ import ( type SubscriptionProtocolType String const ( + // internal state machine status + scStatusInitializing int32 = 0 + scStatusRunning int32 = 1 + scStatusClosing int32 = 2 + SubscriptionsTransportWS SubscriptionProtocolType = "subscriptions-transport-ws" GraphQLWS SubscriptionProtocolType = "graphql-ws" @@ -90,11 +95,11 @@ type SubscriptionProtocol interface { // ConnectionInit sends a initial request to establish a connection within the existing socket ConnectionInit(ctx *SubscriptionContext, connectionParams map[string]interface{}) error // Subscribe requests an graphql operation specified in the payload message - Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error + Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error // Unsubscribe sends a request to stop listening and complete the subscription Unsubscribe(ctx *SubscriptionContext, id string) error // OnMessage listens ongoing messages from server - OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) + OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) // Close terminates all subscriptions of the current websocket Close(ctx *SubscriptionContext) error } @@ -102,14 +107,14 @@ type SubscriptionProtocol interface { // SubscriptionContext represents a shared context for protocol implementations with the websocket connection inside type SubscriptionContext struct { context.Context - WebsocketConn + websocketConn WebsocketConn OnConnected func() onDisconnected func() cancel context.CancelFunc - subscriptions map[string]*Subscription + subscriptions map[string]Subscription disabledLogTypes []OperationMessageType log func(args ...interface{}) - acknowledged int64 + acknowledged int32 exitStatusCodes []int mutex sync.Mutex } @@ -128,16 +133,44 @@ func (sc *SubscriptionContext) Log(message interface{}, source string, opType Op sc.log(message, source) } +// GetContext get the inner context +func (sc *SubscriptionContext) GetContext() context.Context { + sc.mutex.Lock() + defer sc.mutex.Unlock() + return sc.Context +} + +// GetContext set the inner context +func (sc *SubscriptionContext) NewContext() { + sc.mutex.Lock() + defer sc.mutex.Unlock() + ctx, cancel := context.WithCancel(context.Background()) + sc.Context = ctx + sc.cancel = cancel +} + +// SetCancel set the cancel function of the inner context +func (sc *SubscriptionContext) Cancel() { + sc.mutex.Lock() + defer sc.mutex.Unlock() + if sc.cancel != nil { + sc.cancel() + sc.cancel = nil + } +} + // GetWebsocketConn get the current websocket connection func (sc *SubscriptionContext) GetWebsocketConn() WebsocketConn { - return sc.WebsocketConn + sc.mutex.Lock() + defer sc.mutex.Unlock() + return sc.websocketConn } // SetWebsocketConn set the current websocket connection func (sc *SubscriptionContext) SetWebsocketConn(conn WebsocketConn) { sc.mutex.Lock() defer sc.mutex.Unlock() - sc.WebsocketConn = conn + sc.websocketConn = conn } // GetSubscription get the subscription state by id @@ -148,12 +181,21 @@ func (sc *SubscriptionContext) GetSubscription(id string) *Subscription { return nil } sub, _ := sc.subscriptions[id] - return sub + return &sub +} + +// GetSubscriptionsLength returns the length of subscriptions +func (sc *SubscriptionContext) GetSubscriptionsLength() int { + sc.mutex.Lock() + defer sc.mutex.Unlock() + return len(sc.subscriptions) } // GetSubscription get all available subscriptions in the context -func (sc *SubscriptionContext) GetSubscriptions() map[string]*Subscription { - newMap := make(map[string]*Subscription) +func (sc *SubscriptionContext) GetSubscriptions() map[string]Subscription { + sc.mutex.Lock() + defer sc.mutex.Unlock() + newMap := make(map[string]Subscription) for k, v := range sc.subscriptions { newMap[k] = v } @@ -164,42 +206,38 @@ func (sc *SubscriptionContext) GetSubscriptions() map[string]*Subscription { // if subscription is nil, removes the subscription from the map func (sc *SubscriptionContext) SetSubscription(id string, sub *Subscription) { sc.mutex.Lock() + defer sc.mutex.Unlock() if sub == nil { delete(sc.subscriptions, id) } else { - sc.subscriptions[id] = sub + sc.subscriptions[id] = *sub } - sc.mutex.Unlock() } // GetAcknowledge get the acknowledge status func (sc *SubscriptionContext) GetAcknowledge() bool { - return atomic.LoadInt64(&sc.acknowledged) > 0 + return atomic.LoadInt32(&sc.acknowledged) > 0 } // SetAcknowledge set the acknowledge status func (sc *SubscriptionContext) SetAcknowledge(value bool) { if value { - atomic.StoreInt64(&sc.acknowledged, 1) + atomic.StoreInt32(&sc.acknowledged, 1) } else { - atomic.StoreInt64(&sc.acknowledged, 0) + atomic.StoreInt32(&sc.acknowledged, 0) } } // Close closes the context and the inner websocket connection if exists func (sc *SubscriptionContext) Close() error { + var err error if conn := sc.GetWebsocketConn(); conn != nil { - err := conn.Close() + err = conn.Close() sc.SetWebsocketConn(nil) - if err != nil { - return err - } - } - if sc.cancel != nil { - sc.cancel() } + sc.Cancel() - return nil + return err } // Send emits a message to the graphql server @@ -249,7 +287,7 @@ type SubscriptionClient struct { protocol SubscriptionProtocol websocketOptions WebsocketOptions timeout time.Duration - isRunning int64 + clientStatus int32 readLimit int64 // max size of response message. Default 10 MB createConn func(sc *SubscriptionClient) (WebsocketConn, error) retryTimeout time.Duration @@ -268,7 +306,7 @@ func NewSubscriptionClient(url string) *SubscriptionClient { errorChan: make(chan error), protocol: &subscriptionsTransportWS{}, context: &SubscriptionContext{ - subscriptions: make(map[string]*Subscription), + subscriptions: make(map[string]Subscription), }, } } @@ -285,7 +323,7 @@ func (sc *SubscriptionClient) GetTimeout() time.Duration { // GetContext returns current context of subscription client func (sc *SubscriptionClient) GetContext() context.Context { - return sc.context.Context + return sc.context.GetContext() } // WithWebSocket replaces customized websocket client constructor @@ -381,28 +419,26 @@ func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient { return sc } +// get internal client status +func (sc *SubscriptionClient) getClientStatus() int32 { + return atomic.LoadInt32(&sc.clientStatus) +} + // set the running atomic lock status -func (sc *SubscriptionClient) setIsRunning(value bool) { - if value { - atomic.StoreInt64(&sc.isRunning, 1) - } else { - atomic.StoreInt64(&sc.isRunning, 0) - } +func (sc *SubscriptionClient) setClientStatus(value int32) { + atomic.StoreInt32(&sc.clientStatus, value) } // initializes the websocket connection func (sc *SubscriptionClient) init() error { now := time.Now() - ctx, cancel := context.WithCancel(context.Background()) - sc.context.Context = ctx - sc.context.cancel = cancel - for { var err error var conn WebsocketConn // allow custom websocket client if sc.context.GetWebsocketConn() == nil { + sc.context.NewContext() conn, err = sc.createConn(sc) if err == nil { sc.context.SetWebsocketConn(conn) @@ -410,7 +446,7 @@ func (sc *SubscriptionClient) init() error { } if err == nil { - sc.context.SetReadLimit(sc.readLimit) + sc.context.GetWebsocketConn().SetReadLimit(sc.readLimit) // send connection init event to the server connectionParams := sc.connectionParams if sc.connectionParamsFn != nil { @@ -479,9 +515,10 @@ func (sc *SubscriptionClient) doRaw(query string, variables map[string]interface handler: sc.wrapHandler(handler), } - // if the websocket client is running, start subscription immediately - if atomic.LoadInt64(&sc.isRunning) > 0 { - if err := sc.protocol.Subscribe(sc.context, id, &sub); err != nil { + // if the websocket client is running and acknowledged by the server + // start subscription immediately + if sc.context != nil && sc.context.GetAcknowledge() { + if err := sc.protocol.Subscribe(sc.context, id, sub); err != nil { return "", err } } @@ -505,31 +542,39 @@ func (sc *SubscriptionClient) Unsubscribe(id string) error { return sc.protocol.Unsubscribe(sc.context, id) } -// Run start websocket client and subscriptions. If this function is run with goroutine, it can be stopped after closed +// Run start the WebSocket client and subscriptions. +// If the client is running, recalling this function will restart all registered subscriptions +// If this function is run with goroutine, it can be stopped after closed func (sc *SubscriptionClient) Run() error { + + sc.reset() if err := sc.init(); err != nil { return fmt.Errorf("retry timeout. exiting...") } - sc.setIsRunning(true) + sc.setClientStatus(scStatusRunning) + if sc.context == nil { + return fmt.Errorf("the subscription context is nil") + } + conn := sc.context.GetWebsocketConn() + if conn == nil { + return fmt.Errorf("the websocket connection hasn't been created") + } + ctx := sc.context.GetContext() + go func() { - for atomic.LoadInt64(&sc.isRunning) > 0 { + for sc.getClientStatus() == scStatusRunning { select { - case <-sc.context.Done(): + case <-ctx.Done(): return default: - if sc.context == nil || sc.context.GetWebsocketConn() == nil { - return - } - var message OperationMessage - if err := sc.context.ReadJSON(&message); err != nil { + if err := conn.ReadJSON(&message); err != nil { // manual EOF check if err == io.EOF || strings.Contains(err.Error(), "EOF") { - if err = sc.Reset(); err != nil { - sc.errorChan <- err - return - } + sc.setClientStatus(scStatusInitializing) + sc.context.Cancel() + return } closeStatus := websocket.CloseStatus(err) switch closeStatus { @@ -543,16 +588,21 @@ func (sc *SubscriptionClient) Run() error { if closeStatus != -1 && closeStatus < 3000 && closeStatus > 4999 { sc.context.Log(fmt.Sprintf("%s. Retry connecting...", err), "client", GQLInternal) - if err = sc.Reset(); err != nil { + if err = sc.Run(); err != nil { sc.errorChan <- err return } } + if isClosedSubscriptionError(err) { + _ = sc.Close() + return + } + if sc.onError != nil { if err = sc.onError(sc, err); err != nil { // end the subscription if the callback return error - sc.Close() + _ = sc.Close() return } } @@ -560,78 +610,96 @@ func (sc *SubscriptionClient) Run() error { } sub := sc.context.GetSubscription(message.ID) - go sc.protocol.OnMessage(sc.context, sub, message) + go sc.protocol.OnMessage(sc.context, *sub, message) } } }() - for atomic.LoadInt64(&sc.isRunning) > 0 { + for sc.getClientStatus() == scStatusRunning { select { - case <-sc.context.Done(): - return nil + case <-ctx.Done(): + if sc.context.GetSubscriptionsLength() == 0 { + sc.setClientStatus(scStatusClosing) + } + break case e := <-sc.errorChan: // stop the subscription if the error has stop message if e == ErrSubscriptionStopped { - return nil + return sc.Close() } if sc.onError != nil { if err := sc.onError(sc, e); err != nil { + _ = sc.Close() return err } } } } - // if the running status is false, stop retrying - if atomic.LoadInt64(&sc.isRunning) == 0 { + // if the client is closing, stop retrying + if sc.getClientStatus() == scStatusClosing { return nil } - return sc.Reset() + return sc.Run() } -// Reset restart websocket connection and subscriptions -func (sc *SubscriptionClient) Reset() error { +// close the running websocket connection and reset all subscription states +func (sc *SubscriptionClient) reset() { sc.context.SetAcknowledge(false) - isRunning := atomic.LoadInt64(&sc.isRunning) == 0 + isRunning := sc.getClientStatus() == scStatusRunning for id, sub := range sc.context.GetSubscriptions() { - sub.SetStarted(false) - if isRunning { - _ = sc.protocol.Unsubscribe(sc.context, id) - sc.context.SetSubscription(id, sub) + if sub.GetStarted() { + sub.SetStarted(false) + if isRunning { + _ = sc.protocol.Unsubscribe(sc.context, id) + } + sc.context.SetSubscription(id, &sub) } } + sc.setClientStatus(scStatusClosing) - if sc.context.GetWebsocketConn() != nil { - _ = sc.protocol.Close(sc.context) - _ = sc.context.Close() - sc.context.SetWebsocketConn(nil) - } - - return sc.Run() + _ = sc.protocol.Close(sc.context) + _ = sc.context.Close() } // Close closes all subscription channel and websocket as well func (sc *SubscriptionClient) Close() (err error) { - sc.setIsRunning(false) + if sc.getClientStatus() == scStatusClosing { + return nil + } + + sc.setClientStatus(scStatusClosing) + if sc.context == nil { + return + } + + unsubscribeErrors := make(map[string]error) + for id := range sc.context.GetSubscriptions() { - if err = sc.protocol.Unsubscribe(sc.context, id); err != nil { - sc.context.cancel() - return + if err := sc.protocol.Unsubscribe(sc.context, id); err != nil && !isClosedSubscriptionError(err) { + unsubscribeErrors[id] = err } } + protocolCloseError := sc.protocol.Close(sc.context) + closeError := sc.context.Close() - if sc.context != nil { - _ = sc.protocol.Close(sc.context) - err = sc.context.Close() - sc.context.SetWebsocketConn(nil) - if sc.context.onDisconnected != nil { - sc.context.onDisconnected() - } + if sc.context.onDisconnected != nil { + sc.context.onDisconnected() } - return + if len(unsubscribeErrors) > 0 || protocolCloseError != nil || closeError != nil { + return Error{ + Message: "failed to close the subscription client", + Extensions: map[string]interface{}{ + "unsubscribe": unsubscribeErrors, + "protocol": protocolCloseError, + "close": closeError, + }, + } + } + return nil } // the reusable function for sending connection init message. @@ -655,6 +723,22 @@ func connectionInit(conn *SubscriptionContext, connectionParams map[string]inter return conn.Send(msg, GQLConnectionInit) } +// accept closed websocket errors due to data races +func isClosedSubscriptionError(err error) bool { + + expectedErrorMessages := []string{ + "context canceled", + "received header with unexpected rsv bits", + } + errMsg := err.Error() + for _, msg := range expectedErrorMessages { + if strings.Contains(errMsg, msg) { + return true + } + } + return false +} + // default websocket handler implementation using https://github.com/nhooyr/websocket type WebsocketHandler struct { ctx context.Context diff --git a/subscription_graphql_ws.go b/subscription_graphql_ws.go index adab049..e7bf3a7 100644 --- a/subscription_graphql_ws.go +++ b/subscription_graphql_ws.go @@ -43,7 +43,7 @@ func (gws *graphqlWS) ConnectionInit(ctx *SubscriptionContext, connectionParams } // Subscribe requests an graphql operation specified in the payload message -func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error { +func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error { if sub.GetStarted() { return nil } @@ -63,13 +63,15 @@ func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub *Subscr } sub.SetStarted(true) + ctx.SetSubscription(id, &sub) + return nil } // Unsubscribe sends stop message to server and close subscription channel // The input parameter is subscription ID that is returned from Subscribe function func (gws *graphqlWS) Unsubscribe(ctx *SubscriptionContext, id string) error { - if ctx == nil || ctx.WebsocketConn == nil { + if ctx == nil || ctx.GetWebsocketConn() == nil { return nil } sub := ctx.GetSubscription(id) @@ -87,30 +89,23 @@ func (gws *graphqlWS) Unsubscribe(ctx *SubscriptionContext, id string) error { } err := ctx.Send(msg, GQLComplete) - if err != nil { - return err - } - // close the client if there is no running subscription - if len(ctx.GetSubscriptions()) == 0 { + if ctx.GetSubscriptionsLength() == 0 { ctx.Log("no running subscription. exiting...", "client", GQLInternal) return ctx.Close() } - return nil + return err } // OnMessage listens ongoing messages from server -func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) { +func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) { switch message.Type { case GQLError: ctx.Log(message, "server", message.Type) case GQLNext: ctx.Log(message, "server", message.Type) - if subscription == nil { - return - } var out struct { Data *json.RawMessage Errors Errors diff --git a/subscription_graphql_ws_test.go b/subscription_graphql_ws_test.go index b24b20d..8de2973 100644 --- a/subscription_graphql_ws_test.go +++ b/subscription_graphql_ws_test.go @@ -29,7 +29,7 @@ func (h headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) type user_insert_input map[string]interface{} -func graphqlWS_setupClients() (*Client, *SubscriptionClient) { +func hasura_setupClients(protocol SubscriptionProtocolType) (*Client, *SubscriptionClient) { endpoint := fmt.Sprintf("%s/v1/graphql", hasuraTestHost) client := NewClient(endpoint, &http.Client{Transport: headerRoundTripper{ setHeaders: func(req *http.Request) { @@ -39,7 +39,7 @@ func graphqlWS_setupClients() (*Client, *SubscriptionClient) { }}) subscriptionClient := NewSubscriptionClient(endpoint). - WithProtocol(GraphQLWS). + WithProtocol(protocol). WithConnectionParams(map[string]interface{}{ "headers": map[string]string{ "x-hasura-admin-secret": hasuraTestAdminSecret, @@ -81,7 +81,7 @@ func waitHasuraService(timeoutSecs int) error { func TestGraphqlWS_Subscription(t *testing.T) { stop := make(chan bool) - client, subscriptionClient := graphqlWS_setupClients() + client, subscriptionClient := hasura_setupClients(GraphQLWS) msg := randomID() subscriptionClient = subscriptionClient. @@ -174,3 +174,105 @@ func TestGraphqlWS_Subscription(t *testing.T) { <-stop } + +func TestGraphqlWS_SubscriptionRerun(t *testing.T) { + client, subscriptionClient := hasura_setupClients(GraphQLWS) + msg := randomID() + + subscriptionClient = subscriptionClient. + OnError(func(sc *SubscriptionClient, err error) error { + return err + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + if err := subscriptionClient.Run(); err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } + }() + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + // call a mutation request to send message to the subscription + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": msg, + }, + }, + } + err = client.Mutate(context.Background(), &q, variables, OperationName("InsertUser")) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + time.Sleep(2 * time.Second) + go func() { + time.Sleep(2 * time.Second) + subscriptionClient.Unsubscribe(subId1) + }() + + if err := subscriptionClient.Run(); err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } +} diff --git a/subscription_test.go b/subscription_test.go index 64e60da..dc00e79 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -114,7 +114,9 @@ func (r *resolver) broadcastHelloSaid() { case id := <-unsubscribe: delete(subscribers, id) case s := <-r.helloSaidSubscriber: - subscribers[randomID()] = s + id := randomID() + log.Println("new client subscribed: ", id) + subscribers[id] = s case e := <-r.helloSaidEvents: for id, s := range subscribers { go func(id string, s *helloSaidSubscriber) { @@ -401,3 +403,159 @@ func TestSubscriptionLifeCycle2(t *testing.T) { t.Fatalf("got error: %v, want: nil", err) } } + +func TestSubscription_ResetClient(t *testing.T) { + + stop := make(chan bool) + client, subscriptionClient := hasura_setupClients(SubscriptionsTransportWS) + msg := randomID() + + subscriptionClient. + OnError(func(sc *SubscriptionClient, err error) error { + t.Fatalf("got error: %v, want: nil", err) + return err + }). + OnDisconnected(func() { + log.Println("disconnected") + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + /* + subscription { + user { + id + name + } + } + */ + var sub2 struct { + Users []struct { + ID int `graphql:"id"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + subId2, err := subscriptionClient.Subscribe(sub2, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub2) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + + // call a mutation request to send message to the subscription + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": msg, + }, + }, + } + err = client.Mutate(context.Background(), &q, variables, OperationName("InsertUser")) + + if err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } + + time.Sleep(2 * time.Second) + // reset the subscription + log.Printf("resetting the subscription client...") + if err := subscriptionClient.Run(); err != nil { + (*t).Fatalf("failed to reset the subscription client. got error: %v, want: nil", err) + } + log.Printf("the second run was stopped") + stop <- true + }() + + go func() { + time.Sleep(8 * time.Second) + subscriptionClient.Unsubscribe(subId1) + subscriptionClient.Unsubscribe(subId2) + }() + + defer subscriptionClient.Close() + + if err := subscriptionClient.Run(); err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + <-stop +} diff --git a/subscriptions_transport_ws.go b/subscriptions_transport_ws.go index 7b5a0d6..a681cf2 100644 --- a/subscriptions_transport_ws.go +++ b/subscriptions_transport_ws.go @@ -70,7 +70,7 @@ func (stw *subscriptionsTransportWS) ConnectionInit(ctx *SubscriptionContext, co } // Subscribe requests an graphql operation specified in the payload message -func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error { +func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error { if sub.GetStarted() { return nil } @@ -90,13 +90,15 @@ func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id stri } sub.SetStarted(true) + ctx.SetSubscription(id, &sub) + return nil } // Unsubscribe sends stop message to server and close subscription channel // The input parameter is subscription ID that is returned from Subscribe function func (stw *subscriptionsTransportWS) Unsubscribe(ctx *SubscriptionContext, id string) error { - if ctx == nil || ctx.WebsocketConn == nil { + if ctx == nil || ctx.GetWebsocketConn() == nil { return nil } sub := ctx.GetSubscription(id) @@ -114,30 +116,24 @@ func (stw *subscriptionsTransportWS) Unsubscribe(ctx *SubscriptionContext, id st } err := ctx.Send(msg, GQLStop) - if err != nil { - return err - } // close the client if there is no running subscription - if len(ctx.GetSubscriptions()) == 0 { + if ctx.GetSubscriptionsLength() == 0 { ctx.Log("no running subscription. exiting...", "client", GQLInternal) return ctx.Close() } - return nil + return err } // OnMessage listens ongoing messages from server -func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) { +func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) { switch message.Type { case GQLError: ctx.Log(message, "server", GQLError) case GQLData: ctx.Log(message, "server", GQLData) - if subscription == nil { - return - } var out struct { Data *json.RawMessage Errors Errors @@ -163,7 +159,6 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri ctx.Log(message, "server", GQLConnectionError) _ = stw.Close(ctx) _ = ctx.Close() - ctx.cancel() case GQLComplete: ctx.Log(message, "server", GQLComplete) _ = stw.Unsubscribe(ctx, message.ID) @@ -177,7 +172,7 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri subscriptions := ctx.GetSubscriptions() for id, sub := range subscriptions { if err := stw.Subscribe(ctx, id, sub); err != nil { - stw.Unsubscribe(ctx, id) + _ = stw.Unsubscribe(ctx, id) return } } @@ -196,9 +191,5 @@ func (stw *subscriptionsTransportWS) Close(ctx *SubscriptionContext) error { Type: GQLConnectionTerminate, } - if ctx.WebsocketConn != nil { - return ctx.Send(msg, GQLConnectionTerminate) - } - - return nil + return ctx.Send(msg, GQLConnectionTerminate) }