diff --git a/internal/cmd/agent.go b/internal/cmd/agent.go index c19bdfee..7563ee4b 100644 --- a/internal/cmd/agent.go +++ b/internal/cmd/agent.go @@ -73,7 +73,7 @@ coder agent start --coder-url https://my-coder.com --token xxxx-xxxx } } - listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token), wsnet.TURNProxyWebSocket(u, token)) + listener, err := wsnet.Listen(context.Background(), wsnet.ListenEndpoint(u, token), token) if err != nil { return xerrors.Errorf("listen: %w", err) } diff --git a/internal/cmd/tunnel.go b/internal/cmd/tunnel.go index 7b14cf33..9c12dd37 100644 --- a/internal/cmd/tunnel.go +++ b/internal/cmd/tunnel.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/pion/webrtc/v3" "github.com/spf13/cobra" "golang.org/x/xerrors" @@ -107,7 +108,9 @@ func (c *tunnneler) start(ctx context.Context) error { ctx, wsnet.ConnectEndpoint(c.brokerAddr, c.workspaceID, c.token), &wsnet.DialOptions{ - TURNProxy: wsnet.TURNProxyWebSocket(c.brokerAddr, c.token), + TURNProxyAuthToken: c.token, + TURNProxyURL: c.brokerAddr, + ICEServers: []webrtc.ICEServer{wsnet.TURNProxyICECandidate()}, }, ) if err != nil { diff --git a/wsnet/conn.go b/wsnet/conn.go index 608c5c70..5b863f04 100644 --- a/wsnet/conn.go +++ b/wsnet/conn.go @@ -11,14 +11,14 @@ import ( "github.com/pion/datachannel" "github.com/pion/webrtc/v3" - "golang.org/x/net/proxy" "nhooyr.io/websocket" "cdr.dev/coder-cli/coder-sdk" ) const ( - httpScheme = "http" + httpScheme = "http" + turnProxyMagicUsername = "~magicalusername~" bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB maxBufferedAmount uint64 = 1024 * 1024 // 1 MB @@ -46,25 +46,17 @@ func ConnectEndpoint(baseURL *url.URL, workspace, token string) string { return fmt.Sprintf("%s://%s%s%s%s%s", wsScheme, baseURL.Host, "/api/private/envagent/", workspace, "/connect?session_token=", token) } -// TURNWebSocketICECandidate returns a valid relay ICEServer that can be used to -// trigger a TURNWebSocketDialer. +// TURNWebSocketICECandidate returns a fake TCP relay ICEServer. +// It's used to trigger the ICEProxyDialer. func TURNProxyICECandidate() webrtc.ICEServer { return webrtc.ICEServer{ URLs: []string{"turn:127.0.0.1:3478?transport=tcp"}, - Username: "~magicalusername~", - Credential: "~magicalpassword~", + Username: turnProxyMagicUsername, + Credential: turnProxyMagicUsername, CredentialType: webrtc.ICECredentialTypePassword, } } -// TURNWebSocketDialer proxies all TURN traffic through a WebSocket. -func TURNProxyWebSocket(baseURL *url.URL, token string) proxy.Dialer { - return &turnProxyDialer{ - baseURL: baseURL, - token: token, - } -} - // Proxies all TURN ICEServer traffic through this dialer. // References Coder APIs with a specific token. type turnProxyDialer struct { diff --git a/wsnet/dial.go b/wsnet/dial.go index 362bbab9..97d49827 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "net/url" "sync" "time" @@ -24,10 +25,13 @@ type DialOptions struct { // See: https://developer.mozilla.org/en-US/docs/Web/API/RTCConfiguration/iceServers ICEServers []webrtc.ICEServer - // TURNProxy is a function used to proxy all TURN traffic. - // If specified without ICEServers, `TURNProxyICECandidate` - // will be used. - TURNProxy proxy.Dialer + // TURNProxyAuthToken is used to authenticate a TURN proxy request. + TURNProxyAuthToken string + + // TURNProxyURL is the URL to proxy all TURN data through. + // This URL is sent to the listener during handshake so both + // ends connect to the same TURN endpoint. + TURNProxyURL *url.URL } // DialWebsocket dials the broker with a WebSocket and negotiates a connection. @@ -59,13 +63,15 @@ func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { if options.ICEServers == nil { options.ICEServers = []webrtc.ICEServer{} } - // If the TURNProxy is specified and ICEServers aren't, - // it's safe to assume we can inject the default proxy candidate. - if len(options.ICEServers) == 0 && options.TURNProxy != nil { - options.ICEServers = []webrtc.ICEServer{TURNProxyICECandidate()} - } - rtc, err := newPeerConnection(options.ICEServers, options.TURNProxy) + var turnProxy proxy.Dialer + if options.TURNProxyURL != nil { + turnProxy = &turnProxyDialer{ + baseURL: options.TURNProxyURL, + token: options.TURNProxyAuthToken, + } + } + rtc, err := newPeerConnection(options.ICEServers, turnProxy) if err != nil { return nil, fmt.Errorf("create peer connection: %w", err) } @@ -89,9 +95,15 @@ func Dial(conn net.Conn, options *DialOptions) (*Dialer, error) { return nil, fmt.Errorf("set local offer: %w", err) } + var turnProxyURL string + if options.TURNProxyURL != nil { + turnProxyURL = options.TURNProxyURL.String() + } + offerMessage, err := json.Marshal(&BrokerMessage{ - Offer: &offer, - Servers: options.ICEServers, + Offer: &offer, + Servers: options.ICEServers, + TURNProxyURL: turnProxyURL, }) if err != nil { return nil, fmt.Errorf("marshal offer message: %w", err) diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index 6ad27866..91b7d0a2 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -55,7 +55,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, nil) + _, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -75,7 +75,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, nil) + _, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -106,7 +106,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, nil) + _, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -145,7 +145,7 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr, nil) + _, err = Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -184,7 +184,7 @@ func TestDial(t *testing.T) { _, _ = listener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) - srv, err := Listen(context.Background(), listenAddr, nil) + srv, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -211,7 +211,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, nil) + _, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -245,7 +245,7 @@ func TestDial(t *testing.T) { }() connectAddr, listenAddr := createDumbBroker(t) - _, err = Listen(context.Background(), listenAddr, nil) + _, err = Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -282,7 +282,7 @@ func TestDial(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) - _, err := Listen(context.Background(), listenAddr, nil) + _, err := Listen(context.Background(), listenAddr, "") if err != nil { t.Error(err) return @@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) { } }() connectAddr, listenAddr := createDumbBroker(b) - _, err = Listen(context.Background(), listenAddr, nil) + _, err = Listen(context.Background(), listenAddr, "") if err != nil { b.Error(err) return diff --git a/wsnet/listen.go b/wsnet/listen.go index b29bbdb3..e159b6e3 100644 --- a/wsnet/listen.go +++ b/wsnet/listen.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "net/url" "sync" "time" @@ -40,11 +41,11 @@ type DialChannelResponse struct { // Listen connects to the broker proxies connections to the local net. // Close will end all RTC connections. -func Listen(ctx context.Context, broker string, tcpProxy proxy.Dialer) (io.Closer, error) { +func Listen(ctx context.Context, broker string, turnProxyAuthToken string) (io.Closer, error) { l := &listener{ - broker: broker, - connClosers: make([]io.Closer, 0), - tcpProxy: tcpProxy, + broker: broker, + connClosers: make([]io.Closer, 0), + turnProxyAuthToken: turnProxyAuthToken, } // We do a one-off dial outside of the loop to ensure the initial // connection is successful. If not, there's likely an error the @@ -85,8 +86,8 @@ func Listen(ctx context.Context, broker string, tcpProxy proxy.Dialer) (io.Close } type listener struct { - broker string - tcpProxy proxy.Dialer + broker string + turnProxyAuthToken string acceptError error ws *websocket.Conn @@ -189,7 +190,7 @@ func (l *listener) negotiate(conn net.Conn) { return } for _, server := range msg.Servers { - if server.Username == TURNProxyICECandidate().Username { + if server.Username == turnProxyMagicUsername { // This candidate is only used when proxying, // so it will not validate. continue @@ -200,7 +201,19 @@ func (l *listener) negotiate(conn net.Conn) { return } } - rtc, err = newPeerConnection(msg.Servers, l.tcpProxy) + var turnProxy proxy.Dialer + if msg.TURNProxyURL != "" { + u, err := url.Parse(msg.TURNProxyURL) + if err != nil { + closeError(fmt.Errorf("parse turn proxy url: %w", err)) + return + } + turnProxy = &turnProxyDialer{ + baseURL: u, + token: l.turnProxyAuthToken, + } + } + rtc, err = newPeerConnection(msg.Servers, turnProxy) if err != nil { closeError(err) return diff --git a/wsnet/listen_test.go b/wsnet/listen_test.go index 47b856c3..2c5ba35f 100644 --- a/wsnet/listen_test.go +++ b/wsnet/listen_test.go @@ -45,7 +45,7 @@ func TestListen(t *testing.T) { addr := listener.Addr() broker := fmt.Sprintf("http://%s/", addr.String()) - _, err = Listen(context.Background(), broker, nil) + _, err = Listen(context.Background(), broker, "") if err != nil { t.Error(err) return diff --git a/wsnet/proto.go b/wsnet/proto.go index 754fffac..feb4d126 100644 --- a/wsnet/proto.go +++ b/wsnet/proto.go @@ -49,8 +49,10 @@ func (p DialPolicy) permits(network, host string, port uint16) bool { // sides can begin exchanging candidates. type BrokerMessage struct { // Dialer -> Listener - Offer *webrtc.SessionDescription `json:"offer"` - Servers []webrtc.ICEServer `json:"servers"` + Offer *webrtc.SessionDescription `json:"offer"` + Servers []webrtc.ICEServer `json:"servers"` + TURNProxyURL string `json:"turn_proxy_url"` + // Policies denote which addresses the client can dial. If empty or nil, all // addresses are permitted. Policies []DialPolicy `json:"ports"`