diff --git a/conn_notjs.go b/conn_notjs.go index 0c85ab77..344610fe 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -59,6 +59,8 @@ type Conn struct { pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} + pingHandler func(ctx context.Context, p []byte) error + pongHandler func(ctx context.Context, p []byte) error } type connConfig struct { @@ -89,6 +91,9 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } + // set default ping, pong handler + c.SetPingHandler(nil) + c.SetPongHandler(nil) c.readMu = newMu(c) c.writeFrameMu = newMu(c) diff --git a/conn_test.go b/conn_test.go index c2c41292..83cfda89 100644 --- a/conn_test.go +++ b/conn_test.go @@ -103,6 +103,51 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingHandler", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() + + var count int + c2.SetPingHandler(func(context.Context, []byte) error { + count++ + return nil + }) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + for i := 0; i < 10; i++ { + err := c1.Ping(tt.ctx) + assert.Success(t, err) + } + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + assert.Equal(t, "count", 10, count) + }) + + t.Run("pongHandler", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + defer tt.cleanup() + + var count int + c1.SetPongHandler(func(context.Context, []byte) error { + count++ + return nil + }) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + for i := 0; i < 10; i++ { + err := c1.Ping(tt.ctx) + assert.Success(t, err) + } + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + assert.Equal(t, "count", 10, count) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() diff --git a/read.go b/read.go index 89a00988..1048c8e5 100644 --- a/read.go +++ b/read.go @@ -75,6 +75,30 @@ func (c *Conn) SetReadLimit(n int64) { c.msgReader.limitReader.limit.Store(n + 1) } +// SetPingHandler set the handler for pong handler +// From 5.5.2 of RFC 6455 +// "Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in response" +func (c *Conn) SetPingHandler(f func(ctx context.Context, p []byte) error) { + c.pingHandler = func(ctx context.Context, p []byte) error { + if err := c.writeControl(ctx, opPong, p); err != nil { + return err + } + if f != nil { + return f(ctx, p) + } + return nil + } +} + +// SetPongHandler set the handler for ping message +// By default, do nothing +func (c *Conn) SetPongHandler(f func(ctx context.Context, p []byte) error) { + if f == nil { + f = func(context.Context, []byte) error { return nil } + } + c.pongHandler = f +} + const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { @@ -265,7 +289,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: - return c.writeControl(ctx, opPong, b) + return c.pingHandler(ctx, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] @@ -276,7 +300,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { default: } } - return nil + return c.pongHandler(ctx, b) } defer func() {