Skip to content

Commit

Permalink
Merge pull request #95 from Jigsaw-Code/bemasc-reject
Browse files Browse the repository at this point in the history
Avoid creating NAT mappings for rejected packets
  • Loading branch information
Benjamin M. Schwartz authored Nov 2, 2020
2 parents 264ddf9 + 262ac75 commit 6e8119b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 30 deletions.
51 changes: 35 additions & 16 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
}

cipherData := cipherBuf[:clientProxyBytes]
var textData []byte
var payload []byte
var tgtUDPAddr *net.UDPAddr
targetConn := nm.Get(clientAddr.String())
if targetConn == nil {
clientLocation, locErr := s.m.GetLocation(clientAddr)
Expand All @@ -179,6 +180,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
debugUDPAddr(clientAddr, "Got location \"%s\"", clientLocation)

ip := clientAddr.(*net.UDPAddr).IP
var textData []byte
var cipher *ss.Cipher
unpackStart := time.Now()
textData, keyID, cipher, err = findAccessKeyUDP(ip, textBuf, cipherData, s.ciphers)
Expand All @@ -188,36 +190,32 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
}

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil {
return onetErr
}

udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}
targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID)
} else {
unpackStart := time.Now()
textData, err = ss.Unpack(nil, cipherData, targetConn.cipher)
textData, err := ss.Unpack(nil, cipherData, targetConn.cipher)
timeToCipher = time.Now().Sub(unpackStart)
if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
}
}
clientLocation = targetConn.clientLocation

tgtAddr := socks.SplitAddr(textData)
if tgtAddr == nil {
return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil)
}

tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String())
if err != nil {
return onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
}
if err := s.targetIPValidator(tgtUDPAddr.IP); err != nil {
return err
var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil {
return onetErr
}
}
clientLocation = targetConn.clientLocation

debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
payload := textData[len(tgtAddr):]
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
Expand All @@ -228,6 +226,27 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
return nil
}

// Given the decrypted contents of a UDP packet, return
// the payload and the destination address, or an error if
// this packet cannot or should not be forwarded.
func (s *udpService) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) {
tgtAddr := socks.SplitAddr(textData)
if tgtAddr == nil {
return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil)
}

tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String())
if err != nil {
return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
}
if err := s.targetIPValidator(tgtUDPAddr.IP); err != nil {
return nil, nil, err
}

payload := textData[len(tgtAddr):]
return payload, tgtUDPAddr, nil
}

func (s *udpService) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
102 changes: 88 additions & 14 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ import (
"testing"
"time"

onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

const timeout = 5 * time.Minute
Expand Down Expand Up @@ -86,6 +89,77 @@ func (conn *fakePacketConn) Close() error {
return nil
}

// Stub metrics implementation for testing NAT behaviors.
type natTestMetrics struct {
metrics.ShadowsocksMetrics
natEntriesAdded int
}

func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) {
}
func (m *natTestMetrics) AddClosedTCPConnection(clientLocation, accessKey, status string, data metrics.ProxyMetrics, timeToCipher, duration time.Duration) {
}
func (m *natTestMetrics) GetLocation(net.Addr) (string, error) {
return "", nil
}
func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) {
}
func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) {
}
func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) {
}
func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
}
func (m *natTestMetrics) AddUDPNatEntry() {
m.natEntriesAdded++
}
func (m *natTestMetrics) RemoveUDPNatEntry() {}

func TestIPFilter(t *testing.T) {
// Takes a validation policy, and returns the metrics it
// generates when localhost access is attempted
checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics {
ciphers, _ := MakeTestCiphers([]string{"asdf"})
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
clientConn := makePacketConn()
metrics := &natTestMetrics{}
service := NewUDPService(timeout, ciphers, metrics)
service.SetTargetIPValidator(validator)
go service.Serve(clientConn)

// Send one packet to the "discard" port on localhost
targetAddr := socks.ParseAddr("127.0.0.1:9")
payload := []byte("payload")
plaintext := append(targetAddr, payload...)
ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize())
ss.Pack(ciphertext, plaintext, cipher)
clientConn.recv <- packet{
addr: &net.UDPAddr{
IP: net.ParseIP("192.0.2.1"),
Port: 54321,
},
payload: ciphertext,
}

service.GracefulStop()
return metrics
}

