Skip to content

Commit 2db2f66

Browse files
committed
pool flate readers
1 parent 3ab3a8b commit 2db2f66

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

compression.go

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@ import (
1414

1515
var (
1616
flateWriterPool = sync.Pool{}
17+
flateReaderPool = sync.Pool{}
1718
)
1819

19-
func decompressNoContextTakeover(r io.Reader) io.Reader {
20+
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
2021
const tail =
2122
// Add four bytes as specified in RFC
2223
"\x00\x00\xff\xff" +
2324
// Add final block to squelch unexpected EOF error from flate reader.
2425
"\x01\x00\x00\xff\xff"
25-
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
26+
27+
i := flateReaderPool.Get()
28+
if i == nil {
29+
i = flate.NewReader(nil)
30+
}
31+
i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
32+
return &flateReadWrapper{i.(io.ReadCloser)}
2633
}
2734

2835
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
@@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
3643
fw = i.(*flate.Writer)
3744
fw.Reset(tw)
3845
}
39-
return &flateWrapper{fw: fw, tw: tw}, err
46+
return &flateWriteWrapper{fw: fw, tw: tw}, err
4047
}
4148

4249
// truncWriter is an io.Writer that writes all but the last four bytes of the
@@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
7582
return n + nn, err
7683
}
7784

78-
type flateWrapper struct {
85+
type flateWriteWrapper struct {
7986
fw *flate.Writer
8087
tw *truncWriter
8188
}
8289

83-
func (w *flateWrapper) Write(p []byte) (int, error) {
90+
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
8491
if w.fw == nil {
8592
return 0, errWriteClosed
8693
}
8794
return w.fw.Write(p)
8895
}
8996

90-
func (w *flateWrapper) Close() error {
97+
func (w *flateWriteWrapper) Close() error {
9198
if w.fw == nil {
9299
return errWriteClosed
93100
}
@@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error {
103110
}
104111
return err2
105112
}
113+
114+
type flateReadWrapper struct {
115+
fr io.ReadCloser
116+
}
117+
118+
func (r *flateReadWrapper) Read(p []byte) (int, error) {
119+
if r.fr == nil {
120+
return 0, io.ErrClosedPipe
121+
}
122+
n, err := r.fr.Read(p)
123+
if err == io.EOF {
124+
// Preemptively place the reader back in the pool. This helps with
125+
// scenarios where the application does not call NextReader() soon after
126+
// this final read.
127+
r.Close()
128+
}
129+
return n, err
130+
}
131+
132+
func (r *flateReadWrapper) Close() error {
133+
if r.fr == nil {
134+
return io.ErrClosedPipe
135+
}
136+
err := r.fr.Close()
137+
flateReaderPool.Put(r.fr)
138+
r.fr = nil
139+
return err
140+
}

conn.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ type Conn struct {
238238
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
239239

240240
// Read fields
241+
reader io.ReadCloser // the current reader returned to the application
241242
readErr error
242243
br *bufio.Reader
243244
readRemaining int64 // bytes remaining in current frame.
@@ -253,7 +254,7 @@ type Conn struct {
253254
messageReader *messageReader // the current low-level reader
254255

255256
readDecompress bool // whether last read frame had RSV1 set
256-
newDecompressionReader func(io.Reader) io.Reader
257+
newDecompressionReader func(io.Reader) io.ReadCloser
257258
}
258259

259260
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error {
855856
// permanent. Once this method returns a non-nil error, all subsequent calls to
856857
// this method return the same error.
857858
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
859+
// Close previous reader, only relevant for decompression.
860+
if c.reader != nil {
861+
c.reader.Close()
862+
c.reader = nil
863+
}
858864

859865
c.messageReader = nil
860866
c.readLength = 0
@@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
867873
}
868874
if frameType == TextMessage || frameType == BinaryMessage {
869875
c.messageReader = &messageReader{c}
870-
var r io.Reader = c.messageReader
876+
c.reader = c.messageReader
871877
if c.readDecompress {
872-
r = c.newDecompressionReader(r)
878+
c.reader = c.newDecompressionReader(c.reader)
873879
}
874-
return frameType, r, nil
880+
return frameType, c.reader, nil
875881
}
876882
}
877883

@@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
933939
return 0, err
934940
}
935941

942+
func (r *messageReader) Close() error {
943+
return nil
944+
}
945+
936946
// ReadMessage is a helper method for getting a reader using NextReader and
937947
// reading from that reader to a buffer.
938948
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {

0 commit comments

Comments
 (0)