Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ping and pong received callbacks #509

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package websocket

import (
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -62,6 +63,22 @@ type AcceptOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int

// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool

// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
Expand Down Expand Up @@ -156,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,

br: brw.Reader,
bw: brw.Writer,
Expand Down
16 changes: 11 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ type Conn struct {
closeMu sync.Mutex // Protects following.
closed chan struct{}

pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}

type connConfig struct {
Expand All @@ -94,6 +96,8 @@ type connConfig struct {
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)

br *bufio.Reader
bw *bufio.Writer
Expand All @@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn {
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),

closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}

c.readMu = newMu(c)
Expand Down
79 changes: 79 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,85 @@ func TestConn(t *testing.T) {
assert.Contains(t, err, "failed to wait for pong")
})

t.Run("pingReceivedPongReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)

c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)

ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()

err := c1.Ping(ctx)
assert.Success(t, err)

c1.CloseNow()
c2.CloseNow()

assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
})

t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)

c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)

ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()

err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")

c1.CloseNow()
c2.CloseNow()

assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
})

t.Run("concurrentWrite", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)

Expand Down
18 changes: 18 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ type DialOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int

// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool

// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
Expand Down Expand Up @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
Expand Down
8 changes: 8 additions & 0 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

switch h.opcode {
case opPing:
if c.onPingReceived != nil {
if !c.onPingReceived(ctx, b) {
return nil
}
}
return c.writeControl(ctx, opPong, b)
case opPong:
if c.onPongReceived != nil {
c.onPongReceived(ctx, b)
}
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
Expand Down
Loading