Skip to content

Commit e7d30a0

Browse files
authored
refactor: pass a dialer to TCP serving (#150)
1 parent 6fc944e commit e7d30a0

File tree

8 files changed

+149
-66
lines changed

8 files changed

+149
-66
lines changed

internal/integration_test/integration_test.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"github.com/Jigsaw-Code/outline-sdk/transport"
2727
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
2828
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
29-
onet "github.com/Jigsaw-Code/outline-ss-server/net"
3029
"github.com/Jigsaw-Code/outline-ss-server/service"
3130
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
3231
sstest "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
@@ -41,7 +40,7 @@ func init() {
4140
logging.SetLevel(logging.INFO, "")
4241
}
4342

44-
func allowAll(ip net.IP) *onet.ConnectionError {
43+
func allowAll(ip net.IP) error {
4544
// Allow access to localhost so that we can run integration tests with
4645
// an actual destination server.
4746
return nil
@@ -114,7 +113,7 @@ func TestTCPEcho(t *testing.T) {
114113
replayCache := service.NewReplayCache(5)
115114
const testTimeout = 200 * time.Millisecond
116115
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
117-
handler.SetTargetIPValidator(allowAll)
116+
handler.SetTargetDialer(&transport.TCPStreamDialer{})
118117
done := make(chan struct{})
119118
go func() {
120119
service.StreamServe(func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, handler.Handle)
@@ -362,7 +361,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
362361
}
363362
const testTimeout = 200 * time.Millisecond
364363
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, &service.NoOpTCPMetrics{}, testTimeout)
365-
handler.SetTargetIPValidator(allowAll)
364+
handler.SetTargetDialer(&transport.TCPStreamDialer{})
366365
done := make(chan struct{})
367366
go func() {
368367
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
@@ -424,7 +423,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
424423
replayCache := service.NewReplayCache(service.MaxCapacity)
425424
const testTimeout = 200 * time.Millisecond
426425
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
427-
handler.SetTargetIPValidator(allowAll)
426+
handler.SetTargetDialer(&transport.TCPStreamDialer{})
428427
done := make(chan struct{})
429428
go func() {
430429
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)

net/error.go

+20
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,23 @@ type ConnectionError struct {
2424
func NewConnectionError(status, message string, cause error) *ConnectionError {
2525
return &ConnectionError{Status: status, Message: message, Cause: cause}
2626
}
27+
28+
func (e *ConnectionError) Error() string {
29+
if e == nil {
30+
return "<nil>"
31+
}
32+
msg := e.Message
33+
if len(e.Status) > 0 {
34+
msg += " [" + e.Status + "]"
35+
}
36+
if e.Cause != nil {
37+
msg += ": " + e.Cause.Error()
38+
}
39+
return msg
40+
}
41+
42+
func (e *ConnectionError) Unwrap() error {
43+
return e.Cause
44+
}
45+
46+
var _ error = (*ConnectionError)(nil)

net/error_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2019 Jigsaw Operations LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package net
16+
17+
import (
18+
"errors"
19+
"fmt"
20+
"testing"
21+
22+
"github.com/stretchr/testify/require"
23+
)
24+
25+
func TestConnectionErrorUnwrapCause(t *testing.T) {
26+
cause := errors.New("cause")
27+
err := &ConnectionError{Cause: cause}
28+
require.Equal(t, cause, err.Unwrap())
29+
require.ErrorIs(t, err, cause)
30+
}
31+
32+
func TestConnectionErrorString(t *testing.T) {
33+
require.Equal(t, "example message", (&ConnectionError{Message: "example message"}).Error())
34+
require.Equal(t, "example message [ERR_EXAMPLE]", (&ConnectionError{Message: "example message", Status: "ERR_EXAMPLE"}).Error())
35+
36+
cause := errors.New("cause")
37+
err := &ConnectionError{Status: "ERR_EXAMPLE", Message: "example message", Cause: cause}
38+
require.Equal(t, "example message [ERR_EXAMPLE]: cause", err.Error())
39+
}
40+
41+
func TestConnectionErrorFromUnwrap(t *testing.T) {
42+
connErr := &ConnectionError{Message: "connection error"}
43+
topErr := fmt.Errorf("top error: %w", connErr)
44+
require.NotEqual(t, topErr, connErr)
45+
require.ErrorIs(t, topErr, connErr)
46+
var unwrapped *ConnectionError
47+
require.True(t, errors.As(topErr, &unwrapped))
48+
require.Equal(t, connErr, unwrapped)
49+
}

net/private_net.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ func IsPrivateAddress(ip net.IP) bool {
4848
}
4949

5050
// TargetIPValidator is a type alias for checking if an IP is allowed.
51-
type TargetIPValidator = func(net.IP) *ConnectionError
51+
type TargetIPValidator = func(net.IP) error
5252

5353
// RequirePublicIP returns an error if the destination IP is not a
5454
// standard public IP.
55-
func RequirePublicIP(ip net.IP) *ConnectionError {
55+
func RequirePublicIP(ip net.IP) error {
5656
if !ip.IsGlobalUnicast() {
5757
return NewConnectionError("ERR_ADDRESS_INVALID", fmt.Sprintf("Address is not global unicast: %s", ip.String()), nil)
5858
}

net/private_net_test.go

+34-20
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
package net
1616

1717
import (
18+
"errors"
1819
"net"
1920
"testing"
21+
22+
"github.com/stretchr/testify/assert"
2023
)
2124

2225
var privateAddressTests = []struct {
@@ -46,39 +49,50 @@ func TestIsLanAddress(t *testing.T) {
4649
}
4750

4851
func TestRequirePublicIP(t *testing.T) {
49-
if err := RequirePublicIP(net.ParseIP("8.8.8.8")); err != nil {
50-
t.Error(err)
51-
}
52+
var err error
53+
54+
assert.Nil(t, RequirePublicIP(net.ParseIP("8.8.8.8")))
5255

5356
if err := RequirePublicIP(net.ParseIP("2001:4860:4860::8888")); err != nil {
5457
t.Error(err)
5558
}
5659

57-
err := RequirePublicIP(net.ParseIP("192.168.0.23"))
58-
if err == nil {
59-
t.Error("Expected error")
60-
} else if err.Status != "ERR_ADDRESS_PRIVATE" {
61-
t.Errorf("Wrong status %s", err.Status)
60+
err = RequirePublicIP(net.ParseIP("192.168.0.23"))
61+
if assert.NotNil(t, err) {
62+
var connErr *ConnectionError
63+
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
64+
assert.Equal(t, "ERR_ADDRESS_PRIVATE", connErr.Status)
65+
}
6266
}
6367

6468
err = RequirePublicIP(net.ParseIP("::1"))
65-
if err == nil {
66-
t.Error("Expected error")
67-
} else if err.Status != "ERR_ADDRESS_INVALID" {
68-
t.Errorf("Wrong status %s", err.Status)
69+
if assert.NotNil(t, err) {
70+
var connErr *ConnectionError
71+
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
72+
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
73+
}
6974
}
7075

7176
err = RequirePublicIP(net.ParseIP("224.0.0.251"))
72-
if err == nil {
73-
t.Error("Expected error")
74-
} else if err.Status != "ERR_ADDRESS_INVALID" {
75-
t.Errorf("Wrong status %s", err.Status)
77+
if assert.NotNil(t, err) {
78+
var connErr *ConnectionError
79+
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
80+
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
81+
}
7682
}
7783

7884
err = RequirePublicIP(net.ParseIP("ff02::fb"))
79-
if err == nil {
80-
t.Error("Expected error")
81-
} else if err.Status != "ERR_ADDRESS_INVALID" {
82-
t.Errorf("Wrong status %s", err.Status)
85+
if assert.NotNil(t, err) {
86+
var connErr *ConnectionError
87+
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
88+
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
89+
}
8390
}
8491
}
92+
93+
func TestRequirePublicIPInterface(t *testing.T) {
94+
var err error
95+
err = RequirePublicIP(net.ParseIP("8.8.8.8"))
96+
assert.True(t, err == nil)
97+
assert.Equal(t, nil, err)
98+
}

service/tcp.go

+35-33
Original file line numberDiff line numberDiff line change
@@ -124,53 +124,53 @@ type tcpHandler struct {
124124
m TCPMetrics
125125
readTimeout time.Duration
126126
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
127-
replayCache *ReplayCache
128-
targetIPValidator onet.TargetIPValidator
127+
replayCache *ReplayCache
128+
dialer transport.StreamDialer
129129
}
130130

131131
// NewTCPService creates a TCPService
132132
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
133133
func NewTCPHandler(port int, ciphers CipherList, replayCache *ReplayCache, m TCPMetrics, timeout time.Duration) TCPHandler {
134134
return &tcpHandler{
135-
port: port,
136-
ciphers: ciphers,
137-
m: m,
138-
readTimeout: timeout,
139-
replayCache: replayCache,
140-
targetIPValidator: onet.RequirePublicIP,
135+
port: port,
136+
ciphers: ciphers,
137+
m: m,
138+
readTimeout: timeout,
139+
replayCache: replayCache,
140+
dialer: defaultDialer,
141141
}
142142
}
143143

144+
var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP)
145+
146+
func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) transport.StreamDialer {
147+
return &transport.TCPStreamDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
148+
ip, _, _ := net.SplitHostPort(address)
149+
return targetIPValidator(net.ParseIP(ip))
150+
}}}
151+
}
152+
144153
// TCPService is a Shadowsocks TCP service that can be started and stopped.
145154
type TCPHandler interface {
146155
Handle(ctx context.Context, conn transport.StreamConn)
147-
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
148-
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
156+
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
157+
SetTargetDialer(dialer transport.StreamDialer)
149158
}
150159

151-
func (s *tcpHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
152-
s.targetIPValidator = targetIPValidator
160+
func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) {
161+
s.dialer = dialer
153162
}
154163

155-
func dialTarget(tgtAddr socks.Addr, proxyMetrics *metrics.ProxyMetrics, targetIPValidator onet.TargetIPValidator) (transport.StreamConn, *onet.ConnectionError) {
156-
var ipError *onet.ConnectionError
157-
dialer := net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
158-
ip, _, _ := net.SplitHostPort(address)
159-
ipError = targetIPValidator(net.ParseIP(ip))
160-
if ipError != nil {
161-
return errors.New(ipError.Message)
162-
}
164+
func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) *onet.ConnectionError {
165+
if err == nil {
163166
return nil
164-
}}
165-
tgtConn, err := dialer.Dial("tcp", tgtAddr.String())
166-
if ipError != nil {
167-
return nil, ipError
168-
} else if err != nil {
169-
return nil, onet.NewConnectionError("ERR_CONNECT", "Failed to connect to target", err)
170167
}
171-
tgtTCPConn := tgtConn.(*net.TCPConn)
172-
tgtTCPConn.SetKeepAlive(true)
173-
return metrics.MeasureConn(tgtTCPConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy), nil
168+
var connErr *onet.ConnectionError
169+
if errors.As(err, &connErr) {
170+
return connErr
171+
} else {
172+
return onet.NewConnectionError(fallbackStatus, fallbackMsg, err)
173+
}
174174
}
175175

176176
type StreamListener func() (transport.StreamConn, error)
@@ -226,7 +226,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
226226
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
227227
connStart := time.Now()
228228

229-
id, connError := h.handleConnection(h.port, measuredClientConn, &proxyMetrics)
229+
id, connError := h.handleConnection(ctx, h.port, measuredClientConn, &proxyMetrics)
230230

231231
connDuration := time.Since(connStart)
232232
status := "OK"
@@ -239,7 +239,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
239239
logger.Debugf("Done with status %v, duration %v", status, connDuration)
240240
}
241241

242-
func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
242+
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
243243
// Set a deadline to receive the address to the target.
244244
clientConn.SetReadDeadline(time.Now().Add(h.readTimeout))
245245

@@ -275,18 +275,20 @@ func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.Str
275275
// 3. Read target address and dial it.
276276
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
277277
tgtAddr, err := socks.ReadAddr(ssr)
278+
278279
// Clear the deadline for the target address
279280
clientConn.SetReadDeadline(time.Time{})
280281
if err != nil {
281282
// Drain to prevent a close on cipher error.
282283
io.Copy(io.Discard, clientConn)
283284
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
284285
}
285-
tgtConn, dialErr := dialTarget(tgtAddr, proxyMetrics, h.targetIPValidator)
286+
tgtConn, dialErr := h.dialer.Dial(ctx, tgtAddr.String())
286287
if dialErr != nil {
287288
// We don't drain so dial errors and invalid addresses are communicated quickly.
288-
return id, dialErr
289+
return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
289290
}
291+
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
290292
defer tgtConn.Close()
291293

292294
// 4. Bridge the client and target connections

service/tcp_test.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828
"github.com/Jigsaw-Code/outline-sdk/transport"
2929
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
3030
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
31-
onet "github.com/Jigsaw-Code/outline-ss-server/net"
3231
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
3332
logging "github.com/op/go-logging"
3433
"github.com/shadowsocks/go-shadowsocks2/socks"
@@ -39,7 +38,7 @@ func init() {
3938
logging.SetLevel(logging.INFO, "")
4039
}
4140

42-
func allowAll(ip net.IP) *onet.ConnectionError {
41+
func allowAll(ip net.IP) error {
4342
// Allow access to localhost so that we can run integration tests with
4443
// an actual destination server.
4544
return nil
@@ -353,7 +352,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
353352
cipher := firstCipher(cipherList)
354353
testMetrics := &probeTestMetrics{}
355354
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
356-
handler.SetTargetIPValidator(allowAll)
355+
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
357356
done := make(chan struct{})
358357
go func() {
359358
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
@@ -388,7 +387,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
388387
cipher := firstCipher(cipherList)
389388
testMetrics := &probeTestMetrics{}
390389
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
391-
handler.SetTargetIPValidator(allowAll)
390+
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
392391
done := make(chan struct{})
393392
go func() {
394393
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
@@ -424,7 +423,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
424423
cipher := firstCipher(cipherList)
425424
testMetrics := &probeTestMetrics{}
426425
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
427-
handler.SetTargetIPValidator(allowAll)
426+
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
428427
done := make(chan struct{})
429428
go func() {
430429
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)

service/udp.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *
230230
return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
231231
}
232232
if err := h.targetIPValidator(tgtUDPAddr.IP); err != nil {
233-
return nil, nil, err
233+
return nil, nil, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address")
234234
}
235235

236236
payload := textData[len(tgtAddr):]

0 commit comments

Comments
 (0)