Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e33a749

Browse files
mafredriEmyrk
andauthored
fix: Deadlock and race in peer, test improvements (#3086)
* fix: Potential deadlock in peer.Channel dc.OnOpen * fix: Potential send on closed channel * fix: Improve robustness of waitOpened during close * chore: Simplify statements * fix: Improve teardown and timeout of peer tests * fix: Improve robustness of TestConn/Buffering test * Update peer/channel.go Co-authored-by: Steven Masley <[email protected]>
1 parent 62e6856 commit e33a749

File tree

3 files changed

+85
-40
lines changed

3 files changed

+85
-40
lines changed

peer/channel.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,15 @@ func (c *Channel) init() {
106106
// write operations to block once the threshold is set.
107107
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
108108
c.dc.OnBufferedAmountLow(func() {
109+
// Grab the lock to protect the sendMore channel from being
110+
// closed in between the isClosed check and the send.
111+
c.closeMutex.Lock()
112+
defer c.closeMutex.Unlock()
109113
if c.isClosed() {
110114
return
111115
}
112116
select {
113117
case <-c.closed:
114-
return
115118
case c.sendMore <- struct{}{}:
116119
default:
117120
}
@@ -122,15 +125,16 @@ func (c *Channel) init() {
122125
})
123126
c.dc.OnOpen(func() {
124127
c.closeMutex.Lock()
125-
defer c.closeMutex.Unlock()
126-
127128
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
128129
var err error
129130
c.rwc, err = c.dc.Detach()
130131
if err != nil {
132+
c.closeMutex.Unlock()
131133
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
132134
return
133135
}
136+
c.closeMutex.Unlock()
137+
134138
// pion/webrtc will return an io.ErrShortBuffer when a read
135139
// is triggerred with a buffer size less than the chunks written.
136140
//
@@ -189,9 +193,6 @@ func (c *Channel) init() {
189193
//
190194
// This will block until the underlying DataChannel has been opened.
191195
func (c *Channel) Read(bytes []byte) (int, error) {
192-
if c.isClosed() {
193-
return 0, c.closeError
194-
}
195196
err := c.waitOpened()
196197
if err != nil {
197198
return 0, err
@@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) {
228229
c.writeMutex.Lock()
229230
defer c.writeMutex.Unlock()
230231

231-
if c.isClosed() {
232-
return 0, c.closeWithError(nil)
233-
}
234232
err = c.waitOpened()
235233
if err != nil {
236234
return 0, err
@@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool {
308306
func (c *Channel) waitOpened() error {
309307
select {
310308
case <-c.opened:
309+
// Re-check the closed channel to prioritize closure.
310+
if c.isClosed() {
311+
return c.closeError
312+
}
311313
return nil
312314
case <-c.closed:
313315
return c.closeError

peer/conn.go

-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package peer
33
import (
44
"bytes"
55
"context"
6-
76
"crypto/rand"
87
"io"
98
"sync"
@@ -256,7 +255,6 @@ func (c *Conn) init() error {
256255
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
257256
select {
258257
case <-c.closed:
259-
break
260258
case c.localCandidateChannel <- iceCandidate.ToJSON():
261259
}
262260
}()
@@ -265,7 +263,6 @@ func (c *Conn) init() error {
265263
go func() {
266264
select {
267265
case <-c.closed:
268-
return
269266
case c.dcOpenChannel <- dc:
270267
}
271268
}()
@@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
435432
data := make([]byte, pingDataLength)
436433
bytesRead, err := c.pingEchoChan.Read(data)
437434
if err != nil {
438-
if c.isClosed() {
439-
return
440-
}
441435
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
442436
return
443437
}

peer/conn_test.go

+74-25
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ func TestConn(t *testing.T) {
9191
// Create a channel that closes on disconnect.
9292
channel, err := server.CreateChannel(context.Background(), "wow", nil)
9393
assert.NoError(t, err)
94+
defer channel.Close()
95+
9496
err = wan.Stop()
9597
require.NoError(t, err)
9698
// Once the connection is marked as disconnected, this
@@ -107,10 +109,13 @@ func TestConn(t *testing.T) {
107109
t.Parallel()
108110
client, server, _ := createPair(t)
109111
exchange(t, client, server)
110-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
112+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
113+
defer cancel()
114+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
111115
require.NoError(t, err)
116+
defer cch.Close()
112117

113-
sch, err := server.Accept(context.Background())
118+
sch, err := server.Accept(ctx)
114119
require.NoError(t, err)
115120
defer sch.Close()
116121

@@ -123,9 +128,12 @@ func TestConn(t *testing.T) {
123128
t.Parallel()
124129
client, server, wan := createPair(t)
125130
exchange(t, client, server)
126-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
131+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
132+
defer cancel()
133+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
127134
require.NoError(t, err)
128-
sch, err := server.Accept(context.Background())
135+
defer cch.Close()
136+
sch, err := server.Accept(ctx)
129137
require.NoError(t, err)
130138
defer sch.Close()
131139

@@ -140,26 +148,44 @@ func TestConn(t *testing.T) {
140148
t.Parallel()
141149
client, server, _ := createPair(t)
142150
exchange(t, client, server)
143-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
144-
require.NoError(t, err)
145-
sch, err := server.Accept(context.Background())
151+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
152+
defer cancel()
153+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
146154
require.NoError(t, err)
147-
defer sch.Close()
155+
defer cch.Close()
156+
157+
readErr := make(chan error, 1)
148158
go func() {
159+
sch, err := server.Accept(ctx)
160+
if err != nil {
161+
readErr <- err
162+
_ = cch.Close()
163+
return
164+
}
165+
defer sch.Close()
166+
149167
bytes := make([]byte, 4096)
150-
for i := 0; i < 1024; i++ {
151-
_, err := cch.Write(bytes)
152-
require.NoError(t, err)
168+
for {
169+
_, err = sch.Read(bytes)
170+
if err != nil {
171+
readErr <- err
172+
return
173+
}
153174
}
154-
_ = cch.Close()
155175
}()
176+
156177
bytes := make([]byte, 4096)
157-
for {
158-
_, err = sch.Read(bytes)
159-
if err != nil {
160-
require.ErrorIs(t, err, peer.ErrClosed)
161-
break
162-
}
178+
for i := 0; i < 1024; i++ {
179+
_, err = cch.Write(bytes)
180+
require.NoError(t, err, "write i=%d", i)
181+
}
182+
_ = cch.Close()
183+
184+
select {
185+
case err = <-readErr:
186+
require.ErrorIs(t, err, peer.ErrClosed, "read error")
187+
case <-ctx.Done():
188+
require.Fail(t, "timeout waiting for read error")
163189
}
164190
})
165191

@@ -170,13 +196,29 @@ func TestConn(t *testing.T) {
170196
srv, err := net.Listen("tcp", "127.0.0.1:0")
171197
require.NoError(t, err)
172198
defer srv.Close()
199+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
200+
defer cancel()
173201
go func() {
174-
sch, err := server.Accept(context.Background())
175-
assert.NoError(t, err)
202+
sch, err := server.Accept(ctx)
203+
if err != nil {
204+
assert.NoError(t, err)
205+
return
206+
}
207+
defer sch.Close()
208+
176209
nc2 := sch.NetConn()
210+
defer nc2.Close()
211+
177212
nc1, err := net.Dial("tcp", srv.Addr().String())
178-
assert.NoError(t, err)
213+
if err != nil {
214+
assert.NoError(t, err)
215+
return
216+
}
217+
defer nc1.Close()
218+
179219
go func() {
220+
defer nc1.Close()
221+
defer nc2.Close()
180222
_, _ = io.Copy(nc1, nc2)
181223
}()
182224
_, _ = io.Copy(nc2, nc1)
@@ -204,7 +246,7 @@ func TestConn(t *testing.T) {
204246
c := http.Client{
205247
Transport: defaultTransport,
206248
}
207-
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
249+
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
208250
require.NoError(t, err)
209251
resp, err := c.Do(req)
210252
require.NoError(t, err)
@@ -272,14 +314,21 @@ func TestConn(t *testing.T) {
272314
t.Parallel()
273315
client, server, _ := createPair(t)
274316
exchange(t, client, server)
317+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
318+
defer cancel()
275319
go func() {
276-
channel, err := client.CreateChannel(context.Background(), "test", nil)
277-
assert.NoError(t, err)
320+
channel, err := client.CreateChannel(ctx, "test", nil)
321+
if err != nil {
322+
assert.NoError(t, err)
323+
return
324+
}
325+
defer channel.Close()
278326
_, err = channel.Write([]byte{1, 2})
279327
assert.NoError(t, err)
280328
}()
281-
channel, err := server.Accept(context.Background())
329+
channel, err := server.Accept(ctx)
282330
require.NoError(t, err)
331+
defer channel.Close()
283332
data := make([]byte, 1)
284333
_, err = channel.Read(data)
285334
require.NoError(t, err)

0 commit comments

Comments
 (0)