From cacaf7a69bb4f3a8e076f136902f69f7563a4bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 21 Nov 2024 18:12:21 +0800 Subject: [PATCH] Export interface for WireGuard --- go.mod | 2 +- go.sum | 4 +- internal/gtcpip/header/interfaces.go | 136 +++++++++++++++++++++++++++ stack_gvisor.go | 37 ++------ stack_gvisor_tcp.go | 51 ++++++++++ stack_gvisor_udp.go | 2 +- stack_mixed.go | 8 +- stack_system.go | 10 +- stack_system_packet.go | 34 +++++++ tun.go | 3 + tun_darwin.go | 29 ++++-- 11 files changed, 268 insertions(+), 48 deletions(-) create mode 100644 internal/gtcpip/header/interfaces.go create mode 100644 stack_gvisor_tcp.go create mode 100644 stack_system_packet.go diff --git a/go.mod b/go.mod index e7644c0..b4bbd0e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.6.0-alpha.11 + github.com/sagernet/sing v0.6.0-alpha.18 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 270fa12..1cdcf1b 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.6.0-alpha.11 h1:ZcZlA0/vdDeiipAbjK73x9VabGJ/RRcAJgWhOo/OoBk= -github.com/sagernet/sing v0.6.0-alpha.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.18 h1:ih4CurU8KvbhfagYjSqVrE2LR0oBSXSZTNH2sAGPGiM= +github.com/sagernet/sing v0.6.0-alpha.18/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/internal/gtcpip/header/interfaces.go b/internal/gtcpip/header/interfaces.go new file mode 100644 index 0000000..b304532 --- /dev/null +++ b/internal/gtcpip/header/interfaces.go @@ -0,0 +1,136 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "net/netip" + + tcpip "github.com/sagernet/sing-tun/internal/gtcpip" +) + +const ( + // MaxIPPacketSize is the maximum supported IP packet size, excluding + // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit + // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is + // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where + // m is the minimum IPv6 header size; we leave room for some potential + // IP options. + MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize +) + +// Transport offers generic methods to query and/or update the fields of the +// header of a transport protocol buffer. +type Transport interface { + // SourcePort returns the value of the "source port" field. + SourcePort() uint16 + + // Destination returns the value of the "destination port" field. + DestinationPort() uint16 + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourcePort sets the value of the "source port" field. + SetSourcePort(uint16) + + // SetDestinationPort sets the value of the "destination port" field. + SetDestinationPort(uint16) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // Payload returns the data carried in the transport buffer. + Payload() []byte +} + +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + +// Network offers generic methods to query and/or update the fields of the +// header of a network protocol buffer. +type Network interface { + // SourceAddress returns the value of the "source address" field. + SourceAddress() tcpip.Address + + // DestinationAddress returns the value of the "destination address" + // field. + DestinationAddress() tcpip.Address + + DestinationAddr() netip.Addr + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourceAddress sets the value of the "source address" field. + SetSourceAddress(tcpip.Address) + + // SetDestinationAddress sets the value of the "destination address" + // field. + SetDestinationAddress(tcpip.Address) + + SetDestinationAddr(addr netip.Addr) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // TransportProtocol returns the number of the transport protocol + // stored in the payload. + TransportProtocol() tcpip.TransportProtocolNumber + + // Payload returns a byte slice containing the payload of the network + // packet. + Payload() []byte + + // TOS returns the values of the "type of service" and "flow label" fields. + TOS() (uint8, uint32) + + // SetTOS sets the values of the "type of service" and "flow label" fields. + SetTOS(t uint8, l uint32) +} + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/stack_gvisor.go b/stack_gvisor.go index 83ca9e6..65bb7bd 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -19,13 +19,11 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) const WithGVisor = true -const defaultNIC tcpip.NICID = 1 +const DefaultNIC tcpip.NICID = 1 type GVisor struct { ctx context.Context @@ -68,28 +66,11 @@ func (t *GVisor) Start() error { return err } linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun} - ipStack, err := newGVisorStack(linkEndpoint) + ipStack, err := NewGVisorStack(linkEndpoint) if err != nil { return err } - tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination) - if pErr != nil { - r.Complete(pErr != ErrDrop) - return - } - conn := &gLazyConn{ - parentCtx: t.ctx, - stack: t.stack, - request: r, - localAddr: source.TCPAddr(), - remoteAddr: destination.TCPAddr(), - } - go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) - }) - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint @@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr { } } -func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { +func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { icmp.NewProtocol6, }, }) - err := ipStack.CreateNIC(defaultNIC, ep) + err := ipStack.CreateNIC(DefaultNIC, ep) if err != nil { return nil, gonet.TranslateNetstackError(err) } ipStack.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, - {Destination: header.IPv6EmptySubnet, NIC: defaultNIC}, + {Destination: header.IPv4EmptySubnet, NIC: DefaultNIC}, + {Destination: header.IPv6EmptySubnet, NIC: DefaultNIC}, }) - err = ipStack.SetSpoofing(defaultNIC, true) + err = ipStack.SetSpoofing(DefaultNIC, true) if err != nil { return nil, gonet.TranslateNetstackError(err) } - err = ipStack.SetPromiscuousMode(defaultNIC, true) + err = ipStack.SetPromiscuousMode(DefaultNIC, true) if err != nil { return nil, gonet.TranslateNetstackError(err) } diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go new file mode 100644 index 0000000..33cf40e --- /dev/null +++ b/stack_gvisor_tcp.go @@ -0,0 +1,51 @@ +//go:build with_gvisor + +package tun + +import ( + "context" + + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type TCPForwarder struct { + ctx context.Context + stack *stack.Stack + handler Handler + forwarder *tcp.Forwarder +} + +func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { + forwarder := &TCPForwarder{ + ctx: ctx, + stack: stack, + handler: handler, + } + forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) + return forwarder +} + +func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + return f.forwarder.HandlePacket(id, pkt) +} + +func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { + source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination) + if pErr != nil { + r.Complete(pErr != ErrDrop) + return + } + conn := &gLazyConn{ + parentCtx: f.ctx, + stack: f.stack, + request: r, + localAddr: source.TCPAddr(), + remoteAddr: destination.TCPAddr(), + } + go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil) +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 150fd1a..3027798 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock defer packetBuffer.Release() route, err := w.stack.FindRoute( - defaultNIC, + DefaultNIC, AddressFromAddr(destination.Addr), w.source, w.sourceNetwork, diff --git a/stack_mixed.go b/stack_mixed.go index be146e3..601f7f7 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -38,7 +38,7 @@ func (m *Mixed) Start() error { return err } endpoint := channel.New(1024, uint32(m.mtu), "") - ipStack, err := newGVisorStack(endpoint) + ipStack, err := NewGVisorStack(endpoint) if err != nil { return err } @@ -151,10 +151,10 @@ func (m *Mixed) processPacket(packet []byte) bool { writeBack bool err error ) - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: + switch ipVersion := header.IPVersion(packet); ipVersion { + case header.IPv4Version: writeBack, err = m.processIPv4(packet) - case 6: + case header.IPv6Version: writeBack, err = m.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) diff --git a/stack_system.go b/stack_system.go index 39aead0..dde0664 100644 --- a/stack_system.go +++ b/stack_system.go @@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) } if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e })) copy(icmpHdr.Payload(), payload) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) } else { newPacket.Advance(-w.frontHeadroom) } @@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetChecksum(0) } if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-w.frontHeadroom) } diff --git a/stack_system_packet.go b/stack_system_packet.go new file mode 100644 index 0000000..b5060b0 --- /dev/null +++ b/stack_system_packet.go @@ -0,0 +1,34 @@ +package tun + +import ( + "net/netip" + "syscall" + + "github.com/sagernet/sing-tun/internal/gtcpip/header" +) + +func PacketIPVersion(packet []byte) int { + return header.IPVersion(packet) +} + +func PacketFillHeader(packet []byte, ipVersion int) { + if PacketOffset > 0 { + switch ipVersion { + case header.IPv4Version: + packet[3] = syscall.AF_INET + case header.IPv6Version: + packet[3] = syscall.AF_INET6 + } + } +} + +func PacketDestination(packet []byte) netip.Addr { + switch ipVersion := header.IPVersion(packet); ipVersion { + case header.IPv4Version: + return header.IPv4(packet).DestinationAddr() + case header.IPv6Version: + return header.IPv6(packet).DestinationAddr() + default: + return netip.Addr{} + } +} diff --git a/tun.go b/tun.go index d1738e8..6c7020b 100644 --- a/tun.go +++ b/tun.go @@ -1,6 +1,7 @@ package tun import ( + "github.com/sagernet/sing/common/control" "io" "net" "net/netip" @@ -54,6 +55,7 @@ type Options struct { MTU uint32 GSO bool AutoRoute bool + InterfaceScope bool Inet4Gateway netip.Addr Inet6Gateway netip.Addr DNSServers []netip.Addr @@ -74,6 +76,7 @@ type Options struct { IncludeAndroidUser []int IncludePackage []string ExcludePackage []string + InterfaceFinder control.InterfaceFinder InterfaceMonitor DefaultInterfaceMonitor FileDescriptor int Logger logger.Logger diff --git a/tun_darwin.go b/tun_darwin.go index 3b7a47e..e3f39e5 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -3,6 +3,7 @@ package tun import ( "errors" "fmt" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "net" "net/netip" "os" @@ -96,9 +97,10 @@ var ( func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { var packetHeader []byte - if buffers[0].Byte(0)>>4 == 4 { + switch header.IPVersion(buffers[0].Bytes()) { + case header.IPv4Version: packetHeader = packetHeader4[:] - } else { + case header.IPv6Version: packetHeader = packetHeader6[:] } return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) @@ -250,6 +252,7 @@ func configure(tunFd int, ifIndex int, name string, options Options) error { func (t *NativeTun) setRoutes() error { if t.options.AutoRoute && t.options.FileDescriptor == 0 { + routeRanges, err := t.options.BuildAutoRouteRanges(false) if err != nil { return err @@ -262,14 +265,22 @@ func (t *NativeTun) setRoutes() error { } else { gateway = gateway6 } - err = execRoute(unix.RTM_ADD, destination, gateway) + var interfaceIndex int + if t.options.InterfaceScope { + iff, err := t.options.InterfaceFinder.ByName(t.options.Name) + if err != nil { + return err + } + interfaceIndex = iff.Index + } + err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway) if err != nil { if errors.Is(err, unix.EEXIST) { - err = execRoute(unix.RTM_DELETE, destination, gateway) + err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway) if err != nil { return E.Cause(err, "remove existing route: ", destination) } - err = execRoute(unix.RTM_ADD, destination, gateway) + err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway) if err != nil { return E.Cause(err, "re-add route: ", destination) } @@ -300,7 +311,7 @@ func (t *NativeTun) unsetRoutes() error { } else { gateway = gateway6 } - err = execRoute(unix.RTM_DELETE, destination, gateway) + err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway) if err != nil { err = E.Errors(err, E.Cause(err, "delete route: ", destination)) } @@ -317,7 +328,7 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error { return block(socketFd) } -func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error { +func execRoute(rtmType int, interfaceScope bool, interfaceIndex int, destination netip.Prefix, gateway netip.Addr) error { routeMessage := route.RouteMessage{ Type: rtmType, Version: unix.RTM_VERSION, @@ -326,6 +337,10 @@ func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error } if rtmType == unix.RTM_ADD { routeMessage.Flags |= unix.RTF_UP + if interfaceScope { + routeMessage.Flags |= unix.RTF_IFSCOPE + routeMessage.Index = interfaceIndex + } } if gateway.Is4() { routeMessage.Addrs = []route.Addr{