Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 175c42e

Browse files
authored
fix: Only close connection scoped data channels on RTC close (#403)
* fix: Only close connection scoped data channels on RTC close * Fix array pointer * Fix test * Remove old test
1 parent d673079 commit 175c42e

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

wsnet/dial_test.go

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -293,39 +293,55 @@ func TestDial(t *testing.T) {
293293

294294
t.Run("Close Listeners on Disconnect", func(t *testing.T) {
295295
t.Parallel()
296+
log := slogtest.Make(t, nil)
296297

297-
tcpListener, err := net.Listen("tcp", "0.0.0.0:0")
298+
listener, err := net.Listen("tcp", "0.0.0.0:0")
298299
require.NoError(t, err)
299300
go func() {
300-
_, _ = tcpListener.Accept()
301+
for {
302+
c, _ := listener.Accept()
303+
304+
go func() {
305+
b := make([]byte, 5)
306+
_, err := c.Read(b)
307+
if err != nil {
308+
return
309+
}
310+
_, err = c.Write(b)
311+
require.NoError(t, err)
312+
}()
313+
}
301314
}()
302-
303315
connectAddr, listenAddr := createDumbBroker(t)
304-
l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
316+
_, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "")
305317
require.NoError(t, err)
306318

307-
turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN)
308-
dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{
309-
ICEServers: []webrtc.ICEServer{{
310-
URLs: []string{fmt.Sprintf("turn:%s", turnAddr)},
311-
Username: "example",
312-
Credential: testPass,
313-
CredentialType: webrtc.ICECredentialTypePassword,
314-
}},
319+
d1, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{
320+
Log: &log,
315321
}, nil)
316322
require.NoError(t, err)
323+
_, err = d1.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
324+
require.NoError(t, err)
317325

318-
_, err = dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String())
326+
d2, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{
327+
Log: &log,
328+
}, nil)
329+
require.NoError(t, err)
330+
conn, err := d2.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String())
331+
require.NoError(t, err)
332+
err = d1.Close()
319333
require.NoError(t, err)
320334

321-
closeTurn()
335+
// TODO: This needs to be longer than the KeepAlive timeout for the RTC connection.
336+
// Once the listener stores RTC connections instead of io.Closer we can directly
337+
// reference the RTC connection to ensure it's properly closed.
338+
time.Sleep(time.Second * 10)
322339

323-
list := l.(*listener)
324-
assert.Eventually(t, func() bool {
325-
list.connClosersMut.Lock()
326-
defer list.connClosersMut.Unlock()
327-
return len(list.connClosers) == 0
328-
}, time.Second*15, time.Millisecond*100)
340+
b := []byte("hello")
341+
_, err = conn.Write(b)
342+
require.NoError(t, err)
343+
_, err = conn.Read(b)
344+
require.NoError(t, err)
329345
})
330346
}
331347

wsnet/listen.go

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,11 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) {
163163
ctx = slog.With(ctx, slog.F("conn_id", id))
164164

165165
var (
166-
err error
167-
decoder = json.NewDecoder(conn)
168-
rtc *webrtc.PeerConnection
166+
err error
167+
decoder = json.NewDecoder(conn)
168+
rtc *webrtc.PeerConnection
169+
connClosers = make([]io.Closer, 0)
170+
connClosersMut sync.Mutex
169171
// If candidates are sent before an offer, we place them here.
170172
// We currently have no assurances to ensure this can't happen,
171173
// so it's better to buffer and process than fail.
@@ -255,6 +257,9 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) {
255257
closeError(err)
256258
return
257259
}
260+
l.connClosersMut.Lock()
261+
l.connClosers = append(l.connClosers, rtc)
262+
l.connClosersMut.Unlock()
258263
rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
259264
l.log.Info(ctx, "connection state change", slog.F("state", pcs.String()))
260265
switch pcs {
@@ -267,16 +272,16 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) {
267272
}
268273

269274
// Close connections opened when RTC was alive.
270-
l.connClosersMut.Lock()
271-
defer l.connClosersMut.Unlock()
272-
for _, connCloser := range l.connClosers {
275+
connClosersMut.Lock()
276+
defer connClosersMut.Unlock()
277+
for _, connCloser := range connClosers {
273278
_ = connCloser.Close()
274279
}
275-
l.connClosers = make([]io.Closer, 0)
280+
connClosers = make([]io.Closer, 0)
276281
})
277282

278283
flushCandidates := proxyICECandidates(rtc, conn)
279-
rtc.OnDataChannel(l.handle(ctx, msg))
284+
rtc.OnDataChannel(l.handle(ctx, msg, &connClosers, &connClosersMut))
280285

281286
l.log.Debug(ctx, "set remote description", slog.F("offer", *msg.Offer))
282287
err = rtc.SetRemoteDescription(*msg.Offer)
@@ -329,7 +334,7 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) {
329334
}
330335

331336
// nolint:gocognit
332-
func (l *listener) handle(ctx context.Context, msg BrokerMessage) func(dc *webrtc.DataChannel) {
337+
func (l *listener) handle(ctx context.Context, msg BrokerMessage, connClosers *[]io.Closer, connClosersMut *sync.Mutex) func(dc *webrtc.DataChannel) {
333338
return func(dc *webrtc.DataChannel) {
334339
if dc.Protocol() == controlChannel {
335340
// The control channel handles pings.
@@ -430,9 +435,9 @@ func (l *listener) handle(ctx context.Context, msg BrokerMessage) func(dc *webrt
430435
dc: dc,
431436
rw: rw,
432437
}
433-
l.connClosersMut.Lock()
434-
l.connClosers = append(l.connClosers, co)
435-
l.connClosersMut.Unlock()
438+
connClosersMut.Lock()
439+
*connClosers = append(*connClosers, co)
440+
connClosersMut.Unlock()
436441
co.init()
437442
defer nc.Close()
438443
defer co.Close()

0 commit comments

Comments
 (0)