diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 08dde84b..b77bf8aa 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.16.3 +FROM golang:1.16.5 ENV GOFLAGS="-mod=readonly" ENV CI=true diff --git a/wsnet/conn.go b/wsnet/conn.go index de67c3c4..40fa50ae 100644 --- a/wsnet/conn.go +++ b/wsnet/conn.go @@ -129,6 +129,8 @@ type dataChannelConn struct { } func (c *dataChannelConn) init() { + c.closedMutex.Lock() + defer c.closedMutex.Unlock() c.sendMore = make(chan struct{}, 1) c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) c.dc.OnBufferedAmountLow(func() { diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 5d75cfd4..79e7ae39 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -266,6 +266,43 @@ func TestDial(t *testing.T) { _ = conn.Close() assert.Equal(t, 1, dialer.activeConnections()) }) + + t.Run("Close Listeners on Disconnect", func(t *testing.T) { + t.Parallel() + + tcpListener, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + go func() { + _, _ = tcpListener.Accept() + }() + + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) + dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ + ICEServers: []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: testPass, + CredentialType: webrtc.ICECredentialTypePassword, + }}, + }, nil) + require.NoError(t, err) + + _, err = dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) + require.NoError(t, err) + + closeTurn() + + list := l.(*listener) + assert.Eventually(t, func() bool { + list.connClosersMut.Lock() + defer list.connClosersMut.Unlock() + return len(list.connClosers) == 0 + }, time.Second*15, time.Millisecond*100) + }) } func BenchmarkThroughput(b *testing.B) { diff --git a/wsnet/listen.go b/wsnet/listen.go index 02f13f41..803e140b 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -159,9 +159,11 @@ func (l *listener) dial(ctx context.Context) (<-chan error, error) { // so the cognitive overload linter has been disabled. // nolint:gocognit,nestif func (l *listener) negotiate(ctx context.Context, conn net.Conn) { + id := atomic.AddInt64(&l.nextConnNumber, 1) + ctx = slog.With(ctx, slog.F("conn_id", id)) + var ( err error - id = atomic.AddInt64(&l.nextConnNumber, 1) decoder = json.NewDecoder(conn) rtc *webrtc.PeerConnection // If candidates are sent before an offer, we place them here. @@ -171,7 +173,7 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) { // Sends the error provided then closes the connection. // If RTC isn't connected, we'll close it. closeError = func(err error) { - l.log.Warn(ctx, "negotiation error, closing connection", slog.Error(err)) + // l.log.Warn(ctx, "negotiation error, closing connection", slog.Error(err)) d, _ := json.Marshal(&BrokerMessage{ Error: err.Error(), @@ -187,7 +189,6 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) { } ) - ctx = slog.With(ctx, slog.F("conn_id", id)) l.log.Info(ctx, "accepted new session from broker connection, negotiating") for { @@ -255,17 +256,26 @@ func (l *listener) negotiate(ctx context.Context, conn net.Conn) { return } rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { - l.log.Debug(ctx, "connection state change", slog.F("state", pcs.String())) - if pcs == webrtc.PeerConnectionStateConnecting { + l.log.Info(ctx, "connection state change", slog.F("state", pcs.String())) + switch pcs { + case webrtc.PeerConnectionStateConnected: return + case webrtc.PeerConnectionStateConnecting: + // Safe to close the negotiating WebSocket. + _ = conn.Close() + return + } + + // Close connections opened when RTC was alive. + l.connClosersMut.Lock() + defer l.connClosersMut.Unlock() + for _, connCloser := range l.connClosers { + _ = connCloser.Close() } - _ = conn.Close() + l.connClosers = make([]io.Closer, 0) }) flushCandidates := proxyICECandidates(rtc, conn) - l.connClosersMut.Lock() - l.connClosers = append(l.connClosers, rtc) - l.connClosersMut.Unlock() rtc.OnDataChannel(l.handle(ctx, msg)) l.log.Debug(ctx, "set remote description", slog.F("offer", *msg.Offer)) @@ -420,6 +430,9 @@ func (l *listener) handle(ctx context.Context, msg BrokerMessage) func(dc *webrt dc: dc, rw: rw, } + l.connClosersMut.Lock() + l.connClosers = append(l.connClosers, co) + l.connClosersMut.Unlock() co.init() defer nc.Close() defer co.Close()