@@ -238,23 +238,23 @@ type Conn struct {
238238writeBuf []byte // frame is constructed in this buffer.
239239writePos int // end of data in writeBuf.
240240writeFrameType int // type of the current frame.
241- writeSeq int // incremented to invalidate message writers.
242241writeDeadline time.Time
243- isWriting bool // for best-effort concurrent write detection
242+ isWriting bool // for best-effort concurrent write detection
243+ messageWriter * messageWriter // the current writer
244244
245245// Read fields
246246readErr error
247247br * bufio.Reader
248248readRemaining int64 // bytes remaining in current frame.
249249readFinal bool // true the current message has more frames.
250- readSeq int // incremented to invalidate message readers.
251250readLength int64 // Message size.
252251readLimit int64 // Maximum message size.
253252readMaskPos int
254253readMaskKey [4 ]byte
255254handlePong func (string ) error
256255handlePing func (string ) error
257256readErrCount int
257+ messageReader * messageReader // the current reader
258258}
259259
260260func newConn (conn net.Conn , isServer bool , readBufferSize , writeBufferSize int ) * Conn {
@@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
264264if readBufferSize == 0 {
265265readBufferSize = defaultReadBufferSize
266266}
267+ if readBufferSize < maxControlFramePayloadSize {
268+ readBufferSize = maxControlFramePayloadSize
269+ }
267270if writeBufferSize == 0 {
268271writeBufferSize = defaultWriteBufferSize
269272}
@@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
390393return hideTempErr (err )
391394}
392395
393- // NextWriter returns a writer for the next message to send. The writer's
394- // Close method flushes the complete message to the network.
396+ // NextWriter returns a writer for the next message to send. The writer's Close
397+ // method flushes the complete message to the network.
395398//
396399// There can be at most one open writer on a connection. NextWriter closes the
397400// previous writer if the application has not already done so.
@@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
411414}
412415
413416c .writeFrameType = messageType
414- return messageWriter {c , c .writeSeq }, nil
417+ w := & messageWriter {c }
418+ c .messageWriter = w
419+ return w , nil
415420}
416421
417422func (c * Conn ) flushFrame (final bool , extra []byte ) error {
@@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
420425// Check for invalid control frames.
421426if isControl (c .writeFrameType ) &&
422427(! final || length > maxControlFramePayloadSize ) {
423- c .writeSeq ++
428+ c .messageWriter = nil
424429c .writeFrameType = noFrame
425430c .writePos = maxFrameHeaderSize
426431return errInvalidControlFrame
@@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
488493c .writePos = maxFrameHeaderSize
489494c .writeFrameType = continuationFrame
490495if final {
491- c .writeSeq ++
496+ c .messageWriter = nil
492497c .writeFrameType = noFrame
493498}
494499return c .writeErr
495500}
496501
497- type messageWriter struct {
498- c * Conn
499- seq int
500- }
502+ type messageWriter struct { c * Conn }
501503
502- func (w messageWriter ) err () error {
504+ func (w * messageWriter ) err () error {
503505c := w .c
504- if c .writeSeq != w . seq {
506+ if c .messageWriter != w {
505507return errWriteClosed
506508}
507509if c .writeErr != nil {
@@ -510,7 +512,7 @@ func (w messageWriter) err() error {
510512return nil
511513}
512514
513- func (w messageWriter ) ncopy (max int ) (int , error ) {
515+ func (w * messageWriter ) ncopy (max int ) (int , error ) {
514516n := len (w .c .writeBuf ) - w .c .writePos
515517if n <= 0 {
516518if err := w .c .flushFrame (false , nil ); err != nil {
@@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
524526return n , nil
525527}
526528
527- func (w messageWriter ) write (final bool , p []byte ) (int , error ) {
529+ func (w * messageWriter ) write (final bool , p []byte ) (int , error ) {
528530if err := w .err (); err != nil {
529531return 0 , err
530532}
@@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
551553return nn , nil
552554}
553555
554- func (w messageWriter ) Write (p []byte ) (int , error ) {
556+ func (w * messageWriter ) Write (p []byte ) (int , error ) {
555557return w .write (false , p )
556558}
557559
558- func (w messageWriter ) WriteString (p string ) (int , error ) {
560+ func (w * messageWriter ) WriteString (p string ) (int , error ) {
559561if err := w .err (); err != nil {
560562return 0 , err
561563}
@@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
573575return nn , nil
574576}
575577
576- func (w messageWriter ) ReadFrom (r io.Reader ) (nn int64 , err error ) {
578+ func (w * messageWriter ) ReadFrom (r io.Reader ) (nn int64 , err error ) {
577579if err := w .err (); err != nil {
578580return 0 , err
579581}
@@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
598600return nn , err
599601}
600602
601- func (w messageWriter ) Close () error {
603+ func (w * messageWriter ) Close () error {
602604if err := w .err (); err != nil {
603605return err
604606}
@@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
608610// WriteMessage is a helper method for getting a writer using NextWriter,
609611// writing the message and closing the writer.
610612func (c * Conn ) WriteMessage (messageType int , data []byte ) error {
611- wr , err := c .NextWriter (messageType )
613+ w , err := c .NextWriter (messageType )
612614if err != nil {
613615return err
614616}
615- w := wr .(messageWriter )
616- if _ , err := w .write (true , data ); err != nil {
617+ if _ , ok := w .(* messageWriter ); ok && c .isServer {
618+ // Optimize write as a single frame.
619+ n := copy (c .writeBuf [c .writePos :], data )
620+ c .writePos += n
621+ data = data [n :]
622+ err = c .flushFrame (true , data )
617623return err
618624}
619- if c .writeSeq == w .seq {
620- if err := c .flushFrame (true , nil ); err != nil {
621- return err
622- }
625+ if _ , err = w .Write (data ); err != nil {
626+ return err
623627}
624- return nil
628+ return w . Close ()
625629}
626630
627631// SetWriteDeadline sets the write deadline on the underlying network
@@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
635639
636640// Read methods
637641
638- // readFull is like io.ReadFull except that io.EOF is never returned.
639- func (c * Conn ) readFull (p []byte ) (err error ) {
640- var n int
641- for n < len (p ) && err == nil {
642- var nn int
643- nn , err = c .br .Read (p [n :])
644- n += nn
645- }
646- if n == len (p ) {
647- err = nil
648- } else if err == io .EOF {
649- err = errUnexpectedEOF
650- }
651- return
652- }
653-
654642func (c * Conn ) advanceFrame () (int , error ) {
655643
656644// 1. Skip remainder of previous frame.
@@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {
663651
664652// 2. Read and parse first two bytes of frame header.
665653
666- var b [ 8 ] byte
667- if err := c . readFull ( b [: 2 ]); err != nil {
654+ p , err := c . read ( 2 )
655+ if err != nil {
668656return noFrame , err
669657}
670658
671- final := b [0 ]& finalBit != 0
672- frameType := int (b [0 ] & 0xf )
673- reserved := int ((b [0 ] >> 4 ) & 0x7 )
674- mask := b [1 ]& maskBit != 0
675- c .readRemaining = int64 (b [1 ] & 0x7f )
659+ final := p [0 ]& finalBit != 0
660+ frameType := int (p [0 ] & 0xf )
661+ reserved := int ((p [0 ] >> 4 ) & 0x7 )
662+ mask := p [1 ]& maskBit != 0
663+ c .readRemaining = int64 (p [1 ] & 0x7f )
676664
677665if reserved != 0 {
678666return noFrame , c .handleProtocolError ("unexpected reserved bits " + strconv .Itoa (reserved ))
@@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {
704692
705693switch c .readRemaining {
706694case 126 :
707- if err := c .readFull (b [:2 ]); err != nil {
695+ p , err := c .read (2 )
696+ if err != nil {
708697return noFrame , err
709698}
710- c .readRemaining = int64 (binary .BigEndian .Uint16 (b [: 2 ] ))
699+ c .readRemaining = int64 (binary .BigEndian .Uint16 (p ))
711700case 127 :
712- if err := c .readFull (b [:8 ]); err != nil {
701+ p , err := c .read (8 )
702+ if err != nil {
713703return noFrame , err
714704}
715- c .readRemaining = int64 (binary .BigEndian .Uint64 (b [: 8 ] ))
705+ c .readRemaining = int64 (binary .BigEndian .Uint64 (p ))
716706}
717707
718708// 4. Handle frame masking.
@@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {
723713
724714if mask {
725715c .readMaskPos = 0
726- if err := c .readFull (c .readMaskKey [:]); err != nil {
716+ p , err := c .read (len (c .readMaskKey ))
717+ if err != nil {
727718return noFrame , err
728719}
720+ copy (c .readMaskKey [:], p )
729721}
730722
731723// 5. For text and binary messages, enforce read limit and return.
@@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {
745737
746738var payload []byte
747739if c .readRemaining > 0 {
748- payload = make ([] byte , c .readRemaining )
740+ payload , err = c . read ( int ( c .readRemaining ) )
749741c .readRemaining = 0
750- if err := c . readFull ( payload ); err != nil {
742+ if err != nil {
751743return noFrame , err
752744}
753745if c .isServer {
@@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
805797// this method return the same error.
806798func (c * Conn ) NextReader () (messageType int , r io.Reader , err error ) {
807799
808- c .readSeq ++
800+ c .messageReader = nil
809801c .readLength = 0
810802
811803for c .readErr == nil {
@@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
815807break
816808}
817809if frameType == TextMessage || frameType == BinaryMessage {
818- return frameType , messageReader {c , c .readSeq }, nil
810+ r := & messageReader {c }
811+ c .messageReader = r
812+ return frameType , r , nil
819813}
820814}
821815
@@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
830824return noFrame , nil , c .readErr
831825}
832826
833- type messageReader struct {
834- c * Conn
835- seq int
836- }
837-
838- func (r messageReader ) Read (b []byte ) (int , error ) {
827+ type messageReader struct { c * Conn }
839828
840- if r .seq != r .c .readSeq {
829+ func (r * messageReader ) Read (b []byte ) (int , error ) {
830+ c := r .c
831+ if c .messageReader != r {
841832return 0 , io .EOF
842833}
843834
844- for r . c .readErr == nil {
835+ for c .readErr == nil {
845836
846- if r . c .readRemaining > 0 {
847- if int64 (len (b )) > r . c .readRemaining {
848- b = b [:r . c .readRemaining ]
837+ if c .readRemaining > 0 {
838+ if int64 (len (b )) > c .readRemaining {
839+ b = b [:c .readRemaining ]
849840}
850- n , err := r . c .br .Read (b )
851- r . c .readErr = hideTempErr (err )
852- if r . c .isServer {
853- r . c .readMaskPos = maskBytes (r . c .readMaskKey , r . c .readMaskPos , b [:n ])
841+ n , err := c .br .Read (b )
842+ c .readErr = hideTempErr (err )
843+ if c .isServer {
844+ c .readMaskPos = maskBytes (c .readMaskKey , c .readMaskPos , b [:n ])
854845}
855- r . c .readRemaining -= int64 (n )
856- if r . c .readRemaining > 0 && r . c .readErr == io .EOF {
857- r . c .readErr = errUnexpectedEOF
846+ c .readRemaining -= int64 (n )
847+ if c .readRemaining > 0 && c .readErr == io .EOF {
848+ c .readErr = errUnexpectedEOF
858849}
859- return n , r . c .readErr
850+ return n , c .readErr
860851}
861852
862- if r . c .readFinal {
863- r . c . readSeq ++
853+ if c .readFinal {
854+ c . messageReader = nil
864855return 0 , io .EOF
865856}
866857
867- frameType , err := r . c .advanceFrame ()
858+ frameType , err := c .advanceFrame ()
868859switch {
869860case err != nil :
870- r . c .readErr = hideTempErr (err )
861+ c .readErr = hideTempErr (err )
871862case frameType == TextMessage || frameType == BinaryMessage :
872- r . c .readErr = errors .New ("websocket: internal error, unexpected text or binary in Reader" )
863+ c .readErr = errors .New ("websocket: internal error, unexpected text or binary in Reader" )
873864}
874865}
875866
876- err := r . c .readErr
877- if err == io .EOF && r . seq == r . c . readSeq {
867+ err := c .readErr
868+ if err == io .EOF && c . messageReader == r {
878869err = errUnexpectedEOF
879870}
880871return 0 , err
0 commit comments