Skip to content

Commit c4d9214

Browse files
authored
cleanup: clean up TCP calls and use netip (#179)
1 parent 4c35a51 commit c4d9214

File tree

9 files changed

+58
-50
lines changed

9 files changed

+58
-50
lines changed

.github/workflows/go.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ jobs:
3333
run: go build -v ./...
3434

3535
- name: Test
36-
run: go test -v -race -benchmem -bench=. ./... -benchtime=100ms
36+
run: go test -race -benchmem -bench=. ./... -benchtime=100ms

cmd/outline-ss-server/metrics_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
2424
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
25+
"github.com/op/go-logging"
2526
"github.com/prometheus/client_golang/prometheus"
2627
promtest "github.com/prometheus/client_golang/prometheus/testutil"
2728
"github.com/stretchr/testify/require"
@@ -45,6 +46,10 @@ func setNow(t time.Time) {
4546
}
4647
}
4748

49+
func init() {
50+
logging.SetLevel(logging.INFO, "")
51+
}
52+
4853
func TestMethodsDontPanic(t *testing.T) {
4954
ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewPedanticRegistry())
5055
proxyMetrics := metrics.ProxyMetrics{
@@ -149,6 +154,7 @@ func BenchmarkCloseTCP(b *testing.B) {
149154
duration := time.Minute
150155
b.ResetTimer()
151156
for i := 0; i < b.N; i++ {
157+
ssMetrics.AddAuthenticatedTCPConnection(addr, accessKey)
152158
ssMetrics.AddClosedTCPConnection(ipinfo, addr, accessKey, status, data, duration)
153159
ssMetrics.AddTCPCipherSearch(true, timeToCipher)
154160
}

internal/integration_test/integration_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"io"
2222
"net"
23+
"net/netip"
2324
"sync"
2425
"testing"
2526
"time"
@@ -107,7 +108,7 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
107108
t.Logf("Failed to read from UDP conn: %v", err)
108109
return
109110
}
110-
conn.WriteTo(buf[:n], clientAddr)
111+
_, err = conn.WriteTo(buf[:n], clientAddr)
111112
if err != nil {
112113
t.Fatalf("Failed to write: %v", err)
113114
}
@@ -335,7 +336,7 @@ func TestUDPEcho(t *testing.T) {
335336
proxyConn.Close()
336337
<-done
337338
// Verify that the expected metrics were reported.
338-
snapshot := cipherList.SnapshotForClientIP(nil)
339+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
339340
keyID := snapshot[0].Value.(*service.CipherEntry).ID
340341

341342
if testMetrics.natAdded != 1 {

service/cipher_list.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package service
1616

1717
import (
1818
"container/list"
19-
"net"
19+
"net/netip"
2020
"sync"
2121

2222
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
@@ -31,7 +31,7 @@ type CipherEntry struct {
3131
ID string
3232
CryptoKey *shadowsocks.EncryptionKey
3333
SaltGenerator ServerSaltGenerator
34-
lastClientIP net.IP
34+
lastClientIP netip.Addr
3535
}
3636

3737
// MakeCipherEntry constructs a CipherEntry.
@@ -56,8 +56,8 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str
5656
// snapshotting and moving to front.
5757
type CipherList interface {
5858
// Returns a snapshot of the cipher list optimized for this client IP
59-
SnapshotForClientIP(clientIP net.IP) []*list.Element
60-
MarkUsedByClientIP(e *list.Element, clientIP net.IP)
59+
SnapshotForClientIP(clientIP netip.Addr) []*list.Element
60+
MarkUsedByClientIP(e *list.Element, clientIP netip.Addr)
6161
// Update replaces the current contents of the CipherList with `contents`,
6262
// which is a List of *CipherEntry. Update takes ownership of `contents`,
6363
// which must not be read or written after this call.
@@ -75,12 +75,12 @@ func NewCipherList() CipherList {
7575
return &cipherList{list: list.New()}
7676
}
7777

78-
func matchesIP(e *list.Element, clientIP net.IP) bool {
78+
func matchesIP(e *list.Element, clientIP netip.Addr) bool {
7979
c := e.Value.(*CipherEntry)
80-
return clientIP != nil && clientIP.Equal(c.lastClientIP)
80+
return clientIP != netip.Addr{} && clientIP == c.lastClientIP
8181
}
8282

83-
func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
83+
func (cl *cipherList) SnapshotForClientIP(clientIP netip.Addr) []*list.Element {
8484
cl.mu.RLock()
8585
defer cl.mu.RUnlock()
8686
cipherArray := make([]*list.Element, cl.list.Len())
@@ -102,7 +102,7 @@ func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
102102
return cipherArray
103103
}
104104

105-
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) {
105+
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) {
106106
cl.mu.Lock()
107107
defer cl.mu.Unlock()
108108
cl.list.MoveToFront(e)

service/cipher_list_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ package service
1616

1717
import (
1818
"math/rand"
19-
"net"
19+
"net/netip"
2020
"testing"
2121
)
2222

2323
func BenchmarkLocking(b *testing.B) {
24-
var ip net.IP
24+
var ip netip.Addr
2525

2626
ciphers, _ := MakeTestCiphers([]string{"secret"})
2727
b.ResetTimer()
2828
b.RunParallel(func(pb *testing.PB) {
2929
for pb.Next() {
30-
entries := ciphers.SnapshotForClientIP(nil)
30+
entries := ciphers.SnapshotForClientIP(netip.Addr{})
3131
ciphers.MarkUsedByClientIP(entries[0], ip)
3232
}
3333
})
@@ -43,20 +43,20 @@ func BenchmarkSnapshot(b *testing.B) {
4343

4444
// Shuffling simulates the behavior of a real server, where successive
4545
// ciphers are not expected to be nearby in memory.
46-
entries := ciphers.SnapshotForClientIP(nil)
46+
entries := ciphers.SnapshotForClientIP(netip.Addr{})
4747
rand.Shuffle(N, func(i, j int) {
4848
entries[i], entries[j] = entries[j], entries[i]
4949
})
5050
for _, entry := range entries {
5151
// Reorder the list to match the shuffle
5252
// (actually in reverse, but it doesn't matter).
53-
ciphers.MarkUsedByClientIP(entry, nil)
53+
ciphers.MarkUsedByClientIP(entry, netip.Addr{})
5454
}
5555

5656
b.ResetTimer()
5757
b.RunParallel(func(pb *testing.PB) {
5858
for pb.Next() {
59-
ciphers.SnapshotForClientIP(nil)
59+
ciphers.SnapshotForClientIP(netip.Addr{})
6060
}
6161
})
6262
}

service/tcp.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"io"
2424
"net"
25+
"net/netip"
2526
"sync"
2627
"syscall"
2728
"time"
@@ -46,19 +47,19 @@ type TCPMetrics interface {
4647
AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64)
4748
}
4849

49-
func remoteIP(conn net.Conn) net.IP {
50+
func remoteIP(conn net.Conn) netip.Addr {
5051
addr := conn.RemoteAddr()
5152
if addr == nil {
52-
return nil
53+
return netip.Addr{}
5354
}
5455
if tcpaddr, ok := addr.(*net.TCPAddr); ok {
55-
return tcpaddr.IP
56+
return tcpaddr.AddrPort().Addr()
5657
}
57-
ipstr, _, err := net.SplitHostPort(addr.String())
58+
addrPort, err := netip.ParseAddrPort(addr.String())
5859
if err == nil {
59-
return net.ParseIP(ipstr)
60+
return addrPort.Addr()
6061
}
61-
return nil
62+
return netip.Addr{}
6263
}
6364

6465
// Wrapper for logger.Debugf during TCP access key searches.
@@ -76,7 +77,7 @@ func debugTCP(cipherID, template string, val interface{}) {
7677
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
7778
const bytesForKeyFinding = 50
7879

79-
func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
80+
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
8081
// We snapshot the list because it may be modified while we use it.
8182
ciphers := cipherList.SnapshotForClientIP(clientIP)
8283
firstBytes := make([]byte, bytesForKeyFinding)
@@ -264,7 +265,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
264265
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
265266
connStart := time.Now()
266267

267-
id, connError := h.handleConnection(ctx, h.port, clientInfo, measuredClientConn, &proxyMetrics)
268+
id, connError := h.handleConnection(ctx, measuredClientConn, &proxyMetrics)
268269

269270
connDuration := time.Since(connStart)
270271
status := "OK"
@@ -327,7 +328,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr
327328
return nil
328329
}
329330

330-
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientInfo ipinfo.IPInfo, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
331+
func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
331332
// Set a deadline to receive the address to the target.
332333
readDeadline := time.Now().Add(h.readTimeout)
333334
if deadline, ok := ctx.Deadline(); ok {
@@ -341,7 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
341342
id, innerConn, authErr := h.authenticate(outerConn)
342343
if authErr != nil {
343344
// Drain to protect against probing attacks.
344-
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
345+
h.absorbProbe(outerConn, authErr.Status, proxyMetrics)
345346
return id, authErr
346347
}
347348
h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id)
@@ -369,12 +370,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
369370

370371
// Keep the connection open until we hit the authentication deadline to protect against probing attacks
371372
// `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
372-
func (h *tcpHandler) absorbProbe(listenerPort int, clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
373+
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
373374
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
374375
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
375376
drainResult := drainErrToString(drainErr)
376377
logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult)
377-
h.m.AddTCPProbe(status, drainResult, listenerPort, proxyMetrics.ClientProxy)
378+
h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy)
378379
}
379380

380381
func drainErrToString(drainErr error) string {

service/tcp_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"io"
2222
"math/rand"
2323
"net"
24+
"net/netip"
2425
"sync"
2526
"testing"
2627
"time"
@@ -99,7 +100,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
99100
if err != nil {
100101
b.Fatalf("AcceptTCP failed: %v", err)
101102
}
102-
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
103+
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
103104
b.StartTimer()
104105
findAccessKey(clientConn, clientIP, cipherList)
105106
b.StopTimer()
@@ -191,16 +192,16 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
191192
b.Fatal(err)
192193
}
193194
cipherEntries := [numCiphers]*CipherEntry{}
194-
snapshot := cipherList.SnapshotForClientIP(nil)
195+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
195196
for cipherNumber, element := range snapshot {
196197
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
197198
}
198199
for n := 0; n < b.N; n++ {
199200
cipherNumber := byte(n % numCiphers)
200201
reader, writer := io.Pipe()
201-
clientIP := net.IPv4(192, 0, 2, cipherNumber)
202-
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
203-
c := conn{clientAddr: addr, reader: reader, writer: writer}
202+
clientIP := netip.AddrFrom4([4]byte{192, 0, 2, cipherNumber})
203+
addr := netip.AddrPortFrom(clientIP, 54321)
204+
c := conn{clientAddr: net.TCPAddrFromAddrPort(addr), reader: reader, writer: writer}
204205
cipher := cipherEntries[cipherNumber].CryptoKey
205206
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
206207
b.StartTimer()
@@ -345,7 +346,7 @@ func makeClientBytesCoalesced(t *testing.T, cryptoKey *shadowsocks.EncryptionKey
345346
}
346347

347348
func firstCipher(cipherList CipherList) *shadowsocks.EncryptionKey {
348-
snapshot := cipherList.SnapshotForClientIP(nil)
349+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
349350
cipherEntry := snapshot[0].Value.(*CipherEntry)
350351
return cipherEntry.CryptoKey
351352
}
@@ -368,7 +369,6 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
368369
discardListener, discardWait := startDiscardServer(t)
369370
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
370371
for numBytesToSend := 0; numBytesToSend < len(initialBytes); numBytesToSend++ {
371-
t.Logf("Sending %v bytes", numBytesToSend)
372372
bytesToSend := initialBytes[:numBytesToSend]
373373
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
374374
require.NoError(t, err, "Failed for %v bytes sent: %v", numBytesToSend, err)
@@ -405,7 +405,6 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
405405
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
406406
bytesToSend := make([]byte, len(initialBytes))
407407
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
408-
t.Logf("Modifying byte %v", byteToModify)
409408
copy(bytesToSend, initialBytes)
410409
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
411410
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
@@ -442,7 +441,6 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
442441
initialBytes := makeClientBytesCoalesced(t, cipher, discardListener.Addr().String())
443442
bytesToSend := make([]byte, len(initialBytes))
444443
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
445-
t.Logf("Modifying byte %v", byteToModify)
446444
copy(bytesToSend, initialBytes)
447445
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
448446
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
@@ -506,7 +504,7 @@ func TestReplayDefense(t *testing.T) {
506504
const testTimeout = 200 * time.Millisecond
507505
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
508506
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
509-
snapshot := cipherList.SnapshotForClientIP(nil)
507+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
510508
cipherEntry := snapshot[0].Value.(*CipherEntry)
511509
cipher := cipherEntry.CryptoKey
512510
reader, writer := io.Pipe()
@@ -585,7 +583,7 @@ func TestReverseReplayDefense(t *testing.T) {
585583
const testTimeout = 200 * time.Millisecond
586584
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
587585
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
588-
snapshot := cipherList.SnapshotForClientIP(nil)
586+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
589587
cipherEntry := snapshot[0].Value.(*CipherEntry)
590588
cipher := cipherEntry.CryptoKey
591589
reader, writer := io.Pipe()

service/udp.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"fmt"
2020
"net"
21+
"net/netip"
2122
"runtime/debug"
2223
"sync"
2324
"time"
@@ -64,7 +65,7 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {
6465

6566
// Decrypts src into dst. It tries each cipher until it finds one that authenticates
6667
// correctly. dst and src must not overlap.
67-
func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
68+
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
6869
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
6970
// We snapshot the list because it may be modified while we use it.
7071
snapshot := cipherList.SnapshotForClientIP(clientIP)
@@ -156,7 +157,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
156157
}
157158
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)
158159

159-
ip := clientAddr.(*net.UDPAddr).IP
160+
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
160161
var textData []byte
161162
var cryptoKey *shadowsocks.EncryptionKey
162163
unpackStart := time.Now()

0 commit comments

Comments
 (0)