15
15
package service
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"fmt"
20
21
"io"
@@ -98,14 +99,13 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) {
98
99
99
100
func (sl * virtualStreamListener ) Close () error {
100
101
sl .mu .Lock ()
102
+ defer sl .mu .Unlock ()
103
+
101
104
if sl .acceptCh == nil {
102
- sl .mu .Unlock ()
103
105
return nil
104
106
}
105
107
sl .acceptCh = nil
106
108
close (sl .closeCh )
107
- sl .mu .Unlock ()
108
-
109
109
if sl .onCloseFunc != nil {
110
110
return sl .onCloseFunc ()
111
111
}
@@ -116,47 +116,54 @@ func (sl *virtualStreamListener) Addr() net.Addr {
116
116
return sl .addr
117
117
}
118
118
119
- type packetResponse struct {
120
- n int
121
- addr net.Addr
122
- err error
123
- data []byte
119
+ type readRequest struct {
120
+ buffer []byte
121
+ respCh chan struct { // Use a buffered channel for respCh
122
+ n int
123
+ addr net.Addr
124
+ err error
125
+ }
124
126
}
125
127
126
128
type virtualPacketConn struct {
127
129
net.PacketConn
128
130
mu sync.Mutex // Mutex to protect access to the channels
129
- readCh <- chan packetResponse
131
+ readCh chan readRequest
130
132
closeCh chan struct {}
131
133
onCloseFunc OnCloseFunc
132
134
}
133
135
134
136
func (pc * virtualPacketConn ) ReadFrom (p []byte ) (n int , addr net.Addr , err error ) {
135
- pc .mu .Lock ()
136
- readCh := pc .readCh
137
- pc .mu .Unlock ()
137
+ respCh := make (chan struct {
138
+ n int
139
+ addr net.Addr
140
+ err error
141
+ }, 1 )
138
142
139
- select {
140
- case packetResponse , ok := <- readCh :
141
- if ! ok {
142
- return 0 , nil , net .ErrClosed
143
- }
144
- copy (p , packetResponse .data )
145
- return packetResponse .n , packetResponse .addr , packetResponse .err
146
- case <- pc .closeCh :
143
+ pc .mu .Lock ()
144
+ if pc .readCh == nil {
145
+ pc .mu .Unlock ()
147
146
return 0 , nil , net .ErrClosed
148
147
}
148
+ pc .readCh <- readRequest {
149
+ buffer : p ,
150
+ respCh : respCh ,
151
+ }
152
+ pc .mu .Unlock ()
153
+
154
+ resp := <- respCh
155
+ return resp .n , resp .addr , resp .err
149
156
}
150
157
151
158
func (pc * virtualPacketConn ) Close () error {
152
159
pc .mu .Lock ()
160
+ defer pc .mu .Unlock ()
161
+
153
162
if pc .readCh == nil {
154
- pc .mu .Unlock ()
155
163
return nil
156
164
}
157
165
pc .readCh = nil
158
166
close (pc .closeCh )
159
- pc .mu .Unlock ()
160
167
161
168
if pc .onCloseFunc != nil {
162
169
return pc .onCloseFunc ()
@@ -242,7 +249,8 @@ type multiPacketListener struct {
242
249
mu sync.Mutex
243
250
addr string
244
251
pc net.PacketConn
245
- readCh chan packetResponse
252
+ readCh chan readRequest
253
+ cancel context.CancelFunc
246
254
count uint32
247
255
onCloseFunc OnCloseFunc
248
256
}
@@ -265,16 +273,22 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
265
273
return nil , err
266
274
}
267
275
m .pc = pc
268
- m .readCh = make (chan packetResponse )
276
+ m .readCh = make (chan readRequest )
277
+ ctx , cancel := context .WithCancel (context .Background ())
278
+ m .cancel = cancel
269
279
go func () {
270
280
for {
271
- buffer := make ([]byte , serverUDPBufferSize )
272
- n , addr , err := pc .ReadFrom (buffer )
273
- if err != nil {
274
- close (m .readCh )
281
+ select {
282
+ case req := <- m .readCh :
283
+ n , addr , err := pc .ReadFrom (req .buffer )
284
+ req .respCh <- struct { // Send the response to the buffered channel
285
+ n int
286
+ addr net.Addr
287
+ err error
288
+ }{n , addr , err }
289
+ case <- ctx .Done ():
275
290
return
276
291
}
277
- m .readCh <- packetResponse {n : n , addr : addr , err : err , data : buffer [:n ]}
278
292
}
279
293
}()
280
294
}
@@ -289,6 +303,7 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) {
289
303
defer m .mu .Unlock ()
290
304
m .count --
291
305
if m .count == 0 {
306
+ m .cancel ()
292
307
m .pc .Close ()
293
308
if m .onCloseFunc != nil {
294
309
return m .onCloseFunc ()
0 commit comments