@@ -14,7 +14,8 @@ import (
1414)
1515
1616// Conn represents a WebSocket connection.
17- // All methods except Reader can be used concurrently.
17+ // All methods may be called concurrently.
18+ //
1819// Please be sure to call Close on the connection when you
1920// are finished with it to release resources.
2021type Conn struct {
@@ -31,8 +32,10 @@ type Conn struct {
3132 writeDataLock chan struct {}
3233 writeFrameLock chan struct {}
3334
34- readData chan header
35- readDone chan struct {}
35+ readDataLock chan struct {}
36+ readData chan header
37+ readDone chan struct {}
38+ readLoopDone chan struct {}
3639
3740 setReadTimeout chan context.Context
3841 setWriteTimeout chan context.Context
@@ -44,7 +47,7 @@ type Conn struct {
4447// when the connection is closed.
4548// If the parent context is cancelled, the connection will be closed.
4649//
47- // This is an experimental API meaning it may be remove in the future.
50+ // This is an experimental API that may be remove in the future.
4851// Please let me know how you feel about it.
4952func (c * Conn ) Context (parent context.Context ) context.Context {
5053 select {
@@ -77,6 +80,18 @@ func (c *Conn) close(err error) {
7780 c .closeErr = xerrors .Errorf ("websocket closed: %w" , cerr )
7881
7982 close (c .closed )
83+
84+ // See comment in dial.go
85+ if c .client {
86+ go func () {
87+ <- c .readLoopDone
88+ c .readDataLock <- struct {}{}
89+ c .writeFrameLock <- struct {}{}
90+
91+ returnBufioReader (c .br )
92+ returnBufioWriter (c .bw )
93+ }()
94+ }
8095 })
8196}
8297
@@ -94,6 +109,8 @@ func (c *Conn) init() {
94109
95110 c .readData = make (chan header )
96111 c .readDone = make (chan struct {})
112+ c .readDataLock = make (chan struct {}, 1 )
113+ c .readLoopDone = make (chan struct {})
97114
98115 c .setReadTimeout = make (chan context.Context )
99116 c .setWriteTimeout = make (chan context.Context )
@@ -174,8 +191,8 @@ func (c *Conn) timeoutLoop() {
174191 select {
175192 case <- c .closed :
176193 return
177- case readCtx = <- c .setWriteTimeout :
178- case writeCtx = <- c .setReadTimeout :
194+ case writeCtx = <- c .setWriteTimeout :
195+ case readCtx = <- c .setReadTimeout :
179196 case <- readCtx .Done ():
180197 c .close (xerrors .Errorf ("data read timed out: %w" , readCtx .Err ()))
181198 case <- writeCtx .Done ():
@@ -276,6 +293,8 @@ func (c *Conn) readTillData() (header, error) {
276293}
277294
278295func (c * Conn ) readLoop () {
296+ defer close (c .readLoopDone )
297+
279298 for {
280299 h , err := c .readTillData ()
281300 if err != nil {
@@ -487,8 +506,7 @@ func (w *messageWriter) close() error {
487506//
488507// Your application must keep reading messages for the Conn to automatically respond to ping
489508// and close frames and not become stuck waiting for a data message to be read.
490- // Please ensure to read the full message from io.Reader. If you do not read till
491- // io.EOF, the connection will break unless the next read would have yielded io.EOF.
509+ // Please ensure to read the full message from io.Reader.
492510//
493511// You can only read a single message at a time so do not call this method
494512// concurrently.
@@ -500,30 +518,10 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
500518 return typ , r , nil
501519}
502520
503- func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
504- // if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
505- // // If the next read yields io.EOF we are good to go.
506- // r := messageReader{
507- // ctx: ctx,
508- // c: c,
509- // }
510- // _, err := r.Read(nil)
511- // if err == nil {
512- // return 0, nil, xerrors.New("previous message not fully read")
513- // }
514- // if !xerrors.Is(err, io.EOF) {
515- // return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
516- // }
517- //
518- // atomic.StoreInt64(&c.activeReader, 1)
519- // }
520-
521- select {
522- case <- c .closed :
523- return 0 , nil , c .closeErr
524- case <- ctx .Done ():
525- return 0 , nil , ctx .Err ()
526- case c .setReadTimeout <- ctx :
521+ func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
522+ err = c .acquireLock (ctx , c .readDataLock )
523+ if err != nil {
524+ return 0 , nil , err
527525 }
528526
529527 select {
@@ -533,25 +531,24 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
533531 return 0 , nil , ctx .Err ()
534532 case h := <- c .readData :
535533 if h .opcode == opContinuation {
536- if h .fin && h .payloadLength == 0 {
537- select {
538- case <- c .closed :
539- return 0 , nil , c .closeErr
540- case c .readDone <- struct {}{}:
541- return c .reader (ctx )
542- }
534+ ce := CloseError {
535+ Code : StatusProtocolError ,
536+ Reason : "continuation frame not after data or text frame" ,
543537 }
544- return 0 , nil , xerrors .Errorf ("previous reader was not read to EOF" )
538+ c .Close (ce .Code , ce .Reason )
539+ return 0 , nil , ce
545540 }
546541 return MessageType (h .opcode ), & messageReader {
547- h : & h ,
548- c : c ,
542+ ctx : ctx ,
543+ h : & h ,
544+ c : c ,
549545 }, nil
550546 }
551547}
552548
553549// messageReader enables reading a data frame from the WebSocket connection.
554550type messageReader struct {
551+ ctx context.Context
555552 maskPos int
556553 h * header
557554 c * Conn
@@ -598,8 +595,20 @@ func (r *messageReader) read(p []byte) (int, error) {
598595 p = p [:r .h .payloadLength ]
599596 }
600597
598+ select {
599+ case <- r .c .closed :
600+ return 0 , r .c .closeErr
601+ case r .c .setReadTimeout <- r .ctx :
602+ }
603+
601604 n , err := io .ReadFull (r .c .br , p )
602605
606+ select {
607+ case <- r .c .closed :
608+ return 0 , r .c .closeErr
609+ case r .c .setReadTimeout <- context .Background ():
610+ }
611+
603612 r .h .payloadLength -= int64 (n )
604613 if r .h .masked {
605614 r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
@@ -618,12 +627,8 @@ func (r *messageReader) read(p []byte) (int, error) {
618627 }
619628 if r .h .fin {
620629 r .eofed = true
621- select {
622- case <- r .c .closed :
623- return n , r .c .closeErr
624- case r .c .setReadTimeout <- context .Background ():
625- return n , io .EOF
626- }
630+ r .c .releaseLock (r .c .readDataLock )
631+ return n , io .EOF
627632 }
628633 r .maskPos = 0
629634 r .h = nil
0 commit comments