diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 870097c75..5dd345028 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/netip" + "os" "syscall" "github.com/sagernet/sing/common/buf" @@ -25,10 +26,11 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, bufferSize = buf.BufferSize } var ( - buffer *buf.Buffer - readBuffer *buf.Buffer + buffer *buf.Buffer + readBuffer *buf.Buffer + notFirstTime bool ) - newBuffer := func() *buf.Buffer { + source.InitializeReadWaiter(func() *buf.Buffer { if buffer != nil { buffer.Release() } @@ -37,10 +39,10 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) return readBuffer - } - var notFirstTime bool + }) + defer source.InitializeReadWaiter(nil) for { - err = source.WaitReadBuffer(newBuffer) + err = source.WaitReadBuffer() if err != nil { buffer.Release() if errors.Is(err, io.EOF) { @@ -55,10 +57,8 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) err = destination.WriteBuffer(buffer) + buffer.Release() if err != nil { - if buffer != nil { - buffer.Release() - } return } n += int64(dataLen) @@ -83,10 +83,12 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW bufferSize = buf.UDPBufferSize } var ( - buffer *buf.Buffer - readBuffer *buf.Buffer + buffer *buf.Buffer + readBuffer *buf.Buffer + destination M.Socksaddr + notFirstTime bool ) - newBuffer := func() *buf.Buffer { + source.InitializeReadWaiter(func() *buf.Buffer { if buffer != nil { buffer.Release() } @@ -95,11 +97,10 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) return readBuffer - } - var destination M.Socksaddr - var notFirstTime bool + }) + defer source.InitializeReadWaiter(nil) for { - destination, err = source.WaitReadPacket(newBuffer) + destination, err = source.WaitReadPacket() if err != nil { buffer.Release() if !notFirstTime { @@ -113,9 +114,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW if err != nil { buffer.Release() return - } else { - buffer = nil } + buffer = nil n += int64(dataLen) for _, counter := range readCounters { counter(int64(dataLen)) @@ -127,6 +127,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW } } +var _ N.ReadWaiter = (*syscallReadWaiter)(nil) + type syscallReadWaiter struct { rawConn syscall.RawConn readErr error @@ -143,8 +145,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { return nil, false } -func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { - if w.readFunc == nil { +func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + w.readErr = nil + if newBuffer == nil { + w.readFunc = nil + } else { w.readFunc = func(fd uintptr) (done bool) { buffer := newBuffer() var readN int @@ -164,16 +169,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { return true } } +} + +func (w *syscallReadWaiter) WaitReadBuffer() error { + if w.readFunc == nil { + return os.ErrInvalid + } err := w.rawConn.Read(w.readFunc) if err != nil { return err } if w.readErr != nil { + if w.readErr == io.EOF { + return io.EOF + } return E.Cause(w.readErr, "raw read") } return nil } +var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) + type syscallPacketReadWaiter struct { rawConn syscall.RawConn readErr error @@ -191,8 +207,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) return nil, false } -func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - if w.readFunc == nil { +func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + w.readErr = nil + w.readFrom = M.Socksaddr{} + if newBuffer == nil { + w.readFunc = nil + } else { w.readFunc = func(fd uintptr) (done bool) { buffer := newBuffer() var readN int @@ -221,6 +241,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) ( return true } } +} + +func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { + if w.readFunc == nil { + return M.Socksaddr{}, os.ErrInvalid + } err = w.rawConn.Read(w.readFunc) if err != nil { return diff --git a/common/network/direct.go b/common/network/direct.go index 0f09e0a4f..a40275c24 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -6,7 +6,8 @@ import ( ) type ReadWaiter interface { - WaitReadBuffer(newBuffer func() *buf.Buffer) error + InitializeReadWaiter(newBuffer func() *buf.Buffer) + WaitReadBuffer() error } type ReadWaitCreator interface { @@ -14,7 +15,8 @@ type ReadWaitCreator interface { } type PacketReadWaiter interface { - WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) + InitializeReadWaiter(newBuffer func() *buf.Buffer) + WaitReadPacket() (destination M.Socksaddr, err error) } type PacketReadWaitCreator interface {