Skip to content

Commit

Permalink
Merge pull request #97 from Jigsaw-Code/bemasc-keyid
Browse files Browse the repository at this point in the history
Bugfix: track the keyID for UDP connections
  • Loading branch information
Benjamin M. Schwartz authored Nov 6, 2020
2 parents 4759578 + 4a91989 commit 4f3ce4d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 29 deletions.
15 changes: 11 additions & 4 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
var tgtUDPAddr *net.UDPAddr
targetConn := nm.Get(clientAddr.String())
if targetConn == nil {
clientLocation, locErr := s.m.GetLocation(clientAddr)
var locErr error
clientLocation, locErr = s.m.GetLocation(clientAddr)
if locErr != nil {
logger.Warningf("Failed location lookup: %v", locErr)
}
Expand Down Expand Up @@ -201,19 +202,23 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
}
targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID)
} else {
clientLocation = targetConn.clientLocation

unpackStart := time.Now()
textData, err := ss.Unpack(nil, cipherData, targetConn.cipher)
timeToCipher = time.Now().Sub(unpackStart)
if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
}

// The key ID is known with confidence once decryption succeeds.
keyID = targetConn.keyID

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil {
return onetErr
}
}
clientLocation = targetConn.clientLocation

debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
Expand Down Expand Up @@ -271,6 +276,7 @@ func isDNS(addr net.Addr) bool {
type natconn struct {
net.PacketConn
cipher *ss.Cipher
keyID string
// We store the client location in the NAT map to avoid recomputing it
// for every downstream packet in a UDP-based connection.
clientLocation string
Expand Down Expand Up @@ -351,10 +357,11 @@ func (m *natmap) Get(key string) *natconn {
return m.keyConn[key]
}

func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, clientLocation string) *natconn {
func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn {
entry := &natconn{
PacketConn: pc,
cipher: cipher,
keyID: keyID,
clientLocation: clientLocation,
defaultTimeout: m.timeout,
}
Expand All @@ -379,7 +386,7 @@ func (m *natmap) del(key string) net.PacketConn {
}

func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn {
entry := m.set(clientAddr.String(), targetConn, cipher, clientLocation)
entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation)

m.metrics.AddUDPNatEntry()
m.running.Add(1)
Expand Down
83 changes: 58 additions & 25 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
"github.com/stretchr/testify/assert"
)

const timeout = 5 * time.Minute
Expand Down Expand Up @@ -89,10 +90,16 @@ func (conn *fakePacketConn) Close() error {
return nil
}

type udpReport struct {
clientLocation, accessKey, status string
clientProxyBytes, proxyTargetBytes int
}

// Stub metrics implementation for testing NAT behaviors.
type natTestMetrics struct {
metrics.ShadowsocksMetrics
natEntriesAdded int
upstreamPackets []udpReport
}

func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) {
Expand All @@ -107,6 +114,7 @@ func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) {
func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) {
}
func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) {
m.upstreamPackets = append(m.upstreamPackets, udpReport{clientLocation, accessKey, status, clientProxyBytes, proxyTargetBytes})
}
func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
}
Expand All @@ -115,21 +123,20 @@ func (m *natTestMetrics) AddUDPNatEntry() {
}
func (m *natTestMetrics) RemoveUDPNatEntry() {}

func TestIPFilter(t *testing.T) {
// Takes a validation policy, and returns the metrics it
// generates when localhost access is attempted
checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics {
ciphers, _ := MakeTestCiphers([]string{"asdf"})
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
clientConn := makePacketConn()
metrics := &natTestMetrics{}
service := NewUDPService(timeout, ciphers, metrics)
service.SetTargetIPValidator(validator)
go service.Serve(clientConn)

// Send one packet to the "discard" port on localhost
targetAddr := socks.ParseAddr("127.0.0.1:9")
payload := []byte("payload")
// Takes a validation policy, and returns the metrics it
// generates when localhost access is attempted
func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics {
ciphers, _ := MakeTestCiphers([]string{"asdf"})
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
clientConn := makePacketConn()
metrics := &natTestMetrics{}
service := NewUDPService(timeout, ciphers, metrics)
service.SetTargetIPValidator(validator)
go service.Serve(clientConn)

// Send one packet to the "discard" port on localhost
targetAddr := socks.ParseAddr("127.0.0.1:9")
for _, payload := range payloads {
plaintext := append(targetAddr, payload...)
ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize())
ss.Pack(ciphertext, plaintext, cipher)
Expand All @@ -140,26 +147,52 @@ func TestIPFilter(t *testing.T) {
},
payload: ciphertext,
}

service.GracefulStop()
return metrics
}

service.GracefulStop()
return metrics
}

func TestIPFilter(t *testing.T) {
// Test both the first-packet and subsequent-packet cases.
payloads := [][]byte{[]byte("payload1"), []byte("payload2")}

t.Run("Localhost allowed", func(t *testing.T) {
metrics := checkLocalhost(allowAll)
if metrics.natEntriesAdded != 1 {
t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
}
metrics := sendToDiscard(payloads, allowAll)
assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
})

t.Run("Localhost not allowed", func(t *testing.T) {
metrics := checkLocalhost(onet.RequirePublicIP)
if metrics.natEntriesAdded != 0 {
t.Error("Unexpected NAT entry on rejected packet")
metrics := sendToDiscard(payloads, onet.RequirePublicIP)
assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet")
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
for _, report := range metrics.upstreamPackets {
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet")
assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey)
}
})
}

func TestUpstreamMetrics(t *testing.T) {
// Test both the first-packet and subsequent-packet cases.
const N = 10
payloads := make([][]byte, 0)
for i := 1; i <= N; i++ {
payloads = append(payloads, make([]byte, i))
}

metrics := sendToDiscard(payloads, allowAll)

assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets)
for i, report := range metrics.upstreamPackets {
assert.Equal(t, i+1, report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes)
assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes)
assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey)
assert.Equal(t, "OK", report.status, "Wrong status: %s", report.status)
}
}

func assertAlmostEqual(t *testing.T, a, b time.Time) {
delta := a.Sub(b)
limit := 100 * time.Millisecond
Expand Down

0 comments on commit 4f3ce4d

Please sign in to comment.