Skip to content

Commit 4f3ce4d

Browse files
author
Benjamin M. Schwartz
authored
Merge pull request #97 from Jigsaw-Code/bemasc-keyid
Bugfix: track the keyID for UDP connections
2 parents 4759578 + 4a91989 commit 4f3ce4d

File tree

2 files changed

+69
-29
lines changed

2 files changed

+69
-29
lines changed

service/udp.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
173173
var tgtUDPAddr *net.UDPAddr
174174
targetConn := nm.Get(clientAddr.String())
175175
if targetConn == nil {
176-
clientLocation, locErr := s.m.GetLocation(clientAddr)
176+
var locErr error
177+
clientLocation, locErr = s.m.GetLocation(clientAddr)
177178
if locErr != nil {
178179
logger.Warningf("Failed location lookup: %v", locErr)
179180
}
@@ -201,19 +202,23 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
201202
}
202203
targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID)
203204
} else {
205+
clientLocation = targetConn.clientLocation
206+
204207
unpackStart := time.Now()
205208
textData, err := ss.Unpack(nil, cipherData, targetConn.cipher)
206209
timeToCipher = time.Now().Sub(unpackStart)
207210
if err != nil {
208211
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
209212
}
210213

214+
// The key ID is known with confidence once decryption succeeds.
215+
keyID = targetConn.keyID
216+
211217
var onetErr *onet.ConnectionError
212218
if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil {
213219
return onetErr
214220
}
215221
}
216-
clientLocation = targetConn.clientLocation
217222

218223
debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
219224
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
@@ -271,6 +276,7 @@ func isDNS(addr net.Addr) bool {
271276
type natconn struct {
272277
net.PacketConn
273278
cipher *ss.Cipher
279+
keyID string
274280
// We store the client location in the NAT map to avoid recomputing it
275281
// for every downstream packet in a UDP-based connection.
276282
clientLocation string
@@ -351,10 +357,11 @@ func (m *natmap) Get(key string) *natconn {
351357
return m.keyConn[key]
352358
}
353359

354-
func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, clientLocation string) *natconn {
360+
func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn {
355361
entry := &natconn{
356362
PacketConn: pc,
357363
cipher: cipher,
364+
keyID: keyID,
358365
clientLocation: clientLocation,
359366
defaultTimeout: m.timeout,
360367
}
@@ -379,7 +386,7 @@ func (m *natmap) del(key string) net.PacketConn {
379386
}
380387

381388
func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn {
382-
entry := m.set(clientAddr.String(), targetConn, cipher, clientLocation)
389+
entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation)
383390

384391
m.metrics.AddUDPNatEntry()
385392
m.running.Add(1)

service/udp_test.go

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
2828
logging "github.com/op/go-logging"
2929
"github.com/shadowsocks/go-shadowsocks2/socks"
30+
"github.com/stretchr/testify/assert"
3031
)
3132

3233
const timeout = 5 * time.Minute
@@ -89,10 +90,16 @@ func (conn *fakePacketConn) Close() error {
8990
return nil
9091
}
9192

93+
type udpReport struct {
94+
clientLocation, accessKey, status string
95+
clientProxyBytes, proxyTargetBytes int
96+
}
97+
9298
// Stub metrics implementation for testing NAT behaviors.
9399
type natTestMetrics struct {
94100
metrics.ShadowsocksMetrics
95101
natEntriesAdded int
102+
upstreamPackets []udpReport
96103
}
97104

98105
func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) {
@@ -107,6 +114,7 @@ func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) {
107114
func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) {
108115
}
109116
func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) {
117+
m.upstreamPackets = append(m.upstreamPackets, udpReport{clientLocation, accessKey, status, clientProxyBytes, proxyTargetBytes})
110118
}
111119
func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
112120
}
@@ -115,21 +123,20 @@ func (m *natTestMetrics) AddUDPNatEntry() {
115123
}
116124
func (m *natTestMetrics) RemoveUDPNatEntry() {}
117125

