diff --git a/Makefile b/Makefile index cb7f9a4..f84e2fd 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := v0.2.3 +VERSION := v0.2.4 BIN_NAME = grpcr CONTAINER = grpcr diff --git a/http2/processor.go b/http2/processor.go index 445840d..4a9caef 100644 --- a/http2/processor.go +++ b/http2/processor.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" + "math" "github.com/vearne/grpcreplay/protocol" slog "github.com/vearne/simplelog" @@ -49,6 +50,8 @@ func (p *Processor) ProcessTCPPkg() { } hc := p.ConnRepository[dc] + payloadSize := uint32(len(payload)) + // SYN/ACK/FIN if len(payload) <= 0 { if pkg.TCP.FIN { @@ -56,16 +59,14 @@ func (p *Processor) ProcessTCPPkg() { hc.TCPBuffer.Close() delete(p.ConnRepository, dc) } else { - hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload)) - hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq + hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32 } continue } // connection preface if IsConnPreface(payload) { - hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload)) - hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq + hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32 continue } diff --git a/http2/tcp_buffer.go b/http2/tcp_buffer.go index 24cb644..9dda887 100644 --- a/http2/tcp_buffer.go +++ b/http2/tcp_buffer.go @@ -1,23 +1,20 @@ package http2 import ( - "bytes" "github.com/google/gopacket/layers" "github.com/huandu/skiplist" slog "github.com/vearne/simplelog" "math" "net" + "sync/atomic" ) type TCPBuffer struct { //The number of bytes of data currently cached - size uint32 - actualCanReadSize uint32 + size atomic.Int64 + actualCanReadSize atomic.Int64 List *skiplist.SkipList - expectedSeq int64 - // The sliding window contains the leftPointer - leftPointer int64 - + expectedSeq uint32 //There is at most one reader to read dataChannel chan []byte closeChan chan struct{} @@ -26,11 +23,10 @@ type TCPBuffer struct { func NewTCPBuffer() *TCPBuffer { var sb TCPBuffer sb.List = skiplist.New(skiplist.Uint32) - sb.size = 0 - sb.actualCanReadSize = 0 - sb.expectedSeq = -1 - sb.leftPointer = -1 - sb.dataChannel = make(chan []byte, 10) + sb.size.Store(0) + sb.actualCanReadSize.Store(0) + sb.expectedSeq = 0 + sb.dataChannel = make(chan []byte, 100) sb.closeChan = make(chan struct{}) return &sb } @@ -47,72 +43,41 @@ func (sb *TCPBuffer) Read(p []byte) (n int, err error) { err = net.ErrClosed case data = <-sb.dataChannel: n = copy(p, data) + dataSize := int64(len(data)) + sb.size.Add(dataSize * -1) + sb.actualCanReadSize.Add(dataSize * -1) } slog.Debug("SocketBuffer.Read, got:%v bytes", n) return n, err } func (sb *TCPBuffer) AddTCP(tcpPkg *layers.TCP) { - sb.addTCP(tcpPkg) - - if sb.actualCanReadSize > 0 { - slog.Debug("SocketBuffer.AddTCP, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize, sb.expectedSeq) - data := sb.getData() - slog.Debug("push to channel: %v bytes", len(data)) - sb.dataChannel <- data - } -} - -func (sb *TCPBuffer) addTCP(tcpPkg *layers.TCP) { slog.Debug("[start]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize, sb.expectedSeq) + sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq) // duplicate package - if int64(tcpPkg.Seq) < sb.leftPointer || sb.List.Get(tcpPkg.Seq) != nil { + if sb.List.Get(tcpPkg.Seq) != nil { slog.Debug("[end]SocketBuffer.addTCP-duplicate package, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize, sb.expectedSeq) + sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq) return } ele := sb.List.Set(tcpPkg.Seq, tcpPkg) - sb.size += uint32(len(tcpPkg.Payload)) + sb.size.Add(int64(len(tcpPkg.Payload))) + needRemoveList := make([]*skiplist.Element, 0) - for ele != nil && sb.expectedSeq == int64(tcpPkg.Seq) { + for ele != nil && sb.expectedSeq == tcpPkg.Seq { // expect next sequence number - sb.expectedSeq = int64((tcpPkg.Seq + uint32(len(tcpPkg.Payload))) % math.MaxUint32) - sb.actualCanReadSize += uint32(len(tcpPkg.Payload)) + // sequence numbers may wrap around + payloadSize := uint32(len(tcpPkg.Payload)) + sb.actualCanReadSize.Add(int64(payloadSize)) + sb.expectedSeq = (tcpPkg.Seq + payloadSize) % math.MaxUint32 - ele = ele.Next() - if ele != nil { - tcpPkg = ele.Value.(*layers.TCP) - } - } - slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize, sb.expectedSeq) -} - -func (sb *TCPBuffer) getData() []byte { - slog.Debug("[start]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize, sb.expectedSeq) - - var tcpPkg *layers.TCP - buf := bytes.NewBuffer([]byte{}) - ele := sb.List.Front() - if ele != nil { - tcpPkg = ele.Value.(*layers.TCP) - } - - needRemoveList := make([]*skiplist.Element, 0) - for ele != nil && int64(tcpPkg.Seq) <= sb.expectedSeq { - sb.actualCanReadSize -= uint32(len(tcpPkg.Payload)) - sb.size -= uint32(len(tcpPkg.Payload)) - sb.leftPointer += int64(len(tcpPkg.Payload)) - - buf.Write(tcpPkg.Payload) + // push to channel + sb.dataChannel <- tcpPkg.Payload needRemoveList = append(needRemoveList, ele) - ele = ele.Next() + ele = sb.List.Get(sb.expectedSeq) if ele != nil { tcpPkg = ele.Value.(*layers.TCP) } @@ -123,7 +88,6 @@ func (sb *TCPBuffer) getData() []byte { sb.List.RemoveElement(element) } - slog.Debug("[end]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v, data: %v bytes", - sb.size, sb.actualCanReadSize, sb.expectedSeq, buf.Len()) - return buf.Bytes() + slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v", + sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq) } diff --git a/http2/tcp_buffer_test.go b/http2/tcp_buffer_test.go index d7f9cd1..f8896c7 100644 --- a/http2/tcp_buffer_test.go +++ b/http2/tcp_buffer_test.go @@ -12,7 +12,6 @@ func TestSocketBufferSequence1(t *testing.T) { slog.SetLevel(slog.DebugLevel) buffer := NewTCPBuffer() buffer.expectedSeq = 1000 - buffer.leftPointer = 1000 var tcpPkgA layers.TCP tcpPkgA.Seq = 1000 @@ -44,7 +43,6 @@ func TestSocketBufferSequence2(t *testing.T) { slog.SetLevel(slog.DebugLevel) buffer := NewTCPBuffer() buffer.expectedSeq = 1000 - buffer.leftPointer = 1000 var tcpPkgA layers.TCP tcpPkgA.Seq = 1000 @@ -80,7 +78,6 @@ func TestSocketBufferSequence3(t *testing.T) { slog.SetLevel(slog.DebugLevel) buffer := NewTCPBuffer() buffer.expectedSeq = 1000 - buffer.leftPointer = 1000 var tcpPkgA layers.TCP tcpPkgA.Seq = 1000 @@ -112,22 +109,21 @@ func TestSocketBufferSequence3(t *testing.T) { assert.Nil(t, err) } -func TestSocketBufferDuplicate(t *testing.T) { +func TestSocketBufferWrapAround1(t *testing.T) { slog.SetLevel(slog.DebugLevel) buffer := NewTCPBuffer() - buffer.expectedSeq = 1000 - buffer.leftPointer = 1000 + buffer.expectedSeq = 4294967290 var tcpPkgA layers.TCP - tcpPkgA.Seq = 1000 + tcpPkgA.Seq = 4294967290 tcpPkgA.Payload = []byte("aaaaaaaaaa") var tcpPkgB layers.TCP - tcpPkgB.Seq = 1010 + tcpPkgB.Seq = 4 tcpPkgB.Payload = []byte("bbbbbbbbbb") var tcpPkgC layers.TCP - tcpPkgC.Seq = 1020 + tcpPkgC.Seq = 14 tcpPkgC.Payload = []byte("cccccccccc") buffer.AddTCP(&tcpPkgA) @@ -143,3 +139,64 @@ func TestSocketBufferDuplicate(t *testing.T) { // assert for nil (good for errors) assert.Nil(t, err) } + +func TestSocketBufferWrapAround2(t *testing.T) { + slog.SetLevel(slog.DebugLevel) + buffer := NewTCPBuffer() + buffer.expectedSeq = 4294967290 + + var tcpPkgA layers.TCP + tcpPkgA.Seq = 4294967290 + tcpPkgA.Payload = []byte("aaaaaaaaaa") + + var tcpPkgB layers.TCP + tcpPkgB.Seq = 4 + tcpPkgB.Payload = []byte("bbbbbbbbbb") + + var tcpPkgC layers.TCP + tcpPkgC.Seq = 14 + tcpPkgC.Payload = []byte("cccccccccc") + + buffer.AddTCP(&tcpPkgB) + buffer.AddTCP(&tcpPkgA) + buffer.AddTCP(&tcpPkgC) + buffer.AddTCP(&tcpPkgA) + + buf := make([]byte, 1024) + n, err := io.ReadAtLeast(buffer, buf, 30) + // assert equality + assert.Equal(t, 30, n, "read data") + assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data") + // assert for nil (good for errors) + assert.Nil(t, err) +} + +func TestSocketBufferWrapAround3(t *testing.T) { + slog.SetLevel(slog.DebugLevel) + buffer := NewTCPBuffer() + buffer.expectedSeq = 4294967290 + + var tcpPkgA layers.TCP + tcpPkgA.Seq = 4294967290 + tcpPkgA.Payload = []byte("aaaaaaaaaa") + + var tcpPkgB layers.TCP + tcpPkgB.Seq = 4 + tcpPkgB.Payload = []byte("bbbbbbbbbb") + + var tcpPkgC layers.TCP + tcpPkgC.Seq = 14 + tcpPkgC.Payload = []byte("cccccccccc") + + buffer.AddTCP(&tcpPkgA) + buffer.AddTCP(&tcpPkgB) + buffer.AddTCP(&tcpPkgC) + + buf := make([]byte, 1024) + n, err := io.ReadAtLeast(buffer, buf, 30) + // assert equality + assert.Equal(t, 30, n, "read data") + assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data") + // assert for nil (good for errors) + assert.Nil(t, err) +}