Skip to content

Commit 95e4aa9

Browse files
author
Ben Schwartz
committed
Bugfix: track the keyID for UDP connections
This is important for applying usage limits to UDP streams. It also fixes a bug in metrics reporting, because the shared metrics are not tolerant of proxy<->target traffic that lacks a keyID.
1 parent 4759578 commit 95e4aa9

File tree

2 files changed

+85
-24
lines changed

2 files changed

+85
-24
lines changed

service/udp.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
173173
var tgtUDPAddr *net.UDPAddr
174174
targetConn := nm.Get(clientAddr.String())
175175
if targetConn == nil {
176-
clientLocation, locErr := s.m.GetLocation(clientAddr)
176+
var locErr error
177+
clientLocation, locErr = s.m.GetLocation(clientAddr)
177178
if locErr != nil {
178179
logger.Warningf("Failed location lookup: %v", locErr)
179180
}
@@ -201,19 +202,23 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
201202
}
202203
targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID)
203204
} else {
205+
clientLocation = targetConn.clientLocation
206+
204207
unpackStart := time.Now()
205208
textData, err := ss.Unpack(nil, cipherData, targetConn.cipher)
206209
timeToCipher = time.Now().Sub(unpackStart)
207210
if err != nil {
208211
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
209212
}
210213

214+
// The key ID is known with confidence once decryption succeeds.
215+
keyID = targetConn.keyID
216+
211217
var onetErr *onet.ConnectionError
212218
if payload, tgtUDPAddr, onetErr = s.validatePacket(textData); onetErr != nil {
213219
return onetErr
214220
}
215221
}
216-
clientLocation = targetConn.clientLocation
217222

