Skip to content

Commit a87eae1

Browse files
committed
Add hooks to support RFC 7692 (per-message compression extension)
Add newCompressionWriter and newDecompressionReader fields to Conn. When not nil, these functions are used to create a compression/decompression wrapper around an underlying message writer/reader. Add code to set and check for RSV1 frame header bit. Add functions compressNoContextTakeover and decompressNoContextTakeover for creating no context takeover wrappers around an underlying message writer/reader. Work remaining: - Add fields to Dialer and Upgrader for specifying compression options. - Add compression negotiation to Dialer and Upgrader. - Add function to enable/disable write compression: // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable }
1 parent b5389d0 commit a87eae1

File tree

3 files changed

+191
-32
lines changed

3 files changed

+191
-32
lines changed

compression.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package websocket
6+
7+
import (
8+
"compress/flate"
9+
"errors"
10+
"io"
11+
"strings"
12+
)
13+
14+
func decompressNoContextTakeover(r io.Reader) io.Reader {
15+
const tail =
16+
// Add four bytes as specified in RFC
17+
"\x00\x00\xff\xff" +
18+
// Add final block to squelch unexpected EOF error from flate reader.
19+
"\x01\x00\x00\xff\xff"
20+
21+
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
22+
}
23+
24+
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
25+
tw := &truncWriter{w: w}
26+
fw, err := flate.NewWriter(tw, 3)
27+
return &flateWrapper{fw: fw, tw: tw}, err
28+
}
29+
30+
// truncWriter is an io.Writer that writes all but the last four bytes of the
31+
// stream to another io.Writer.
32+
type truncWriter struct {
33+
w io.WriteCloser
34+
n int
35+
p [4]byte
36+
}
37+
38+
func (w *truncWriter) Write(p []byte) (int, error) {
39+
n := 0
40+
41+
// fill buffer first for simplicity.
42+
if w.n < len(w.p) {
43+
n = copy(w.p[w.n:], p)
44+
p = p[n:]
45+
w.n += n
46+
if len(p) == 0 {
47+
return n, nil
48+
}
49+
}
50+
51+
m := len(p)
52+
if m > len(w.p) {
53+
m = len(w.p)
54+
}
55+
56+
if nn, err := w.w.Write(w.p[:m]); err != nil {
57+
return n + nn, err
58+
}
59+
60+
copy(w.p[:], w.p[m:])
61+
copy(w.p[len(w.p)-m:], p[len(p)-m:])
62+
nn, err := w.w.Write(p[:len(p)-m])
63+
return n + nn, err
64+
}
65+
66+
type flateWrapper struct {
67+
fw *flate.Writer
68+
tw *truncWriter
69+
}
70+
71+
func (w *flateWrapper) Write(p []byte) (int, error) {
72+
return w.fw.Write(p)
73+
}
74+
75+
func (w *flateWrapper) Close() error {
76+
err1 := w.fw.Flush()
77+
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
78+
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
79+
}
80+
err2 := w.tw.w.Close()
81+
if err1 != nil {
82+
return err1
83+
}
84+
return err2
85+
}

compression_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package websocket
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
)
8+
9+
type nopCloser struct{ io.Writer }
10+
11+
func (nopCloser) Close() error { return nil }
12+
13+
func TestTruncWriter(t *testing.T) {
14+
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
15+
for n := 1; n <= 10; n++ {
16+
var b bytes.Buffer
17+
w := &truncWriter{w: nopCloser{&b}}
18+
p := []byte(data)
19+
for len(p) > 0 {
20+
m := len(p)
21+
if m > n {
22+
m = n
23+
}
24+
w.Write(p[:m])
25+
p = p[m:]
26+
}
27+
if b.String() != data[:len(data)-len(w.p)] {
28+
t.Errorf("%d: %q", n, b.String())
29+
}
30+
}
31+
}

conn.go

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@ import (
1818
)
1919

2020
const (
21+
// Frame header byte 0 bits from Section 5.2 of RFC 6455
22+
finalBit = 1 << 7
23+
rsv1Bit = 1 << 6
24+
rsv2Bit = 1 << 5
25+
rsv3Bit = 1 << 4
26+
27+
// Frame header byte 1 bits from Section 5.2 of RFC 6455
28+
maskBit = 1 << 7
29+
2130
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
2231
maxControlFramePayloadSize = 125
23-
finalBit = 1 << 7
24-
maskBit = 1 << 7
25-
writeWait = time.Second
32+
33+
writeWait = time.Second
2634

2735
defaultReadBufferSize = 4096
2836
defaultWriteBufferSize = 4096
@@ -230,17 +238,20 @@ type Conn struct {
230238
subprotocol string
231239

232240
// Write fields
233-
mu chan bool // used as mutex to protect write to conn and closeSent
234-
closeSent bool // true if close message was sent
235-
236-
// Message writer fields.
241+
mu chan bool // used as mutex to protect write to conn and closeSent
242+
closeSent bool // whether close message was sent
237243
writeErr error
238244
writeBuf []byte // frame is constructed in this buffer.
239245
writePos int // end of data in writeBuf.
240246
writeFrameType int // type of the current frame.
241247
writeDeadline time.Time
248+
messageWriter *messageWriter // the current low-level message writer
249+
writer io.WriteCloser // the current writer returned to the application
242250
isWriting bool // for best-effort concurrent write detection
243-
messageWriter *messageWriter // the current writer
251+
252+
enableWriteCompression bool
253+
writeCompress bool // whether next call to flushFrame should set RSV1
254+
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
244255

245256
// Read fields
246257
readErr error
@@ -254,7 +265,10 @@ type Conn struct {
254265
handlePong func(string) error
255266
handlePing func(string) error
256267
readErrCount int
257-
messageReader *messageReader // the current reader
268+
messageReader *messageReader // the current low-level reader
269+
270+
readDecompress bool // whether last read frame had RSV1 set
271+
newDecompressionReader func(io.Reader) io.Reader
258272
}
259273

260274
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
272286
}
273287

274288
c := &Conn{
275-
isServer: isServer,
276-
br: bufio.NewReaderSize(conn, readBufferSize),
277-
conn: conn,
278-
mu: mu,
279-
readFinal: true,
280-
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
281-
writeFrameType: noFrame,
282-
writePos: maxFrameHeaderSize,
289+
isServer: isServer,
290+
br: bufio.NewReaderSize(conn, readBufferSize),
291+
conn: conn,
292+
mu: mu,
293+
readFinal: true,
294+
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
295+
writeFrameType: noFrame,
296+
writePos: maxFrameHeaderSize,
297+
enableWriteCompression: true,
283298
}
284299
c.SetPingHandler(nil)
285300
c.SetPongHandler(nil)
@@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
403418
return nil, c.writeErr
404419
}
405420

