From bcf7e53bb7466544246da20f8e8c858ffbf53e7a Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 26 Jun 2024 12:02:40 -0400 Subject: [PATCH 1/7] refactor: modularize UDP connection handling --- service/udp.go | 171 +++++++++++++++++++++----------------------- service/udp_test.go | 10 ++- 2 files changed, 89 insertions(+), 92 deletions(-) diff --git a/service/udp.go b/service/udp.go index 4830e302..57a0efb7 100644 --- a/service/udp.go +++ b/service/udp.go @@ -87,6 +87,7 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis type packetHandler struct { natTimeout time.Duration ciphers CipherList + nm *natmap m UDPMetrics targetIPValidator onet.TargetIPValidator } @@ -113,108 +114,94 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup - nm := newNATmap(h.natTimeout, h.m, &running) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) + h.nm = newNATmap(h.natTimeout, h.m, &running) + defer h.nm.Close() for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { - break + status := "OK" + keyID, clientInfo, clientProxyBytes, proxyTargetBytes, connErr := h.handleConnection(clientConn) + if connErr != nil { + if errors.Is(connErr.Cause, net.ErrClosed) { + break + } + logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause) + status = connErr.Status } + h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + } +} - var clientInfo ipinfo.IPInfo - keyID := "" - var proxyTargetBytes int +func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byte, int, *onet.ConnectionError) { + cipherBuf := make([]byte, serverUDPBufferSize) + textBuf := make([]byte, serverUDPBufferSize) + clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) + if err != nil { + return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err) + } - connError := func() (connError *onet.ConnectionError) { - defer func() { - if r := recover(); r != nil { - logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) - debug.PrintStack() - } - }() + if logger.IsEnabledFor(logging.DEBUG) { + defer logger.Debugf("UDP(%v): done", clientAddr) + logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) + } - // Error from ReadFrom - if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) - } - if logger.IsEnabledFor(logging.DEBUG) { - defer logger.Debugf("UDP(%v): done", clientAddr) - logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) - } + targetConn := h.nm.Get(clientAddr.String()) + remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr() - cipherData := cipherBuf[:clientProxyBytes] - var payload []byte - var tgtUDPAddr *net.UDPAddr - targetConn := nm.Get(clientAddr.String()) - if targetConn == nil { - var locErr error - clientInfo, locErr = ipinfo.GetIPInfoFromAddr(h.m, clientAddr) - if locErr != nil { - logger.Warningf("Failed client info lookup: %v", locErr) - } - debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var textData []byte - var cryptoKey *shadowsocks.EncryptionKey - unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) - } + unpackStart := time.Now() + textData, keyID, cryptoKey, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers) + timeToCipher := time.Since(unpackStart) + h.m.AddUDPCipherSearch(err == nil, timeToCipher) + if keyErr != nil { + return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr) + } - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } + if targetConn != nil { + return targetConn, textData, clientProxyBytes, nil + } - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) - } else { - clientInfo = targetConn.clientInfo + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return targetConn, textData, clientProxyBytes, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) + } - unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) + clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) + if locErr != nil { + logger.Warningf("Failed client info lookup: %v", locErr) + } + debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } + targetConn = h.nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) + return targetConn, textData, clientProxyBytes, nil +} - // The key ID is known with confidence once decryption succeeds. - keyID = targetConn.keyID +func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) + debug.PrintStack() + } + }() - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } - } + targetConn, textData, clientProxyBytes, authErr := h.authenticate(clientConn) + if authErr != nil { + var clientInfo ipinfo.IPInfo + if targetConn != nil { + clientInfo = targetConn.clientInfo + } + return "", clientInfo, clientProxyBytes, 0, authErr + } - debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) - proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) - } - return nil - }() + payload, tgtUDPAddr, onetErr := h.validatePacket(textData) + if onetErr != nil { + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, 0, onetErr + } - status := "OK" - if connError != nil { - logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause) - status = connError.Status - } - h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + debugUDPAddr(targetConn.clientAddr, "Proxy exit %v", targetConn.LocalAddr()) + proxyTargetBytes, err := targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) } + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, nil } // Given the decrypted contents of a UDP packet, return @@ -245,8 +232,9 @@ func isDNS(addr net.Addr) bool { type natconn struct { net.PacketConn - cryptoKey *shadowsocks.EncryptionKey - keyID string + cryptoKey *shadowsocks.EncryptionKey + keyID string + clientAddr net.Addr // We store the client information in the NAT map to avoid recomputing it // for every downstream packet in a UDP-based connection. clientInfo ipinfo.IPInfo @@ -327,11 +315,12 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { +func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { entry := &natconn{ PacketConn: pc, cryptoKey: cryptoKey, keyID: keyID, + clientAddr: clientAddr, clientInfo: clientInfo, defaultTimeout: m.timeout, } @@ -339,7 +328,7 @@ func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.Encry m.Lock() defer m.Unlock() - m.keyConn[key] = entry + m.keyConn[clientAddr.String()] = entry return entry } @@ -356,7 +345,7 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, clientInfo) + entry := m.set(clientAddr, targetConn, cryptoKey, keyID, clientInfo) m.metrics.AddUDPNatEntry(clientAddr, keyID) m.running.Add(1) diff --git a/service/udp_test.go b/service/udp_test.go index f94238c5..e461280e 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -162,12 +162,20 @@ func TestIPFilter(t *testing.T) { t.Run("Localhost allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, allowAll) + assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) + for _, report := range metrics.upstreamPackets { + assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") + assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet") + assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) + } }) t.Run("Localhost not allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, onet.RequirePublicIP) - assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + + assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") From 459583ecc81d6bf55240decd947af474f12609f0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 27 Jun 2024 10:49:17 -0400 Subject: [PATCH 2/7] Update tests to separate the IP filtering from NAT entries. --- service/udp_test.go | 46 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/service/udp_test.go b/service/udp_test.go index e461280e..d68905bd 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -123,7 +123,7 @@ func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher ti // Takes a validation policy, and returns the metrics it // generates when localhost access is attempted -func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { +func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator, useValidCipher bool) *natTestMetrics { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey clientConn := makePacketConn() @@ -140,7 +140,12 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest targetAddr := socks.ParseAddr("127.0.0.1:9") for _, payload := range payloads { plaintext := append(targetAddr, payload...) - ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) + var ciphertext []byte + if useValidCipher { + ciphertext = make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) + } else { + ciphertext = []byte("invalid cipher") + } shadowsocks.Pack(ciphertext, plaintext, cipher) clientConn.recv <- packet{ addr: &net.UDPAddr{ @@ -156,35 +161,54 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest return metrics } +func sendToDiscardWithValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { + return sendToDiscard(payloads, validator, true) +} + +func sendToDiscardWithInValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { + return sendToDiscard(payloads, validator, false) +} + func TestIPFilter(t *testing.T) { // Test both the first-packet and subsequent-packet cases. payloads := [][]byte{[]byte("payload1"), []byte("payload2")} t.Run("Localhost allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, allowAll) + metrics := sendToDiscardWithValidCipher(payloads, allowAll) - assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet") - assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) } }) t.Run("Localhost not allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, onet.RequirePublicIP) + metrics := sendToDiscardWithValidCipher(payloads, onet.RequirePublicIP) - assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet") - assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) } }) } +func TestNATEntries(t *testing.T) { + // Test both the first-packet and subsequent-packet cases. + payloads := [][]byte{[]byte("payload1"), []byte("payload2")} + + t.Run("Valid cipher", func(t *testing.T) { + metrics := sendToDiscardWithValidCipher(payloads, onet.RequirePublicIP) + + assert.Equal(t, 1, metrics.natEntriesAdded, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + }) + + t.Run("Invalid cipher", func(t *testing.T) { + metrics := sendToDiscardWithInValidCipher(payloads, onet.RequirePublicIP) + + assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + }) +} + func TestUpstreamMetrics(t *testing.T) { // Test both the first-packet and subsequent-packet cases. const N = 10 @@ -193,7 +217,7 @@ func TestUpstreamMetrics(t *testing.T) { payloads = append(payloads, make([]byte, i)) } - metrics := sendToDiscard(payloads, allowAll) + metrics := sendToDiscardWithValidCipher(payloads, allowAll) assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets) for i, report := range metrics.upstreamPackets { From 0680221ed980e5fe1acac355b321a21335e23f94 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 27 Jun 2024 18:26:44 -0400 Subject: [PATCH 3/7] Move socket creation out of `authenticate()` method. --- service/udp.go | 67 +++++++++++++++++++++------------------------ service/udp_test.go | 4 +-- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/service/udp.go b/service/udp.go index 57a0efb7..728cf8df 100644 --- a/service/udp.go +++ b/service/udp.go @@ -65,23 +65,23 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) { // Decrypts src into dst. It tries each cipher until it finds one that authenticates // correctly. dst and src must not overlap. -func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) { +func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) (*CipherEntry, []byte, error) { // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD. // We snapshot the list because it may be modified while we use it. snapshot := cipherList.SnapshotForClientIP(clientIP) - for ci, entry := range snapshot { - id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey - buf, err := shadowsocks.Unpack(dst, src, cryptoKey) + for ci, elt := range snapshot { + entry := elt.Value.(*CipherEntry) + buf, err := shadowsocks.Unpack(dst, src, entry.CryptoKey) if err != nil { - debugUDP(id, "Failed to unpack: %v", err) + debugUDP(entry.ID, "Failed to unpack: %v", err) continue } - debugUDP(id, "Found cipher at index %d", ci) + debugUDP(entry.ID, "Found cipher at index %d", ci) // Move the active cipher to the front, so that the search is quicker next time. - cipherList.MarkUsedByClientIP(entry, clientIP) - return buf, id, cryptoKey, nil + cipherList.MarkUsedByClientIP(elt, clientIP) + return entry, buf, nil } - return nil, "", nil, errors.New("could not find valid UDP cipher") + return nil, nil, errors.New("could not find valid UDP cipher") } type packetHandler struct { @@ -131,12 +131,12 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { } } -func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byte, int, *onet.ConnectionError) { +func (h *packetHandler) authenticate(clientConn net.PacketConn) (net.Addr, *CipherEntry, []byte, int, *onet.ConnectionError) { cipherBuf := make([]byte, serverUDPBufferSize) textBuf := make([]byte, serverUDPBufferSize) clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) if err != nil { - return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err) + return nil, nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err) } if logger.IsEnabledFor(logging.DEBUG) { @@ -144,34 +144,17 @@ func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byt logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) } - targetConn := h.nm.Get(clientAddr.String()) remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr() unpackStart := time.Now() - textData, keyID, cryptoKey, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers) + cipherEntry, textData, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers) timeToCipher := time.Since(unpackStart) h.m.AddUDPCipherSearch(err == nil, timeToCipher) if keyErr != nil { - return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr) - } - - if targetConn != nil { - return targetConn, textData, clientProxyBytes, nil - } - - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return targetConn, textData, clientProxyBytes, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - - clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) - if locErr != nil { - logger.Warningf("Failed client info lookup: %v", locErr) + return nil, nil, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr) } - debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - targetConn = h.nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) - return targetConn, textData, clientProxyBytes, nil + return clientAddr, cipherEntry, textData, clientProxyBytes, nil } func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) { @@ -182,13 +165,25 @@ func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipi } }() - targetConn, textData, clientProxyBytes, authErr := h.authenticate(clientConn) + clientAddr, cipherEntry, textData, clientProxyBytes, authErr := h.authenticate(clientConn) if authErr != nil { - var clientInfo ipinfo.IPInfo - if targetConn != nil { - clientInfo = targetConn.clientInfo + return "", ipinfo.IPInfo{}, clientProxyBytes, 0, authErr + } + + targetConn := h.nm.Get(clientAddr.String()) + if targetConn == nil { + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return "", ipinfo.IPInfo{}, clientProxyBytes, 0, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) } - return "", clientInfo, clientProxyBytes, 0, authErr + + clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) + if locErr != nil { + logger.Warningf("Failed client info lookup: %v", locErr) + } + debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) + + targetConn = h.nm.Add(clientAddr, clientConn, cipherEntry.CryptoKey, udpConn, clientInfo, cipherEntry.ID) } payload, tgtUDPAddr, onetErr := h.validatePacket(textData) diff --git a/service/udp_test.go b/service/udp_test.go index d68905bd..579fe522 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -469,7 +469,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { cipherNumber := n % numCiphers ip := ips[cipherNumber] packet := packets[cipherNumber] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) if err != nil { b.Error(err) } @@ -498,7 +498,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { ip := ips[n%numIPs] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) if err != nil { b.Error(err) } From 0e69c45d98982c6985fc4c44b236c7c92841d56a Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 28 Jun 2024 14:21:58 -0400 Subject: [PATCH 4/7] Split out proxying into `proxyConnection()`. --- cmd/outline-ss-server/metrics.go | 10 +-- cmd/outline-ss-server/metrics_test.go | 23 +++-- internal/integration_test/integration_test.go | 4 +- service/udp.go | 90 ++++++++++--------- service/udp_test.go | 23 ++--- 5 files changed, 83 insertions(+), 67 deletions(-) diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 531c16ba..414935ae 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -324,12 +324,12 @@ func (m *outlineMetrics) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, client } } -func (m *outlineMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { +func (m *outlineMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { m.udpPacketsFromClientPerLocation.WithLabelValues(clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN), status).Inc() - addIfNonZero(int64(clientProxyBytes), m.dataBytes, "c>p", "udp", accessKey) - addIfNonZero(int64(clientProxyBytes), m.dataBytesPerLocation, "c>p", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) - addIfNonZero(int64(proxyTargetBytes), m.dataBytes, "p>t", "udp", accessKey) - addIfNonZero(int64(proxyTargetBytes), m.dataBytesPerLocation, "p>t", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) + addIfNonZero(data.ClientProxy, m.dataBytes, "c>p", "udp", accessKey) + addIfNonZero(data.ClientProxy, m.dataBytesPerLocation, "c>p", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) + addIfNonZero(data.ProxyTarget, m.dataBytes, "p>t", "udp", accessKey) + addIfNonZero(data.ProxyTarget, m.dataBytesPerLocation, "p>t", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) } func (m *outlineMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { diff --git a/cmd/outline-ss-server/metrics_test.go b/cmd/outline-ss-server/metrics_test.go index 353520e4..085169d8 100644 --- a/cmd/outline-ss-server/metrics_test.go +++ b/cmd/outline-ss-server/metrics_test.go @@ -52,23 +52,29 @@ func init() { func TestMethodsDontPanic(t *testing.T) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewPedanticRegistry()) - proxyMetrics := metrics.ProxyMetrics{ + tcpProxyMetrics := metrics.ProxyMetrics{ ClientProxy: 1, ProxyTarget: 2, TargetProxy: 3, ProxyClient: 4, } + udpProxyMetrics := metrics.ProxyMetrics{ + ClientProxy: 10, + ProxyTarget: 20, + TargetProxy: 30, + ProxyClient: 40, + } ipInfo := ipinfo.IPInfo{CountryCode: "US", ASN: 100} ssMetrics.SetBuildInfo("0.0.0-test") ssMetrics.SetNumAccessKeys(20, 2) ssMetrics.AddOpenTCPConnection(ipInfo) ssMetrics.AddAuthenticatedTCPConnection(fakeAddr("127.0.0.1:9"), "0") - ssMetrics.AddClosedTCPConnection(ipInfo, fakeAddr("127.0.0.1:9"), "1", "OK", proxyMetrics, 10*time.Millisecond) - ssMetrics.AddUDPPacketFromClient(ipInfo, "2", "OK", 10, 20) + ssMetrics.AddClosedTCPConnection(ipInfo, fakeAddr("127.0.0.1:9"), "1", "OK", tcpProxyMetrics, 10*time.Millisecond) + ssMetrics.AddUDPPacketFromClient(ipInfo, "2", "OK", udpProxyMetrics) ssMetrics.AddUDPPacketFromTarget(ipInfo, "3", "OK", 10, 20) ssMetrics.AddUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") ssMetrics.RemoveUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") - ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", 443, proxyMetrics.ClientProxy) + ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", 443, tcpProxyMetrics.ClientProxy) ssMetrics.AddTCPCipherSearch(true, 10*time.Millisecond) ssMetrics.AddUDPCipherSearch(true, 10*time.Millisecond) } @@ -174,14 +180,19 @@ func BenchmarkProbe(b *testing.B) { func BenchmarkClientUDP(b *testing.B) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewRegistry()) + proxyMetrics := metrics.ProxyMetrics{ + ClientProxy: 1000, + ProxyTarget: 2000, + TargetProxy: 3000, + ProxyClient: 4000, + } clientInfo := ipinfo.IPInfo{CountryCode: "ZZ", ASN: 100} accessKey := "key 1" status := "OK" - size := 1000 timeToCipher := time.Microsecond b.ResetTimer() for i := 0; i < b.N; i++ { - ssMetrics.AddUDPPacketFromClient(clientInfo, accessKey, status, size, size) + ssMetrics.AddUDPPacketFromClient(clientInfo, accessKey, status, proxyMetrics) ssMetrics.AddUDPCipherSearch(true, timeToCipher) } } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 4ca2f120..17585fbe 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -262,8 +262,8 @@ var _ service.UDPMetrics = (*fakeUDPMetrics)(nil) func (m *fakeUDPMetrics) GetIPInfo(ip net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{CountryCode: "QQ"}, nil } -func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { - m.up = append(m.up, udpRecord{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes}) +func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { + m.up = append(m.up, udpRecord{clientInfo, accessKey, status, int(data.ClientProxy), int(data.ProxyTarget)}) } func (m *fakeUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { m.down = append(m.down, udpRecord{clientInfo, accessKey, status, targetProxyBytes, proxyClientBytes}) diff --git a/service/udp.go b/service/udp.go index 728cf8df..9e134264 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,6 +15,7 @@ package service import ( + "context" "errors" "fmt" "net" @@ -26,6 +27,7 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -35,7 +37,7 @@ type UDPMetrics interface { ipinfo.IPInfoMap // UDP metrics - AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) + AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) AddUDPNatEntry(clientAddr net.Addr, accessKey string) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string) @@ -119,7 +121,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { for { status := "OK" - keyID, clientInfo, clientProxyBytes, proxyTargetBytes, connErr := h.handleConnection(clientConn) + keyID, clientInfo, proxyMetrics, connErr := h.handleConnection(context.TODO(), clientConn) if connErr != nil { if errors.Is(connErr.Cause, net.ErrClosed) { break @@ -127,7 +129,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause) status = connErr.Status } - h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + h.m.AddUDPPacketFromClient(clientInfo, keyID, status, proxyMetrics) } } @@ -157,7 +159,31 @@ func (h *packetHandler) authenticate(clientConn net.PacketConn) (net.Addr, *Ciph return clientAddr, cipherEntry, textData, clientProxyBytes, nil } -func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) { +func (h *packetHandler) proxyConnection(ctx context.Context, clientAddr net.Addr, tgtAddr net.Addr, clientConn net.PacketConn, cipherEntry CipherEntry, payload []byte, proxyMetrics *metrics.ProxyMetrics) (ipinfo.IPInfo, *onet.ConnectionError) { + tgtConn := h.nm.Get(clientAddr.String()) + if tgtConn == nil { + clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) + if locErr != nil { + logger.Warningf("Failed client info lookup: %v", locErr) + } + debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) + + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return ipinfo.IPInfo{}, nil + } + tgtConn = h.nm.Add(clientAddr, clientConn, cipherEntry.CryptoKey, udpConn, clientInfo, cipherEntry.ID) + } + + proxyTargetBytes, err := tgtConn.WriteTo(payload, tgtAddr) + proxyMetrics.ProxyTarget += int64(proxyTargetBytes) + if err != nil { + return tgtConn.clientInfo, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) + } + return tgtConn.clientInfo, nil +} + +func (h *packetHandler) handleConnection(ctx context.Context, clientConn net.PacketConn) (string, ipinfo.IPInfo, metrics.ProxyMetrics, *onet.ConnectionError) { defer func() { if r := recover(); r != nil { logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) @@ -165,44 +191,27 @@ func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipi } }() + var proxyMetrics metrics.ProxyMetrics clientAddr, cipherEntry, textData, clientProxyBytes, authErr := h.authenticate(clientConn) + proxyMetrics.ClientProxy += int64(clientProxyBytes) if authErr != nil { - return "", ipinfo.IPInfo{}, clientProxyBytes, 0, authErr + return "", ipinfo.IPInfo{}, proxyMetrics, authErr } - targetConn := h.nm.Get(clientAddr.String()) - if targetConn == nil { - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return "", ipinfo.IPInfo{}, clientProxyBytes, 0, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - - clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) - if locErr != nil { - logger.Warningf("Failed client info lookup: %v", locErr) - } - debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - - targetConn = h.nm.Add(clientAddr, clientConn, cipherEntry.CryptoKey, udpConn, clientInfo, cipherEntry.ID) - } - - payload, tgtUDPAddr, onetErr := h.validatePacket(textData) + payload, tgtAddr, onetErr := h.getProxyRequest(textData) if onetErr != nil { - return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, 0, onetErr + return cipherEntry.ID, ipinfo.IPInfo{}, proxyMetrics, onetErr } + debugUDPAddr(clientAddr, "Proxy exit %s", tgtAddr.String()) - debugUDPAddr(targetConn.clientAddr, "Proxy exit %v", targetConn.LocalAddr()) - proxyTargetBytes, err := targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature - if err != nil { - return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) - } - return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, nil + clientInfo, err := h.proxyConnection(ctx, clientAddr, tgtAddr, clientConn, *cipherEntry, payload, &proxyMetrics) + return cipherEntry.ID, clientInfo, proxyMetrics, err } // Given the decrypted contents of a UDP packet, return // the payload and the destination address, or an error if // this packet cannot or should not be forwarded. -func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { +func (h *packetHandler) getProxyRequest(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { tgtAddr := socks.SplitAddr(textData) if tgtAddr == nil { return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) @@ -227,9 +236,7 @@ func isDNS(addr net.Addr) bool { type natconn struct { net.PacketConn - cryptoKey *shadowsocks.EncryptionKey - keyID string - clientAddr net.Addr + cryptoKey *shadowsocks.EncryptionKey // We store the client information in the NAT map to avoid recomputing it // for every downstream packet in a UDP-based connection. clientInfo ipinfo.IPInfo @@ -310,12 +317,9 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { +func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, clientInfo ipinfo.IPInfo) *natconn { entry := &natconn{ PacketConn: pc, - cryptoKey: cryptoKey, - keyID: keyID, - clientAddr: clientAddr, clientInfo: clientInfo, defaultTimeout: m.timeout, } @@ -340,12 +344,12 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn { - entry := m.set(clientAddr, targetConn, cryptoKey, keyID, clientInfo) + entry := m.set(clientAddr, targetConn, clientInfo) m.metrics.AddUDPNatEntry(clientAddr, keyID) m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + timedCopy(clientAddr, clientConn, entry, cryptoKey, keyID, m.metrics) m.metrics.RemoveUDPNatEntry(clientAddr, keyID) if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -375,13 +379,13 @@ var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, - keyID string, sm UDPMetrics) { + cryptoKey *shadowsocks.EncryptionKey, keyID string, sm UDPMetrics) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. pkt := make([]byte, serverUDPBufferSize) - saltSize := targetConn.cryptoKey.SaltSize() + saltSize := cryptoKey.SaltSize() // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). bodyStart := saltSize + maxAddrLen @@ -425,7 +429,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco // [ packBuf ] // [ buf ] packBuf := pkt[saltStart:] - buf, err := shadowsocks.Pack(packBuf, plaintextBuf, targetConn.cryptoKey) // Encrypt in-place + buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } @@ -456,7 +460,7 @@ var _ UDPMetrics = (*NoOpUDPMetrics)(nil) func (m *NoOpUDPMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{}, nil } -func (m *NoOpUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { +func (m *NoOpUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { } func (m *NoOpUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { } diff --git a/service/udp_test.go b/service/udp_test.go index 579fe522..7cd59c95 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -26,6 +26,7 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" "github.com/stretchr/testify/assert" @@ -93,9 +94,9 @@ func (conn *fakePacketConn) Close() error { } type udpReport struct { - clientInfo ipinfo.IPInfo - accessKey, status string - clientProxyBytes, proxyTargetBytes int + clientInfo ipinfo.IPInfo + accessKey, status string + data metrics.ProxyMetrics } // Stub metrics implementation for testing NAT behaviors. @@ -109,8 +110,8 @@ var _ UDPMetrics = (*natTestMetrics)(nil) func (m *natTestMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{}, nil } -func (m *natTestMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { - m.upstreamPackets = append(m.upstreamPackets, udpReport{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes}) +func (m *natTestMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { + m.upstreamPackets = append(m.upstreamPackets, udpReport{clientInfo, accessKey, status, data}) } func (m *natTestMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { } @@ -178,7 +179,7 @@ func TestIPFilter(t *testing.T) { assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet") + assert.Greater(t, int(report.data.ProxyTarget), 0, "Expected nonzero bytes to be sent for allowed packet") } }) @@ -187,7 +188,7 @@ func TestIPFilter(t *testing.T) { assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet") + assert.EqualValues(t, 0, report.data.ProxyTarget, "No bytes should be sent due to a disallowed packet") } }) } @@ -197,13 +198,13 @@ func TestNATEntries(t *testing.T) { payloads := [][]byte{[]byte("payload1"), []byte("payload2")} t.Run("Valid cipher", func(t *testing.T) { - metrics := sendToDiscardWithValidCipher(payloads, onet.RequirePublicIP) + metrics := sendToDiscardWithValidCipher(payloads, allowAll) assert.Equal(t, 1, metrics.natEntriesAdded, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) }) t.Run("Invalid cipher", func(t *testing.T) { - metrics := sendToDiscardWithInValidCipher(payloads, onet.RequirePublicIP) + metrics := sendToDiscardWithInValidCipher(payloads, allowAll) assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") }) @@ -221,8 +222,8 @@ func TestUpstreamMetrics(t *testing.T) { assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets) for i, report := range metrics.upstreamPackets { - assert.Equal(t, i+1, report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes) - assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) + assert.EqualValues(t, i+1, report.data.ProxyTarget, "Expected %d payload bytes, not %d", i+1, report.data.ProxyTarget) + assert.Greater(t, report.data.ClientProxy, report.data.ProxyTarget, "Expected nonzero input overhead (%d > %d)", report.data.ClientProxy, report.data.ProxyTarget) assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey) assert.Equal(t, "OK", report.status, "Wrong status: %s", report.status) } From 37204289ee1372a96cf55098ff0e201b64559e15 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 28 Jun 2024 14:58:12 -0400 Subject: [PATCH 5/7] Remove unused context. --- service/udp.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/service/udp.go b/service/udp.go index 9e134264..a6a46b50 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,7 +15,6 @@ package service import ( - "context" "errors" "fmt" "net" @@ -121,7 +120,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { for { status := "OK" - keyID, clientInfo, proxyMetrics, connErr := h.handleConnection(context.TODO(), clientConn) + keyID, clientInfo, proxyMetrics, connErr := h.handleConnection(clientConn) if connErr != nil { if errors.Is(connErr.Cause, net.ErrClosed) { break @@ -159,7 +158,7 @@ func (h *packetHandler) authenticate(clientConn net.PacketConn) (net.Addr, *Ciph return clientAddr, cipherEntry, textData, clientProxyBytes, nil } -func (h *packetHandler) proxyConnection(ctx context.Context, clientAddr net.Addr, tgtAddr net.Addr, clientConn net.PacketConn, cipherEntry CipherEntry, payload []byte, proxyMetrics *metrics.ProxyMetrics) (ipinfo.IPInfo, *onet.ConnectionError) { +func (h *packetHandler) proxyConnection(clientAddr net.Addr, tgtAddr net.Addr, clientConn net.PacketConn, cipherEntry CipherEntry, payload []byte, proxyMetrics *metrics.ProxyMetrics) (ipinfo.IPInfo, *onet.ConnectionError) { tgtConn := h.nm.Get(clientAddr.String()) if tgtConn == nil { clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) @@ -183,7 +182,7 @@ func (h *packetHandler) proxyConnection(ctx context.Context, clientAddr net.Addr return tgtConn.clientInfo, nil } -func (h *packetHandler) handleConnection(ctx context.Context, clientConn net.PacketConn) (string, ipinfo.IPInfo, metrics.ProxyMetrics, *onet.ConnectionError) { +func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, metrics.ProxyMetrics, *onet.ConnectionError) { defer func() { if r := recover(); r != nil { logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) @@ -204,7 +203,7 @@ func (h *packetHandler) handleConnection(ctx context.Context, clientConn net.Pac } debugUDPAddr(clientAddr, "Proxy exit %s", tgtAddr.String()) - clientInfo, err := h.proxyConnection(ctx, clientAddr, tgtAddr, clientConn, *cipherEntry, payload, &proxyMetrics) + clientInfo, err := h.proxyConnection(clientAddr, tgtAddr, clientConn, *cipherEntry, payload, &proxyMetrics) return cipherEntry.ID, clientInfo, proxyMetrics, err } From 5381cf96db71d92f537032990884cf5ce7a9f957 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 28 Jun 2024 15:03:23 -0400 Subject: [PATCH 6/7] Revert key string change. --- service/udp.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/udp.go b/service/udp.go index a6a46b50..dbf02e59 100644 --- a/service/udp.go +++ b/service/udp.go @@ -316,7 +316,7 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, clientInfo ipinfo.IPInfo) *natconn { +func (m *natmap) set(key string, pc net.PacketConn, clientInfo ipinfo.IPInfo) *natconn { entry := &natconn{ PacketConn: pc, clientInfo: clientInfo, @@ -326,7 +326,7 @@ func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, clientInfo ipinfo.I m.Lock() defer m.Unlock() - m.keyConn[clientAddr.String()] = entry + m.keyConn[key] = entry return entry } @@ -343,7 +343,7 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn { - entry := m.set(clientAddr, targetConn, clientInfo) + entry := m.set(clientAddr.String(), targetConn, clientInfo) m.metrics.AddUDPNatEntry(clientAddr, keyID) m.running.Add(1) From 7ab36b282337e47167eadb17344db4585eca843b Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 28 Jun 2024 15:15:42 -0400 Subject: [PATCH 7/7] Refactor `sendToDiscard()`. --- service/udp_test.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/service/udp_test.go b/service/udp_test.go index 7cd59c95..e052d90a 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -124,9 +124,7 @@ func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher ti // Takes a validation policy, and returns the metrics it // generates when localhost access is attempted -func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator, useValidCipher bool) *natTestMetrics { - ciphers, _ := MakeTestCiphers([]string{"asdf"}) - cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey +func sendToDiscard(ciphers CipherList, payloads [][]byte, cipher *shadowsocks.EncryptionKey, validator onet.TargetIPValidator) *natTestMetrics { clientConn := makePacketConn() metrics := &natTestMetrics{} handler := NewPacketHandler(timeout, ciphers, metrics) @@ -141,12 +139,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator, useValid targetAddr := socks.ParseAddr("127.0.0.1:9") for _, payload := range payloads { plaintext := append(targetAddr, payload...) - var ciphertext []byte - if useValidCipher { - ciphertext = make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) - } else { - ciphertext = []byte("invalid cipher") - } + ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) shadowsocks.Pack(ciphertext, plaintext, cipher) clientConn.recv <- packet{ addr: &net.UDPAddr{ @@ -163,11 +156,15 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator, useValid } func sendToDiscardWithValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { - return sendToDiscard(payloads, validator, true) + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + return sendToDiscard(ciphers, payloads, cipher, validator) } func sendToDiscardWithInValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { - return sendToDiscard(payloads, validator, false) + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher, _ := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "invalid cipher") + return sendToDiscard(ciphers, payloads, cipher, validator) } func TestIPFilter(t *testing.T) {