From 95e4aa91b396efac0eea9d009d05ff0ff4e261eb Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Thu, 5 Nov 2020 17:30:58 -0500 Subject: [PATCH 1/2] Bugfix: track the keyID for UDP connections This is important for applying usage limits to UDP streams. It also fixes a bug in metrics reporting, because the shared metrics are not tolerant of proxy<->target traffic that lacks a keyID. --- service/udp.go | 15 ++++++-- service/udp_test.go | 94 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 85 insertions(+), 24 deletions(-) diff --git a/service/udp.go b/service/udp.go index 6434ec38..42160497 100644 --- a/service/udp.go +++ b/service/udp.go @@ -173,7 +173,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { var tgtUDPAddr *net.UDPAddr targetConn := nm.Get(clientAddr.String()) if targetConn == nil { - clientLocation, locErr := s.m.GetLocation(clientAddr) + var locErr error + clientLocation, locErr = s.m.GetLocation(clientAddr) if locErr != nil { logger.Warningf("Failed location lookup: %v", locErr) } @@ -201,6 +202,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID) } else { + clientLocation = targetConn.clientLocation + unpackStart := time.Now() textData, err := ss.Unpack(nil, cipherData, targetConn.cipher) timeToCipher = time.Now().Sub(unpackStart) @@ -208,12 +211,14 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) } + // The key ID is known with confidence once decryption succeeds. + keyID = targetConn.keyID + var onetErr *onet.ConnectionError if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil { return onetErr } } - clientLocation = targetConn.clientLocation debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature @@ -271,6 +276,7 @@ func isDNS(addr net.Addr) bool { type natconn struct { net.PacketConn cipher *ss.Cipher + keyID string // We store the client location in the NAT map to avoid recomputing it // for every downstream packet in a UDP-based connection. clientLocation string @@ -351,10 +357,11 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, clientLocation string) *natconn { +func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn { entry := &natconn{ PacketConn: pc, cipher: cipher, + keyID: keyID, clientLocation: clientLocation, defaultTimeout: m.timeout, } @@ -379,7 +386,7 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cipher, clientLocation) + entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation) m.metrics.AddUDPNatEntry() m.running.Add(1) diff --git a/service/udp_test.go b/service/udp_test.go index c30469ef..f397a850 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -89,10 +89,16 @@ func (conn *fakePacketConn) Close() error { return nil } +type udpReport struct { + clientLocation, accessKey, status string + clientProxyBytes, proxyTargetBytes int +} + // Stub metrics implementation for testing NAT behaviors. type natTestMetrics struct { metrics.ShadowsocksMetrics natEntriesAdded int + upstreamPackets []udpReport } func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) { @@ -107,6 +113,7 @@ func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) { func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) { } func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) { + m.upstreamPackets = append(m.upstreamPackets, udpReport{clientLocation, accessKey, status, clientProxyBytes, proxyTargetBytes}) } func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) { } @@ -115,21 +122,20 @@ func (m *natTestMetrics) AddUDPNatEntry() { } func (m *natTestMetrics) RemoveUDPNatEntry() {} -func TestIPFilter(t *testing.T) { - // Takes a validation policy, and returns the metrics it - // generates when localhost access is attempted - checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics { - ciphers, _ := MakeTestCiphers([]string{"asdf"}) - cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher - clientConn := makePacketConn() - metrics := &natTestMetrics{} - service := NewUDPService(timeout, ciphers, metrics) - service.SetTargetIPValidator(validator) - go service.Serve(clientConn) - - // Send one packet to the "discard" port on localhost - targetAddr := socks.ParseAddr("127.0.0.1:9") - payload := []byte("payload") +// Takes a validation policy, and returns the metrics it +// generates when localhost access is attempted +func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher + clientConn := makePacketConn() + metrics := &natTestMetrics{} + service := NewUDPService(timeout, ciphers, metrics) + service.SetTargetIPValidator(validator) + go service.Serve(clientConn) + + // Send one packet to the "discard" port on localhost + targetAddr := socks.ParseAddr("127.0.0.1:9") + for _, payload := range payloads { plaintext := append(targetAddr, payload...) ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) ss.Pack(ciphertext, plaintext, cipher) @@ -140,26 +146,74 @@ func TestIPFilter(t *testing.T) { }, payload: ciphertext, } - - service.GracefulStop() - return metrics } + service.GracefulStop() + return metrics +} + +func TestIPFilter(t *testing.T) { + // Test both the first-packet and subsequent-packet cases. + payloads := [][]byte{[]byte("payload1"), []byte("payload2")} + t.Run("Localhost allowed", func(t *testing.T) { - metrics := checkLocalhost(allowAll) + metrics := sendToDiscard(payloads, allowAll) if metrics.natEntriesAdded != 1 { t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded) } }) t.Run("Localhost not allowed", func(t *testing.T) { - metrics := checkLocalhost(onet.RequirePublicIP) + metrics := sendToDiscard(payloads, onet.RequirePublicIP) if metrics.natEntriesAdded != 0 { t.Error("Unexpected NAT entry on rejected packet") } + if len(metrics.upstreamPackets) != 2 { + t.Errorf("Expected 2 reports, not %v", metrics.upstreamPackets) + } + for _, report := range metrics.upstreamPackets { + if report.clientProxyBytes == 0 { + t.Error("Expected nonzero input packet size") + } + if report.proxyTargetBytes != 0 { + t.Error("No bytes should be sent due to a disallowed packet") + } + if report.accessKey != "id-0" { + t.Errorf("Unexpected access key: %s", report.accessKey) + } + } }) } +func TestUpstreamMetrics(t *testing.T) { + // Test both the first-packet and subsequent-packet cases. + const N = 10 + payloads := make([][]byte, 0) + for i := 1; i <= N; i++ { + payloads = append(payloads, make([]byte, i)) + } + + metrics := sendToDiscard(payloads, allowAll) + + if len(metrics.upstreamPackets) != N { + t.Errorf("Expected %d reports, not %v", N, metrics.upstreamPackets) + } + for i, report := range metrics.upstreamPackets { + if report.proxyTargetBytes != i+1 { + t.Errorf("Expected %d payload bytes, not %d", i, report.proxyTargetBytes) + } + if report.clientProxyBytes <= report.proxyTargetBytes { + t.Errorf("Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) + } + if report.accessKey != "id-0" { + t.Errorf("Unexpected access key name: %s", report.accessKey) + } + if report.status != "OK" { + t.Errorf("Wrong status: %s", report.status) + } + } +} + func assertAlmostEqual(t *testing.T, a, b time.Time) { delta := a.Sub(b) limit := 100 * time.Millisecond From 4a91989d4676bd12e142592972ab9498a2b8f40e Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Thu, 5 Nov 2020 18:01:35 -0500 Subject: [PATCH 2/2] Convert to testify --- service/udp_test.go | 45 ++++++++++++--------------------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/service/udp_test.go b/service/udp_test.go index f397a850..12139fdc 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -27,6 +27,7 @@ import ( ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" + "github.com/stretchr/testify/assert" ) const timeout = 5 * time.Minute @@ -158,29 +159,17 @@ func TestIPFilter(t *testing.T) { t.Run("Localhost allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, allowAll) - if metrics.natEntriesAdded != 1 { - t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded) - } + assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) }) t.Run("Localhost not allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, onet.RequirePublicIP) - if metrics.natEntriesAdded != 0 { - t.Error("Unexpected NAT entry on rejected packet") - } - if len(metrics.upstreamPackets) != 2 { - t.Errorf("Expected 2 reports, not %v", metrics.upstreamPackets) - } + assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { - if report.clientProxyBytes == 0 { - t.Error("Expected nonzero input packet size") - } - if report.proxyTargetBytes != 0 { - t.Error("No bytes should be sent due to a disallowed packet") - } - if report.accessKey != "id-0" { - t.Errorf("Unexpected access key: %s", report.accessKey) - } + assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") + assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet") + assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) } }) } @@ -195,22 +184,12 @@ func TestUpstreamMetrics(t *testing.T) { metrics := sendToDiscard(payloads, allowAll) - if len(metrics.upstreamPackets) != N { - t.Errorf("Expected %d reports, not %v", N, metrics.upstreamPackets) - } + assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets) for i, report := range metrics.upstreamPackets { - if report.proxyTargetBytes != i+1 { - t.Errorf("Expected %d payload bytes, not %d", i, report.proxyTargetBytes) - } - if report.clientProxyBytes <= report.proxyTargetBytes { - t.Errorf("Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) - } - if report.accessKey != "id-0" { - t.Errorf("Unexpected access key name: %s", report.accessKey) - } - if report.status != "OK" { - t.Errorf("Wrong status: %s", report.status) - } + assert.Equal(t, i+1, report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes) + assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) + assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey) + assert.Equal(t, "OK", report.status, "Wrong status: %s", report.status) } }