Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit a8443d0

Browse files
authored
fix: Close Ping DataChannel when connection ends (#382)
Previously, Ping() would hang forever due to the DataChannel never closing when the RTC connection ended.
1 parent 6cc203a commit a8443d0

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

wsnet/dial.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) {
8686
ctrl: ctrl,
8787
rtc: rtc,
8888
closedChan: make(chan struct{}),
89-
connClosers: make([]io.Closer, 0),
89+
connClosers: []io.Closer{ctrl},
9090
}
9191

9292
return dialer, dialer.negotiate()

wsnet/dial_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ func ExampleDial_basic() {
5050
// nolint:gocognit,gocyclo
5151
func TestDial(t *testing.T) {
5252
t.Run("Ping", func(t *testing.T) {
53+
t.Parallel()
54+
5355
connectAddr, listenAddr := createDumbBroker(t)
5456
_, err := Listen(context.Background(), listenAddr)
5557
if err != nil {
@@ -67,7 +69,38 @@ func TestDial(t *testing.T) {
6769
}
6870
})
6971

72+
t.Run("Ping Close", func(t *testing.T) {
73+
t.Parallel()
74+
75+
connectAddr, listenAddr := createDumbBroker(t)
76+
_, err := Listen(context.Background(), listenAddr)
77+
if err != nil {
78+
t.Error(err)
79+
return
80+
}
81+
turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN)
82+
dialer, err := DialWebsocket(context.Background(), connectAddr, []webrtc.ICEServer{{
83+
URLs: []string{fmt.Sprintf("turn:%s", turnAddr)},
84+
Username: "example",
85+
Credential: testPass,
86+
CredentialType: webrtc.ICECredentialTypePassword,
87+
}})
88+
if err != nil {
89+
t.Error(err)
90+
return
91+
}
92+
_ = dialer.Ping(context.Background())
93+
closeTurn()
94+
err = dialer.Ping(context.Background())
95+
if err != io.EOF {
96+
t.Error(err)
97+
return
98+
}
99+
})
100+
70101
t.Run("OPError", func(t *testing.T) {
102+
t.Parallel()
103+
71104
connectAddr, listenAddr := createDumbBroker(t)
72105
_, err := Listen(context.Background(), listenAddr)
73106
if err != nil {
@@ -91,6 +124,8 @@ func TestDial(t *testing.T) {
91124
})
92125

93126
t.Run("Proxy", func(t *testing.T) {
127+
t.Parallel()
128+
94129
listener, err := net.Listen("tcp", "0.0.0.0:0")
95130
if err != nil {
96131
t.Error(err)
@@ -134,6 +169,8 @@ func TestDial(t *testing.T) {
134169

135170
// Expect that we'd get an EOF on the server closing.
136171
t.Run("EOF on Close", func(t *testing.T) {
172+
t.Parallel()
173+
137174
listener, err := net.Listen("tcp", "0.0.0.0:0")
138175
if err != nil {
139176
t.Error(err)
@@ -167,6 +204,8 @@ func TestDial(t *testing.T) {
167204
})
168205

169206
t.Run("Disconnect", func(t *testing.T) {
207+
t.Parallel()
208+
170209
connectAddr, listenAddr := createDumbBroker(t)
171210
_, err := Listen(context.Background(), listenAddr)
172211
if err != nil {
@@ -190,6 +229,8 @@ func TestDial(t *testing.T) {
190229
})
191230

192231
t.Run("Disconnect DialContext", func(t *testing.T) {
232+
t.Parallel()
233+
193234
tcpListener, err := net.Listen("tcp", "0.0.0.0:0")
194235
if err != nil {
195236
t.Error(err)
@@ -232,6 +273,8 @@ func TestDial(t *testing.T) {
232273
})
233274

234275
t.Run("Closed", func(t *testing.T) {
276+
t.Parallel()
277+
235278
connectAddr, listenAddr := createDumbBroker(t)
236279
_, err := Listen(context.Background(), listenAddr)
237280
if err != nil {

wsnet/rtc.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func newPeerConnection(servers []webrtc.ICEServer) (*webrtc.PeerConnection, erro
159159
se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4})
160160
se.SetSrflxAcceptanceMinWait(0)
161161
se.DetachDataChannels()
162-
se.SetICETimeouts(time.Second*5, time.Second*5, time.Second*2)
162+
se.SetICETimeouts(time.Second*3, time.Second*3, time.Second*2)
163163
lf := logging.NewDefaultLoggerFactory()
164164
lf.DefaultLogLevel = logging.LogLevelDisabled
165165
se.LoggerFactory = lf
@@ -252,6 +252,9 @@ func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) er
252252
if channel.ReadyState() == webrtc.DataChannelStateOpen {
253253
return nil
254254
}
255+
if channel.ReadyState() != webrtc.DataChannelStateConnecting {
256+
return fmt.Errorf("channel closed")
257+
}
255258
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
256259
defer cancelFunc()
257260
channel.OnOpen(func() {

0 commit comments

Comments
 (0)