218223
debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
219224
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
@@ -271,6 +276,7 @@ func isDNS(addr net.Addr) bool {
271276
type natconn struct {
272277
net.PacketConn
273278
cipher *ss.Cipher
279+
keyID string
274280
// We store the client location in the NAT map to avoid recomputing it
275281
// for every downstream packet in a UDP-based connection.
276282
clientLocation string
@@ -351,10 +357,11 @@ func (m *natmap) Get(key string) *natconn {
351357
return m.keyConn[key]
352358
}
353359

354-
func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, clientLocation string) *natconn {
360+
func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn {
355361
entry := &natconn{
356362
PacketConn: pc,
357363
cipher: cipher,
364+
keyID: keyID,
358365
clientLocation: clientLocation,
359366
defaultTimeout: m.timeout,
360367
}
@@ -379,7 +386,7 @@ func (m *natmap) del(key string) net.PacketConn {
379386
}
380387

381388
func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn {
382-
entry := m.set(clientAddr.String(), targetConn, cipher, clientLocation)
389+
entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation)
383390

384391
m.metrics.AddUDPNatEntry()
385392
m.running.Add(1)

service/udp_test.go

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,16 @@ func (conn *fakePacketConn) Close() error {
8989
return nil
9090
}
9191

92+
type udpReport struct {
93+
clientLocation, accessKey, status string
94+
clientProxyBytes, proxyTargetBytes int
95+
}
96+
9297
// Stub metrics implementation for testing NAT behaviors.
9398
type natTestMetrics struct {
9499
metrics.ShadowsocksMetrics
95100
natEntriesAdded int
101+
upstreamPackets []udpReport
96102
}
97103

98104
func (m *natTestMetrics) AddTCPProbe(clientLocation, status, drainResult string, port int, data metrics.ProxyMetrics) {
@@ -107,6 +113,7 @@ func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) {
107113
func (m *natTestMetrics) AddOpenTCPConnection(clientLocation string) {
108114
}
109115
func (m *natTestMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) {
116+
m.upstreamPackets = append(m.upstreamPackets, udpReport{clientLocation, accessKey, status, clientProxyBytes, proxyTargetBytes})
110117
}
111118
func (m *natTestMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
112119
}
@@ -115,21 +122,20 @@ func (m *natTestMetrics) AddUDPNatEntry() {
115122
}
116123
func (m *natTestMetrics) RemoveUDPNatEntry() {}
117124

118-
func TestIPFilter(t *testing.T) {
119-
// Takes a validation policy, and returns the metrics it
120-
// generates when localhost access is attempted
121-
checkLocalhost := func(validator onet.TargetIPValidator) *natTestMetrics {
122-
ciphers, _ := MakeTestCiphers([]string{"asdf"})
123-
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
124-
clientConn := makePacketConn()
125-
metrics := &natTestMetrics{}
126-
service := NewUDPService(timeout, ciphers, metrics)
127-
service.SetTargetIPValidator(validator)
128-
go service.Serve(clientConn)
129-
130-
// Send one packet to the "discard" port on localhost
131-
targetAddr := socks.ParseAddr("127.0.0.1:9")
132-
payload := []byte("payload")
125+
// Takes a validation policy, and returns the metrics it
126+
// generates when localhost access is attempted
127+
func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics {
128+
ciphers, _ := MakeTestCiphers([]string{"asdf"})
129+
cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
130+
clientConn := makePacketConn()
131+
metrics := &natTestMetrics{}
132+
service := NewUDPService(timeout, ciphers, metrics)
133+
service.SetTargetIPValidator(validator)
134+
go service.Serve(clientConn)
135+
136+
// Send one packet to the "discard" port on localhost
137+
targetAddr := socks.ParseAddr("127.0.0.1:9")
138+
for _, payload := range payloads {
133139
plaintext := append(targetAddr, payload...)
134140
ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize())
135141
ss.Pack(ciphertext, plaintext, cipher)
@@ -140,26 +146,74 @@ func TestIPFilter(t *testing.T) {
140146
},
141147
payload: ciphertext,
142148
}
143-
144-
service.GracefulStop()
145-
return metrics
146149
}
147150

151+
service.GracefulStop()
152+
return metrics
153+
}
154+
155+
func TestIPFilter(t *testing.T) {
156+
// Test both the first-packet and subsequent-packet cases.
157+
payloads := [][]byte{[]byte("payload1"), []byte("payload2")}
158+
148159
t.Run("Localhost allowed", func(t *testing.T) {
149-
metrics := checkLocalhost(allowAll)
160+
metrics := sendToDiscard(payloads, allowAll)
150161
if metrics.natEntriesAdded != 1 {
151162
t.Errorf("Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
152163
}
153164
})
154165

155166
t.Run("Localhost not allowed", func(t *testing.T) {
156-
metrics := checkLocalhost(onet.RequirePublicIP)
167+
metrics := sendToDiscard(payloads, onet.RequirePublicIP)
157168
if metrics.natEntriesAdded != 0 {
158169
t.Error("Unexpected NAT entry on rejected packet")
159170
}
171+
if len(metrics.upstreamPackets) != 2 {
172+
t.Errorf("Expected 2 reports, not %v", metrics.upstreamPackets)
173+
}
174+
for _, report := range metrics.upstreamPackets {
175+
if report.clientProxyBytes == 0 {
176+
t.Error("Expected nonzero input packet size")
177+
}
178+
if report.proxyTargetBytes != 0 {
179+
t.Error("No bytes should be sent due to a disallowed packet")
180+
}
181+
if report.accessKey != "id-0" {
182+
t.Errorf("Unexpected access key: %s", report.accessKey)
183+
}
184+
}
160185
})
161186
}
162187

188+
func TestUpstreamMetrics(t *testing.T) {
189+
// Test both the first-packet and subsequent-packet cases.
190+
const N = 10
191+
payloads := make([][]byte, 0)
192+
for i := 1; i <= N; i++ {
193+
payloads = append(payloads, make([]byte, i))
194+
}
195+
196+
metrics := sendToDiscard(payloads, allowAll)
197+
198+
if len(metrics.upstreamPackets) != N {
199+
t.Errorf("Expected %d reports, not %v", N, metrics.upstreamPackets)
200+
}
201+
for i, report := range metrics.upstreamPackets {
202+
if report.proxyTargetBytes != i+1 {
203+
t.Errorf("Expected %d payload bytes, not %d", i, report.proxyTargetBytes)
204+
}
205+
if report.clientProxyBytes <= report.proxyTargetBytes {
206+
t.Errorf("Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes)
207+
}
208+
if report.accessKey != "id-0" {
209+
t.Errorf("Unexpected access key name: %s", report.accessKey)
210+
}
211+
if report.status != "OK" {
212+
t.Errorf("Wrong status: %s", report.status)
213+
}
214+
}
215+
}
216+
163217
func assertAlmostEqual(t *testing.T, a, b time.Time) {
164218
delta := a.Sub(b)
165219
limit := 100 * time.Millisecond

0 commit comments

Comments
 (0)