@@ -3,6 +3,7 @@ package packet
33import (
44"bufio"
55"bytes"
6+ "compress/zlib"
67"crypto/rand"
78"crypto/rsa"
89"crypto/sha1"
@@ -12,6 +13,7 @@ import (
1213"net"
1314"sync"
1415
16+ "github.com/DataDog/zstd"
1517. "github.com/go-mysql-org/go-mysql/mysql"
1618"github.com/go-mysql-org/go-mysql/utils"
1719"github.com/pingcap/errors"
@@ -56,6 +58,16 @@ type Conn struct {
5658header [4 ]byte
5759
5860Sequence uint8
61+
62+ Compression uint8
63+
64+ CompressedSequence uint8
65+
66+ compressedHeader [7 ]byte
67+
68+ compressedReaderActive bool
69+
70+ compressedReader io.Reader
5971}
6072
6173func NewConn (conn net.Conn ) * Conn {
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
94106utils .BytesBufferPut (buf )
95107}()
96108
97- if err := c .ReadPacketTo (buf ); err != nil {
98- return nil , errors .Trace (err )
109+ if c .Compression != MYSQL_COMPRESS_NONE {
110+ if ! c .compressedReaderActive {
111+ if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
112+ return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
113+ }
114+
115+ compressedSequence := c .compressedHeader [3 ]
116+ uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
117+ if compressedSequence != c .CompressedSequence {
118+ return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
119+ compressedSequence , c .CompressedSequence )
120+ }
121+
122+ if uncompressedLength > 0 {
123+ var err error
124+ switch c .Compression {
125+ case MYSQL_COMPRESS_ZLIB :
126+ c .compressedReader , err = zlib .NewReader (c .reader )
127+ case MYSQL_COMPRESS_ZSTD :
128+ c .compressedReader = zstd .NewReader (c .reader )
129+ }
130+ if err != nil {
131+ return nil , err
132+ }
133+ }
134+ c .compressedReaderActive = true
135+ }
136+ }
137+
138+ if c .compressedReader != nil {
139+ if err := c .ReadPacketTo (buf , c .compressedReader ); err != nil {
140+ return nil , errors .Trace (err )
141+ }
142+ } else {
143+ if err := c .ReadPacketTo (buf , c .reader ); err != nil {
144+ return nil , errors .Trace (err )
145+ }
99146}
100147
101148readBytes := buf .Bytes ()
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
145192return written , nil
146193}
147194
148- func (c * Conn ) ReadPacketTo (w io.Writer ) error {
149- if _ , err := io .ReadFull (c . reader , c .header [:4 ]); err != nil {
195+ func (c * Conn ) ReadPacketTo (w io.Writer , r io. Reader ) error {
196+ if _ , err := io .ReadFull (r , c .header [:4 ]); err != nil {
150197return errors .Wrapf (ErrBadConn , "io.ReadFull(header) failed. err %v" , err )
151198}
152199
@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
164211buf .Grow (length )
165212}
166213
167- if n , err := c .copyN (w , c . reader , int64 (length )); err != nil {
214+ if n , err := c .copyN (w , r , int64 (length )); err != nil {
168215return errors .Wrapf (ErrBadConn , "io.CopyN failed. err %v, copied %v, expected %v" , err , n , length )
169216} else if n != int64 (length ) {
170217return errors .Wrapf (ErrBadConn , "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected" , n , length )
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
173220return nil
174221}
175222
176- if err : = c .ReadPacketTo (w ); err != nil {
223+ if err = c .ReadPacketTo (w , r ); err != nil {
177224return errors .Wrap (err , "ReadPacketTo failed" )
178225}
179226}
@@ -209,14 +256,95 @@ func (c *Conn) WritePacket(data []byte) error {
209256data [2 ] = byte (length >> 16 )
210257data [3 ] = c .Sequence
211258
212- if n , err := c .Write (data ); err != nil {
213- return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
214- } else if n != len (data ) {
215- return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
259+ switch c .Compression {
260+ case MYSQL_COMPRESS_NONE :
261+ if n , err := c .Write (data ); err != nil {
262+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
263+ } else if n != len (data ) {
264+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
265+ }
266+ case MYSQL_COMPRESS_ZLIB :
267+ fallthrough
268+ case MYSQL_COMPRESS_ZSTD :
269+ if n , err := c .writeCompressed (data ); err != nil {
270+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
271+ } else if n != len (data ) {
272+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
273+ }
274+ c .compressedReader = nil
275+ c .compressedReaderActive = false
276+ default :
277+ return errors .Wrapf (ErrBadConn , "Write failed. Unsuppored compression algorithm set" )
278+ }
279+
280+ c .Sequence ++
281+ return nil
282+ }
283+
284+ func (c * Conn ) writeCompressed (data []byte ) (n int , err error ) {
285+ var compressedLength , uncompressedLength int
286+ var payload , compressedPacket bytes.Buffer
287+ var w io.WriteCloser
288+ minCompressLength := 50
289+ compressedHeader := make ([]byte , 7 )
290+
291+ switch c .Compression {
292+ case MYSQL_COMPRESS_ZLIB :
293+ w , err = zlib .NewWriterLevel (& payload , zlib .HuffmanOnly )
294+ case MYSQL_COMPRESS_ZSTD :
295+ w = zstd .NewWriter (& payload )
296+ }
297+ if err != nil {
298+ return 0 , err
299+ }
300+
301+ if len (data ) > minCompressLength {
302+ uncompressedLength = len (data )
303+ n , err = w .Write (data )
304+ if err != nil {
305+ return 0 , err
306+ }
307+ err = w .Close ()
308+ if err != nil {
309+ return 0 , err
310+ }
311+ }
312+
313+ if len (data ) > minCompressLength {
314+ compressedLength = len (payload .Bytes ())
315+ } else {
316+ compressedLength = len (data )
317+ }
318+
319+ c .CompressedSequence = 0
320+ compressedHeader [0 ] = byte (compressedLength )
321+ compressedHeader [1 ] = byte (compressedLength >> 8 )
322+ compressedHeader [2 ] = byte (compressedLength >> 16 )
323+ compressedHeader [3 ] = c .CompressedSequence
324+ compressedHeader [4 ] = byte (uncompressedLength )
325+ compressedHeader [5 ] = byte (uncompressedLength >> 8 )
326+ compressedHeader [6 ] = byte (uncompressedLength >> 16 )
327+ _ , err = compressedPacket .Write (compressedHeader )
328+ if err != nil {
329+ return 0 , err
330+ }
331+ c .CompressedSequence ++
332+
333+ if len (data ) > minCompressLength {
334+ _ , err = compressedPacket .Write (payload .Bytes ())
216335} else {
217- c .Sequence ++
218- return nil
336+ n , err = compressedPacket .Write (data )
337+ }
338+ if err != nil {
339+ return 0 , err
219340}
341+
342+ _ , err = c .Write (compressedPacket .Bytes ())
343+ if err != nil {
344+ return 0 , err
345+ }
346+
347+ return n , nil
220348}
221349
222350// WriteClearAuthPacket: Client clear text authentication packet
0 commit comments