diff --git a/close.go b/close.go index ff2e878a..adf42c52 100644 --- a/close.go +++ b/close.go @@ -214,7 +214,7 @@ func (c *Conn) waitCloseHandshake() error { } for { - h, err := c.readLoop(ctx) + h, err := c.readLoop(ctx, true) if err != nil { return err } diff --git a/read.go b/read.go index e2699da5..f84b0fec 100644 --- a/read.go +++ b/read.go @@ -180,7 +180,8 @@ func (c *Conn) readRSV1Illegal(h header) bool { return false } -func (c *Conn) readLoop(ctx context.Context) (header, error) { +func (c *Conn) readLoop(ctx context.Context, ignoreControl bool) (header, error) { + for { h, err := c.readFrameHeader(ctx) if err != nil { @@ -199,6 +200,9 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { switch h.opcode { case opClose, opPing, opPong: + if ignoreControl { + return h, nil + } err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. @@ -344,7 +348,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro return 0, nil, errors.New("previous message not read to completion") } - h, err := c.readLoop(ctx) + h, err := c.readLoop(ctx, false) if err != nil { return 0, nil, err } @@ -429,7 +433,7 @@ func (mr *msgReader) read(p []byte) (int, error) { return 0, io.EOF } - h, err := mr.c.readLoop(mr.ctx) + h, err := mr.c.readLoop(mr.ctx, false) if err != nil { return 0, err }