118-
func TestIPFilter(t *testing.T) {
119-
// Takes a validation policy, and returns the metrics it
120-
// generates when localhost access is attempted
121-
checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics {
122-
ciphers, _ := MakeTestCiphers([]string{"asdf"})
123-
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
124-
clientConn := makePacketConn()
125-
metrics := &natTestMetrics{}
126-
service := NewUDPService(timeout, ciphers, metrics)
127-
service.SetTargetIPValidator(validator)
128-
go service.Serve(clientConn)
129-
130-
// Send one packet to the "discard" port on localhost
131-
targetAddr := socks.ParseAddr("127.0.0.1:9")
132-
payload := []byte("payload")
126+
// Takes a validation policy, and returns the metrics it
127+
// generates when localhost access is attempted
128+
func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics {
129+
ciphers, _ := MakeTestCiphers([]string{"asdf"})
130+
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
131+
clientConn := makePacketConn()
132+
metrics := &natTestMetrics{}
133+
service := NewUDPService(timeout, ciphers, metrics)
134+
service.SetTargetIPValidator(validator)
135+
go service.Serve(clientConn)
136+
137+
// Send one packet to the "discard" port on localhost
138+
targetAddr := socks.ParseAddr("127.0.0.1:9")
139+
for _, payload := range payloads {
133140
plaintext := append(targetAddr, payload...)
134141
ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize())
135142
ss.Pack(ciphertext, plaintext, cipher)
@@ -140,26 +147,52 @@ func TestIPFilter(t *testing.T) {
140147
},
141148
payload: ciphertext,
142149
}
143-
144-
service.GracefulStop()
145-
return metrics
146150
}
147151

152+
service.GracefulStop()
153+
return metrics
154+
}
155+
156+
func TestIPFilter(t *testing.T) {
157+
// Test both the first-packet and subsequent-packet cases.
158+
payloads := [][]byte{[]byte("payload1"), []byte("payload2")}
159+
148160
t.Run("Localhost allowed", func(t *testing.T) {
149-
metrics := checkLocalhost(allowAll)
150-
if metrics.natEntriesAdded != 1 {
151-
t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
152-
}
161+
metrics := sendToDiscard(payloads, allowAll)
162+
assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
153163
})
154164

155165
t.Run("Localhost not allowed", func(t *testing.T) {
156-
metrics := checkLocalhost(onet.RequirePublicIP)
157-
if metrics.natEntriesAdded != 0 {
158-
t.Error("Unexpected NAT entry on rejected packet")
166+
metrics := sendToDiscard(payloads, onet.RequirePublicIP)
167+
assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet")
168+
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
169+
for _, report := range metrics.upstreamPackets {
170+
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
171+
assert.Equal(t, 0, report.proxyTargetBytes, "No bytes should be sent due to a disallowed packet")
172+
assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey)
159173
}
160174
})
161175
}
162176

177+
func TestUpstreamMetrics(t *testing.T) {
178+
// Test both the first-packet and subsequent-packet cases.
179+
const N = 10
180+
payloads := make([][]byte, 0)
181+
for i := 1; i <= N; i++ {
182+
payloads = append(payloads, make([]byte, i))
183+
}
184+
185+
metrics := sendToDiscard(payloads, allowAll)
186+
187+
assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %v", N, metrics.upstreamPackets)
188+
for i, report := range metrics.upstreamPackets {
189+
assert.Equal(t, i+1, report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes)
190+
assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes)
191+
assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey)
192+
assert.Equal(t, "OK", report.status, "Wrong status: %s", report.status)
193+
}
194+
}
195+
163196
func assertAlmostEqual(t *testing.T, a, b time.Time) {
164197
delta := a.Sub(b)
165198
limit := 100 * time.Millisecond

0 commit comments

Comments
 (0)