Skip to content

Commit 3f89621

Browse files
committed
some refactoring
1 parent 700db22 commit 3f89621

File tree

3 files changed

+36
-41
lines changed

3 files changed

+36
-41
lines changed

compress.go

+21-28
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,25 @@ func newCompIO(mc *mysqlConn) *compIO {
9898

9999
func (c *compIO) readNext(need int) ([]byte, error) {
100100
for c.buff.Len() < need {
101-
if err := c.uncompressPacket(); err != nil {
101+
if err := c.readCompressedPacket(); err != nil {
102102
return nil, err
103103
}
104104
}
105105
data := c.buff.Next(need)
106106
return data[:need:need], nil // prevent caller writes into c.buff
107107
}
108108

109-
func (c *compIO) uncompressPacket() error {
109+
func (c *compIO) readCompressedPacket() error {
110110
header, err := c.mc.buf.readNext(7) // size of compressed header
111111
if err != nil {
112112
return err
113113
}
114+
_ = header[6] // bounds check hint to compiler; guaranteed by readNext
114115

115116
// compressed header structure
116-
comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
117-
uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16)
117+
comprLength := getUint24(header[0:3])
118118
compressionSequence := uint8(header[3])
119+
uncompressedLength := getUint24(header[4:7])
119120
if debugTrace {
120121
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
121122
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
@@ -171,43 +172,42 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
171172
buf := &c.buff
172173

173174
for dataLen > 0 {
174-
buf.Reset()
175-
payloadLen := dataLen
176-
if payloadLen > maxPayloadLen {
177-
payloadLen = maxPayloadLen
178-
}
175+
payloadLen := min(maxPayloadLen, dataLen)
179176
payload := packets[:payloadLen]
180177
uncompressedLen := payloadLen
181178

182-
if _, err := buf.Write(blankHeader); err != nil {
183-
return 0, err
184-
}
179+
buf.Reset()
180+
buf.Write(blankHeader) // Buffer.Write() never returns error
185181

186182
// If payload is less than minCompressLength, don't compress.
187183
if uncompressedLen < minCompressLength {
188-
if _, err := buf.Write(payload); err != nil {
189-
return 0, err
190-
}
184+
buf.Write(payload)
191185
uncompressedLen = 0
192186
} else {
193187
zCompress(payload, buf)
188+
// do not compress if compressed data is larger than uncompressed data
189+
// I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes.
190+
if buf.Len() > uncompressedLen {
191+
buf.Reset()
192+
buf.Write(blankHeader)
193+
buf.Write(payload)
194+
uncompressedLen = 0
195+
}
194196
}
195197

196-
if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
198+
if err := c.mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
197199
return 0, err
198200
}
199201
dataLen -= payloadLen
200202
packets = packets[payloadLen:]
201-
buf.Reset()
202203
}
203204

204205
return totalBytes, nil
205206
}
206207

207208
// writeCompressedPacket writes a compressed packet with header.
208209
// data should start with 7 size space for header followed by payload.
209-
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
210-
mc := c.mc
210+
func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error {
211211
comprLength := len(data) - 7
212212
if debugTrace {
213213
fmt.Printf(
@@ -216,16 +216,9 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
216216
}
217217

218218
// compression header
219-
data[0] = byte(0xff & comprLength)
220-
data[1] = byte(0xff & (comprLength >> 8))
221-
data[2] = byte(0xff & (comprLength >> 16))
222-
219+
putUint24(data[0:3], comprLength)
223220
data[3] = mc.compressSequence
224-
225-
// this value is never greater than maxPayloadLength
226-
data[4] = byte(0xff & uncompressedLen)
227-
data[5] = byte(0xff & (uncompressedLen >> 8))
228-
data[6] = byte(0xff & (uncompressedLen >> 16))
221+
putUint24(data[4:7], uncompressedLen)
229222

230223
if _, err := mc.netConn.Write(data); err != nil {
231224
mc.log("writing compressed packet:", err)

packets.go

+3-13
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
4343
}
4444

4545
// packet length [24 bit]
46-
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
46+
pktLen := getUint24(data[:3])
4747
seqNr := data[3]
4848

4949
if mc.compress {
@@ -117,18 +117,8 @@ func (mc *mysqlConn) writePacket(data []byte) error {
117117
}
118118

119119
for {
120-
var size int
121-
if pktLen >= maxPacketSize {
122-
data[0] = 0xff
123-
data[1] = 0xff
124-
data[2] = 0xff
125-
size = maxPacketSize
126-
} else {
127-
data[0] = byte(pktLen)
128-
data[1] = byte(pktLen >> 8)
129-
data[2] = byte(pktLen >> 16)
130-
size = pktLen
131-
}
120+
size := min(maxPacketSize, pktLen)
121+
putUint24(data[:3], size)
132122
data[3] = mc.sequence
133123

134124
// Write packet

utils.go

+12
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,18 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
490490
* Convert from and to bytes *
491491
******************************************************************************/
492492

493+
// 24bit integer: used for packet headers.
494+
495+
func putUint24(data []byte, n int) {
496+
data[2] = byte(n >> 16)
497+
data[1] = byte(n >> 8)
498+
data[0] = byte(n)
499+
}
500+
501+
func getUint24(data []byte) int {
502+
return int(data[2])<<16 | int(data[1])<<8 | int(data[0])
503+
}
504+
493505
func uint64ToBytes(n uint64) []byte {
494506
return []byte{
495507
byte(n),

0 commit comments

Comments
 (0)