Skip to content

Commit 5707a9c

Browse files
committed
refactor: modularize UDP connection handling
1 parent c4d9214 commit 5707a9c

File tree

2 files changed

+100
-103
lines changed

2 files changed

+100
-103
lines changed

service/udp.go

Lines changed: 89 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,29 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {
6565

6666
// Decrypts src into dst. It tries each cipher until it finds one that authenticates
6767
// correctly. dst and src must not overlap.
68-
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
68+
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) (*CipherEntry, []byte, error) {
6969
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
7070
// We snapshot the list because it may be modified while we use it.
7171
snapshot := cipherList.SnapshotForClientIP(clientIP)
72-
for ci, entry := range snapshot {
73-
id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey
74-
buf, err := shadowsocks.Unpack(dst, src, cryptoKey)
72+
for ci, elt := range snapshot {
73+
entry := elt.Value.(*CipherEntry)
74+
buf, err := shadowsocks.Unpack(dst, src, entry.CryptoKey)
7575
if err != nil {
76-
debugUDP(id, "Failed to unpack: %v", err)
76+
debugUDP(entry.ID, "Failed to unpack: %v", err)
7777
continue
7878
}
79-
debugUDP(id, "Found cipher at index %d", ci)
79+
debugUDP(entry.ID, "Found cipher at index %d", ci)
8080
// Move the active cipher to the front, so that the search is quicker next time.
81-
cipherList.MarkUsedByClientIP(entry, clientIP)
82-
return buf, id, cryptoKey, nil
81+
cipherList.MarkUsedByClientIP(elt, clientIP)
82+
return entry, buf, nil
8383
}
84-
return nil, "", nil, errors.New("could not find valid UDP cipher")
84+
return nil, nil, errors.New("could not find valid UDP cipher")
8585
}
8686

8787
type packetHandler struct {
8888
natTimeout time.Duration
8989
ciphers CipherList
90+
nm *natmap
9091
m UDPMetrics
9192
targetIPValidator onet.TargetIPValidator
9293
}
@@ -113,108 +114,94 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali
113114
func (h *packetHandler) Handle(clientConn net.PacketConn) {
114115
var running sync.WaitGroup
115116

116-
nm := newNATmap(h.natTimeout, h.m, &running)
117-
defer nm.Close()
118-
cipherBuf := make([]byte, serverUDPBufferSize)
119-
textBuf := make([]byte, serverUDPBufferSize)
117+
h.nm = newNATmap(h.natTimeout, h.m, &running)
118+
defer h.nm.Close()
120119

121120
for {
122-
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
123-
if errors.Is(err, net.ErrClosed) {
124-
break
121+
status := "OK"
122+
keyID, clientInfo, clientProxyBytes, proxyTargetBytes, connErr := h.handleConnection(clientConn)
123+
if connErr != nil {
124+
if errors.Is(connErr.Cause, net.ErrClosed) {
125+
break
126+
}
127+
logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause)
128+
status = connErr.Status
125129
}
130+
h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes)
131+
}
132+
}
126133

127-
var clientInfo ipinfo.IPInfo
128-
keyID := ""
129-
var proxyTargetBytes int
134+
func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byte, int, *onet.ConnectionError) {
135+
cipherBuf := make([]byte, serverUDPBufferSize)
136+
textBuf := make([]byte, serverUDPBufferSize)
137+
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
138+
if err != nil {
139+
return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
140+
}
130141

131-
connError := func() (connError *onet.ConnectionError) {
132-
defer func() {
133-
if r := recover(); r != nil {
134-
logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r)
135-
debug.PrintStack()
136-
}
137-
}()
142+
if logger.IsEnabledFor(logging.DEBUG) {
143+
defer logger.Debugf("UDP(%v): done", clientAddr)
144+
logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes)
145+
}
138146

139-
// Error from ReadFrom
140-
if err != nil {
141-
return onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
142-
}
143-
if logger.IsEnabledFor(logging.DEBUG) {
144-
defer logger.Debugf("UDP(%v): done", clientAddr)
145-
logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes)
146-
}
147+
targetConn := h.nm.Get(clientAddr.String())
148+
remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr()
147149

148-
cipherData := cipherBuf[:clientProxyBytes]
149-
var payload []byte
150-
var tgtUDPAddr *net.UDPAddr
151-
targetConn := nm.Get(clientAddr.String())
152-
if targetConn == nil {
153-
var locErr error
154-
clientInfo, locErr = ipinfo.GetIPInfoFromAddr(h.m, clientAddr)
155-
if locErr != nil {
156-
logger.Warningf("Failed client info lookup: %v", locErr)
157-
}
158-
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)
159-
160-
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
161-
var textData []byte
162-
var cryptoKey *shadowsocks.EncryptionKey
163-
unpackStart := time.Now()
164-
textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers)
165-
timeToCipher := time.Since(unpackStart)
166-
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
167-
168-
if err != nil {
169-
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
170-
}
150+
unpackStart := time.Now()
151+
cipherEntry, textData, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers)
152+
timeToCipher := time.Since(unpackStart)
153+
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
154+
if keyErr != nil {
155+
return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr)
156+
}
171157

172-
var onetErr *onet.ConnectionError
173-
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
174-
return onetErr
175-
}
158+
if targetConn != nil {
159+
return targetConn, textData, clientProxyBytes, nil
160+
}
176161

