Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ type Conn struct {
readLock chan struct{}

// messageReader state.
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerShouldLock bool
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int

setReadTimeout chan context.Context
setWriteTimeout chan context.Context
Expand Down Expand Up @@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
if h.opcode.controlOp() {
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
continue
Expand Down Expand Up @@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readMsgLeft = c.msgReadLimit.Load()
c.readerShouldLock = lock

r := &messageReader{
c: c,
Expand All @@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {

// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
return r.exportedRead(p, true)
}

func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
n, err := r.read(p, lock)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
// isn't used widely yet.
Expand All @@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
if r.c.readerShouldLock {
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
if err != nil {
return 0, err
func (r *messageReader) readUnlocked(p []byte) (int, error) {
return r.exportedRead(p, false)
}

func (r *messageReader) read(p []byte, lock bool) (int, error) {
if lock {
// If we cannot acquire the read lock, then
// there is either a concurrent read or the close handshake
// is proceeding.
select {
case r.c.readLock <- struct{}{}:
defer r.c.releaseLock(r.c.readLock)
default:
if r.c.closing.Load() == 1 {
<-r.c.closed
return 0, r.c.closeErr
}
return 0, errors.New("concurrent read detected")
}
defer r.c.releaseLock(r.c.readLock)
}

if r.eof() {
return 0, fmt.Errorf("cannot use EOFed reader")
return 0, errors.New("cannot use EOFed reader")
}

if r.c.readMsgLeft <= 0 {
Expand Down Expand Up @@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
return c.closeReceived
}

c.readerShouldLock = false

b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
Expand All @@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
}
}

_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
r := readerFunc(c.activeReader.readUnlocked)
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
}
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
Expand Down