t.Run("Localhost allowed", func(t *testing.T) {
metrics := checkLocalhost(allowAll)
if metrics.natEntriesAdded != 1 {
t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
}
})

t.Run("Localhost not allowed", func(t *testing.T) {
metrics := checkLocalhost(onet.RequirePublicIP)
if metrics.natEntriesAdded != 0 {
t.Error("Unexpected NAT entry on rejected packet")
}
})
}

func assertAlmostEqual(t *testing.T, a, b time.Time) {
delta := a.Sub(b)
limit := 100 * time.Millisecond
Expand All @@ -95,14 +169,14 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) {
}

func TestNATEmpty(t *testing.T) {
nat := newNATmap(timeout, &probeTestMetrics{}, &sync.WaitGroup{})
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{})
if nat.Get("foo") != nil {
t.Error("Expected nil value from empty NAT map")
}
}

func setup() (*fakePacketConn, *fakePacketConn, *natconn) {
nat := newNATmap(timeout, &probeTestMetrics{}, &sync.WaitGroup{})
func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) {
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{})
clientConn := makePacketConn()
targetConn := makePacketConn()
nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id")
Expand All @@ -111,7 +185,7 @@ func setup() (*fakePacketConn, *fakePacketConn, *natconn) {
}

func TestNATGet(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()
if entry == nil {
t.Fatal("Failed to find target conn")
}
Expand All @@ -121,7 +195,7 @@ func TestNATGet(t *testing.T) {
}

func TestNATWrite(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()

// Simulate one generic packet being sent
buf := []byte{1}
Expand All @@ -137,7 +211,7 @@ func TestNATWrite(t *testing.T) {
}

func TestNATWriteDNS(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()

// Simulate one DNS query being sent.
buf := []byte{1}
Expand All @@ -154,7 +228,7 @@ func TestNATWriteDNS(t *testing.T) {
}

func TestNATWriteDNSMultiple(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()

// Simulate three DNS queries being sent.
buf := []byte{1}
Expand All @@ -169,7 +243,7 @@ func TestNATWriteDNSMultiple(t *testing.T) {
}

func TestNATWriteMixed(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()

// Simulate both non-DNS and DNS packets being sent.
buf := []byte{1}
Expand All @@ -182,7 +256,7 @@ func TestNATWriteMixed(t *testing.T) {
}

func TestNATFastClose(t *testing.T) {
clientConn, targetConn, entry := setup()
clientConn, targetConn, entry := setupNAT()

// Send one DNS query.
query := []byte{1}
Expand All @@ -208,7 +282,7 @@ func TestNATFastClose(t *testing.T) {
}

func TestNATNoFastClose_NotDNS(t *testing.T) {
clientConn, targetConn, entry := setup()
clientConn, targetConn, entry := setupNAT()

// Send one non-DNS packet.
query := []byte{1}
Expand All @@ -233,7 +307,7 @@ func TestNATNoFastClose_NotDNS(t *testing.T) {
}

func TestNATNoFastClose_MultipleDNS(t *testing.T) {
clientConn, targetConn, entry := setup()
clientConn, targetConn, entry := setupNAT()

// Send two DNS packets.
query1 := []byte{1}
Expand Down Expand Up @@ -267,7 +341,7 @@ func (e *fakeTimeoutError) Temporary() bool {
}

func TestNATTimeout(t *testing.T) {
_, targetConn, entry := setup()
_, targetConn, entry := setupNAT()

// Simulate a non-DNS initial packet.
entry.WriteTo([]byte{1}, &targetAddr)
Expand Down Expand Up @@ -365,7 +439,7 @@ func TestUDPDoubleServe(t *testing.T) {
if err != nil {
t.Fatal(err)
}
testMetrics := &probeTestMetrics{}
testMetrics := &natTestMetrics{}
const testTimeout = 200 * time.Millisecond
s := NewUDPService(testTimeout, cipherList, testMetrics)

Expand Down Expand Up @@ -399,7 +473,7 @@ func TestUDPEarlyStop(t *testing.T) {
if err != nil {
t.Fatal(err)
}
testMetrics := &probeTestMetrics{}
testMetrics := &natTestMetrics{}
const testTimeout = 200 * time.Millisecond
s := NewUDPService(testTimeout, cipherList, testMetrics)

Expand Down

0 comments on commit 6e8119b

Please sign in to comment.