From 0353283ebf75a9f4731f3cc9eb31a97d3374820d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 21 Oct 2024 21:55:07 +0800 Subject: [PATCH] Migrate to udpnat2 / Add PrepareConnection --- go.mod | 4 +- go.sum | 8 ++-- stack.go | 3 +- stack_gvisor.go | 22 ++++++++--- stack_gvisor_udp.go | 75 ++++++++++++++++++++------------------ stack_system.go | 89 ++++++++++++++++++++++++++------------------- stack_system_nat.go | 12 ++++-- tun.go | 2 + 8 files changed, 126 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 153b0b9..0d3ae2b 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.20 require ( github.com/go-ole/go-ole v1.3.0 github.com/sagernet/fswatch v0.1.1 - github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc + github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a + github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 0a710c4..b9ad12a 100644 --- a/go.sum +++ b/go.sum @@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs= github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o= -github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc h1:IvmeRstYX63O0QpLGJgVOaaM21ZIG0frJi6MT29Irtk= -github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= +github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM= +github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a h1:6qlFfBvLZT/MhDpUr4cKY6RxYTnaCcFgOrJEnf/0+io= -github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce h1:5qVxlM/CSW1pTBiiD2ZOIi2ziE6EXdRlnT4H+enjbEk= +github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/stack.go b/stack.go index 88cd9ee..5ba18e4 100644 --- a/stack.go +++ b/stack.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "net" "net/netip" + "time" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" @@ -23,7 +24,7 @@ type StackOptions struct { Tun Tun TunOptions Options EndpointIndependentNat bool - UDPTimeout int64 + UDPTimeout time.Duration Handler Handler Logger logger.Logger ForwarderBindInterface bool diff --git a/stack_gvisor.go b/stack_gvisor.go index 5044729..0ee5b5b 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -5,6 +5,7 @@ package tun import ( "context" "net/netip" + "os" "time" "github.com/sagernet/gvisor/pkg/tcpip" @@ -32,7 +33,7 @@ type GVisor struct { ctx context.Context tun GVisorTun endpointIndependentNat bool - udpTimeout int64 + udpTimeout time.Duration broadcastAddr netip.Addr handler Handler logger logger.Logger @@ -85,13 +86,18 @@ func (t *GVisor) Start() error { localAddr: source.TCPAddr(), remoteAddr: destination.TCPAddr(), } + pErr := t.handler.PrepareConnection(source, destination) + if pErr != nil { + r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid) + return + } go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) if !t.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { + udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { var wq waiter.Queue - endpoint, err := request.CreateEndpoint(&wq) + endpoint, err := r.CreateEndpoint(&wq) if err != nil { return } @@ -102,9 +108,15 @@ func (t *GVisor) Start() error { endpoint.Abort() return } + source := M.SocksaddrFromNet(lAddr) + destination := M.SocksaddrFromNet(rAddr) + pErr := t.handler.PrepareConnection(source, destination) + if pErr != nil { + gWriteUnreachable(t.stack, r.Packet(), pErr) + r.Packet().DecRef() + return + } go func() { - source := M.SocksaddrFromNet(lAddr) - destination := M.SocksaddrFromNet(rAddr) ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second) t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) }() diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 99662c4..0b0fedc 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -8,6 +8,8 @@ import ( "net/netip" "os" "sync" + "time" + _ "unsafe" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" @@ -19,59 +21,60 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/udpnat" + "github.com/sagernet/sing/common/udpnat2" ) type UDPForwarder struct { - ctx context.Context - stack *stack.Stack - udpNat *udpnat.Service[netip.AddrPort] - - // cache - cacheProto tcpip.NetworkProtocolNumber - cacheID stack.TransportEndpointID + ctx context.Context + stack *stack.Stack + handler Handler + udpNat *udpnat.Service } -func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { - return &UDPForwarder{ - ctx: ctx, - stack: stack, - udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler), +func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder { + forwarder := &UDPForwarder{ + ctx: ctx, + stack: stack, + handler: handler, } + forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout) + return forwarder } func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) - if source.IsIPv4() { - f.cacheProto = header.IPv4ProtocolNumber - } else { - f.cacheProto = header.IPv6ProtocolNumber - } - gBuffer := pkt.Data().ToBuffer() - sBuffer := buf.NewSize(int(gBuffer.Size())) - gBuffer.Apply(func(view *buffer.View) { - sBuffer.Write(view.AsSlice()) + bufferRange := pkt.Data().AsRange() + bufferSlices := make([][]byte, bufferRange.Size()) + rangeIterate(bufferRange, func(view *buffer.View) { + bufferSlices = append(bufferSlices, view.AsSlice()) }) - f.cacheID = id - f.udpNat.NewPacketEx( - f.ctx, - source.AddrPort(), - sBuffer, - source, - destination, - f.newUDPConn, - ) + f.udpNat.NewPacket(bufferSlices, source, destination, pkt) return true } -func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { - return &UDPBackWriter{ +//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate +func rangeIterate(r stack.Range, fn func(*buffer.View)) + +func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { + pErr := f.handler.PrepareConnection(source, destination) + if pErr != nil { + gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr) + return false, nil, nil, nil + } + var sourceNetwork tcpip.NetworkProtocolNumber + if source.Addr.Is4() { + sourceNetwork = header.IPv4ProtocolNumber + } else { + sourceNetwork = header.IPv6ProtocolNumber + } + writer := &UDPBackWriter{ stack: f.stack, - source: f.cacheID.RemoteAddress, - sourcePort: f.cacheID.RemotePort, - sourceNetwork: f.cacheProto, + source: AddressFromAddr(source.Addr), + sourcePort: source.Port, + sourceNetwork: sourceNetwork, } + return true, f.ctx, writer, nil } type UDPBackWriter struct { diff --git a/stack_system.go b/stack_system.go index c7f93fc..f9915ea 100644 --- a/stack_system.go +++ b/stack_system.go @@ -15,7 +15,7 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/udpnat" + "github.com/sagernet/sing/common/udpnat2" ) var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25") @@ -34,13 +34,13 @@ type System struct { inet6ServerAddress netip.Addr inet6Address netip.Addr broadcastAddr netip.Addr - udpTimeout int64 + udpTimeout time.Duration tcpListener net.Listener tcpListener6 net.Listener tcpPort uint16 tcpPort6 uint16 tcpNat *TCPNat - udpNat *udpnat.Service[netip.AddrPort] + udpNat *udpnat.Service bindInterface bool interfaceFinder control.InterfaceFinder frontHeadroom int @@ -151,8 +151,8 @@ func (s *System) start() error { s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port go s.acceptLoop(tcpListener) } - s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) - s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler) + s.tcpNat = NewNat(s.ctx, s.udpTimeout) + s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout) return nil } @@ -354,7 +354,11 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip. packet.SetDestinationIP(session.Source.Addr()) header.SetDestinationPort(session.Source.Port()) } else { - natPort := s.tcpNat.Lookup(source, destination) + natPort, err := s.tcpNat.Lookup(source, destination, s.handler) + if err != nil { + // TODO: implement rejects + return nil + } packet.SetSourceIP(s.inet4Address) header.SetSourcePort(natPort) packet.SetDestinationIP(s.inet4ServerAddress) @@ -385,7 +389,11 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip. packet.SetDestinationIP(session.Source.Addr()) header.SetDestinationPort(session.Source.Port()) } else { - natPort := s.tcpNat.Lookup(source, destination) + natPort, err := s.tcpNat.Lookup(source, destination, s.handler) + if err != nil { + // TODO: implement rejects + return nil + } packet.SetSourceIP(s.inet6Address) header.SetSourcePort(natPort) packet.SetDestinationIP(s.inet6ServerAddress) @@ -409,56 +417,61 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. if !header.Valid() { return E.New("ipv4: udp: invalid packet") } - source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) - destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) - if !destination.Addr().IsGlobalUnicast() { + source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort()) + destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort()) + if !destination.Addr.IsGlobalUnicast() { return nil } - data := buf.As(header.Payload()) - if data.Len() == 0 { + s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet) + return nil +} + +func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { + if !header.Valid() { + return E.New("ipv6: udp: invalid packet") + } + source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort()) + destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort()) + if !destination.Addr.IsGlobalUnicast() { return nil } - s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter { + s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet) + return nil +} + +func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { + pErr := s.handler.PrepareConnection(source, destination) + if pErr != nil { + // TODO: implement ICMP port unreachable + return false, nil, nil, nil + } + var writer N.PacketWriter + if source.IsIPv4() { + packet := userData.(clashtcpip.IPv4Packet) headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter4{ + writer = &systemUDPPacketWriter4{ s.tun, s.frontHeadroom + PacketOffset, headerCopy, - source, + source.AddrPort(), s.txChecksumOffload, } - }) - return nil -} - -func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { - if !header.Valid() { - return E.New("ipv6: udp: invalid packet") - } - source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) - destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) - if !destination.Addr().IsGlobalUnicast() { - return nil - } - data := buf.As(header.Payload()) - if data.Len() == 0 { - return nil - } - s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter { - headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize + } else { + packet := userData.(clashtcpip.IPv6Packet) + headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter6{ + writer = &systemUDPPacketWriter6{ s.tun, s.frontHeadroom + PacketOffset, headerCopy, - source, + source.AddrPort(), s.txChecksumOffload, } - }) - return nil + } + return true, s.ctx, writer, nil } func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { diff --git a/stack_system_nat.go b/stack_system_nat.go index ff80413..6e7e7ef 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -5,6 +5,8 @@ import ( "net/netip" "sync" "time" + + M "github.com/sagernet/sing/common/metadata" ) type TCPNat struct { @@ -68,12 +70,16 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession { return session } -func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { +func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) { n.addrAccess.RLock() port, loaded := n.addrMap[source] n.addrAccess.RUnlock() if loaded { - return port + return port, nil + } + pErr := handler.PrepareConnection(M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination)) + if pErr != nil { + return 0, pErr } n.addrAccess.Lock() nextPort := n.portIndex @@ -92,5 +98,5 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1 LastActive: time.Now(), } n.portAccess.Unlock() - return nextPort + return nextPort, nil } diff --git a/tun.go b/tun.go index ea85bb3..68ba7c1 100644 --- a/tun.go +++ b/tun.go @@ -10,11 +10,13 @@ import ( F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ranges" ) type Handler interface { + PrepareConnection(source M.Socksaddr, destination M.Socksaddr) error N.TCPConnectionHandlerEx N.UDPConnectionHandlerEx }