Skip to content

Commit be01041

Browse files
committed
Reduce memory allocations in NextReader, NextWriter
Redo 8b209f6 with support for old versions of Go.
1 parent 50d660d commit be01041

File tree

3 files changed

+116
-86
lines changed

3 files changed

+116
-86
lines changed

conn.go

Lines changed: 77 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -238,23 +238,23 @@ type Conn struct {
238238
writeBuf []byte // frame is constructed in this buffer.
239239
writePos int // end of data in writeBuf.
240240
writeFrameType int // type of the current frame.
241-
writeSeq int // incremented to invalidate message writers.
242241
writeDeadline time.Time
243-
isWriting bool // for best-effort concurrent write detection
242+
isWriting bool // for best-effort concurrent write detection
243+
messageWriter *messageWriter // the current writer
244244

245245
// Read fields
246246
readErr error
247247
br *bufio.Reader
248248
readRemaining int64 // bytes remaining in current frame.
249249
readFinal bool // true the current message has more frames.
250-
readSeq int // incremented to invalidate message readers.
251250
readLength int64 // Message size.
252251
readLimit int64 // Maximum message size.
253252
readMaskPos int
254253
readMaskKey [4]byte
255254
handlePong func(string) error
256255
handlePing func(string) error
257256
readErrCount int
257+
messageReader *messageReader // the current reader
258258
}
259259

260260
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
264264
if readBufferSize == 0 {
265265
readBufferSize = defaultReadBufferSize
266266
}
267+
if readBufferSize < maxControlFramePayloadSize {
268+
readBufferSize = maxControlFramePayloadSize
269+
}
267270
if writeBufferSize == 0 {
268271
writeBufferSize = defaultWriteBufferSize
269272
}
@@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
390393
return hideTempErr(err)
391394
}
392395

393-
// NextWriter returns a writer for the next message to send. The writer's
394-
// Close method flushes the complete message to the network.
396+
// NextWriter returns a writer for the next message to send. The writer's Close
397+
// method flushes the complete message to the network.
395398
//
396399
// There can be at most one open writer on a connection. NextWriter closes the
397400
// previous writer if the application has not already done so.
@@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
411414
}
412415

413416
c.writeFrameType = messageType
414-
return messageWriter{c, c.writeSeq}, nil
417+
w := &messageWriter{c}
418+
c.messageWriter = w
419+
return w, nil
415420
}
416421

