Skip to content

Commit

Permalink
refactor: split UDP serving from handling of packets
Browse files Browse the repository at this point in the history
This will allow us to re-use the handling of packets in the Caddy
server where serving is handled separately.
  • Loading branch information
sbruens committed Nov 18, 2024
1 parent e1fd23b commit ffb6c09
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 101 deletions.
4 changes: 2 additions & 2 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
return err
}
slog.Info("UDP service started.", "address", pc.LocalAddr().String())
go ssService.HandlePacket(pc)
go service.PacketServe(pc, ssService.HandlePacket)
}

for _, serviceConfig := range config.Services {
Expand Down Expand Up @@ -271,7 +271,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
return err
}
slog.Info("UDP service started.", "address", pc.LocalAddr().String())
go ssService.HandlePacket(pc)
go service.PacketServe(pc, ssService.HandlePacket)
}
}
totalCipherCount += len(serviceConfig.Keys)
Expand Down
6 changes: 3 additions & 3 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func TestUDPEcho(t *testing.T) {
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
proxy.Handle(proxyConn)
service.PacketServe(proxyConn, proxy.Handle)
done <- struct{}{}
}()

Expand Down Expand Up @@ -548,7 +548,7 @@ func BenchmarkUDPEcho(b *testing.B) {
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
proxy.Handle(server)
service.PacketServe(server, proxy.Handle)
done <- struct{}{}
}()

Expand Down Expand Up @@ -592,7 +592,7 @@ func BenchmarkUDPManyKeys(b *testing.B) {
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
proxy.Handle(proxyConn)
service.PacketServe(proxyConn, proxy.Handle)
done <- struct{}{}
}()

Expand Down
6 changes: 3 additions & 3 deletions service/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type ServiceMetrics interface {

type Service interface {
HandleStream(ctx context.Context, conn transport.StreamConn)
HandlePacket(conn net.PacketConn)
HandlePacket(conn net.Conn, pkt []byte)
}

// Option is a Shadowsocks service constructor option.
Expand Down Expand Up @@ -137,8 +137,8 @@ func (s *ssService) HandleStream(ctx context.Context, conn transport.StreamConn)
}

// HandlePacket handles a Shadowsocks packet connection.
func (s *ssService) HandlePacket(conn net.PacketConn) {
s.ph.Handle(conn)
func (s *ssService) HandlePacket(conn net.Conn, pkt []byte) {
s.ph.Handle(conn, pkt)
}

type ssConnMetrics struct {
Expand Down
205 changes: 116 additions & 89 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ type UDPMetrics interface {
// Max UDP buffer size for the server code.
const serverUDPBufferSize = 64 * 1024

var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, serverUDPBufferSize)
},
}

// Wrapper for slog.Debug during UDP proxying.
func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
Expand Down Expand Up @@ -83,7 +89,7 @@ type packetHandler struct {
logger *slog.Logger
natTimeout time.Duration
ciphers CipherList
m UDPMetrics
nm *natmap
ssm ShadowsocksConnMetrics
targetIPValidator onet.TargetIPValidator
}
Expand All @@ -96,24 +102,23 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr
if ssMetrics == nil {
ssMetrics = &NoOpShadowsocksConnMetrics{}
}
nm := newNATmap(natTimeout, m, noopLogger())
return &packetHandler{
logger: noopLogger(),
natTimeout: natTimeout,
ciphers: cipherList,
m: m,
nm: nm,
ssm: ssMetrics,
targetIPValidator: onet.RequirePublicIP,
}
}

// PacketHandler is a running UDP shadowsocks proxy that can be stopped.
type PacketHandler interface {
Handle(conn net.Conn, pkt []byte)
// SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
SetLogger(l *slog.Logger)
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
// Handle returns after clientConn closes and all the sub goroutines return.
Handle(clientConn net.PacketConn)
}

func (h *packetHandler) SetLogger(l *slog.Logger) {
Expand All @@ -127,101 +132,124 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali
h.targetIPValidator = targetIPValidator
}

// Listen on addr for encrypted packets and basically do UDP NAT.
// We take the ciphers as a pointer because it gets replaced on config updates.
func (h *packetHandler) Handle(clientConn net.PacketConn) {
nm := newNATmap(h.natTimeout, h.m, h.logger)
defer nm.Close()
cipherBuf := make([]byte, serverUDPBufferSize)
textBuf := make([]byte, serverUDPBufferSize)
type PacketHandleFunc func(conn net.Conn, pkt []byte)

// PacketServe listens for packets and calls `handle` to handle them until the connection
// returns [ErrClosed].
func PacketServe(clientConn net.PacketConn, handle PacketHandleFunc) {
buffer := bufferPool.Get().([]byte)
defer bufferPool.Put(buffer)

for {
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
if errors.Is(err, net.ErrClosed) {
break
n, addr, err := clientConn.ReadFrom(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
slog.Warn("Failed to read from client. Continuing to listen.", "err", err)
continue
}
pkt := buffer[:n]

keyID := ""
var proxyTargetBytes int
var targetConn *natconn

connError := func() (connError *onet.ConnectionError) {
func() {
defer func() {
if r := recover(); r != nil {
slog.Error("Panic in UDP loop: %v. Continuing to listen.", r)
slog.Error("Panic in UDP loop. Continuing to listen.", "err", r)
debug.PrintStack()
}
}()
handle(&wrappedPacketConn{PacketConn: clientConn, raddr: addr}, pkt)
}()
}
}

// Error from ReadFrom
if err != nil {
return onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
}
defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String()))
debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes))

