@@ -3,6 +3,7 @@ package packet
33import (
44"bufio"
55"bytes"
6+ "compress/zlib"
67"crypto/rand"
78"crypto/rsa"
89"crypto/sha1"
@@ -12,6 +13,7 @@ import (
1213"net"
1314"sync"
1415
16+ "github.com/DataDog/zstd"
1517. "github.com/go-mysql-org/go-mysql/mysql"
1618"github.com/go-mysql-org/go-mysql/utils"
1719"github.com/pingcap/errors"
@@ -56,6 +58,16 @@ type Conn struct {
5658header [4 ]byte
5759
5860Sequence uint8
61+
62+ Compression uint8
63+
64+ CompressedSequence uint8
65+
66+ compressedHeader [7 ]byte
67+
68+ compressedReaderActive bool
69+
70+ compressedReader io.Reader
5971}
6072
6173func NewConn (conn net.Conn ) * Conn {
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
94106utils .BytesBufferPut (buf )
95107}()
96108
97- if err := c .ReadPacketTo (buf ); err != nil {
98- return nil , errors .Trace (err )
109+ if c .Compression != MYSQL_COMPRESS_NONE {
110+ if ! c .compressedReaderActive {
111+ if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
112+ return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
113+ }
114+
115+ compressedSequence := c .compressedHeader [3 ]
116+ uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
117+ if compressedSequence != c .CompressedSequence {
118+ return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
119+ compressedSequence , c .CompressedSequence )
120+ }
121+
122+ if uncompressedLength > 0 {
123+ var err error
124+ switch c .Compression {
125+ case MYSQL_COMPRESS_ZLIB :
126+ c .compressedReader , err = zlib .NewReader (c .reader )
127+ case MYSQL_COMPRESS_ZSTD :
128+ c .compressedReader = zstd .NewReader (c .reader )
129+ }
130+ if err != nil {
131+ return nil , err
132+ }
133+ }
134+ c .compressedReaderActive = true
135+ }
136+ }
137+
138+ if c .compressedReader != nil {
139+ if err := c .ReadPacketTo (buf , c .compressedReader ); err != nil {
140+ return nil , errors .Trace (err )
141+ }
142+ } else {
143+ if err := c .ReadPacketTo (buf , c .reader ); err != nil {
144+ return nil , errors .Trace (err )
145+ }
99146}
100147
101148readBytes := buf .Bytes ()
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
145192return written , nil
146193}
147194
148- func (c * Conn ) ReadPacketTo (w io.Writer ) error {
149- if _ , err := io .ReadFull (c . reader , c .header [:4 ]); err != nil {
195+ func (c * Conn ) ReadPacketTo (w io.Writer , r io. Reader ) error {
196+ if _ , err := io .ReadFull (r , c .header [:4 ]); err != nil {
150197return errors .Wrapf (ErrBadConn , "io.ReadFull(header) failed. err %v" , err )
151198}
152199
@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
164211buf .Grow (length )
165212}
166213
167- if n , err := c .copyN (w , c . reader , int64 (length )); err != nil {
214+ if n , err := c .copyN (w , r , int64 (length )); err != nil {
168215return errors .Wrapf (ErrBadConn , "io.CopyN failed. err %v, copied %v, expected %v" , err , n , length )
169216} else if n != int64 (length ) {
170217return errors .Wrapf (ErrBadConn , "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected" , n , length )
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
173220return nil
174221}
175222
176- if err : = c .ReadPacketTo (w ); err != nil {
223+ if err = c .ReadPacketTo (w , r ); err != nil {
177224return errors .Wrap (err , "ReadPacketTo failed" )
178225}
179226}
@@ -209,14 +256,93 @@ func (c *Conn) WritePacket(data []byte) error {
209256data [2 ] = byte (length >> 16 )
210257data [3 ] = c .Sequence
211258
212- if n , err := c .Write (data ); err != nil {
213- return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
214- } else if n != len (data ) {
215- return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
259+ switch c .Compression {
260+ case MYSQL_COMPRESS_NONE :
261+ if n , err := c .Write (data ); err != nil {
262+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
263+ } else if n != len (data ) {
264+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
265+ }
266+ case MYSQL_COMPRESS_ZLIB , MYSQL_COMPRESS_ZSTD :
267+ if n , err := c .writeCompressed (data ); err != nil {
268+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
269+ } else if n != len (data ) {
270+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
271+ }
272+ c .compressedReader = nil
273+ c .compressedReaderActive = false
274+ default :
275+ return errors .Wrapf (ErrBadConn , "Write failed. Unsuppored compression algorithm set" )
276+ }
277+
278+ c .Sequence ++
279+ return nil
280+ }
281+
282+ func (c * Conn ) writeCompressed (data []byte ) (n int , err error ) {
283+ var compressedLength , uncompressedLength int
284+ var payload , compressedPacket bytes.Buffer
285+ var w io.WriteCloser
286+ minCompressLength := 50
287+ compressedHeader := make ([]byte , 7 )
288+
289+ switch c .Compression {
290+ case MYSQL_COMPRESS_ZLIB :
291+ w , err = zlib .NewWriterLevel (& payload , zlib .HuffmanOnly )
292+ case MYSQL_COMPRESS_ZSTD :
293+ w = zstd .NewWriter (& payload )
294+ }
295+ if err != nil {
296+ return 0 , err
297+ }
298+
299+ if len (data ) > minCompressLength {
300+ uncompressedLength = len (data )
301+ n , err = w .Write (data )
302+ if err != nil {
303+ return 0 , err
304+ }
305+ err = w .Close ()
306+ if err != nil {
307+ return 0 , err
308+ }
309+ }
310+
311+ if len (data ) > minCompressLength {
312+ compressedLength = len (payload .Bytes ())
313+ } else {
314+ compressedLength = len (data )
315+ }
316+
317+ c .CompressedSequence = 0
318+ compressedHeader [0 ] = byte (compressedLength )
319+ compressedHeader [1 ] = byte (compressedLength >> 8 )
320+ compressedHeader [2 ] = byte (compressedLength >> 16 )
321+ compressedHeader [3 ] = c .CompressedSequence
322+ compressedHeader [4 ] = byte (uncompressedLength )
323+ compressedHeader [5 ] = byte (uncompressedLength >> 8 )
324+ compressedHeader [6 ] = byte (uncompressedLength >> 16 )
325+ _ , err = compressedPacket .Write (compressedHeader )
326+ if err != nil {
327+ return 0 , err
328+ }
329+ c .CompressedSequence ++
330+
331+ if len (data ) > minCompressLength {
332+ _ , err = compressedPacket .Write (payload .Bytes ())
216333} else {
217- c .Sequence ++
218- return nil
334+ n , err = compressedPacket .Write (data )
335+ }
336+ if err != nil {
337+ return 0 , err
219338}
339+
340+ _ , err = c .Write (compressedPacket .Bytes ())
341+ if err != nil {
342+ return 0 , err
343+ }
344+
345+ return n , nil
220346}
221347
222348// WriteClearAuthPacket: Client clear text authentication packet
0 commit comments