diff --git a/wsnet/dial.go b/wsnet/dial.go index 394e6a9f..d6f45461 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -52,11 +52,11 @@ func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsO // We should close the socket intentionally. _ = conn.Close(websocket.StatusInternalError, "an error occurred") }() - return Dial(nconn, netOpts) + return Dial(ctx, nconn, netOpts) } // Dial negotiates a connection to a listener. -func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { +func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, error) { if options == nil { options = &DialOptions{} } @@ -121,7 +121,7 @@ func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { connClosers: []io.Closer{ctrl}, } - return dialer, dialer.negotiate() + return dialer, dialer.negotiate(ctx) } // Dialer enables arbitrary dialing to any network and address @@ -138,7 +138,7 @@ type Dialer struct { pingMut sync.Mutex } -func (d *Dialer) negotiate() (err error) { +func (d *Dialer) negotiate(ctx context.Context) (err error) { var ( decoder = json.NewDecoder(d.conn) errCh = make(chan error) @@ -153,7 +153,7 @@ func (d *Dialer) negotiate() (err error) { defer func() { _ = d.conn.Close() }() - err := waitForConnectionOpen(context.Background(), d.rtc) + err := waitForConnectionOpen(ctx, d.rtc) if err != nil { errCh <- err return diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 0b93c57c..79702743 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -242,11 +242,16 @@ func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) err if conn.ConnectionState() == webrtc.PeerConnectionStateConnected { return nil } - ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) - defer cancelFunc() + var cancel context.CancelFunc + if _, deadlineSet := ctx.Deadline(); deadlineSet { + ctx, cancel = context.WithCancel(ctx) + } else { + ctx, cancel = context.WithTimeout(ctx, time.Second*15) + } + defer cancel() conn.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { if pcs == webrtc.PeerConnectionStateConnected { - cancelFunc() + cancel() } }) <-ctx.Done()