Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix race with c.readerShouldLock
Closes #168
  • Loading branch information
nhooyr committed Nov 4, 2019
commit 780bda4159cd001ed4e1704327c1292a1d21336d
53 changes: 36 additions & 17 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ type Conn struct {
readLock chan struct{}

// messageReader state.
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerShouldLock bool
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int

setReadTimeout chan context.Context
setWriteTimeout chan context.Context
Expand Down Expand Up @@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readMsgLeft = c.msgReadLimit.Load()
c.readerShouldLock = lock

r := &messageReader{
c: c,
Expand All @@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {

// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
return r.exportedRead(p, true)
}

func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
n, err := r.read(p, lock)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
// isn't used widely yet.
Expand All @@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
if r.c.readerShouldLock {
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
if err != nil {
return 0, err
func (r *messageReader) readUnlocked(p []byte) (int, error) {
return r.exportedRead(p, false)
}

func (r *messageReader) read(p []byte, lock bool) (int, error) {
if lock {
// If we cannot acquire the read lock, then
// there is either a concurrent read or the close handshake
// is proceeding.
select {
case r.c.readLock <- struct{}{}:
defer r.c.releaseLock(r.c.readLock)
default:
if r.c.closing.Load() == 1 {
<-r.c.closed
return 0, r.c.closeErr
}
return 0, errors.New("concurrent read detected")
}
defer r.c.releaseLock(r.c.readLock)
}

if r.eof() {
return 0, fmt.Errorf("cannot use EOFed reader")
return 0, errors.New("cannot use EOFed reader")
}

if r.c.readMsgLeft <= 0 {
Expand Down Expand Up @@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
return c.closeReceived
}

c.readerShouldLock = false

b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
Expand All @@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
}
}

_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
r := readerFunc(c.activeReader.readUnlocked)
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
}
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
Expand Down