diff --git a/wsnet/dial.go b/wsnet/dial.go index d6f45461..3880b12c 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -147,7 +147,6 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { // so it's better to buffer and process than fail. pendingCandidates = []webrtc.ICECandidateInit{} ) - go func() { defer close(errCh) defer func() { @@ -155,6 +154,9 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { }() err := waitForConnectionOpen(ctx, d.rtc) if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + _ = d.conn.Close() + } errCh <- err return } diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 5dd11b58..5d75cfd4 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "testing" + "time" "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" @@ -51,6 +52,18 @@ func ExampleDial_basic() { } func TestDial(t *testing.T) { + t.Run("Timeout", func(t *testing.T) { + t.Parallel() + + connectAddr, _ := createDumbBroker(t) + + ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancelFunc() + dialer, err := DialWebsocket(ctx, connectAddr, nil, nil) + require.True(t, errors.Is(err, context.DeadlineExceeded)) + require.Error(t, dialer.conn.Close(), "already wrote close") + }) + t.Run("Ping", func(t *testing.T) { t.Parallel() diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 79702743..32a089a2 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -256,7 +256,7 @@ func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) err }) <-ctx.Done() if ctx.Err() == context.DeadlineExceeded { - return ctx.Err() + return context.DeadlineExceeded } return nil } diff --git a/wsnet/wsnet_test.go b/wsnet/wsnet_test.go index ad9ac381..20aa7699 100644 --- a/wsnet/wsnet_test.go +++ b/wsnet/wsnet_test.go @@ -68,7 +68,9 @@ func createDumbBroker(t testing.TB) (connectAddr string, listenAddr string) { mut.Lock() defer mut.Unlock() if sess == nil { - t.Error("listen not called") + // We discard inbound to emulate a pubsub where we don't know if anyone + // is listening on the other side. + _, _ = io.Copy(io.Discard, nc) return } oc, err := sess.Open()