diff --git a/accept.go b/accept.go index 774ea285..f45fdd0b 100644 --- a/accept.go +++ b/accept.go @@ -5,6 +5,7 @@ package websocket import ( "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" @@ -62,6 +63,22 @@ type AcceptOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { @@ -156,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con client: false, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: brw.Reader, bw: brw.Writer, diff --git a/conn.go b/conn.go index 76b057dd..42fe89fe 100644 --- a/conn.go +++ b/conn.go @@ -83,9 +83,11 @@ type Conn struct { closeMu sync.Mutex // Protects following. closed chan struct{} - pingCounter atomic.Int64 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) } type connConfig struct { @@ -94,6 +96,8 @@ type connConfig struct { client bool copts *compressionOptions flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) br *bufio.Reader bw *bufio.Writer @@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn { writeTimeout: make(chan context.Context), timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, } c.readMu = newMu(c) diff --git a/conn_test.go b/conn_test.go index 9ed8c7ea..45bb75be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -97,6 +97,85 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingReceivedPongReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Success(t, err) + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2) + assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1)) + }) + + t.Run("pingReceivedPongNotReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1)) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) diff --git a/dial.go b/dial.go index ad61a35d..0b11ecbb 100644 --- a/dial.go +++ b/dial.go @@ -48,6 +48,22 @@ type DialOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( client: true, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil diff --git a/read.go b/read.go index 1267b5b9..2db22435 100644 --- a/read.go +++ b/read.go @@ -312,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + if c.onPingReceived != nil { + if !c.onPingReceived(ctx, b) { + return nil + } + } return c.writeControl(ctx, opPong, b) case opPong: + if c.onPongReceived != nil { + c.onPongReceived(ctx, b) + } c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock()