diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 531c16ba..414935ae 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -324,12 +324,12 @@ func (m *outlineMetrics) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, client } } -func (m *outlineMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { +func (m *outlineMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { m.udpPacketsFromClientPerLocation.WithLabelValues(clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN), status).Inc() - addIfNonZero(int64(clientProxyBytes), m.dataBytes, "c>p", "udp", accessKey) - addIfNonZero(int64(clientProxyBytes), m.dataBytesPerLocation, "c>p", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) - addIfNonZero(int64(proxyTargetBytes), m.dataBytes, "p>t", "udp", accessKey) - addIfNonZero(int64(proxyTargetBytes), m.dataBytesPerLocation, "p>t", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) + addIfNonZero(data.ClientProxy, m.dataBytes, "c>p", "udp", accessKey) + addIfNonZero(data.ClientProxy, m.dataBytesPerLocation, "c>p", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) + addIfNonZero(data.ProxyTarget, m.dataBytes, "p>t", "udp", accessKey) + addIfNonZero(data.ProxyTarget, m.dataBytesPerLocation, "p>t", "udp", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN)) } func (m *outlineMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { diff --git a/cmd/outline-ss-server/metrics_test.go b/cmd/outline-ss-server/metrics_test.go index 353520e4..085169d8 100644 --- a/cmd/outline-ss-server/metrics_test.go +++ b/cmd/outline-ss-server/metrics_test.go @@ -52,23 +52,29 @@ func init() { func TestMethodsDontPanic(t *testing.T) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewPedanticRegistry()) - proxyMetrics := metrics.ProxyMetrics{ + tcpProxyMetrics := metrics.ProxyMetrics{ ClientProxy: 1, ProxyTarget: 2, TargetProxy: 3, ProxyClient: 4, } + udpProxyMetrics := metrics.ProxyMetrics{ + ClientProxy: 10, + ProxyTarget: 20, + TargetProxy: 30, + ProxyClient: 40, + } ipInfo := ipinfo.IPInfo{CountryCode: "US", ASN: 100} ssMetrics.SetBuildInfo("0.0.0-test") ssMetrics.SetNumAccessKeys(20, 2) ssMetrics.AddOpenTCPConnection(ipInfo) ssMetrics.AddAuthenticatedTCPConnection(fakeAddr("127.0.0.1:9"), "0") - ssMetrics.AddClosedTCPConnection(ipInfo, fakeAddr("127.0.0.1:9"), "1", "OK", proxyMetrics, 10*time.Millisecond) - ssMetrics.AddUDPPacketFromClient(ipInfo, "2", "OK", 10, 20) + ssMetrics.AddClosedTCPConnection(ipInfo, fakeAddr("127.0.0.1:9"), "1", "OK", tcpProxyMetrics, 10*time.Millisecond) + ssMetrics.AddUDPPacketFromClient(ipInfo, "2", "OK", udpProxyMetrics) ssMetrics.AddUDPPacketFromTarget(ipInfo, "3", "OK", 10, 20) ssMetrics.AddUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") ssMetrics.RemoveUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") - ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", 443, proxyMetrics.ClientProxy) + ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", 443, tcpProxyMetrics.ClientProxy) ssMetrics.AddTCPCipherSearch(true, 10*time.Millisecond) ssMetrics.AddUDPCipherSearch(true, 10*time.Millisecond) } @@ -174,14 +180,19 @@ func BenchmarkProbe(b *testing.B) { func BenchmarkClientUDP(b *testing.B) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewRegistry()) + proxyMetrics := metrics.ProxyMetrics{ + ClientProxy: 1000, + ProxyTarget: 2000, + TargetProxy: 3000, + ProxyClient: 4000, + } clientInfo := ipinfo.IPInfo{CountryCode: "ZZ", ASN: 100} accessKey := "key 1" status := "OK" - size := 1000 timeToCipher := time.Microsecond b.ResetTimer() for i := 0; i < b.N; i++ { - ssMetrics.AddUDPPacketFromClient(clientInfo, accessKey, status, size, size) + ssMetrics.AddUDPPacketFromClient(clientInfo, accessKey, status, proxyMetrics) ssMetrics.AddUDPCipherSearch(true, timeToCipher) } } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 4ca2f120..17585fbe 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -262,8 +262,8 @@ var _ service.UDPMetrics = (*fakeUDPMetrics)(nil) func (m *fakeUDPMetrics) GetIPInfo(ip net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{CountryCode: "QQ"}, nil } -func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { - m.up = append(m.up, udpRecord{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes}) +func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { + m.up = append(m.up, udpRecord{clientInfo, accessKey, status, int(data.ClientProxy), int(data.ProxyTarget)}) } func (m *fakeUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { m.down = append(m.down, udpRecord{clientInfo, accessKey, status, targetProxyBytes, proxyClientBytes}) diff --git a/service/udp.go b/service/udp.go index 4830e302..dbf02e59 100644 --- a/service/udp.go +++ b/service/udp.go @@ -26,6 +26,7 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -35,7 +36,7 @@ type UDPMetrics interface { ipinfo.IPInfoMap // UDP metrics - AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) + AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) AddUDPNatEntry(clientAddr net.Addr, accessKey string) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string) @@ -65,28 +66,29 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) { // Decrypts src into dst. It tries each cipher until it finds one that authenticates // correctly. dst and src must not overlap. -func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) { +func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) (*CipherEntry, []byte, error) { // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD. // We snapshot the list because it may be modified while we use it. snapshot := cipherList.SnapshotForClientIP(clientIP) - for ci, entry := range snapshot { - id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey - buf, err := shadowsocks.Unpack(dst, src, cryptoKey) + for ci, elt := range snapshot { + entry := elt.Value.(*CipherEntry) + buf, err := shadowsocks.Unpack(dst, src, entry.CryptoKey) if err != nil { - debugUDP(id, "Failed to unpack: %v", err) + debugUDP(entry.ID, "Failed to unpack: %v", err) continue } - debugUDP(id, "Found cipher at index %d", ci) + debugUDP(entry.ID, "Found cipher at index %d", ci) // Move the active cipher to the front, so that the search is quicker next time. - cipherList.MarkUsedByClientIP(entry, clientIP) - return buf, id, cryptoKey, nil + cipherList.MarkUsedByClientIP(elt, clientIP) + return entry, buf, nil } - return nil, "", nil, errors.New("could not find valid UDP cipher") + return nil, nil, errors.New("could not find valid UDP cipher") } type packetHandler struct { natTimeout time.Duration ciphers CipherList + nm *natmap m UDPMetrics targetIPValidator onet.TargetIPValidator } @@ -113,114 +115,102 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup - nm := newNATmap(h.natTimeout, h.m, &running) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) + h.nm = newNATmap(h.natTimeout, h.m, &running) + defer h.nm.Close() for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { - break + status := "OK" + keyID, clientInfo, proxyMetrics, connErr := h.handleConnection(clientConn) + if connErr != nil { + if errors.Is(connErr.Cause, net.ErrClosed) { + break + } + logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause) + status = connErr.Status } + h.m.AddUDPPacketFromClient(clientInfo, keyID, status, proxyMetrics) + } +} - var clientInfo ipinfo.IPInfo - keyID := "" - var proxyTargetBytes int - - connError := func() (connError *onet.ConnectionError) { - defer func() { - if r := recover(); r != nil { - logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) - debug.PrintStack() - } - }() +func (h *packetHandler) authenticate(clientConn net.PacketConn) (net.Addr, *CipherEntry, []byte, int, *onet.ConnectionError) { + cipherBuf := make([]byte, serverUDPBufferSize) + textBuf := make([]byte, serverUDPBufferSize) + clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) + if err != nil { + return nil, nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err) + } - // Error from ReadFrom - if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) - } - if logger.IsEnabledFor(logging.DEBUG) { - defer logger.Debugf("UDP(%v): done", clientAddr) - logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) - } + if logger.IsEnabledFor(logging.DEBUG) { + defer logger.Debugf("UDP(%v): done", clientAddr) + logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) + } - cipherData := cipherBuf[:clientProxyBytes] - var payload []byte - var tgtUDPAddr *net.UDPAddr - targetConn := nm.Get(clientAddr.String()) - if targetConn == nil { - var locErr error - clientInfo, locErr = ipinfo.GetIPInfoFromAddr(h.m, clientAddr) - if locErr != nil { - logger.Warningf("Failed client info lookup: %v", locErr) - } - debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var textData []byte - var cryptoKey *shadowsocks.EncryptionKey - unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) - } + remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } + unpackStart := time.Now() + cipherEntry, textData, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers) + timeToCipher := time.Since(unpackStart) + h.m.AddUDPCipherSearch(err == nil, timeToCipher) + if keyErr != nil { + return nil, nil, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr) + } - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) - } else { - clientInfo = targetConn.clientInfo + return clientAddr, cipherEntry, textData, clientProxyBytes, nil +} - unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) +func (h *packetHandler) proxyConnection(clientAddr net.Addr, tgtAddr net.Addr, clientConn net.PacketConn, cipherEntry CipherEntry, payload []byte, proxyMetrics *metrics.ProxyMetrics) (ipinfo.IPInfo, *onet.ConnectionError) { + tgtConn := h.nm.Get(clientAddr.String()) + if tgtConn == nil { + clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) + if locErr != nil { + logger.Warningf("Failed client info lookup: %v", locErr) + } + debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return ipinfo.IPInfo{}, nil + } + tgtConn = h.nm.Add(clientAddr, clientConn, cipherEntry.CryptoKey, udpConn, clientInfo, cipherEntry.ID) + } - // The key ID is known with confidence once decryption succeeds. - keyID = targetConn.keyID + proxyTargetBytes, err := tgtConn.WriteTo(payload, tgtAddr) + proxyMetrics.ProxyTarget += int64(proxyTargetBytes) + if err != nil { + return tgtConn.clientInfo, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) + } + return tgtConn.clientInfo, nil +} - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } - } +func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, metrics.ProxyMetrics, *onet.ConnectionError) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) + debug.PrintStack() + } + }() - debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) - proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) - } - return nil - }() + var proxyMetrics metrics.ProxyMetrics + clientAddr, cipherEntry, textData, clientProxyBytes, authErr := h.authenticate(clientConn) + proxyMetrics.ClientProxy += int64(clientProxyBytes) + if authErr != nil { + return "", ipinfo.IPInfo{}, proxyMetrics, authErr + } - status := "OK" - if connError != nil { - logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause) - status = connError.Status - } - h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + payload, tgtAddr, onetErr := h.getProxyRequest(textData) + if onetErr != nil { + return cipherEntry.ID, ipinfo.IPInfo{}, proxyMetrics, onetErr } + debugUDPAddr(clientAddr, "Proxy exit %s", tgtAddr.String()) + + clientInfo, err := h.proxyConnection(clientAddr, tgtAddr, clientConn, *cipherEntry, payload, &proxyMetrics) + return cipherEntry.ID, clientInfo, proxyMetrics, err } // Given the decrypted contents of a UDP packet, return // the payload and the destination address, or an error if // this packet cannot or should not be forwarded. -func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { +func (h *packetHandler) getProxyRequest(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { tgtAddr := socks.SplitAddr(textData) if tgtAddr == nil { return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) @@ -246,7 +236,6 @@ func isDNS(addr net.Addr) bool { type natconn struct { net.PacketConn cryptoKey *shadowsocks.EncryptionKey - keyID string // We store the client information in the NAT map to avoid recomputing it // for every downstream packet in a UDP-based connection. clientInfo ipinfo.IPInfo @@ -327,11 +316,9 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { +func (m *natmap) set(key string, pc net.PacketConn, clientInfo ipinfo.IPInfo) *natconn { entry := &natconn{ PacketConn: pc, - cryptoKey: cryptoKey, - keyID: keyID, clientInfo: clientInfo, defaultTimeout: m.timeout, } @@ -356,12 +343,12 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, clientInfo) + entry := m.set(clientAddr.String(), targetConn, clientInfo) m.metrics.AddUDPNatEntry(clientAddr, keyID) m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + timedCopy(clientAddr, clientConn, entry, cryptoKey, keyID, m.metrics) m.metrics.RemoveUDPNatEntry(clientAddr, keyID) if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -391,13 +378,13 @@ var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, - keyID string, sm UDPMetrics) { + cryptoKey *shadowsocks.EncryptionKey, keyID string, sm UDPMetrics) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. pkt := make([]byte, serverUDPBufferSize) - saltSize := targetConn.cryptoKey.SaltSize() + saltSize := cryptoKey.SaltSize() // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). bodyStart := saltSize + maxAddrLen @@ -441,7 +428,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco // [ packBuf ] // [ buf ] packBuf := pkt[saltStart:] - buf, err := shadowsocks.Pack(packBuf, plaintextBuf, targetConn.cryptoKey) // Encrypt in-place + buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } @@ -472,7 +459,7 @@ var _ UDPMetrics = (*NoOpUDPMetrics)(nil) func (m *NoOpUDPMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{}, nil } -func (m *NoOpUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { +func (m *NoOpUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { } func (m *NoOpUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { } diff --git a/service/udp_test.go b/service/udp_test.go index f94238c5..e052d90a 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -26,6 +26,7 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" "github.com/stretchr/testify/assert" @@ -93,9 +94,9 @@ func (conn *fakePacketConn) Close() error { } type udpReport struct { - clientInfo ipinfo.IPInfo - accessKey, status string - clientProxyBytes, proxyTargetBytes int + clientInfo ipinfo.IPInfo + accessKey, status string + data metrics.ProxyMetrics } // Stub metrics implementation for testing NAT behaviors. @@ -109,8 +110,8 @@ var _ UDPMetrics = (*natTestMetrics)(nil) func (m *natTestMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) { return ipinfo.IPInfo{}, nil } -func (m *natTestMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) { - m.upstreamPackets = append(m.upstreamPackets, udpReport{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes}) +func (m *natTestMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics) { + m.upstreamPackets = append(m.upstreamPackets, udpReport{clientInfo, accessKey, status, data}) } func (m *natTestMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) { } @@ -123,9 +124,7 @@ func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher ti // Takes a validation policy, and returns the metrics it // generates when localhost access is attempted -func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { - ciphers, _ := MakeTestCiphers([]string{"asdf"}) - cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey +func sendToDiscard(ciphers CipherList, payloads [][]byte, cipher *shadowsocks.EncryptionKey, validator onet.TargetIPValidator) *natTestMetrics { clientConn := makePacketConn() metrics := &natTestMetrics{} handler := NewPacketHandler(timeout, ciphers, metrics) @@ -156,27 +155,58 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest return metrics } +func sendToDiscardWithValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + return sendToDiscard(ciphers, payloads, cipher, validator) +} + +func sendToDiscardWithInValidCipher(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher, _ := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "invalid cipher") + return sendToDiscard(ciphers, payloads, cipher, validator) +} + func TestIPFilter(t *testing.T) { // Test both the first-packet and subsequent-packet cases. payloads := [][]byte{[]byte("payload1"), []byte("payload2")} t.Run("Localhost allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, allowAll) - assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + metrics := sendToDiscardWithValidCipher(payloads, allowAll) + + assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) + for _, report := range metrics.upstreamPackets { + assert.Greater(t, int(report.data.ProxyTarget), 0, "Expected nonzero bytes to be sent for allowed packet") + } }) t.Run("Localhost not allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, onet.RequirePublicIP) - assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + metrics := sendToDiscardWithValidCipher(payloads, onet.RequirePublicIP) + assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") - assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet") - assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) + assert.EqualValues(t, 0, report.data.ProxyTarget, "No bytes should be sent due to a disallowed packet") } }) } +func TestNATEntries(t *testing.T) { + // Test both the first-packet and subsequent-packet cases. + payloads := [][]byte{[]byte("payload1"), []byte("payload2")} + + t.Run("Valid cipher", func(t *testing.T) { + metrics := sendToDiscardWithValidCipher(payloads, allowAll) + + assert.Equal(t, 1, metrics.natEntriesAdded, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + }) + + t.Run("Invalid cipher", func(t *testing.T) { + metrics := sendToDiscardWithInValidCipher(payloads, allowAll) + + assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + }) +} + func TestUpstreamMetrics(t *testing.T) { // Test both the first-packet and subsequent-packet cases. const N = 10 @@ -185,12 +215,12 @@ func TestUpstreamMetrics(t *testing.T) { payloads = append(payloads, make([]byte, i)) } - metrics := sendToDiscard(payloads, allowAll) + metrics := sendToDiscardWithValidCipher(payloads, allowAll) assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets) for i, report := range metrics.upstreamPackets { - assert.Equal(t, i+1, report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes) - assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) + assert.EqualValues(t, i+1, report.data.ProxyTarget, "Expected %d payload bytes, not %d", i+1, report.data.ProxyTarget) + assert.Greater(t, report.data.ClientProxy, report.data.ProxyTarget, "Expected nonzero input overhead (%d > %d)", report.data.ClientProxy, report.data.ProxyTarget) assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey) assert.Equal(t, "OK", report.status, "Wrong status: %s", report.status) } @@ -437,7 +467,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { cipherNumber := n % numCiphers ip := ips[cipherNumber] packet := packets[cipherNumber] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) if err != nil { b.Error(err) } @@ -466,7 +496,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { ip := ips[n%numIPs] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) if err != nil { b.Error(err) }