From 6305d31f4951ed235406e2a0b4d55f4772dd6467 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 9 Jul 2021 00:16:16 -0500 Subject: [PATCH] fix: Close Ping DataChannel when connection ends Previously, Ping() would hang forever due to the DataChannel never closing when the RTC connection ended. --- wsnet/dial.go | 2 +- wsnet/dial_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ wsnet/rtc.go | 5 ++++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/wsnet/dial.go b/wsnet/dial.go index c72dc513..0beb2232 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -86,7 +86,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { ctrl: ctrl, rtc: rtc, closedChan: make(chan struct{}), - connClosers: make([]io.Closer, 0), + connClosers: []io.Closer{ctrl}, } return dialer, dialer.negotiate() diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 9b412a3e..5d2e3884 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -50,6 +50,8 @@ func ExampleDial_basic() { // nolint:gocognit,gocyclo func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { + t.Parallel() + connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), listenAddr) if err != nil { @@ -67,7 +69,38 @@ func TestDial(t *testing.T) { } }) + t.Run("Ping Close", func(t *testing.T) { + t.Parallel() + + connectAddr, listenAddr := createDumbBroker(t) + _, err := Listen(context.Background(), listenAddr) + if err != nil { + t.Error(err) + return + } + turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) + dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{ + URLs: []string{fmt.Sprintf("turn:%s", turnAddr)}, + Username: "example", + Credential: testPass, + CredentialType: webrtc.ICECredentialTypePassword, + }}) + if err != nil { + t.Error(err) + return + } + _ = dialer.Ping(context.Background()) + closeTurn() + err = dialer.Ping(context.Background()) + if err != io.EOF { + t.Error(err) + return + } + }) + t.Run("OPError", func(t *testing.T) { + t.Parallel() + connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), listenAddr) if err != nil { @@ -91,6 +124,8 @@ func TestDial(t *testing.T) { }) t.Run("Proxy", func(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) @@ -134,6 +169,8 @@ func TestDial(t *testing.T) { // Expect that we'd get an EOF on the server closing. t.Run("EOF on Close", func(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) @@ -167,6 +204,8 @@ func TestDial(t *testing.T) { }) t.Run("Disconnect", func(t *testing.T) { + t.Parallel() + connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), listenAddr) if err != nil { @@ -190,6 +229,8 @@ func TestDial(t *testing.T) { }) t.Run("Disconnect DialContext", func(t *testing.T) { + t.Parallel() + tcpListener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) @@ -232,6 +273,8 @@ func TestDial(t *testing.T) { }) t.Run("Closed", func(t *testing.T) { + t.Parallel() + connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), listenAddr) if err != nil { diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 4d454311..e8b5eab3 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -159,7 +159,7 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4}) se.SetSrflxAcceptanceMinWait(0) se.DetachDataChannels() - se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2) + se.SetICETimeouts(time.Second*3, time.Second*3, time.Second*2) lf := logging.NewDefaultLoggerFactory() lf.DefaultLogLevel = logging.LogLevelDisabled se.LoggerFactory = lf @@ -252,6 +252,9 @@ func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) er if channel.ReadyState() == webrtc.DataChannelStateOpen { return nil } + if channel.ReadyState() != webrtc.DataChannelStateConnecting { + return fmt.Errorf("channel closed") + } ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) defer cancelFunc() channel.OnOpen(func() {