@@ -24,10 +24,10 @@ type deflaterPool struct {
2424 pool []* deflater
2525}
2626
27- func (c * deflaterPool ) initialize (options PermessageDeflate ) * deflaterPool {
27+ func (c * deflaterPool ) initialize (options PermessageDeflate , limit int ) * deflaterPool {
2828 c .num = uint64 (options .PoolSize )
2929 for i := uint64 (0 ); i < c .num ; i ++ {
30- c .pool = append (c .pool , new (deflater ).initialize (true , options ))
30+ c .pool = append (c .pool , new (deflater ).initialize (true , options , limit ))
3131 }
3232 return c
3333}
@@ -39,15 +39,19 @@ func (c *deflaterPool) Select() *deflater {
3939
4040type deflater struct {
4141 dpsLocker sync.Mutex
42+ buf []byte
43+ limit int
4244 dpsBuffer * bytes.Buffer
4345 dpsReader io.ReadCloser
4446 cpsLocker sync.Mutex
4547 cpsWriter * flate.Writer
4648}
4749
48- func (c * deflater ) initialize (isServer bool , options PermessageDeflate ) * deflater {
50+ func (c * deflater ) initialize (isServer bool , options PermessageDeflate , limit int ) * deflater {
4951 c .dpsReader = flate .NewReader (nil )
5052 c .dpsBuffer = bytes .NewBuffer (nil )
53+ c .buf = make ([]byte , 32 * 1024 )
54+ c .limit = limit
5155 windowBits := internal .SelectValue (isServer , options .ServerMaxWindowBits , options .ClientMaxWindowBits )
5256 if windowBits == 15 {
5357 c .cpsWriter , _ = flate .NewWriter (nil , options .Level )
@@ -73,7 +77,8 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er
7377
7478 _ , _ = src .Write (flateTail )
7579 c .resetFR (src , dict )
76- if _ , err := c .dpsReader .(io.WriterTo ).WriteTo (c .dpsBuffer ); err != nil {
80+ reader := limitReader (c .dpsReader , c .limit )
81+ if _ , err := io .CopyBuffer (c .dpsBuffer , reader , c .buf ); err != nil {
7782 return nil , err
7883 }
7984 var dst = binaryPool .Get (c .dpsBuffer .Len ())
@@ -223,3 +228,20 @@ func permessageNegotiation(str string) PermessageDeflate {
223228 options .ServerMaxWindowBits = internal .SelectValue (options .ServerMaxWindowBits < 8 , 8 , options .ServerMaxWindowBits )
224229 return options
225230}
231+
232+ func limitReader (r io.Reader , limit int ) io.Reader { return & limitedReader {R : r , M : limit } }
233+
234+ type limitedReader struct {
235+ R io.Reader
236+ N int
237+ M int
238+ }
239+
240+ func (c * limitedReader ) Read (p []byte ) (n int , err error ) {
241+ n , err = c .R .Read (p )
242+ c .N += n
243+ if c .N > c .M {
244+ return n , internal .CloseMessageTooLarge
245+ }
246+ return
247+ }
0 commit comments