@@ -124,53 +124,53 @@ type tcpHandler struct {
124
124
m TCPMetrics
125
125
readTimeout time.Duration
126
126
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
127
- replayCache * ReplayCache
128
- targetIPValidator onet. TargetIPValidator
127
+ replayCache * ReplayCache
128
+ dialer transport. StreamDialer
129
129
}
130
130
131
131
// NewTCPService creates a TCPService
132
132
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
133
133
func NewTCPHandler (port int , ciphers CipherList , replayCache * ReplayCache , m TCPMetrics , timeout time.Duration ) TCPHandler {
134
134
return & tcpHandler {
135
- port : port ,
136
- ciphers : ciphers ,
137
- m : m ,
138
- readTimeout : timeout ,
139
- replayCache : replayCache ,
140
- targetIPValidator : onet . RequirePublicIP ,
135
+ port : port ,
136
+ ciphers : ciphers ,
137
+ m : m ,
138
+ readTimeout : timeout ,
139
+ replayCache : replayCache ,
140
+ dialer : defaultDialer ,
141
141
}
142
142
}
143
143
144
+ var defaultDialer = makeValidatingTCPStreamDialer (onet .RequirePublicIP )
145
+
146
+ func makeValidatingTCPStreamDialer (targetIPValidator onet.TargetIPValidator ) transport.StreamDialer {
147
+ return & transport.TCPStreamDialer {Dialer : net.Dialer {Control : func (network , address string , c syscall.RawConn ) error {
148
+ ip , _ , _ := net .SplitHostPort (address )
149
+ return targetIPValidator (net .ParseIP (ip ))
150
+ }}}
151
+ }
152
+
144
153
// TCPService is a Shadowsocks TCP service that can be started and stopped.
145
154
type TCPHandler interface {
146
155
Handle (ctx context.Context , conn transport.StreamConn )
147
- // SetTargetIPValidator sets the function to be used to validate the target IP addresses.
148
- SetTargetIPValidator ( targetIPValidator onet. TargetIPValidator )
156
+ // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
157
+ SetTargetDialer ( dialer transport. StreamDialer )
149
158
}
150
159
151
- func (s * tcpHandler ) SetTargetIPValidator ( targetIPValidator onet. TargetIPValidator ) {
152
- s .targetIPValidator = targetIPValidator
160
+ func (s * tcpHandler ) SetTargetDialer ( dialer transport. StreamDialer ) {
161
+ s .dialer = dialer
153
162
}
154
163
155
- func dialTarget (tgtAddr socks.Addr , proxyMetrics * metrics.ProxyMetrics , targetIPValidator onet.TargetIPValidator ) (transport.StreamConn , * onet.ConnectionError ) {
156
- var ipError * onet.ConnectionError
157
- dialer := net.Dialer {Control : func (network , address string , c syscall.RawConn ) error {
158
- ip , _ , _ := net .SplitHostPort (address )
159
- ipError = targetIPValidator (net .ParseIP (ip ))
160
- if ipError != nil {
161
- return errors .New (ipError .Message )
162
- }
164
+ func ensureConnectionError (err error , fallbackStatus string , fallbackMsg string ) * onet.ConnectionError {
165
+ if err == nil {
163
166
return nil
164
- }}
165
- tgtConn , err := dialer .Dial ("tcp" , tgtAddr .String ())
166
- if ipError != nil {
167
- return nil , ipError
168
- } else if err != nil {
169
- return nil , onet .NewConnectionError ("ERR_CONNECT" , "Failed to connect to target" , err )
170
167
}
171
- tgtTCPConn := tgtConn .(* net.TCPConn )
172
- tgtTCPConn .SetKeepAlive (true )
173
- return metrics .MeasureConn (tgtTCPConn , & proxyMetrics .ProxyTarget , & proxyMetrics .TargetProxy ), nil
168
+ var connErr * onet.ConnectionError
169
+ if errors .As (err , & connErr ) {
170
+ return connErr
171
+ } else {
172
+ return onet .NewConnectionError (fallbackStatus , fallbackMsg , err )
173
+ }
174
174
}
175
175
176
176
type StreamListener func () (transport.StreamConn , error )
@@ -226,7 +226,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
226
226
measuredClientConn := metrics .MeasureConn (clientConn , & proxyMetrics .ProxyClient , & proxyMetrics .ClientProxy )
227
227
connStart := time .Now ()
228
228
229
- id , connError := h .handleConnection (h .port , measuredClientConn , & proxyMetrics )
229
+ id , connError := h .handleConnection (ctx , h .port , measuredClientConn , & proxyMetrics )
230
230
231
231
connDuration := time .Since (connStart )
232
232
status := "OK"
@@ -239,7 +239,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
239
239
logger .Debugf ("Done with status %v, duration %v" , status , connDuration )
240
240
}
241
241
242
- func (h * tcpHandler ) handleConnection (listenerPort int , clientConn transport.StreamConn , proxyMetrics * metrics.ProxyMetrics ) (string , * onet.ConnectionError ) {
242
+ func (h * tcpHandler ) handleConnection (ctx context. Context , listenerPort int , clientConn transport.StreamConn , proxyMetrics * metrics.ProxyMetrics ) (string , * onet.ConnectionError ) {
243
243
// Set a deadline to receive the address to the target.
244
244
clientConn .SetReadDeadline (time .Now ().Add (h .readTimeout ))
245
245
@@ -275,18 +275,20 @@ func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.Str
275
275
// 3. Read target address and dial it.
276
276
ssr := shadowsocks .NewReader (clientReader , cipherEntry .CryptoKey )
277
277
tgtAddr , err := socks .ReadAddr (ssr )
278
+
278
279
// Clear the deadline for the target address
279
280
clientConn .SetReadDeadline (time.Time {})
280
281
if err != nil {
281
282
// Drain to prevent a close on cipher error.
282
283
io .Copy (io .Discard , clientConn )
283
284
return id , onet .NewConnectionError ("ERR_READ_ADDRESS" , "Failed to get target address" , err )
284
285
}
285
- tgtConn , dialErr := dialTarget ( tgtAddr , proxyMetrics , h . targetIPValidator )
286
+ tgtConn , dialErr := h . dialer . Dial ( ctx , tgtAddr . String () )
286
287
if dialErr != nil {
287
288
// We don't drain so dial errors and invalid addresses are communicated quickly.
288
- return id , dialErr
289
+ return id , ensureConnectionError ( dialErr , "ERR_CONNECT" , "Failed to connect to target" )
289
290
}
291
+ tgtConn = metrics .MeasureConn (tgtConn , & proxyMetrics .ProxyTarget , & proxyMetrics .TargetProxy )
290
292
defer tgtConn .Close ()
291
293
292
294
// 4. Bridge the client and target connections
0 commit comments