diff --git a/wsnet/dial.go b/wsnet/dial.go index af4b422c..283bedf4 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -137,6 +137,7 @@ type Dialer struct { closedChan chan struct{} connClosers []io.Closer connClosersMut sync.Mutex + pingMut sync.Mutex } func (d *Dialer) negotiate() (err error) { @@ -160,7 +161,7 @@ func (d *Dialer) negotiate() (err error) { return } d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { - if pcs == webrtc.PeerConnectionStateConnected { + if pcs != webrtc.PeerConnectionStateDisconnected { return } @@ -178,6 +179,7 @@ func (d *Dialer) negotiate() (err error) { default: } close(d.closedChan) + _ = d.rtc.Close() }) }() @@ -263,7 +265,7 @@ func (d *Dialer) Ping(ctx context.Context) error { // Since we control the client and server we could open this // data channel with `Negotiated` true to reduce traffic being // sent when the RTC connection is opened. - err := waitForDataChannelOpen(context.Background(), d.ctrl) + err := waitForDataChannelOpen(ctx, d.ctrl) if err != nil { return err } @@ -273,13 +275,28 @@ func (d *Dialer) Ping(ctx context.Context) error { return err } } + d.pingMut.Lock() + defer d.pingMut.Unlock() _, err = d.ctrlrw.Write([]byte{'a'}) if err != nil { return fmt.Errorf("write: %w", err) } - b := make([]byte, 4) - _, err = d.ctrlrw.Read(b) - return err + errCh := make(chan error) + go func() { + // There's a race in which connections can get lost-mid ping + // in which case this would block forever. + defer close(errCh) + _, err = d.ctrlrw.Read(make([]byte, 4)) + errCh <- err + }() + ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) + defer cancelFunc() + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return ctx.Err() + } } // DialContext dials the network and address on the remote listener. diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 05c04f1b..0b93c57c 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -160,7 +160,9 @@ func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4}) se.SetSrflxAcceptanceMinWait(0) se.DetachDataChannels() - se.SetICETimeouts(time.Second*3, time.Second*3, time.Second*2) + // If the disconnect and keep-alive timeouts are too closely related, we'll + // experience "random" connection failures. + se.SetICETimeouts(time.Second*5, time.Second*25, time.Second*2) lf := logging.NewDefaultLoggerFactory() lf.DefaultLogLevel = logging.LogLevelDisabled se.LoggerFactory = lf