@@ -18,11 +18,19 @@ import (
1818)
1919
2020const (
21+ // Frame header byte 0 bits from Section 5.2 of RFC 6455
22+ finalBit = 1 << 7
23+ rsv1Bit = 1 << 6
24+ rsv2Bit = 1 << 5
25+ rsv3Bit = 1 << 4
26+
27+ // Frame header byte 1 bits from Section 5.2 of RFC 6455
28+ maskBit = 1 << 7
29+
2130maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
2231maxControlFramePayloadSize = 125
23- finalBit = 1 << 7
24- maskBit = 1 << 7
25- writeWait = time .Second
32+
33+ writeWait = time .Second
2634
2735defaultReadBufferSize = 4096
2836defaultWriteBufferSize = 4096
@@ -230,17 +238,20 @@ type Conn struct {
230238subprotocol string
231239
232240// Write fields
233- mu chan bool // used as mutex to protect write to conn and closeSent
234- closeSent bool // true if close message was sent
235-
236- // Message writer fields.
241+ mu chan bool // used as mutex to protect write to conn and closeSent
242+ closeSent bool // whether close message was sent
237243writeErr error
238244writeBuf []byte // frame is constructed in this buffer.
239245writePos int // end of data in writeBuf.
240246writeFrameType int // type of the current frame.
241247writeDeadline time.Time
248+ messageWriter * messageWriter // the current low-level message writer
249+ writer io.WriteCloser // the current writer returned to the application
242250isWriting bool // for best-effort concurrent write detection
243- messageWriter * messageWriter // the current writer
251+
252+ enableWriteCompression bool
253+ writeCompress bool // whether next call to flushFrame should set RSV1
254+ newCompressionWriter func (io.WriteCloser ) (io.WriteCloser , error )
244255
245256// Read fields
246257readErr error
@@ -254,7 +265,10 @@ type Conn struct {
254265handlePong func (string ) error
255266handlePing func (string ) error
256267readErrCount int
257- messageReader * messageReader // the current reader
268+ messageReader * messageReader // the current low-level reader
269+
270+ readDecompress bool // whether last read frame had RSV1 set
271+ newDecompressionReader func (io.Reader ) io.Reader
258272}
259273
260274func newConn (conn net.Conn , isServer bool , readBufferSize , writeBufferSize int ) * Conn {
@@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
272286}
273287
274288c := & Conn {
275- isServer : isServer ,
276- br : bufio .NewReaderSize (conn , readBufferSize ),
277- conn : conn ,
278- mu : mu ,
279- readFinal : true ,
280- writeBuf : make ([]byte , writeBufferSize + maxFrameHeaderSize ),
281- writeFrameType : noFrame ,
282- writePos : maxFrameHeaderSize ,
289+ isServer : isServer ,
290+ br : bufio .NewReaderSize (conn , readBufferSize ),
291+ conn : conn ,
292+ mu : mu ,
293+ readFinal : true ,
294+ writeBuf : make ([]byte , writeBufferSize + maxFrameHeaderSize ),
295+ writeFrameType : noFrame ,
296+ writePos : maxFrameHeaderSize ,
297+ enableWriteCompression : true ,
283298}
284299c .SetPingHandler (nil )
285300c .SetPongHandler (nil )
@@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
403418return nil , c .writeErr
404419}
405420
406- if c .writeFrameType != noFrame {
407- if err := c .flushFrame (true , nil ); err != nil {
421+ // Close previous writer if not already closed by the application. It's
422+ // probably better to return an error in this situation, but we cannot
423+ // change this without breaking existing applications.
424+ if c .writer != nil {
425+ err := c .writer .Close ()
426+ if err != nil {
408427return nil , err
409428}
410429}
@@ -414,18 +433,32 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
414433}
415434
416435c .writeFrameType = messageType
417- w := & messageWriter {c }
418- c .messageWriter = w
436+ c .messageWriter = & messageWriter {c }
437+
438+ var w io.WriteCloser = c .messageWriter
439+ if c .newCompressionWriter != nil && c .enableWriteCompression && isData (messageType ) {
440+ c .writeCompress = true
441+ var err error
442+ w , err = c .newCompressionWriter (w )
443+ if err != nil {
444+ c .writer .Close ()
445+ return nil , err
446+ }
447+ }
448+
419449return w , nil
420450}
421451
452+ // flushFrame writes buffered data and extra as a frame to the network. The
453+ // final argument indicates that this is the last frame in the message.
422454func (c * Conn ) flushFrame (final bool , extra []byte ) error {
423455length := c .writePos - maxFrameHeaderSize + len (extra )
424456
425457// Check for invalid control frames.
426458if isControl (c .writeFrameType ) &&
427459(! final || length > maxControlFramePayloadSize ) {
428460c .messageWriter = nil
461+ c .writer = nil
429462c .writeFrameType = noFrame
430463c .writePos = maxFrameHeaderSize
431464return errInvalidControlFrame
@@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
435468if final {
436469b0 |= finalBit
437470}
471+ if c .writeCompress {
472+ b0 |= rsv1Bit
473+ }
474+ c .writeCompress = false
475+
438476b1 := byte (0 )
439477if ! c .isServer {
440478b1 |= maskBit
@@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
494532c .writeFrameType = continuationFrame
495533if final {
496534c .messageWriter = nil
535+ c .writer = nil
497536c .writeFrameType = noFrame
498537}
499538return c .writeErr
@@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
526565return n , nil
527566}
528567
529- func (w * messageWriter ) write ( final bool , p []byte ) (int , error ) {
568+ func (w * messageWriter ) Write ( p []byte ) (int , error ) {
530569if err := w .err (); err != nil {
531570return 0 , err
532571}
533572
534573if len (p ) > 2 * len (w .c .writeBuf ) && w .c .isServer {
535574// Don't buffer large messages.
536- err := w .c .flushFrame (final , p )
575+ err := w .c .flushFrame (false , p )
537576if err != nil {
538577return 0 , err
539578}
@@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
553592return nn , nil
554593}
555594
556- func (w * messageWriter ) Write (p []byte ) (int , error ) {
557- return w .write (false , p )
558- }
559-
560595func (w * messageWriter ) WriteString (p string ) (int , error ) {
561596if err := w .err (); err != nil {
562597return 0 , err
@@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
658693
659694final := p [0 ]& finalBit != 0
660695frameType := int (p [0 ] & 0xf )
661- reserved := int ((p [0 ] >> 4 ) & 0x7 )
662696mask := p [1 ]& maskBit != 0
663697c .readRemaining = int64 (p [1 ] & 0x7f )
664698
665- if reserved != 0 {
666- return noFrame , c .handleProtocolError ("unexpected reserved bits " + strconv .Itoa (reserved ))
699+ c .readDecompress = false
700+ if c .newDecompressionReader != nil && (p [0 ]& rsv1Bit ) != 0 {
701+ c .readDecompress = true
702+ p [0 ] &^= rsv1Bit
703+ }
704+
705+ if rsv := p [0 ] & (rsv1Bit | rsv2Bit | rsv3Bit ); rsv != 0 {
706+ return noFrame , c .handleProtocolError ("unexpected reserved bits 0x" + strconv .FormatInt (int64 (rsv ), 16 ))
667707}
668708
669709switch frameType {
@@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
807847break
808848}
809849if frameType == TextMessage || frameType == BinaryMessage {
810- r := & messageReader {c }
811- c .messageReader = r
850+ c .messageReader = & messageReader {c }
851+ var r io.Reader = c .messageReader
852+ if c .readDecompress {
853+ r = c .newDecompressionReader (r )
854+ }
812855return frameType , r , nil
813856}
814857}
0 commit comments