Skip to content

Commit bf09c37

Browse files
authored
Merge pull request #86 from lxzan/dev
Fix: ReadMaxPayloadSize Limit
2 parents be5b1fd + e743e93 commit bf09c37

6 files changed

Lines changed: 62 additions & 9 deletions

File tree

benchmark_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
9797
config: config,
9898
deflater: new(deflater),
9999
}
100-
conn1.deflater.initialize(false, conn1.pd)
100+
conn1.deflater.initialize(false, conn1.pd, config.ReadMaxPayloadSize)
101101
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)
102102

103103
var reader = bytes.NewBuffer(buf.Bytes())

client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func (c *connector) handshake() (*Conn, *http.Response, error) {
175175
readQueue: make(channel, c.option.ParallelGolimit),
176176
}
177177
if pd.Enabled {
178-
socket.deflater.initialize(false, pd)
178+
socket.deflater.initialize(false, pd, c.option.ReadMaxPayloadSize)
179179
if pd.ServerContextTakeover {
180180
socket.dpsWindow.initialize(nil, pd.ServerMaxWindowBits)
181181
}

compress.go

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ type deflaterPool struct {
2424
pool []*deflater
2525
}
2626

27-
func (c *deflaterPool) initialize(options PermessageDeflate) *deflaterPool {
27+
func (c *deflaterPool) initialize(options PermessageDeflate, limit int) *deflaterPool {
2828
c.num = uint64(options.PoolSize)
2929
for i := uint64(0); i < c.num; i++ {
30-
c.pool = append(c.pool, new(deflater).initialize(true, options))
30+
c.pool = append(c.pool, new(deflater).initialize(true, options, limit))
3131
}
3232
return c
3333
}
@@ -39,15 +39,19 @@ func (c *deflaterPool) Select() *deflater {
3939

4040
type deflater struct {
4141
dpsLocker sync.Mutex
42+
buf []byte
43+
limit int
4244
dpsBuffer *bytes.Buffer
4345
dpsReader io.ReadCloser
4446
cpsLocker sync.Mutex
4547
cpsWriter *flate.Writer
4648
}
4749

48-
func (c *deflater) initialize(isServer bool, options PermessageDeflate) *deflater {
50+
func (c *deflater) initialize(isServer bool, options PermessageDeflate, limit int) *deflater {
4951
c.dpsReader = flate.NewReader(nil)
5052
c.dpsBuffer = bytes.NewBuffer(nil)
53+
c.buf = make([]byte, 32*1024)
54+
c.limit = limit
5155
windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits)
5256
if windowBits == 15 {
5357
c.cpsWriter, _ = flate.NewWriter(nil, options.Level)
@@ -73,7 +77,8 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er
7377

7478
_, _ = src.Write(flateTail)
7579
c.resetFR(src, dict)
76-
if _, err := c.dpsReader.(io.WriterTo).WriteTo(c.dpsBuffer); err != nil {
80+
reader := limitReader(c.dpsReader, c.limit)
81+
if _, err := io.CopyBuffer(c.dpsBuffer, reader, c.buf); err != nil {
7782
return nil, err
7883
}
7984
var dst = binaryPool.Get(c.dpsBuffer.Len())
@@ -223,3 +228,20 @@ func permessageNegotiation(str string) PermessageDeflate {
223228
options.ServerMaxWindowBits = internal.SelectValue(options.ServerMaxWindowBits < 8, 8, options.ServerMaxWindowBits)
224229
return options
225230
}
231+
232+
func limitReader(r io.Reader, limit int) io.Reader { return &limitedReader{R: r, M: limit} }
233+
234+
type limitedReader struct {
235+
R io.Reader
236+
N int
237+
M int
238+
}
239+
240+
func (c *limitedReader) Read(p []byte) (n int, err error) {
241+
n, err = c.R.Read(p)
242+
c.N += n
243+
if c.N > c.M {
244+
return n, internal.CloseMessageTooLarge
245+
}
246+
return
247+
}

task_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ func serveWebSocket(
4040
}
4141
if compressEnabled {
4242
if isServer {
43-
socket.deflater = new(deflaterPool).initialize(pd).Select()
43+
socket.deflater = new(deflaterPool).initialize(pd, config.ReadMaxPayloadSize).Select()
4444
if pd.ServerContextTakeover {
4545
socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits)
4646
}
4747
if pd.ClientContextTakeover {
4848
socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits)
4949
}
5050
} else {
51-
socket.deflater = new(deflater).initialize(false, pd)
51+
socket.deflater = new(deflater).initialize(false, pd, config.ReadMaxPayloadSize)
5252
}
5353
}
5454
return socket

upgrader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader {
8484
deflaterPool: new(deflaterPool),
8585
}
8686
if u.option.PermessageDeflate.Enabled {
87-
u.deflaterPool.initialize(u.option.PermessageDeflate)
87+
u.deflaterPool.initialize(u.option.PermessageDeflate, option.ReadMaxPayloadSize)
8888
}
8989
return u
9090
}

writer_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package gws
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"io"
78
"net"
89
"net/http"
@@ -79,6 +80,36 @@ func TestWriteBigMessage(t *testing.T) {
7980
var err = server.WriteMessage(OpcodeText, internal.AlphabetNumeric.Generate(128))
8081
assert.Error(t, err)
8182
})
83+
84+
t.Run("", func(t *testing.T) {
85+
var wg = &sync.WaitGroup{}
86+
wg.Add(1)
87+
var serverHandler = new(webSocketMocker)
88+
var clientHandler = new(webSocketMocker)
89+
serverHandler.onClose = func(socket *Conn, err error) {
90+
assert.True(t, errors.Is(err, internal.CloseMessageTooLarge))
91+
wg.Done()
92+
}
93+
var serverOption = &ServerOption{
94+
ReadMaxPayloadSize: 128,
95+
PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1},
96+
}
97+
var clientOption = &ClientOption{
98+
ReadMaxPayloadSize: 128 * 1024,
99+
PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1},
100+
}
101+
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
102+
go server.ReadLoop()
103+
go client.ReadLoop()
104+
105+
var buf = bytes.NewBufferString("")
106+
for i := 0; i < 64*1024; i++ {
107+
buf.WriteString("a")
108+
}
109+
var err = client.WriteMessage(OpcodeText, buf.Bytes())
110+
assert.NoError(t, err)
111+
wg.Wait()
112+
})
82113
}
83114

84115
func TestWriteClose(t *testing.T) {

0 commit comments

Comments
 (0)