Skip to content

Commit fede4d8

Browse files
committed
Use channels to ensure virtualPacketConns get closed.
1 parent e658b90 commit fede4d8

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

service/listeners.go

+58-5
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ func (t *TCPListener) Addr() net.Addr {
6464
return t.ln.Addr()
6565
}
6666

67+
type OnCloseFunc func() error
68+
6769
type acceptResponse struct {
6870
conn transport.StreamConn
6971
err error
7072
}
7173

72-
type OnCloseFunc func() error
73-
7474
type virtualStreamListener struct {
7575
mu sync.Mutex // Mutex to protect access to the channels
7676
addr net.Addr
@@ -119,14 +119,52 @@ func (sl *virtualStreamListener) Addr() net.Addr {
119119
return sl.addr
120120
}
121121

122+
type packetResponse struct {
123+
n int
124+
addr net.Addr
125+
err error
126+
data []byte
127+
}
128+
122129
type virtualPacketConn struct {
123130
net.PacketConn
131+
mu sync.Mutex // Mutex to protect access to the channels
132+
readCh <-chan packetResponse
133+
closeCh chan struct{}
134+
closed bool
124135
onCloseFunc OnCloseFunc
125136
}
126137

127-
func (spc *virtualPacketConn) Close() error {
128-
if spc.onCloseFunc != nil {
129-
return spc.onCloseFunc()
138+
func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
139+
pc.mu.Lock()
140+
readCh := pc.readCh
141+
pc.mu.Unlock()
142+
143+
select {
144+
case packetResponse, ok := <-readCh:
145+
if !ok {
146+
return 0, nil, net.ErrClosed
147+
}
148+
copy(p, packetResponse.data)
149+
return packetResponse.n, packetResponse.addr, packetResponse.err
150+
case <-pc.closeCh:
151+
return 0, nil, net.ErrClosed
152+
}
153+
}
154+
155+
func (pc *virtualPacketConn) Close() error {
156+
pc.mu.Lock()
157+
if pc.closed {
158+
pc.mu.Unlock()
159+
return nil
160+
}
161+
pc.closed = true
162+
pc.readCh = nil
163+
close(pc.closeCh)
164+
pc.mu.Unlock()
165+
166+
if pc.onCloseFunc != nil {
167+
return pc.onCloseFunc()
130168
}
131169
return nil
132170
}
@@ -204,6 +242,7 @@ type multiPacketListener struct {
204242
mu sync.Mutex
205243
addr string
206244
pc RefCount[net.PacketConn]
245+
readCh chan packetResponse
207246
onCloseFunc OnCloseFunc
208247
}
209248

@@ -226,6 +265,18 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
226265
return nil, err
227266
}
228267
m.pc = NewRefCount(pc, m.onCloseFunc)
268+
m.readCh = make(chan packetResponse)
269+
go func() {
270+
for {
271+
buffer := make([]byte, serverUDPBufferSize)
272+
n, addr, err := pc.ReadFrom(buffer)
273+
if err != nil {
274+
close(m.readCh)
275+
return
276+
}
277+
m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]}
278+
}
279+
}()
229280
}
230281
return m.pc, nil
231282
}()
@@ -236,6 +287,8 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
236287
pc := refCount.Acquire()
237288
return &virtualPacketConn{
238289
PacketConn: pc,
290+
readCh: m.readCh,
291+
closeCh: make(chan struct{}),
239292
onCloseFunc: refCount.Close,
240293
}, nil
241294
}

0 commit comments

Comments
 (0)