Skip to content

Commit a154a33

Browse files
committed
Add ping and pong received callbacks
This change adds two optional callbacks to both `DialOptions` and `AcceptOptions`. These callbacks are invoked synchronously when a ping or pong frame is received, allowing advanced users to log or inspect payloads for metrics or debugging. If the callback needs to perform more complex work or reuse the payload outside the callback, it is recommended to clone the byte slice and/or perform processing in a separate goroutine. Tests confirm that the ping/pong callbacks are invoked as expected. References #246
1 parent 11bda98 commit a154a33

File tree

5 files changed

+100
-5
lines changed

5 files changed

+100
-5
lines changed

accept.go

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package websocket
55

66
import (
77
"bytes"
8+
"context"
89
"crypto/sha1"
910
"encoding/base64"
1011
"errors"
@@ -62,6 +63,18 @@ type AcceptOptions struct {
6263
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
6364
// for CompressionContextTakeover.
6465
CompressionThreshold int
66+
67+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
68+
//
69+
// To avoid blocking, process the callback asynchronously using a goroutine.
70+
// If you need to reuse the payload outside the callback, clone the byte slice.
71+
// Any modifications to the payload within the callback will be sent in the subsequent pong frame.
72+
OnPingReceived func(context.Context, []byte)
73+
74+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
75+
//
76+
// To avoid blocking, process the callback asynchronously using a goroutine.
77+
OnPongReceived func(context.Context, []byte)
6578
}
6679

6780
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
@@ -156,6 +169,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
156169
client: false,
157170
copts: copts,
158171
flateThreshold: opts.CompressionThreshold,
172+
onPingReceived: opts.OnPingReceived,
173+
onPongReceived: opts.OnPongReceived,
159174

160175
br: brw.Reader,
161176
bw: brw.Writer,

conn.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ type Conn struct {
8383
closeMu sync.Mutex // Protects following.
8484
closed chan struct{}
8585

86-
pingCounter atomic.Int64
87-
activePingsMu sync.Mutex
88-
activePings map[string]chan<- struct{}
86+
pingCounter atomic.Int64
87+
activePingsMu sync.Mutex
88+
activePings map[string]chan<- struct{}
89+
onPingReceived func(context.Context, []byte)
90+
onPongReceived func(context.Context, []byte)
8991
}
9092

9193
type connConfig struct {
@@ -94,6 +96,8 @@ type connConfig struct {
9496
client bool
9597
copts *compressionOptions
9698
flateThreshold int
99+
onPingReceived func(context.Context, []byte)
100+
onPongReceived func(context.Context, []byte)
97101

98102
br *bufio.Reader
99103
bw *bufio.Writer
@@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn {
114118
writeTimeout: make(chan context.Context),
115119
timeoutLoopDone: make(chan struct{}),
116120

117-
closed: make(chan struct{}),
118-
activePings: make(map[string]chan<- struct{}),
121+
closed: make(chan struct{}),
122+
activePings: make(map[string]chan<- struct{}),
123+
onPingReceived: cfg.onPingReceived,
124+
onPongReceived: cfg.onPongReceived,
119125
}
120126

121127
c.readMu = newMu(c)

conn_test.go

+54
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,60 @@ func TestConn(t *testing.T) {
9797
assert.Contains(t, err, "failed to wait for pong")
9898
})
9999

100+
t.Run("pingPongReceived1", func(t *testing.T) {
101+
var pingReceived, pongReceived bool
102+
tt, c1, c2 := newConnTest(t,
103+
&websocket.DialOptions{
104+
OnPingReceived: func(ctx context.Context, payload []byte) {
105+
pingReceived = true
106+
},
107+
}, &websocket.AcceptOptions{
108+
OnPongReceived: func(ctx context.Context, payload []byte) {
109+
pongReceived = true
110+
},
111+
},
112+
)
113+
114+
c1.CloseRead(tt.ctx)
115+
c2.CloseRead(tt.ctx)
116+
117+
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
118+
defer cancel()
119+
120+
err := c1.Ping(ctx)
121+
assert.Success(t, err)
122+
123+
assert.Equal(t, "ping received", true, pingReceived)
124+
assert.Equal(t, "pong received", true, pongReceived)
125+
})
126+
127+
t.Run("pingPongReceived2", func(t *testing.T) {
128+
var pingReceived, pongReceived bool
129+
tt, c1, c2 := newConnTest(t,
130+
&websocket.DialOptions{
131+
OnPongReceived: func(ctx context.Context, payload []byte) {
132+
pongReceived = true
133+
},
134+
}, &websocket.AcceptOptions{
135+
OnPingReceived: func(ctx context.Context, payload []byte) {
136+
pingReceived = true
137+
},
138+
},
139+
)
140+
141+
c1.CloseRead(tt.ctx)
142+
c2.CloseRead(tt.ctx)
143+
144+
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
145+
defer cancel()
146+
147+
err := c2.Ping(ctx)
148+
assert.Success(t, err)
149+
150+
assert.Equal(t, "ping received", true, pingReceived)
151+
assert.Equal(t, "pong received", true, pongReceived)
152+
})
153+
100154
t.Run("concurrentWrite", func(t *testing.T) {
101155
tt, c1, c2 := newConnTest(t, nil, nil)
102156

dial.go

+14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ type DialOptions struct {
4848
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
4949
// for CompressionContextTakeover.
5050
CompressionThreshold int
51+
52+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
53+
//
54+
// To avoid blocking, process the callback asynchronously using a goroutine.
55+
// If you need to reuse the payload outside the callback, clone the byte slice.
56+
// Any modifications to the payload within the callback will be sent in the subsequent pong frame.
57+
OnPingReceived func(context.Context, []byte)
58+
59+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
60+
//
61+
// To avoid blocking, process the callback asynchronously using a goroutine.
62+
OnPongReceived func(context.Context, []byte)
5163
}
5264

5365
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
@@ -163,6 +175,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
163175
client: true,
164176
copts: copts,
165177
flateThreshold: opts.CompressionThreshold,
178+
onPingReceived: opts.OnPingReceived,
179+
onPongReceived: opts.OnPongReceived,
166180
br: getBufioReader(rwc),
167181
bw: getBufioWriter(rwc),
168182
}), resp, nil

read.go

+6
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,14 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
312312

313313
switch h.opcode {
314314
case opPing:
315+
if c.onPingReceived != nil {
316+
c.onPingReceived(ctx, b)
317+
}
315318
return c.writeControl(ctx, opPong, b)
316319
case opPong:
320+
if c.onPongReceived != nil {
321+
c.onPongReceived(ctx, b)
322+
}
317323
c.activePingsMu.Lock()
318324
pong, ok := c.activePings[string(b)]
319325
c.activePingsMu.Unlock()

0 commit comments

Comments
 (0)