diff --git a/caddy/shadowsocks_handler.go b/caddy/shadowsocks_handler.go index 0167895c..adc85d2a 100644 --- a/caddy/shadowsocks_handler.go +++ b/caddy/shadowsocks_handler.go @@ -51,9 +51,8 @@ type ShadowsocksHandler struct { Keys []KeyConfig `json:"keys,omitempty"` streamHandler outline.StreamHandler - packetHandler outline.PacketHandler + associationHandler outline.AssociationHandler metrics outline.ServiceMetrics - tgtListener transport.PacketListener logger *slog.Logger } @@ -106,13 +105,12 @@ func (h *ShadowsocksHandler) Provision(ctx caddy.Context) error { ciphers := outline.NewCipherList() ciphers.Update(cipherList) - h.streamHandler, h.packetHandler = outline.NewShadowsocksHandlers( + h.streamHandler, h.associationHandler = outline.NewShadowsocksHandlers( outline.WithLogger(h.logger), outline.WithCiphers(ciphers), outline.WithMetrics(h.metrics), outline.WithReplayCache(&app.ReplayCache), ) - h.tgtListener = outline.MakeTargetUDPListener(defaultNatTimeout, 0) return nil } @@ -122,11 +120,7 @@ func (h *ShadowsocksHandler) Handle(cx *layer4.Connection, _ layer4.Handler) err case transport.StreamConn: h.streamHandler.HandleStream(cx.Context, conn, h.metrics.AddOpenTCPConnection(conn)) case net.Conn: - assoc, err := outline.NewPacketAssociation(conn, h.tgtListener, h.metrics.AddOpenUDPAssociation(conn)) - if err != nil { - return fmt.Errorf("failed to handle association: %v", err) - } - outline.HandleAssociation(assoc, h.packetHandler.HandlePacket) + h.associationHandler.HandleAssociation(cx.Context, conn, h.metrics.AddOpenUDPAssociation(conn)) default: return fmt.Errorf("failed to handle unknown connection type: %t", conn) } diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index eed8f876..e8203d42 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -225,10 +225,11 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { ciphers := service.NewCipherList() ciphers.Update(cipherList) - streamHandler, packetHandler := service.NewShadowsocksHandlers( + streamHandler, associationHandler := service.NewShadowsocksHandlers( service.WithCiphers(ciphers), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, 0)), service.WithLogger(slog.Default()), ) ln, err := lnSet.ListenStream(addr) @@ -245,15 +246,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - tgtListener := service.MakeTargetUDPListener(s.natTimeout, 0) - go service.PacketServe(pc, func(conn net.Conn) (service.PacketAssociation, error) { - m := s.serviceMetrics.AddOpenUDPAssociation(conn) - assoc, err := service.NewPacketAssociation(conn, tgtListener, m) - if err != nil { - return nil, fmt.Errorf("failed to handle association: %v", err) - } - return assoc, nil - }, packetHandler.HandlePacket, s.serverMetrics) + go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { + associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) + }, s.serverMetrics) } for _, serviceConfig := range config.Services { @@ -261,11 +256,12 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { if err != nil { return fmt.Errorf("failed to create cipher list from config: %v", err) } - streamHandler, packetHandler := service.NewShadowsocksHandlers( + streamHandler, associationHandler := service.NewShadowsocksHandlers( service.WithCiphers(ciphers), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, serviceConfig.Dialer.Fwmark)), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)), service.WithLogger(slog.Default()), ) if err != nil { @@ -298,15 +294,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { } return serviceConfig.Dialer.Fwmark }()) - tgtListener := service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark) - go service.PacketServe(pc, func(conn net.Conn) (service.PacketAssociation, error) { - m := s.serviceMetrics.AddOpenUDPAssociation(conn) - assoc, err := service.NewPacketAssociation(conn, tgtListener, m) - if err != nil { - return nil, fmt.Errorf("failed to handle association: %v", err) - } - return assoc, nil - }, packetHandler.HandlePacket, s.serverMetrics) + go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { + associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) + }, s.serverMetrics) } } totalCipherCount += len(serviceConfig.Keys) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 82fb370f..10ac5018 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -317,15 +317,14 @@ func TestUDPEcho(t *testing.T) { if err != nil { t.Fatal(err) } - proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) natMetrics := &natTestMetrics{} associationMetrics := &fakeUDPAssociationMetrics{} - go service.PacketServe(proxyConn, func(conn net.Conn) (service.PacketAssociation, error) { - assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, associationMetrics) - return assoc, nil - }, proxy.Handle, natMetrics) + go service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, associationMetrics) + }, natMetrics) cryptoKey, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[0]) require.NoError(t, err) @@ -546,14 +545,13 @@ func BenchmarkUDPEcho(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - service.PacketServe(server, func(conn net.Conn) (service.PacketAssociation, error) { - assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, nil) - return assoc, nil - }, proxy.Handle, &natTestMetrics{}) + service.PacketServe(server, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) done <- struct{}{} }() @@ -593,14 +591,13 @@ func BenchmarkUDPManyKeys(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - service.PacketServe(proxyConn, func(conn net.Conn) (service.PacketAssociation, error) { - assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, nil) - return assoc, nil - }, proxy.Handle, &natTestMetrics{}) + service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) done <- struct{}{} }() diff --git a/service/shadowsocks.go b/service/shadowsocks.go index dd4ac517..e63ba835 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -24,10 +24,8 @@ import ( onet "github.com/Jigsaw-Code/outline-ss-server/net" ) -const ( - // 59 seconds is most common timeout for servers that do not respond to invalid requests - tcpReadTimeout time.Duration = 59 * time.Second -) +// 59 seconds is most common timeout for servers that do not respond to invalid requests +const tcpReadTimeout time.Duration = 59 * time.Second // ShadowsocksConnMetrics is used to report Shadowsocks related metrics on connections. type ShadowsocksConnMetrics interface { @@ -51,11 +49,12 @@ type ssService struct { targetIPValidator onet.TargetIPValidator replayCache *ReplayCache - streamDialer transport.StreamDialer + streamDialer transport.StreamDialer + packetListener transport.PacketListener } // NewShadowsocksHandlers creates new Shadowsocks stream and packet handlers. -func NewShadowsocksHandlers(opts ...Option) (StreamHandler, PacketHandler) { +func NewShadowsocksHandlers(opts ...Option) (StreamHandler, AssociationHandler) { s := &ssService{ logger: noopLogger(), } @@ -74,10 +73,13 @@ func NewShadowsocksHandlers(opts ...Option) (StreamHandler, PacketHandler) { } sh.SetLogger(s.logger) - ph := NewPacketHandler(s.ciphers, &ssConnMetrics{s.metrics.AddUDPCipherSearch}) - ph.SetLogger(s.logger) + ah := NewAssociationHandler(s.ciphers, &ssConnMetrics{s.metrics.AddUDPCipherSearch}) + if s.packetListener != nil { + ah.SetTargetPacketListener(s.packetListener) + } + ah.SetLogger(s.logger) - return sh, ph + return sh, ah } // WithLogger can be used to provide a custom log target. If not provided, @@ -115,6 +117,13 @@ func WithStreamDialer(dialer transport.StreamDialer) Option { } } +// WithPacketListener option function. +func WithPacketListener(listener transport.PacketListener) Option { + return func(s *ssService) { + s.packetListener = listener + } +} + type ssConnMetrics struct { metricFunc func(accessKeyFound bool, timeToCipher time.Duration) } diff --git a/service/tcp_test.go b/service/tcp_test.go index 946ed1d4..9f5ecb60 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -368,7 +368,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( @@ -406,7 +406,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( @@ -445,7 +445,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( @@ -747,7 +747,7 @@ func TestStreamServeEarlyClose(t *testing.T) { err = tcpListener.Close() require.NoError(t, err) // This should return quickly, without timing out or calling the handler. - StreamServeStream(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil) + StreamServe(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil) } // Makes sure the TCP listener returns [io.ErrClosed] on Close(). diff --git a/service/udp.go b/service/udp.go index 70363d29..03d20bd8 100644 --- a/service/udp.go +++ b/service/udp.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "log/slog" "net" "net/netip" @@ -47,8 +48,13 @@ type UDPAssociationMetrics interface { AddClose() } -// Max UDP buffer size for the server code. -const serverUDPBufferSize = 64 * 1024 +const ( + // Max UDP buffer size for the server code. + serverUDPBufferSize = 64 * 1024 + + // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. + defaultNatTimeout time.Duration = 5 * time.Minute +) // Buffer pool used for reading UDP packets. var readBufPool = slicepool.MakePool(serverUDPBufferSize) @@ -83,135 +89,143 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis return nil, "", nil, errors.New("could not find valid UDP cipher") } -type packetHandler struct { +type associationHandler struct { logger *slog.Logger ciphers CipherList ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator + targetListener transport.PacketListener } -var _ PacketHandler = (*packetHandler)(nil) +var _ AssociationHandler = (*associationHandler)(nil) -// NewPacketHandler creates a PacketHandler -func NewPacketHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetrics) PacketHandler { +// NewAssociationHandler creates a AssociationHandler +func NewAssociationHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetrics) AssociationHandler { if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } - return &packetHandler{ + return &associationHandler{ logger: noopLogger(), ciphers: cipherList, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, + targetListener: MakeTargetUDPListener(defaultNatTimeout, 0), } } -// PacketHandler is a handler that handles UDP assocations. -type PacketHandler interface { - HandlePacket(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) +// AssociationHandler is a handler that handles UDP assocations. +type AssociationHandler interface { + HandleAssociation(ctx context.Context, conn net.Conn, assocMetrics UDPAssociationMetrics) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) + // SetTargetPacketListener sets the packet listener to use for target connections. + SetTargetPacketListener(targetListener transport.PacketListener) } -func (h *packetHandler) SetLogger(l *slog.Logger) { +func (h *associationHandler) SetLogger(l *slog.Logger) { if l == nil { l = noopLogger() } h.logger = l } -func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { +func (h *associationHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { h.targetIPValidator = targetIPValidator } -func (h *packetHandler) authenticate(pkt []byte, assoc PacketAssociation) ([]byte, error) { - var textData []byte - keyResult, err := assoc.DoOnce(func() (any, error) { - var ( - keyID string - key *shadowsocks.EncryptionKey - keyErr error - ) - ip := assoc.ClientAddr().AddrPort().Addr() - textLazySlice := readBufPool.LazySlice() - textBuf := textLazySlice.Acquire() - unpackStart := time.Now() - textData, keyID, key, keyErr = findAccessKeyUDP(ip, textBuf, pkt, h.ciphers, h.logger) - timeToCipher := time.Since(unpackStart) - textLazySlice.Release() - h.ssm.AddCipherSearch(keyErr == nil, timeToCipher) - - assoc.AddAuthentication(keyID) - if keyErr != nil { - return nil, keyErr +func (h *associationHandler) SetTargetPacketListener(targetListener transport.PacketListener) { + h.targetListener = targetListener +} + +func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn net.Conn, assocMetrics UDPAssociationMetrics) { + l := h.logger.With(slog.Any("client", clientConn.RemoteAddr())) + + var targetConn net.PacketConn + defer func() { + debugUDP(l, "Done") + if targetConn != nil { + targetConn.Close() } - go relayTargetToClient(assoc, func(pkt []byte, assoc PacketAssociation) error { - return h.handlePacketFromTarget(pkt, assoc, key) - }) - return key, nil - }) - if err != nil { - return nil, err - } - cryptoKey, ok := keyResult.(*shadowsocks.EncryptionKey) - if !ok { - // This should never happen in practice. We return a `shadowsocks.EncrypTionKey` - // in the `authenticate` anonymous function above. - return nil, errors.New("authentication result is not an encryption key") - } + assocMetrics.AddClose() + }() - if textData == nil { - // This is a subsequent packet. First packets are already decrypted as part of the - // initial access key search. - unpackStart := time.Now() - textData, err = shadowsocks.Unpack(nil, pkt, cryptoKey) - timeToCipher := time.Since(unpackStart) - h.ssm.AddCipherSearch(err == nil, timeToCipher) - } + var cryptoKey *shadowsocks.EncryptionKey - return textData, nil -} + readBufLazySlice := readBufPool.LazySlice() + readBuf := readBufLazySlice.Acquire() + defer readBufLazySlice.Release() + for { + n, err := clientConn.Read(readBuf) + if errors.Is(err, net.ErrClosed) { + break + } + pkt := readBuf[:n] + debugUDP(l, "Outbound packet.", slog.Int("bytes", n)) + + var proxyTargetBytes int + connError := func() *onet.ConnectionError { + var textData []byte + if targetConn == nil { + ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() + textLazySlice := readBufPool.LazySlice() + unpackStart := time.Now() + var keyID string + textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textLazySlice.Acquire(), pkt, h.ciphers, h.logger) + timeToCipher := time.Since(unpackStart) + textLazySlice.Release() + h.ssm.AddCipherSearch(err == nil, timeToCipher) -func (h *packetHandler) HandlePacket(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) { - l := h.logger.With(slog.Any("association", assoc)) - defer debugUDP(l, "Done") + if err != nil { + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) + } + assocMetrics.AddAuthentication(keyID) - debugUDP(l, "Outbound packet.", slog.Int("bytes", len(pkt))) + // Create the target connection. + targetConn, err = h.targetListener.ListenPacket(ctx) + if err != nil { + return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create a `PacketConn`", err) + } + l = l.With(slog.Any("ltarget", targetConn.LocalAddr())) + go relayTargetToClient(targetConn, clientConn, cryptoKey, assocMetrics, l) + } else { + unpackStart := time.Now() + textData, err = shadowsocks.Unpack(nil, pkt, cryptoKey) + timeToCipher := time.Since(unpackStart) + h.ssm.AddCipherSearch(err == nil, timeToCipher) - var proxyTargetBytes int - connError := func() *onet.ConnectionError { - textData, err := h.authenticate(pkt, assoc) - lazySlice.Release() - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } + if err != nil { + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) + } + } - payload, tgtUDPAddr, onetErr := h.validatePacket(textData) - if onetErr != nil { - return onetErr - } + payload, tgtUDPAddr, onetErr := h.validatePacket(textData) + if onetErr != nil { + return onetErr + } - debugUDP(l, "Proxy exit.") - proxyTargetBytes, err = assoc.WriteToTarget(payload, tgtUDPAddr) // accept only UDPAddr despite the signature - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) - } - return nil - }() + debugUDP(l, "Proxy exit.") + proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) + } + return nil + }() - status := "OK" - if connError != nil { - debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) - status = connError.Status + status := "OK" + if connError != nil { + debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + status = connError.Status + } + assocMetrics.AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes)) } - assoc.AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes)) } // Given the decrypted contents of a UDP packet, return // the payload and the destination address, or an error if // this packet cannot or should not be forwarded. -func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { +func (h *associationHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { tgtAddr := socks.SplitAddr(textData) if tgtAddr == nil { return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) @@ -229,93 +243,23 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, * return payload, tgtUDPAddr, nil } -// Get the maximum length of the shadowsocks address header by parsing -// and serializing an IPv6 address from the example range. -var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) - -func (h *packetHandler) handlePacketFromTarget(pkt []byte, assoc PacketAssociation, cryptoKey *shadowsocks.EncryptionKey) error { - l := h.logger.With(slog.Any("association", assoc)) - - expired := false - var bodyLen, proxyClientBytes int - connError := func() *onet.ConnectionError { - saltSize := cryptoKey.SaltSize() - // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). - bodyStart := saltSize + maxAddrLen - - var ( - raddr net.Addr - err error - ) - // `readBuf` receives the plaintext body in `pkt`: - // [padding?][salt][address][body][tag][unused] - // |-- bodyStart --|[ readBuf ] - readBuf := pkt[bodyStart:] - bodyLen, raddr, err = assoc.ReadFromTarget(readBuf) - if err != nil { - if netErr, ok := err.(net.Error); ok { - if netErr.Timeout() { - expired = true - return nil - } - } - return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) - } - - debugUDP(l, "Got response.", slog.Any("rtarget", raddr)) - srcAddr := socks.ParseAddr(raddr.String()) - addrStart := bodyStart - len(srcAddr) - // `plainTextBuf` concatenates the SOCKS address and body: - // [padding?][salt][address][body][tag][unused] - // |-- addrStart -|[plaintextBuf ] - plaintextBuf := pkt[addrStart : bodyStart+bodyLen] - copy(plaintextBuf, srcAddr) - - // saltStart is 0 if raddr is IPv6. - saltStart := addrStart - saltSize - // `packBuf` adds space for the salt and tag. - // `buf` shows the space that was used. - // [padding?][salt][address][body][tag][unused] - // [ packBuf ] - // [ buf ] - packBuf := pkt[saltStart:] - buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place - if err != nil { - return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) - } - proxyClientBytes, err = assoc.WriteToClient(buf) - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) - } - return nil - }() - status := "OK" - if connError != nil { - debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) - status = connError.Status - } - if expired { - return errors.New("target connection has expired") - } - assoc.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes)) - return nil -} - -type NewAssociationFunc func(conn net.Conn) (PacketAssociation, error) +type AssociationHandleFunc func(ctx context.Context, conn net.Conn) // PacketServe listens for UDP packets on the provided [net.PacketConn], creates // and manages NAT associations, and invokes the provided `handlePacket` // function for each packet. It uses a NAT map to track active associations and // handles their lifecycle. -func PacketServe(clientConn net.PacketConn, newAssociation NewAssociationFunc, handlePacket PacketHandleFuncWithLazySlice, metrics NATMetrics) { +func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, metrics NATMetrics) { nm := newNATmap() - defer nm.Close() + ctx, contextCancel := context.WithCancel(context.Background()) + defer contextCancel() for { lazySlice := readBufPool.LazySlice() buffer := lazySlice.Acquire() - isClosed := func() bool { + expired := false + func() { defer func() { if r := recover(); r != nil { slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) @@ -327,63 +271,104 @@ func PacketServe(clientConn net.PacketConn, newAssociation NewAssociationFunc, h if err != nil { lazySlice.Release() if errors.Is(err, net.ErrClosed) { - return true + expired = true + return } slog.Warn("Failed to read from client. Continuing to listen.", "err", err) - return false + return } - pkt := buffer[:n] + pkt := &packet{payload: buffer[:n], lazySlice: lazySlice} // TODO(#19): Include server address in the NAT key as well. assoc := nm.Get(addr.String()) if assoc == nil { - conn := &natconn{PacketConn: clientConn, raddr: addr} - assoc, err = newAssociation(conn) + assoc = &association{ + pc: clientConn, + raddr: addr, + readCh: make(chan *packet, 5), + } if err != nil { slog.Error("Failed to handle association", slog.Any("err", err)) - return false + return } metrics.AddNATEntry() nm.Add(addr.String(), assoc) + go func() { + assocHandle(ctx, assoc) + metrics.RemoveNATEntry() + nm.Del(addr.String()) + }() } select { - case <-assoc.Done(): - lazySlice.Release() - metrics.RemoveNATEntry() - nm.Del(addr.String()) + case assoc.readCh <- pkt: default: - go handlePacket(pkt, assoc, lazySlice) + slog.Debug("Dropping packet due to full read queue") + // TODO: Add a metric to track number of dropped packets. } - return false }() - if isClosed { - return + if expired { + break } } } -// natconn wraps a [net.PacketConn] with an address into a [net.Conn]. -type natconn struct { - net.PacketConn - raddr net.Addr +type packet struct { + payload []byte + lazySlice slicepool.LazySlice +} + +// association wraps a [net.PacketConn] with an address into a [net.Conn]. +type association struct { + pc net.PacketConn + raddr net.Addr + readCh chan *packet } -var _ net.Conn = (*natconn)(nil) +var _ net.Conn = (*association)(nil) -func (c *natconn) Read(p []byte) (int, error) { - n, _, err := c.PacketConn.ReadFrom(p) - return n, err +func (c *association) Read(p []byte) (int, error) { + pkt, ok := <-c.readCh + if !ok { + return 0, net.ErrClosed + } + n := copy(p, pkt.payload) + pkt.lazySlice.Release() + if n < len(pkt.payload) { + return n, io.ErrShortBuffer + } + return n, nil } -func (c *natconn) Write(b []byte) (n int, err error) { - return c.PacketConn.WriteTo(b, c.raddr) +func (c *association) Write(b []byte) (n int, err error) { + return c.pc.WriteTo(b, c.raddr) } -func (c *natconn) RemoteAddr() net.Addr { +func (c *association) Close() error { + close(c.readCh) + return c.pc.Close() +} + +func (c *association) LocalAddr() net.Addr { + return c.pc.LocalAddr() +} + +func (c *association) RemoteAddr() net.Addr { return c.raddr } +func (c *association) SetDeadline(t time.Time) error { + return c.pc.SetDeadline(t) +} + +func (c *association) SetReadDeadline(t time.Time) error { + return c.pc.SetReadDeadline(t) +} + +func (c *association) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + func isDNS(addr net.Addr) bool { _, port, _ := net.SplitHostPort(addr.String()) return port == "53" @@ -449,15 +434,15 @@ func (c *timedPacketConn) ReadFrom(buf []byte) (int, net.Addr, error) { // Packet NAT table type natmap struct { sync.RWMutex - associations map[string]PacketAssociation + associations map[string]*association } func newNATmap() *natmap { - return &natmap{associations: make(map[string]PacketAssociation)} + return &natmap{associations: make(map[string]*association)} } // Get returns a UDP NAT entry from the natmap. -func (m *natmap) Get(clientAddr string) PacketAssociation { +func (m *natmap) Get(clientAddr string) *association { m.RLock() defer m.RUnlock() return m.associations[clientAddr] @@ -474,193 +459,94 @@ func (m *natmap) Del(clientAddr string) { } // Add adds a new UDP NAT entry to the natmap. -func (m *natmap) Add(clientAddr string, assoc PacketAssociation) { +func (m *natmap) Add(clientAddr string, assoc *association) { m.Lock() defer m.Unlock() m.associations[clientAddr] = assoc } -func (m *natmap) Close() error { - m.Lock() - defer m.Unlock() - - var err error - for _, assoc := range m.associations { - if e := assoc.Close(); e != nil { - err = e - } - } - return err -} - -// PacketHandleFunc processes a single incoming packet. -type PacketHandleFunc func(pkt []byte, assoc PacketAssociation) error +// PacketHandleFunc processes a single packet. +type PacketHandleFunc func(pkt []byte) error -// PacketHandleFuncWithLazySlice processes a single incoming packet. -// -// lazySlice is the LazySlice that holds the pkt buffer, which should be -// released as soon as the packet is processed. -type PacketHandleFuncWithLazySlice func(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) - -func HandleAssociation(assoc PacketAssociation, handlePacket PacketHandleFuncWithLazySlice) { - for { - lazySlice := readBufPool.LazySlice() - buf := lazySlice.Acquire() - n, err := assoc.ReadFromClient(buf) - if errors.Is(err, net.ErrClosed) { - lazySlice.Release() - return - } - pkt := buf[:n] - select { - case <-assoc.Done(): - lazySlice.Release() - return - default: - go handlePacket(pkt, assoc, lazySlice) - } - } -} +// Get the maximum length of the shadowsocks address header by parsing +// and serializing an IPv6 address from the example range. +var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // relayTargetToClient handles the target-side of the association by // copying from target to client until read timeout. -func relayTargetToClient(assoc PacketAssociation, handlePacket PacketHandleFunc) { - defer assoc.CloseTarget() +func relayTargetToClient(targetConn net.PacketConn, clientConn net.Conn, cryptoKey *shadowsocks.EncryptionKey, m UDPAssociationMetrics, l *slog.Logger) { + defer targetConn.Close() // pkt is used for in-place encryption of downstream UDP packets. // Padding is only used if the address is IPv4. pkt := make([]byte, serverUDPBufferSize) - for { - if err := handlePacket(pkt, assoc); err != nil { - break - } - } -} - -// PacketAssociation represents a UDP association. -type PacketAssociation interface { - // TODO(sbruens): Decouple the metrics from the association. - UDPAssociationMetrics - - // ReadFromClient reads data from the client side of the association. - ReadFromClient(b []byte) (n int, err error) - - // WriteToClient writes data to the client side of the association. - WriteToClient(b []byte) (n int, err error) - - // ReadFromTarget reads data from the target side of the association. - ReadFromTarget(p []byte) (n int, addr net.Addr, err error) - - // WriteToTarget writes data to the target side of the association. - WriteToTarget(b []byte, addr net.Addr) (int, error) - - // ClientAddr returns the remote network address of the client connection, if known. - ClientAddr() *net.UDPAddr + saltSize := cryptoKey.SaltSize() + // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). + bodyStart := saltSize + maxAddrLen - // DoOnce executes the provided function only once and caches the result. - DoOnce(f func() (any, error)) (any, error) - - // Done returns a channel that is closed when the association is closed. - Done() <-chan struct{} - - // Close closes the association and releases any associated resources. - Close() error - - // Closes the target side of the association. - CloseTarget() error -} - -type association struct { - clientConn net.Conn - targetConn net.PacketConn - - once sync.Once - cachedResult any - - UDPAssociationMetrics - doneCh chan struct{} -} - -var _ PacketAssociation = (*association)(nil) -var _ UDPAssociationMetrics = (*association)(nil) -var _ slog.LogValuer = (*association)(nil) - -// NewPacketAssociation creates a new packet-based association. -func NewPacketAssociation(conn net.Conn, listener transport.PacketListener, m UDPAssociationMetrics) (PacketAssociation, error) { - if m == nil { - m = &NoOpUDPAssociationMetrics{} - } - // Create the target connection - targetConn, err := listener.ListenPacket(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to create target connection: %w", err) - } - - return &association{ - clientConn: conn, - targetConn: targetConn, - UDPAssociationMetrics: m, - doneCh: make(chan struct{}), - }, nil -} - -func (a *association) ReadFromClient(b []byte) (n int, err error) { - return a.clientConn.Read(b) -} - -func (a *association) WriteToClient(b []byte) (n int, err error) { - return a.clientConn.Write(b) -} - -func (a *association) ReadFromTarget(p []byte) (n int, addr net.Addr, err error) { - return a.targetConn.ReadFrom(p) -} - -func (a *association) WriteToTarget(b []byte, addr net.Addr) (int, error) { - return a.targetConn.WriteTo(b, addr) -} - -func (a *association) ClientAddr() *net.UDPAddr { - return a.clientConn.RemoteAddr().(*net.UDPAddr) -} + expired := false + for { + var targetProxyBytes, proxyClientBytes int + connError := func() *onet.ConnectionError { + var ( + raddr net.Addr + err error + ) + // `readBuf` receives the plaintext body in `pkt`: + // [padding?][salt][address][body][tag][unused] + // |-- bodyStart --|[ readBuf ] + readBuf := pkt[bodyStart:] + targetProxyBytes, raddr, err = targetConn.ReadFrom(readBuf) + if err != nil { + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() { + expired = true + return nil + } + } + return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) + } -func (a *association) DoOnce(f func() (any, error)) (any, error) { - var err error - a.once.Do(func() { - result, err := f() - if err == nil { - a.cachedResult = result + debugUDP(l, "Got response.", slog.Any("rtarget", raddr)) + srcAddr := socks.ParseAddr(raddr.String()) + addrStart := bodyStart - len(srcAddr) + // `plainTextBuf` concatenates the SOCKS address and body: + // [padding?][salt][address][body][tag][unused] + // |-- addrStart -|[plaintextBuf ] + plaintextBuf := pkt[addrStart : bodyStart+targetProxyBytes] + copy(plaintextBuf, srcAddr) + + // saltStart is 0 if raddr is IPv6. + saltStart := addrStart - saltSize + // `packBuf` adds space for the salt and tag. + // `buf` shows the space that was used. + // [padding?][salt][address][body][tag][unused] + // [ packBuf ] + // [ buf ] + packBuf := pkt[saltStart:] + buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place + if err != nil { + return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) + } + proxyClientBytes, err = clientConn.Write(buf) + if err != nil { + return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) + } + return nil + }() + status := "OK" + if connError != nil { + debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + status = connError.Status } - }) - return a.cachedResult, err -} -func (a *association) Done() <-chan struct{} { - return a.doneCh -} - -func (a *association) Close() error { - now := time.Now() - return a.clientConn.SetReadDeadline(now) -} - -func (a *association) CloseTarget() error { - a.UDPAssociationMetrics.AddClose() - err := a.targetConn.Close() - if err != nil { - return err + if expired { + break + } + m.AddPacketFromTarget(status, int64(targetProxyBytes), int64(proxyClientBytes)) } - close(a.doneCh) - return nil -} - -func (a *association) LogValue() slog.Value { - return slog.GroupValue( - slog.Any("client", a.clientConn.RemoteAddr()), - slog.Any("ltarget", a.targetConn.LocalAddr()), - ) } // NoOpUDPAssociationMetrics is a [UDPAssociationMetrics] that doesn't do anything. Useful in tests diff --git a/service/udp_test.go b/service/udp_test.go index 51363c35..bf1090f6 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -17,8 +17,8 @@ package service import ( "bytes" "context" - "errors" "fmt" + "io" "net" "net/netip" "sync" @@ -49,7 +49,7 @@ func init() { natCryptoKey, _ = shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "test password") } -type packet struct { +type fakePacket struct { addr net.Addr payload []byte err error @@ -65,16 +65,16 @@ func (ln *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, err type fakePacketConn struct { net.PacketConn - send chan packet - recv chan packet + send chan fakePacket + recv chan fakePacket deadline time.Time mu sync.Mutex } func makePacketConn() *fakePacketConn { return &fakePacketConn{ - send: make(chan packet, 1), - recv: make(chan packet), + send: make(chan fakePacket, 1), + recv: make(chan fakePacket), } } @@ -102,7 +102,7 @@ func (conn *fakePacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) } }() - conn.send <- packet{addr, payload, nil} + conn.send <- fakePacket{addr, payload, nil} return len(payload), err } @@ -113,7 +113,7 @@ func (conn *fakePacketConn) ReadFrom(buffer []byte) (int, net.Addr, error) { } n := copy(buffer, pkt.payload) if n < len(pkt.payload) { - return n, pkt.addr, errors.New("buffer was too short") + return n, pkt.addr, io.ErrShortBuffer } return n, pkt.addr, pkt.err } @@ -182,7 +182,7 @@ func sendSSPayload(conn *fakePacketConn, addr net.Addr, cipher *shadowsocks.Encr plaintext := append(socksAddr, payload...) ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) shadowsocks.Pack(ciphertext, plaintext, cipher) - conn.recv <- packet{ + conn.recv <- fakePacket{ addr: &clientAddr, payload: ciphertext, } @@ -191,37 +191,38 @@ func sendSSPayload(conn *fakePacketConn, addr net.Addr, cipher *shadowsocks.Encr // startTestHandler creates a new association handler with a fake // client and target connection for testing purposes. It also starts a // PacketServe goroutine to handle incoming packets on the client connection. -func startTestHandler() (PacketHandler, func(target net.Addr, payload []byte), *fakePacketConn) { +func startTestHandler() (AssociationHandler, func(target net.Addr, payload []byte), *fakePacketConn) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(ciphers, nil) + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{targetConn}, nil) - return assoc, nil - }, handler.Handle, &natTestMetrics{}) + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) return handler, func(target net.Addr, payload []byte) { sendSSPayload(clientConn, target, cipher, payload) }, targetConn } -func TestNatconnCloseWhileReading(t *testing.T) { - nc := &natconn{ - PacketConn: makePacketConn(), - raddr: &clientAddr, +func TestAssociationCloseWhileReading(t *testing.T) { + assoc := &association{ + pc: makePacketConn(), + raddr: &clientAddr, + readCh: make(chan *packet), } go func() { buf := make([]byte, 1024) - nc.Read(buf) + assoc.Read(buf) }() - err := nc.Close() + err := assoc.Close() assert.NoError(t, err, "Close should not panic or return an error") } -func TestPacketHandler_Handle_IPFilter(t *testing.T) { +func TestAssociationHandler_Handle_IPFilter(t *testing.T) { t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() handler.SetTargetIPValidator(onet.RequirePublicIP) @@ -252,14 +253,14 @@ func TestPacketHandler_Handle_IPFilter(t *testing.T) { func TestUpstreamMetrics(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(ciphers, nil) + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) metrics := &fakeUDPAssociationMetrics{} - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{targetConn}, metrics) - return assoc, nil - }, handler.Handle, &natTestMetrics{}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, metrics) + }, &natTestMetrics{}) // Test both the first-packet and subsequent-packet cases. const N = 10 @@ -371,13 +372,13 @@ func TestTimedPacketConn(t *testing.T) { t.Run("FastClose", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(ciphers, nil) + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{targetConn}, nil) - return assoc, nil - }, handler.Handle, &natTestMetrics{}) + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) // Send one DNS query. sendSSPayload(clientConn, &dnsAddr, cipher, []byte{1}) @@ -385,7 +386,7 @@ func TestTimedPacketConn(t *testing.T) { require.Len(t, sent.payload, 1) // Send the response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} + received := fakePacket{addr: &dnsAddr, payload: response} targetConn.recv <- received sent, ok := <-clientConn.send if !ok { @@ -399,13 +400,13 @@ func TestTimedPacketConn(t *testing.T) { t.Run("NoFastClose_NotDNS", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(ciphers, nil) + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{targetConn}, nil) - return assoc, nil - }, handler.Handle, &natTestMetrics{}) + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) // Send one non-DNS packet. sendSSPayload(clientConn, &targetAddr, cipher, []byte{1}) @@ -413,7 +414,7 @@ func TestTimedPacketConn(t *testing.T) { require.Len(t, sent.payload, 1) // Send the response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &targetAddr, payload: response} + received := fakePacket{addr: &targetAddr, payload: response} targetConn.recv <- received sent, ok := <-clientConn.send if !ok { @@ -427,13 +428,13 @@ func TestTimedPacketConn(t *testing.T) { t.Run("NoFastClose_MultipleDNS", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(ciphers, nil) + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{targetConn}, nil) - return assoc, nil - }, handler.Handle, &natTestMetrics{}) + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) // Send two DNS packets. sendSSPayload(clientConn, &dnsAddr, cipher, []byte{1}) @@ -443,7 +444,7 @@ func TestTimedPacketConn(t *testing.T) { // Send a response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} + received := fakePacket{addr: &dnsAddr, payload: response} targetConn.recv <- received <-clientConn.send @@ -458,7 +459,7 @@ func TestTimedPacketConn(t *testing.T) { sendPayload(&targetAddr, []byte{1}) <-targetConn.send // Simulate a read timeout. - received := packet{err: &fakeTimeoutError{}} + received := fakePacket{err: &fakeTimeoutError{}} before := time.Now() targetConn.recv <- received // Wait for targetConn to close. @@ -514,20 +515,6 @@ func TestNATMap(t *testing.T) { assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) - - t.Run("Close", func(t *testing.T) { - nm := newNATmap() - addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - pc := makePacketConn() - assoc := &association{clientConn: &natconn{PacketConn: pc, raddr: addr}} - nm.Add(addr.String(), assoc) - - err := nm.Close() - assert.NoError(t, err, "Close should not return an error") - - // The underlying connection should be scheduled to close immediately. - assertAlmostEqual(t, pc.deadline, time.Now()) - }) } // Simulates receiving invalid UDP packets on a server with 100 ciphers. @@ -613,7 +600,8 @@ func TestUDPEarlyClose(t *testing.T) { t.Fatal(err) } const testTimeout = 200 * time.Millisecond - ph := NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) + handler := NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) + handler.SetTargetPacketListener(&packetListener{makePacketConn()}) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -621,10 +609,9 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - go PacketServe(clientConn, func(conn net.Conn) (PacketAssociation, error) { - assoc, _ := NewPacketAssociation(conn, &packetListener{makePacketConn()}, nil) - return assoc, nil - }, ph.Handle, &natTestMetrics{}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close().