diff --git a/service/udp.go b/service/udp.go index cb48b123..6434ec38 100644 --- a/service/udp.go +++ b/service/udp.go @@ -169,7 +169,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } cipherData := cipherBuf[:clientProxyBytes] - var textData []byte + var payload []byte + var tgtUDPAddr *net.UDPAddr targetConn := nm.Get(clientAddr.String()) if targetConn == nil { clientLocation, locErr := s.m.GetLocation(clientAddr) @@ -179,6 +180,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { debugUDPAddr(clientAddr, "Got location \"%s\"", clientLocation) ip := clientAddr.(*net.UDPAddr).IP + var textData []byte var cipher *ss.Cipher unpackStart := time.Now() textData, keyID, cipher, err = findAccessKeyUDP(ip, textBuf, cipherData, s.ciphers) @@ -188,6 +190,11 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) } + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil { + return onetErr + } + udpConn, err := net.ListenPacket("udp", "") if err != nil { return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) @@ -195,29 +202,20 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID) } else { unpackStart := time.Now() - textData, err = ss.Unpack(nil, cipherData, targetConn.cipher) + textData, err := ss.Unpack(nil, cipherData, targetConn.cipher) timeToCipher = time.Now().Sub(unpackStart) if err != nil { return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) } - } - clientLocation = targetConn.clientLocation - tgtAddr := socks.SplitAddr(textData) - if tgtAddr == nil { - return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) - } - - tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String()) - if err != nil { - return onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err) - } - if err := s.targetIPValidator(tgtUDPAddr.IP); err != nil { - return err + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil { + return onetErr + } } + clientLocation = targetConn.clientLocation debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) - payload := textData[len(tgtAddr):] proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) @@ -228,6 +226,27 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { return nil } +// Given the decrypted contents of a UDP packet, return +// the payload and the destination address, or an error if +// this packet cannot or should not be forwarded. +func (s *udpService) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { + tgtAddr := socks.SplitAddr(textData) + if tgtAddr == nil { + return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) + } + + tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String()) + if err != nil { + return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err) + } + if err := s.targetIPValidator(tgtUDPAddr.IP); err != nil { + return nil, nil, err + } + + payload := textData[len(tgtAddr):] + return payload, tgtUDPAddr, nil +} + func (s *udpService) Stop() error { s.mu.Lock() defer s.mu.Unlock() diff --git a/service/udp_test.go b/service/udp_test.go index 1c047555..c30469ef 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -22,8 +22,11 @@ import ( "testing" "time" + onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" logging "github.com/op/go-logging" + "github.com/shadowsocks/go-shadowsocks2/socks" ) const timeout = 5 * time.Minute @@ -86,6 +89,77 @@ func (conn *fakePacketConn) Close() error { return nil } +// Stub metrics implementation for testing NAT behaviors. +type natTestMetrics struct { + metrics.ShadowsocksMetrics + natEntriesAdded int +} + +func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) { +} +func (m *natTestMetrics) AddClosedTCPConnection(clientLocation, accessKey, status string, data metrics.ProxyMetrics, timeToCipher, duration time.Duration) { +} +func (m *natTestMetrics) GetLocation(net.Addr) (string, error) { + return "", nil +} +func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) { +} +func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) { +} +func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) { +} +func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) { +} +func (m *natTestMetrics) AddUDPNatEntry() { + m.natEntriesAdded++ +} +func (m *natTestMetrics) RemoveUDPNatEntry() {} + +func TestIPFilter(t *testing.T) { + // Takes a validation policy, and returns the metrics it + // generates when localhost access is attempted + checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher + clientConn := makePacketConn() + metrics := &natTestMetrics{} + service := NewUDPService(timeout, ciphers, metrics) + service.SetTargetIPValidator(validator) + go service.Serve(clientConn) + + // Send one packet to the "discard" port on localhost + targetAddr := socks.ParseAddr("127.0.0.1:9") + payload := []byte("payload") + plaintext := append(targetAddr, payload...) + ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) + ss.Pack(ciphertext, plaintext, cipher) + clientConn.recv <- packet{ + addr: &net.UDPAddr{ + IP: net.ParseIP("192.0.2.1"), + Port: 54321, + }, + payload: ciphertext, + } + + service.GracefulStop() + return metrics + } + + t.Run("Localhost allowed", func(t *testing.T) { + metrics := checkLocalhost(allowAll) + if metrics.natEntriesAdded != 1 { + t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + } + }) + + t.Run("Localhost not allowed", func(t *testing.T) { + metrics := checkLocalhost(onet.RequirePublicIP) + if metrics.natEntriesAdded != 0 { + t.Error("Unexpected NAT entry on rejected packet") + } + }) +} + func assertAlmostEqual(t *testing.T, a, b time.Time) { delta := a.Sub(b) limit := 100 * time.Millisecond @@ -95,14 +169,14 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &probeTestMetrics{}, &sync.WaitGroup{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) if nat.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } } -func setup() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &probeTestMetrics{}, &sync.WaitGroup{}) +func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) clientConn := makePacketConn() targetConn := makePacketConn() nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id") @@ -111,7 +185,7 @@ func setup() (*fakePacketConn, *fakePacketConn, *natconn) { } func TestNATGet(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() if entry == nil { t.Fatal("Failed to find target conn") } @@ -121,7 +195,7 @@ func TestNATGet(t *testing.T) { } func TestNATWrite(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() // Simulate one generic packet being sent buf := []byte{1} @@ -137,7 +211,7 @@ func TestNATWrite(t *testing.T) { } func TestNATWriteDNS(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() // Simulate one DNS query being sent. buf := []byte{1} @@ -154,7 +228,7 @@ func TestNATWriteDNS(t *testing.T) { } func TestNATWriteDNSMultiple(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() // Simulate three DNS queries being sent. buf := []byte{1} @@ -169,7 +243,7 @@ func TestNATWriteDNSMultiple(t *testing.T) { } func TestNATWriteMixed(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() // Simulate both non-DNS and DNS packets being sent. buf := []byte{1} @@ -182,7 +256,7 @@ func TestNATWriteMixed(t *testing.T) { } func TestNATFastClose(t *testing.T) { - clientConn, targetConn, entry := setup() + clientConn, targetConn, entry := setupNAT() // Send one DNS query. query := []byte{1} @@ -208,7 +282,7 @@ func TestNATFastClose(t *testing.T) { } func TestNATNoFastClose_NotDNS(t *testing.T) { - clientConn, targetConn, entry := setup() + clientConn, targetConn, entry := setupNAT() // Send one non-DNS packet. query := []byte{1} @@ -233,7 +307,7 @@ func TestNATNoFastClose_NotDNS(t *testing.T) { } func TestNATNoFastClose_MultipleDNS(t *testing.T) { - clientConn, targetConn, entry := setup() + clientConn, targetConn, entry := setupNAT() // Send two DNS packets. query1 := []byte{1} @@ -267,7 +341,7 @@ func (e *fakeTimeoutError) Temporary() bool { } func TestNATTimeout(t *testing.T) { - _, targetConn, entry := setup() + _, targetConn, entry := setupNAT() // Simulate a non-DNS initial packet. entry.WriteTo([]byte{1}, &targetAddr) @@ -365,7 +439,7 @@ func TestUDPDoubleServe(t *testing.T) { if err != nil { t.Fatal(err) } - testMetrics := &probeTestMetrics{} + testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond s := NewUDPService(testTimeout, cipherList, testMetrics) @@ -399,7 +473,7 @@ func TestUDPEarlyStop(t *testing.T) { if err != nil { t.Fatal(err) } - testMetrics := &probeTestMetrics{} + testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond s := NewUDPService(testTimeout, cipherList, testMetrics)