diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 3a04af0b..23bd143c 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -238,7 +238,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, ssService.HandlePacket) } for _, serviceConfig := range config.Services { @@ -271,7 +271,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, ssService.HandlePacket) } } totalCipherCount += len(serviceConfig.Keys) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 0994b90f..ca359db7 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -315,7 +315,7 @@ func TestUDPEcho(t *testing.T) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + service.PacketServe(proxyConn, proxy.Handle) done <- struct{}{} }() @@ -548,7 +548,7 @@ func BenchmarkUDPEcho(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(server) + service.PacketServe(server, proxy.Handle) done <- struct{}{} }() @@ -592,7 +592,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + service.PacketServe(proxyConn, proxy.Handle) done <- struct{}{} }() diff --git a/service/shadowsocks.go b/service/shadowsocks.go index 636fa94e..7b8221d3 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -44,7 +44,7 @@ type ServiceMetrics interface { type Service interface { HandleStream(ctx context.Context, conn transport.StreamConn) - HandlePacket(conn net.PacketConn) + HandlePacket(conn net.Conn, pkt []byte) } // Option is a Shadowsocks service constructor option. @@ -137,8 +137,8 @@ func (s *ssService) HandleStream(ctx context.Context, conn transport.StreamConn) } // HandlePacket handles a Shadowsocks packet connection. -func (s *ssService) HandlePacket(conn net.PacketConn) { - s.ph.Handle(conn) +func (s *ssService) HandlePacket(conn net.Conn, pkt []byte) { + s.ph.Handle(conn, pkt) } type ssConnMetrics struct { diff --git a/service/udp.go b/service/udp.go index 5616fb8a..7dd06b47 100644 --- a/service/udp.go +++ b/service/udp.go @@ -43,6 +43,12 @@ type UDPMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, serverUDPBufferSize) + }, +} + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -83,7 +89,7 @@ type packetHandler struct { logger *slog.Logger natTimeout time.Duration ciphers CipherList - m UDPMetrics + nm *natmap ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator } @@ -96,11 +102,11 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } + nm := newNATmap(natTimeout, m, noopLogger()) return &packetHandler{ logger: noopLogger(), - natTimeout: natTimeout, ciphers: cipherList, - m: m, + nm: nm, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, } @@ -108,12 +114,11 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr // PacketHandler is a running UDP shadowsocks proxy that can be stopped. type PacketHandler interface { + Handle(conn net.Conn, pkt []byte) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) - // Handle returns after clientConn closes and all the sub goroutines return. - Handle(clientConn net.PacketConn) } func (h *packetHandler) SetLogger(l *slog.Logger) { @@ -127,101 +132,124 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali h.targetIPValidator = targetIPValidator } -// Listen on addr for encrypted packets and basically do UDP NAT. -// We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(clientConn net.PacketConn) { - nm := newNATmap(h.natTimeout, h.m, h.logger) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) +type PacketHandleFunc func(conn net.Conn, pkt []byte) + +// PacketServe listens for packets and calls `handle` to handle them until the connection +// returns [ErrClosed]. +func PacketServe(clientConn net.PacketConn, handle PacketHandleFunc) { + buffer := bufferPool.Get().([]byte) + defer bufferPool.Put(buffer) for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { - break + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue } + pkt := buffer[:n] - keyID := "" - var proxyTargetBytes int - var targetConn *natconn - - connError := func() (connError *onet.ConnectionError) { + func() { defer func() { if r := recover(); r != nil { - slog.Error("Panic in UDP loop: %v. Continuing to listen.", r) + slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) debug.PrintStack() } }() + handle(&wrappedPacketConn{PacketConn: clientConn, raddr: addr}, pkt) + }() + } +} - // Error from ReadFrom - if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) - } - defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String())) - debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) - - cipherData := cipherBuf[:clientProxyBytes] - var payload []byte - var tgtUDPAddr *net.UDPAddr - targetConn = nm.Get(clientAddr.String()) - if targetConn == nil { - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var textData []byte - var cryptoKey *shadowsocks.EncryptionKey - unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) - timeToCipher := time.Since(unpackStart) - h.ssm.AddCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) - } +type wrappedPacketConn struct { + net.PacketConn + raddr net.Addr +} - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } +func (pc *wrappedPacketConn) Read(p []byte) (int, error) { + n, _, err := pc.PacketConn.ReadFrom(p) + return n, err +} - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, keyID) - } else { - unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) - timeToCipher := time.Since(unpackStart) - h.ssm.AddCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } +func (pc *wrappedPacketConn) RemoteAddr() net.Addr { + return pc.raddr +} - // The key ID is known with confidence once decryption succeeds. - keyID = targetConn.keyID +func (pc *wrappedPacketConn) Write(b []byte) (n int, err error) { + return pc.PacketConn.WriteTo(b, pc.raddr) +} - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } +func (h *packetHandler) Handle(clientConn net.Conn, pkt []byte) { + buffer := bufferPool.Get().([]byte) + defer bufferPool.Put(buffer) + + var err error + var proxyTargetBytes int + var targetConn *natconn + + connError := func() (connError *onet.ConnectionError) { + defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientConn.RemoteAddr().String())) + debugUDPAddr(h.logger, "Outbound packet.", clientConn.RemoteAddr(), slog.Int("bytes", len(pkt))) + + var payload []byte + var tgtUDPAddr *net.UDPAddr + targetConn = h.nm.Get(clientConn.RemoteAddr().String()) + if targetConn == nil { + ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() + var textData []byte + var cryptoKey *shadowsocks.EncryptionKey + unpackStart := time.Now() + textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, buffer, pkt, h.ciphers, h.logger) + timeToCipher := time.Since(unpackStart) + h.ssm.AddCipherSearch(err == nil, timeToCipher) + + if err != nil { + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) } - debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) - proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { + return onetErr + } + + udpConn, err := net.ListenPacket("udp", "") if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) + return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) } - return nil - }() + targetConn = h.nm.Add(clientConn, udpConn, cryptoKey, keyID) + } else { + unpackStart := time.Now() + textData, err := shadowsocks.Unpack(nil, pkt, targetConn.cryptoKey) + timeToCipher := time.Since(unpackStart) + h.ssm.AddCipherSearch(err == nil, timeToCipher) - status := "OK" - if connError != nil { - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) - status = connError.Status + if err != nil { + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) + } + + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { + return onetErr + } } - if targetConn != nil { - targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes)) + + debugUDPAddr(h.logger, "Proxy exit.", clientConn.RemoteAddr(), slog.Any("target", targetConn.LocalAddr())) + proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) } + return nil + }() + + status := "OK" + if connError != nil { + slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + status = connError.Status + } + if targetConn != nil { + targetConn.metrics.AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes)) } } @@ -333,11 +361,10 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, connMetrics UDPConnMetrics) *natconn { +func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, connMetrics UDPConnMetrics) *natconn { entry := &natconn{ PacketConn: pc, cryptoKey: cryptoKey, - keyID: keyID, metrics: connMetrics, defaultTimeout: m.timeout, } @@ -361,14 +388,14 @@ func (m *natmap) del(key string) net.PacketConn { return nil } -func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn { - connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID) - entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, connMetrics) +func (m *natmap) Add(clientConn net.Conn, targetConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string) *natconn { + connMetrics := m.metrics.AddUDPNatEntry(clientConn.RemoteAddr(), keyID) + entry := m.set(clientConn.RemoteAddr().String(), targetConn, cryptoKey, connMetrics) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.logger) + timedCopy(clientConn, entry, m.logger) connMetrics.RemoveNatEntry() - if pc := m.del(clientAddr.String()); pc != nil { + if pc := m.del(clientConn.RemoteAddr().String()); pc != nil { pc.Close() } }() @@ -394,7 +421,7 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l *slog.Logger) { +func timedCopy(clientConn net.Conn, targetConn *natconn, l *slog.Logger) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. @@ -427,7 +454,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) } - debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr)) + debugUDPAddr(l, "Got response.", clientConn.RemoteAddr(), slog.Any("target", raddr)) srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: @@ -448,7 +475,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } - proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr) + proxyClientBytes, err = clientConn.Write(buf) if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) } diff --git a/service/udp_test.go b/service/udp_test.go index 6f620316..9bf657ce 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -139,7 +139,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest handler.SetTargetIPValidator(validator) done := make(chan struct{}) go func() { - handler.Handle(clientConn) + PacketServe(clientConn, handler.Handle) done <- struct{}{} }() @@ -216,7 +216,7 @@ func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { nat := newNATmap(timeout, &natTestMetrics{}, noopLogger()) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id") + nat.Add(&wrappedPacketConn{PacketConn: clientConn, raddr: &clientAddr}, targetConn, natCryptoKey, "key id") entry := nat.Get(clientAddr.String()) return clientConn, targetConn, entry } @@ -481,7 +481,7 @@ func TestUDPEarlyClose(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewPacketHandler(testTimeout, cipherList, testMetrics, &fakeShadowsocksMetrics{}) + ph := NewPacketHandler(testTimeout, cipherList, testMetrics, &fakeShadowsocksMetrics{}) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -489,7 +489,7 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(clientConn) + PacketServe(clientConn, ph.Handle) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close().