417422
func (c *Conn) flushFrame(final bool, extra []byte) error {
@@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
420425
// Check for invalid control frames.
421426
if isControl(c.writeFrameType) &&
422427
(!final || length > maxControlFramePayloadSize) {
423-
c.writeSeq++
428+
c.messageWriter = nil
424429
c.writeFrameType = noFrame
425430
c.writePos = maxFrameHeaderSize
426431
return errInvalidControlFrame
@@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
488493
c.writePos = maxFrameHeaderSize
489494
c.writeFrameType = continuationFrame
490495
if final {
491-
c.writeSeq++
496+
c.messageWriter = nil
492497
c.writeFrameType = noFrame
493498
}
494499
return c.writeErr
495500
}
496501

497-
type messageWriter struct {
498-
c *Conn
499-
seq int
500-
}
502+
type messageWriter struct{ c *Conn }
501503

502-
func (w messageWriter) err() error {
504+
func (w *messageWriter) err() error {
503505
c := w.c
504-
if c.writeSeq != w.seq {
506+
if c.messageWriter != w {
505507
return errWriteClosed
506508
}
507509
if c.writeErr != nil {
@@ -510,7 +512,7 @@ func (w messageWriter) err() error {
510512
return nil
511513
}
512514

513-
func (w messageWriter) ncopy(max int) (int, error) {
515+
func (w *messageWriter) ncopy(max int) (int, error) {
514516
n := len(w.c.writeBuf) - w.c.writePos
515517
if n <= 0 {
516518
if err := w.c.flushFrame(false, nil); err != nil {
@@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
524526
return n, nil
525527
}
526528

527-
func (w messageWriter) write(final bool, p []byte) (int, error) {
529+
func (w *messageWriter) write(final bool, p []byte) (int, error) {
528530
if err := w.err(); err != nil {
529531
return 0, err
530532
}
@@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
551553
return nn, nil
552554
}
553555

554-
func (w messageWriter) Write(p []byte) (int, error) {
556+
func (w *messageWriter) Write(p []byte) (int, error) {
555557
return w.write(false, p)
556558
}
557559

558-
func (w messageWriter) WriteString(p string) (int, error) {
560+
func (w *messageWriter) WriteString(p string) (int, error) {
559561
if err := w.err(); err != nil {
560562
return 0, err
561563
}
@@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
573575
return nn, nil
574576
}
575577

576-
func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
578+
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
577579
if err := w.err(); err != nil {
578580
return 0, err
579581
}
@@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
598600
return nn, err
599601
}
600602

601-
func (w messageWriter) Close() error {
603+
func (w *messageWriter) Close() error {
602604
if err := w.err(); err != nil {
603605
return err
604606
}
@@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
608610
// WriteMessage is a helper method for getting a writer using NextWriter,
609611
// writing the message and closing the writer.
610612
func (c *Conn) WriteMessage(messageType int, data []byte) error {
611-
wr, err := c.NextWriter(messageType)
613+
w, err := c.NextWriter(messageType)
612614
if err != nil {
613615
return err
614616
}
615-
w := wr.(messageWriter)
616-
if _, err := w.write(true, data); err != nil {
617+
if _, ok := w.(*messageWriter); ok && c.isServer {
618+
// Optimize write as a single frame.
619+
n := copy(c.writeBuf[c.writePos:], data)
620+
c.writePos += n
621+
data = data[n:]
622+
err = c.flushFrame(true, data)
617623
return err
618624
}
619-
if c.writeSeq == w.seq {
620-
if err := c.flushFrame(true, nil); err != nil {
621-
return err
622-
}
625+
if _, err = w.Write(data); err != nil {
626+
return err
623627
}
624-
return nil
628+
return w.Close()
625629
}
626630

627631
// SetWriteDeadline sets the write deadline on the underlying network
@@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
635639

636640
// Read methods
637641

638-
// readFull is like io.ReadFull except that io.EOF is never returned.
639-
func (c *Conn) readFull(p []byte) (err error) {
640-
var n int
641-
for n < len(p) && err == nil {
642-
var nn int
643-
nn, err = c.br.Read(p[n:])
644-
n += nn
645-
}
646-
if n == len(p) {
647-
err = nil
648-
} else if err == io.EOF {
649-
err = errUnexpectedEOF
650-
}
651-
return
652-
}
653-
654642
func (c *Conn) advanceFrame() (int, error) {
655643

656644
// 1. Skip remainder of previous frame.
@@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {
663651

664652
// 2. Read and parse first two bytes of frame header.
665653

666-
var b [8]byte
667-
if err := c.readFull(b[:2]); err != nil {
654+
p, err := c.read(2)
655+
if err != nil {
668656
return noFrame, err
669657
}
670658

671-
final := b[0]&finalBit != 0
672-
frameType := int(b[0] & 0xf)
673-
reserved := int((b[0] >> 4) & 0x7)
674-
mask := b[1]&maskBit != 0
675-
c.readRemaining = int64(b[1] & 0x7f)
659+
final := p[0]&finalBit != 0
660+
frameType := int(p[0] & 0xf)
661+
reserved := int((p[0] >> 4) & 0x7)
662+
mask := p[1]&maskBit != 0
663+
c.readRemaining = int64(p[1] & 0x7f)
676664

677665
if reserved != 0 {
678666
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
@@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {
704692

705693
switch c.readRemaining {
706694
case 126:
707-
if err := c.readFull(b[:2]); err != nil {
695+
p, err := c.read(2)
696+
if err != nil {
708697
return noFrame, err
709698
}
710-
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
699+
c.readRemaining = int64(binary.BigEndian.Uint16(p))
711700
case 127:
712-
if err := c.readFull(b[:8]); err != nil {
701+
p, err := c.read(8)
702+
if err != nil {
713703
return noFrame, err
714704
}
715-
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
705+
c.readRemaining = int64(binary.BigEndian.Uint64(p))
716706
}
717707

718708
// 4. Handle frame masking.
@@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {
723713

724714
if mask {
725715
c.readMaskPos = 0
726-
if err := c.readFull(c.readMaskKey[:]); err != nil {
716+
p, err := c.read(len(c.readMaskKey))
717+
if err != nil {
727718
return noFrame, err
728719
}
720+
copy(c.readMaskKey[:], p)
729721
}
730722

731723
// 5. For text and binary messages, enforce read limit and return.
@@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {
745737

746738
var payload []byte
747739
if c.readRemaining > 0 {
748-
payload = make([]byte, c.readRemaining)
740+
payload, err = c.read(int(c.readRemaining))
749741
c.readRemaining = 0
750-
if err := c.readFull(payload); err != nil {
742+
if err != nil {
751743
return noFrame, err
752744
}
753745
if c.isServer {
@@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
805797
// this method return the same error.
806798
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
807799

808-
c.readSeq++
800+
c.messageReader = nil
809801
c.readLength = 0
810802

811803
for c.readErr == nil {
@@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
815807
break
816808
}
817809
if frameType == TextMessage || frameType == BinaryMessage {
818-
return frameType, messageReader{c, c.readSeq}, nil
810+
r := &messageReader{c}
811+
c.messageReader = r
812+
return frameType, r, nil
819813
}
820814
}
821815

@@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
830824
return noFrame, nil, c.readErr
831825
}
832826

833-
type messageReader struct {
834-
c *Conn
835-
seq int
836-
}
837-
838-
func (r messageReader) Read(b []byte) (int, error) {
827+
type messageReader struct{ c *Conn }
839828

840-
if r.seq != r.c.readSeq {
829+
func (r *messageReader) Read(b []byte) (int, error) {
830+
c := r.c
831+
if c.messageReader != r {
841832
return 0, io.EOF
842833
}
843834

844-
for r.c.readErr == nil {
835+
for c.readErr == nil {
845836

846-
if r.c.readRemaining > 0 {
847-
if int64(len(b)) > r.c.readRemaining {
848-
b = b[:r.c.readRemaining]
837+
if c.readRemaining > 0 {
838+
if int64(len(b)) > c.readRemaining {
839+
b = b[:c.readRemaining]
849840
}
850-
n, err := r.c.br.Read(b)
851-
r.c.readErr = hideTempErr(err)
852-
if r.c.isServer {
853-
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
841+
n, err := c.br.Read(b)
842+
c.readErr = hideTempErr(err)
843+
if c.isServer {
844+
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
854845
}
855-
r.c.readRemaining -= int64(n)
856-
if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
857-
r.c.readErr = errUnexpectedEOF
846+
c.readRemaining -= int64(n)
847+
if c.readRemaining > 0 && c.readErr == io.EOF {
848+
c.readErr = errUnexpectedEOF
858849
}
859-
return n, r.c.readErr
850+
return n, c.readErr
860851
}
861852

862-
if r.c.readFinal {
863-
r.c.readSeq++
853+
if c.readFinal {
854+
c.messageReader = nil
864855
return 0, io.EOF
865856
}
866857

867-
frameType, err := r.c.advanceFrame()
858+
frameType, err := c.advanceFrame()
868859
switch {
869860
case err != nil:
870-
r.c.readErr = hideTempErr(err)
861+
c.readErr = hideTempErr(err)
871862
case frameType == TextMessage || frameType == BinaryMessage:
872-
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
863+
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
873864
}
874865
}
875866

876-
err := r.c.readErr
877-
if err == io.EOF && r.seq == r.c.readSeq {
867+
err := c.readErr
868+
if err == io.EOF && c.messageReader == r {
878869
err = errUnexpectedEOF
879870
}
880871
return 0, err

conn_read.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build go1.5
6+
7+
package websocket
8+
9+
import "io"
10+
11+
func (c *Conn) read(n int) ([]byte, error) {
12+
p, err := c.br.Peek(n)
13+
if err == io.EOF {
14+
err = errUnexpectedEOF
15+
}
16+
c.br.Discard(len(p))
17+
return p, err
18+
}

0 commit comments

Comments
 (0)