Skip to content

Commit 9af3061

Browse files
committed
Clean up IPs
1 parent 4c35a51 commit 9af3061

File tree

7 files changed

+51
-46
lines changed

7 files changed

+51
-46
lines changed

internal/integration_test/integration_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"io"
2222
"net"
23+
"net/netip"
2324
"sync"
2425
"testing"
2526
"time"
@@ -107,7 +108,7 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
107108
t.Logf("Failed to read from UDP conn: %v", err)
108109
return
109110
}
110-
conn.WriteTo(buf[:n], clientAddr)
111+
_, err = conn.WriteTo(buf[:n], clientAddr)
111112
if err != nil {
112113
t.Fatalf("Failed to write: %v", err)
113114
}
@@ -335,7 +336,7 @@ func TestUDPEcho(t *testing.T) {
335336
proxyConn.Close()
336337
<-done
337338
// Verify that the expected metrics were reported.
338-
snapshot := cipherList.SnapshotForClientIP(nil)
339+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
339340
keyID := snapshot[0].Value.(*service.CipherEntry).ID
340341

341342
if testMetrics.natAdded != 1 {

service/cipher_list.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package service
1616

1717
import (
1818
"container/list"
19-
"net"
19+
"net/netip"
2020
"sync"
2121

2222
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
@@ -31,7 +31,7 @@ type CipherEntry struct {
3131
ID string
3232
CryptoKey *shadowsocks.EncryptionKey
3333
SaltGenerator ServerSaltGenerator
34-
lastClientIP net.IP
34+
lastClientIP netip.Addr
3535
}
3636

3737
// MakeCipherEntry constructs a CipherEntry.
@@ -56,8 +56,8 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str
5656
// snapshotting and moving to front.
5757
type CipherList interface {
5858
// Returns a snapshot of the cipher list optimized for this client IP
59-
SnapshotForClientIP(clientIP net.IP) []*list.Element
60-
MarkUsedByClientIP(e *list.Element, clientIP net.IP)
59+
SnapshotForClientIP(clientIP netip.Addr) []*list.Element
60+
MarkUsedByClientIP(e *list.Element, clientIP netip.Addr)
6161
// Update replaces the current contents of the CipherList with `contents`,
6262
// which is a List of *CipherEntry. Update takes ownership of `contents`,
6363
// which must not be read or written after this call.
@@ -75,12 +75,12 @@ func NewCipherList() CipherList {
7575
return &cipherList{list: list.New()}
7676
}
7777

78-
func matchesIP(e *list.Element, clientIP net.IP) bool {
78+
func matchesIP(e *list.Element, clientIP netip.Addr) bool {
7979
c := e.Value.(*CipherEntry)
80-
return clientIP != nil && clientIP.Equal(c.lastClientIP)
80+
return clientIP != netip.Addr{} && clientIP == c.lastClientIP
8181
}
8282

83-
func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
83+
func (cl *cipherList) SnapshotForClientIP(clientIP netip.Addr) []*list.Element {
8484
cl.mu.RLock()
8585
defer cl.mu.RUnlock()
8686
cipherArray := make([]*list.Element, cl.list.Len())
@@ -102,7 +102,7 @@ func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
102102
return cipherArray
103103
}
104104

105-
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) {
105+
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) {
106106
cl.mu.Lock()
107107
defer cl.mu.Unlock()
108108
cl.list.MoveToFront(e)

service/cipher_list_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ package service
1616

1717
import (
1818
"math/rand"
19-
"net"
19+
"net/netip"
2020
"testing"
2121
)
2222

2323
func BenchmarkLocking(b *testing.B) {
24-
var ip net.IP
24+
var ip netip.Addr
2525

2626
ciphers, _ := MakeTestCiphers([]string{"secret"})
2727
b.ResetTimer()
2828
b.RunParallel(func(pb *testing.PB) {
2929
for pb.Next() {
30-
entries := ciphers.SnapshotForClientIP(nil)
30+
entries := ciphers.SnapshotForClientIP(netip.Addr{})
3131
ciphers.MarkUsedByClientIP(entries[0], ip)
3232
}
3333
})
@@ -43,20 +43,20 @@ func BenchmarkSnapshot(b *testing.B) {
4343

4444
// Shuffling simulates the behavior of a real server, where successive
4545
// ciphers are not expected to be nearby in memory.
46-
entries := ciphers.SnapshotForClientIP(nil)
46+
entries := ciphers.SnapshotForClientIP(netip.Addr{})
4747
rand.Shuffle(N, func(i, j int) {
4848
entries[i], entries[j] = entries[j], entries[i]
4949
})
5050
for _, entry := range entries {
5151
// Reorder the list to match the shuffle
5252
// (actually in reverse, but it doesn't matter).
53-
ciphers.MarkUsedByClientIP(entry, nil)
53+
ciphers.MarkUsedByClientIP(entry, netip.Addr{})
5454
}
5555

5656
b.ResetTimer()
5757
b.RunParallel(func(pb *testing.PB) {
5858
for pb.Next() {
59-
ciphers.SnapshotForClientIP(nil)
59+
ciphers.SnapshotForClientIP(netip.Addr{})
6060
}
6161
})
6262
}

service/tcp.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"io"
2424
"net"
25+
"net/netip"
2526
"sync"
2627
"syscall"
2728
"time"
@@ -46,19 +47,19 @@ type TCPMetrics interface {
4647
AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64)
4748
}
4849

49-
func remoteIP(conn net.Conn) net.IP {
50+
func remoteIP(conn net.Conn) netip.Addr {
5051
addr := conn.RemoteAddr()
5152
if addr == nil {
52-
return nil
53+
return netip.Addr{}
5354
}
5455
if tcpaddr, ok := addr.(*net.TCPAddr); ok {
55-
return tcpaddr.IP
56+
return tcpaddr.AddrPort().Addr()
5657
}
57-
ipstr, _, err := net.SplitHostPort(addr.String())
58+
addrPort, err := netip.ParseAddrPort(addr.String())
5859
if err == nil {
59-
return net.ParseIP(ipstr)
60+
return addrPort.Addr()
6061
}
61-
return nil
62+
return netip.Addr{}
6263
}
6364

6465
// Wrapper for logger.Debugf during TCP access key searches.
@@ -76,7 +77,7 @@ func debugTCP(cipherID, template string, val interface{}) {
7677
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
7778
const bytesForKeyFinding = 50
7879

79-
func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
80+
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
8081
// We snapshot the list because it may be modified while we use it.
8182
ciphers := cipherList.SnapshotForClientIP(clientIP)
8283
firstBytes := make([]byte, bytesForKeyFinding)
@@ -264,7 +265,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
264265
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
265266
connStart := time.Now()
266267

267-
id, connError := h.handleConnection(ctx, h.port, clientInfo, measuredClientConn, &proxyMetrics)
268+
id, connError := h.handleConnection(ctx, measuredClientConn, &proxyMetrics)
268269

269270
connDuration := time.Since(connStart)
270271
status := "OK"
@@ -327,7 +328,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr
327328
return nil
328329
}
329330

330-
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientInfo ipinfo.IPInfo, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
331+
func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
331332
// Set a deadline to receive the address to the target.
332333
readDeadline := time.Now().Add(h.readTimeout)
333334
if deadline, ok := ctx.Deadline(); ok {
@@ -341,7 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
341342
id, innerConn, authErr := h.authenticate(outerConn)
342343
if authErr != nil {
343344
// Drain to protect against probing attacks.
344-
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
345+
h.absorbProbe(outerConn, authErr.Status, proxyMetrics)
345346
return id, authErr
346347
}
347348
h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id)
@@ -369,12 +370,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
369370

370371
// Keep the connection open until we hit the authentication deadline to protect against probing attacks
371372
// `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
372-
func (h *tcpHandler) absorbProbe(listenerPort int, clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
373+
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
373374
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
374375
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
375376
drainResult := drainErrToString(drainErr)
376377
logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult)
377-
h.m.AddTCPProbe(status, drainResult, listenerPort, proxyMetrics.ClientProxy)
378+
h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy)
378379
}
379380

380381
func drainErrToString(drainErr error) string {

service/tcp_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"io"
2222
"math/rand"
2323
"net"
24+
"net/netip"
2425
"sync"
2526
"testing"
2627
"time"
@@ -99,7 +100,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
99100
if err != nil {
100101
b.Fatalf("AcceptTCP failed: %v", err)
101102
}
102-
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
103+
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
103104
b.StartTimer()
104105
findAccessKey(clientConn, clientIP, cipherList)
105106
b.StopTimer()
@@ -191,16 +192,16 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
191192
b.Fatal(err)
192193
}
193194
cipherEntries := [numCiphers]*CipherEntry{}
194-
snapshot := cipherList.SnapshotForClientIP(nil)
195+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
195196
for cipherNumber, element := range snapshot {
196197
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
197198
}
198199
for n := 0; n < b.N; n++ {
199200
cipherNumber := byte(n % numCiphers)
200201
reader, writer := io.Pipe()
201-
clientIP := net.IPv4(192, 0, 2, cipherNumber)
202-
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
203-
c := conn{clientAddr: addr, reader: reader, writer: writer}
202+
clientIP := netip.AddrFrom4([4]byte{192, 0, 2, cipherNumber})
203+
addr := netip.AddrPortFrom(clientIP, 54321)
204+
c := conn{clientAddr: net.TCPAddrFromAddrPort(addr), reader: reader, writer: writer}
204205
cipher := cipherEntries[cipherNumber].CryptoKey
205206
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
206207
b.StartTimer()
@@ -345,7 +346,7 @@ func makeClientBytesCoalesced(t *testing.T, cryptoKey *shadowsocks.EncryptionKey
345346
}
346347

347348
func firstCipher(cipherList CipherList) *shadowsocks.EncryptionKey {
348-
snapshot := cipherList.SnapshotForClientIP(nil)
349+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
349350
cipherEntry := snapshot[0].Value.(*CipherEntry)
350351
return cipherEntry.CryptoKey
351352
}
@@ -506,7 +507,7 @@ func TestReplayDefense(t *testing.T) {
506507
const testTimeout = 200 * time.Millisecond
507508
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
508509
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
509-
snapshot := cipherList.SnapshotForClientIP(nil)
510+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
510511
cipherEntry := snapshot[0].Value.(*CipherEntry)
511512
cipher := cipherEntry.CryptoKey
512513
reader, writer := io.Pipe()
@@ -585,7 +586,7 @@ func TestReverseReplayDefense(t *testing.T) {
585586
const testTimeout = 200 * time.Millisecond
586587
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
587588
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
588-
snapshot := cipherList.SnapshotForClientIP(nil)
589+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
589590
cipherEntry := snapshot[0].Value.(*CipherEntry)
590591
cipher := cipherEntry.CryptoKey
591592
reader, writer := io.Pipe()

service/udp.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"fmt"
2020
"net"
21+
"net/netip"
2122
"runtime/debug"
2223
"sync"
2324
"time"
@@ -64,7 +65,7 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {
6465

6566
// Decrypts src into dst. It tries each cipher until it finds one that authenticates
6667
// correctly. dst and src must not overlap.
67-
func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
68+
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
6869
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
6970
// We snapshot the list because it may be modified while we use it.
7071
snapshot := cipherList.SnapshotForClientIP(clientIP)
@@ -156,7 +157,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
156157
}
157158
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)
158159

159-
ip := clientAddr.(*net.UDPAddr).IP
160+
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
160161
var textData []byte
161162
var cryptoKey *shadowsocks.EncryptionKey
162163
unpackStart := time.Now()

service/udp_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"bytes"
1919
"errors"
2020
"net"
21+
"net/netip"
2122
"sync"
2223
"testing"
2324
"time"
@@ -124,7 +125,7 @@ func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher ti
124125
// generates when localhost access is attempted
125126
func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics {
126127
ciphers, _ := MakeTestCiphers([]string{"asdf"})
127-
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).CryptoKey
128+
cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey
128129
clientConn := makePacketConn()
129130
metrics := &natTestMetrics{}
130131
handler := NewPacketHandler(timeout, ciphers, metrics)
@@ -403,7 +404,7 @@ func BenchmarkUDPUnpackFail(b *testing.B) {
403404
}
404405
testPayload := makeTestPayload(50)
405406
textBuf := make([]byte, serverUDPBufferSize)
406-
testIP := net.ParseIP("192.0.2.1")
407+
testIP := netip.MustParseAddr("192.0.2.1")
407408
b.ResetTimer()
408409
for n := 0; n < b.N; n++ {
409410
findAccessKeyUDP(testIP, textBuf, testPayload, cipherList)
@@ -420,16 +421,16 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) {
420421
}
421422
testBuf := make([]byte, serverUDPBufferSize)
422423
packets := [numCiphers][]byte{}
423-
ips := [numCiphers]net.IP{}
424-
snapshot := cipherList.SnapshotForClientIP(nil)
424+
ips := [numCiphers]netip.Addr{}
425+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
425426
for i, element := range snapshot {
426427
packets[i] = make([]byte, 0, serverUDPBufferSize)
427428
plaintext := makeTestPayload(50)
428429
packets[i], err = shadowsocks.Pack(make([]byte, serverUDPBufferSize), plaintext, element.Value.(*CipherEntry).CryptoKey)
429430
if err != nil {
430431
b.Error(err)
431432
}
432-
ips[i] = net.IPv4(192, 0, 2, byte(i))
433+
ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)})
433434
}
434435
b.ResetTimer()
435436
for n := 0; n < b.N; n++ {
@@ -452,15 +453,15 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) {
452453
}
453454
testBuf := make([]byte, serverUDPBufferSize)
454455
plaintext := makeTestPayload(50)
455-
snapshot := cipherList.SnapshotForClientIP(nil)
456+
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
456457
cryptoKey := snapshot[0].Value.(*CipherEntry).CryptoKey
457458
packet, err := shadowsocks.Pack(make([]byte, serverUDPBufferSize), plaintext, cryptoKey)
458459
require.Nil(b, err)
459460

460461
const numIPs = 100 // Must be <256
461-
ips := [numIPs]net.IP{}
462+
ips := [numIPs]netip.Addr{}
462463
for i := 0; i < numIPs; i++ {
463-
ips[i] = net.IPv4(192, 0, 2, byte(i))
464+
ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)})
464465
}
465466
b.ResetTimer()
466467
for n := 0; n < b.N; n++ {

0 commit comments

Comments
 (0)