From 59a6bdc1fa6f40ebfcf2942675b39c687a11033c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 11:37:44 +0800 Subject: [PATCH] Update udpant usages --- go.mod | 2 +- go.sum | 4 ++-- stack_gvisor.go | 6 +++++- stack_gvisor_udp.go | 8 +++++++- stack_mixed.go | 34 +++++++++++++++++++--------------- 5 files changed, 34 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index c74a07d..5eccc7f 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.6.0-alpha.18 + github.com/sagernet/sing v0.6.0-alpha.24 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 6fd8e46..75e6c72 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.6.0-alpha.18 h1:ih4CurU8KvbhfagYjSqVrE2LR0oBSXSZTNH2sAGPGiM= -github.com/sagernet/sing v0.6.0-alpha.18/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.24 h1:qPc9i0mHADIFNYlWMg7fWWZZ0kBxWHEs8npsAG6KqAo= +github.com/sagernet/sing v0.6.0-alpha.24/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/stack_gvisor.go b/stack_gvisor.go index 65bb7bd..ccf9a8b 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -34,6 +34,7 @@ type GVisor struct { logger logger.Logger stack *stack.Stack endpoint stack.LinkEndpoint + udpForwarder *UDPForwarder } type GVisorTun interface { @@ -71,9 +72,11 @@ func (t *GVisor) Start() error { return err } ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) + udpForwarder := NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint + t.udpForwarder = udpForwarder return nil } @@ -86,6 +89,7 @@ func (t *GVisor) Close() error { for _, endpoint := range t.stack.CleanupEndpoints() { endpoint.Abort() } + t.udpForwarder.Close() return nil } diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 3027798..284fedf 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -37,10 +37,16 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, t stack: stack, handler: handler, } - forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true) + udpNat := udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true) + udpNat.Start() + forwarder.udpNat = udpNat return forwarder } +func (f *UDPForwarder) Close() { + f.udpNat.Close() +} + func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) diff --git a/stack_mixed.go b/stack_mixed.go index 0cd8819..c2eca96 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -16,8 +16,9 @@ import ( type Mixed struct { *System - stack *stack.Stack - endpoint *channel.Endpoint + stack *stack.Stack + endpoint *channel.Endpoint + udpForwarder *UDPForwarder } func NewMixed( @@ -42,14 +43,29 @@ func (m *Mixed) Start() error { if err != nil { return err } - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) + udpForwarder := NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) m.stack = ipStack m.endpoint = endpoint + m.udpForwarder = udpForwarder go m.tunLoop() go m.packetLoop() return nil } +func (m *Mixed) Close() error { + if m.stack == nil { + return nil + } + m.endpoint.Attach(nil) + m.stack.Close() + for _, endpoint := range m.stack.CleanupEndpoints() { + endpoint.Abort() + } + m.udpNat.Close() + return m.System.Close() +} + func (m *Mixed) tunLoop() { if winTun, isWinTun := m.tun.(WinTun); isWinTun { m.wintunLoop(winTun) @@ -222,15 +238,3 @@ func (m *Mixed) packetLoop() { packet.DecRef() } } - -func (m *Mixed) Close() error { - if m.stack == nil { - return nil - } - m.endpoint.Attach(nil) - m.stack.Close() - for _, endpoint := range m.stack.CleanupEndpoints() { - endpoint.Abort() - } - return m.System.Close() -}