177-
udpConn, err := net.ListenPacket("udp", "")
178-
if err != nil {
179-
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
180-
}
181-
targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)
182-
} else {
183-
clientInfo = targetConn.clientInfo
162+
udpConn, err := net.ListenPacket("udp", "")
163+
if err != nil {
164+
return targetConn, textData, clientProxyBytes, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
165+
}
184166

185-
unpackStart := time.Now()
186-
textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey)
187-
timeToCipher := time.Since(unpackStart)
188-
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
167+
clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr)
168+
if locErr != nil {
169+
logger.Warningf("Failed client info lookup: %v", locErr)
170+
}
171+
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)
189172

190-
if err != nil {
191-
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
192-
}
173+
targetConn = h.nm.Add(clientAddr, clientConn, cipherEntry.CryptoKey, udpConn, clientInfo, cipherEntry.ID)
174+
return targetConn, textData, clientProxyBytes, nil
175+
}
193176

194-
// The key ID is known with confidence once decryption succeeds.
195-
keyID = targetConn.keyID
177+
func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) {
178+
defer func() {
179+
if r := recover(); r != nil {
180+
logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r)
181+
debug.PrintStack()
182+
}
183+
}()
196184

197-
var onetErr *onet.ConnectionError
198-
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
199-
return onetErr
200-
}
201-
}
185+
targetConn, textData, clientProxyBytes, authErr := h.authenticate(clientConn)
186+
if authErr != nil {
187+
var clientInfo ipinfo.IPInfo
188+
if targetConn != nil {
189+
clientInfo = targetConn.clientInfo
190+
}
191+
return "", clientInfo, clientProxyBytes, 0, authErr
192+
}
202193

203-
debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
204-
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
205-
if err != nil {
206-
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
207-
}
208-
return nil
209-
}()
194+
payload, tgtUDPAddr, onetErr := h.validatePacket(textData)
195+
if onetErr != nil {
196+
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, 0, onetErr
197+
}
210198

211-
status := "OK"
212-
if connError != nil {
213-
logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause)
214-
status = connError.Status
215-
}
216-
h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes)
199+
debugUDPAddr(targetConn.clientAddr, "Proxy exit %v", targetConn.LocalAddr())
200+
proxyTargetBytes, err := targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
201+
if err != nil {
202+
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
217203
}
204+
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, nil
218205
}
219206

220207
// Given the decrypted contents of a UDP packet, return
@@ -245,8 +232,9 @@ func isDNS(addr net.Addr) bool {
245232

246233
type natconn struct {
247234
net.PacketConn
248-
cryptoKey *shadowsocks.EncryptionKey
249-
keyID string
235+
cryptoKey *shadowsocks.EncryptionKey
236+
keyID string
237+
clientAddr net.Addr
250238
// We store the client information in the NAT map to avoid recomputing it
251239
// for every downstream packet in a UDP-based connection.
252240
clientInfo ipinfo.IPInfo
@@ -327,19 +315,20 @@ func (m *natmap) Get(key string) *natconn {
327315
return m.keyConn[key]
328316
}
329317

330-
func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn {
318+
func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn {
331319
entry := &natconn{
332320
PacketConn: pc,
333321
cryptoKey: cryptoKey,
334322
keyID: keyID,
323+
clientAddr: clientAddr,
335324
clientInfo: clientInfo,
336325
defaultTimeout: m.timeout,
337326
}
338327

339328
m.Lock()
340329
defer m.Unlock()
341330

342-
m.keyConn[key] = entry
331+
m.keyConn[clientAddr.String()] = entry
343332
return entry
344333
}
345334

@@ -356,7 +345,7 @@ func (m *natmap) del(key string) net.PacketConn {
356345
}
357346

358347
func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn {
359-
entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, clientInfo)
348+
entry := m.set(clientAddr, targetConn, cryptoKey, keyID, clientInfo)
360349

361350
m.metrics.AddUDPNatEntry(clientAddr, keyID)
362351
m.running.Add(1)

service/udp_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,20 @@ func TestIPFilter(t *testing.T) {
162162

163163
t.Run("Localhost allowed", func(t *testing.T) {
164164
metrics := sendToDiscard(payloads, allowAll)
165+
165166
assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
167+
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
168+
for _, report := range metrics.upstreamPackets {
169+
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
170+
assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet")
171+
assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey)
172+
}
166173
})
167174

168175
t.Run("Localhost not allowed", func(t *testing.T) {
169176
metrics := sendToDiscard(payloads, onet.RequirePublicIP)
170-
assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet")
177+
178+
assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
171179
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
172180
for _, report := range metrics.upstreamPackets {
173181
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
@@ -437,7 +445,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) {
437445
cipherNumber := n % numCiphers
438446
ip := ips[cipherNumber]
439447
packet := packets[cipherNumber]
440-
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
448+
_, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
441449
if err != nil {
442450
b.Error(err)
443451
}
@@ -466,7 +474,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) {
466474
b.ResetTimer()
467475
for n := 0; n < b.N; n++ {
468476
ip := ips[n%numIPs]
469-
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
477+
_, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
470478
if err != nil {
471479
b.Error(err)
472480
}

0 commit comments

Comments
 (0)