diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8b929d3..5ff609f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,7 @@ jobs: cd ./example/hasura docker-compose up -d - name: Run Go unit tests - run: go test -v -race -coverprofile=coverage.out ./... + run: go test -v -race -timeout 3m -coverprofile=coverage.out ./... - name: Go coverage format run: | go get github.com/boumenot/gocover-cobertura diff --git a/README.md b/README.md index 66280e2..c67bf58 100644 --- a/README.md +++ b/README.md @@ -585,7 +585,12 @@ client. // max size of response message WithReadLimit(10*1024*1024). // these operation event logs won't be printed - WithoutLogTypes(graphql.GQLData, graphql.GQLConnectionKeepAlive) + WithoutLogTypes(graphql.GQLData, graphql.GQLConnectionKeepAlive). + // the client should exit when all subscriptions were closed, default true + WithExitWhenNoSubscription(false). + // WithRetryStatusCodes allow retry the subscription connection when receiving one of these codes + // the input parameter can be number string or range, e.g 4000-5000 + WithRetryStatusCodes("4000", "4000-4050") ``` #### Subscription Protocols @@ -629,6 +634,12 @@ client.OnDisconnected(fn func()) // If this function is empty, or returns nil, the error is ignored // If returns error, the websocket connection will be terminated client.OnError(onError func(sc *SubscriptionClient, err error) error) + +// OnConnectionAlive event is triggered when the websocket receive a connection alive message (differs per protocol) +client.OnConnectionAlive(fn func()) + +// OnSubscriptionComplete event is triggered when the subscription receives a terminated message from the server +client.OnSubscriptionComplete(fn func(sub Subscription)) ``` #### Custom HTTP Client diff --git a/query.go b/query.go index 37073ec..65748b0 100644 --- a/query.go +++ b/query.go @@ -88,22 +88,22 @@ func ConstructMutation(v interface{}, variables map[string]interface{}, options } // ConstructSubscription build GraphQL subscription string from struct and variables -func ConstructSubscription(v interface{}, variables map[string]interface{}, options ...Option) (string, error) { +func ConstructSubscription(v interface{}, variables map[string]interface{}, options ...Option) (string, string, error) { query, err := query(v) if err != nil { - return "", err + return "", "", err } optionsOutput, err := constructOptions(options) if err != nil { - return "", err + return "", "", err } if len(variables) > 0 { - return fmt.Sprintf("subscription %s(%s)%s%s", optionsOutput.operationName, queryArguments(variables), optionsOutput.OperationDirectivesString(), query), nil + return fmt.Sprintf("subscription %s(%s)%s%s", optionsOutput.operationName, queryArguments(variables), optionsOutput.OperationDirectivesString(), query), optionsOutput.operationName, nil } if optionsOutput.operationName == "" && len(optionsOutput.operationDirectives) == 0 { - return "subscription" + query, nil + return "subscription" + query, optionsOutput.operationName, nil } - return fmt.Sprintf("subscription %s%s%s", optionsOutput.operationName, optionsOutput.OperationDirectivesString(), query), nil + return fmt.Sprintf("subscription %s%s%s", optionsOutput.operationName, optionsOutput.OperationDirectivesString(), query), optionsOutput.operationName, nil } // queryArguments constructs a minified arguments string for variables. diff --git a/query_test.go b/query_test.go index fb8e7fd..e73de3a 100644 --- a/query_test.go +++ b/query_test.go @@ -635,11 +635,16 @@ func TestConstructSubscription(t *testing.T) { }, } for _, tc := range tests { - got, err := ConstructSubscription(tc.inV, tc.inVariables, OperationName(tc.name)) + got, gotName, err := ConstructSubscription(tc.inV, tc.inVariables, OperationName(tc.name)) if err != nil { t.Error(err) - } else if got != tc.want { - t.Errorf("\ngot: %q\nwant: %q\n", got, tc.want) + } else { + if got != tc.want { + t.Errorf("\ngot: %q\nwant: %q\n", got, tc.want) + } + if gotName != tc.name { + t.Errorf("\ninvalid operation name \ngot: %q\nwant: %q\n", gotName, tc.name) + } } } } diff --git a/subscription.go b/subscription.go index cf73de9..4c07052 100644 --- a/subscription.go +++ b/subscription.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "sync" "sync/atomic" @@ -18,7 +19,10 @@ import ( ) // SubscriptionProtocolType represents the protocol specification enum of the subscription -type SubscriptionProtocolType String +type SubscriptionProtocolType string + +// internal subscription status +type SubscriptionStatus int32 const ( // internal state machine status @@ -26,8 +30,20 @@ const ( scStatusRunning int32 = 1 scStatusClosing int32 = 2 + // SubscriptionWaiting the subscription hasn't been registered to the server + SubscriptionWaiting SubscriptionStatus = 0 + // SubscriptionRunning the subscription is up and running + SubscriptionRunning SubscriptionStatus = 1 + // SubscriptionUnsubcribed the subscription was manually unsubscribed by the user + SubscriptionUnsubcribed SubscriptionStatus = 2 + + // SubscriptionsTransportWS the enum implements the subscription transport that follows Apollo's subscriptions-transport-ws protocol specification + // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md SubscriptionsTransportWS SubscriptionProtocolType = "subscriptions-transport-ws" - GraphQLWS SubscriptionProtocolType = "graphql-ws" + + // GraphQLWS enum implements GraphQL over WebSocket Protocol (graphql-ws) + // https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md + GraphQLWS SubscriptionProtocolType = "graphql-ws" // Receiving a message of a type or format which is not specified in this document // The can be vaguely descriptive on why the received message is invalid. @@ -57,8 +73,12 @@ const ( GQL_INTERNAL = GQLInternal ) -// ErrSubscriptionStopped a special error which forces the subscription stop -var ErrSubscriptionStopped = errors.New("subscription stopped") +var ( + // ErrSubscriptionStopped a special error which forces the subscription stop + ErrSubscriptionStopped = errors.New("subscription stopped") + + errRetry = errors.New("retry subscription client") +) // OperationMessage represents a subscription operation message type OperationMessage struct { @@ -85,6 +105,10 @@ type WebsocketConn interface { // message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. SetReadLimit(limit int64) + // GetCloseStatus tries to get WebSocket close status from error + // return -1 if the error is unknown + // https://www.iana.org/assignments/websocket/websocket.xhtml + GetCloseStatus(error) int32 } // SubscriptionProtocol abstracts the life-cycle of subscription protocol implementation for a specific transport protocol @@ -95,11 +119,11 @@ type SubscriptionProtocol interface { // ConnectionInit sends a initial request to establish a connection within the existing socket ConnectionInit(ctx *SubscriptionContext, connectionParams map[string]interface{}) error // Subscribe requests an graphql operation specified in the payload message - Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error + Subscribe(ctx *SubscriptionContext, sub Subscription) error // Unsubscribe sends a request to stop listening and complete the subscription - Unsubscribe(ctx *SubscriptionContext, id string) error + Unsubscribe(ctx *SubscriptionContext, sub Subscription) error // OnMessage listens ongoing messages from server - OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) + OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) error // Close terminates all subscriptions of the current websocket Close(ctx *SubscriptionContext) error } @@ -107,17 +131,20 @@ type SubscriptionProtocol interface { // SubscriptionContext represents a shared context for protocol implementations with the websocket connection inside type SubscriptionContext struct { context.Context - websocketConn WebsocketConn - OnConnected func() - onDisconnected func() - onConnectionAlive func() - cancel context.CancelFunc - subscriptions map[string]Subscription - disabledLogTypes []OperationMessageType - log func(args ...interface{}) - acknowledged int32 - exitStatusCodes []int - mutex sync.Mutex + websocketConn WebsocketConn + + OnConnected func() + OnDisconnected func() + OnConnectionAlive func() + OnSubscriptionComplete func(sub Subscription) + + cancel context.CancelFunc + subscriptions map[string]Subscription + disabledLogTypes []OperationMessageType + log func(args ...interface{}) + acknowledged int32 + retryStatusCodes [][]int32 + mutex sync.Mutex } // Log prints condition logging with message type filters @@ -182,17 +209,35 @@ func (sc *SubscriptionContext) GetSubscription(id string) *Subscription { return nil } sub, found := sc.subscriptions[id] - if !found { - return nil + if found { + return &sub } - return &sub + + for _, s := range sc.subscriptions { + if id == s.id { + return &s + } + } + return nil } -// GetSubscriptionsLength returns the length of subscriptions -func (sc *SubscriptionContext) GetSubscriptionsLength() int { +// GetSubscriptionsLength returns the length of subscriptions by status +func (sc *SubscriptionContext) GetSubscriptionsLength(status []SubscriptionStatus) int { sc.mutex.Lock() defer sc.mutex.Unlock() - return len(sc.subscriptions) + if len(status) == 0 { + return len(sc.subscriptions) + } + count := 0 + for _, sub := range sc.subscriptions { + for _, s := range status { + if sub.status == s { + count++ + break + } + } + } + return count } // GetSubscription get all available subscriptions in the context @@ -208,13 +253,13 @@ func (sc *SubscriptionContext) GetSubscriptions() map[string]Subscription { // SetSubscription set the input subscription state into the context // if subscription is nil, removes the subscription from the map -func (sc *SubscriptionContext) SetSubscription(id string, sub *Subscription) { +func (sc *SubscriptionContext) SetSubscription(key string, sub *Subscription) { sc.mutex.Lock() defer sc.mutex.Unlock() if sub == nil { - delete(sc.subscriptions, id) + delete(sc.subscriptions, key) } else { - sc.subscriptions[id] = *sub + sc.subscriptions[key] = *sub } } @@ -236,9 +281,13 @@ func (sc *SubscriptionContext) SetAcknowledge(value bool) { func (sc *SubscriptionContext) Close() error { var err error if conn := sc.GetWebsocketConn(); conn != nil { - err = conn.Close() sc.SetWebsocketConn(nil) + if sc.OnDisconnected != nil { + sc.OnDisconnected() + } + err = conn.Close() } + sc.Cancel() return err @@ -257,24 +306,28 @@ type handlerFunc func(data []byte, err error) error // Subscription stores the subscription declaration and its state type Subscription struct { + id string + key string payload GraphQLRequestPayload handler func(data []byte, err error) - started bool + status SubscriptionStatus } -// GetPayload returns the graphql request payload -func (s Subscription) GetPayload() GraphQLRequestPayload { - return s.payload +// GetID returns the subscription ID +func (s Subscription) GetID() string { + return s.id } -// GetStarted a public getter for the started status -func (s Subscription) GetStarted() bool { - return s.started +// GetKey returns the unique key of the subscription map +// Key is the immutable id of the subscription that is generated the first time +// It is used for searching because the subscription id is refreshed whenever the client reset +func (s Subscription) GetKey() string { + return s.key } -// SetStarted a public getter for the started status -func (s *Subscription) SetStarted(value bool) { - s.started = value +// GetPayload returns the graphql request payload +func (s Subscription) GetPayload() GraphQLRequestPayload { + return s.payload } // GetHandler a public getter for the subscription handler @@ -282,33 +335,46 @@ func (s Subscription) GetHandler() func(data []byte, err error) { return s.handler } +// GetStatus a public getter for the subscription status +func (s Subscription) GetStatus() SubscriptionStatus { + return s.status +} + +// SetStatus a public getter for the subscription status +func (s *Subscription) SetStatus(status SubscriptionStatus) { + s.status = status +} + // SubscriptionClient is a GraphQL subscription client. type SubscriptionClient struct { - url string - context *SubscriptionContext - connectionParams map[string]interface{} - connectionParamsFn func() map[string]interface{} - protocol SubscriptionProtocol - websocketOptions WebsocketOptions - timeout time.Duration - clientStatus int32 - readLimit int64 // max size of response message. Default 10 MB - createConn func(sc *SubscriptionClient) (WebsocketConn, error) - retryTimeout time.Duration - onError func(sc *SubscriptionClient, err error) error - errorChan chan error + url string + context *SubscriptionContext + connectionParams map[string]interface{} + connectionParamsFn func() map[string]interface{} + protocol SubscriptionProtocol + websocketOptions WebsocketOptions + timeout time.Duration + clientStatus int32 + readLimit int64 // max size of response message. Default 10 MB + createConn func(sc *SubscriptionClient) (WebsocketConn, error) + retryTimeout time.Duration + onError func(sc *SubscriptionClient, err error) error + errorChan chan error + exitWhenNoSubscription bool + mutex sync.Mutex } // NewSubscriptionClient constructs new subscription client func NewSubscriptionClient(url string) *SubscriptionClient { return &SubscriptionClient{ - url: url, - timeout: time.Minute, - readLimit: 10 * 1024 * 1024, // set default limit 10MB - createConn: newWebsocketConn, - retryTimeout: time.Minute, - errorChan: make(chan error), - protocol: &subscriptionsTransportWS{}, + url: url, + timeout: time.Minute, + readLimit: 10 * 1024 * 1024, // set default limit 10MB + createConn: newWebsocketConn, + retryTimeout: time.Minute, + errorChan: make(chan error), + protocol: &subscriptionsTransportWS{}, + exitWhenNoSubscription: true, context: &SubscriptionContext{ subscriptions: make(map[string]Subscription), }, @@ -327,7 +393,7 @@ func (sc *SubscriptionClient) GetTimeout() time.Duration { // GetContext returns current context of subscription client func (sc *SubscriptionClient) GetContext() context.Context { - return sc.context.GetContext() + return sc.getContext().GetContext() } // WithWebSocket replaces customized websocket client constructor @@ -372,7 +438,7 @@ func (sc *SubscriptionClient) WithConnectionParamsFn(fn func() map[string]interf return sc } -// WithTimeout updates write timeout of websocket client +// WithTimeout updates read and write timeout of websocket client func (sc *SubscriptionClient) WithTimeout(timeout time.Duration) *SubscriptionClient { sc.timeout = timeout return sc @@ -385,6 +451,12 @@ func (sc *SubscriptionClient) WithRetryTimeout(timeout time.Duration) *Subscript return sc } +// WithExitWhenNoSubscription the client should exit when all subscriptions were closed +func (sc *SubscriptionClient) WithExitWhenNoSubscription(value bool) *SubscriptionClient { + sc.exitWhenNoSubscription = value + return sc +} + // WithLog sets logging function to print out received messages. By default, nothing is printed func (sc *SubscriptionClient) WithLog(logger func(args ...interface{})) *SubscriptionClient { sc.context.log = logger @@ -403,8 +475,20 @@ func (sc *SubscriptionClient) WithReadLimit(limit int64) *SubscriptionClient { return sc } +// WithRetryStatusCodes allow retry the subscription connection when receiving one of these codes +// the input parameter can be number string or range, e.g 4000-5000 +func (sc *SubscriptionClient) WithRetryStatusCodes(codes ...string) *SubscriptionClient { + + statusCodes, err := parseInt32Ranges(codes) + if err != nil { + panic(err) + } + sc.context.retryStatusCodes = statusCodes + return sc +} + // OnError event is triggered when there is any connection error. This is bottom exception handler level -// If this function is empty, or returns nil, the error is ignored +// If this function is empty, or returns nil, the client restarts the connection // If returns error, the websocket connection will be terminated func (sc *SubscriptionClient) OnError(onError func(sc *SubscriptionClient, err error) error) *SubscriptionClient { sc.onError = onError @@ -419,16 +503,34 @@ func (sc *SubscriptionClient) OnConnected(fn func()) *SubscriptionClient { // OnDisconnected event is triggered when the websocket client was disconnected func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient { - sc.context.onDisconnected = fn + sc.context.OnDisconnected = fn return sc } // OnConnectionAlive event is triggered when the websocket receive a connection alive message (differs per protocol) func (sc *SubscriptionClient) OnConnectionAlive(fn func()) *SubscriptionClient { - sc.context.onConnectionAlive = fn + sc.context.OnConnectionAlive = fn + return sc +} + +// OnSubscriptionComplete event is triggered when the subscription receives a terminated message from the server +func (sc *SubscriptionClient) OnSubscriptionComplete(fn func(sub Subscription)) *SubscriptionClient { + sc.context.OnSubscriptionComplete = fn return sc } +func (sc *SubscriptionClient) getContext() *SubscriptionContext { + sc.mutex.Lock() + defer sc.mutex.Unlock() + return sc.context +} + +func (sc *SubscriptionClient) setContext(value *SubscriptionContext) { + sc.mutex.Lock() + defer sc.mutex.Unlock() + sc.context = value +} + // get internal client status func (sc *SubscriptionClient) getClientStatus() int32 { return atomic.LoadInt32(&sc.clientStatus) @@ -443,26 +545,27 @@ func (sc *SubscriptionClient) setClientStatus(value int32) { func (sc *SubscriptionClient) init() error { now := time.Now() + ctx := sc.getContext() for { var err error var conn WebsocketConn // allow custom websocket client - if sc.context.GetWebsocketConn() == nil { - sc.context.NewContext() + if ctx.GetWebsocketConn() == nil { + ctx.NewContext() conn, err = sc.createConn(sc) if err == nil { - sc.context.SetWebsocketConn(conn) + ctx.SetWebsocketConn(conn) } } if err == nil { - sc.context.GetWebsocketConn().SetReadLimit(sc.readLimit) + ctx.GetWebsocketConn().SetReadLimit(sc.readLimit) // send connection init event to the server connectionParams := sc.connectionParams if sc.connectionParamsFn != nil { connectionParams = sc.connectionParamsFn() } - err = sc.protocol.ConnectionInit(sc.context, connectionParams) + err = sc.protocol.ConnectionInit(ctx, connectionParams) } if err == nil { @@ -470,12 +573,12 @@ func (sc *SubscriptionClient) init() error { } if sc.retryTimeout > 0 && now.Add(sc.retryTimeout).Before(time.Now()) { - if sc.context.onDisconnected != nil { - sc.context.onDisconnected() + if ctx.OnDisconnected != nil { + ctx.OnDisconnected() } return err } - sc.context.Log(fmt.Sprintf("%s. retry in second...", err.Error()), "client", GQLInternal) + ctx.Log(fmt.Sprintf("%s. retry in second...", err.Error()), "client", GQLInternal) time.Sleep(time.Second) } } @@ -497,44 +600,48 @@ func (sc *SubscriptionClient) NamedSubscribe(name string, v interface{}, variabl // SubscribeRaw sends start message to server and open a channel to receive data, with raw query // Deprecated: use Exec instead func (sc *SubscriptionClient) SubscribeRaw(query string, variables map[string]interface{}, handler func(message []byte, err error) error) (string, error) { - return sc.doRaw(query, variables, handler) + return sc.doRaw(query, variables, "", handler) } // Exec sends start message to server and open a channel to receive data, with raw query func (sc *SubscriptionClient) Exec(query string, variables map[string]interface{}, handler func(message []byte, err error) error) (string, error) { - return sc.doRaw(query, variables, handler) + return sc.doRaw(query, variables, "", handler) } func (sc *SubscriptionClient) do(v interface{}, variables map[string]interface{}, handler func(message []byte, err error) error, options ...Option) (string, error) { - query, err := ConstructSubscription(v, variables, options...) + query, operationName, err := ConstructSubscription(v, variables, options...) if err != nil { return "", err } - return sc.doRaw(query, variables, handler) + return sc.doRaw(query, variables, operationName, handler) } -func (sc *SubscriptionClient) doRaw(query string, variables map[string]interface{}, handler func(message []byte, err error) error) (string, error) { +func (sc *SubscriptionClient) doRaw(query string, variables map[string]interface{}, operationName string, handler func(message []byte, err error) error) (string, error) { id := uuid.New().String() sub := Subscription{ + id: id, + key: id, payload: GraphQLRequestPayload{ - Query: query, - Variables: variables, + Query: query, + Variables: variables, + OperationName: operationName, }, handler: sc.wrapHandler(handler), } // if the websocket client is running and acknowledged by the server // start subscription immediately - if sc.context != nil && sc.context.GetAcknowledge() { - if err := sc.protocol.Subscribe(sc.context, id, sub); err != nil { + ctx := sc.getContext() + if ctx != nil && sc.getClientStatus() == scStatusRunning && ctx.GetAcknowledge() { + if err := sc.protocol.Subscribe(ctx, sub); err != nil { return "", err } + } else { + ctx.SetSubscription(id, &sub) } - sc.context.SetSubscription(id, &sub) - return id, nil } @@ -549,7 +656,29 @@ func (sc *SubscriptionClient) wrapHandler(fn handlerFunc) func(data []byte, err // Unsubscribe sends stop message to server and close subscription channel // The input parameter is subscription ID that is returned from Subscribe function func (sc *SubscriptionClient) Unsubscribe(id string) error { - return sc.protocol.Unsubscribe(sc.context, id) + ctx := sc.getContext() + if ctx == nil || ctx.GetWebsocketConn() == nil { + return nil + } + sub := ctx.GetSubscription(id) + + if sub == nil { + return fmt.Errorf("subscription id %s doesn't not exist", id) + } + + if sub.status == SubscriptionUnsubcribed { + return nil + } + var err error + if sub.status == SubscriptionRunning { + err = sc.protocol.Unsubscribe(ctx, *sub) + } + sub.status = SubscriptionUnsubcribed + ctx.SetSubscription(sub.key, sub) + + sc.checkSubscriptionStatuses(ctx) + + return err } // Run start the WebSocket client and subscriptions. @@ -557,23 +686,29 @@ func (sc *SubscriptionClient) Unsubscribe(id string) error { // If this function is run with goroutine, it can be stopped after closed func (sc *SubscriptionClient) Run() error { - sc.reset() + if sc.getClientStatus() != scStatusInitializing { + sc.reset() + } + if err := sc.init(); err != nil { return fmt.Errorf("retry timeout. exiting...") } - sc.setClientStatus(scStatusRunning) - if sc.context == nil { + subContext := sc.getContext() + if subContext == nil { return fmt.Errorf("the subscription context is nil") } - conn := sc.context.GetWebsocketConn() + + conn := subContext.GetWebsocketConn() if conn == nil { return fmt.Errorf("the websocket connection hasn't been created") } - ctx := sc.context.GetContext() + + sc.setClientStatus(scStatusRunning) + ctx := subContext.GetContext() go func() { - for sc.getClientStatus() == scStatusRunning { + for { select { case <-ctx.Done(): return @@ -581,128 +716,168 @@ func (sc *SubscriptionClient) Run() error { var message OperationMessage if err := conn.ReadJSON(&message); err != nil { // manual EOF check - if err == io.EOF || strings.Contains(err.Error(), "EOF") { - sc.setClientStatus(scStatusInitializing) - sc.context.Cancel() + if err == io.EOF || strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection reset by peer") { + sc.errorChan <- errRetry return } - closeStatus := websocket.CloseStatus(err) - switch closeStatus { - case websocket.StatusNormalClosure, websocket.StatusAbnormalClosure: - // close event from websocket client, exiting... - return - case StatusConnectionInitialisationTimeout, StatusTooManyInitialisationRequests, StatusSubscriberAlreadyExists, StatusUnauthorized: - sc.context.Log(err, "server", GQLError) + if errors.Is(err, context.Canceled) { return } - if closeStatus != -1 && closeStatus < 3000 && closeStatus > 4999 { - sc.context.Log(fmt.Sprintf("%s. Retry connecting...", err), "client", GQLInternal) - if err = sc.Run(); err != nil { - sc.errorChan <- err + closeStatus := conn.GetCloseStatus(err) + + for _, retryCode := range subContext.retryStatusCodes { + if (len(retryCode) == 1 && retryCode[0] == closeStatus) || + (len(retryCode) >= 2 && retryCode[0] <= closeStatus && closeStatus <= retryCode[1]) { + sc.errorChan <- errRetry return } } - if isClosedSubscriptionError(err) { - _ = sc.Close() + switch websocket.StatusCode(closeStatus) { + case websocket.StatusBadGateway, websocket.StatusNoStatusRcvd: + sc.errorChan <- errRetry + return + case websocket.StatusNormalClosure, websocket.StatusAbnormalClosure: + // close event from websocket client, exiting... + subContext.Cancel() + return + case StatusInvalidMessage, StatusConnectionInitialisationTimeout, StatusTooManyInitialisationRequests, StatusSubscriberAlreadyExists, StatusUnauthorized: + subContext.Log(err, "server", GQL_CONNECTION_ERROR) + sc.errorChan <- err return } if sc.onError != nil { if err = sc.onError(sc, err); err != nil { // end the subscription if the callback return error - _ = sc.Close() + subContext.Cancel() return } } continue } - sub := sc.context.GetSubscription(message.ID) + sub := subContext.GetSubscription(message.ID) if sub == nil { sub = &Subscription{} } - go sc.protocol.OnMessage(sc.context, *sub, message) + go func() { + if err := sc.protocol.OnMessage(subContext, *sub, message); err != nil { + sc.errorChan <- err + } + + sc.checkSubscriptionStatuses(subContext) + }() } } }() - for sc.getClientStatus() == scStatusRunning { + for { select { case <-ctx.Done(): - if sc.context.GetSubscriptionsLength() == 0 { - sc.setClientStatus(scStatusClosing) - } - break + return sc.close(subContext) case e := <-sc.errorChan: + if sc.getClientStatus() == scStatusClosing { + return nil + } + // stop the subscription if the error has stop message if e == ErrSubscriptionStopped { - return sc.Close() + return sc.close(subContext) + } + if e == errRetry { + return sc.Run() } if sc.onError != nil { if err := sc.onError(sc, e); err != nil { - _ = sc.Close() + sc.close(subContext) return err + } else { + return sc.Run() } } } } - // if the client is closing, stop retrying - if sc.getClientStatus() == scStatusClosing { - return nil - } - - return sc.Run() } // close the running websocket connection and reset all subscription states func (sc *SubscriptionClient) reset() { - sc.context.SetAcknowledge(false) - isRunning := sc.getClientStatus() == scStatusRunning - - for id, sub := range sc.context.GetSubscriptions() { - if sub.GetStarted() { - sub.SetStarted(false) - if isRunning { - _ = sc.protocol.Unsubscribe(sc.context, id) - } - sc.context.SetSubscription(id, &sub) + subContext := sc.getContext() + // fork a new subscription context to start a new session + // avoid conflicting with the last running session what is shutting down + newContext := &SubscriptionContext{ + OnConnected: subContext.OnConnected, + OnDisconnected: subContext.OnDisconnected, + OnSubscriptionComplete: subContext.OnSubscriptionComplete, + disabledLogTypes: subContext.disabledLogTypes, + log: subContext.log, + retryStatusCodes: subContext.retryStatusCodes, + subscriptions: make(map[string]Subscription), + } + + for key, sub := range subContext.GetSubscriptions() { + // remove subscriptions that are manually unsubscribed by the user + if sub.status == SubscriptionUnsubcribed { + continue + } + if sub.status == SubscriptionRunning { + sc.protocol.Unsubscribe(subContext, sub) } + + // should restart subscriptions with new id + // to avoid subscription id conflict errors from the server + sub.id = uuid.NewString() + sub.status = SubscriptionWaiting + newContext.SetSubscription(key, &sub) } - sc.setClientStatus(scStatusClosing) - _ = sc.protocol.Close(sc.context) - _ = sc.context.Close() + sc.protocol.Close(subContext) + subContext.Close() + + sc.setClientStatus(scStatusInitializing) + sc.setContext(newContext) } // Close closes all subscription channel and websocket as well func (sc *SubscriptionClient) Close() (err error) { + return sc.close(sc.getContext()) +} + +func (sc *SubscriptionClient) close(ctx *SubscriptionContext) (err error) { if sc.getClientStatus() == scStatusClosing { return nil } sc.setClientStatus(scStatusClosing) - if sc.context == nil { + if ctx == nil { return } - unsubscribeErrors := make(map[string]error) - for id := range sc.context.GetSubscriptions() { - if err := sc.protocol.Unsubscribe(sc.context, id); err != nil && !isClosedSubscriptionError(err) { - unsubscribeErrors[id] = err + conn := ctx.GetWebsocketConn() + + for key, sub := range ctx.GetSubscriptions() { + ctx.SetSubscription(key, nil) + if conn == nil { + continue + } + if sub.status == SubscriptionRunning { + if err := sc.protocol.Unsubscribe(ctx, sub); err != nil { + unsubscribeErrors[key] = err + } } } - protocolCloseError := sc.protocol.Close(sc.context) - closeError := sc.context.Close() - if sc.context.onDisconnected != nil { - sc.context.onDisconnected() + var protocolCloseError error + if conn != nil { + protocolCloseError = sc.protocol.Close(ctx) } - if len(unsubscribeErrors) > 0 || protocolCloseError != nil || closeError != nil { + closeError := ctx.Close() + + if len(unsubscribeErrors) > 0 { return Error{ Message: "failed to close the subscription client", Extensions: map[string]interface{}{ @@ -715,6 +890,17 @@ func (sc *SubscriptionClient) Close() (err error) { return nil } +func (sc *SubscriptionClient) checkSubscriptionStatuses(ctx *SubscriptionContext) { + // close the client if there is no running subscription + if sc.exitWhenNoSubscription && ctx.GetSubscriptionsLength([]SubscriptionStatus{ + SubscriptionRunning, + SubscriptionWaiting, + }) == 0 { + ctx.Log("no running subscription. exiting...", "client", GQLInternal) + ctx.Cancel() + } +} + // the reusable function for sending connection init message. // The payload format of both subscriptions-transport-ws and graphql-ws are the same func connectionInit(conn *SubscriptionContext, connectionParams map[string]interface{}) error { @@ -736,20 +922,24 @@ func connectionInit(conn *SubscriptionContext, connectionParams map[string]inter return conn.Send(msg, GQLConnectionInit) } -// accept closed websocket errors due to data races -func isClosedSubscriptionError(err error) bool { - - expectedErrorMessages := []string{ - "context canceled", - "received header with unexpected rsv bits", - } - errMsg := err.Error() - for _, msg := range expectedErrorMessages { - if strings.Contains(errMsg, msg) { - return true +func parseInt32Ranges(codes []string) ([][]int32, error) { + statusCodes := make([][]int32, 0, len(codes)) + for _, c := range codes { + sRange := strings.Split(c, "-") + iRange := make([]int32, len(sRange)) + for j, sCode := range sRange { + i, err := strconv.ParseInt(sCode, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid status code; input: %s", sCode) + } + iRange[j] = int32(i) + } + if len(iRange) > 0 { + statusCodes = append(statusCodes, iRange) } } - return false + + return statusCodes, nil } // default websocket handler implementation using https://github.com/nhooyr/websocket @@ -779,6 +969,28 @@ func (wh *WebsocketHandler) Close() error { return wh.Conn.Close(websocket.StatusNormalClosure, "close websocket") } +// GetCloseStatus tries to get WebSocket close status from error +// https://www.iana.org/assignments/websocket/websocket.xhtml +func (wh *WebsocketHandler) GetCloseStatus(err error) int32 { + // context timeout error returned from ReadJSON or WriteJSON + // try to ping the server, if failed return abnormal closeure error + if errors.Is(err, context.DeadlineExceeded) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if pingErr := wh.Ping(ctx); pingErr != nil { + return int32(websocket.StatusNoStatusRcvd) + } + return -1 + } + + code := websocket.CloseStatus(err) + if code == -1 && strings.Contains(err.Error(), "received header with unexpected rsv bits") { + return int32(websocket.StatusNormalClosure) + } + + return int32(code) +} + // the default constructor function to create a websocket client // which uses https://github.com/nhooyr/websocket library func newWebsocketConn(sc *SubscriptionClient) (WebsocketConn, error) { diff --git a/subscription_graphql_ws.go b/subscription_graphql_ws.go index 475d9b3..bfa0527 100644 --- a/subscription_graphql_ws.go +++ b/subscription_graphql_ws.go @@ -43,8 +43,8 @@ func (gws *graphqlWS) ConnectionInit(ctx *SubscriptionContext, connectionParams } // Subscribe requests an graphql operation specified in the payload message -func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error { - if sub.GetStarted() { +func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, sub Subscription) error { + if sub.GetStatus() == SubscriptionRunning { return nil } payload, err := json.Marshal(sub.GetPayload()) @@ -53,7 +53,7 @@ func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscri } // send start message to the server msg := OperationMessage{ - ID: id, + ID: sub.id, Type: GQLSubscribe, Payload: payload, } @@ -62,48 +62,40 @@ func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscri return err } - sub.SetStarted(true) - ctx.SetSubscription(id, &sub) + sub.SetStatus(SubscriptionRunning) + ctx.SetSubscription(sub.GetKey(), &sub) return nil } // Unsubscribe sends stop message to server and close subscription channel // The input parameter is subscription ID that is returned from Subscribe function -func (gws *graphqlWS) Unsubscribe(ctx *SubscriptionContext, id string) error { - if ctx == nil || ctx.GetWebsocketConn() == nil { - return nil - } - sub := ctx.GetSubscription(id) - - if sub == nil { - return fmt.Errorf("subscription id %s doesn't not exist", id) - } - - ctx.SetSubscription(id, nil) - +func (gws *graphqlWS) Unsubscribe(ctx *SubscriptionContext, sub Subscription) error { // send stop message to the server msg := OperationMessage{ - ID: id, + ID: sub.id, Type: GQLComplete, } - err := ctx.Send(msg, GQLComplete) - // close the client if there is no running subscription - if ctx.GetSubscriptionsLength() == 0 { - ctx.Log("no running subscription. exiting...", "client", GQLInternal) - return ctx.Close() - } - - return err + return ctx.Send(msg, GQLComplete) } // OnMessage listens ongoing messages from server -func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) { +func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) error { switch message.Type { case GQLError: ctx.Log(message, "server", message.Type) + var errs Errors + jsonErr := json.Unmarshal(message.Payload, &errs) + if jsonErr != nil { + subscription.handler(nil, fmt.Errorf("%s", string(message.Payload))) + return nil + } + if len(errs) > 0 { + subscription.handler(nil, errs) + return nil + } case GQLNext: ctx.Log(message, "server", message.Type) var out struct { @@ -111,17 +103,17 @@ func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscript Errors Errors } if subscription.handler == nil { - return + return nil } err := json.Unmarshal(message.Payload, &out) if err != nil { subscription.handler(nil, err) - return + return nil } if len(out.Errors) > 0 { subscription.handler(nil, out.Errors) - return + return nil } var outData []byte @@ -132,11 +124,23 @@ func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscript subscription.handler(outData, nil) case GQLComplete: ctx.Log(message, "server", message.Type) - _ = gws.Unsubscribe(ctx, message.ID) + sub := ctx.GetSubscription(message.ID) + if ctx.OnSubscriptionComplete != nil { + if sub == nil { + ctx.OnSubscriptionComplete(Subscription{ + id: message.ID, + }) + } else { + ctx.OnSubscriptionComplete(*sub) + } + } + if sub != nil { + ctx.SetSubscription(sub.GetKey(), nil) + } case GQLPing: ctx.Log(message, "server", GQLPing) - if ctx.onConnectionAlive != nil { - ctx.onConnectionAlive() + if ctx.OnConnectionAlive != nil { + ctx.OnConnectionAlive() } // send pong response message back to the server msg := OperationMessage{ @@ -153,9 +157,9 @@ func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscript ctx.Log(message, "server", GQLConnectionAck) ctx.SetAcknowledge(true) for id, sub := range ctx.GetSubscriptions() { - if err := gws.Subscribe(ctx, id, sub); err != nil { - gws.Unsubscribe(ctx, id) - return + if err := gws.Subscribe(ctx, sub); err != nil { + ctx.Log(fmt.Sprintf("failed to subscribe: %s; id: %s; query: %s", err, id, sub.payload.Query), "client", GQLInternal) + return nil } } if ctx.OnConnected != nil { @@ -164,6 +168,8 @@ func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscript default: ctx.Log(message, "server", GQLUnknown) } + + return nil } // Close terminates all subscriptions of the current websocket diff --git a/subscription_graphql_ws_test.go b/subscription_graphql_ws_test.go index 8de2973..051e8e6 100644 --- a/subscription_graphql_ws_test.go +++ b/subscription_graphql_ws_test.go @@ -10,6 +10,8 @@ import ( "net/http" "testing" "time" + + "nhooyr.io/websocket" ) const ( @@ -84,7 +86,12 @@ func TestGraphqlWS_Subscription(t *testing.T) { client, subscriptionClient := hasura_setupClients(GraphQLWS) msg := randomID() + hasKeepAlive := false + subscriptionClient = subscriptionClient. + OnConnectionAlive(func() { + hasKeepAlive = true + }). OnError(func(sc *SubscriptionClient, err error) error { return err }) @@ -173,6 +180,10 @@ func TestGraphqlWS_Subscription(t *testing.T) { } <-stop + + if !hasKeepAlive { + t.Fatalf("expected OnConnectionAlive event, got none") + } } func TestGraphqlWS_SubscriptionRerun(t *testing.T) { @@ -276,3 +287,81 @@ func TestGraphqlWS_SubscriptionRerun(t *testing.T) { (*t).Fatalf("got error: %v, want: nil", err) } } + +func TestGraphQLWS_OnError(t *testing.T) { + stop := make(chan bool) + + subscriptionClient := NewSubscriptionClient(fmt.Sprintf("%s/v1/graphql", hasuraTestHost)). + WithProtocol(GraphQLWS). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "x-hasura-admin-secret": "test", + }, + }).WithLog(log.Println) + + msg := randomID() + + subscriptionClient = subscriptionClient. + OnConnected(func() { + log.Println("client connected") + }). + OnError(func(sc *SubscriptionClient, err error) error { + log.Println("OnError: ", err) + return err + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + if err := subscriptionClient.Run(); err == nil || websocket.CloseStatus(err) != 4400 { + (*t).Fatalf("got error: %v, want: 4400", err) + } + stop <- true + }() + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + <-stop +} diff --git a/subscription_test.go b/subscription_test.go index dc00e79..18d84c2 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -4,273 +4,73 @@ import ( "context" "encoding/json" "errors" + "fmt" "log" - "math/rand" - "net/http" + "sync" "testing" "time" - "github.com/graph-gophers/graphql-go" - "github.com/graph-gophers/graphql-go/relay" - "github.com/graph-gophers/graphql-transport-ws/graphqlws" + "nhooyr.io/websocket" ) -const schema = ` -schema { - subscription: Subscription - mutation: Mutation - query: Query -} -type Query { - hello: String! -} -type Subscription { - helloSaid(): HelloSaidEvent! -} -type Mutation { - sayHello(msg: String!): HelloSaidEvent! -} -type HelloSaidEvent { - id: String! - msg: String! -} -` - -func subscription_setupClients() (*Client, *SubscriptionClient) { - endpoint := "http://localhost:8081/graphql" - - client := NewClient(endpoint, &http.Client{Transport: http.DefaultTransport}) - - subscriptionClient := NewSubscriptionClient(endpoint). - WithConnectionParams(map[string]interface{}{ - "headers": map[string]string{ - "foo": "bar", - }, - }).WithLog(log.Println) - - return client, subscriptionClient -} - -func subscription_setupServer() *http.Server { - - // init graphQL schema - s, err := graphql.ParseSchema(schema, newResolver()) - if err != nil { - panic(err) - } - - // graphQL handler - mux := http.NewServeMux() - graphQLHandler := graphqlws.NewHandlerFunc(s, &relay.Handler{Schema: s}) - mux.HandleFunc("/graphql", graphQLHandler) - server := &http.Server{Addr: ":8081", Handler: mux} - - return server -} - -type resolver struct { - helloSaidEvents chan *helloSaidEvent - helloSaidSubscriber chan *helloSaidSubscriber -} - -func newResolver() *resolver { - r := &resolver{ - helloSaidEvents: make(chan *helloSaidEvent), - helloSaidSubscriber: make(chan *helloSaidSubscriber), - } - - go r.broadcastHelloSaid() - - return r -} - -func (r *resolver) Hello() string { - return "Hello world!" -} - -func (r *resolver) SayHello(args struct{ Msg string }) *helloSaidEvent { - e := &helloSaidEvent{msg: args.Msg, id: randomID()} - go func() { - select { - case r.helloSaidEvents <- e: - case <-time.After(1 * time.Second): - } - }() - return e -} - -type helloSaidSubscriber struct { - stop <-chan struct{} - events chan<- *helloSaidEvent -} - -func (r *resolver) broadcastHelloSaid() { - subscribers := map[string]*helloSaidSubscriber{} - unsubscribe := make(chan string) - - // NOTE: subscribing and sending events are at odds. - for { - select { - case id := <-unsubscribe: - delete(subscribers, id) - case s := <-r.helloSaidSubscriber: - id := randomID() - log.Println("new client subscribed: ", id) - subscribers[id] = s - case e := <-r.helloSaidEvents: - for id, s := range subscribers { - go func(id string, s *helloSaidSubscriber) { - select { - case <-s.stop: - unsubscribe <- id - return - default: - } - - select { - case <-s.stop: - unsubscribe <- id - case s.events <- e: - case <-time.After(time.Second): - } - }(id, s) - } - } - } -} - -func (r *resolver) HelloSaid(ctx context.Context) <-chan *helloSaidEvent { - c := make(chan *helloSaidEvent) - // NOTE: this could take a while - r.helloSaidSubscriber <- &helloSaidSubscriber{events: c, stop: ctx.Done()} - - return c -} - -type helloSaidEvent struct { - id string - msg string -} - -func (r *helloSaidEvent) Msg() string { - return r.msg -} - -func (r *helloSaidEvent) ID() string { - return r.id -} - -func randomID() string { - var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - - b := make([]rune, 16) - for i := range b { - b[i] = letter[rand.Intn(len(letter))] - } - return string(b) -} - -func TestSubscriptionLifeCycle(t *testing.T) { - stop := make(chan bool) - server := subscription_setupServer() - client, subscriptionClient := subscription_setupClients() +func TestSubscription_LifeCycleEvents(t *testing.T) { + server := subscription_setupServer(8082) + client, subscriptionClient := subscription_setupClients(8082) msg := randomID() - go func() { - if err := server.ListenAndServe(); err != nil { - log.Println(err) - } - }() - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer server.Shutdown(ctx) - defer cancel() - - subscriptionClient. - OnError(func(sc *SubscriptionClient, err error) error { - return err - }) - - /* - subscription { - helloSaid { - id - msg - } - } - */ - var sub struct { - HelloSaid struct { - ID String - Message String `graphql:"msg" json:"msg"` - } `graphql:"helloSaid" json:"helloSaid"` - } - - _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil - } - log.Println("result", string(data)) - e = json.Unmarshal(data, &sub) - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil - } - - if sub.HelloSaid.Message != String(msg) { - t.Fatalf("subscription message does not match. got: %s, want: %s", sub.HelloSaid.Message, msg) - } - - return errors.New("exit") - }) - - if err != nil { - t.Fatalf("got error: %v, want: nil", err) + var lock sync.Mutex + subscriptionResults := []Subscription{} + wasConnected := false + wasDisconnected := false + addResult := func(s Subscription) int { + lock.Lock() + defer lock.Unlock() + subscriptionResults = append(subscriptionResults, s) + return len(subscriptionResults) } - go func() { - if err := subscriptionClient.Run(); err == nil || err.Error() != "exit" { - (*t).Fatalf("got error: %v, want: exit", err) - } - stop <- true - }() - - defer subscriptionClient.Close() + fixtures := []struct { + Query interface{} + Variables map[string]interface{} + Subscription *Subscription + }{ + { + Query: func() interface{} { + var t struct { + HelloSaid struct { + ID String + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } - // wait until the subscription client connects to the server - time.Sleep(2 * time.Second) + return t + }(), + Variables: nil, + Subscription: &Subscription{ + payload: GraphQLRequestPayload{ + Query: "subscription{helloSaid{id,msg}}", + }, + }, + }, + { + Query: func() interface{} { + var t struct { + HelloSaid struct { + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } - // call a mutation request to send message to the subscription - /* - mutation ($msg: String!) { - sayHello(msg: $msg) { - id - msg - } - } - */ - var q struct { - SayHello struct { - ID String - Msg String - } `graphql:"sayHello(msg: $msg)"` - } - variables := map[string]interface{}{ - "msg": String(msg), - } - err = client.Mutate(context.Background(), &q, variables, OperationName("SayHello")) - if err != nil { - t.Fatalf("got error: %v, want: nil", err) + return t + }(), + Variables: nil, + Subscription: &Subscription{ + payload: GraphQLRequestPayload{ + Query: "subscription{helloSaid{msg}}", + }, + }, + }, } - <-stop -} - -func TestSubscriptionLifeCycle2(t *testing.T) { - server := subscription_setupServer() - client, subscriptionClient := subscription_setupClients() - msg := randomID() go func() { if err := server.ListenAndServe(); err != nil { log.Println(err) @@ -281,89 +81,58 @@ func TestSubscriptionLifeCycle2(t *testing.T) { defer server.Shutdown(ctx) defer cancel() - subscriptionClient. + subscriptionClient = subscriptionClient. + WithExitWhenNoSubscription(false). + WithTimeout(3 * time.Second). + OnConnected(func() { + lock.Lock() + defer lock.Unlock() + log.Println("connected") + wasConnected = true + }). OnError(func(sc *SubscriptionClient, err error) error { t.Fatalf("got error: %v, want: nil", err) return err }). OnDisconnected(func() { + lock.Lock() + defer lock.Unlock() log.Println("disconnected") - }) - /* - subscription { - helloSaid { - id - msg + wasDisconnected = true + }). + OnSubscriptionComplete(func(s Subscription) { + log.Println("OnSubscriptionComplete: ", s) + length := addResult(s) + if length == len(fixtures) { + log.Println("done, closing...") + subscriptionClient.Close() } - } - */ - var sub struct { - HelloSaid struct { - ID String - Message String `graphql:"msg" json:"msg"` - } `graphql:"helloSaid" json:"helloSaid"` - } - - subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil - } - - log.Println("result", string(data)) - e = json.Unmarshal(data, &sub) - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil - } - - if sub.HelloSaid.Message != String(msg) { - t.Fatalf("subscription message does not match. got: %s, want: %s", sub.HelloSaid.Message, msg) - } - - return nil - }) - - if err != nil { - t.Fatalf("got error: %v, want: nil", err) - } + }) - /* - subscription { - helloSaid { - id - msg + for _, f := range fixtures { + id, err := subscriptionClient.Subscribe(f.Query, f.Variables, func(data []byte, e error) error { + lock.Lock() + defer lock.Unlock() + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil } - } - */ - var sub2 struct { - HelloSaid struct { - Message String `graphql:"msg" json:"msg"` - } `graphql:"helloSaid" json:"helloSaid"` - } - _, err = subscriptionClient.Subscribe(sub2, nil, func(data []byte, e error) error { - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil - } + log.Println("result", string(data)) + e = json.Unmarshal(data, &f.Query) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } - log.Println("result", string(data)) - e = json.Unmarshal(data, &sub2) - if e != nil { - t.Fatalf("got error: %v, want: nil", e) return nil - } + }) - if sub2.HelloSaid.Message != String(msg) { - t.Fatalf("subscription message does not match. got: %s, want: %s", sub2.HelloSaid.Message, msg) + if err != nil { + t.Fatalf("got error: %v, want: nil", err) } - - return ErrSubscriptionStopped - }) - - if err != nil { - t.Fatalf("got error: %v, want: nil", err) + f.Subscription.id = id + log.Printf("subscribed: %s; subscriptions %+v", id, subscriptionClient.context.subscriptions) } go func() { @@ -388,13 +157,21 @@ func TestSubscriptionLifeCycle2(t *testing.T) { variables := map[string]interface{}{ "msg": String(msg), } - err = client.Mutate(context.Background(), &q, variables, OperationName("SayHello")) + err := client.Mutate(context.Background(), &q, variables, OperationName("SayHello")) if err != nil { (*t).Fatalf("got error: %v, want: nil", err) } - time.Sleep(time.Second) - subscriptionClient.Unsubscribe(subId1) + time.Sleep(2 * time.Second) + for _, f := range fixtures { + log.Println("unsubscribing ", f.Subscription.id) + if err := subscriptionClient.Unsubscribe(f.Subscription.id); err != nil { + log.Printf("subscriptions: %+v", subscriptionClient.context.subscriptions) + panic(err) + + } + time.Sleep(time.Second) + } }() defer subscriptionClient.Close() @@ -402,21 +179,48 @@ func TestSubscriptionLifeCycle2(t *testing.T) { if err := subscriptionClient.Run(); err != nil { t.Fatalf("got error: %v, want: nil", err) } -} -func TestSubscription_ResetClient(t *testing.T) { + if len(subscriptionResults) != len(fixtures) { + t.Fatalf("failed to listen OnSubscriptionComplete event. got %+v, want: %+v", len(subscriptionResults), len(fixtures)) + } + for i, s := range subscriptionResults { + if s.id != fixtures[i].Subscription.id { + t.Fatalf("%d: subscription id not matched, got: %s, want: %s", i, s.GetPayload().Query, fixtures[i].Subscription.payload.Query) + } + if s.GetPayload().Query != fixtures[i].Subscription.payload.Query { + t.Fatalf("%d: query output not matched, got: %s, want: %s", i, s.GetPayload().Query, fixtures[i].Subscription.payload.Query) + } + } + + if !wasConnected { + t.Fatalf("expected OnConnected event, got none") + } + if !wasDisconnected { + t.Fatalf("expected OnDisonnected event, got none") + } +} +func TestSubscription_WithRetryStatusCodes(t *testing.T) { stop := make(chan bool) - client, subscriptionClient := hasura_setupClients(SubscriptionsTransportWS) msg := randomID() - - subscriptionClient. + disconnectedCount := 0 + subscriptionClient := NewSubscriptionClient(fmt.Sprintf("%s/v1/graphql", hasuraTestHost)). + WithProtocol(GraphQLWS). + WithRetryStatusCodes("4400"). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "x-hasura-admin-secret": "test", + }, + }).WithLog(log.Println). + OnDisconnected(func() { + disconnectedCount++ + if disconnectedCount > 5 { + stop <- true + } + }). OnError(func(sc *SubscriptionClient, err error) error { - t.Fatalf("got error: %v, want: nil", err) + t.Fatal("should not receive error") return err - }). - OnDisconnected(func() { - log.Println("disconnected") }) /* @@ -434,7 +238,7 @@ func TestSubscription_ResetClient(t *testing.T) { } `graphql:"user(order_by: { id: desc }, limit: 5)"` } - subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { if e != nil { t.Fatalf("got error: %v, want: nil", e) return nil @@ -458,6 +262,12 @@ func TestSubscription_ResetClient(t *testing.T) { t.Fatalf("got error: %v, want: nil", err) } + go func() { + if err := subscriptionClient.Run(); err != nil && websocket.CloseStatus(err) == 4400 { + (*t).Fatalf("should not get error 4400, got error: %v, want: nil", err) + } + }() + defer subscriptionClient.Close() // wait until the subscription client connects to the server @@ -465,97 +275,144 @@ func TestSubscription_ResetClient(t *testing.T) { t.Fatalf("failed to start hasura service: %s", err) } - /* - subscription { - user { - id - name - } - } - */ - var sub2 struct { - Users []struct { - ID int `graphql:"id"` - } `graphql:"user(order_by: { id: desc }, limit: 5)"` + <-stop +} + +func TestSubscription_parseInt32Ranges(t *testing.T) { + fixtures := []struct { + Input []string + Expected [][]int32 + Error error + }{ + { + Input: []string{"1", "2", "3-5"}, + Expected: [][]int32{{1}, {2}, {3, 5}}, + }, + { + Input: []string{"a", "2", "3-5"}, + Error: errors.New("invalid status code; input: a"), + }, } - subId2, err := subscriptionClient.Subscribe(sub2, nil, func(data []byte, e error) error { - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil + for i, f := range fixtures { + output, err := parseInt32Ranges(f.Input) + if f.Expected != nil && fmt.Sprintf("%v", output) != fmt.Sprintf("%v", f.Expected) { + t.Fatalf("%d: got: %+v, want: %+v", i, output, f.Expected) } - - log.Println("result", string(data)) - e = json.Unmarshal(data, &sub2) - if e != nil { - t.Fatalf("got error: %v, want: nil", e) - return nil + if f.Error != nil && f.Error.Error() != err.Error() { + t.Fatalf("%d: error should equal, got: %+v, want: %+v", i, err, f.Error) } + } +} - if len(sub.Users) > 0 && sub.Users[0].Name != msg { - t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) - } +func TestSubscription_closeThenRun(t *testing.T) { + _, subscriptionClient := hasura_setupClients(GraphQLWS) + + fixtures := []struct { + Query interface{} + Variables map[string]interface{} + Subscription *Subscription + }{ + { + Query: func() interface{} { + var t struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } - return nil - }) + return t + }(), + Variables: nil, + Subscription: &Subscription{ + payload: GraphQLRequestPayload{ + Query: "subscription{helloSaid{id,msg}}", + }, + }, + }, + { + Query: func() interface{} { + var t struct { + Users []struct { + ID int `graphql:"id"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } - if err != nil { - t.Fatalf("got error: %v, want: nil", err) + return t + }(), + Variables: nil, + Subscription: &Subscription{ + payload: GraphQLRequestPayload{ + Query: "subscription{helloSaid{msg}}", + }, + }, + }, } - go func() { + subscriptionClient = subscriptionClient. + WithExitWhenNoSubscription(false). + WithTimeout(3 * time.Second). + OnError(func(sc *SubscriptionClient, err error) error { + t.Fatalf("got error: %v, want: nil", err) + return err + }) - // call a mutation request to send message to the subscription - /* - mutation InsertUser($objects: [user_insert_input!]!) { - insert_user(objects: $objects) { - id - name + bulkSubscribe := func() { + + for _, f := range fixtures { + id, err := subscriptionClient.Subscribe(f.Query, f.Variables, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil } + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) } - */ - var q struct { - InsertUser struct { - Returning []struct { - ID int `graphql:"id"` - Name string `graphql:"name"` - } `graphql:"returning"` - } `graphql:"insert_user(objects: $objects)"` + log.Printf("subscribed: %s", id) } - variables := map[string]interface{}{ - "objects": []user_insert_input{ - { - "name": msg, - }, - }, - } - err = client.Mutate(context.Background(), &q, variables, OperationName("InsertUser")) + } - if err != nil { - (*t).Fatalf("got error: %v, want: nil", err) - } + bulkSubscribe() - time.Sleep(2 * time.Second) - // reset the subscription - log.Printf("resetting the subscription client...") + go func() { if err := subscriptionClient.Run(); err != nil { - (*t).Fatalf("failed to reset the subscription client. got error: %v, want: nil", err) + (*t).Fatalf("got error: %v, want: nil", err) } - log.Printf("the second run was stopped") - stop <- true }() + time.Sleep(3 * time.Second) + if err := subscriptionClient.Close(); err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } + + bulkSubscribe() + go func() { - time.Sleep(8 * time.Second) - subscriptionClient.Unsubscribe(subId1) - subscriptionClient.Unsubscribe(subId2) - }() + length := subscriptionClient.getContext().GetSubscriptionsLength(nil) + if length != 2 { + (*t).Fatalf("unexpected subscription client. got: %d, want: 2", length) + } - defer subscriptionClient.Close() + waitingLen := subscriptionClient.getContext().GetSubscriptionsLength([]SubscriptionStatus{SubscriptionWaiting}) + if waitingLen != 2 { + (*t).Fatalf("unexpected waiting subscription client. got: %d, want: 2", waitingLen) + } + if err := subscriptionClient.Run(); err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + panic(err) + } + }() - if err := subscriptionClient.Run(); err != nil { + time.Sleep(3 * time.Second) + length := subscriptionClient.getContext().GetSubscriptionsLength(nil) + if length != 2 { + (*t).Fatalf("unexpected subscription client after restart. got: %d, want: 2, subscriptions: %+v", length, subscriptionClient.context.subscriptions) + } + if err := subscriptionClient.Close(); err != nil { t.Fatalf("got error: %v, want: nil", err) } - - <-stop } diff --git a/subscriptions_transport_ws.go b/subscriptions_transport_ws.go index abd3303..9caca07 100644 --- a/subscriptions_transport_ws.go +++ b/subscriptions_transport_ws.go @@ -2,6 +2,7 @@ package graphql import ( "encoding/json" + "errors" "fmt" ) @@ -70,8 +71,8 @@ func (stw *subscriptionsTransportWS) ConnectionInit(ctx *SubscriptionContext, co } // Subscribe requests an graphql operation specified in the payload message -func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id string, sub Subscription) error { - if sub.GetStarted() { +func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, sub Subscription) error { + if sub.GetStatus() == SubscriptionRunning { return nil } payload, err := json.Marshal(sub.GetPayload()) @@ -80,7 +81,7 @@ func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id stri } // send start message to the server msg := OperationMessage{ - ID: id, + ID: sub.id, Type: GQLStart, Payload: payload, } @@ -89,45 +90,26 @@ func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id stri return err } - sub.SetStarted(true) - ctx.SetSubscription(id, &sub) + sub.SetStatus(SubscriptionRunning) + ctx.SetSubscription(sub.GetKey(), &sub) return nil } // Unsubscribe sends stop message to server and close subscription channel // The input parameter is subscription ID that is returned from Subscribe function -func (stw *subscriptionsTransportWS) Unsubscribe(ctx *SubscriptionContext, id string) error { - if ctx == nil || ctx.GetWebsocketConn() == nil { - return nil - } - sub := ctx.GetSubscription(id) - - if sub == nil { - return fmt.Errorf("subscription id %s doesn't not exist", id) - } - - ctx.SetSubscription(id, nil) - +func (stw *subscriptionsTransportWS) Unsubscribe(ctx *SubscriptionContext, sub Subscription) error { // send stop message to the server msg := OperationMessage{ - ID: id, + ID: sub.id, Type: GQLStop, } - err := ctx.Send(msg, GQLStop) - - // close the client if there is no running subscription - if ctx.GetSubscriptionsLength() == 0 { - ctx.Log("no running subscription. exiting...", "client", GQLInternal) - return ctx.Close() - } - - return err + return ctx.Send(msg, GQLStop) } // OnMessage listens ongoing messages from server -func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) { +func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscription Subscription, message OperationMessage) error { switch message.Type { case GQLError: @@ -139,17 +121,17 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri Errors Errors } if subscription.handler == nil { - return + return nil } err := json.Unmarshal(message.Payload, &out) if err != nil { subscription.handler(nil, err) - return + return nil } if len(out.Errors) > 0 { subscription.handler(nil, out.Errors) - return + return nil } var outData []byte @@ -160,15 +142,48 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri subscription.handler(outData, nil) case GQLConnectionError, "conn_err": ctx.Log(message, "server", GQLConnectionError) - _ = stw.Close(ctx) - _ = ctx.Close() + + // try to parse the error object + var payload interface{} + err := fmt.Errorf(string(message.Payload)) + jsonErr := json.Unmarshal(message.Payload, &payload) + if jsonErr == nil { + var errMsg string + if p, ok := payload.(map[string]interface{}); ok { + if msg, ok := p["error"]; ok { + errMsg = fmt.Sprint(msg) + } else if msg, ok := p["message"]; ok { + errMsg = fmt.Sprint(msg) + } + err = Error{ + Message: errMsg, + Extensions: p, + } + } else if s, ok := payload.(string); ok { + return errors.New(s) + } + } + return err case GQLComplete: ctx.Log(message, "server", GQLComplete) - _ = stw.Unsubscribe(ctx, message.ID) + sub := ctx.GetSubscription(message.ID) + if ctx.OnSubscriptionComplete != nil { + if sub == nil { + ctx.OnSubscriptionComplete(Subscription{ + id: message.ID, + }) + } else { + ctx.OnSubscriptionComplete(*sub) + ctx.SetSubscription(sub.GetKey(), nil) + } + } + if sub != nil { + ctx.SetSubscription(sub.GetKey(), nil) + } case GQLConnectionKeepAlive: ctx.Log(message, "server", GQLConnectionKeepAlive) - if ctx.onConnectionAlive != nil { - ctx.onConnectionAlive() + if ctx.OnConnectionAlive != nil { + ctx.OnConnectionAlive() } case GQLConnectionAck: // Expected response to the ConnectionInit message from the client acknowledging a successful connection with the server. @@ -177,9 +192,9 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri ctx.SetAcknowledge(true) subscriptions := ctx.GetSubscriptions() for id, sub := range subscriptions { - if err := stw.Subscribe(ctx, id, sub); err != nil { - _ = stw.Unsubscribe(ctx, id) - return + if err := stw.Subscribe(ctx, sub); err != nil { + ctx.Log(fmt.Sprintf("failed to subscribe: %s; id: %s; query: %s", err, id, sub.payload.Query), "client", GQLInternal) + return nil } } if ctx.OnConnected != nil { @@ -188,6 +203,8 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri default: ctx.Log(message, "server", GQLUnknown) } + + return nil } // Close terminates all subscriptions of the current websocket diff --git a/subscriptions_transport_ws_test.go b/subscriptions_transport_ws_test.go new file mode 100644 index 0000000..096baf0 --- /dev/null +++ b/subscriptions_transport_ws_test.go @@ -0,0 +1,770 @@ +package graphql + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "net/http" + "testing" + "time" + + "github.com/graph-gophers/graphql-go" + "github.com/graph-gophers/graphql-go/relay" + "github.com/graph-gophers/graphql-transport-ws/graphqlws" +) + +const schema = ` +schema { + subscription: Subscription + mutation: Mutation + query: Query +} +type Query { + hello: String! +} +type Subscription { + helloSaid(): HelloSaidEvent! +} +type Mutation { + sayHello(msg: String!): HelloSaidEvent! +} +type HelloSaidEvent { + id: String! + msg: String! +} +` + +func subscription_setupClients(port int) (*Client, *SubscriptionClient) { + endpoint := fmt.Sprintf("http://localhost:%d/graphql", port) + + client := NewClient(endpoint, &http.Client{Transport: http.DefaultTransport}) + + subscriptionClient := NewSubscriptionClient(endpoint). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "foo": "bar", + }, + }).WithLog(log.Println) + + return client, subscriptionClient +} + +func subscription_setupServer(port int) *http.Server { + + // init graphQL schema + s, err := graphql.ParseSchema(schema, newResolver()) + if err != nil { + panic(err) + } + + // graphQL handler + mux := http.NewServeMux() + graphQLHandler := graphqlws.NewHandlerFunc(s, &relay.Handler{Schema: s}) + mux.HandleFunc("/graphql", graphQLHandler) + server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} + + return server +} + +type resolver struct { + helloSaidEvents chan *helloSaidEvent + helloSaidSubscriber chan *helloSaidSubscriber +} + +func newResolver() *resolver { + r := &resolver{ + helloSaidEvents: make(chan *helloSaidEvent), + helloSaidSubscriber: make(chan *helloSaidSubscriber), + } + + go r.broadcastHelloSaid() + + return r +} + +func (r *resolver) Hello() string { + return "Hello world!" +} + +func (r *resolver) SayHello(args struct{ Msg string }) *helloSaidEvent { + e := &helloSaidEvent{msg: args.Msg, id: randomID()} + go func() { + select { + case r.helloSaidEvents <- e: + case <-time.After(1 * time.Second): + } + }() + return e +} + +type helloSaidSubscriber struct { + stop <-chan struct{} + events chan<- *helloSaidEvent +} + +func (r *resolver) broadcastHelloSaid() { + subscribers := map[string]*helloSaidSubscriber{} + unsubscribe := make(chan string) + + // NOTE: subscribing and sending events are at odds. + for { + select { + case id := <-unsubscribe: + delete(subscribers, id) + case s := <-r.helloSaidSubscriber: + id := randomID() + log.Println("new client subscribed: ", id) + subscribers[id] = s + case e := <-r.helloSaidEvents: + for id, s := range subscribers { + go func(id string, s *helloSaidSubscriber) { + select { + case <-s.stop: + unsubscribe <- id + return + default: + } + + select { + case <-s.stop: + unsubscribe <- id + case s.events <- e: + case <-time.After(time.Second): + } + }(id, s) + } + } + } +} + +func (r *resolver) HelloSaid(ctx context.Context) <-chan *helloSaidEvent { + c := make(chan *helloSaidEvent) + // NOTE: this could take a while + r.helloSaidSubscriber <- &helloSaidSubscriber{events: c, stop: ctx.Done()} + + return c +} + +type helloSaidEvent struct { + id string + msg string +} + +func (r *helloSaidEvent) Msg() string { + return r.msg +} + +func (r *helloSaidEvent) ID() string { + return r.id +} + +func randomID() string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, 16) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +func TestTransportWS_basicTest(t *testing.T) { + stop := make(chan bool) + server := subscription_setupServer(8081) + client, subscriptionClient := subscription_setupClients(8081) + msg := randomID() + go func() { + if err := server.ListenAndServe(); err != nil { + log.Println(err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer server.Shutdown(ctx) + defer cancel() + + subscriptionClient. + OnError(func(sc *SubscriptionClient, err error) error { + return err + }) + + /* + subscription { + helloSaid { + id + msg + } + } + */ + var sub struct { + HelloSaid struct { + ID String + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } + + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if sub.HelloSaid.Message != String(msg) { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.HelloSaid.Message, msg) + } + + return errors.New("exit") + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + if err := subscriptionClient.Run(); err == nil || err.Error() != "exit" { + (*t).Fatalf("got error: %v, want: exit", err) + } + stop <- true + }() + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + time.Sleep(2 * time.Second) + + // call a mutation request to send message to the subscription + /* + mutation ($msg: String!) { + sayHello(msg: $msg) { + id + msg + } + } + */ + var q struct { + SayHello struct { + ID String + Msg String + } `graphql:"sayHello(msg: $msg)"` + } + variables := map[string]interface{}{ + "msg": String(msg), + } + err = client.Mutate(context.Background(), &q, variables, OperationName("SayHello")) + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + <-stop +} + +func TestTransportWS_exitWhenNoSubscription(t *testing.T) { + server := subscription_setupServer(8085) + client, subscriptionClient := subscription_setupClients(8085) + msg := randomID() + go func() { + if err := server.ListenAndServe(); err != nil { + log.Println(err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer server.Shutdown(ctx) + defer cancel() + + subscriptionClient = subscriptionClient. + WithTimeout(3 * time.Second). + OnError(func(sc *SubscriptionClient, err error) error { + t.Fatalf("got error: %v, want: nil", err) + return err + }). + OnDisconnected(func() { + log.Println("disconnected") + }) + /* + subscription { + helloSaid { + id + msg + } + } + */ + var sub struct { + HelloSaid struct { + ID String + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } + + subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if sub.HelloSaid.Message != String(msg) { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.HelloSaid.Message, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + /* + subscription { + helloSaid { + id + msg + } + } + */ + var sub2 struct { + HelloSaid struct { + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } + + subId2, err := subscriptionClient.Subscribe(sub2, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub2) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if sub2.HelloSaid.Message != String(msg) { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub2.HelloSaid.Message, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + // wait until the subscription client connects to the server + time.Sleep(2 * time.Second) + + // call a mutation request to send message to the subscription + /* + mutation ($msg: String!) { + sayHello(msg: $msg) { + id + msg + } + } + */ + var q struct { + SayHello struct { + ID String + Msg String + } `graphql:"sayHello(msg: $msg)"` + } + variables := map[string]interface{}{ + "msg": String(msg), + } + err = client.Mutate(context.Background(), &q, variables, OperationName("SayHello")) + if err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } + + time.Sleep(2 * time.Second) + subscriptionClient.Unsubscribe(subId1) + subscriptionClient.Unsubscribe(subId2) + }() + + defer subscriptionClient.Close() + + if err := subscriptionClient.Run(); err != nil { + t.Fatalf("got error: %v, want: nil", err) + } +} + +func TestTransportWS_ResetClient(t *testing.T) { + + stop := make(chan bool) + client, subscriptionClient := hasura_setupClients(SubscriptionsTransportWS) + msg := randomID() + + subscriptionClient. + OnError(func(sc *SubscriptionClient, err error) error { + t.Fatalf("got error: %v, want: nil", err) + return err + }). + OnDisconnected(func() { + log.Println("disconnected") + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + subId1, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + /* + subscription { + user { + id + name + } + } + */ + var sub2 struct { + Users []struct { + ID int `graphql:"id"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + subId2, err := subscriptionClient.Subscribe(sub2, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub2) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + + // call a mutation request to send message to the subscription + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": msg, + }, + }, + } + err = client.Mutate(context.Background(), &q, variables, OperationName("InsertUser")) + + if err != nil { + (*t).Fatalf("got error: %v, want: nil", err) + } + + time.Sleep(2 * time.Second) + + // test susbcription ids + sub1 := subscriptionClient.getContext().GetSubscription(subId1) + if sub1 == nil { + (*t).Fatalf("subscription 1 not found: %s", subId1) + } else { + if sub1.key != subId1 { + (*t).Fatalf("subscription key 1 not equal, got %s, want %s", subId1, sub1.key) + } + if sub1.id != subId1 { + (*t).Fatalf("subscription id 1 not equal, got %s, want %s", subId1, sub1.id) + } + } + sub2 := subscriptionClient.getContext().GetSubscription(subId2) + if sub2 == nil { + (*t).Fatalf("subscription 2 not found: %s", subId2) + } else { + if sub2.key != subId2 { + (*t).Fatalf("subscription id 2 not equal, got %s, want %s", subId2, sub2.key) + } + + if sub2.id != subId2 { + (*t).Fatalf("subscription id 2 not equal, got %s, want %s", subId2, sub2.id) + } + } + + // reset the subscription + log.Printf("resetting the subscription client...") + if err := subscriptionClient.Run(); err != nil { + (*t).Fatalf("failed to reset the subscription client. got error: %v, want: nil", err) + } + log.Printf("the second run was stopped") + stop <- true + }() + + go func() { + time.Sleep(8 * time.Second) + + // test subscription ids + sub1 := subscriptionClient.getContext().GetSubscription(subId1) + if sub1 == nil { + (*t).Fatalf("subscription 1 not found: %s", subId1) + } else { + if sub1.key != subId1 { + (*t).Fatalf("subscription key 1 not equal, got %s, want %s", subId1, sub1.key) + } + if sub1.id == subId1 { + (*t).Fatalf("subscription id 1 should equal, got %s, want %s", subId1, sub1.id) + } + } + sub2 := subscriptionClient.getContext().GetSubscription(subId2) + if sub2 == nil { + (*t).Fatalf("subscription 2 not found: %s", subId2) + } else { + if sub2.key != subId2 { + (*t).Fatalf("subscription id 2 not equal, got %s, want %s", subId2, sub2.key) + } + + if sub2.id == subId2 { + (*t).Fatalf("subscription id 2 should equal, got %s, want %s", subId2, sub2.id) + } + } + + subscriptionClient.Unsubscribe(subId1) + subscriptionClient.Unsubscribe(subId2) + }() + + defer subscriptionClient.Close() + + if err := subscriptionClient.Run(); err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + <-stop +} + +func TestTransportWS_onDisconnected(t *testing.T) { + port := 8083 + server := subscription_setupServer(port) + var wasConnected bool + disconnected := make(chan bool) + go func() { + if err := server.ListenAndServe(); err != nil { + log.Println(err) + } + }() + + // init client + _, subscriptionClient := subscription_setupClients(port) + subscriptionClient = subscriptionClient. + WithTimeout(5 * time.Second). + OnError(func(sc *SubscriptionClient, err error) error { + panic(err) + }). + OnConnected(func() { + log.Println("OnConnected") + wasConnected = true + }). + OnDisconnected(func() { + log.Println("OnDisconnected") + disconnected <- true + }) + + /* + subscription { + helloSaid { + id + msg + } + } + */ + var sub struct { + HelloSaid struct { + ID String + Message String `graphql:"msg" json:"msg"` + } `graphql:"helloSaid" json:"helloSaid"` + } + + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + } + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + // run client + go func() { + subscriptionClient.Run() + }() + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + time.Sleep(2 * time.Second) + if err := server.Close(); err != nil { + panic(err) + } + + <-disconnected + + if !wasConnected { + t.Fatal("the OnConnected event must be triggered") + } +} + +func TestTransportWS_OnError(t *testing.T) { + stop := make(chan bool) + + subscriptionClient := NewSubscriptionClient(fmt.Sprintf("%s/v1/graphql", hasuraTestHost)). + WithTimeout(3 * time.Second). + WithProtocol(SubscriptionsTransportWS). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "x-hasura-admin-secret": "test", + }, + }).WithLog(log.Println) + + msg := randomID() + + subscriptionClient = subscriptionClient. + OnConnected(func() { + log.Println("client connected") + }). + OnError(func(sc *SubscriptionClient, err error) error { + log.Println("OnError: ", err) + return err + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return nil + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + unauthorizedErr := "invalid x-hasura-admin-secret/x-hasura-access-key" + err := subscriptionClient.Run() + + if err == nil || err.Error() != unauthorizedErr { + (*t).Errorf("got error: %v, want: %s", err, unauthorizedErr) + } + stop <- true + }() + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + <-stop +}