cipherData := cipherBuf[:clientProxyBytes]
var payload []byte
var tgtUDPAddr *net.UDPAddr
targetConn = nm.Get(clientAddr.String())
if targetConn == nil {
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger)
timeToCipher := time.Since(unpackStart)
h.ssm.AddCipherSearch(err == nil, timeToCipher)

if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
}
type wrappedPacketConn struct {
net.PacketConn
raddr net.Addr
}

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}
func (pc *wrappedPacketConn) Read(p []byte) (int, error) {
n, _, err := pc.PacketConn.ReadFrom(p)
return n, err
}

udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}
targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, keyID)
} else {
unpackStart := time.Now()
textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey)
timeToCipher := time.Since(unpackStart)
h.ssm.AddCipherSearch(err == nil, timeToCipher)

if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
}
func (pc *wrappedPacketConn) RemoteAddr() net.Addr {
return pc.raddr
}

// The key ID is known with confidence once decryption succeeds.
keyID = targetConn.keyID
func (pc *wrappedPacketConn) Write(b []byte) (n int, err error) {
return pc.PacketConn.WriteTo(b, pc.raddr)
}

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}
func (h *packetHandler) Handle(clientConn net.Conn, pkt []byte) {
buffer := bufferPool.Get().([]byte)
defer bufferPool.Put(buffer)

var err error
var proxyTargetBytes int
var targetConn *natconn

connError := func() (connError *onet.ConnectionError) {
defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientConn.RemoteAddr().String()))
debugUDPAddr(h.logger, "Outbound packet.", clientConn.RemoteAddr(), slog.Int("bytes", len(pkt)))

var payload []byte
var tgtUDPAddr *net.UDPAddr
targetConn = h.nm.Get(clientConn.RemoteAddr().String())
if targetConn == nil {
ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr()
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, buffer, pkt, h.ciphers, h.logger)
timeToCipher := time.Since(unpackStart)
h.ssm.AddCipherSearch(err == nil, timeToCipher)

if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
}

debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr()))
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}

udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}
return nil
}()
targetConn = h.nm.Add(clientConn, udpConn, cryptoKey, keyID)
} else {
unpackStart := time.Now()
textData, err := shadowsocks.Unpack(nil, pkt, targetConn.cryptoKey)
timeToCipher := time.Since(unpackStart)
h.ssm.AddCipherSearch(err == nil, timeToCipher)

status := "OK"
if connError != nil {
slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
status = connError.Status
if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
}

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}
}
if targetConn != nil {
targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes))

debugUDPAddr(h.logger, "Proxy exit.", clientConn.RemoteAddr(), slog.Any("target", targetConn.LocalAddr()))
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
}
return nil
}()

status := "OK"
if connError != nil {
slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
status = connError.Status
}
if targetConn != nil {
targetConn.metrics.AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes))
}
}

Expand Down Expand Up @@ -333,11 +361,10 @@ func (m *natmap) Get(key string) *natconn {
return m.keyConn[key]
}

func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, connMetrics UDPConnMetrics) *natconn {
func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, connMetrics UDPConnMetrics) *natconn {
entry := &natconn{
PacketConn: pc,
cryptoKey: cryptoKey,
keyID: keyID,
metrics: connMetrics,
defaultTimeout: m.timeout,
}
Expand All @@ -361,14 +388,14 @@ func (m *natmap) del(key string) net.PacketConn {
return nil
}

func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn {
connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID)
entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, connMetrics)
func (m *natmap) Add(clientConn net.Conn, targetConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string) *natconn {
connMetrics := m.metrics.AddUDPNatEntry(clientConn.RemoteAddr(), keyID)
entry := m.set(clientConn.RemoteAddr().String(), targetConn, cryptoKey, connMetrics)

go func() {
timedCopy(clientAddr, clientConn, entry, keyID, m.logger)
timedCopy(clientConn, entry, m.logger)
connMetrics.RemoveNatEntry()
if pc := m.del(clientAddr.String()); pc != nil {
if pc := m.del(clientConn.RemoteAddr().String()); pc != nil {
pc.Close()
}
}()
Expand All @@ -394,7 +421,7 @@ func (m *natmap) Close() error {
var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345"))

// copy from target to client until read timeout
func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l *slog.Logger) {
func timedCopy(clientConn net.Conn, targetConn *natconn, l *slog.Logger) {
// pkt is used for in-place encryption of downstream UDP packets, with the layout
// [padding?][salt][address][body][tag][extra]
// Padding is only used if the address is IPv4.
Expand Down Expand Up @@ -427,7 +454,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco
return onet.NewConnectionError("ERR_READ", "Failed to read from target", err)
}

debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr))
debugUDPAddr(l, "Got response.", clientConn.RemoteAddr(), slog.Any("target", raddr))
srcAddr := socks.ParseAddr(raddr.String())
addrStart := bodyStart - len(srcAddr)
// `plainTextBuf` concatenates the SOCKS address and body:
Expand All @@ -448,7 +475,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco
if err != nil {
return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err)
}
proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr)
proxyClientBytes, err = clientConn.Write(buf)
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err)
}
Expand Down
Loading

0 comments on commit ffb6c09

Please sign in to comment.