diff --git a/Makefile b/Makefile index abbc64e8a07cc..09a93a38733ff 100644 --- a/Makefile +++ b/Makefile @@ -386,7 +386,6 @@ lint/shellcheck: $(shell shfmt -f .) gen: \ coderd/database/dump.sql \ coderd/database/querier.go \ - peerbroker/proto/peerbroker.pb.go \ provisionersdk/proto/provisioner.pb.go \ provisionerd/proto/provisionerd.pb.go \ site/src/api/typesGenerated.ts @@ -395,7 +394,7 @@ gen: \ # Mark all generated files as fresh so make thinks they're up-to-date. This is # used during releases so we don't run generation scripts. gen/mark-fresh: - files="coderd/database/dump.sql coderd/database/querier.go peerbroker/proto/peerbroker.pb.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts" + files="coderd/database/dump.sql coderd/database/querier.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts" for file in $$files; do echo "$$file" if [ ! -f "$$file" ]; then @@ -417,14 +416,6 @@ coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/dat coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go ./coderd/database/generate.sh -peerbroker/proto/peerbroker.pb.go: peerbroker/proto/peerbroker.proto - protoc \ - --go_out=. \ - --go_opt=paths=source_relative \ - --go-drpc_out=. \ - --go-drpc_opt=paths=source_relative \ - ./peerbroker/proto/peerbroker.proto - provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto protoc \ --go_out=. \ diff --git a/agent/agent.go b/agent/agent.go index 6018062b37f7b..18243ee788789 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -11,7 +11,6 @@ import ( "io" "net" "net/netip" - "net/url" "os" "os/exec" "os/user" @@ -34,8 +33,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent/usershell" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" "github.com/coder/coder/pty" "github.com/coder/coder/tailnet" "github.com/coder/retry" @@ -64,7 +61,6 @@ var ( type Options struct { CoordinatorDialer CoordinatorDialer - WebRTCDialer WebRTCDialer FetchMetadata FetchMetadata StatsReporter StatsReporter @@ -80,8 +76,6 @@ type Metadata struct { Directory string `json:"directory"` } -type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) - // CoordinatorDialer is a function that constructs a new broker. // A dialer must be passed in to allow for reconnects. type CoordinatorDialer func(ctx context.Context) (net.Conn, error) @@ -95,7 +89,6 @@ func New(options Options) io.Closer { } ctx, cancelFunc := context.WithCancel(context.Background()) server := &agent{ - webrtcDialer: options.WebRTCDialer, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, closeCancel: cancelFunc, @@ -111,8 +104,7 @@ func New(options Options) io.Closer { } type agent struct { - webrtcDialer WebRTCDialer - logger slog.Logger + logger slog.Logger reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration @@ -173,9 +165,6 @@ func (a *agent) run(ctx context.Context) { } }() - if a.webrtcDialer != nil { - go a.runWebRTCNetworking(ctx) - } if metadata.DERPMap != nil { go a.runTailnet(ctx, metadata.DERPMap) } @@ -326,49 +315,6 @@ func (a *agent) runCoordinator(ctx context.Context) { } } -func (a *agent) runWebRTCNetworking(ctx context.Context) { - var peerListener *peerbroker.Listener - var err error - // An exponential back-off occurs when the connection is failing to dial. - // This is to prevent server spam in case of a coderd outage. - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - peerListener, err = a.webrtcDialer(ctx, a.logger) - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - if a.isClosed() { - return - } - a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) - continue - } - a.logger.Info(context.Background(), "connected to webrtc broker") - break - } - select { - case <-ctx.Done(): - return - default: - } - - for { - conn, err := peerListener.Accept() - if err != nil { - if a.isClosed() { - return - } - a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) - a.runWebRTCNetworking(ctx) - return - } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() - go a.handlePeerConn(ctx, conn) - } -} - func (a *agent) runStartupScript(ctx context.Context, script string) error { if script == "" { return nil @@ -401,74 +347,6 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error { return nil } -func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) { - go func() { - select { - case <-a.closed: - case <-peerConn.Closed(): - } - _ = peerConn.Close() - a.connCloseWait.Done() - }() - for { - channel, err := peerConn.Accept(ctx) - if err != nil { - if errors.Is(err, peer.ErrClosed) || a.isClosed() { - return - } - a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) - return - } - - conn := channel.NetConn() - - switch channel.Protocol() { - case ProtocolSSH: - go a.sshServer.HandleConn(a.stats.wrapConn(conn)) - case ProtocolReconnectingPTY: - rawID := channel.Label() - // The ID format is referenced in conn.go. - // :: - idParts := strings.SplitN(rawID, ":", 4) - if len(idParts) != 4 { - a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID)) - continue - } - id := idParts[0] - // Enforce a consistent format for IDs. - _, err := uuid.Parse(id) - if err != nil { - a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err)) - continue - } - // Parse the initial terminal dimensions. - height, err := strconv.Atoi(idParts[1]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1])) - continue - } - width, err := strconv.Atoi(idParts[2]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2])) - continue - } - go a.handleReconnectingPTY(ctx, reconnectingPTYInit{ - ID: id, - Height: uint16(height), - Width: uint16(width), - Command: idParts[3], - }, a.stats.wrapConn(conn)) - case ProtocolDial: - go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn)) - default: - a.logger.Warn(ctx, "unhandled protocol from channel", - slog.F("protocol", channel.Protocol()), - slog.F("label", channel.Label()), - ) - } - } -} - func (a *agent) init(ctx context.Context) { a.logger.Info(ctx, "generating host key") // Clients' should ignore the host key when connecting. @@ -915,70 +793,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYIn } } -// dialResponse is written to datachannels with protocol "dial" by the agent as -// the first packet to signify whether the dial succeeded or failed. -type dialResponse struct { - Error string `json:"error,omitempty"` -} - -func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { - defer conn.Close() - - writeError := func(responseError error) error { - msg := "" - if responseError != nil { - msg = responseError.Error() - if !xerrors.Is(responseError, io.EOF) { - a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError)) - } - } - b, err := json.Marshal(dialResponse{ - Error: msg, - }) - if err != nil { - a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err)) - return xerrors.Errorf("marshal agent webrtc dial response: %w", err) - } - - _, err = conn.Write(b) - return err - } - - u, err := url.Parse(label) - if err != nil { - _ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) - return - } - - network := u.Scheme - addr := u.Host + u.Path - if strings.HasPrefix(network, "unix") { - if runtime.GOOS == "windows" { - _ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces")) - return - } - addr, err = ExpandRelativeHomePath(addr) - if err != nil { - _ = writeError(xerrors.Errorf("expand path %q: %w", addr, err)) - return - } - } - - d := net.Dialer{Timeout: 3 * time.Second} - nconn, err := d.DialContext(ctx, network, addr) - if err != nil { - _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) - return - } - - err = writeError(nil) - if err != nil { - return - } - - Bicopy(ctx, conn, nconn) -} - // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { select { diff --git a/agent/agent_test.go b/agent/agent_test.go index 9f75f42363ea6..08c7918765319 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -20,12 +20,10 @@ import ( "golang.org/x/xerrors" "tailscale.com/net/speedtest" - "tailscale.com/tailcfg" scp "github.com/bramvdbogaerde/go-scp" "github.com/google/uuid" "github.com/pion/udp" - "github.com/pion/webrtc/v3" "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -37,10 +35,6 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" @@ -54,64 +48,49 @@ func TestMain(m *testing.M) { func TestAgent(t *testing.T) { t.Parallel() t.Run("Stats", func(t *testing.T) { - for _, tailscale := range []bool{true, false} { - t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) { - t.Parallel() + t.Parallel() - setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) { - var derpMap *tailcfg.DERPMap - if tailscale { - derpMap = tailnettest.RunDERPAndSTUN(t) - } - conn, stats := setupAgent(t, agent.Metadata{ - DERPMap: derpMap, - }, 0) - assert.Empty(t, <-stats) - return conn, stats - } + t.Run("SSH", func(t *testing.T) { + t.Parallel() + conn, stats := setupAgent(t, agent.Metadata{}, 0) + + sshClient, err := conn.SSHClient() + require.NoError(t, err) + defer sshClient.Close() + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + assert.EqualValues(t, 1, (<-stats).NumConns) + assert.Greater(t, (<-stats).RxBytes, int64(0)) + assert.Greater(t, (<-stats).TxBytes, int64(0)) + }) + + t.Run("ReconnectingPTY", func(t *testing.T) { + t.Parallel() + + conn, stats := setupAgent(t, agent.Metadata{}, 0) - t.Run("SSH", func(t *testing.T) { - t.Parallel() - conn, stats := setupAgent(t) - - sshClient, err := conn.SSHClient() - require.NoError(t, err) - session, err := sshClient.NewSession() - require.NoError(t, err) - defer session.Close() - - assert.EqualValues(t, 1, (<-stats).NumConns) - assert.Greater(t, (<-stats).RxBytes, int64(0)) - assert.Greater(t, (<-stats).TxBytes, int64(0)) - }) - - t.Run("ReconnectingPTY", func(t *testing.T) { - t.Parallel() - - conn, stats := setupAgent(t) - - ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") - require.NoError(t, err) - defer ptyConn.Close() - - data, err := json.Marshal(agent.ReconnectingPTYRequest{ - Data: "echo test\r\n", - }) - require.NoError(t, err) - _, err = ptyConn.Write(data) - require.NoError(t, err) - - var s *agent.Stats - require.Eventuallyf(t, func() bool { - var ok bool - s, ok = (<-stats) - return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 - }, testutil.WaitLong, testutil.IntervalFast, - "never saw stats: %+v", s, - ) - }) + ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") + require.NoError(t, err) + defer ptyConn.Close() + + data, err := json.Marshal(agent.ReconnectingPTYRequest{ + Data: "echo test\r\n", }) - } + require.NoError(t, err) + _, err = ptyConn.Write(data) + require.NoError(t, err) + + var s *agent.Stats + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = (<-stats) + return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats: %+v", s, + ) + }) }) t.Run("SessionExec", func(t *testing.T) { @@ -235,6 +214,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, agent.Metadata{}, 0) sshClient, err := conn.SSHClient() require.NoError(t, err) + defer sshClient.Close() client, err := sftp.NewClient(sshClient) require.NoError(t, err) tempFile := filepath.Join(t.TempDir(), "sftp") @@ -252,6 +232,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, agent.Metadata{}, 0) sshClient, err := conn.SSHClient() require.NoError(t, err) + defer sshClient.Close() scpClient, err := scp.NewClientBySSH(sshClient) require.NoError(t, err) tempFile := filepath.Join(t.TempDir(), "scp") @@ -384,9 +365,7 @@ func TestAgent(t *testing.T) { t.Skip("ConPTY appears to be inconsistent on Windows.") } - conn, _ := setupAgent(t, agent.Metadata{ - DERPMap: tailnettest.RunDERPAndSTUN(t), - }, 0) + conn, _ := setupAgent(t, agent.Metadata{}, 0) id := uuid.NewString() netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -462,19 +441,6 @@ func TestAgent(t *testing.T) { return l }, }, - { - name: "Unix", - setup: func(t *testing.T) net.Listener { - if runtime.GOOS == "windows" { - t.Skip("Unix socket forwarding isn't supported on Windows") - } - - tmpDir := t.TempDir() - l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) - require.NoError(t, err, "create UDP listener") - return l - }, - }, } for _, c := range cases { @@ -496,8 +462,11 @@ func TestAgent(t *testing.T) { } }() - // Dial the listener over WebRTC twice and test out of order conn, _ := setupAgent(t, agent.Metadata{}, 0) + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn1.Close() @@ -506,36 +475,11 @@ func TestAgent(t *testing.T) { defer conn2.Close() testDial(t, conn2) testDial(t, conn1) + time.Sleep(150 * time.Millisecond) }) } }) - t.Run("DialError", func(t *testing.T) { - t.Parallel() - - if runtime.GOOS == "windows" { - // This test uses Unix listeners so we can very easily ensure that - // no other tests decide to listen on the same random port we - // picked. - t.Skip("this test is unsupported on Windows") - return - } - - tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") - require.NoError(t, err, "create temp dir") - t.Cleanup(func() { - _ = os.RemoveAll(tmpDir) - }) - - // Try to dial the non-existent Unix socket over WebRTC - conn, _ := setupAgent(t, agent.Metadata{}, 0) - netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock")) - require.Error(t, err) - require.ErrorContains(t, err, "remote dial error") - require.ErrorContains(t, err, "no such file") - require.Nil(t, netConn) - }) - t.Run("Tailnet", func(t *testing.T) { t.Parallel() derpMap := tailnettest.RunDERPAndSTUN(t) @@ -578,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe return } ssh, err := agentConn.SSH() - if !assert.NoError(t, err) { + if err != nil { _ = conn.Close() return } @@ -622,11 +566,12 @@ func (c closeFunc) Close() error { } func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) ( - agent.Conn, + *agent.Conn, <-chan *agent.Stats, ) { - client, server := provisionersdk.TransportPipe() - tailscale := metadata.DERPMap != nil + if metadata.DERPMap == nil { + metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) + } coordinator := tailnet.NewCoordinator() agentID := uuid.New() statsCh := make(chan *agent.Stats) @@ -634,17 +579,18 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) FetchMetadata: func(ctx context.Context) (agent.Metadata, error) { return metadata, nil }, - WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { - listener, err := peerbroker.Listen(server, nil) - return listener, err - }, CoordinatorDialer: func(ctx context.Context) (net.Conn, error) { clientConn, serverConn := net.Pipe() + closed := make(chan struct{}) t.Cleanup(func() { _ = serverConn.Close() _ = clientConn.Close() + <-closed }) - go coordinator.ServeAgent(serverConn, agentID) + go func() { + _ = coordinator.ServeAgent(serverConn, agentID) + close(closed) + }() return clientConn, nil }, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), @@ -683,46 +629,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) }, }) t.Cleanup(func() { - _ = client.Close() - _ = server.Close() _ = closer.Close() }) - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(context.Background()) - assert.NoError(t, err) - if tailscale { - conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: metadata.DERPMap, - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - clientConn, serverConn := net.Pipe() - t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() - _ = conn.Close() - }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { - return conn.UpdateNodes(node) - }) - conn.SetNodeCallback(sendNode) - return &agent.TailnetConn{ - Conn: conn, - }, statsCh - } - conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: metadata.DERPMap, + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), }) require.NoError(t, err) + clientConn, serverConn := net.Pipe() t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() _ = conn.Close() }) - - return &agent.WebRTCConn{ - Negotiator: api, - Conn: conn, + go coordinator.ServeClient(serverConn, uuid.New(), agentID) + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNode) + return &agent.Conn{ + Conn: conn, }, statsCh } diff --git a/agent/conn.go b/agent/conn.go index 2e1b45ea0b48a..b64e935af7ecc 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -4,13 +4,9 @@ import ( "context" "encoding/binary" "encoding/json" - "fmt" - "io" "net" "net/netip" - "net/url" "strconv" - "strings" "time" "golang.org/x/crypto/ssh" @@ -19,8 +15,6 @@ import ( "tailscale.com/net/speedtest" "tailscale.com/tailcfg" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/tailnet" ) @@ -32,123 +26,12 @@ type ReconnectingPTYRequest struct { Width uint16 `json:"width"` } -// Conn is a temporary interface while we switch from WebRTC to Wireguard networking. -type Conn interface { - io.Closer - Closed() <-chan struct{} - Ping() (time.Duration, error) - CloseWithError(err error) error - ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) - SSH() (net.Conn, error) - Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) - SSHClient() (*ssh.Client, error) - DialContext(ctx context.Context, network string, addr string) (net.Conn, error) -} - -// Conn wraps a peer connection with helper functions to -// communicate with the agent. -type WebRTCConn struct { - // Negotiator is responsible for exchanging messages. - Negotiator proto.DRPCPeerBrokerClient - - *peer.Conn -} - -// ReconnectingPTY returns a connection serving a TTY that can -// be reconnected to via ID. -// -// The command is optional and defaults to start a shell. -func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { - channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{ - Protocol: ProtocolReconnectingPTY, - }) - if err != nil { - return nil, xerrors.Errorf("pty: %w", err) - } - return channel.NetConn(), nil -} - -// SSH dials the built-in SSH server. -func (c *WebRTCConn) SSH() (net.Conn, error) { - channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ - Protocol: ProtocolSSH, - }) - if err != nil { - return nil, xerrors.Errorf("dial: %w", err) - } - return channel.NetConn(), nil -} - -func (*WebRTCConn) Speedtest(_ speedtest.Direction, _ time.Duration) ([]speedtest.Result, error) { - return nil, xerrors.New("not implemented") -} - -// SSHClient calls SSH to create a client that uses a weak cipher -// for high throughput. -func (c *WebRTCConn) SSHClient() (*ssh.Client, error) { - netConn, err := c.SSH() - if err != nil { - return nil, xerrors.Errorf("ssh: %w", err) - } - sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ - // SSH host validation isn't helpful, because obtaining a peer - // connection already signifies user-intent to dial a workspace. - // #nosec - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }) - if err != nil { - return nil, xerrors.Errorf("ssh conn: %w", err) - } - return ssh.NewClient(sshConn, channels, requests), nil -} - -// DialContext dials an arbitrary protocol+address from inside the workspace and -// proxies it through the provided net.Conn. -func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { - u := &url.URL{ - Scheme: network, - } - if strings.HasPrefix(network, "unix") { - u.Path = addr - } else { - u.Host = addr - } - - channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ - Protocol: ProtocolDial, - Unordered: strings.HasPrefix(network, "udp"), - }) - if err != nil { - return nil, xerrors.Errorf("create datachannel: %w", err) - } - - // The first message written from the other side is a JSON payload - // containing the dial error. - dec := json.NewDecoder(channel) - var res dialResponse - err = dec.Decode(&res) - if err != nil { - return nil, xerrors.Errorf("decode agent dial response: %w", err) - } - if res.Error != "" { - _ = channel.Close() - return nil, xerrors.Errorf("remote dial error: %v", res.Error) - } - - return channel.NetConn(), nil -} - -func (c *WebRTCConn) Close() error { - _ = c.Negotiator.DRPCConn().Close() - return c.Conn.Close() -} - -type TailnetConn struct { +type Conn struct { *tailnet.Conn CloseFunc func() } -func (c *TailnetConn) Ping() (time.Duration, error) { +func (c *Conn) Ping() (time.Duration, error) { errCh := make(chan error, 1) durCh := make(chan time.Duration, 1) c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) { @@ -166,11 +49,11 @@ func (c *TailnetConn) Ping() (time.Duration, error) { } } -func (c *TailnetConn) CloseWithError(_ error) error { +func (c *Conn) CloseWithError(_ error) error { return c.Close() } -func (c *TailnetConn) Close() error { +func (c *Conn) Close() error { if c.CloseFunc != nil { c.CloseFunc() } @@ -184,7 +67,7 @@ type reconnectingPTYInit struct { Command string } -func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { +func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort))) if err != nil { return nil, err @@ -210,13 +93,13 @@ func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command s return conn, nil } -func (c *TailnetConn) SSH() (net.Conn, error) { +func (c *Conn) SSH() (net.Conn, error) { return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort))) } // SSHClient calls SSH to create a client that uses a weak cipher // for high throughput. -func (c *TailnetConn) SSHClient() (*ssh.Client, error) { +func (c *Conn) SSHClient() (*ssh.Client, error) { netConn, err := c.SSH() if err != nil { return nil, xerrors.Errorf("ssh: %w", err) @@ -233,7 +116,7 @@ func (c *TailnetConn) SSHClient() (*ssh.Client, error) { return ssh.NewClient(sshConn, channels, requests), nil } -func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { +func (c *Conn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSpeedtestPort))) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) @@ -245,7 +128,10 @@ func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Dur return results, err } -func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + if network == "unix" { + return nil, xerrors.New("network must be tcp or udp") + } _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.Atoi(rawPort) ipp := netip.AddrPortFrom(tailnetIP, uint16(port)) diff --git a/cli/agent.go b/cli/agent.go index 2c6fdef4a03ce..837d30eb37176 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -32,7 +32,6 @@ func workspaceAgent() *cobra.Command { pprofEnabled bool pprofAddress string noReap bool - wireguard bool ) cmd := &cobra.Command{ Use: "agent", @@ -184,7 +183,6 @@ func workspaceAgent() *cobra.Command { closer := agent.New(agent.Options{ FetchMetadata: client.WorkspaceAgentMetadata, - WebRTCDialer: client.ListenWorkspaceAgent, Logger: logger, EnvironmentVariables: map[string]string{ // Override the "CODER_AGENT_TOKEN" variable in all @@ -203,6 +201,5 @@ func workspaceAgent() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.") cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.") cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.") return cmd } diff --git a/cli/agent_test.go b/cli/agent_test.go index 82c199cd6268f..6dd8849b74d79 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -7,10 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/testutil" ) func TestWorkspaceAgent(t *testing.T) { @@ -63,11 +66,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) @@ -121,11 +126,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) @@ -179,11 +186,13 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() err = <-errC require.NoError(t, err) diff --git a/cli/configssh.go b/cli/configssh.go index 0ff27ca29f16b..373257403110d 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -139,7 +139,6 @@ func configSSH() *cobra.Command { usePreviousOpts bool dryRun bool skipProxyCommand bool - wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -289,15 +288,11 @@ func configSSH() *cobra.Command { "\tLogLevel ERROR", ) if !skipProxyCommand { - wgArg := "" - if wireguard { - wgArg = "--wireguard " - } configOptions = append( configOptions, fmt.Sprintf( - "\tProxyCommand %s --global-config %s ssh %s--stdio %s", - escapedCoderBinary, escapedGlobalConfig, wgArg, hostname, + "\tProxyCommand %s --global-config %s ssh --stdio %s", + escapedCoderBinary, escapedGlobalConfig, hostname, ), ) } @@ -374,9 +369,6 @@ func configSSH() *cobra.Command { cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.") _ = cmd.Flags().MarkHidden("skip-proxy-command") cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.") - _ = cmd.Flags().MarkHidden("wireguard") - cliui.AllowSkipPrompt(cmd) return cmd diff --git a/cli/configssh_test.go b/cli/configssh_test.go index e91d46c03aaeb..e1ae4054b5bea 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -12,12 +12,14 @@ import ( "path/filepath" "strconv" "strings" + "sync" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -106,15 +108,14 @@ func TestConfigSSH(t *testing.T) { agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, - CoordinatorDialer: client.ListenWorkspaceAgentTailnet, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) defer func() { _ = agentCloser.Close() }() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer agentConn.Close() @@ -123,17 +124,28 @@ func TestConfigSSH(t *testing.T) { defer func() { _ = listener.Close() }() + copyDone := make(chan struct{}) go func() { + defer close(copyDone) + var wg sync.WaitGroup for { conn, err := listener.Accept() if err != nil { - return + break } ssh, err := agentConn.SSH() assert.NoError(t, err) - go io.Copy(conn, ssh) - go io.Copy(ssh, conn) + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(conn, ssh) + }() + go func() { + defer wg.Done() + _, _ = io.Copy(ssh, conn) + }() } + wg.Wait() }() sshConfigFile := sshConfigFileName(t) @@ -178,6 +190,9 @@ func TestConfigSSH(t *testing.T) { data, err := sshCmd.Output() require.NoError(t, err) require.Equal(t, "test", strings.TrimSpace(string(data))) + + _ = listener.Close() + <-copyDone } func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index a187566f61b11..78ebcb7d0ff52 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -20,6 +20,8 @@ import ( "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" + "cdr.dev/slog" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" @@ -72,7 +74,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) // start workspace agent - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false") + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) agentClient := client clitest.SetupConfig(t, agentClient, root) @@ -85,11 +87,13 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer dialer.Close() - _, err = dialer.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err = dialer.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) return agentClient, agentToken, pubkey } diff --git a/cli/portforward.go b/cli/portforward.go index bdcceab723daa..7943291c042c0 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -17,9 +17,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - "cdr.dev/slog/sloggers/sloghuman" "github.com/coder/coder/agent" - "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) @@ -28,7 +26,6 @@ func portForward() *cobra.Command { var ( tcpForwards []string // : udpForwards []string // : - wireguard bool ) cmd := &cobra.Command{ Use: "port-forward ", @@ -94,16 +91,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var conn agent.Conn - if !wireguard { - conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) - } else { - logger := slog.Logger{} - if cliflag.IsSetBool(cmd, varVerbose) { - logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())).Named("tailnet").Leveled(slog.LevelDebug) - } - conn, err = client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID) - } + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) if err != nil { return err } @@ -178,12 +166,10 @@ func portForward() *cobra.Command { cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") - cmd.Flags().BoolVarP(&wireguard, "wireguard", "", true, "Specifies whether to use wireguard networking or not.") - _ = cmd.Flags().MarkHidden("wireguard") return cmd } -func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) var ( diff --git a/cli/portforward_test.go b/cli/portforward_test.go index bb0081a6d4b8a..62385b999a5fb 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -377,7 +377,7 @@ func setupTestListener(t *testing.T, l net.Listener) string { addr := l.Addr().String() _, port, err := net.SplitHostPort(addr) - require.NoErrorf(t, err, "split listen path %q", addr) + require.NoErrorf(t, err, "split non-Unix listen path %q", addr) addr = port return addr diff --git a/cli/server.go b/cli/server.go index bfc04447259ee..57d0880eb3755 100644 --- a/cli/server.go +++ b/cli/server.go @@ -28,8 +28,6 @@ import ( embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/google/go-github/v43/github" "github.com/google/uuid" - "github.com/pion/turn/v2" - "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/afero" @@ -59,7 +57,6 @@ import ( "github.com/coder/coder/coderd/prometheusmetrics" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" @@ -113,9 +110,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { tlsEnable bool tlsKeyFile string tlsMinVersion string - turnRelayAddress string tunnel bool - stunServers []string traceEnable bool secureAuthCookie bool sshKeygenAlgorithmRaw string @@ -300,22 +295,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err) } - turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(turnRelayAddress), - Address: turnRelayAddress, - }) - if err != nil { - return xerrors.Errorf("create turn server: %w", err) - } - defer turnServer.Close() - - iceServers := make([]webrtc.ICEServer, 0) - for _, stunServer := range stunServers { - iceServers = append(iceServers, webrtc.ICEServer{ - URLs: []string{stunServer}, - }) - } - // Validate provided auto-import templates. var ( validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(autoImportTemplates)) @@ -360,7 +339,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { options := &coderd.Options{ AccessURL: accessURLParsed, - ICEServers: iceServers, Logger: logger.Named("coderd"), Database: databasefake.New(), DERPMap: derpMap, @@ -369,8 +347,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { GoogleTokenValidator: googleTokenValidator, SecureAuthCookie: secureAuthCookie, SSHKeygenAlgorithm: sshKeygenAlgorithm, - TailscaleEnable: tailscaleEnable, - TURNServer: turnServer, TracerProvider: tracerProvider, Telemetry: telemetry.NewNoop(), AutoImportTemplates: validatedAutoImportTemplates, @@ -478,7 +454,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { OIDCAuth: oidcClientID != "", OIDCIssuerURL: oidcIssuerURL, Prometheus: promEnabled, - STUN: len(stunServers) != 0, + STUN: len(derpServerSTUNAddrs) != 0, Tunnel: tunnel, }) if err != nil { @@ -850,13 +826,8 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { `Minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`) cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false, "Workspaces must be able to reach the `access-url`. This overrides your access URL with a public access URL that tunnels your Coder deployment.") - cliflag.StringArrayVarP(root.Flags(), &stunServers, "stun-server", "", "CODER_STUN_SERVERS", []string{ - "stun:stun.l.google.com:19302", - }, "URLs for STUN servers to enable P2P connections.") cliflag.BoolVarP(root.Flags(), &traceEnable, "trace", "", "CODER_TRACE", false, "Whether application tracing data is collected.") - cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1", - "The address to bind TURN connections.") cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false, "Controls if the 'Secure' property is set on browser session cookies") cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519", diff --git a/cli/speedtest.go b/cli/speedtest.go index f1e177d4eb22d..357048f63ea34 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -12,7 +12,6 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" @@ -73,8 +72,7 @@ func speedtest() *cobra.Command { if err != nil { continue } - tc, _ := conn.(*agent.TailnetConn) - status := tc.Status() + status := conn.Status() if len(status.Peers()) != 1 { continue } diff --git a/cli/speedtest_test.go b/cli/speedtest_test.go index 3431c5508ee40..fa432950d1db4 100644 --- a/cli/speedtest_test.go +++ b/cli/speedtest_test.go @@ -25,7 +25,6 @@ func TestSpeedtest(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) diff --git a/cli/ssh.go b/cli/ssh.go index 96b243bab828a..bb8922fb0ff71 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -22,7 +22,6 @@ import ( "cdr.dev/slog" - "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -43,7 +42,6 @@ func ssh() *cobra.Command { forwardAgent bool identityAgent string wsPollInterval time.Duration - wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -88,12 +86,7 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var conn agent.Conn - if !wireguard { - conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) - } else { - conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) - } + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) if err != nil { return err } @@ -221,9 +214,6 @@ func ssh() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.") - _ = cmd.Flags().MarkHidden("wireguard") - return cmd } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index be68ad6226f36..b3f148b01519f 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -90,7 +90,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) @@ -112,7 +111,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) @@ -181,7 +179,6 @@ func TestSSH(t *testing.T) { agentClient.SessionToken = agentToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent"), }) diff --git a/coderd/coderd.go b/coderd/coderd.go index be2faa04c17cf..0e91f0b4d9339 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -13,7 +13,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/klauspost/compress/zstd" - "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" @@ -35,7 +34,6 @@ import ( "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" @@ -65,17 +63,14 @@ type Options struct { GithubOAuth2Config *GithubOAuth2Config OIDCConfig *OIDCConfig PrometheusRegistry *prometheus.Registry - ICEServers []webrtc.ICEServer SecureAuthCookie bool SSHKeygenAlgorithm gitsshkey.Algorithm Telemetry telemetry.Reporter - TURNServer *turnconn.Server TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler FeaturesService features.Service - TailscaleEnable bool TailnetCoordinator *tailnet.Coordinator DERPMap *tailcfg.DERPMap @@ -92,6 +87,12 @@ func New(options *Options) *API { // Multiply the update by two to allow for some lag-time. options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2 } + if options.AgentStatsRefreshInterval == 0 { + options.AgentStatsRefreshInterval = 10 * time.Minute + } + if options.MetricsCacheRefreshInterval == 0 { + options.MetricsCacheRefreshInterval = time.Hour + } if options.APIRateLimit == 0 { options.APIRateLimit = 512 } @@ -149,11 +150,7 @@ func New(options *Options) *API { }, metricsCache: metricsCache, } - if options.TailscaleEnable { - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) - } else { - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0) - } + api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, @@ -415,14 +412,8 @@ func New(options *Options) *API { r.Use(httpmw.ExtractWorkspaceAgent(options.Database)) r.Get("/metadata", api.workspaceAgentMetadata) r.Post("/version", api.postWorkspaceAgentVersion) - r.Get("/listen", api.workspaceAgentListen) - r.Get("/gitsshkey", api.agentGitSSHKey) - r.Get("/turn", api.workspaceAgentTurn) - r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/coordinate", api.workspaceAgentCoordinate) - r.Get("/report-stats", api.workspaceAgentReportStats) }) r.Route("/{workspaceagent}", func(r chi.Router) { @@ -432,11 +423,7 @@ func New(options *Options) *API { httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspaceAgent) - r.Get("/dial", api.workspaceAgentDial) - r.Get("/turn", api.userWorkspaceAgentTurn) r.Get("/pty", api.workspaceAgentPTY) - r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/connection", api.workspaceAgentConnection) r.Get("/coordinate", api.workspaceAgentClientCoordinate) }) diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authtest.go index 2ba404d7c4254..564bfbacf91cd 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authtest.go @@ -188,18 +188,14 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "GET:/api/v2/users/oidc/callback": {NoAuthorize: true}, // All workspaceagents endpoints do not use rbac - "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, // These endpoints have more assertions. This is good, add more endpoints to assert if you can! "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, @@ -256,14 +252,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionRead, AssertObject: workspaceRBACObj, }, - "GET:/api/v2/workspaceagents/{workspaceagent}/dial": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, - "GET:/api/v2/workspaceagents/{workspaceagent}/turn": { - AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, - }, "GET:/api/v2/workspaceagents/{workspaceagent}/pty": { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 48690ebb54cc4..709078b154171 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -54,7 +54,6 @@ import ( "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" @@ -202,12 +201,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } - turnServer, err := turnconn.New(nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = turnServer.Close() - }) - features := coderd.DisabledImplementations if options.Auditor != nil { features.Auditor = options.Auditor @@ -231,7 +224,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c OIDCConfig: options.OIDCConfig, GoogleTokenValidator: options.GoogleTokenValidator, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, - TURNServer: turnServer, APIRateLimit: options.APIRateLimit, Authorizer: options.Authorizer, Telemetry: telemetry.NewNoop(), diff --git a/coderd/templates_test.go b/coderd/templates_test.go index d3bcbd47dc33a..5052e4f9ba467 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -604,7 +604,6 @@ func TestTemplateDAUs(t *testing.T) { agentCloser := agent.New(agent.Options{ Logger: slogtest.Make(t, nil), StatsReporter: agentClient.AgentReportStats, - WebRTCDialer: agentClient.ListenWorkspaceAgent, FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, }) diff --git a/coderd/turnconn/turnconn.go b/coderd/turnconn/turnconn.go deleted file mode 100644 index b8231146d3cba..0000000000000 --- a/coderd/turnconn/turnconn.go +++ /dev/null @@ -1,203 +0,0 @@ -package turnconn - -import ( - "io" - "net" - "sync" - - "github.com/pion/logging" - "github.com/pion/turn/v2" - "github.com/pion/webrtc/v3" - "golang.org/x/net/proxy" - "golang.org/x/xerrors" -) - -var ( - // reservedAddress is a magic address that's used exclusively - // for proxying via Coder. We don't proxy all TURN connections, - // because that'd exclude the possibility of a customer using - // their own TURN server. - reservedAddress = "127.0.0.1:12345" - credential = "coder" - localhost = &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - } - - // Proxy is a an ICE Server that uses a special hostname - // to indicate traffic should be proxied. - Proxy = webrtc.ICEServer{ - URLs: []string{"turns:" + reservedAddress}, - Username: "coder", - Credential: credential, - } -) - -// New constructs a new TURN server binding to the relay address provided. -// The relay address is used to broadcast the location of an accepted connection. -func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) { - if relayAddress == nil { - relayAddress = &turn.RelayAddressGeneratorStatic{ - RelayAddress: localhost.IP, - Address: "127.0.0.1", - } - } - logger := logging.NewDefaultLoggerFactory() - logger.DefaultLogLevel = logging.LogLevelDisabled - server := &Server{ - conns: make(chan net.Conn, 1), - closed: make(chan struct{}), - } - server.listener = &listener{ - srv: server, - } - var err error - server.turn, err = turn.NewServer(turn.ServerConfig{ - AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - // TURN connections require credentials. It's not important - // for our use-case, because our listener is entirely in-memory. - return turn.GenerateAuthKey(Proxy.Username, "", credential), true - }, - ListenerConfigs: []turn.ListenerConfig{{ - Listener: server.listener, - RelayAddressGenerator: relayAddress, - }}, - LoggerFactory: logger, - }) - if err != nil { - return nil, xerrors.Errorf("create server: %w", err) - } - - return server, nil -} - -// Server accepts and connects TURN allocations. -// -// This is a thin wrapper around pion/turn that pipes -// connections directly to the in-memory handler. -type Server struct { - listener *listener - turn *turn.Server - - closeMutex sync.Mutex - closed chan (struct{}) - conns chan (net.Conn) -} - -// Accept consumes a new connection into the TURN server. -// A unique remote address must exist per-connection. -// pion/turn indexes allocations based on the address. -func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn { - if localAddress == nil { - localAddress = localhost - } - conn := &Conn{ - Conn: nc, - remoteAddress: remoteAddress, - localAddress: localAddress, - closed: make(chan struct{}), - } - s.conns <- conn - return conn -} - -// Close ends the TURN server. -func (s *Server) Close() error { - s.closeMutex.Lock() - defer s.closeMutex.Unlock() - if s.isClosed() { - return nil - } - err := s.turn.Close() - close(s.conns) - close(s.closed) - return err -} - -func (s *Server) isClosed() bool { - select { - case <-s.closed: - return true - default: - return false - } -} - -// listener implements net.Listener for the TURN -// server to consume. -type listener struct { - srv *Server -} - -func (l *listener) Accept() (net.Conn, error) { - conn, ok := <-l.srv.conns - if !ok { - return nil, io.EOF - } - return conn, nil -} - -func (*listener) Close() error { - return nil -} - -func (*listener) Addr() net.Addr { - return nil -} - -type Conn struct { - net.Conn - closed chan struct{} - localAddress *net.TCPAddr - remoteAddress *net.TCPAddr -} - -func (c *Conn) LocalAddr() net.Addr { - return c.localAddress -} - -func (c *Conn) RemoteAddr() net.Addr { - return c.remoteAddress -} - -// Closed returns a channel which is closed when -// the connection is. -func (c *Conn) Closed() <-chan struct{} { - return c.closed -} - -func (c *Conn) Close() error { - err := c.Conn.Close() - select { - case <-c.closed: - default: - close(c.closed) - } - return err -} - -type dialer func(network, addr string) (c net.Conn, err error) - -func (d dialer) Dial(network, addr string) (c net.Conn, err error) { - return d(network, addr) -} - -// ProxyDialer accepts a proxy function that's called when the connection -// address matches the reserved host in the "Proxy" ICE server. -// -// This should be passed to WebRTC connections as an ICE dialer. -func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer { - return dialer(func(network, addr string) (net.Conn, error) { - if addr != reservedAddress { - return proxy.Direct.Dial(network, addr) - } - netConn, err := proxyFunc() - if err != nil { - return nil, err - } - return &Conn{ - localAddress: localhost, - closed: make(chan struct{}), - Conn: netConn, - }, nil - }) -} diff --git a/coderd/turnconn/turnconn_test.go b/coderd/turnconn/turnconn_test.go deleted file mode 100644 index 6a8d0411cb7b1..0000000000000 --- a/coderd/turnconn/turnconn_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package turnconn_test - -import ( - "net" - "sync" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/turnconn" - "github.com/coder/coder/peer" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestTURNConn(t *testing.T) { - t.Parallel() - turnServer, err := turnconn.New(nil) - require.NoError(t, err) - defer turnServer.Close() - - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - - clientDialer, clientTURN := net.Pipe() - turnServer.Accept(clientTURN, &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 16000, - }, nil) - require.NoError(t, err) - clientSettings := webrtc.SettingEngine{} - clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) - clientSettings.SetRelayAcceptanceMinWait(0) - clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { - return clientDialer, nil - })) - client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ - SettingEngine: clientSettings, - Logger: logger.Named("client"), - }) - require.NoError(t, err) - defer func() { - _ = client.Close() - }() - - serverDialer, serverTURN := net.Pipe() - turnServer.Accept(serverTURN, &net.TCPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 16001, - }, nil) - require.NoError(t, err) - serverSettings := webrtc.SettingEngine{} - serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) - serverSettings.SetRelayAcceptanceMinWait(0) - serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { - return serverDialer, nil - })) - server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ - SettingEngine: serverSettings, - Logger: logger.Named("server"), - }) - require.NoError(t, err) - defer func() { - _ = server.Close() - }() - exchange(t, client, server) - - _, err = client.Ping() - require.NoError(t, err) -} - -func exchange(t *testing.T, client, server *peer.Conn) { - var wg sync.WaitGroup - wg.Add(2) - t.Cleanup(wg.Wait) - go func() { - defer wg.Done() - for { - select { - case c := <-server.LocalCandidate(): - client.AddRemoteCandidate(c) - case c := <-server.LocalSessionDescription(): - client.SetRemoteSessionDescription(c) - case <-server.Closed(): - return - } - } - }() - go func() { - defer wg.Done() - for { - select { - case c := <-client.LocalCandidate(): - server.AddRemoteCandidate(c) - case c := <-client.LocalSessionDescription(): - server.SetRemoteSessionDescription(c) - case <-client.Closed(): - return - } - } - }() -} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index fa464e6bbfb0e..7780ae154fe05 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -15,7 +15,6 @@ import ( "time" "github.com/google/uuid" - "github.com/hashicorp/yamux" "go.opentelemetry.io/otel/trace" "golang.org/x/mod/semver" "golang.org/x/xerrors" @@ -30,12 +29,7 @@ import ( "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/tracing" - "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/tailnet" ) @@ -66,67 +60,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { httpapi.Write(rw, http.StatusOK, apiAgent) } -func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - workspaceAgent := httpmw.WorkspaceAgentParam(r) - workspace := httpmw.WorkspaceParam(r) - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { - httpapi.ResourceNotFound(rw) - return - } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error reading workspace agent.", - Detail: err.Error(), - }) - return - } - if apiAgent.Status != codersdk.WorkspaceAgentConnected { - httpapi.Write(rw, http.StatusPreconditionFailed, codersdk.Response{ - Message: fmt.Sprintf("Agent isn't connected! Status: %s.", apiAgent.Status), - }) - return - } - - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Server(wsNetConn, config) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) - - err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{ - ChannelID: workspaceAgent.ID.String(), - Logger: api.Logger.Named("peerbroker-proxy-dial"), - Pubsub: api.Pubsub, - }) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) - return - } -} - func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { workspaceAgent := httpmw.WorkspaceAgent(r) apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) @@ -186,231 +119,6 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques httpapi.Write(rw, http.StatusOK, nil) } -func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - workspaceAgent := httpmw.WorkspaceAgent(r) - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error fetching workspace build job.", - Detail: err.Error(), - }) - return - } - // Ensure the resource is still valid! - // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } - - err = ensureLatestBuild() - if err != nil { - api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", - slog.F("resource", resource), - slog.F("agent", workspaceAgent), - ) - httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ - Message: "Agent trying to connect from non-latest build.", - Detail: err.Error(), - }) - return - } - - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Server(wsNetConn, config) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{ - ChannelID: workspaceAgent.ID.String(), - Pubsub: api.Pubsub, - Logger: api.Logger.Named("peerbroker-proxy-listen"), - }) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - defer closer.Close() - - firstConnectedAt := workspaceAgent.FirstConnectedAt - if !firstConnectedAt.Valid { - firstConnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - } - lastConnectedAt := sql.NullTime{ - Time: database.Now(), - Valid: true, - } - disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func() error { - err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: workspaceAgent.ID, - FirstConnectedAt: firstConnectedAt, - LastConnectedAt: lastConnectedAt, - DisconnectedAt: disconnectedAt, - UpdatedAt: database.Now(), - }) - if err != nil { - return err - } - return nil - } - - defer func() { - disconnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - _ = updateConnectionTimes() - }() - - err = updateConnectionTimes() - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) - - api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) - - ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer ticker.Stop() - for { - select { - case <-session.CloseChan(): - return - case <-ticker.C: - lastConnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - err = updateConnectionTimes() - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - err = ensureLatestBuild() - if err != nil { - // Disconnect agents that are no longer valid. - _ = conn.Close(websocket.StatusGoingAway, "") - return - } - } - } -} - -func (api *API) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusOK, api.ICEServers) -} - -// userWorkspaceAgentTurn is a user connecting to a remote workspace agent -// through turn. -func (api *API) userWorkspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { - workspace := httpmw.WorkspaceParam(r) - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { - httpapi.ResourceNotFound(rw) - return - } - - // Passed authorization - api.workspaceAgentTurn(rw, r) -} - -// workspaceAgentTurn proxies a WebSocket connection to the TURN server. -func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) - remoteAddress := &net.TCPAddr{ - IP: net.ParseIP(r.RemoteAddr), - } - // By default requests have the remote address and port. - host, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid remote address.", - Detail: err.Error(), - }) - return - } - remoteAddress.IP = net.ParseIP(host) - remoteAddress.Port, err = strconv.Atoi(port) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf("Port for remote address %q must be an integer.", r.RemoteAddr), - Detail: err.Error(), - }) - return - } - - wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) - - api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) - select { - case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed(): - case <-ctx.Done(): - } - api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) -} - // workspaceAgentPTY spawns a PTY and pipes it over a WebSocket. // This is used for the web terminal. func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { @@ -492,75 +200,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { _, _ = io.Copy(ptNetConn, wsNetConn) } -// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on -// r.Context() for cancellation if it's use is safe or r.Hijack() has -// not been performed. -func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { - client, server := provisionersdk.TransportPipe() - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - _ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{ - ChannelID: agentID.String(), - Logger: api.Logger.Named("peerbroker-proxy-dial"), - Pubsub: api.Pubsub, - }) - _ = client.Close() - _ = server.Close() - }() - - peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := peerClient.NegotiateConnection(ctx) - if err != nil { - cancelFunc() - return nil, xerrors.Errorf("negotiate: %w", err) - } - options := &peer.ConnOptions{ - Logger: api.Logger.Named("agent-dialer"), - } - options.SettingEngine.SetSrflxAcceptanceMinWait(0) - options.SettingEngine.SetRelayAcceptanceMinWait(0) - // Use the ProxyDialer for the TURN server. - // This is required for connections where P2P is not enabled. - options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) { - clientPipe, serverPipe := net.Pipe() - go func() { - <-ctx.Done() - _ = clientPipe.Close() - _ = serverPipe.Close() - }() - localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) - remoteAddress := &net.TCPAddr{ - IP: net.ParseIP(r.RemoteAddr), - } - // By default requests have the remote address and port. - host, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return nil, xerrors.Errorf("split remote address: %w", err) - } - remoteAddress.IP = net.ParseIP(host) - remoteAddress.Port, err = strconv.Atoi(port) - if err != nil { - return nil, xerrors.Errorf("convert remote port: %w", err) - } - api.TURNServer.Accept(clientPipe, remoteAddress, localAddress) - return serverPipe, nil - })) - peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options) - if err != nil { - cancelFunc() - return nil, xerrors.Errorf("dial: %w", err) - } - go func() { - <-peerConn.Closed() - cancelFunc() - }() - return &agent.WebRTCConn{ - Negotiator: peerClient, - Conn: peerConn, - }, nil -} - -func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { +func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { clientConn, serverConn := net.Pipe() go func() { <-r.Context().Done() @@ -587,7 +227,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (a _ = conn.Close() } }() - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, }, nil } @@ -609,6 +249,48 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request api.websocketWaitMutex.Unlock() defer api.websocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) + resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + + build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Internal error fetching workspace build job.", + Detail: err.Error(), + }) + return + } + // Ensure the resource is still valid! + // We only accept agents for resources on the latest build. + ensureLatestBuild := func() error { + latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) + if err != nil { + return err + } + if build.ID != latestBuild.ID { + return xerrors.New("build is outdated") + } + return nil + } + + err = ensureLatestBuild() + if err != nil { + api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", + slog.F("resource", resource), + slog.F("agent", workspaceAgent), + ) + httpapi.Write(rw, http.StatusForbidden, codersdk.Response{ + Message: "Agent trying to connect from non-latest build.", + Detail: err.Error(), + }) + return + } conn, err := websocket.Accept(rw, r, nil) if err != nil { @@ -618,12 +300,88 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request }) return } - defer conn.Close(websocket.StatusNormalClosure, "") - err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID) + ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary) + defer wsNetConn.Close() + + firstConnectedAt := workspaceAgent.FirstConnectedAt + if !firstConnectedAt.Valid { + firstConnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + } + lastConnectedAt := sql.NullTime{ + Time: database.Now(), + Valid: true, + } + disconnectedAt := workspaceAgent.DisconnectedAt + updateConnectionTimes := func() error { + err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: workspaceAgent.ID, + FirstConnectedAt: firstConnectedAt, + LastConnectedAt: lastConnectedAt, + DisconnectedAt: disconnectedAt, + UpdatedAt: database.Now(), + }) + if err != nil { + return err + } + return nil + } + + defer func() { + disconnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + _ = updateConnectionTimes() + }() + + err = updateConnectionTimes() if err != nil { - _ = conn.Close(websocket.StatusInternalError, err.Error()) + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } + + // end span so we don't get long lived trace data + tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) + api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent)) + + defer conn.Close(websocket.StatusNormalClosure, "") + + closeChan := make(chan struct{}) + go func() { + defer close(closeChan) + err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } + }() + ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) + defer ticker.Stop() + for { + select { + case <-closeChan: + return + case <-ticker.C: + } + lastConnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + err = updateConnectionTimes() + if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + return + } + err := ensureLatestBuild() + if err != nil { + // Disconnect agents that are no longer valid. + _ = conn.Close(websocket.StatusGoingAway, "") + return + } + } } // workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates. diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 38d27e26e799a..c4514c1134427 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "cdr.dev/slog" @@ -18,7 +17,6 @@ import ( "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/testutil" @@ -112,7 +110,6 @@ func TestWorkspaceAgentListen(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { @@ -123,13 +120,15 @@ func TestWorkspaceAgentListen(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) require.NoError(t, err) defer func() { _ = conn.Close() }() - _, err = conn.Ping() - require.NoError(t, err) + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) }) t.Run("FailNonLatestBuild", func(t *testing.T) { @@ -202,75 +201,12 @@ func TestWorkspaceAgentListen(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) + _, err = agentClient.ListenWorkspaceAgentTailnet(ctx) require.Error(t, err) require.ErrorContains(t, err, "build is outdated") }) } -func TestWorkspaceAgentTURN(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - IncludeProvisionerDaemon: true, - }) - - user := coderdtest.CreateFirstUser(t, client) - authToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "example", - Type: "aws_instance", - Agents: []*proto.Agent{{ - Id: uuid.NewString(), - Auth: &proto.Agent_Token{ - Token: authToken, - }, - }}, - }}, - }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - - agentClient := codersdk.New(client.URL) - agentClient.SessionToken = authToken - agentCloser := agent.New(agent.Options{ - FetchMetadata: agentClient.WorkspaceAgentMetadata, - CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, - Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), - }) - defer func() { - _ = agentCloser.Close() - }() - resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - opts := &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client"), - } - // Force a TURN connection! - opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4}) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts) - require.NoError(t, err) - defer func() { - _ = conn.Close() - }() - _, err = conn.Ping() - require.NoError(t, err) -} - func TestWorkspaceAgentTailnet(t *testing.T) { t.Parallel() client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil) @@ -306,7 +242,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) { agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, - WebRTCDialer: agentClient.ListenWorkspaceAgent, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) @@ -373,7 +308,6 @@ func TestWorkspaceAgentPTY(t *testing.T) { agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index 11a75d4a967f3..831b3761693df 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -103,7 +103,6 @@ func setupProxyTest(t *testing.T) (*codersdk.Client, uuid.UUID, codersdk.Workspa agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, - WebRTCDialer: agentClient.ListenWorkspaceAgent, Logger: slogtest.Make(t, nil).Named("agent"), }) t.Cleanup(func() { diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 698f467a40790..7d3b741a63b7e 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { } // Dialer creates a new agent connection by ID. -type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error) +type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error) // Conn wraps an agent connection with a reusable HTTP transport. type Conn struct { - agent.Conn + *agent.Conn locks atomic.Uint64 timeoutMutex sync.Mutex diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 80f187ba15ab7..a9ea85a2492ac 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -35,7 +35,7 @@ func TestCache(t *testing.T) { t.Parallel() t.Run("Same", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, 0) defer func() { @@ -50,7 +50,7 @@ func TestCache(t *testing.T) { t.Run("Expire", func(t *testing.T) { t.Parallel() called := atomic.NewInt32(0) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { called.Add(1) return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) @@ -69,7 +69,7 @@ func TestCache(t *testing.T) { }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -102,7 +102,7 @@ func TestCache(t *testing.T) { }() go server.Serve(random) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -139,7 +139,7 @@ func TestCache(t *testing.T) { }) } -func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn { +func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn { metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator() @@ -180,7 +180,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) return conn.UpdateNodes(node) }) conn.SetNodeCallback(sendNode) - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, } } diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index e71fbec7fbee5..296df9b5ac70d 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -135,6 +135,7 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText)) go func() { defer close(logs) + defer conn.Close(websocket.StatusGoingAway, "") var log ProvisionerJobLog for { err = decoder.Decode(&log) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 021c4bcf77865..2117de03c6ce3 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -14,9 +14,6 @@ import ( "cloud.google.com/go/compute/metadata" "github.com/google/uuid" - "github.com/hashicorp/yamux" - "github.com/pion/webrtc/v3" - "golang.org/x/net/proxy" "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -25,11 +22,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent" - "github.com/coder/coder/coderd/turnconn" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -206,69 +198,6 @@ func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (agent.Metadata, er return agentMetadata, json.NewDecoder(res.Body).Decode(&agentMetadata) } -// ListenWorkspaceAgent connects as a workspace agent identifying with the session token. -// On each inbound connection request, connection info is fetched. -func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { - serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/listen") - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenKey, - Value: c.SessionToken, - }}) - httpClient := &http.Client{ - Jar: jar, - } - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) - if err != nil { - return nil, xerrors.Errorf("multiplex client: %w", err) - } - return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - // This can be cached if it adds to latency too much. - res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil) - if err != nil { - return nil, nil, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, nil, readBodyAsError(res) - } - var iceServers []webrtc.ICEServer - err = json.NewDecoder(res.Body).Decode(&iceServers) - if err != nil { - return nil, nil, err - } - - options := webrtc.SettingEngine{} - options.SetSrflxAcceptanceMinWait(0) - options.SetRelayAcceptanceMinWait(0) - options.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, "/api/v2/workspaceagents/me/turn")) - iceServers = append(iceServers, turnconn.Proxy) - return iceServers, &peer.ConnOptions{ - SettingEngine: options, - Logger: logger, - }, nil - }) -} - func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) { coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate") if err != nil { @@ -286,17 +215,20 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err Jar: jar, } // nolint:bodyclose - conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, }) if err != nil { - return nil, err + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) } return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) { +func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*agent.Conn, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil) if err != nil { return nil, err @@ -349,10 +281,12 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg CompressionMode: websocket.CompressionDisabled, }) if errors.Is(err, context.Canceled) { + _ = ws.Close(websocket.StatusAbnormalClosure, "") return } if err != nil { logger.Debug(ctx, "failed to dial", slog.Error(err)) + _ = ws.Close(websocket.StatusAbnormalClosure, "") continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { @@ -362,15 +296,18 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg logger.Debug(ctx, "serving coordinator") err = <-errChan if errors.Is(err, context.Canceled) { + _ = ws.Close(websocket.StatusAbnormalClosure, "") return } if err != nil { logger.Debug(ctx, "error serving coordinator", slog.Error(err)) + _ = ws.Close(websocket.StatusAbnormalClosure, "") continue } + _ = ws.Close(websocket.StatusAbnormalClosure, "") } }() - return &agent.TailnetConn{ + return &agent.Conn{ Conn: conn, CloseFunc: func() { cancelFunc() @@ -379,78 +316,6 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg }, nil } -// DialWorkspaceAgent creates a connection to the specified resource. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) { - serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String())) - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenKey, - Value: c.SessionToken, - }}) - httpClient := &http.Client{ - Jar: jar, - } - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) - if err != nil { - return nil, xerrors.Errorf("multiplex client: %w", err) - } - client := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)) - stream, err := client.NegotiateConnection(ctx) - if err != nil { - return nil, xerrors.Errorf("negotiate connection: %w", err) - } - - res, err = c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/iceservers", agentID.String()), nil) - if err != nil { - return nil, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, readBodyAsError(res) - } - var iceServers []webrtc.ICEServer - err = json.NewDecoder(res.Body).Decode(&iceServers) - if err != nil { - return nil, err - } - - if options == nil { - options = &peer.ConnOptions{} - } - options.SettingEngine.SetSrflxAcceptanceMinWait(0) - options.SettingEngine.SetRelayAcceptanceMinWait(0) - options.SettingEngine.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, fmt.Sprintf("/api/v2/workspaceagents/%s/turn", agentID.String()))) - iceServers = append(iceServers, turnconn.Proxy) - - peerConn, err := peerbroker.Dial(stream, iceServers, options) - if err != nil { - return nil, xerrors.Errorf("dial peer: %w", err) - } - return &agent.WebRTCConn{ - Negotiator: client, - Conn: peerConn, - }, nil -} - // WorkspaceAgent returns an agent by ID. func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s", id), nil) @@ -509,27 +374,6 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, path string) proxy.Dialer { - return turnconn.ProxyDialer(func() (net.Conn, error) { - turnURL, err := c.URL.Parse(path) - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - conn, res, err := websocket.Dial(ctx, turnURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, err - } - return nil, readBodyAsError(res) - } - return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil - }) -} - // AgentReportStats begins a stat streaming connection with the Coder server. // It is resilient to network failures and intermittent coderd issues. func (c *Client) AgentReportStats( @@ -584,6 +428,7 @@ func (c *Client) AgentReportStats( var req AgentStatsReportRequest err := wsjson.Read(ctx, conn, &req) if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, "") return err } @@ -597,6 +442,7 @@ func (c *Client) AgentReportStats( err = wsjson.Write(ctx, conn, resp) if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, "") return err } } diff --git a/go.mod b/go.mod index 91cfff2d27514..716f2ef4c9cf2 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ replace github.com/fatedier/kcp-go => github.com/coder/kcp-go v2.0.4-0.202204091 // https://github.com/pion/udp/pull/73 replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b2872e92e98d -// https://github.com/hashicorp/hc-install/pull/68 +// https://github.com/hashicorp/hc-dinstall/pull/68 replace github.com/hashicorp/hc-install => github.com/mafredri/hc-install v0.4.1-0.20220727132613-e91868e28445 // https://github.com/tcnksm/go-httpstat/pull/29 @@ -119,12 +119,7 @@ require ( github.com/nhatthm/otelsql v0.4.0 github.com/open-policy-agent/opa v0.41.0 github.com/ory/dockertest/v3 v3.9.1 - github.com/pion/datachannel v1.5.2 - github.com/pion/logging v0.2.2 - github.com/pion/transport v0.13.1 - github.com/pion/turn/v2 v2.0.8 github.com/pion/udp v0.1.1 - github.com/pion/webrtc/v3 v3.1.43 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e github.com/pkg/sftp v1.13.5 @@ -150,7 +145,6 @@ require ( golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 - golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094 golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 @@ -171,6 +165,11 @@ require ( tailscale.com v1.30.0 ) +require ( + github.com/pion/transport v0.13.1 // indirect + golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect +) + require ( filippo.io/edwards25519 v1.0.0-rc.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect @@ -258,17 +257,6 @@ require ( github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect github.com/opencontainers/runc v1.1.2 // indirect github.com/pelletier/go-toml/v2 v2.0.2 // indirect - github.com/pion/dtls/v2 v2.1.5 // indirect - github.com/pion/ice/v2 v2.2.6 // indirect - github.com/pion/interceptor v0.1.11 // indirect - github.com/pion/mdns v0.0.5 // indirect - github.com/pion/randutil v0.1.0 // indirect - github.com/pion/rtcp v1.2.9 // indirect - github.com/pion/rtp v1.7.13 // indirect - github.com/pion/sctp v1.8.2 // indirect - github.com/pion/sdp/v3 v3.0.5 // indirect - github.com/pion/srtp/v2 v2.0.10 // indirect - github.com/pion/stun v0.3.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect diff --git a/go.sum b/go.sum index 155138c4c3d36..9337154197ab1 100644 --- a/go.sum +++ b/go.sum @@ -1451,7 +1451,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/gomega v0.0.0-20151007035656-2152b45fa28a/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= @@ -1526,43 +1525,9 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2 github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pion/datachannel v1.5.2 h1:piB93s8LGmbECrpO84DnkIVWasRMk3IimbcXkTQLE6E= -github.com/pion/datachannel v1.5.2/go.mod h1:FTGQWaHrdCwIJ1rw6xBIfZVkslikjShim5yr05XFuCQ= -github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus= -github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c= -github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY= -github.com/pion/ice/v2 v2.2.6 h1:R/vaLlI1J2gCx141L5PEwtuGAGcyS6e7E0hDeJFq5Ig= -github.com/pion/ice/v2 v2.2.6/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE= -github.com/pion/interceptor v0.1.11 h1:00U6OlqxA3FFB50HSg25J/8cWi7P6FbSzw4eFn24Bvs= -github.com/pion/interceptor v0.1.11/go.mod h1:tbtKjZY14awXd7Bq0mmWvgtHB5MDaRN7HV3OZ/uy7s8= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= -github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g= -github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= -github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= -github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= -github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA= -github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= -github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= -github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA= -github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= -github.com/pion/sdp/v3 v3.0.5 h1:ouvI7IgGl+V4CrqskVtr3AaTrPvPisEOxwgpdktctkU= -github.com/pion/sdp/v3 v3.0.5/go.mod h1:iiFWFpQO8Fy3S5ldclBkpXqmWy02ns78NOKoLLL0YQw= -github.com/pion/srtp/v2 v2.0.10 h1:b8ZvEuI+mrL8hbr/f1YiJFB34UMrOac3R3N1yq2UN0w= -github.com/pion/srtp/v2 v2.0.10/go.mod h1:XEeSWaK9PfuMs7zxXyiN252AHPbH12NX5q/CFDWtUuA= -github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= -github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= -github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q= -github.com/pion/transport v0.12.3/go.mod h1:OViWW9SP2peE/HbwBvARicmAVnesphkNkCVZIWJ6q9A= -github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g= github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA= github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg= -github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw= -github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw= -github.com/pion/webrtc/v3 v3.1.43 h1:YT3ZTO94UT4kSBvZnRAH82+0jJPUruiKr9CEstdlQzk= -github.com/pion/webrtc/v3 v3.1.43/go.mod h1:G/J8k0+grVsjC/rjCZ24AKoCCxcFFODgh7zThNZGs0M= github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= @@ -2042,9 +2007,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220516162934-403b01795ae8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 h1:O8uGbHCqlTp2P6QJSLmCojM4mN6UemYv8K+dCnmHmu0= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -2155,7 +2117,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -2178,7 +2139,6 @@ golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220107192237-5cfca573fb4d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -2186,13 +2146,11 @@ golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220401154927-543a649e0bdd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/oauth2 v0.0.0-20180227000427-d7d64896b5ff/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -2389,7 +2347,6 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/peer/channel.go b/peer/channel.go deleted file mode 100644 index 4945e5d465683..0000000000000 --- a/peer/channel.go +++ /dev/null @@ -1,317 +0,0 @@ -package peer - -import ( - "bufio" - "context" - "io" - "net" - "sync" - - "github.com/pion/datachannel" - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - - "cdr.dev/slog" -) - -const ( - bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB - maxBufferedAmount uint64 = 1024 * 1024 // 1 MB - // For some reason messages larger just don't work... - // This shouldn't be a huge deal for real-world usage. - // See: https://github.com/pion/datachannel/issues/59 - maxMessageLength = 64 * 1024 // 64 KB -) - -// newChannel creates a new channel and initializes it. -// The initialization overrides listener handles, and detaches -// the channel on open. The datachannel should not be manually -// mutated after being passed to this function. -func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOptions) *Channel { - channel := &Channel{ - opts: opts, - conn: conn, - dc: dc, - - opened: make(chan struct{}), - closed: make(chan struct{}), - sendMore: make(chan struct{}, 1), - } - channel.init() - return channel -} - -type ChannelOptions struct { - // ID is a channel ID that should be used when `Negotiated` - // is true. - ID uint16 - - // Negotiated returns whether the data channel will already - // be active on the other end. Defaults to false. - Negotiated bool - - // Arbitrary string that can be parsed on `Accept`. - Protocol string - - // Unordered determines whether the channel acts like - // a UDP connection. Defaults to false. - Unordered bool - - // Whether the channel will be left open on disconnect or not. - // If true, data will be buffered on either end to be sent - // once reconnected. Defaults to false. - OpenOnDisconnect bool -} - -// Channel represents a WebRTC DataChannel. -// -// This struct wraps webrtc.DataChannel to add concurrent-safe usage, -// data bufferring, and standardized errors for connection state. -// -// It modifies the default behavior of a DataChannel by closing on -// WebRTC PeerConnection failure. This is done to emulate TCP connections. -// This option can be changed in the options when creating a Channel. -type Channel struct { - opts *ChannelOptions - - conn *Conn - dc *webrtc.DataChannel - // This field can be nil. It becomes set after the DataChannel - // has been opened and is detached. - rwc datachannel.ReadWriteCloser - reader io.Reader - - closed chan struct{} - closeMutex sync.Mutex - closeError error - - opened chan struct{} - - // sendMore is used to block Write operations on a full buffer. - // It's signaled when the buffer can accept more data. - sendMore chan struct{} - writeMutex sync.Mutex -} - -// init attaches listeners to the DataChannel to detect opening, -// closing, and when the channel is ready to transmit data. -// -// This should only be called once on creation. -func (c *Channel) init() { - // WebRTC connections maintain an internal buffer that can fill when: - // 1. Data is being sent faster than it can flush. - // 2. The connection is disconnected, but data is still being sent. - // - // This applies a maximum in-memory buffer for data, and will cause - // write operations to block once the threshold is set. - c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) - c.dc.OnBufferedAmountLow(func() { - // Grab the lock to protect the sendMore channel from being - // closed in between the isClosed check and the send. - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - if c.isClosed() { - return - } - select { - case <-c.closed: - case c.sendMore <- struct{}{}: - default: - } - }) - c.dc.OnClose(func() { - c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label())) - _ = c.closeWithError(ErrClosed) - }) - c.dc.OnOpen(func() { - c.closeMutex.Lock() - c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label())) - var err error - c.rwc, err = c.dc.Detach() - if err != nil { - c.closeMutex.Unlock() - _ = c.closeWithError(xerrors.Errorf("detach: %w", err)) - return - } - c.closeMutex.Unlock() - - // pion/webrtc will return an io.ErrShortBuffer when a read - // is triggered with a buffer size less than the chunks written. - // - // This makes sense when considering UDP connections, because - // buffering of data that has no transmit guarantees is likely - // to cause unexpected behavior. - // - // When ordered, this adds a bufio.Reader. This ensures additional - // data on TCP-like connections can be read in parts, while still - // being buffered. - if c.opts.Unordered { - c.reader = c.rwc - } else { - // This must be the max message length otherwise a short - // buffer error can occur. - c.reader = bufio.NewReaderSize(c.rwc, maxMessageLength) - } - close(c.opened) - }) - - c.conn.dcDisconnectListeners.Add(1) - c.conn.dcFailedListeners.Add(1) - c.conn.dcClosedWaitGroup.Add(1) - go func() { - var err error - // A DataChannel can disconnect multiple times, so this needs to loop. - for { - select { - case <-c.conn.closedRTC: - // If this channel was closed, there's no need to close again. - err = c.conn.closeError - case <-c.conn.Closed(): - // If the RTC connection closed with an error, this channel - // should end with the same one. - err = c.conn.closeError - case <-c.conn.dcDisconnectChannel: - // If the RTC connection is disconnected, we need to check if - // the DataChannel is supposed to end on disconnect. - if c.opts.OpenOnDisconnect { - continue - } - err = xerrors.Errorf("rtc disconnected. closing: %w", ErrClosed) - case <-c.conn.dcFailedChannel: - // If the RTC connection failed, close the Channel. - err = ErrFailed - } - if err != nil { - break - } - } - _ = c.closeWithError(err) - }() -} - -// Read blocks until data is received. -// -// This will block until the underlying DataChannel has been opened. -func (c *Channel) Read(bytes []byte) (int, error) { - err := c.waitOpened() - if err != nil { - return 0, err - } - - bytesRead, err := c.reader.Read(bytes) - if err != nil { - if c.isClosed() { - return 0, c.closeError - } - // An EOF always occurs when the connection is closed. - // Alternative close errors will occur first if an unexpected - // close has occurred. - if xerrors.Is(err, io.EOF) { - err = c.closeWithError(ErrClosed) - } - } - return bytesRead, err -} - -// Write sends data to the underlying DataChannel. -// -// This function will block if too much data is being sent. -// Data will buffer if the connection is temporarily disconnected, -// and will be flushed upon reconnection. -// -// If the Channel is setup to close on disconnect, any buffered -// data will be lost. -func (c *Channel) Write(bytes []byte) (n int, err error) { - if len(bytes) > maxMessageLength { - return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength) - } - - c.writeMutex.Lock() - defer c.writeMutex.Unlock() - - err = c.waitOpened() - if err != nil { - return 0, err - } - if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount { - <-c.sendMore - } - - return c.rwc.Write(bytes) -} - -// Close gracefully closes the DataChannel. -func (c *Channel) Close() error { - return c.closeWithError(nil) -} - -// Label returns the label of the underlying DataChannel. -func (c *Channel) Label() string { - return c.dc.Label() -} - -// Protocol returns the protocol of the underlying DataChannel. -func (c *Channel) Protocol() string { - return c.dc.Protocol() -} - -// NetConn wraps the DataChannel in a struct fulfilling net.Conn. -// Read, Write, and Close operations can still be used on the *Channel struct. -func (c *Channel) NetConn() net.Conn { - return &fakeNetConn{ - c: c, - addr: &peerAddr{}, - } -} - -// closeWithError closes the Channel with the error provided. -// If a graceful close occurs, the error will be nil. -func (c *Channel) closeWithError(err error) error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - if c.isClosed() { - return c.closeError - } - - c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err)) - if err == nil { - c.closeError = ErrClosed - } else { - c.closeError = err - } - if c.rwc != nil { - _ = c.rwc.Close() - } - _ = c.dc.Close() - - close(c.closed) - close(c.sendMore) - c.conn.dcDisconnectListeners.Sub(1) - c.conn.dcFailedListeners.Sub(1) - c.conn.dcClosedWaitGroup.Done() - - return err -} - -func (c *Channel) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -func (c *Channel) waitOpened() error { - select { - case <-c.opened: - // Re-check the closed channel to prioritize closure. - if c.isClosed() { - return c.closeError - } - return nil - case <-c.closed: - return c.closeError - } -} diff --git a/peer/conn.go b/peer/conn.go deleted file mode 100644 index 2e67b500ee5fd..0000000000000 --- a/peer/conn.go +++ /dev/null @@ -1,616 +0,0 @@ -package peer - -import ( - "bytes" - "context" - "crypto/rand" - "io" - "sync" - "time" - - "github.com/pion/logging" - "github.com/pion/webrtc/v3" - "go.uber.org/atomic" - "golang.org/x/xerrors" - - "cdr.dev/slog" -) - -var ( - // ErrDisconnected occurs when the connection has disconnected. - // The connection will be attempting to reconnect at this point. - ErrDisconnected = xerrors.New("connection is disconnected") - // ErrFailed occurs when the connection has failed. - // The connection will not retry after this point. - ErrFailed = xerrors.New("connection has failed") - // ErrClosed occurs when the connection was closed. It wraps io.EOF - // to fulfill expected read errors from closed pipes. - ErrClosed = xerrors.Errorf("connection was closed: %w", io.EOF) - - // The amount of random bytes sent in a ping. - pingDataLength = 64 -) - -// Client creates a new client connection. -func Client(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { - return newWithClientOrServer(servers, true, opts) -} - -// Server creates a new server connection. -func Server(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { - return newWithClientOrServer(servers, false, opts) -} - -// newWithClientOrServer constructs a new connection with the client option. -// nolint:revive -func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOptions) (*Conn, error) { - if opts == nil { - opts = &ConnOptions{} - } - - opts.SettingEngine.DetachDataChannels() - logger := logging.NewDefaultLoggerFactory() - logger.DefaultLogLevel = logging.LogLevelDisabled - opts.SettingEngine.LoggerFactory = logger - api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine)) - rtc, err := api.NewPeerConnection(webrtc.Configuration{ - ICEServers: servers, - }) - if err != nil { - return nil, xerrors.Errorf("create peer connection: %w", err) - } - conn := &Conn{ - pingChannelID: 1, - pingEchoChannelID: 2, - rtc: rtc, - offerer: client, - closed: make(chan struct{}), - closedRTC: make(chan struct{}), - closedICE: make(chan struct{}), - dcOpenChannel: make(chan *webrtc.DataChannel, 8), - dcDisconnectChannel: make(chan struct{}), - dcFailedChannel: make(chan struct{}), - localCandidateChannel: make(chan webrtc.ICECandidateInit), - localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), - negotiated: make(chan struct{}), - remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), - settingEngine: opts.SettingEngine, - } - conn.loggerValue.Store(opts.Logger) - if client { - // If we're the client, we want to flip the echo and - // ping channel IDs so pings don't accidentally hit each other. - conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID - } - err = conn.init() - if err != nil { - return nil, xerrors.Errorf("init: %w", err) - } - return conn, nil -} - -type ConnOptions struct { - Logger slog.Logger - - // Enables customization on the underlying WebRTC connection. - SettingEngine webrtc.SettingEngine -} - -// Conn represents a WebRTC peer connection. -// -// This struct wraps webrtc.PeerConnection to add bidirectional pings, -// concurrent-safe webrtc.DataChannel, and standardized errors for connection state. -type Conn struct { - rtc *webrtc.PeerConnection - // Determines whether this connection will send the offer or the answer. - offerer bool - - closed chan struct{} - closedRTC chan struct{} - closedRTCMutex sync.Mutex - closedICE chan struct{} - closedICEMutex sync.Mutex - closeMutex sync.Mutex - closeError error - - dcCreateMutex sync.Mutex - dcOpenChannel chan *webrtc.DataChannel - dcDisconnectChannel chan struct{} - dcDisconnectListeners atomic.Uint32 - dcFailedChannel chan struct{} - dcFailedListeners atomic.Uint32 - dcClosedWaitGroup sync.WaitGroup - - localCandidateChannel chan webrtc.ICECandidateInit - localSessionDescriptionChannel chan webrtc.SessionDescription - remoteSessionDescriptionChannel chan webrtc.SessionDescription - - negotiated chan struct{} - - loggerValue atomic.Value - settingEngine webrtc.SettingEngine - - pingChannelID uint16 - pingEchoChannelID uint16 - - pingEchoChan *Channel - pingEchoOnce sync.Once - pingEchoError error - pingMutex sync.Mutex - pingOnce sync.Once - pingChan *Channel - pingError error -} - -func (c *Conn) logger() slog.Logger { - log, valid := c.loggerValue.Load().(slog.Logger) - if !valid { - return slog.Logger{} - } - - return log -} - -func (c *Conn) init() error { - c.rtc.OnNegotiationNeeded(c.negotiate) - c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) { - c.closedICEMutex.Lock() - defer c.closedICEMutex.Unlock() - select { - case <-c.closedICE: - // Don't log more state changes if we've already closed. - return - default: - c.logger().Debug(context.Background(), "ice connection state updated", - slog.F("state", iceConnectionState)) - - if iceConnectionState == webrtc.ICEConnectionStateClosed { - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - close(c.closedICE) - } - } - }) - c.rtc.OnICEGatheringStateChange(func(iceGatherState webrtc.ICEGathererState) { - c.closedICEMutex.Lock() - defer c.closedICEMutex.Unlock() - select { - case <-c.closedICE: - // Don't log more state changes if we've already closed. - return - default: - c.logger().Debug(context.Background(), "ice gathering state updated", - slog.F("state", iceGatherState)) - - if iceGatherState == webrtc.ICEGathererStateClosed { - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - close(c.closedICE) - } - } - }) - c.rtc.OnConnectionStateChange(func(peerConnectionState webrtc.PeerConnectionState) { - go func() { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - if c.isClosed() { - return - } - c.logger().Debug(context.Background(), "rtc connection updated", - slog.F("state", peerConnectionState)) - }() - - switch peerConnectionState { - case webrtc.PeerConnectionStateDisconnected: - for i := 0; i < int(c.dcDisconnectListeners.Load()); i++ { - select { - case c.dcDisconnectChannel <- struct{}{}: - default: - } - } - case webrtc.PeerConnectionStateFailed: - for i := 0; i < int(c.dcFailedListeners.Load()); i++ { - select { - case c.dcFailedChannel <- struct{}{}: - default: - } - } - case webrtc.PeerConnectionStateClosed: - // pion/webrtc can update this state multiple times. - // A connection can never become un-closed, so we - // close the channel if it isn't already. - c.closedRTCMutex.Lock() - defer c.closedRTCMutex.Unlock() - select { - case <-c.closedRTC: - default: - close(c.closedRTC) - } - } - }) - - // These functions need to check if the conn is closed, because they can be - // called after being closed. - c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) { - c.logger().Debug(context.Background(), "signaling state updated", - slog.F("state", signalState)) - }) - c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) { - c.logger().Debug(context.Background(), "dtls transport state updated", - slog.F("state", dtlsTransportState)) - }) - c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) { - c.logger().Debug(context.Background(), "selected candidate pair changed", - slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote)) - }) - c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) { - if iceCandidate == nil { - return - } - // Run this in a goroutine so we don't block pion/webrtc - // from continuing. - go func() { - c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate)) - select { - case <-c.closed: - case c.localCandidateChannel <- iceCandidate.ToJSON(): - } - }() - }) - c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) { - go func() { - select { - case <-c.closed: - case c.dcOpenChannel <- dc: - } - }() - }) - _, err := c.pingChannel() - if err != nil { - return err - } - _, err = c.pingEchoChannel() - if err != nil { - return err - } - - return nil -} - -// negotiate is triggered when a connection is ready to be established. -// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/ -func (c *Conn) negotiate() { - c.logger().Debug(context.Background(), "negotiating") - // ICE candidates cannot be added until SessionDescriptions have been - // exchanged between peers. - defer func() { - select { - case <-c.negotiated: - default: - close(c.negotiated) - } - }() - - if c.offerer { - offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{}) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("create offer: %w", err)) - return - } - // pion/webrtc will panic if Close is called while this - // function is being executed. - c.closeMutex.Lock() - err = c.rtc.SetLocalDescription(offer) - c.closeMutex.Unlock() - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) - return - } - c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer)) - select { - case <-c.closed: - return - case c.localSessionDescriptionChannel <- offer: - } - c.logger().Debug(context.Background(), "sent offer") - } - - var sessionDescription webrtc.SessionDescription - c.logger().Debug(context.Background(), "awaiting remote description...") - select { - case <-c.closed: - return - case sessionDescription = <-c.remoteSessionDescriptionChannel: - } - c.logger().Debug(context.Background(), "setting remote description") - - err := c.rtc.SetRemoteDescription(sessionDescription) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err)) - return - } - - if !c.offerer { - answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{}) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("create answer: %w", err)) - return - } - // pion/webrtc will panic if Close is called while this - // function is being executed. - c.closeMutex.Lock() - err = c.rtc.SetLocalDescription(answer) - c.closeMutex.Unlock() - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) - return - } - c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer)) - select { - case <-c.closed: - return - case c.localSessionDescriptionChannel <- answer: - } - c.logger().Debug(context.Background(), "sent answer") - } -} - -// AddRemoteCandidate adds a remote candidate to the RTC connection. -func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) { - if c.isClosed() { - return - } - // This must occur in a goroutine to allow the SessionDescriptions - // to be exchanged first. - go func() { - select { - case <-c.closed: - case <-c.negotiated: - } - if c.isClosed() { - return - } - c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate)) - err := c.rtc.AddICECandidate(i) - if err != nil { - if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed { - return - } - _ = c.CloseWithError(xerrors.Errorf("accept candidate: %w", err)) - } - }() -} - -// SetRemoteSessionDescription sets the remote description for the WebRTC connection. -func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) { - select { - case <-c.closed: - case c.remoteSessionDescriptionChannel <- sessionDescription: - } -} - -// LocalSessionDescription returns a channel that emits a session description -// when one is required to be exchanged. -func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription { - return c.localSessionDescriptionChannel -} - -// LocalCandidate returns a channel that emits when a local candidate -// needs to be exchanged with a remote connection. -func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit { - return c.localCandidateChannel -} - -func (c *Conn) pingChannel() (*Channel, error) { - c.pingOnce.Do(func() { - c.pingChan, c.pingError = c.dialChannel(context.Background(), "ping", &ChannelOptions{ - ID: c.pingChannelID, - Negotiated: true, - OpenOnDisconnect: true, - }) - if c.pingError != nil { - return - } - }) - return c.pingChan, c.pingError -} - -func (c *Conn) pingEchoChannel() (*Channel, error) { - c.pingEchoOnce.Do(func() { - c.pingEchoChan, c.pingEchoError = c.dialChannel(context.Background(), "echo", &ChannelOptions{ - ID: c.pingEchoChannelID, - Negotiated: true, - OpenOnDisconnect: true, - }) - if c.pingEchoError != nil { - return - } - go func() { - for { - data := make([]byte, pingDataLength) - bytesRead, err := c.pingEchoChan.Read(data) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err)) - return - } - _, err = c.pingEchoChan.Write(data[:bytesRead]) - if err != nil { - _ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err)) - return - } - } - }() - }) - return c.pingEchoChan, c.pingEchoError -} - -// SetConfiguration applies options to the WebRTC connection. -// Generally used for updating transport options, like ICE servers. -func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error { - return c.rtc.SetConfiguration(configuration) -} - -// Accept blocks waiting for a channel to be opened. -func (c *Conn) Accept(ctx context.Context) (*Channel, error) { - var dataChannel *webrtc.DataChannel - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-c.closed: - return nil, c.closeError - case dataChannel = <-c.dcOpenChannel: - } - - return newChannel(c, dataChannel, &ChannelOptions{}), nil -} - -// CreateChannel creates a new DataChannel. -func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { - if opts == nil { - opts = &ChannelOptions{} - } - if opts.ID == c.pingChannelID || opts.ID == c.pingEchoChannelID { - return nil, xerrors.Errorf("datachannel id %d and %d are reserved for ping", c.pingChannelID, c.pingEchoChannelID) - } - return c.dialChannel(ctx, label, opts) -} - -func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { - // pion/webrtc is slower when opening multiple channels - // in parallel than it is sequentially. - c.dcCreateMutex.Lock() - defer c.dcCreateMutex.Unlock() - - c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts)) - var id *uint16 - if opts.ID != 0 { - id = &opts.ID - } - ordered := true - if opts.Unordered { - ordered = false - } - if opts.OpenOnDisconnect && !opts.Negotiated { - return nil, xerrors.New("OpenOnDisconnect is only allowed for Negotiated channels") - } - if c.isClosed() { - return nil, xerrors.Errorf("closed: %w", c.closeError) - } - - dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{ - ID: id, - Negotiated: &opts.Negotiated, - Ordered: &ordered, - Protocol: &opts.Protocol, - }) - if err != nil { - return nil, xerrors.Errorf("create data channel: %w", err) - } - return newChannel(c, dataChannel, opts), nil -} - -// Ping returns the duration it took to round-trip data. -// Multiple pings cannot occur at the same time, so this function will block. -func (c *Conn) Ping() (time.Duration, error) { - // Pings are not async, so we need a mutex. - c.pingMutex.Lock() - defer c.pingMutex.Unlock() - - ping, err := c.pingChannel() - if err != nil { - return 0, xerrors.Errorf("get ping channel: %w", err) - } - pingDataSent := make([]byte, pingDataLength) - _, err = rand.Read(pingDataSent) - if err != nil { - return 0, xerrors.Errorf("read random ping data: %w", err) - } - start := time.Now() - _, err = ping.Write(pingDataSent) - if err != nil { - return 0, xerrors.Errorf("send ping: %w", err) - } - c.logger().Debug(context.Background(), "wrote ping", - slog.F("connection_state", c.rtc.ConnectionState())) - - pingDataReceived := make([]byte, pingDataLength) - _, err = ping.Read(pingDataReceived) - if err != nil { - return 0, xerrors.Errorf("read ping: %w", err) - } - end := time.Now() - if !bytes.Equal(pingDataSent, pingDataReceived) { - return 0, xerrors.Errorf("ping data inconsistency sent != received") - } - return end.Sub(start), nil -} - -func (c *Conn) Closed() <-chan struct{} { - return c.closed -} - -// Close closes the connection and frees all associated resources. -func (c *Conn) Close() error { - return c.CloseWithError(nil) -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -// CloseWithError closes the connection; subsequent reads/writes will return the error err. -func (c *Conn) CloseWithError(err error) error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - if c.isClosed() { - return c.closeError - } - - logger := c.logger() - - logger.Debug(context.Background(), "closing conn with error", slog.Error(err)) - if err == nil { - c.closeError = ErrClosed - } else { - c.closeError = err - } - - if ch, _ := c.pingChannel(); ch != nil { - _ = ch.closeWithError(c.closeError) - } - // If the WebRTC connection has already been closed (due to failure or disconnect), - // this call will return an error that isn't typed. We don't check the error because - // closing an already closed connection isn't an issue for us. - _ = c.rtc.Close() - - // Waiting for pion/webrtc to report closed state on both of these - // ensures no goroutine leaks. - if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew { - logger.Debug(context.Background(), "waiting for rtc connection close...") - <-c.closedRTC - } - if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew { - logger.Debug(context.Background(), "waiting for ice connection close...") - <-c.closedICE - } - - // Waits for all DataChannels to exit before officially labeling as closed. - // All logging, goroutines, and async functionality is cleaned up after this. - c.dcClosedWaitGroup.Wait() - - // Disable logging! - c.loggerValue.Store(slog.Logger{}) - logger.Sync() - - logger.Debug(context.Background(), "closed") - close(c.closed) - return err -} diff --git a/peer/conn_test.go b/peer/conn_test.go deleted file mode 100644 index 992765b940c74..0000000000000 --- a/peer/conn_test.go +++ /dev/null @@ -1,434 +0,0 @@ -package peer_test - -import ( - "context" - "io" - "net" - "net/http" - "os" - "sync" - "testing" - "time" - - "github.com/pion/logging" - "github.com/pion/transport/vnet" - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "golang.org/x/xerrors" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/peer" - "github.com/coder/coder/testutil" -) - -var ( - disconnectedTimeout = func() time.Duration { - // Connection state is unfortunately time-based. When resources are - // contended, a connection can take greater than this timeout to - // handshake, which results in a test flake. - // - // During local testing resources are rarely contended. Reducing this - // timeout leads to faster local development. - // - // In CI resources are frequently contended, so increasing this value - // results in less flakes. - if os.Getenv("CI") == "true" { - return time.Second - } - return 100 * time.Millisecond - }() - failedTimeout = disconnectedTimeout * 3 - keepAliveInterval = time.Millisecond * 2 - - // There's a global race in the vnet library allocation code. - // This mutex locks around the creation of the vnet. - vnetMutex = sync.Mutex{} -) - -func TestMain(m *testing.M) { - // pion/ice doesn't properly close immediately. The solution for this isn't yet known. See: - // https://github.com/pion/ice/pull/413 - goleak.VerifyTestMain(m, - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func1"), - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func2"), - goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).taskLoop"), - goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"), - ) -} - -func TestConn(t *testing.T) { - t.Parallel() - t.Run("Ping", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - _, err := client.Ping() - require.NoError(t, err) - _, err = server.Ping() - require.NoError(t, err) - }) - - t.Run("PingNetworkOffline", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - _, err := server.Ping() - require.NoError(t, err) - err = wan.Stop() - require.NoError(t, err) - _, err = server.Ping() - require.ErrorIs(t, err, peer.ErrFailed) - }) - - t.Run("PingReconnect", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - _, err := server.Ping() - require.NoError(t, err) - // Create a channel that closes on disconnect. - channel, err := server.CreateChannel(context.Background(), "wow", nil) - assert.NoError(t, err) - defer channel.Close() - - err = wan.Stop() - require.NoError(t, err) - // Once the connection is marked as disconnected, this - // channel will be closed. - _, err = channel.Read(make([]byte, 4)) - assert.ErrorIs(t, err, peer.ErrClosed) - err = wan.Start() - require.NoError(t, err) - _, err = server.Ping() - require.NoError(t, err) - }) - - t.Run("Accept", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - - sch, err := server.Accept(ctx) - require.NoError(t, err) - defer sch.Close() - - _ = cch.Close() - _, err = sch.Read(make([]byte, 4)) - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("AcceptNetworkOffline", func(t *testing.T) { - t.Parallel() - client, server, wan := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - sch, err := server.Accept(ctx) - require.NoError(t, err) - defer sch.Close() - - err = wan.Stop() - require.NoError(t, err) - _ = cch.Close() - _, err = sch.Read(make([]byte, 4)) - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("Buffering", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - defer cch.Close() - - readErr := make(chan error, 1) - go func() { - sch, err := server.Accept(ctx) - if err != nil { - readErr <- err - _ = cch.Close() - return - } - defer sch.Close() - - bytes := make([]byte, 4096) - for { - _, err = sch.Read(bytes) - if err != nil { - readErr <- err - return - } - } - }() - - bytes := make([]byte, 4096) - for i := 0; i < 1024; i++ { - _, err = cch.Write(bytes) - require.NoError(t, err, "write i=%d", i) - } - _ = cch.Close() - - select { - case err = <-readErr: - require.ErrorIs(t, err, peer.ErrClosed, "read error") - case <-ctx.Done(): - require.Fail(t, "timeout waiting for read error") - } - }) - - t.Run("NetConn", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - srv, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer srv.Close() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - go func() { - sch, err := server.Accept(ctx) - if err != nil { - assert.NoError(t, err) - return - } - defer sch.Close() - - nc2 := sch.NetConn() - defer nc2.Close() - - nc1, err := net.Dial("tcp", srv.Addr().String()) - if err != nil { - assert.NoError(t, err) - return - } - defer nc1.Close() - - go func() { - defer nc1.Close() - defer nc2.Close() - _, _ = io.Copy(nc1, nc2) - }() - _, _ = io.Copy(nc2, nc1) - }() - go func() { - server := http.Server{ - ReadHeaderTimeout: time.Minute, - Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(200) - }), - } - defer server.Close() - _ = server.Serve(srv) - }() - - //nolint:forcetypeassert - defaultTransport := http.DefaultTransport.(*http.Transport).Clone() - var cch *peer.Channel - defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) - if err != nil { - return nil, err - } - return cch.NetConn(), nil - } - c := http.Client{ - Transport: defaultTransport, - } - req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil) - require.NoError(t, err) - resp, err := c.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, resp.StatusCode, 200) - // Triggers any connections to close. - // This test below ensures the DataChannel actually closes. - defaultTransport.CloseIdleConnections() - err = cch.Close() - require.ErrorIs(t, err, peer.ErrClosed) - }) - - t.Run("CloseBeforeNegotiate", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - err := client.Close() - require.NoError(t, err) - err = server.Close() - require.NoError(t, err) - }) - - t.Run("CloseWithError", func(t *testing.T) { - t.Parallel() - conn, err := peer.Client([]webrtc.ICEServer{}, nil) - require.NoError(t, err) - expectedErr := xerrors.New("wow") - _ = conn.CloseWithError(expectedErr) - _, err = conn.CreateChannel(context.Background(), "", nil) - require.ErrorIs(t, err, expectedErr) - }) - - t.Run("PingConcurrent", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - _, err := client.Ping() - assert.NoError(t, err) - }() - go func() { - defer wg.Done() - _, err := server.Ping() - assert.NoError(t, err) - }() - wg.Wait() - }) - - t.Run("CandidateBeforeSessionDescription", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - server.SetRemoteSessionDescription(<-client.LocalSessionDescription()) - sdp := <-server.LocalSessionDescription() - client.AddRemoteCandidate(<-server.LocalCandidate()) - client.SetRemoteSessionDescription(sdp) - server.AddRemoteCandidate(<-client.LocalCandidate()) - _, err := client.Ping() - require.NoError(t, err) - }) - - t.Run("ShortBuffer", func(t *testing.T) { - t.Parallel() - client, server, _ := createPair(t) - exchange(t, client, server) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - go func() { - channel, err := client.CreateChannel(ctx, "test", nil) - if err != nil { - assert.NoError(t, err) - return - } - defer channel.Close() - _, err = channel.Write([]byte{1, 2}) - assert.NoError(t, err) - }() - channel, err := server.Accept(ctx) - require.NoError(t, err) - defer channel.Close() - data := make([]byte, 1) - _, err = channel.Read(data) - require.NoError(t, err) - require.Equal(t, uint8(0x1), data[0]) - _, err = channel.Read(data) - require.NoError(t, err) - require.Equal(t, uint8(0x2), data[0]) - }) -} - -func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { - loggingFactory := logging.NewDefaultLoggerFactory() - loggingFactory.DefaultLogLevel = logging.LogLevelDisabled - vnetMutex.Lock() - defer vnetMutex.Unlock() - wan, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "1.2.3.0/24", - LoggerFactory: loggingFactory, - }) - require.NoError(t, err) - c1Net := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"1.2.3.4"}, - }) - err = wan.AddNet(c1Net) - require.NoError(t, err) - c2Net := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"1.2.3.5"}, - }) - err = wan.AddNet(c2Net) - require.NoError(t, err) - - c1SettingEngine := webrtc.SettingEngine{} - c1SettingEngine.SetVNet(c1Net) - c1SettingEngine.SetPrflxAcceptanceMinWait(0) - c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - channel1, err := peer.Client([]webrtc.ICEServer{{}}, &peer.ConnOptions{ - SettingEngine: c1SettingEngine, - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - t.Cleanup(func() { - channel1.Close() - }) - c2SettingEngine := webrtc.SettingEngine{} - c2SettingEngine.SetVNet(c2Net) - c2SettingEngine.SetPrflxAcceptanceMinWait(0) - c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - channel2, err := peer.Server([]webrtc.ICEServer{{}}, &peer.ConnOptions{ - SettingEngine: c2SettingEngine, - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - t.Cleanup(func() { - channel2.Close() - }) - - err = wan.Start() - require.NoError(t, err) - t.Cleanup(func() { - _ = wan.Stop() - }) - - return channel1, channel2, wan -} - -func exchange(t *testing.T, client, server *peer.Conn) { - var wg sync.WaitGroup - wg.Add(2) - t.Cleanup(func() { - _ = client.Close() - _ = server.Close() - - wg.Wait() - }) - go func() { - defer wg.Done() - for { - select { - case c := <-server.LocalCandidate(): - client.AddRemoteCandidate(c) - case c := <-server.LocalSessionDescription(): - client.SetRemoteSessionDescription(c) - case <-server.Closed(): - return - } - } - }() - go func() { - defer wg.Done() - for { - select { - case c := <-client.LocalCandidate(): - server.AddRemoteCandidate(c) - case c := <-client.LocalSessionDescription(): - server.SetRemoteSessionDescription(c) - case <-client.Closed(): - return - } - } - }() -} diff --git a/peer/netconn.go b/peer/netconn.go deleted file mode 100644 index e564c0ecc209c..0000000000000 --- a/peer/netconn.go +++ /dev/null @@ -1,59 +0,0 @@ -package peer - -import ( - "net" - "time" -) - -type peerAddr struct{} - -// Statically checks if we properly implement net.Addr. -var _ net.Addr = &peerAddr{} - -func (*peerAddr) Network() string { - return "peer" -} - -func (*peerAddr) String() string { - return "peer/unknown-addr" -} - -type fakeNetConn struct { - c *Channel - addr *peerAddr -} - -// Statically checks if we properly implement net.Conn. -var _ net.Conn = &fakeNetConn{} - -func (c *fakeNetConn) Read(b []byte) (n int, err error) { - return c.c.Read(b) -} - -func (c *fakeNetConn) Write(b []byte) (n int, err error) { - return c.c.Write(b) -} - -func (c *fakeNetConn) Close() error { - return c.c.Close() -} - -func (c *fakeNetConn) LocalAddr() net.Addr { - return c.addr -} - -func (c *fakeNetConn) RemoteAddr() net.Addr { - return c.addr -} - -func (*fakeNetConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*fakeNetConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*fakeNetConn) SetWriteDeadline(_ time.Time) error { - return nil -} diff --git a/peerbroker/dial.go b/peerbroker/dial.go deleted file mode 100644 index 61ef7b409a597..0000000000000 --- a/peerbroker/dial.go +++ /dev/null @@ -1,87 +0,0 @@ -package peerbroker - -import ( - "context" - "errors" - "io" - "reflect" - - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" -) - -// Dial consumes the PeerBroker gRPC connection negotiation stream to produce a WebRTC peered connection. -func Dial(stream proto.DRPCPeerBroker_NegotiateConnectionClient, iceServers []webrtc.ICEServer, opts *peer.ConnOptions) (*peer.Conn, error) { - peerConn, err := peer.Client(iceServers, opts) - if err != nil { - return nil, xerrors.Errorf("create peer connection: %w", err) - } - go func() { - defer stream.Close() - // Exchanging messages from the peer connection to negotiate a connection. - for { - select { - case <-peerConn.Closed(): - return - case sessionDescription := <-peerConn.LocalSessionDescription(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_Sdp{ - Sdp: &proto.WebRTCSessionDescription{ - SdpType: int32(sessionDescription.Type), - Sdp: sessionDescription.SDP, - }, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err)) - return - } - case iceCandidate := <-peerConn.LocalCandidate(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_IceCandidate{ - IceCandidate: iceCandidate.Candidate, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err)) - return - } - } - } - }() - go func() { - // Exchanging messages from the server to negotiate a connection. - for { - serverToClientMessage, err := stream.Recv() - if err != nil { - // p2p connections should never die if this stream does due - // to proper closure or context cancellation! - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { - return - } - _ = peerConn.CloseWithError(xerrors.Errorf("recv: %w", err)) - return - } - - switch { - case serverToClientMessage.GetSdp() != nil: - peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{ - Type: webrtc.SDPType(serverToClientMessage.GetSdp().SdpType), - SDP: serverToClientMessage.GetSdp().Sdp, - }) - case serverToClientMessage.GetIceCandidate() != "": - peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{ - Candidate: serverToClientMessage.GetIceCandidate(), - }) - default: - _ = peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(serverToClientMessage).String())) - return - } - } - }() - - return peerConn, nil -} diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go deleted file mode 100644 index efd4e6917ac41..0000000000000 --- a/peerbroker/dial_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package peerbroker_test - -import ( - "context" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestDial(t *testing.T) { - t.Parallel() - - t.Run("Connect", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - settingEngine := webrtc.SettingEngine{} - listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - SettingEngine: settingEngine, - }, nil - }) - require.NoError(t, err) - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - - clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - SettingEngine: settingEngine, - }) - require.NoError(t, err) - defer clientConn.Close() - - serverConn, err := listener.Accept() - require.NoError(t, err) - defer serverConn.Close() - _, err = serverConn.Ping() - require.NoError(t, err) - - _, err = clientConn.Ping() - require.NoError(t, err) - }) -} diff --git a/peerbroker/listen.go b/peerbroker/listen.go deleted file mode 100644 index 34c91ea6e51a4..0000000000000 --- a/peerbroker/listen.go +++ /dev/null @@ -1,188 +0,0 @@ -package peerbroker - -import ( - "context" - "errors" - "io" - "net" - "reflect" - "sync" - - "github.com/pion/webrtc/v3" - "golang.org/x/xerrors" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker/proto" -) - -// ConnSettingsFunc returns initialization options for a connection -type ConnSettingsFunc func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) - -// Listen consumes the transport as the server-side of the PeerBroker dRPC service. -// The Accept function must be serviced, or new connections will hang. -func Listen(connListener net.Listener, connSettingsFunc ConnSettingsFunc) (*Listener, error) { - if connSettingsFunc == nil { - connSettingsFunc = func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return []webrtc.ICEServer{}, nil, nil - } - } - ctx, cancelFunc := context.WithCancel(context.Background()) - listener := &Listener{ - connectionChannel: make(chan *peer.Conn), - connectionListener: connListener, - - closeFunc: cancelFunc, - closed: make(chan struct{}), - } - - mux := drpcmux.New() - err := proto.DRPCRegisterPeerBroker(mux, &peerBrokerService{ - connSettingsFunc: connSettingsFunc, - - listener: listener, - }) - if err != nil { - return nil, xerrors.Errorf("register peer broker: %w", err) - } - srv := drpcserver.New(mux) - go func() { - err := srv.Serve(ctx, connListener) - _ = listener.closeWithError(err) - }() - - return listener, nil -} - -type Listener struct { - connectionChannel chan *peer.Conn - connectionListener net.Listener - - closeFunc context.CancelFunc - closed chan struct{} - closeMutex sync.Mutex - closeError error -} - -// Accept blocks until a connection arrives or the listener is closed. -func (l *Listener) Accept() (*peer.Conn, error) { - select { - case <-l.closed: - return nil, l.closeError - case conn := <-l.connectionChannel: - return conn, nil - } -} - -// Close ends the listener. This will block all new WebRTC connections -// from establishing, but will not close active connections. -func (l *Listener) Close() error { - return l.closeWithError(io.EOF) -} - -func (l *Listener) closeWithError(err error) error { - l.closeMutex.Lock() - defer l.closeMutex.Unlock() - - if l.isClosed() { - return l.closeError - } - - _ = l.connectionListener.Close() - l.closeError = err - l.closeFunc() - close(l.closed) - - return nil -} - -func (l *Listener) isClosed() bool { - select { - case <-l.closed: - return true - default: - return false - } -} - -// Implements the PeerBroker service protobuf definition. -type peerBrokerService struct { - listener *Listener - - connSettingsFunc ConnSettingsFunc -} - -// NegotiateConnection negotiates a WebRTC connection. -func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { - iceServers, connOptions, err := b.connSettingsFunc(stream.Context()) - if err != nil { - return xerrors.Errorf("get connection settings: %w", err) - } - peerConn, err := peer.Server(iceServers, connOptions) - if err != nil { - return xerrors.Errorf("create peer connection: %w", err) - } - select { - case <-b.listener.closed: - return peerConn.CloseWithError(b.listener.closeError) - case b.listener.connectionChannel <- peerConn: - } - go func() { - defer stream.Close() - for { - select { - case <-peerConn.Closed(): - return - case sessionDescription := <-peerConn.LocalSessionDescription(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_Sdp{ - Sdp: &proto.WebRTCSessionDescription{ - SdpType: int32(sessionDescription.Type), - Sdp: sessionDescription.SDP, - }, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err)) - return - } - case iceCandidate := <-peerConn.LocalCandidate(): - err = stream.Send(&proto.Exchange{ - Message: &proto.Exchange_IceCandidate{ - IceCandidate: iceCandidate.Candidate, - }, - }) - if err != nil { - _ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err)) - return - } - } - } - }() - for { - clientToServerMessage, err := stream.Recv() - if err != nil { - // p2p connections should never die if this stream does due - // to proper closure or context cancellation! - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { - return nil - } - return peerConn.CloseWithError(xerrors.Errorf("recv: %w", err)) - } - - switch { - case clientToServerMessage.GetSdp() != nil: - peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{ - Type: webrtc.SDPType(clientToServerMessage.GetSdp().SdpType), - SDP: clientToServerMessage.GetSdp().Sdp, - }) - case clientToServerMessage.GetIceCandidate() != "": - peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{ - Candidate: clientToServerMessage.GetIceCandidate(), - }) - default: - return peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(clientToServerMessage).String())) - } - } -} diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go deleted file mode 100644 index 81582a91d4b84..0000000000000 --- a/peerbroker/listen_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package peerbroker_test - -import ( - "context" - "io" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestListen(t *testing.T) { - t.Parallel() - // Ensures connections blocked on Accept() are - // closed if the listener is. - t.Run("NoAcceptClosed", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - listener, err := peerbroker.Listen(server, nil) - require.NoError(t, err) - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - clientConn, err := peerbroker.Dial(stream, nil, nil) - require.NoError(t, err) - defer clientConn.Close() - - _ = listener.Close() - }) - - // Ensures Accept() properly exits when Close() is called. - t.Run("AcceptClosed", func(t *testing.T) { - t.Parallel() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - - listener, err := peerbroker.Listen(server, nil) - require.NoError(t, err) - go listener.Close() - _, err = listener.Accept() - require.ErrorIs(t, err, io.EOF) - }) -} diff --git a/peerbroker/proto/peerbroker.pb.go b/peerbroker/proto/peerbroker.pb.go deleted file mode 100644 index d4e09f44be118..0000000000000 --- a/peerbroker/proto/peerbroker.pb.go +++ /dev/null @@ -1,269 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.26.0 -// protoc v3.21.5 -// source: peerbroker/proto/peerbroker.proto - -package proto - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type WebRTCSessionDescription struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - SdpType int32 `protobuf:"varint,1,opt,name=sdp_type,json=sdpType,proto3" json:"sdp_type,omitempty"` - Sdp string `protobuf:"bytes,2,opt,name=sdp,proto3" json:"sdp,omitempty"` -} - -func (x *WebRTCSessionDescription) Reset() { - *x = WebRTCSessionDescription{} - if protoimpl.UnsafeEnabled { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *WebRTCSessionDescription) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WebRTCSessionDescription) ProtoMessage() {} - -func (x *WebRTCSessionDescription) ProtoReflect() protoreflect.Message { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WebRTCSessionDescription.ProtoReflect.Descriptor instead. -func (*WebRTCSessionDescription) Descriptor() ([]byte, []int) { - return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{0} -} - -func (x *WebRTCSessionDescription) GetSdpType() int32 { - if x != nil { - return x.SdpType - } - return 0 -} - -func (x *WebRTCSessionDescription) GetSdp() string { - if x != nil { - return x.Sdp - } - return "" -} - -type Exchange struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Types that are assignable to Message: - // - // *Exchange_Sdp - // *Exchange_IceCandidate - Message isExchange_Message `protobuf_oneof:"message"` -} - -func (x *Exchange) Reset() { - *x = Exchange{} - if protoimpl.UnsafeEnabled { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Exchange) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Exchange) ProtoMessage() {} - -func (x *Exchange) ProtoReflect() protoreflect.Message { - mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Exchange.ProtoReflect.Descriptor instead. -func (*Exchange) Descriptor() ([]byte, []int) { - return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{1} -} - -func (m *Exchange) GetMessage() isExchange_Message { - if m != nil { - return m.Message - } - return nil -} - -func (x *Exchange) GetSdp() *WebRTCSessionDescription { - if x, ok := x.GetMessage().(*Exchange_Sdp); ok { - return x.Sdp - } - return nil -} - -func (x *Exchange) GetIceCandidate() string { - if x, ok := x.GetMessage().(*Exchange_IceCandidate); ok { - return x.IceCandidate - } - return "" -} - -type isExchange_Message interface { - isExchange_Message() -} - -type Exchange_Sdp struct { - Sdp *WebRTCSessionDescription `protobuf:"bytes,1,opt,name=sdp,proto3,oneof"` -} - -type Exchange_IceCandidate struct { - IceCandidate string `protobuf:"bytes,2,opt,name=ice_candidate,json=iceCandidate,proto3,oneof"` -} - -func (*Exchange_Sdp) isExchange_Message() {} - -func (*Exchange_IceCandidate) isExchange_Message() {} - -var File_peerbroker_proto_peerbroker_proto protoreflect.FileDescriptor - -var file_peerbroker_proto_peerbroker_proto_rawDesc = []byte{ - 0x0a, 0x21, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2f, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x22, - 0x47, 0x0a, 0x18, 0x57, 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x73, - 0x64, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x73, - 0x64, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x64, 0x70, 0x22, 0x76, 0x0a, 0x08, 0x45, 0x78, 0x63, 0x68, - 0x61, 0x6e, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x24, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x57, - 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x44, 0x65, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x73, 0x64, 0x70, 0x12, 0x25, - 0x0a, 0x0d, 0x69, 0x63, 0x65, 0x5f, 0x63, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x69, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x32, 0x53, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x42, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x12, 0x45, - 0x0a, 0x13, 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, 0x65, 0x43, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x14, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, - 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x14, 0x2e, 0x70, 0x65, - 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, - 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, - 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_peerbroker_proto_peerbroker_proto_rawDescOnce sync.Once - file_peerbroker_proto_peerbroker_proto_rawDescData = file_peerbroker_proto_peerbroker_proto_rawDesc -) - -func file_peerbroker_proto_peerbroker_proto_rawDescGZIP() []byte { - file_peerbroker_proto_peerbroker_proto_rawDescOnce.Do(func() { - file_peerbroker_proto_peerbroker_proto_rawDescData = protoimpl.X.CompressGZIP(file_peerbroker_proto_peerbroker_proto_rawDescData) - }) - return file_peerbroker_proto_peerbroker_proto_rawDescData -} - -var file_peerbroker_proto_peerbroker_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_peerbroker_proto_peerbroker_proto_goTypes = []interface{}{ - (*WebRTCSessionDescription)(nil), // 0: peerbroker.WebRTCSessionDescription - (*Exchange)(nil), // 1: peerbroker.Exchange -} -var file_peerbroker_proto_peerbroker_proto_depIdxs = []int32{ - 0, // 0: peerbroker.Exchange.sdp:type_name -> peerbroker.WebRTCSessionDescription - 1, // 1: peerbroker.PeerBroker.NegotiateConnection:input_type -> peerbroker.Exchange - 1, // 2: peerbroker.PeerBroker.NegotiateConnection:output_type -> peerbroker.Exchange - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_peerbroker_proto_peerbroker_proto_init() } -func file_peerbroker_proto_peerbroker_proto_init() { - if File_peerbroker_proto_peerbroker_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_peerbroker_proto_peerbroker_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*WebRTCSessionDescription); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peerbroker_proto_peerbroker_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Exchange); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_peerbroker_proto_peerbroker_proto_msgTypes[1].OneofWrappers = []interface{}{ - (*Exchange_Sdp)(nil), - (*Exchange_IceCandidate)(nil), - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_peerbroker_proto_peerbroker_proto_rawDesc, - NumEnums: 0, - NumMessages: 2, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_peerbroker_proto_peerbroker_proto_goTypes, - DependencyIndexes: file_peerbroker_proto_peerbroker_proto_depIdxs, - MessageInfos: file_peerbroker_proto_peerbroker_proto_msgTypes, - }.Build() - File_peerbroker_proto_peerbroker_proto = out.File - file_peerbroker_proto_peerbroker_proto_rawDesc = nil - file_peerbroker_proto_peerbroker_proto_goTypes = nil - file_peerbroker_proto_peerbroker_proto_depIdxs = nil -} diff --git a/peerbroker/proto/peerbroker.proto b/peerbroker/proto/peerbroker.proto deleted file mode 100644 index f67b338ed3372..0000000000000 --- a/peerbroker/proto/peerbroker.proto +++ /dev/null @@ -1,28 +0,0 @@ - -syntax = "proto3"; -option go_package = "github.com/coder/coder/peerbroker/proto"; - -package peerbroker; - -message WebRTCSessionDescription { - int32 sdp_type = 1; - string sdp = 2; -} - -message Exchange { - oneof message { - WebRTCSessionDescription sdp = 1; - string ice_candidate = 2; - } -} - -// PeerBroker mediates WebRTC connection signaling. -service PeerBroker { - // NegotiateConnection establishes a bidirectional stream to negotiate a new WebRTC connection. - // 1. Client sends WebRTCSessionDescription to the server. - // 2. Server sends WebRTCSessionDescription to the client, exchanging encryption keys. - // 3. Client<->Server exchange ICE Candidates to establish a peered connection. - // - // See: https://davekilian.com/webrtc-the-hard-way.html - rpc NegotiateConnection(stream Exchange) returns (stream Exchange); -} \ No newline at end of file diff --git a/peerbroker/proto/peerbroker_drpc.pb.go b/peerbroker/proto/peerbroker_drpc.pb.go deleted file mode 100644 index ae06f79a01371..0000000000000 --- a/peerbroker/proto/peerbroker_drpc.pb.go +++ /dev/null @@ -1,146 +0,0 @@ -// Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.26 -// source: peerbroker/proto/peerbroker.proto - -package proto - -import ( - context "context" - errors "errors" - protojson "google.golang.org/protobuf/encoding/protojson" - proto "google.golang.org/protobuf/proto" - drpc "storj.io/drpc" - drpcerr "storj.io/drpc/drpcerr" -) - -type drpcEncoding_File_peerbroker_proto_peerbroker_proto struct{} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Marshal(msg drpc.Message) ([]byte, error) { - return proto.Marshal(msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) { - return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Unmarshal(buf []byte, msg drpc.Message) error { - return proto.Unmarshal(buf, msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONMarshal(msg drpc.Message) ([]byte, error) { - return protojson.Marshal(msg.(proto.Message)) -} - -func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error { - return protojson.Unmarshal(buf, msg.(proto.Message)) -} - -type DRPCPeerBrokerClient interface { - DRPCConn() drpc.Conn - - NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) -} - -type drpcPeerBrokerClient struct { - cc drpc.Conn -} - -func NewDRPCPeerBrokerClient(cc drpc.Conn) DRPCPeerBrokerClient { - return &drpcPeerBrokerClient{cc} -} - -func (c *drpcPeerBrokerClient) DRPCConn() drpc.Conn { return c.cc } - -func (c *drpcPeerBrokerClient) NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) { - stream, err := c.cc.NewStream(ctx, "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) - if err != nil { - return nil, err - } - x := &drpcPeerBroker_NegotiateConnectionClient{stream} - return x, nil -} - -type DRPCPeerBroker_NegotiateConnectionClient interface { - drpc.Stream - Send(*Exchange) error - Recv() (*Exchange, error) -} - -type drpcPeerBroker_NegotiateConnectionClient struct { - drpc.Stream -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) Send(m *Exchange) error { - return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) Recv() (*Exchange, error) { - m := new(Exchange) - if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil { - return nil, err - } - return m, nil -} - -func (x *drpcPeerBroker_NegotiateConnectionClient) RecvMsg(m *Exchange) error { - return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -type DRPCPeerBrokerServer interface { - NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error -} - -type DRPCPeerBrokerUnimplementedServer struct{} - -func (s *DRPCPeerBrokerUnimplementedServer) NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error { - return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) -} - -type DRPCPeerBrokerDescription struct{} - -func (DRPCPeerBrokerDescription) NumMethods() int { return 1 } - -func (DRPCPeerBrokerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { - switch n { - case 0: - return "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{}, - func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { - return nil, srv.(DRPCPeerBrokerServer). - NegotiateConnection( - &drpcPeerBroker_NegotiateConnectionStream{in1.(drpc.Stream)}, - ) - }, DRPCPeerBrokerServer.NegotiateConnection, true - default: - return "", nil, nil, nil, false - } -} - -func DRPCRegisterPeerBroker(mux drpc.Mux, impl DRPCPeerBrokerServer) error { - return mux.Register(impl, DRPCPeerBrokerDescription{}) -} - -type DRPCPeerBroker_NegotiateConnectionStream interface { - drpc.Stream - Send(*Exchange) error - Recv() (*Exchange, error) -} - -type drpcPeerBroker_NegotiateConnectionStream struct { - drpc.Stream -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) Send(m *Exchange) error { - return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) Recv() (*Exchange, error) { - m := new(Exchange) - if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil { - return nil, err - } - return m, nil -} - -func (x *drpcPeerBroker_NegotiateConnectionStream) RecvMsg(m *Exchange) error { - return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}) -} diff --git a/peerbroker/proxy.go b/peerbroker/proxy.go deleted file mode 100644 index 3e3ccb441776b..0000000000000 --- a/peerbroker/proxy.go +++ /dev/null @@ -1,283 +0,0 @@ -package peerbroker - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "io" - "net" - "sync" - - "github.com/google/uuid" - "github.com/hashicorp/yamux" - "golang.org/x/xerrors" - protobuf "google.golang.org/protobuf/proto" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - - "cdr.dev/slog" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/peerbroker/proto" -) - -var ( - // Each NegotiateConnection() function call spawns a new stream. - streamIDLength = len(uuid.NewString()) - // We shouldn't PubSub anything larger than this! - maxPayloadSizeBytes = 8192 -) - -// ProxyOptions provides values to configure a proxy. -type ProxyOptions struct { - ChannelID string - Logger slog.Logger - Pubsub database.Pubsub -} - -// ProxyDial writes client negotiation streams over PubSub. -// -// PubSub is used to geodistribute WebRTC handshakes. All negotiation -// messages are small in size (<=8KB), and we don't require delivery -// guarantees because connections can always be renegotiated. -// -// ┌────────────────────┐ ┌─────────────────────────────┐ -// │ coderd │ │ coderd │ -// -// ┌─────────────────────┐ │//connect │ │ //listen │ -// │ client │ │ │ │ │ ┌─────┐ -// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the │◄──┤agent│ -// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘ -// └─────────────────────┘ │ channel: │ │from payloads to create new │ -// -// │ │ │NegotiateConnection() streams│ -// ││ │or write to existing ones. │ -// └────────────────────┘ └─────────────────────────────┘ -func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) { - proxyDial := &proxyDial{ - channelID: options.ChannelID, - logger: options.Logger, - pubsub: options.Pubsub, - connection: client, - streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient), - } - return proxyDial, proxyDial.listen() -} - -// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener -// as new NegotiateConnection() streams. -func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error { - mux := drpcmux.New() - err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{ - channelID: options.ChannelID, - pubsub: options.Pubsub, - logger: options.Logger, - }) - if err != nil { - return xerrors.Errorf("register peer broker: %w", err) - } - server := drpcserver.New(mux) - err = server.Serve(ctx, connListener) - if err != nil { - if errors.Is(err, yamux.ErrSessionShutdown) { - return nil - } - return xerrors.Errorf("serve: %w", err) - } - return nil -} - -type proxyListen struct { - channelID string - pubsub database.Pubsub - logger slog.Logger -} - -func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { - streamID := uuid.NewString() - var err error - closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) { - err := p.onServerToClientMessage(streamID, stream, message) - if err != nil { - p.logger.Debug(ctx, "failed to accept server message", slog.Error(err)) - } - }) - if err != nil { - return xerrors.Errorf("subscribe: %w", err) - } - defer closeSubscribe() - for { - clientToServerMessage, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return xerrors.Errorf("recv: %w", err) - } - data, err := protobuf.Marshal(clientToServerMessage) - if err != nil { - return xerrors.Errorf("marshal: %w", err) - } - if len(data) > maxPayloadSizeBytes { - return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes) - } - data = append([]byte(streamID), data...) - err = p.pubsub.Publish(proxyOutID(p.channelID), marshal(data)) - if err != nil { - return xerrors.Errorf("publish: %w", err) - } - } - return nil -} - -func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error { - var err error - message, err = unmarshal(message) - if err != nil { - return xerrors.Errorf("decode: %w", err) - } - if len(message) < streamIDLength { - return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength) - } - serverStreamID := string(message[0:streamIDLength]) - if serverStreamID != streamID { - // It's not trying to communicate with this stream! - return nil - } - var msg proto.Exchange - err = protobuf.Unmarshal(message[streamIDLength:], &msg) - if err != nil { - return xerrors.Errorf("unmarshal message: %w", err) - } - err = stream.Send(&msg) - if err != nil { - return xerrors.Errorf("send message: %w", err) - } - return nil -} - -type proxyDial struct { - channelID string - pubsub database.Pubsub - logger slog.Logger - - connection proto.DRPCPeerBrokerClient - closeSubscribe func() - streamMutex sync.Mutex - streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient -} - -func (p *proxyDial) listen() error { - var err error - p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) { - err := p.onClientToServerMessage(ctx, message) - if err != nil { - p.logger.Debug(ctx, "failed to accept client message", slog.Error(err)) - } - }) - if err != nil { - return err - } - return nil -} - -func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error { - var err error - message, err = unmarshal(message) - if err != nil { - return xerrors.Errorf("decode: %w", err) - } - if len(message) < streamIDLength { - return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength) - } - streamID := string(message[0:streamIDLength]) - p.streamMutex.Lock() - stream, ok := p.streams[streamID] - if !ok { - stream, err = p.connection.NegotiateConnection(ctx) - if err != nil { - p.streamMutex.Unlock() - return xerrors.Errorf("negotiate connection: %w", err) - } - p.streams[streamID] = stream - go func() { - defer stream.Close() - - err := p.onServerToClientMessage(streamID, stream) - if err != nil { - p.logger.Debug(ctx, "failed to accept server message", slog.Error(err)) - } - }() - go func() { - <-stream.Context().Done() - p.streamMutex.Lock() - delete(p.streams, streamID) - p.streamMutex.Unlock() - }() - } - p.streamMutex.Unlock() - - var msg proto.Exchange - err = protobuf.Unmarshal(message[streamIDLength:], &msg) - if err != nil { - return xerrors.Errorf("unmarshal message: %w", err) - } - err = stream.Send(&msg) - if err != nil { - return xerrors.Errorf("write message: %w", err) - } - return nil -} - -func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error { - for { - serverToClientMessage, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - if errors.Is(err, context.Canceled) { - break - } - return xerrors.Errorf("recv: %w", err) - } - data, err := protobuf.Marshal(serverToClientMessage) - if err != nil { - return xerrors.Errorf("marshal: %w", err) - } - if len(data) > maxPayloadSizeBytes { - return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes) - } - data = append([]byte(streamID), data...) - err = p.pubsub.Publish(proxyInID(p.channelID), marshal(data)) - if err != nil { - return xerrors.Errorf("publish: %w", err) - } - } - return nil -} - -func (p *proxyDial) Close() error { - p.streamMutex.Lock() - defer p.streamMutex.Unlock() - p.closeSubscribe() - return nil -} - -// base64 needs to be used here to keep the pubsub messages in UTF-8 range. -// PostgreSQL cannot handle non UTF-8 messages over pubsub. -func marshal(data []byte) []byte { - return []byte(base64.StdEncoding.EncodeToString(data)) -} - -func unmarshal(data []byte) ([]byte, error) { - return base64.StdEncoding.DecodeString(string(data)) -} - -func proxyOutID(channelID string) string { - return fmt.Sprintf("%s-out", channelID) -} - -func proxyInID(channelID string) string { - return fmt.Sprintf("%s-in", channelID) -} diff --git a/peerbroker/proxy_test.go b/peerbroker/proxy_test.go deleted file mode 100644 index 80fe405c24fcf..0000000000000 --- a/peerbroker/proxy_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package peerbroker_test - -import ( - "context" - "sync" - "testing" - - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" -) - -func TestProxy(t *testing.T) { - t.Parallel() - ctx := context.Background() - channelID := "hello" - pubsub := database.NewPubsubInMemory() - dialerClient, dialerServer := provisionersdk.TransportPipe() - defer dialerClient.Close() - defer dialerServer.Close() - listenerClient, listenerServer := provisionersdk.TransportPipe() - defer listenerClient.Close() - defer listenerServer.Close() - - listener, err := peerbroker.Listen(listenerServer, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return nil, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - }, nil - }) - require.NoError(t, err) - - proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{ - ChannelID: channelID, - Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug), - Pubsub: pubsub, - }) - require.NoError(t, err) - defer func() { - _ = proxyCloser.Close() - }() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{ - ChannelID: channelID, - Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug), - Pubsub: pubsub, - }) - assert.NoError(t, err) - }() - - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), - }) - require.NoError(t, err) - defer clientConn.Close() - - serverConn, err := listener.Accept() - require.NoError(t, err) - defer serverConn.Close() - _, err = serverConn.Ping() - require.NoError(t, err) - - _, err = clientConn.Ping() - require.NoError(t, err) - - _ = dialerServer.Close() - wg.Wait() -}