Skip to content

Commit 851ac28

Browse files
committed
Implement channel-based packet read for virtual connections.
1 parent 53b1e96 commit 851ac28

File tree

1 file changed

+44
-29
lines changed

1 file changed

+44
-29
lines changed

service/listeners.go

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package service
1616

1717
import (
18+
"context"
1819
"errors"
1920
"fmt"
2021
"io"
@@ -98,14 +99,13 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) {
9899

99100
func (sl *virtualStreamListener) Close() error {
100101
sl.mu.Lock()
102+
defer sl.mu.Unlock()
103+
101104
if sl.acceptCh == nil {
102-
sl.mu.Unlock()
103105
return nil
104106
}
105107
sl.acceptCh = nil
106108
close(sl.closeCh)
107-
sl.mu.Unlock()
108-
109109
if sl.onCloseFunc != nil {
110110
return sl.onCloseFunc()
111111
}
@@ -116,47 +116,54 @@ func (sl *virtualStreamListener) Addr() net.Addr {
116116
return sl.addr
117117
}
118118

119-
type packetResponse struct {
120-
n int
121-
addr net.Addr
122-
err error
123-
data []byte
119+
type readRequest struct {
120+
buffer []byte
121+
respCh chan struct { // Use a buffered channel for respCh
122+
n int
123+
addr net.Addr
124+
err error
125+
}
124126
}
125127

126128
type virtualPacketConn struct {
127129
net.PacketConn
128130
mu sync.Mutex // Mutex to protect access to the channels
129-
readCh <-chan packetResponse
131+
readCh chan readRequest
130132
closeCh chan struct{}
131133
onCloseFunc OnCloseFunc
132134
}
133135

134136
func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
135-
pc.mu.Lock()
136-
readCh := pc.readCh
137-
pc.mu.Unlock()
137+
respCh := make(chan struct {
138+
n int
139+
addr net.Addr
140+
err error
141+
}, 1)
138142

139-
select {
140-
case packetResponse, ok := <-readCh:
141-
if !ok {
142-
return 0, nil, net.ErrClosed
143-
}
144-
copy(p, packetResponse.data)
145-
return packetResponse.n, packetResponse.addr, packetResponse.err
146-
case <-pc.closeCh:
143+
pc.mu.Lock()
144+
if pc.readCh == nil {
145+
pc.mu.Unlock()
147146
return 0, nil, net.ErrClosed
148147
}
148+
pc.readCh <- readRequest{
149+
buffer: p,
150+
respCh: respCh,
151+
}
152+
pc.mu.Unlock()
153+
154+
resp := <-respCh
155+
return resp.n, resp.addr, resp.err
149156
}
150157

151158
func (pc *virtualPacketConn) Close() error {
152159
pc.mu.Lock()
160+
defer pc.mu.Unlock()
161+
153162
if pc.readCh == nil {
154-
pc.mu.Unlock()
155163
return nil
156164
}
157165
pc.readCh = nil
158166
close(pc.closeCh)
159-
pc.mu.Unlock()
160167

161168
if pc.onCloseFunc != nil {
162169
return pc.onCloseFunc()
@@ -242,7 +249,8 @@ type multiPacketListener struct {
242249
mu sync.Mutex
243250
addr string
244251
pc net.PacketConn
245-
readCh chan packetResponse
252+
readCh chan readRequest
253+
cancel context.CancelFunc
246254
count uint32
247255
onCloseFunc OnCloseFunc
248256
}
@@ -265,16 +273,22 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
265273
return nil, err
266274
}
267275
m.pc = pc
268-
m.readCh = make(chan packetResponse)
276+
m.readCh = make(chan readRequest)
277+
ctx, cancel := context.WithCancel(context.Background())
278+
m.cancel = cancel
269279
go func() {
270280
for {
271-
buffer := make([]byte, serverUDPBufferSize)
272-
n, addr, err := pc.ReadFrom(buffer)
273-
if err != nil {
274-
close(m.readCh)
281+
select {
282+
case req := <-m.readCh:
283+
n, addr, err := pc.ReadFrom(req.buffer)
284+
req.respCh <- struct { // Send the response to the buffered channel
285+
n int
286+
addr net.Addr
287+
err error
288+
}{n, addr, err}
289+
case <-ctx.Done():
275290
return
276291
}
277-
m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]}
278292
}
279293
}()
280294
}
@@ -289,6 +303,7 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
289303
defer m.mu.Unlock()
290304
m.count--
291305
if m.count == 0 {
306+
m.cancel()
292307
m.pc.Close()
293308
if m.onCloseFunc != nil {
294309
return m.onCloseFunc()

0 commit comments

Comments
 (0)