Skip to content

Commit

Permalink
Improve read waiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed May 12, 2023
1 parent ab3e469 commit b671451
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 24 deletions.
70 changes: 48 additions & 22 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"io"
"net/netip"
"os"
"syscall"

"github.com/sagernet/sing/common/buf"
Expand All @@ -25,10 +26,11 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool
)
newBuffer := func() *buf.Buffer {
source.InitializeReadWaiter(func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
Expand All @@ -37,10 +39,10 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var notFirstTime bool
})
defer source.InitializeReadWaiter(nil)
for {
err = source.WaitReadBuffer(newBuffer)
err = source.WaitReadBuffer()
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
Expand All @@ -55,10 +57,8 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
buffer.Release()
if err != nil {
if buffer != nil {
buffer.Release()
}
return
}
n += int64(dataLen)
Expand All @@ -83,10 +83,12 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
notFirstTime bool
)
newBuffer := func() *buf.Buffer {
source.InitializeReadWaiter(func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
Expand All @@ -95,11 +97,10 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var destination M.Socksaddr
var notFirstTime bool
})
defer source.InitializeReadWaiter(nil)
for {
destination, err = source.WaitReadPacket(newBuffer)
destination, err = source.WaitReadPacket()
if err != nil {
buffer.Release()
if !notFirstTime {
Expand All @@ -113,9 +114,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
if err != nil {
buffer.Release()
return
} else {
buffer = nil
}
buffer = nil
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
Expand All @@ -127,6 +127,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
}
}

var _ N.ReadWaiter = (*syscallReadWaiter)(nil)

type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
Expand All @@ -143,8 +145,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}

func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
if w.readFunc == nil {
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
Expand All @@ -164,16 +169,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
return true
}
}
}

func (w *syscallReadWaiter) WaitReadBuffer() error {
if w.readFunc == nil {
return os.ErrInvalid
}
err := w.rawConn.Read(w.readFunc)
if err != nil {
return err
}
if w.readErr != nil {
if w.readErr == io.EOF {
return io.EOF
}
return E.Cause(w.readErr, "raw read")
}
return nil
}

var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)

type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
Expand All @@ -191,8 +207,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}

func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
if w.readFunc == nil {
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
Expand Down Expand Up @@ -221,6 +241,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
return true
}
}
}

func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
Expand Down
6 changes: 4 additions & 2 deletions common/network/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ import (
)

type ReadWaiter interface {
WaitReadBuffer(newBuffer func() *buf.Buffer) error
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadBuffer() error
}

type ReadWaitCreator interface {
CreateReadWaiter() (ReadWaiter, bool)
}

type PacketReadWaiter interface {
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadPacket() (destination M.Socksaddr, err error)
}

type PacketReadWaitCreator interface {
Expand Down

0 comments on commit b671451

Please sign in to comment.