406-
if c.writeFrameType != noFrame {
407-
if err := c.flushFrame(true, nil); err != nil {
421+
// Close previous writer if not already closed by the application. It's
422+
// probably better to return an error in this situation, but we cannot
423+
// change this without breaking existing applications.
424+
if c.writer != nil {
425+
err := c.writer.Close()
426+
if err != nil {
408427
return nil, err
409428
}
410429
}
@@ -414,18 +433,32 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
414433
}
415434

416435
c.writeFrameType = messageType
417-
w := &messageWriter{c}
418-
c.messageWriter = w
436+
c.messageWriter = &messageWriter{c}
437+
438+
var w io.WriteCloser = c.messageWriter
439+
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
440+
c.writeCompress = true
441+
var err error
442+
w, err = c.newCompressionWriter(w)
443+
if err != nil {
444+
c.writer.Close()
445+
return nil, err
446+
}
447+
}
448+
419449
return w, nil
420450
}
421451

452+
// flushFrame writes buffered data and extra as a frame to the network. The
453+
// final argument indicates that this is the last frame in the message.
422454
func (c *Conn) flushFrame(final bool, extra []byte) error {
423455
length := c.writePos - maxFrameHeaderSize + len(extra)
424456

425457
// Check for invalid control frames.
426458
if isControl(c.writeFrameType) &&
427459
(!final || length > maxControlFramePayloadSize) {
428460
c.messageWriter = nil
461+
c.writer = nil
429462
c.writeFrameType = noFrame
430463
c.writePos = maxFrameHeaderSize
431464
return errInvalidControlFrame
@@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
435468
if final {
436469
b0 |= finalBit
437470
}
471+
if c.writeCompress {
472+
b0 |= rsv1Bit
473+
}
474+
c.writeCompress = false
475+
438476
b1 := byte(0)
439477
if !c.isServer {
440478
b1 |= maskBit
@@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
494532
c.writeFrameType = continuationFrame
495533
if final {
496534
c.messageWriter = nil
535+
c.writer = nil
497536
c.writeFrameType = noFrame
498537
}
499538
return c.writeErr
@@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
526565
return n, nil
527566
}
528567

529-
func (w *messageWriter) write(final bool, p []byte) (int, error) {
568+
func (w *messageWriter) Write(p []byte) (int, error) {
530569
if err := w.err(); err != nil {
531570
return 0, err
532571
}
533572

534573
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
535574
// Don't buffer large messages.
536-
err := w.c.flushFrame(final, p)
575+
err := w.c.flushFrame(false, p)
537576
if err != nil {
538577
return 0, err
539578
}
@@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
553592
return nn, nil
554593
}
555594

556-
func (w *messageWriter) Write(p []byte) (int, error) {
557-
return w.write(false, p)
558-
}
559-
560595
func (w *messageWriter) WriteString(p string) (int, error) {
561596
if err := w.err(); err != nil {
562597
return 0, err
@@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
658693

659694
final := p[0]&finalBit != 0
660695
frameType := int(p[0] & 0xf)
661-
reserved := int((p[0] >> 4) & 0x7)
662696
mask := p[1]&maskBit != 0
663697
c.readRemaining = int64(p[1] & 0x7f)
664698

665-
if reserved != 0 {
666-
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
699+
c.readDecompress = false
700+
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
701+
c.readDecompress = true
702+
p[0] &^= rsv1Bit
703+
}
704+
705+
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
706+
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
667707
}
668708

669709
switch frameType {
@@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
807847
break
808848
}
809849
if frameType == TextMessage || frameType == BinaryMessage {
810-
r := &messageReader{c}
811-
c.messageReader = r
850+
c.messageReader = &messageReader{c}
851+
var r io.Reader = c.messageReader
852+
if c.readDecompress {
853+
r = c.newDecompressionReader(r)
854+
}
812855
return frameType, r, nil
813856
}
814857
}

0 commit comments

Comments
 (0)