diff --git a/wsnet/dial.go b/wsnet/dial.go index 637bc5fd..c72dc513 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -85,6 +85,7 @@ func Dial(conn net.Conn, iceServers []webrtc.ICEServer) (*Dialer, error) { conn: conn, ctrl: ctrl, rtc: rtc, + closedChan: make(chan struct{}), connClosers: make([]io.Closer, 0), } @@ -100,6 +101,7 @@ type Dialer struct { ctrlrw datachannel.ReadWriteCloser rtc *webrtc.PeerConnection + closedChan chan struct{} connClosers []io.Closer connClosersMut sync.Mutex } @@ -136,6 +138,13 @@ func (d *Dialer) negotiate() (err error) { _ = connCloser.Close() } d.connClosers = make([]io.Closer, 0) + + select { + case <-d.closedChan: + return + default: + } + close(d.closedChan) }) }() @@ -184,6 +193,12 @@ func (d *Dialer) negotiate() (err error) { return <-errCh } +// Closed returns a channel that closes when +// the connection is closed. +func (d *Dialer) Closed() <-chan struct{} { + return d.closedChan +} + // Close closes the RTC connection. // All data channels dialed will be closed. func (d *Dialer) Close() error { diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 71fdc8c7..9b412a3e 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -10,6 +10,7 @@ import ( "net" "strconv" "testing" + "time" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" @@ -46,7 +47,7 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } -// nolint:gocognit +// nolint:gocognit,gocyclo func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { connectAddr, listenAddr := createDumbBroker(t) @@ -229,6 +230,28 @@ func TestDial(t *testing.T) { return } }) + + t.Run("Closed", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + _, err := Listen(context.Background(), listenAddr) + if err != nil { + t.Error(err) + return + } + dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + if err != nil { + t.Error(err) + return + } + go func() { + _ = dialer.Close() + }() + select { + case <-dialer.Closed(): + case <-time.NewTimer(time.Second).C: + t.Error("didn't close in time") + } + }) } func BenchmarkThroughput(b *testing.B) {