Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix: Deadlock and race in peer, test improvements #3086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions peer/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,15 @@ func (c *Channel) init() {
// write operations to block once the threshold is set.
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
c.dc.OnBufferedAmountLow(func() {
// Grab the lock to protect the sendMore channel from being
// closed in between the isClosed check and the send.
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return
}
select {
case <-c.closed:
Comment on lines 113 to 117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check? Is it a problem if sendMore gets a value in the buffered channel? I don't think it matters which one is prioritized. Either one will exit the select statement and "return". If sendMore is sent a write will happen. But if you close the connection, a write will also happen because the <-c.sendMore gets unblocked.

if c.isClosed() {
	return
}

Idk, just seems redundant.

Grabbing the closed lock doesn't seem necessary because of the same argument.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both (check and mutex) are needed because c.sendMore is closed in c.closeWithError, so we want to guard against potentially sending on a closed channel (which would panic).

By holding the mutex, we ensure that closure doesn't happen between the isClosed() check and send on c.sendMore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh btw, I was also first of the impression that the check might not be needed, but without it the pion/sctp library may try to unlock an unlocked mutex:

fatal error: sync: Unlock of unlocked RWMutex

goroutine 447 [running]:
runtime.throw({0xe70b28?, 0x0?})
	/usr/local/go/src/runtime/panic.go:992 +0x71 fp=0xc0001d26c8 sp=0xc0001d2698 pc=0x469d71
sync.throw({0xe70b28?, 0x418f30?})
	/usr/local/go/src/runtime/panic.go:978 +0x1e fp=0xc0001d26e8 sp=0xc0001d26c8 pc=0x499abe
sync.(*RWMutex).Unlock(0xc000158550)
	/usr/local/go/src/sync/rwmutex.go:201 +0x7e fp=0xc0001d2728 sp=0xc0001d26e8 pc=0x4afbfe
github.com/pion/sctp.(*Association).handleChunk.func1()
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:2244 +0x3a fp=0xc0001d2748 sp=0xc0001d2728 pc=0x77e4da
panic({0xdb16e0, 0xfd2780})
	/usr/local/go/src/runtime/panic.go:844 +0x258 fp=0xc0001d2808 sp=0xc0001d2748 pc=0x4697d8
runtime.selectgo(0xc0001d29c8, 0xc0001d2998, 0x47ad94?, 0x1, 0x7fc6abe90008?, 0x0)
	/usr/local/go/src/runtime/select.go:516 +0xf3c fp=0xc0001d2968 sp=0xc0001d2808 pc=0x47dfbc
github.com/coder/coder/peer.(*Channel).init.func1()
	/home/maf/src/coder/peer/channel.go:109 +0xd0 fp=0xc0001d29f8 sp=0xc0001d2968 pc=0xb13710
github.com/pion/sctp.(*Stream).onBufferReleased(0xc00020e500, 0x16a0)
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/stream.go:357 +0x4af fp=0xc0001d2a80 sp=0xc0001d29f8 pc=0x79faef
github.com/pion/sctp.(*Association).handleSack(0xc000158540, 0xc000c32000)
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:1623 +0x9bb fp=0xc0001d2c28 sp=0xc0001d2a80 pc=0x776a1b
github.com/pion/sctp.(*Association).handleChunk(0xc000158540, 0xc000e02d60?, {0xfd88e8?, 0xc000c32000?})
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:2288 +0x30d fp=0xc0001d2df0 sp=0xc0001d2c28 pc=0x77d88d
github.com/pion/sctp.(*Association).handleInbound(0xc000158540, {0xc000e02d60, 0x1c, 0x1c})
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:603 +0x505 fp=0xc0001d2ec0 sp=0xc0001d2df0 pc=0x769585
github.com/pion/sctp.(*Association).readLoop(0xc000158540)
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:521 +0x29c fp=0xc0001d2fc0 sp=0xc0001d2ec0 pc=0x76783c
github.com/pion/sctp.(*Association).init.func2()
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:339 +0x3a fp=0xc0001d2fe0 sp=0xc0001d2fc0 pc=0x765a1a
runtime.goexit()
	/usr/local/go/src/runtime/asm_amd64.s:1571 +0x1 fp=0xc0001d2fe8 sp=0xc0001d2fe0 pc=0x49edc1
created by github.com/pion/sctp.(*Association).init
	/home/maf/go/pkg/mod/github.com/pion/[email protected]/association.go:339 +0x12a

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will a select send something on a closed channel? If so TIL.

👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, if you try to run this playground a couple of times, you will notice that sometimes it exits 0, and sometimes panics: https://go.dev/play/p/c35kE0948kl

return
case c.sendMore <- struct{}{}:
default:
}
Expand All @@ -122,15 +125,16 @@ func (c *Channel) init() {
})
c.dc.OnOpen(func() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
var err error
c.rwc, err = c.dc.Detach()
if err != nil {
c.closeMutex.Unlock()
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
return
}
c.closeMutex.Unlock()

// pion/webrtc will return an io.ErrShortBuffer when a read
// is triggerred with a buffer size less than the chunks written.
//
Expand Down Expand Up @@ -189,9 +193,6 @@ func (c *Channel) init() {
//
// This will block until the underlying DataChannel has been opened.
func (c *Channel) Read(bytes []byte) (int, error) {
if c.isClosed() {
return 0, c.closeError
}
err := c.waitOpened()
if err != nil {
return 0, err
Expand Down Expand Up @@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()

if c.isClosed() {
return 0, c.closeWithError(nil)
}
err = c.waitOpened()
if err != nil {
return 0, err
Expand Down Expand Up @@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool {
func (c *Channel) waitOpened() error {
select {
case <-c.opened:
// Re-check the closed channel to prioritize closure.
if c.isClosed() {
return c.closeError
}
return nil
case <-c.closed:
return c.closeError
Expand Down
6 changes: 0 additions & 6 deletions peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package peer
import (
"bytes"
"context"

"crypto/rand"
"io"
"sync"
Expand Down Expand Up @@ -256,7 +255,6 @@ func (c *Conn) init() error {
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
select {
case <-c.closed:
break
case c.localCandidateChannel <- iceCandidate.ToJSON():
}
}()
Expand All @@ -265,7 +263,6 @@ func (c *Conn) init() error {
go func() {
select {
case <-c.closed:
return
case c.dcOpenChannel <- dc:
}
}()
Expand Down Expand Up @@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
data := make([]byte, pingDataLength)
bytesRead, err := c.pingEchoChan.Read(data)
if err != nil {
if c.isClosed() {
return
}
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
return
}
Expand Down
99 changes: 74 additions & 25 deletions peer/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ func TestConn(t *testing.T) {
// Create a channel that closes on disconnect.
channel, err := server.CreateChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
defer channel.Close()

err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
Expand All @@ -107,10 +109,13 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer cch.Close()

sch, err := server.Accept(context.Background())
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()

Expand All @@ -123,9 +128,12 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
defer cch.Close()
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()

Expand All @@ -140,26 +148,44 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer sch.Close()
defer cch.Close()

readErr := make(chan error, 1)
go func() {
sch, err := server.Accept(ctx)
if err != nil {
readErr <- err
_ = cch.Close()
return
}
defer sch.Close()

bytes := make([]byte, 4096)
for i := 0; i < 1024; i++ {
_, err := cch.Write(bytes)
require.NoError(t, err)
for {
_, err = sch.Read(bytes)
if err != nil {
readErr <- err
return
}
}
_ = cch.Close()
}()

bytes := make([]byte, 4096)
for {
_, err = sch.Read(bytes)
if err != nil {
require.ErrorIs(t, err, peer.ErrClosed)
break
}
for i := 0; i < 1024; i++ {
_, err = cch.Write(bytes)
require.NoError(t, err, "write i=%d", i)
}
_ = cch.Close()

select {
case err = <-readErr:
require.ErrorIs(t, err, peer.ErrClosed, "read error")
case <-ctx.Done():
require.Fail(t, "timeout waiting for read error")
}
})

Expand All @@ -170,13 +196,29 @@ func TestConn(t *testing.T) {
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
go func() {
sch, err := server.Accept(context.Background())
assert.NoError(t, err)
sch, err := server.Accept(ctx)
if err != nil {
assert.NoError(t, err)
return
}
defer sch.Close()

nc2 := sch.NetConn()
defer nc2.Close()

nc1, err := net.Dial("tcp", srv.Addr().String())
assert.NoError(t, err)
if err != nil {
assert.NoError(t, err)
return
}
defer nc1.Close()

go func() {
defer nc1.Close()
defer nc2.Close()
_, _ = io.Copy(nc1, nc2)
}()
_, _ = io.Copy(nc2, nc1)
Expand Down Expand Up @@ -204,7 +246,7 @@ func TestConn(t *testing.T) {
c := http.Client{
Transport: defaultTransport,
}
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
Expand Down Expand Up @@ -272,14 +314,21 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
go func() {
channel, err := client.CreateChannel(context.Background(), "test", nil)
assert.NoError(t, err)
channel, err := client.CreateChannel(ctx, "test", nil)
if err != nil {
assert.NoError(t, err)
return
}
defer channel.Close()
_, err = channel.Write([]byte{1, 2})
assert.NoError(t, err)
}()
channel, err := server.Accept(context.Background())
channel, err := server.Accept(ctx)
require.NoError(t, err)
defer channel.Close()
data := make([]byte, 1)
_, err = channel.Read(data)
require.NoError(t, err)
Expand Down