diff --git a/.vscode/settings.json b/.vscode/settings.json index 81965c42613bd..68d823b8cff81 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,6 +13,9 @@ "codersdk", "cronstrue", "databasefake", + "DERP", + "derphttp", + "derpmap", "devel", "drpc", "drpcconn", @@ -25,6 +28,7 @@ "goarch", "gographviz", "goleak", + "gonet", "gossh", "gsyslog", "hashicorp", @@ -34,20 +38,30 @@ "idtoken", "Iflag", "incpatch", + "ipnstate", "isatty", "Jobf", "Keygen", "kirsle", "Kubernetes", "ldflags", + "magicsock", "manifoldco", "mapstructure", "mattn", "mitchellh", "moby", "namesgenerator", + "namespacing", + "netaddr", + "netip", + "netmap", + "netns", + "netstack", + "nettype", "nfpms", "nhooyr", + "nmcfg", "nolint", "nosec", "ntqry", @@ -63,14 +77,23 @@ "provisionersdk", "ptty", "ptytest", + "reconfig", "retrier", "rpty", "sdkproto", "sdktrace", "Signup", + "slogtest", "sourcemapped", "Srcs", "stretchr", + "stuntest", + "tailbroker", + "tailcfg", + "tailexchange", + "tailnet", + "tailnettest", + "Tailscale", "TCGETS", "tcpip", "TCSETS", @@ -84,19 +107,30 @@ "tfstate", "tparallel", "trimprefix", + "tsdial", + "tslogger", + "tstun", "turnconn", "typegen", "unconvert", "Untar", + "Userspace", "VMID", "weblinks", "webrtc", + "wgcfg", + "wgconfig", + "wgengine", + "wgmonitor", + "wgnet", "workspaceagent", + "workspaceagents", "workspaceapp", "workspaceapps", "workspacebuilds", "workspacename", "wsconncache", + "wsjson", "xerrors", "xstate", "yamux" diff --git a/agent/agent.go b/agent/agent.go index 29dae43162e98..da19c4ac61de7 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -4,11 +4,13 @@ import ( "context" "crypto/rand" "crypto/rsa" + "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" + "net/netip" "net/url" "os" "os/exec" @@ -27,15 +29,14 @@ import ( "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" - "inet.af/netaddr" - "tailscale.com/types/key" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/agent/usershell" "github.com/coder/coder/peer" - "github.com/coder/coder/peer/peerwg" "github.com/coder/coder/peerbroker" "github.com/coder/coder/pty" + "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -50,57 +51,63 @@ const ( MagicSessionErrorCode = 229 ) +var ( + // tailnetIP is a static IPv6 address with the Tailscale prefix that is used to route + // connections from clients to this node. A dynamic address is not required because a Tailnet + // client only dials a single agent at a time. + tailnetIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") + tailnetSSHPort = 1 + tailnetReconnectingPTYPort = 2 +) + type Options struct { - EnableWireguard bool - UploadWireguardKeys UploadWireguardKeys - ListenWireguardPeers ListenWireguardPeers + CoordinatorDialer CoordinatorDialer + WebRTCDialer WebRTCDialer + FetchMetadata FetchMetadata + ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string Logger slog.Logger } type Metadata struct { - WireguardAddresses []netaddr.IPPrefix `json:"addresses"` - EnvironmentVariables map[string]string `json:"environment_variables"` - StartupScript string `json:"startup_script"` - Directory string `json:"directory"` + DERPMap *tailcfg.DERPMap `json:"derpmap"` + EnvironmentVariables map[string]string `json:"environment_variables"` + StartupScript string `json:"startup_script"` + Directory string `json:"directory"` } -type WireguardPublicKeys struct { - Public key.NodePublic `json:"public"` - Disco key.DiscoPublic `json:"disco"` -} +type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) -type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error) -type UploadWireguardKeys func(ctx context.Context, keys WireguardPublicKeys) error -type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error) +// CoordinatorDialer is a function that constructs a new broker. +// A dialer must be passed in to allow for reconnects. +type CoordinatorDialer func(ctx context.Context) (net.Conn, error) -func New(dialer Dialer, options *Options) io.Closer { - if options == nil { - options = &Options{} - } +// FetchMetadata is a function to obtain metadata for the agent. +type FetchMetadata func(ctx context.Context) (Metadata, error) + +func New(options Options) io.Closer { if options.ReconnectingPTYTimeout == 0 { options.ReconnectingPTYTimeout = 5 * time.Minute } ctx, cancelFunc := context.WithCancel(context.Background()) server := &agent{ - dialer: dialer, + webrtcDialer: options.WebRTCDialer, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, closeCancel: cancelFunc, closed: make(chan struct{}), envVars: options.EnvironmentVariables, - enableWireguard: options.EnableWireguard, - postKeys: options.UploadWireguardKeys, - listenWireguardPeers: options.ListenWireguardPeers, + coordinatorDialer: options.CoordinatorDialer, + fetchMetadata: options.FetchMetadata, } server.init(ctx) return server } type agent struct { - dialer Dialer - logger slog.Logger + webrtcDialer WebRTCDialer + logger slog.Logger reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration @@ -113,24 +120,21 @@ type agent struct { envVars map[string]string // metadata is atomic because values can change after reconnection. metadata atomic.Value - startupScript atomic.Bool + fetchMetadata FetchMetadata sshServer *ssh.Server - enableWireguard bool - network *peerwg.Network - postKeys UploadWireguardKeys - listenWireguardPeers ListenWireguardPeers + network *tailnet.Conn + coordinatorDialer CoordinatorDialer } func (a *agent) run(ctx context.Context) { var metadata Metadata - var peerListener *peerbroker.Listener var err error // An exponential back-off occurs when the connection is failing to dial. // This is to prevent server spam in case of a coderd outage. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting") - metadata, peerListener, err = a.dialer(ctx, a.logger) + metadata, err = a.fetchMetadata(ctx) if err != nil { if errors.Is(err, context.Canceled) { return @@ -141,7 +145,7 @@ func (a *agent) run(ctx context.Context) { a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) continue } - a.logger.Info(context.Background(), "connected") + a.logger.Info(context.Background(), "fetched metadata") break } select { @@ -151,24 +155,164 @@ func (a *agent) run(ctx context.Context) { } a.metadata.Store(metadata) - if a.startupScript.CAS(false, true) { - // The startup script has not ran yet! - go func() { - err := a.runStartupScript(ctx, metadata.StartupScript) - if errors.Is(err, context.Canceled) { + // The startup script has not ran yet! + go func() { + err := a.runStartupScript(ctx, metadata.StartupScript) + if errors.Is(err, context.Canceled) { + return + } + if err != nil { + a.logger.Warn(ctx, "agent script failed", slog.Error(err)) + } + }() + + if a.webrtcDialer != nil { + go a.runWebRTCNetworking(ctx) + } + if metadata.DERPMap != nil { + go a.runTailnet(ctx, metadata.DERPMap) + } +} + +func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) { + a.closeMutex.Lock() + defer a.closeMutex.Unlock() + if a.isClosed() { + return + } + if a.network != nil { + a.network.SetDERPMap(derpMap) + return + } + var err error + a.network, err = tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnetIP, 128)}, + DERPMap: derpMap, + Logger: a.logger.Named("tailnet"), + }) + if err != nil { + a.logger.Critical(ctx, "create tailnet", slog.Error(err)) + return + } + go a.runCoordinator(ctx) + + sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSSHPort)) + if err != nil { + a.logger.Critical(ctx, "listen for ssh", slog.Error(err)) + return + } + go func() { + for { + conn, err := sshListener.Accept() + if err != nil { return } + go a.sshServer.HandleConn(conn) + } + }() + reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetReconnectingPTYPort)) + if err != nil { + a.logger.Critical(ctx, "listen for reconnecting pty", slog.Error(err)) + return + } + go func() { + for { + conn, err := reconnectingPTYListener.Accept() if err != nil { - a.logger.Warn(ctx, "agent script failed", slog.Error(err)) + return } - }() + // This cannot use a JSON decoder, since that can + // buffer additional data that is required for the PTY. + rawLen := make([]byte, 2) + _, err = conn.Read(rawLen) + if err != nil { + continue + } + length := binary.LittleEndian.Uint16(rawLen) + data := make([]byte, length) + _, err = conn.Read(data) + if err != nil { + continue + } + var msg reconnectingPTYInit + err = json.Unmarshal(data, &msg) + if err != nil { + continue + } + go a.handleReconnectingPTY(ctx, msg, conn) + } + }() +} + +// runCoordinator listens for nodes and updates the self-node as it changes. +func (a *agent) runCoordinator(ctx context.Context) { + var coordinator net.Conn + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + coordinator, err = a.coordinatorDialer(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + if a.isClosed() { + return + } + a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + a.logger.Info(context.Background(), "connected to coordination server") + break + } + select { + case <-ctx.Done(): + return + default: + } + defer coordinator.Close() + sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes) + a.network.SetNodeCallback(sendNodes) + select { + case <-ctx.Done(): + return + case err := <-errChan: + if a.isClosed() { + return + } + if errors.Is(err, context.Canceled) { + return + } + a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err)) + a.runCoordinator(ctx) + return } +} - if a.enableWireguard { - err = a.startWireguard(ctx, metadata.WireguardAddresses) +func (a *agent) runWebRTCNetworking(ctx context.Context) { + var peerListener *peerbroker.Listener + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + peerListener, err = a.webrtcDialer(ctx, a.logger) if err != nil { - a.logger.Error(ctx, "start wireguard", slog.Error(err)) + if errors.Is(err, context.Canceled) { + return + } + if a.isClosed() { + return + } + a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue } + a.logger.Info(context.Background(), "connected to webrtc broker") + break + } + select { + case <-ctx.Done(): + return + default: } for { @@ -178,7 +322,7 @@ func (a *agent) run(ctx context.Context) { return } a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) - a.run(ctx) + a.runWebRTCNetworking(ctx) return } a.closeMutex.Lock() @@ -243,7 +387,38 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { case ProtocolSSH: go a.sshServer.HandleConn(channel.NetConn()) case ProtocolReconnectingPTY: - go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn()) + rawID := channel.Label() + // The ID format is referenced in conn.go. + // :: + idParts := strings.SplitN(rawID, ":", 4) + if len(idParts) != 4 { + a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID)) + continue + } + id := idParts[0] + // Enforce a consistent format for IDs. + _, err := uuid.Parse(id) + if err != nil { + a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err)) + continue + } + // Parse the initial terminal dimensions. + height, err := strconv.Atoi(idParts[1]) + if err != nil { + a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1])) + continue + } + width, err := strconv.Atoi(idParts[2]) + if err != nil { + a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2])) + continue + } + go a.handleReconnectingPTY(ctx, reconnectingPTYInit{ + ID: id, + Height: uint16(height), + Width: uint16(width), + Command: idParts[3], + }, channel.NetConn()) case ProtocolDial: go a.handleDial(ctx, channel.Label(), channel.NetConn()) default: @@ -514,45 +689,19 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { return cmd.Wait() } -func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) { +func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYInit, conn net.Conn) { defer conn.Close() - // The ID format is referenced in conn.go. - // :: - idParts := strings.SplitN(rawID, ":", 4) - if len(idParts) != 4 { - a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID)) - return - } - id := idParts[0] - // Enforce a consistent format for IDs. - _, err := uuid.Parse(id) - if err != nil { - a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err)) - return - } - // Parse the initial terminal dimensions. - height, err := strconv.Atoi(idParts[1]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1])) - return - } - width, err := strconv.Atoi(idParts[2]) - if err != nil { - a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2])) - return - } - var rpty *reconnectingPTY - rawRPTY, ok := a.reconnectingPTYs.Load(id) + rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID) if ok { rpty, ok = rawRPTY.(*reconnectingPTY) if !ok { - a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id)) + a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", msg.ID)) } } else { // Empty command will default to the users shell! - cmd, err := a.createCommand(ctx, idParts[3], nil) + cmd, err := a.createCommand(ctx, msg.Command, nil) if err != nil { a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err)) return @@ -561,7 +710,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne ptty, process, err := pty.Start(cmd) if err != nil { - a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id)) + a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", msg.ID)) } // Default to buffer 64KiB. @@ -582,7 +731,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc), circularBuffer: circularBuffer, } - a.reconnectingPTYs.Store(id, rpty) + a.reconnectingPTYs.Store(msg.ID, rpty) go func() { // CommandContext isn't respected for Windows PTYs right now, // so we need to manually track the lifecycle. @@ -611,7 +760,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne _, err = rpty.circularBuffer.Write(part) rpty.circularBufferMutex.Unlock() if err != nil { - a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id)) + a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", msg.ID)) break } rpty.activeConnsMutex.Lock() @@ -625,22 +774,22 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne // ID from memory. _ = process.Kill() rpty.Close() - a.reconnectingPTYs.Delete(id) + a.reconnectingPTYs.Delete(msg.ID) a.connCloseWait.Done() }() } // Resize the PTY to initial height + width. - err = rpty.ptty.Resize(uint16(height), uint16(width)) + err := rpty.ptty.Resize(msg.Height, msg.Width) if err != nil { // We can continue after this, it's not fatal! - a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err)) + a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) } // Write any previously stored data for the TTY. rpty.circularBufferMutex.RLock() _, err = conn.Write(rpty.circularBuffer.Bytes()) rpty.circularBufferMutex.RUnlock() if err != nil { - a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err)) + a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err)) return } connectionID := uuid.NewString() @@ -686,12 +835,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne return } if err != nil { - a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err)) + a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", msg.ID), slog.Error(err)) return } _, err = rpty.ptty.Input().Write([]byte(req.Data)) if err != nil { - a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err)) + a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) return } // Check if a resize needs to happen! @@ -701,7 +850,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne err = rpty.ptty.Resize(req.Height, req.Width) if err != nil { // We can continue after this, it's not fatal! - a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err)) + a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) } } } @@ -788,6 +937,9 @@ func (a *agent) Close() error { } close(a.closed) a.closeCancel() + if a.network != nil { + _ = a.network.Close() + } _ = a.sshServer.Close() a.connCloseWait.Wait() return nil diff --git a/agent/agent_test.go b/agent/agent_test.go index b30db6ae25e43..fa671cd01723b 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -7,12 +7,14 @@ import ( "fmt" "io" "net" + "net/netip" "os" "os/exec" "path/filepath" "runtime" "strconv" "strings" + "sync" "testing" "time" @@ -38,6 +40,8 @@ import ( "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/pty/ptytest" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" "github.com/coder/coder/testutil" ) @@ -314,7 +318,9 @@ func TestAgent(t *testing.T) { t.Skip("ConPTY appears to be inconsistent on Windows.") } - conn := setupAgent(t, agent.Metadata{}, 0) + conn := setupAgent(t, agent.Metadata{ + DERPMap: tailnettest.RunDERPAndSTUN(t), + }, 0) id := uuid.NewString() netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -463,12 +469,26 @@ func TestAgent(t *testing.T) { require.ErrorContains(t, err, "no such file") require.Nil(t, netConn) }) + + t.Run("Tailnet", func(t *testing.T) { + t.Parallel() + derpMap := tailnettest.RunDERPAndSTUN(t) + conn := setupAgent(t, agent.Metadata{ + DERPMap: derpMap, + }, 0) + defer conn.Close() + require.Eventually(t, func() bool { + _, err := conn.Ping() + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) + }) } func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { agentConn := setupAgent(t, agent.Metadata{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + waitGroup := sync.WaitGroup{} go func() { defer listener.Close() for { @@ -481,11 +501,16 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe _ = conn.Close() return } - go agent.Bicopy(context.Background(), conn, ssh) + waitGroup.Add(1) + go func() { + agent.Bicopy(context.Background(), conn, ssh) + waitGroup.Done() + }() } }() t.Cleanup(func() { _ = listener.Close() + waitGroup.Wait() }) tcpAddr, valid := listener.Addr().(*net.TCPAddr) require.True(t, valid) @@ -500,17 +525,36 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session { sshClient, err := setupAgent(t, options, 0).SSHClient() require.NoError(t, err) + t.Cleanup(func() { + _ = sshClient.Close() + }) session, err := sshClient.NewSession() require.NoError(t, err) return session } -func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn { +func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn { client, server := provisionersdk.TransportPipe() - closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) { - listener, err := peerbroker.Listen(server, nil) - return metadata, listener, err - }, &agent.Options{ + tailscale := metadata.DERPMap != nil + coordinator := tailnet.NewCoordinator() + agentID := uuid.New() + closer := agent.New(agent.Options{ + FetchMetadata: func(ctx context.Context) (agent.Metadata, error) { + return metadata, nil + }, + WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { + listener, err := peerbroker.Listen(server, nil) + return listener, err + }, + CoordinatorDialer: func(ctx context.Context) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = serverConn.Close() + _ = clientConn.Close() + }) + go coordinator.ServeAgent(serverConn, agentID) + return clientConn, nil + }, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), ReconnectingPTYTimeout: ptyTimeout, }) @@ -522,6 +566,28 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(context.Background()) assert.NoError(t, err) + if tailscale { + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: metadata.DERPMap, + Logger: slogtest.Make(t, nil).Named("tailnet"), + }) + require.NoError(t, err) + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + _ = conn.Close() + }) + go coordinator.ServeClient(serverConn, uuid.New(), agentID) + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNode) + return &agent.TailnetConn{ + Conn: conn, + } + } conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ Logger: slogtest.Make(t, nil), }) @@ -530,7 +596,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) _ = conn.Close() }) - return &agent.Conn{ + return &agent.WebRTCConn{ Negotiator: api, Conn: conn, } diff --git a/agent/conn.go b/agent/conn.go index 0be45bc05c33e..0e95e97e21254 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -2,17 +2,25 @@ package agent import ( "context" + "encoding/binary" "encoding/json" "fmt" + "io" "net" + "net/netip" "net/url" + "strconv" "strings" + "time" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/tailnet" ) // ReconnectingPTYRequest is sent from the client to the server @@ -23,9 +31,21 @@ type ReconnectingPTYRequest struct { Width uint16 `json:"width"` } +// Conn is a temporary interface while we switch from WebRTC to Wireguard networking. +type Conn interface { + io.Closer + Closed() <-chan struct{} + Ping() (time.Duration, error) + CloseWithError(err error) error + ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) + SSH() (net.Conn, error) + SSHClient() (*ssh.Client, error) + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) +} + // Conn wraps a peer connection with helper functions to // communicate with the agent. -type Conn struct { +type WebRTCConn struct { // Negotiator is responsible for exchanging messages. Negotiator proto.DRPCPeerBrokerClient @@ -36,7 +56,7 @@ type Conn struct { // be reconnected to via ID. // // The command is optional and defaults to start a shell. -func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { +func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{ Protocol: ProtocolReconnectingPTY, }) @@ -47,7 +67,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) } // SSH dials the built-in SSH server. -func (c *Conn) SSH() (net.Conn, error) { +func (c *WebRTCConn) SSH() (net.Conn, error) { channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ Protocol: ProtocolSSH, }) @@ -59,7 +79,7 @@ func (c *Conn) SSH() (net.Conn, error) { // SSHClient calls SSH to create a client that uses a weak cipher // for high throughput. -func (c *Conn) SSHClient() (*ssh.Client, error) { +func (c *WebRTCConn) SSHClient() (*ssh.Client, error) { netConn, err := c.SSH() if err != nil { return nil, xerrors.Errorf("ssh: %w", err) @@ -78,7 +98,7 @@ func (c *Conn) SSHClient() (*ssh.Client, error) { // DialContext dials an arbitrary protocol+address from inside the workspace and // proxies it through the provided net.Conn. -func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { u := &url.URL{ Scheme: network, } @@ -112,7 +132,107 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne return channel.NetConn(), nil } -func (c *Conn) Close() error { +func (c *WebRTCConn) Close() error { _ = c.Negotiator.DRPCConn().Close() return c.Conn.Close() } + +type TailnetConn struct { + *tailnet.Conn + CloseFunc func() +} + +func (c *TailnetConn) Ping() (time.Duration, error) { + errCh := make(chan error, 1) + durCh := make(chan time.Duration, 1) + c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) { + if pr.Err != "" { + errCh <- xerrors.New(pr.Err) + return + } + durCh <- time.Duration(pr.LatencySeconds * float64(time.Second)) + }) + select { + case err := <-errCh: + return 0, err + case dur := <-durCh: + return dur, nil + } +} + +func (c *TailnetConn) CloseWithError(_ error) error { + return c.Close() +} + +func (c *TailnetConn) Close() error { + if c.CloseFunc != nil { + c.CloseFunc() + } + return c.Conn.Close() +} + +type reconnectingPTYInit struct { + ID string + Height uint16 + Width uint16 + Command string +} + +func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { + conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort))) + if err != nil { + return nil, err + } + data, err := json.Marshal(reconnectingPTYInit{ + ID: id, + Height: height, + Width: width, + Command: command, + }) + if err != nil { + _ = conn.Close() + return nil, err + } + data = append(make([]byte, 2), data...) + binary.LittleEndian.PutUint16(data, uint16(len(data)-2)) + + _, err = conn.Write(data) + if err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil +} + +func (c *TailnetConn) SSH() (net.Conn, error) { + return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort))) +} + +// SSHClient calls SSH to create a client that uses a weak cipher +// for high throughput. +func (c *TailnetConn) SSHClient() (*ssh.Client, error) { + netConn, err := c.SSH() + if err != nil { + return nil, xerrors.Errorf("ssh: %w", err) + } + sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ + // SSH host validation isn't helpful, because obtaining a peer + // connection already signifies user-intent to dial a workspace. + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, xerrors.Errorf("ssh conn: %w", err) + } + return ssh.NewClient(sshConn, channels, requests), nil +} + +func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + _, rawPort, _ := net.SplitHostPort(addr) + port, _ := strconv.Atoi(rawPort) + ipp := netip.AddrPortFrom(tailnetIP, uint16(port)) + if network == "udp" { + return c.Conn.DialContextUDP(ctx, ipp) + } + return c.Conn.DialContextTCP(ctx, ipp) +} diff --git a/agent/wireguard.go b/agent/wireguard.go deleted file mode 100644 index 603b5616e4740..0000000000000 --- a/agent/wireguard.go +++ /dev/null @@ -1,97 +0,0 @@ -package agent - -import ( - "context" - "net" - "strconv" - - "golang.org/x/xerrors" - "inet.af/netaddr" - - "cdr.dev/slog" - "github.com/coder/coder/peer/peerwg" -) - -func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) error { - if a.network != nil { - _ = a.network.Close() - a.network = nil - } - - // We can't create a wireguard network without these. - if len(addrs) == 0 || a.listenWireguardPeers == nil || a.postKeys == nil { - return xerrors.New("wireguard is enabled, but no addresses were provided or necessary functions were not provided") - } - - wg, err := peerwg.New(a.logger.Named("wireguard"), addrs) - if err != nil { - return xerrors.Errorf("create wireguard network: %w", err) - } - - // A new keypair is generated on each agent start. - // This keypair must be sent to Coder to allow for incoming connections. - err = a.postKeys(ctx, WireguardPublicKeys{ - Public: wg.NodePrivateKey.Public(), - Disco: wg.DiscoPublicKey, - }) - if err != nil { - a.logger.Warn(ctx, "post keys", slog.Error(err)) - } - - go func() { - for { - ch, listenClose, err := a.listenWireguardPeers(ctx, a.logger) - if err != nil { - a.logger.Warn(ctx, "listen wireguard peers", slog.Error(err)) - return - } - - for { - peer, ok := <-ch - if !ok { - break - } - - err := wg.AddPeer(peer) - a.logger.Info(ctx, "added wireguard peer", slog.F("peer", peer.NodePublicKey.ShortString()), slog.Error(err)) - } - - listenClose() - } - }() - - a.startWireguardListeners(ctx, wg, []handlerPort{ - {port: 12212, handler: a.sshServer.HandleConn}, - }) - - a.network = wg - return nil -} - -type handlerPort struct { - handler func(conn net.Conn) - port uint16 -} - -func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) { - for _, h := range handlers { - go func(h handlerPort) { - a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port)) - - listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port)))) - if err != nil { - a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err)) - return - } - - for { - conn, err := listener.Accept() - if err != nil { - return - } - - go h.handler(conn) - } - }(h) - } -} diff --git a/cli/agent.go b/cli/agent.go index 24571e9fbc497..eb6d2287af998 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -182,16 +182,16 @@ func workspaceAgent() *cobra.Command { logger.Error(cmd.Context(), "post agent version: %w", slog.Error(err), slog.F("version", version)) } - closer := agent.New(client.ListenWorkspaceAgent, &agent.Options{ - Logger: logger, + closer := agent.New(agent.Options{ + FetchMetadata: client.WorkspaceAgentMetadata, + WebRTCDialer: client.ListenWorkspaceAgent, + Logger: logger, EnvironmentVariables: map[string]string{ // Override the "CODER_AGENT_TOKEN" variable in all // shells so "gitssh" works! "CODER_AGENT_TOKEN": client.SessionToken, }, - EnableWireguard: wireguard, - UploadWireguardKeys: client.UploadWorkspaceAgentKeys, - ListenWireguardPeers: client.WireguardPeerListener, + CoordinatorDialer: client.ListenWorkspaceAgentTailnet, }) <-cmd.Context().Done() return closer.Close() diff --git a/cli/cliflag/cliflag.go b/cli/cliflag/cliflag.go index 843416c3ff3ea..42722ecc1cfb3 100644 --- a/cli/cliflag/cliflag.go +++ b/cli/cliflag/cliflag.go @@ -90,6 +90,23 @@ func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env)) } +// IntVarP sets a uint8 flag on the given flag set. +func IntVarP(flagset *pflag.FlagSet, ptr *int, name string, shorthand string, env string, def int, usage string) { + val, ok := os.LookupEnv(env) + if !ok || val == "" { + flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + vi64, err := strconv.ParseUint(val, 10, 8) + if err != nil { + flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) + return + } + + flagset.IntVarP(ptr, name, shorthand, int(vi64), fmtUsage(usage, env)) +} + func Bool(flagset *pflag.FlagSet, name, shorthand, env string, def bool, usage string) { val, ok := os.LookupEnv(env) if !ok || val == "" { diff --git a/cli/cliflag/cliflag_test.go b/cli/cliflag/cliflag_test.go index acdf7d6765fb5..5d826166307a5 100644 --- a/cli/cliflag/cliflag_test.go +++ b/cli/cliflag/cliflag_test.go @@ -108,7 +108,7 @@ func TestCliflag(t *testing.T) { require.Equal(t, []string{}, got) }) - t.Run("IntDefault", func(t *testing.T) { + t.Run("UInt8Default", func(t *testing.T) { var ptr uint8 flagset, name, shorthand, env, usage := randomFlag() def, _ := cryptorand.Int63n(10) @@ -121,7 +121,7 @@ func TestCliflag(t *testing.T) { require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) }) - t.Run("IntEnvVar", func(t *testing.T) { + t.Run("UInt8EnvVar", func(t *testing.T) { var ptr uint8 flagset, name, shorthand, env, usage := randomFlag() envValue, _ := cryptorand.Int63n(10) @@ -134,7 +134,7 @@ func TestCliflag(t *testing.T) { require.Equal(t, uint8(envValue), got) }) - t.Run("IntFailParse", func(t *testing.T) { + t.Run("UInt8FailParse", func(t *testing.T) { var ptr uint8 flagset, name, shorthand, env, usage := randomFlag() envValue, _ := cryptorand.String(10) @@ -147,6 +147,45 @@ func TestCliflag(t *testing.T) { require.Equal(t, uint8(def), got) }) + t.Run("IntDefault", func(t *testing.T) { + var ptr int + flagset, name, shorthand, env, usage := randomFlag() + def, _ := cryptorand.Int63n(10) + + cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage) + got, err := flagset.GetInt(name) + require.NoError(t, err) + require.Equal(t, int(def), got) + require.Contains(t, flagset.FlagUsages(), usage) + require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) + }) + + t.Run("IntEnvVar", func(t *testing.T) { + var ptr int + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.Int63n(10) + t.Setenv(env, strconv.FormatUint(uint64(envValue), 10)) + def, _ := cryptorand.Int() + + cliflag.IntVarP(flagset, &ptr, name, shorthand, env, def, usage) + got, err := flagset.GetInt(name) + require.NoError(t, err) + require.Equal(t, int(envValue), got) + }) + + t.Run("IntFailParse", func(t *testing.T) { + var ptr int + flagset, name, shorthand, env, usage := randomFlag() + envValue, _ := cryptorand.String(10) + t.Setenv(env, envValue) + def, _ := cryptorand.Int63n(10) + + cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage) + got, err := flagset.GetInt(name) + require.NoError(t, err) + require.Equal(t, int(def), got) + }) + t.Run("BoolDefault", func(t *testing.T) { var ptr bool flagset, name, shorthand, env, usage := randomFlag() diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 29e59a7dfc8ec..318d71627ec13 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -104,8 +104,11 @@ func TestConfigSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + CoordinatorDialer: client.ListenWorkspaceAgentTailnet, + Logger: slogtest.Make(t, nil).Named("agent"), }) defer func() { _ = agentCloser.Close() diff --git a/cli/features.go b/cli/features.go index 1995153275eaf..f430534330816 100644 --- a/cli/features.go +++ b/cli/features.go @@ -13,8 +13,6 @@ import ( "github.com/coder/coder/codersdk" ) -var featureColumns = []string{"Name", "Entitlement", "Enabled", "Limit", "Actual"} - func features() *cobra.Command { cmd := &cobra.Command{ Short: "List features", @@ -29,8 +27,9 @@ func features() *cobra.Command { func featuresList() *cobra.Command { var ( - columns []string - outputFormat string + featureColumns = []string{"Name", "Entitlement", "Enabled", "Limit", "Actual"} + columns []string + outputFormat string ) cmd := &cobra.Command{ diff --git a/cli/portforward.go b/cli/portforward.go index b9db1ffaa4c22..66753ac92f7d1 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -163,7 +163,7 @@ func portForward() *cobra.Command { return cmd } -func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) var ( diff --git a/cli/root.go b/cli/root.go index 6094ddfe24d3d..f90c3ad1ec5f1 100644 --- a/cli/root.go +++ b/cli/root.go @@ -86,7 +86,6 @@ func Core() []*cobra.Command { update(), users(), versionCmd(), - wireguardPortForward(), workspaceAgent(), features(), } diff --git a/cli/server.go b/cli/server.go index b66ce7f690d6c..f6f5534f3604c 100644 --- a/cli/server.go +++ b/cli/server.go @@ -41,6 +41,7 @@ import ( "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" + "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" @@ -65,6 +66,7 @@ import ( "github.com/coder/coder/provisionerd" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/tailnet" ) // nolint:gocyclo @@ -73,6 +75,12 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { accessURL string address string autobuildPollInterval time.Duration + derpServerEnabled bool + derpServerRegionID int + derpServerRegionCode string + derpServerRegionName string + derpServerSTUNAddrs []string + derpConfigURL string promEnabled bool promAddress string pprofEnabled bool @@ -94,6 +102,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { oidcEmailDomain string oidcIssuerURL string oidcScopes []string + tailscaleEnable bool telemetryEnable bool telemetryURL string tlsCertFile string @@ -245,6 +254,17 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { if err != nil { return xerrors.Errorf("parse URL: %w", err) } + accessURLPortRaw := accessURLParsed.Port() + if accessURLPortRaw == "" { + accessURLPortRaw = "80" + if accessURLParsed.Scheme == "https" { + accessURLPortRaw = "443" + } + } + accessURLPort, err := strconv.Atoi(accessURLPortRaw) + if err != nil { + return xerrors.Errorf("parse access URL port: %w", err) + } // Warn the user if the access URL appears to be a loopback address. isLocal, err := isLocalURL(ctx, accessURLParsed) @@ -307,16 +327,35 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { validatedAutoImportTemplates[i] = v } + derpMap, err := tailnet.NewDERPMap(ctx, &tailcfg.DERPRegion{ + RegionID: derpServerRegionID, + RegionCode: derpServerRegionCode, + RegionName: derpServerRegionName, + Nodes: []*tailcfg.DERPNode{{ + Name: fmt.Sprintf("%db", derpServerRegionID), + RegionID: derpServerRegionID, + HostName: accessURLParsed.Hostname(), + DERPPort: accessURLPort, + STUNPort: -1, + ForceHTTP: accessURLParsed.Scheme == "http", + }}, + }, derpServerSTUNAddrs, derpConfigURL) + if err != nil { + return xerrors.Errorf("create derp map: %w", err) + } + options := &coderd.Options{ AccessURL: accessURLParsed, ICEServers: iceServers, Logger: logger.Named("coderd"), Database: databasefake.New(), + DERPMap: derpMap, Pubsub: database.NewPubsubInMemory(), CacheDir: cacheDir, GoogleTokenValidator: googleTokenValidator, SecureAuthCookie: secureAuthCookie, SSHKeygenAlgorithm: sshKeygenAlgorithm, + TailscaleEnable: tailscaleEnable, TURNServer: turnServer, TracerProvider: tracerProvider, Telemetry: telemetry.NewNoop(), @@ -704,6 +743,24 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { cliflag.DurationVarP(root.Flags(), &autobuildPollInterval, "autobuild-poll-interval", "", "CODER_AUTOBUILD_POLL_INTERVAL", time.Minute, "Specifies the interval at which to poll for and execute automated workspace build operations.") cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder.") cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard.") + cliflag.StringVarP(root.Flags(), &derpConfigURL, "derp-config-url", "", "CODER_DERP_CONFIG_URL", "", + "Specifies a URL to periodically fetch a DERP map. See: https://tailscale.com/kb/1118/custom-derp-servers/") + cliflag.BoolVarP(root.Flags(), &derpServerEnabled, "derp-server-enable", "", "CODER_DERP_SERVER_ENABLE", true, "Specifies whether to enable or disable the embedded DERP server.") + cliflag.IntVarP(root.Flags(), &derpServerRegionID, "derp-server-region-id", "", "CODER_DERP_SERVER_REGION_ID", 999, "Specifies the region ID to use for the embedded DERP server.") + cliflag.StringVarP(root.Flags(), &derpServerRegionCode, "derp-server-region-code", "", "CODER_DERP_SERVER_REGION_CODE", "coder", "Specifies the region code that is displayed in the Coder UI for the embedded DERP server.") + cliflag.StringVarP(root.Flags(), &derpServerRegionName, "derp-server-region-name", "", "CODER_DERP_SERVER_REGION_NAME", "Coder Embedded DERP", "Specifies the region name that is displayed in the Coder UI for the embedded DERP server.") + cliflag.StringArrayVarP(root.Flags(), &derpServerSTUNAddrs, "derp-server-stun-addresses", "", "CODER_DERP_SERVER_STUN_ADDRESSES", []string{ + "stun.l.google.com:19302", + }, "Specify addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections entirely.") + + // Mark hidden while this feature is in testing! + _ = root.Flags().MarkHidden("derp-config-url") + _ = root.Flags().MarkHidden("derp-server-enable") + _ = root.Flags().MarkHidden("derp-server-region-id") + _ = root.Flags().MarkHidden("derp-server-region-code") + _ = root.Flags().MarkHidden("derp-server-region-name") + _ = root.Flags().MarkHidden("derp-server-stun-addresses") + cliflag.BoolVarP(root.Flags(), &promEnabled, "prometheus-enable", "", "CODER_PROMETHEUS_ENABLE", false, "Enable serving prometheus metrics on the addressdefined by --prometheus-address.") cliflag.StringVarP(root.Flags(), &promAddress, "prometheus-address", "", "CODER_PROMETHEUS_ADDRESS", "127.0.0.1:2112", "The address to serve prometheus metrics.") cliflag.BoolVarP(root.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.") @@ -743,6 +800,9 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { "Specifies an issuer URL to use for OIDC.") cliflag.StringArrayVarP(root.Flags(), &oidcScopes, "oidc-scopes", "", "CODER_OIDC_SCOPES", []string{oidc.ScopeOpenID, "profile", "email"}, "Specifies scopes to grant when authenticating with OIDC.") + cliflag.BoolVarP(root.Flags(), &tailscaleEnable, "tailscale", "", "CODER_TAILSCALE", false, + "Specifies whether Tailscale networking is used for web applications and terminals.") + _ = root.Flags().MarkHidden("tailscale") enableTelemetryByDefault := !isTest() cliflag.BoolVarP(root.Flags(), &telemetryEnable, "telemetry", "", "CODER_TELEMETRY", enableTelemetryByDefault, "Specifies whether telemetry is enabled or not. Coder collects anonymized usage data to help improve our product.") cliflag.StringVarP(root.Flags(), &telemetryURL, "telemetry-url", "", "CODER_TELEMETRY_URL", "https://telemetry.coder.com", "Specifies a URL to send telemetry to.") diff --git a/cli/ssh.go b/cli/ssh.go index 6295e8e025431..7f23cce706c20 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -19,18 +19,16 @@ import ( gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" - "inet.af/netaddr" - tslogger "tailscale.com/types/logger" "cdr.dev/slog" - "cdr.dev/slog/sloggers/sloghuman" + + "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" - "github.com/coder/coder/peer/peerwg" ) var ( @@ -90,87 +88,35 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var newSSHClient func() (*gossh.Client, error) - + var conn agent.Conn if !wireguard { - conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) - if err != nil { - return err - } - defer conn.Close() - - stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) - defer stopPolling() - - if stdio { - rawSSH, err := conn.SSH() - if err != nil { - return err - } - defer rawSSH.Close() - - go func() { - _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) - }() - _, _ = io.Copy(rawSSH, cmd.InOrStdin()) - return nil - } - - newSSHClient = conn.SSHClient + conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) } else { - // TODO: more granual control of Tailscale logging. - peerwg.Logf = tslogger.Discard - - ipv6 := peerwg.UUIDToNetaddr(uuid.New()) - wgn, err := peerwg.New( - slog.Make(sloghuman.Sink(cmd.ErrOrStderr())), - []netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)}, - ) - if err != nil { - return xerrors.Errorf("create wireguard network: %w", err) - } - defer wgn.Close() - - err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ - Recipient: workspaceAgent.ID, - NodePublicKey: wgn.NodePrivateKey.Public(), - DiscoPublicKey: wgn.DiscoPublicKey, - IPv6: ipv6, - }) - if err != nil { - return xerrors.Errorf("post wireguard peer: %w", err) - } - - err = wgn.AddPeer(peerwg.Handshake{ - Recipient: workspaceAgent.ID, - DiscoPublicKey: workspaceAgent.DiscoPublicKey, - NodePublicKey: workspaceAgent.WireguardPublicKey, - IPv6: workspaceAgent.IPv6.IP(), - }) - if err != nil { - return xerrors.Errorf("add workspace agent as peer: %w", err) - } + conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) + } + if err != nil { + return err + } + defer conn.Close() - if stdio { - rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP()) - if err != nil { - return err - } - defer rawSSH.Close() + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) + defer stopPolling() - go func() { - _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) - }() - _, _ = io.Copy(rawSSH, cmd.InOrStdin()) - return nil + if stdio { + rawSSH, err := conn.SSH() + if err != nil { + return err } + defer rawSSH.Close() - newSSHClient = func() (*gossh.Client, error) { - return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP()) - } + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + }() + _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + return nil } - sshClient, err := newSSHClient() + sshClient, err := conn.SSHClient() if err != nil { return err } @@ -330,34 +276,34 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder if len(agents) == 0 { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) } - var agent codersdk.WorkspaceAgent + var workspaceAgent codersdk.WorkspaceAgent if len(workspaceParts) >= 2 { for _, otherAgent := range agents { if otherAgent.Name != workspaceParts[1] { continue } - agent = otherAgent + workspaceAgent = otherAgent break } - if agent.ID == uuid.Nil { + if workspaceAgent.ID == uuid.Nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1]) } } - if agent.ID == uuid.Nil { + if workspaceAgent.ID == uuid.Nil { if len(agents) > 1 { if !shuffle { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") } - agent, err = cryptorand.Element(agents) + workspaceAgent, err = cryptorand.Element(agents) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err } } else { - agent = agents[0] + workspaceAgent = agents[0] } } - return workspace, agent, nil + return workspace, workspaceAgent, nil } // Attempt to poll workspace autostop. We write a per-workspace lockfile to diff --git a/cli/ssh_test.go b/cli/ssh_test.go index e9743180642b7..6f2965174ac77 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -19,7 +19,6 @@ import ( "golang.org/x/crypto/ssh" gosshagent "golang.org/x/crypto/ssh/agent" - "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -89,8 +88,11 @@ func TestSSH(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = agentToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + Logger: slogtest.Make(t, nil).Named("agent"), }) defer func() { _ = agentCloser.Close() @@ -108,8 +110,11 @@ func TestSSH(t *testing.T) { // the build and agent to connect! agentClient := codersdk.New(client.URL) agentClient.SessionToken = agentToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + Logger: slogtest.Make(t, nil).Named("agent"), }) <-ctx.Done() _ = agentCloser.Close() @@ -174,8 +179,11 @@ func TestSSH(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = agentToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + Logger: slogtest.Make(t, nil).Named("agent"), }) defer agentCloser.Close() diff --git a/cli/wireguardtunnel.go b/cli/wireguardtunnel.go deleted file mode 100644 index 488f7ad342074..0000000000000 --- a/cli/wireguardtunnel.go +++ /dev/null @@ -1,263 +0,0 @@ -package cli - -import ( - "context" - "fmt" - "net" - "os" - "os/signal" - "strconv" - "sync" - "syscall" - - "github.com/google/uuid" - "github.com/pion/udp" - "github.com/spf13/cobra" - "golang.org/x/xerrors" - "inet.af/netaddr" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/sloghuman" - coderagent "github.com/coder/coder/agent" - "github.com/coder/coder/cli/cliui" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer/peerwg" -) - -func wireguardPortForward() *cobra.Command { - var ( - tcpForwards []string // : - udpForwards []string // : - // TODO: unix support - // unixForwards []string // : OR : - ) - cmd := &cobra.Command{ - Use: "wireguard-port-forward ", - Aliases: []string{"wireguard-tunnel"}, - Args: cobra.ExactArgs(1), - // Hide all wireguard commands for now while we test! - Hidden: true, - Example: formatExamples( - example{ - Description: "Port forward a single TCP port from 1234 in the workspace to port 5678 on your local machine", - Command: "coder wireguard-port-forward --tcp 5678:1234", - }, - example{ - Description: "Port forward a single UDP port from port 9000 to port 9000 on your local machine", - Command: "coder wireguard-port-forward --udp 9000", - }, - example{ - Description: "Port forward multiple TCP ports and a UDP port", - Command: "coder wireguard-port-forward --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53", - }, - ), - RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) - defer cancel() - - specs, err := parsePortForwards(tcpForwards, nil, nil) - if err != nil { - return xerrors.Errorf("parse port-forward specs: %w", err) - } - if len(specs) == 0 { - err = cmd.Help() - if err != nil { - return xerrors.Errorf("generate help output: %w", err) - } - return xerrors.New("no port-forwards requested") - } - - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) - if err != nil { - return err - } - if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { - return xerrors.New("workspace must be in start transition to port-forward") - } - if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) - if err != nil { - return err - } - } - - err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ - WorkspaceName: workspace.Name, - Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { - return client.WorkspaceAgent(ctx, workspaceAgent.ID) - }, - }) - if err != nil { - return xerrors.Errorf("await agent: %w", err) - } - - ipv6 := peerwg.UUIDToNetaddr(uuid.New()) - wgn, err := peerwg.New( - slog.Make(sloghuman.Sink(cmd.ErrOrStderr())), - []netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)}, - ) - if err != nil { - return xerrors.Errorf("create wireguard network: %w", err) - } - defer wgn.Close() - - err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ - Recipient: workspaceAgent.ID, - NodePublicKey: wgn.NodePrivateKey.Public(), - DiscoPublicKey: wgn.DiscoPublicKey, - IPv6: ipv6, - }) - if err != nil { - return xerrors.Errorf("post wireguard peer: %w", err) - } - - err = wgn.AddPeer(peerwg.Handshake{ - Recipient: workspaceAgent.ID, - DiscoPublicKey: workspaceAgent.DiscoPublicKey, - NodePublicKey: workspaceAgent.WireguardPublicKey, - IPv6: workspaceAgent.IPv6.IP(), - }) - if err != nil { - return xerrors.Errorf("add workspace agent as peer: %w", err) - } - - // Start all listeners. - var ( - wg = new(sync.WaitGroup) - listeners = make([]net.Listener, len(specs)) - closeAllListeners = func() { - for _, l := range listeners { - if l == nil { - continue - } - _ = l.Close() - } - } - ) - defer closeAllListeners() - - for i, spec := range specs { - l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP()) - if err != nil { - return err - } - listeners[i] = l - } - - // Wait for the context to be canceled or for a signal and close - // all listeners. - var closeErr error - wg.Add(1) - go func() { - defer wg.Done() - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - - select { - case <-ctx.Done(): - closeErr = ctx.Err() - case <-sigs: - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections") - closeErr = xerrors.New("signal received") - } - - cancel() - closeAllListeners() - }() - - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") - wg.Wait() - return closeErr - }, - } - - cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") - cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") - // cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port") - - return cmd -} - -func listenAndPortForwardWireguard(ctx context.Context, cmd *cobra.Command, - wgn *peerwg.Network, - wg *sync.WaitGroup, - spec portForwardSpec, - agentIP netaddr.IP, -) (net.Listener, error) { - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) - - var ( - l net.Listener - err error - ) - switch spec.listenNetwork { - case "tcp": - l, err = net.Listen(spec.listenNetwork, spec.listenAddress) - case "udp": - var host, port string - host, port, err = net.SplitHostPort(spec.listenAddress) - if err != nil { - return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err) - } - - var portInt int - portInt, err = strconv.Atoi(port) - if err != nil { - return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) - } - - l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ - IP: net.ParseIP(host), - Port: portInt, - }) - // case "unix": - // l, err = net.Listen(spec.listenNetwork, spec.listenAddress) - default: - return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork) - } - if err != nil { - return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) - } - - wg.Add(1) - go func(spec portForwardSpec) { - defer wg.Done() - for { - netConn, err := l.Accept() - if err != nil { - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") - return - } - - go func(netConn net.Conn) { - defer netConn.Close() - - ipPort := netaddr.MustParseIPPort(spec.dialAddress).WithIP(agentIP) - - var remoteConn net.Conn - switch spec.dialNetwork { - case "tcp": - remoteConn, err = wgn.Netstack.DialContextTCP(ctx, ipPort) - case "udp": - remoteConn, err = wgn.Netstack.DialContextUDP(ctx, ipPort) - } - if err != nil { - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) - return - } - defer remoteConn.Close() - - coderagent.Bicopy(ctx, netConn, remoteConn) - }(netConn) - } - }(spec) - - return l, nil -} diff --git a/coderd/coderd.go b/coderd/coderd.go index 58a4386c90005..52bdc54208d11 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -18,6 +18,10 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" "golang.org/x/xerrors" "google.golang.org/api/idtoken" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/tailcfg" + "tailscale.com/types/key" "cdr.dev/slog" "github.com/coder/coder/buildinfo" @@ -33,6 +37,7 @@ import ( "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" + "github.com/coder/coder/tailnet" ) // Options are requires parameters for Coder to start. @@ -67,6 +72,10 @@ type Options struct { AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler FeaturesService FeaturesService + + TailscaleEnable bool + TailnetCoordinator *tailnet.Coordinator + DERPMap *tailcfg.DERPMap } // New constructs a Coder API handler. @@ -93,6 +102,9 @@ func New(options *Options) *API { if options.PrometheusRegistry == nil { options.PrometheusRegistry = prometheus.NewRegistry() } + if options.TailnetCoordinator == nil { + options.TailnetCoordinator = tailnet.NewCoordinator() + } if options.LicenseHandler == nil { options.LicenseHandler = licenses() } @@ -119,7 +131,12 @@ func New(options *Options) *API { Logger: options.Logger, }, } - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0) + if options.TailscaleEnable { + api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) + } else { + api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0) + } + api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, OIDC: options.OIDCConfig, @@ -148,6 +165,7 @@ func New(options *Options) *API { // other applications might not as well. r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) + r.Get("/derp", derphttp.Handler(api.derpServer).ServeHTTP) r.Route("/api/v2", func(r chi.Router) { r.NotFound(func(rw http.ResponseWriter, r *http.Request) { @@ -338,9 +356,8 @@ func New(options *Options) *API { r.Get("/gitsshkey", api.agentGitSSHKey) r.Get("/turn", api.workspaceAgentTurn) r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/wireguardlisten", api.workspaceAgentWireguardListener) - r.Post("/keys", api.postWorkspaceAgentKeys) - r.Get("/derp", api.derpMap) + + r.Get("/coordinate", api.workspaceAgentCoordinate) }) r.Route("/{workspaceagent}", func(r chi.Router) { r.Use( @@ -349,12 +366,13 @@ func New(options *Options) *API { httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspaceAgent) - r.Post("/peer", api.postWorkspaceAgentWireguardPeer) r.Get("/dial", api.workspaceAgentDial) r.Get("/turn", api.userWorkspaceAgentTurn) r.Get("/pty", api.workspaceAgentPTY) r.Get("/iceservers", api.workspaceAgentICEServers) - r.Get("/derp", api.derpMap) + + r.Get("/connection", api.workspaceAgentConnection) + r.Get("/coordinate", api.workspaceAgentClientCoordinate) }) }) r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) { @@ -420,6 +438,8 @@ func New(options *Options) *API { type API struct { *Options + derpServer *derp.Server + Handler chi.Router siteHandler http.Handler websocketWaitMutex sync.Mutex diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 615e032f00eb2..7e9175e73392b 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -2,13 +2,21 @@ package coderd_test import ( "context" + "net/netip" + "strconv" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" ) @@ -38,3 +46,71 @@ func TestAuthorizeAllEndpoints(t *testing.T) { skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) a.Test(ctx, assertRoute, skipRoutes) } + +func TestDERP(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + derpPort, err := strconv.Atoi(client.URL.Port()) + require.NoError(t, err) + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "cdr", + RegionName: "Coder", + Nodes: []*tailcfg.DERPNode{{ + Name: "1a", + RegionID: 1, + HostName: client.URL.Hostname(), + DERPPort: derpPort, + STUNPort: -1, + ForceHTTP: true, + }}, + }, + }, + } + w1IP := tailnet.IP() + w1, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, + Logger: logger.Named("w1"), + DERPMap: derpMap, + }) + require.NoError(t, err) + + w2, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + Logger: logger.Named("w2"), + DERPMap: derpMap, + }) + require.NoError(t, err) + w1.SetNodeCallback(func(node *tailnet.Node) { + w2.UpdateNodes([]*tailnet.Node{node}) + }) + w2.SetNodeCallback(func(node *tailnet.Node) { + w1.UpdateNodes([]*tailnet.Node{node}) + }) + + conn := make(chan struct{}) + go func() { + listener, err := w1.Listen("tcp", ":35565") + assert.NoError(t, err) + defer listener.Close() + conn <- struct{}{} + nc, err := listener.Accept() + assert.NoError(t, err) + _ = nc.Close() + conn <- struct{}{} + }() + + <-conn + nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565)) + require.NoError(t, err) + _ = nc.Close() + <-conn + + w1.Close() + w2.Close() +} diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authtest.go index 0ca2d1a0884d5..7bfa855da8f7a 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authtest.go @@ -167,6 +167,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { // skipRoutes allows skipping routes from being checked. skipRoutes := map[string]string{ "POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes", + "GET:/derp": "This requires a WebSocket upgrade!", } assertRoute := map[string]RouteCheck{ @@ -193,12 +194,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true}, "GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true}, "GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/derp": {NoAuthorize: true}, + "GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true}, "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/me/wireguardlisten": {NoAuthorize: true}, - "POST:/api/v2/workspaceagents/me/keys": {NoAuthorize: true}, "GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true}, - "GET:/api/v2/workspaceagents/{workspaceagent}/derp": {NoAuthorize: true}, // These endpoints have more assertions. This is good, add more endpoints to assert if you can! "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, @@ -271,6 +269,10 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }, + "GET:/api/v2/workspaceagents/{workspaceagent}/coordinate": { + AssertAction: rbac.ActionCreate, + AssertObject: workspaceExecObj, + }, "GET:/api/v2/workspaces/": { StatusCode: http.StatusOK, AssertAction: rbac.ActionRead, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f7c7c04288ea6..8970058ffb74d 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -21,6 +21,7 @@ import ( "net/http/httptest" "net/url" "os" + "strconv" "strings" "testing" "time" @@ -36,6 +37,7 @@ import ( "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" + "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -178,6 +180,9 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c serverURL, err := url.Parse(srv.URL) require.NoError(t, err) + derpPort, err := strconv.Atoi(serverURL.Port()) + require.NoError(t, err) + // match default with cli default if options.SSHKeygenAlgorithm == "" { options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 @@ -211,7 +216,25 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c APIRateLimit: options.APIRateLimit, Authorizer: options.Authorizer, Telemetry: telemetry.NewNoop(), - AutoImportTemplates: options.AutoImportTemplates, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "coder", + RegionName: "Coder", + Nodes: []*tailcfg.DERPNode{{ + Name: "1a", + RegionID: 1, + IPv4: "127.0.0.1", + DERPPort: derpPort, + STUNPort: -1, + InsecureForTests: true, + ForceHTTP: true, + }}, + }, + }, + }, + AutoImportTemplates: options.AutoImportTemplates, }) t.Cleanup(func() { _ = coderAPI.Close() diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 34c81fc046b29..6ebf15613e290 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1701,23 +1701,20 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser defer q.mutex.Unlock() agent := database.WorkspaceAgent{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - StartupScript: arg.StartupScript, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - WireguardNodeIPv6: arg.WireguardNodeIPv6, - WireguardNodePublicKey: arg.WireguardNodePublicKey, - WireguardDiscoPublicKey: arg.WireguardDiscoPublicKey, + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + ResourceID: arg.ResourceID, + AuthToken: arg.AuthToken, + AuthInstanceID: arg.AuthInstanceID, + EnvironmentVariables: arg.EnvironmentVariables, + Name: arg.Name, + Architecture: arg.Architecture, + OperatingSystem: arg.OperatingSystem, + Directory: arg.Directory, + StartupScript: arg.StartupScript, + InstanceMetadata: arg.InstanceMetadata, + ResourceMetadata: arg.ResourceMetadata, } q.provisionerJobAgents = append(q.provisionerJobAgents, agent) @@ -2029,24 +2026,6 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg return sql.ErrNoRows } -func (q *fakeQuerier) UpdateWorkspaceAgentKeysByID(_ context.Context, arg database.UpdateWorkspaceAgentKeysByIDParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, agent := range q.provisionerJobAgents { - if agent.ID != arg.ID { - continue - } - - agent.WireguardNodePublicKey = arg.WireguardNodePublicKey - agent.WireguardDiscoPublicKey = arg.WireguardDiscoPublicKey - agent.UpdatedAt = arg.UpdatedAt - q.provisionerJobAgents[index] = agent - return nil - } - return sql.ErrNoRows -} - func (q *fakeQuerier) UpdateWorkspaceAgentVersionByID(_ context.Context, arg database.UpdateWorkspaceAgentVersionByIDParams) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbtypes/dbtypes.go b/coderd/database/dbtypes/dbtypes.go deleted file mode 100644 index 3653f4f37cb62..0000000000000 --- a/coderd/database/dbtypes/dbtypes.go +++ /dev/null @@ -1,74 +0,0 @@ -package dbtypes - -import ( - "database/sql/driver" - - "golang.org/x/xerrors" - "tailscale.com/types/key" -) - -// NodePublic is a wrapper around a key.NodePublic which represents the -// Wireguard public key for an agent.. -type NodePublic key.NodePublic - -func (n NodePublic) String() string { - return key.NodePublic(n).String() -} - -// This is necessary so NodePublic can be serialized in JSON loggers. -func (n NodePublic) MarshalJSON() ([]byte, error) { - j, err := key.NodePublic(n).MarshalText() - // surround in quotes to make it a JSON string - j = append([]byte{'"'}, append(j, '"')...) - return j, err -} - -// Value is so NodePublic can be inserted into the database. -func (n NodePublic) Value() (driver.Value, error) { - return key.NodePublic(n).MarshalText() -} - -// Scan is so NodePublic can be read from the database. -func (n *NodePublic) Scan(value interface{}) error { - switch v := value.(type) { - case []byte: - return (*key.NodePublic)(n).UnmarshalText(v) - case string: - return (*key.NodePublic)(n).UnmarshalText([]byte(v)) - default: - return xerrors.Errorf("unexpected type: %T", v) - } -} - -// NodePublic is a wrapper around a key.NodePublic which represents the -// Tailscale disco key for an agent. -type DiscoPublic key.DiscoPublic - -func (n DiscoPublic) String() string { - return key.DiscoPublic(n).String() -} - -// This is necessary so DiscoPublic can be serialized in JSON loggers. -func (n DiscoPublic) MarshalJSON() ([]byte, error) { - j, err := key.DiscoPublic(n).MarshalText() - // surround in quotes to make it a JSON string - j = append([]byte{'"'}, append(j, '"')...) - return j, err -} - -// Value is so DiscoPublic can be inserted into the database. -func (n DiscoPublic) Value() (driver.Value, error) { - return key.DiscoPublic(n).MarshalText() -} - -// Scan is so DiscoPublic can be read from the database. -func (n *DiscoPublic) Scan(value interface{}) error { - switch v := value.(type) { - case []byte: - return (*key.DiscoPublic)(n).UnmarshalText(v) - case string: - return (*key.DiscoPublic)(n).UnmarshalText([]byte(v)) - default: - return xerrors.Errorf("unexpected type: %T", v) - } -} diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 25c2b45d4fa79..133136ad07689 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -309,9 +309,6 @@ CREATE TABLE workspace_agents ( instance_metadata jsonb, resource_metadata jsonb, directory character varying(4096) DEFAULT ''::character varying NOT NULL, - wireguard_node_ipv6 inet DEFAULT '::'::inet NOT NULL, - wireguard_node_public_key character varying(128) DEFAULT 'nodekey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL, - wireguard_disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL, version text DEFAULT ''::text NOT NULL ); diff --git a/coderd/database/migrations/000041_tailnet.down.sql b/coderd/database/migrations/000041_tailnet.down.sql new file mode 100644 index 0000000000000..e889d94f0934b --- /dev/null +++ b/coderd/database/migrations/000041_tailnet.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE workspace_agents + ADD COLUMN wireguard_node_ipv6 inet NOT NULL DEFAULT '::/128', + ADD COLUMN wireguard_node_public_key varchar(128) NOT NULL DEFAULT 'nodekey:0000000000000000000000000000000000000000000000000000000000000000', + ADD COLUMN wireguard_disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'; diff --git a/coderd/database/migrations/000041_tailnet.up.sql b/coderd/database/migrations/000041_tailnet.up.sql new file mode 100644 index 0000000000000..bcda153ee80f1 --- /dev/null +++ b/coderd/database/migrations/000041_tailnet.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE workspace_agents DROP COLUMN wireguard_node_ipv6; +ALTER TABLE workspace_agents DROP COLUMN wireguard_node_public_key; +ALTER TABLE workspace_agents DROP COLUMN wireguard_disco_public_key; diff --git a/coderd/database/models.go b/coderd/database/models.go index be6fe1bbfbbb5..4a58c476ce1d9 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -10,7 +10,6 @@ import ( "fmt" "time" - "github.com/coder/coder/coderd/database/dbtypes" "github.com/google/uuid" "github.com/tabbed/pqtype" ) @@ -519,26 +518,23 @@ type Workspace struct { } type WorkspaceAgent struct { - ID uuid.UUID `db:"id" json:"id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"` - LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"` - DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"` - ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` - AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` - AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` - Architecture string `db:"architecture" json:"architecture"` - EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` - OperatingSystem string `db:"operating_system" json:"operating_system"` - StartupScript sql.NullString `db:"startup_script" json:"startup_script"` - InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` - ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` - Directory string `db:"directory" json:"directory"` - WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` - WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"` + LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"` + DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"` + ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` + AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` + AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` + Architecture string `db:"architecture" json:"architecture"` + EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` + OperatingSystem string `db:"operating_system" json:"operating_system"` + StartupScript sql.NullString `db:"startup_script" json:"startup_script"` + InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` + ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` + Directory string `db:"directory" json:"directory"` // Version tracks the version of the currently running workspace agent. Workspace agents register their version upon start. Version string `db:"version" json:"version"` } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3644d2d969d37..1d817ee2af5b3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -143,7 +143,6 @@ type querier interface { UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (Workspace, error) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error - UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error UpdateWorkspaceAgentVersionByID(ctx context.Context, arg UpdateWorkspaceAgentVersionByIDParams) error UpdateWorkspaceAutostart(ctx context.Context, arg UpdateWorkspaceAutostartParams) error UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d3f15006c834b..662327a0722ff 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -10,7 +10,6 @@ import ( "encoding/json" "time" - "github.com/coder/coder/coderd/database/dbtypes" "github.com/google/uuid" "github.com/lib/pq" "github.com/tabbed/pqtype" @@ -3173,7 +3172,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE @@ -3203,9 +3202,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ) return i, err @@ -3213,7 +3209,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE @@ -3241,9 +3237,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ) return i, err @@ -3251,7 +3244,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE @@ -3281,9 +3274,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ) return i, err @@ -3291,7 +3281,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE @@ -3325,9 +3315,6 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids [] &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ); err != nil { return nil, err @@ -3344,7 +3331,7 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids [] } const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many -SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version FROM workspace_agents WHERE created_at > $1 +SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE created_at > $1 ` func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error) { @@ -3374,9 +3361,6 @@ func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, created &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ); err != nil { return nil, err @@ -3408,33 +3392,27 @@ INSERT INTO startup_script, directory, instance_metadata, - resource_metadata, - wireguard_node_ipv6, - wireguard_node_public_key, - wireguard_disco_public_key + resource_metadata ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version ` type InsertWorkspaceAgentParams struct { - ID uuid.UUID `db:"id" json:"id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` - AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` - AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` - Architecture string `db:"architecture" json:"architecture"` - EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` - OperatingSystem string `db:"operating_system" json:"operating_system"` - StartupScript sql.NullString `db:"startup_script" json:"startup_script"` - Directory string `db:"directory" json:"directory"` - InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` - ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` - WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` - WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` + AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` + AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` + Architecture string `db:"architecture" json:"architecture"` + EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` + OperatingSystem string `db:"operating_system" json:"operating_system"` + StartupScript sql.NullString `db:"startup_script" json:"startup_script"` + Directory string `db:"directory" json:"directory"` + InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` + ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` } func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) { @@ -3453,9 +3431,6 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa arg.Directory, arg.InstanceMetadata, arg.ResourceMetadata, - arg.WireguardNodeIPv6, - arg.WireguardNodePublicKey, - arg.WireguardDiscoPublicKey, ) var i WorkspaceAgent err := row.Scan( @@ -3476,9 +3451,6 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.WireguardNodeIPv6, - &i.WireguardNodePublicKey, - &i.WireguardDiscoPublicKey, &i.Version, ) return i, err @@ -3515,34 +3487,6 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg return err } -const updateWorkspaceAgentKeysByID = `-- name: UpdateWorkspaceAgentKeysByID :exec -UPDATE - workspace_agents -SET - wireguard_node_public_key = $2, - wireguard_disco_public_key = $3, - updated_at = $4 -WHERE - id = $1 -` - -type UpdateWorkspaceAgentKeysByIDParams struct { - ID uuid.UUID `db:"id" json:"id"` - WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` -} - -func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error { - _, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID, - arg.ID, - arg.WireguardNodePublicKey, - arg.WireguardDiscoPublicKey, - arg.UpdatedAt, - ) - return err -} - const updateWorkspaceAgentVersionByID = `-- name: UpdateWorkspaceAgentVersionByID :exec UPDATE workspace_agents diff --git a/coderd/database/queries/workspaceagents.sql b/coderd/database/queries/workspaceagents.sql index ca136a759ed2c..541c14372d707 100644 --- a/coderd/database/queries/workspaceagents.sql +++ b/coderd/database/queries/workspaceagents.sql @@ -53,13 +53,10 @@ INSERT INTO startup_script, directory, instance_metadata, - resource_metadata, - wireguard_node_ipv6, - wireguard_node_public_key, - wireguard_disco_public_key + resource_metadata ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING *; -- name: UpdateWorkspaceAgentConnectionByID :exec UPDATE @@ -72,16 +69,6 @@ SET WHERE id = $1; --- name: UpdateWorkspaceAgentKeysByID :exec -UPDATE - workspace_agents -SET - wireguard_node_public_key = $2, - wireguard_disco_public_key = $3, - updated_at = $4 -WHERE - id = $1; - -- name: UpdateWorkspaceAgentVersionByID :exec UPDATE workspace_agents diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 53bae1099b974..fec2bccc16c0a 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -16,12 +16,6 @@ packages: # deleted after generation. output_db_file_name: db_tmp.go -overrides: - - column: workspace_agents.wireguard_node_public_key - go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic - - column: workspace_agents.wireguard_disco_public_key - go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic - rename: api_key: APIKey login_type_oidc: LoginTypeOIDC @@ -34,5 +28,5 @@ rename: gitsshkey: GitSSHKey rbac_roles: RBACRoles ip_address: IPAddress - wireguard_node_ipv6: WireguardNodeIPv6 + ip_addresses: IPAddresses jwt: JWT diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 67a0219ec149e..6a109587676dc 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -23,13 +23,11 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/dbtypes" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/parameter" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer/peerwg" "github.com/coder/coder/provisionerd/proto" "github.com/coder/coder/provisionersdk" sdkproto "github.com/coder/coder/provisionersdk/proto" @@ -804,9 +802,6 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid. String: prAgent.StartupScript, Valid: prAgent.StartupScript != "", }, - WireguardNodeIPv6: peerwg.UUIDToInet(agentID), - WireguardNodePublicKey: dbtypes.NodePublic{}, - WireguardDiscoPublicKey: dbtypes.DiscoPublic{}, }) if err != nil { return xerrors.Errorf("insert agent: %w", err) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 04cbbd4d821df..19e802fc35440 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -264,7 +264,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, } } - apiAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading job agent.", diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index e4e57488c9cbf..5ed39af446de1 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -8,22 +8,21 @@ import ( "io" "net" "net/http" + "net/netip" "strconv" + "strings" "time" "github.com/google/uuid" "github.com/hashicorp/yamux" - "github.com/tabbed/pqtype" "golang.org/x/mod/semver" "golang.org/x/xerrors" - "inet.af/netaddr" "nhooyr.io/websocket" - "tailscale.com/types/key" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/dbtypes" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" @@ -31,10 +30,10 @@ import ( "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/peer" - "github.com/coder/coder/peer/peerwg" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/tailnet" ) func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { @@ -52,7 +51,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { }) return } - apiAgent, err := convertWorkspaceAgent(workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -76,7 +75,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { httpapi.ResourceNotFound(rw) return } - apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -127,7 +126,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -136,17 +135,8 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) return } - ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.WireguardNodeIPv6.IPNet) - if !ok { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Workspace agent has an invalid ipv6 address.", - Detail: workspaceAgent.WireguardNodeIPv6.IPNet.String(), - }) - return - } - httpapi.Write(rw, http.StatusOK, agent.Metadata{ - WireguardAddresses: []netaddr.IPPrefix{ipp}, + DERPMap: api.DERPMap, EnvironmentVariables: apiAgent.EnvironmentVariables, StartupScript: apiAgent.StartupScript, Directory: apiAgent.Directory, @@ -155,7 +145,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) { workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -431,7 +421,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { httpapi.ResourceNotFound(rw) return } - apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -498,141 +488,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { _, _ = io.Copy(ptNetConn, wsNetConn) } -func (*API) derpMap(rw http.ResponseWriter, _ *http.Request) { - httpapi.Write(rw, http.StatusOK, peerwg.DerpMap) -} - -type WorkspaceKeysRequest struct { - Public key.NodePublic `json:"public"` - Disco key.DiscoPublic `json:"disco"` -} - -func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - workspaceAgent = httpmw.WorkspaceAgent(r) - keys WorkspaceKeysRequest - ) - if !httpapi.Read(rw, r, &keys) { - return - } - - err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{ - ID: workspaceAgent.ID, - WireguardNodePublicKey: dbtypes.NodePublic(keys.Public), - WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco), - UpdatedAt: database.Now(), - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error setting agent keys.", - Detail: err.Error(), - }) - return - } - - rw.WriteHeader(http.StatusNoContent) -} - -func (api *API) postWorkspaceAgentWireguardPeer(rw http.ResponseWriter, r *http.Request) { - var ( - req peerwg.Handshake - workspaceAgent = httpmw.WorkspaceAgentParam(r) - workspace = httpmw.WorkspaceParam(r) - ) - - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { - httpapi.ResourceNotFound(rw) - return - } - - if !httpapi.Read(rw, r, &req) { - return - } - - if req.Recipient != workspaceAgent.ID { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid recipient.", - }) - return - } - - raw, err := req.MarshalText() - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error marshaling wireguard peer message.", - Detail: err.Error(), - }) - return - } - - err = api.Pubsub.Publish("wireguard_peers", raw) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error publishing wireguard peer message.", - Detail: err.Error(), - }) - return - } - - rw.WriteHeader(http.StatusNoContent) -} - -func (api *API) workspaceAgentWireguardListener(rw http.ResponseWriter, r *http.Request) { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() - - ctx := r.Context() - workspaceAgent := httpmw.WorkspaceAgent(r) - - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - defer conn.Close(websocket.StatusNormalClosure, "") - - agentIDBytes, _ := workspaceAgent.ID.MarshalText() - subCancel, err := api.Pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) { - // Since we subscribe to all peer broadcasts, we do a light check to - // make sure we're the intended recipient without fully decoding the - // message. - hint, err := peerwg.HandshakeRecipientHint(agentIDBytes, message) - if err != nil { - api.Logger.Error(ctx, "invalid wireguard peer message", slog.Error(err)) - return - } - - // We aren't the intended recipient. - if !hint { - return - } - - _ = conn.Write(ctx, websocket.MessageBinary, message) - }) - if err != nil { - api.Logger.Error(ctx, "pubsub listen", slog.Error(err)) - return - } - defer subCancel() - - // end span so we don't get long lived trace data - tracing.EndHTTPSpan(r, 200) - - // Wait for the connection to close or the client to send a message. - //nolint:dogsled - _, _, _ = conn.Reader(ctx) -} - // dialWorkspaceAgent connects to a workspace agent by ID. Only rely on // r.Context() for cancellation if it's use is safe or r.Hijack() has // not been performed. -func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) { +func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { client, server := provisionersdk.TransportPipe() ctx, cancelFunc := context.WithCancel(context.Background()) go func() { @@ -691,12 +550,110 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C <-peerConn.Closed() cancelFunc() }() - return &agent.Conn{ + return &agent.WebRTCConn{ Negotiator: peerClient, Conn: peerConn, }, nil } +func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + <-r.Context().Done() + _ = clientConn.Close() + _ = serverConn.Close() + }() + + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: api.DERPMap, + Logger: api.Logger.Named("tailnet").Leveled(slog.LevelDebug), + }) + if err != nil { + return nil, xerrors.Errorf("create tailnet conn: %w", err) + } + + sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNodes) + go func() { + err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID) + if err != nil { + _ = conn.Close() + } + }() + return &agent.TailnetConn{ + Conn: conn, + }, nil +} + +func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request) { + workspace := httpmw.WorkspaceParam(r) + if !api.Authorize(r, rbac.ActionRead, workspace) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(rw, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{ + DERPMap: api.DERPMap, + }) +} + +func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) { + api.websocketWaitMutex.Lock() + api.websocketWaitGroup.Add(1) + api.websocketWaitMutex.Unlock() + defer api.websocketWaitGroup.Done() + workspaceAgent := httpmw.WorkspaceAgent(r) + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } +} + +// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates. +// After accept a PubSub starts listening for new connection node updates +// which are written to the WebSocket. +func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) { + workspace := httpmw.WorkspaceParam(r) + if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { + httpapi.ResourceNotFound(rw) + return + } + + api.websocketWaitMutex.Lock() + api.websocketWaitGroup.Add(1) + api.websocketWaitMutex.Unlock() + defer api.websocketWaitGroup.Done() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + err = api.TailnetCoordinator.ServeClient(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } +} + func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { apps := make([]codersdk.WorkspaceApp, 0) for _, dbApp := range dbApps { @@ -710,28 +667,14 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { return apps } -func inetToNetaddr(inet pqtype.Inet) netaddr.IPPrefix { - if !inet.Valid { - return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128) - } - - ipp, ok := netaddr.FromStdIPNet(&inet.IPNet) - if !ok { - return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128) - } - - return ipp -} - -func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { +func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { var envs map[string]string if dbAgent.EnvironmentVariables.Valid { err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err) + return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err) } } - workspaceAgent := codersdk.WorkspaceAgent{ ID: dbAgent.ID, CreatedAt: dbAgent.CreatedAt, @@ -742,13 +685,29 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work Architecture: dbAgent.Architecture, OperatingSystem: dbAgent.OperatingSystem, StartupScript: dbAgent.StartupScript.String, + Version: dbAgent.Version, EnvironmentVariables: envs, Directory: dbAgent.Directory, Apps: apps, - IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6), - WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey), - DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey), - Version: dbAgent.Version, + } + node := coordinator.Node(dbAgent.ID) + if node != nil { + workspaceAgent.DERPLatency = map[string]codersdk.DERPRegion{} + for rawRegion, latency := range node.DERPLatency { + regionParts := strings.SplitN(rawRegion, "-", 2) + regionID, err := strconv.Atoi(regionParts[0]) + if err != nil { + return codersdk.WorkspaceAgent{}, xerrors.Errorf("convert derp region id %q: %w", rawRegion, err) + } + region, found := derpMap.Regions[regionID] + if !found { + return codersdk.WorkspaceAgent{}, xerrors.Errorf("region %d not found in derpmap", regionID) + } + workspaceAgent.DERPLatency[region.RegionName] = codersdk.DERPRegion{ + Preferred: node.PreferredDERP == regionID, + LatencyMilliseconds: latency * 1000, + } + } } if dbAgent.FirstConnectedAt.Valid { diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 5ec8274fecfa9..afe1411247313 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -109,8 +109,11 @@ func TestWorkspaceAgentListen(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { _ = agentCloser.Close() @@ -199,7 +202,7 @@ func TestWorkspaceAgentListen(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - _, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) + _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) require.Error(t, err) require.ErrorContains(t, err, "build is outdated") }) @@ -240,8 +243,11 @@ func TestWorkspaceAgentTURN(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { _ = agentCloser.Close() @@ -265,6 +271,65 @@ func TestWorkspaceAgentTURN(t *testing.T) { require.NoError(t, err) } +func TestWorkspaceAgentTailnet(t *testing.T) { + t.Parallel() + client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil) + user := coderdtest.CreateFirstUser(t, client) + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + daemonCloser.Close() + + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), + }) + defer agentCloser.Close() + resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID) + require.NoError(t, err) + defer conn.Close() + sshClient, err := conn.SSHClient() + require.NoError(t, err) + session, err := sshClient.NewSession() + require.NoError(t, err) + output, err := session.CombinedOutput("echo test") + require.NoError(t, err) + _ = session.Close() + _ = sshClient.Close() + _ = conn.Close() + require.Equal(t, "test", strings.TrimSpace(string(output))) +} + func TestWorkspaceAgentPTY(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { @@ -305,8 +370,11 @@ func TestWorkspaceAgentPTY(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), }) defer func() { _ = agentCloser.Close() diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index d17a5e95ccbbb..4a75e9fc78e99 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -81,8 +81,11 @@ func TestWorkspaceAppsProxyPath(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil), + agentCloser := agent.New(agent.Options{ + FetchMetadata: agentClient.WorkspaceAgentMetadata, + CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet, + WebRTCDialer: agentClient.ListenWorkspaceAgent, + Logger: slogtest.Make(t, nil).Named("agent"), }) t.Cleanup(func() { _ = agentCloser.Close() diff --git a/coderd/workspaceresources.go b/coderd/workspaceresources.go index 61f5ff2f4101e..416390132f48a 100644 --- a/coderd/workspaceresources.go +++ b/coderd/workspaceresources.go @@ -70,7 +70,7 @@ func (api *API) workspaceResource(rw http.ResponseWriter, r *http.Request) { } } - convertedAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + convertedAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 7d3b741a63b7e..698f467a40790 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { } // Dialer creates a new agent connection by ID. -type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error) +type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error) // Conn wraps an agent connection with a reusable HTTP transport. type Conn struct { - *agent.Conn + agent.Conn locks atomic.Uint64 timeoutMutex sync.Mutex diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 452c516c8a342..80f187ba15ab7 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -7,13 +7,13 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "net/netip" "net/url" "sync" "testing" "time" "github.com/google/uuid" - "github.com/pion/webrtc/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -23,10 +23,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/wsconncache" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" ) func TestMain(m *testing.M) { @@ -37,7 +35,7 @@ func TestCache(t *testing.T) { t.Parallel() t.Run("Same", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, 0) defer func() { @@ -52,7 +50,7 @@ func TestCache(t *testing.T) { t.Run("Expire", func(t *testing.T) { t.Parallel() called := atomic.NewInt32(0) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { called.Add(1) return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) @@ -71,7 +69,7 @@ func TestCache(t *testing.T) { }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -104,7 +102,7 @@ func TestCache(t *testing.T) { }() go server.Serve(random) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) { + cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) { return setupAgent(t, agent.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -141,37 +139,48 @@ func TestCache(t *testing.T) { }) } -func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn { - client, server := provisionersdk.TransportPipe() - closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) { - listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { - return nil, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - }, nil - }) - return metadata, listener, err - }, &agent.Options{ - Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), +func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn { + metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) + + coordinator := tailnet.NewCoordinator() + agentID := uuid.New() + closer := agent.New(agent.Options{ + FetchMetadata: func(ctx context.Context) (agent.Metadata, error) { + return metadata, nil + }, + CoordinatorDialer: func(ctx context.Context) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = serverConn.Close() + _ = clientConn.Close() + }) + go coordinator.ServeAgent(serverConn, agentID) + return clientConn, nil + }, + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelInfo), ReconnectingPTYTimeout: ptyTimeout, }) t.Cleanup(func() { - _ = client.Close() - _ = server.Close() _ = closer.Close() }) - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(context.Background()) - assert.NoError(t, err) - conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: metadata.DERPMap, + Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug), }) require.NoError(t, err) + clientConn, serverConn := net.Pipe() t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() _ = conn.Close() }) - - return &agent.Conn{ - Negotiator: api, - Conn: conn, + go coordinator.ServeClient(serverConn, uuid.New(), agentID) + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNode) + return &agent.TailnetConn{ + Conn: conn, } } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index b2aeff66aec02..816d4ea4a01e9 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -3,11 +3,14 @@ package codersdk import ( "context" "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/http/cookiejar" + "net/netip" + "time" "cloud.google.com/go/compute/metadata" "github.com/google/uuid" @@ -16,16 +19,18 @@ import ( "golang.org/x/net/proxy" "golang.org/x/xerrors" "nhooyr.io/websocket" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/peer" - "github.com/coder/coder/peer/peerwg" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/tailnet" + "github.com/coder/retry" ) type GoogleInstanceIdentityToken struct { @@ -48,6 +53,12 @@ type WorkspaceAgentAuthenticateResponse struct { SessionToken string `json:"session_token"` } +// WorkspaceAgentConnectionInfo returns required information for establishing +// a connection with a workspace. +type WorkspaceAgentConnectionInfo struct { + DERPMap *tailcfg.DERPMap `json:"derp_map"` +} + type PostWorkspaceAgentVersionRequest struct { Version string `json:"version"` } @@ -180,16 +191,30 @@ func (c *Client) AuthWorkspaceAzureInstanceIdentity(ctx context.Context) (Worksp return resp, json.NewDecoder(res.Body).Decode(&resp) } +// WorkspaceAgentMetadata fetches metadata for the currently authenticated workspace agent. +func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (agent.Metadata, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/metadata", nil) + if err != nil { + return agent.Metadata{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return agent.Metadata{}, readBodyAsError(res) + } + var agentMetadata agent.Metadata + return agentMetadata, json.NewDecoder(res.Body).Decode(&agentMetadata) +} + // ListenWorkspaceAgent connects as a workspace agent identifying with the session token. // On each inbound connection request, connection info is fetched. -func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) { +func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/listen") if err != nil { - return agent.Metadata{}, nil, xerrors.Errorf("parse url: %w", err) + return nil, xerrors.Errorf("parse url: %w", err) } jar, err := cookiejar.New(nil) if err != nil { - return agent.Metadata{}, nil, xerrors.Errorf("create cookie jar: %w", err) + return nil, xerrors.Errorf("create cookie jar: %w", err) } jar.SetCookies(serverURL, []*http.Cookie{{ Name: SessionTokenKey, @@ -205,17 +230,17 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) ( }) if err != nil { if res == nil { - return agent.Metadata{}, nil, err + return nil, err } - return agent.Metadata{}, nil, readBodyAsError(res) + return nil, readBodyAsError(res) } config := yamux.DefaultConfig() config.LogOutput = io.Discard session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) if err != nil { - return agent.Metadata{}, nil, xerrors.Errorf("multiplex client: %w", err) + return nil, xerrors.Errorf("multiplex client: %w", err) } - listener, err := peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { + return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { // This can be cached if it adds to latency too much. res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil) if err != nil { @@ -241,116 +266,120 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) ( Logger: logger, }, nil }) +} + +func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) { + coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate") if err != nil { - return agent.Metadata{}, nil, xerrors.Errorf("listen peerbroker: %w", err) + return nil, xerrors.Errorf("parse url: %w", err) } - - // Fetch updated agent metadata - res, err = c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/metadata", nil) + jar, err := cookiejar.New(nil) if err != nil { - return agent.Metadata{}, nil, err + return nil, xerrors.Errorf("create cookie jar: %w", err) } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return agent.Metadata{}, nil, readBodyAsError(res) + jar.SetCookies(coordinateURL, []*http.Cookie{{ + Name: SessionTokenKey, + Value: c.SessionToken, + }}) + httpClient := &http.Client{ + Jar: jar, } - var agentMetadata agent.Metadata - return agentMetadata, listener, json.NewDecoder(res.Body).Decode(&agentMetadata) + // nolint:bodyclose + conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + }) + if err != nil { + return nil, err + } + + return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -// PostWireguardPeer announces your public keys and IPv6 address to the -// specified recipient. -func (c *Client) PostWireguardPeer(ctx context.Context, workspaceID uuid.UUID, peerMsg peerwg.Handshake) error { - res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/peer?workspace=%s", - peerMsg.Recipient, - workspaceID.String(), - ), peerMsg) +func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil) if err != nil { - return err + return nil, err } defer res.Body.Close() - if res.StatusCode != http.StatusNoContent { - return readBodyAsError(res) + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var connInfo WorkspaceAgentConnectionInfo + err = json.NewDecoder(res.Body).Decode(&connInfo) + if err != nil { + return nil, xerrors.Errorf("decode conn info: %w", err) } - _, _ = io.Copy(io.Discard, res.Body) - return nil -} + ip := tailnet.IP() + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, + DERPMap: connInfo.DERPMap, + Logger: logger, + }) + if err != nil { + return nil, xerrors.Errorf("create tailnet: %w", err) + } -// WireguardPeerListener listens for wireguard peer messages. Peer messages are -// sent when a new client wants to connect. Once receiving a peer message, the -// peer should be added to the NetworkMap of the wireguard interface. -func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error) { - serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/wireguardlisten") + coordinateURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID)) if err != nil { - return nil, nil, xerrors.Errorf("parse url: %w", err) + return nil, xerrors.Errorf("parse url: %w", err) } jar, err := cookiejar.New(nil) if err != nil { - return nil, nil, xerrors.Errorf("create cookie jar: %w", err) + return nil, xerrors.Errorf("create cookie jar: %w", err) } - jar.SetCookies(serverURL, []*http.Cookie{{ + jar.SetCookies(coordinateURL, []*http.Cookie{{ Name: SessionTokenKey, Value: c.SessionToken, }}) httpClient := &http.Client{ Jar: jar, } - - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return nil, nil, xerrors.Errorf("websocket dial: %w", err) - } - return nil, nil, readBodyAsError(res) - } - - ch := make(chan peerwg.Handshake, 1) + ctx, cancelFunc := context.WithCancel(ctx) + closed := make(chan struct{}) go func() { - defer conn.Close(websocket.StatusGoingAway, "") - defer close(ch) - - for { - _, message, err := conn.Read(ctx) + defer close(closed) + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + logger.Debug(ctx, "connecting") + // nolint:bodyclose + ws, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if errors.Is(err, context.Canceled) { + return + } if err != nil { - break + logger.Debug(ctx, "failed to dial", slog.Error(err)) + continue + } + sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { + return conn.UpdateNodes(node) + }) + conn.SetNodeCallback(sendNode) + logger.Debug(ctx, "serving coordinator") + err = <-errChan + if errors.Is(err, context.Canceled) { + return } - - var msg peerwg.Handshake - err = msg.UnmarshalText(message) if err != nil { - logger.Error(ctx, "unmarshal wireguard peer message", slog.Error(err)) + logger.Debug(ctx, "error serving coordinator", slog.Error(err)) continue } - - ch <- msg } }() - - return ch, func() { _ = conn.Close(websocket.StatusGoingAway, "") }, nil -} - -// UploadWorkspaceAgentKeys uploads the public keys of the workspace agent that -// were generated on startup. These keys are used by clients to communicate with -// the workspace agent over the wireguard interface. -func (c *Client) UploadWorkspaceAgentKeys(ctx context.Context, keys agent.WireguardPublicKeys) error { - res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/keys", keys) - if err != nil { - return xerrors.Errorf("do request: %w", err) - } - defer res.Body.Close() - if res.StatusCode != http.StatusNoContent { - return readBodyAsError(res) - } - return nil + return &agent.TailnetConn{ + Conn: conn, + CloseFunc: func() { + cancelFunc() + <-closed + }, + }, nil } // DialWorkspaceAgent creates a connection to the specified resource. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (*agent.Conn, error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) { serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String())) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) @@ -415,7 +444,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti if err != nil { return nil, xerrors.Errorf("dial peer: %w", err) } - return &agent.Conn{ + return &agent.WebRTCConn{ Negotiator: client, Conn: peerConn, }, nil diff --git a/codersdk/workspaceresources.go b/codersdk/workspaceresources.go index cd9b7cff1c9e6..03096c84c4d0f 100644 --- a/codersdk/workspaceresources.go +++ b/codersdk/workspaceresources.go @@ -8,8 +8,6 @@ import ( "time" "github.com/google/uuid" - "inet.af/netaddr" - "tailscale.com/types/key" ) type WorkspaceAgentStatus string @@ -37,6 +35,11 @@ type WorkspaceResourceMetadata struct { Sensitive bool `json:"sensitive"` } +type DERPRegion struct { + Preferred bool `json:"preferred"` + LatencyMilliseconds float64 `json:"latency_ms"` +} + type WorkspaceAgent struct { ID uuid.UUID `json:"id"` CreatedAt time.Time `json:"created_at"` @@ -53,11 +56,10 @@ type WorkspaceAgent struct { OperatingSystem string `json:"operating_system"` StartupScript string `json:"startup_script,omitempty"` Directory string `json:"directory,omitempty"` - Apps []WorkspaceApp `json:"apps"` - WireguardPublicKey key.NodePublic `json:"wireguard_public_key"` - DiscoPublicKey key.DiscoPublic `json:"disco_public_key"` - IPv6 netaddr.IPPrefix `json:"ipv6"` Version string `json:"version"` + Apps []WorkspaceApp `json:"apps"` + // DERPLatency is mapped by region name (e.g. "New York City", "Seattle"). + DERPLatency map[string]DERPRegion `json:"latency"` } type WorkspaceAgentResourceMetadata struct { diff --git a/go.mod b/go.mod index 7b8df9f5613aa..2b8480f592d66 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ replace github.com/golang/glog => github.com/coder/glog v1.0.1-0.20220322161911- // https://github.com/coder/kcp-go/commit/83c0904cec69dcf21ec10c54ea666bda18ada831 replace github.com/fatedier/kcp-go => github.com/coder/kcp-go v2.0.4-0.20220409183554-83c0904cec69+incompatible -replace golang.zx2c4.com/wireguard/tun/netstack => github.com/coder/wireguard-go/tun/netstack v0.0.0-20220614153727-d82b4ba8619f +replace golang.zx2c4.com/wireguard/tun/netstack => github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab // https://github.com/pion/udp/pull/73 replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b2872e92e98d @@ -44,6 +44,13 @@ replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b // https://github.com/hashicorp/hc-install/pull/68 replace github.com/hashicorp/hc-install => github.com/mafredri/hc-install v0.4.1-0.20220727132613-e91868e28445 +// https://github.com/tcnksm/go-httpstat/pull/29 +replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0-20220831233600-c91452099472 + +// There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: +// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220831012541-a77bda274fd6 + require ( cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f cloud.google.com/go/compute v1.7.0 @@ -129,12 +136,12 @@ require ( go.uber.org/atomic v1.9.0 go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 - golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/oauth2 v0.0.0-20220622183110-fd043fe589d2 golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f - golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d + golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/text v0.3.7 golang.org/x/tools v0.1.11 @@ -146,7 +153,7 @@ require ( google.golang.org/protobuf v1.28.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/yaml.v3 v3.0.1 - inet.af/netaddr v0.0.0-20220617031823-097006376321 + gvisor.dev/gvisor v0.0.0-20220801230058-850e42eb4444 k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9 nhooyr.io/websocket v1.8.7 storj.io/drpc v0.0.33-0.20220622181519-9206537a4db7 @@ -154,6 +161,7 @@ require ( ) require ( + filippo.io/edwards25519 v1.0.0-rc.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect @@ -183,6 +191,7 @@ require ( github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/elastic/go-windows v1.0.0 // indirect + github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/gin-gonic/gin v1.7.0 // indirect github.com/go-logr/logr v1.2.3 // indirect @@ -206,6 +215,7 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/insomniacslk/dhcp v0.0.0-20211209223715-7d93572ebe8e // indirect @@ -265,10 +275,11 @@ require ( github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect github.com/tcnksm/go-httpstat v0.2.0 // indirect github.com/tdewolff/parse/v2 v2.6.0 // indirect - github.com/u-root/uio v0.0.0-20210528151154-e40b768296a7 // indirect + github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 // indirect github.com/vektah/gqlparser/v2 v2.4.4 // indirect github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect + github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect @@ -281,9 +292,8 @@ require ( go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.8.0 // indirect go.opentelemetry.io/otel/metric v0.31.0 // indirect go.opentelemetry.io/proto/otlp v0.18.0 // indirect - go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect - go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect + go4.org/netipx v0.0.0-20220725152314-7e7bdc8411bf golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect golang.zx2c4.com/wireguard/windows v0.4.10 // indirect @@ -292,6 +302,5 @@ require ( google.golang.org/grpc v1.47.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gvisor.dev/gvisor v0.0.0-20220407223209-21871174d445 // indirect howett.net/plist v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index b40a80417b716..f479632bf4fb2 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f cloud.google.com/go/storage v1.22.1/go.mod h1:S8N1cAStu7BOeFfE8KAQzmyyLkK8p/vmRq6kuBTW58Y= contrib.go.opencensus.io/exporter/stackdriver v0.13.4/go.mod h1:aXENhDJ1Y4lIg4EUaVTwzvYETVNZk10Pu26tevFKLUc= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/edwards25519 v1.0.0-rc.1 h1:m0VOOB23frXZvAOK44usCgLWvtsxIoMCTBGJZlpmGfU= +filippo.io/edwards25519 v1.0.0-rc.1/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= filippo.io/mkcert v1.4.3 h1:axpnmtrZMM8u5Hf4N3UXxboGemMOV+Tn+e+pkHM6E3o= gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20210715213245-6c3934b029d8/go.mod h1:CzsSbkDixRphAF5hS6wbMKq0eI6ccJRb7/A0M6JBnwg= @@ -350,8 +352,10 @@ github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1 h1:UqBrPWSYvRI2s5RtOu github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= -github.com/coder/wireguard-go/tun/netstack v0.0.0-20220614153727-d82b4ba8619f h1:wsrm7hB9cvvw8ybX41YjzXDMbpo3gjlesw7oHYhtZW4= -github.com/coder/wireguard-go/tun/netstack v0.0.0-20220614153727-d82b4ba8619f/go.mod h1:PerNzwKlnUUbKSRrSghbyhE9wEl3xakvPY9muprxlv8= +github.com/coder/tailscale v1.1.1-0.20220831012541-a77bda274fd6 h1://ApBDDh58hFwMe0AzlgqJrGhzu6Rjk8fQXrR+mbhYE= +github.com/coder/tailscale v1.1.1-0.20220831012541-a77bda274fd6/go.mod h1:MO+tWkQp2YIF3KBnnej/mQvgYccRS5Xk/IrEpZ4Z3BU= +github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab h1:9yEvRWXXfyKzXu8AqywCi+tFZAoqCy4wVcsXwuvZNMc= +github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab/go.mod h1:TCJ66NtXh3urJotTdoYQOHHkyE899vOQl5TuF+WLSes= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= @@ -554,7 +558,6 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/elastic/go-sysinfo v1.8.1 h1:4Yhj+HdV6WjbCRgGdZpPJ8lZQlXZLKDAeIkmQ/VRvi4= github.com/elastic/go-sysinfo v1.8.1/go.mod h1:JfllUnzoQV/JRYymbH3dO1yggI3mV2oTKSXsDHM+uIM= @@ -616,6 +619,8 @@ github.com/fsouza/fake-gcs-server v1.17.0/go.mod h1:D1rTE4YCyHFNa99oyJJ5HyclvN/0 github.com/fullsailor/pkcs7 v0.0.0-20190404230743-d7302db945fa h1:RDBNVkRviHZtvDvId8XSGPu3rmpmSe+wKRcEWNgsfWU= github.com/fullsailor/pkcs7 v0.0.0-20190404230743-d7302db945fa/go.mod h1:KnogPXtdwXqoenmZCw6S+25EAm2MkxbG0deNDu4cbSA= github.com/fullstorydev/grpcurl v1.6.0/go.mod h1:ZQ+ayqbKMJNhzLmbpCiurTVlaK2M/3nqZCxaQ2Ze/sM= +github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= +github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fzipp/gocyclo v0.3.1/go.mod h1:DJHO6AUmbdqj2ET4Z9iArSuwWgYDRryYt2wASxc7x3E= github.com/gabriel-vasile/mimetype v1.3.1/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmxRtOJlERCzSmRvr8= github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmxRtOJlERCzSmRvr8= @@ -1029,6 +1034,8 @@ github.com/hashicorp/terraform-json v0.14.0 h1:sh9iZ1Y8IFJLx+xQiKHGud6/TSUCM0N8e github.com/hashicorp/terraform-json v0.14.0/go.mod h1:5A9HIWPkk4e5aeeXIBbkcOvaZbIYnAIkEyqP2pNSckM= github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce h1:7FO+LmZwiG/eDsBWo50ZeqV5PoH0gwiM1mxFajXAkas= github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= +github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSVUgRRRtOrZOC1fYmY9gV0e9z/Iu+xNVSASWjsuyGU= +github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -1203,6 +1210,8 @@ github.com/kulti/thelper v0.4.0/go.mod h1:vMu2Cizjy/grP+jmsvOFDx1kYP6+PD1lqg4Yu5 github.com/kunwardeep/paralleltest v1.0.3/go.mod h1:vLydzomDFpk7yu5UX02RmP0H8QfRPOV/oFhWN85Mjb4= github.com/kylecarbs/embedded-postgres v1.17.1-0.20220615202325-461532cecd3a h1:uOnis+HNE6e6eR17YlqzKk51GDahd7E/FacnZxS8h8w= github.com/kylecarbs/embedded-postgres v1.17.1-0.20220615202325-461532cecd3a/go.mod h1:0B+3bPsMvcNgR9nN+bdM2x9YaNYDnf3ksUqYp1OAub0= +github.com/kylecarbs/go-httpstat v0.0.0-20220831233600-c91452099472 h1:KXbxoQY9tOxgacpw0vbHWfIb56Xuzgi0Oql5yr6RYaA= +github.com/kylecarbs/go-httpstat v0.0.0-20220831233600-c91452099472/go.mod h1:MdOqT7wdglCBuU45KzMIvO+xdKlCGHPUWwdTxytqHBU= github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b h1:1Y1X6aR78kMEQE1iCjQodB3lA7VO4jB88Wf8ZrzXSsA= github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY= @@ -1781,8 +1790,6 @@ github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05/go.mod h1:PdCqy github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= -github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= -github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8= github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b/go.mod h1:yHp0ai0Z9gUljN3o0xMhYJnH/IcvkdTBOX2fmJ93JEM= github.com/tdewolff/parse/v2 v2.6.0 h1:f2D7w32JtqjCv6SczWkfwK+m15et42qEtDnZXHoNY70= github.com/tdewolff/parse/v2 v2.6.0/go.mod h1:WzaJpRSbwq++EIQHYIRTpbYKNA3gn9it1Ik++q4zyho= @@ -1804,8 +1811,8 @@ github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoi github.com/tommy-muehle/go-mnd/v2 v2.4.0/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= github.com/u-root/uio v0.0.0-20210528114334-82958018845c/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= -github.com/u-root/uio v0.0.0-20210528151154-e40b768296a7 h1:XMAtQHwKjWHIRwg+8Nj/rzUomQY1q6cM3ncA0wP8GU4= -github.com/u-root/uio v0.0.0-20210528151154-e40b768296a7/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= +github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 h1:hl6sK6aFgTLISijk6xIzeqnPzQcsLqqvL6vEfTPinME= +github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= @@ -1845,6 +1852,8 @@ github.com/vmihailenco/msgpack/v4 v4.3.12/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+ github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xanzy/go-gitlab v0.15.0/go.mod h1:8zdQa/ri1dfn8eS3Ir1SyfvOKlw7WBJ8DVThkpGiXrs= github.com/xanzy/ssh-agent v0.3.0/go.mod h1:3s9xbODqPuuhK9JV1R321M/FlMZSBvE5aY6eAcqrDh0= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= @@ -1982,12 +1991,11 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go4.org/intern v0.0.0-20211027215823-ae77deb06f29 h1:UXLjNohABv4S58tHmeuIZDO6e3mHpW2Dx33gaNt03LE= -go4.org/intern v0.0.0-20211027215823-ae77deb06f29/go.mod h1:cS2ma+47FKrLPdXFpr7CuxiTW3eyJbWew4qx0qtQWDA= go4.org/mem v0.0.0-20210711025021-927187094b94 h1:OAAkygi2Js191AJP1Ds42MhJRgeofeKGjuoUqNp1QC4= go4.org/mem v0.0.0-20210711025021-927187094b94/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= +go4.org/netipx v0.0.0-20220725152314-7e7bdc8411bf h1:IdwJUzqoIo5lkr2EOyKoe5qipUaEjbOKKY5+fzPBZ3A= +go4.org/netipx v0.0.0-20220725152314-7e7bdc8411bf/go.mod h1:+QXzaoURFd0rGDIjDNpyIkv+F9R7EmeKorvlKRnhqgA= go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 h1:FyBZqvoA/jbNzuAWLQE2kG820zMAkcilx6BMjGbL/E4= -go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180501155221-613d6eafa307/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -2046,8 +2054,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= -golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd h1:zVFyTKZN/Q7mNRWSs1GOYnHM9NiFSJ54YVRsD0rNWT4= -golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= +golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= golang.org/x/exp/typeparams v0.0.0-20220328175248-053ad81199eb h1:fP6C8Xutcp5AlakmT/SkQot0pMicROAsEX7OfNPuG10= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= @@ -2377,8 +2385,8 @@ golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d h1:/m5NbqQelATgoSPVC2Z23sR4kVNokFwDDyWh/3rGY+I= -golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -2822,8 +2830,8 @@ gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= gotest.tools/v3 v3.2.0 h1:I0DwBVMGAx26dttAj1BtJLAkVGncrkkUXfJLC4Flt/I= -gvisor.dev/gvisor v0.0.0-20220407223209-21871174d445 h1:pLNQCtMzh4O6rdhoUeWHuutt4yMft+B9Cgw/bezWchE= -gvisor.dev/gvisor v0.0.0-20220407223209-21871174d445/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= +gvisor.dev/gvisor v0.0.0-20220801230058-850e42eb4444 h1:0d3ygmOM5RgQB8rmsZNeAY/7Q98fKt1HrGO2XIp4pDI= +gvisor.dev/gvisor v0.0.0-20220801230058-850e42eb4444/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -2835,8 +2843,6 @@ honnef.co/go/tools v0.2.1/go.mod h1:lPVVZ2BS5TfnjLyizF7o7hv7j9/L+8cZY2hLyjP9cGY= honnef.co/go/tools v0.4.0-0.dev.0.20220404092545-59d7a2877f83 h1:lZ9GIYaU+o5+X6ST702I/Ntyq9Y2oIMZ42rBQpem64A= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= -inet.af/netaddr v0.0.0-20220617031823-097006376321 h1:B4dC8ySKTQXasnjDTMsoCMf1sQG4WsMej0WXaHxunmU= -inet.af/netaddr v0.0.0-20220617031823-097006376321/go.mod h1:OIezDfdzOgFhuw4HuWapWq2e9l0H9tK4F1j+ETRtF3k= k8s.io/api v0.20.1/go.mod h1:KqwcCVogGxQY3nBlRpwt+wpAMF/KjaCc7RpywacvqUo= k8s.io/api v0.20.4/go.mod h1:++lNL1AJMkDymriNniQsWRkMDzRaX2Y/POTUi8yvqYQ= k8s.io/api v0.20.6/go.mod h1:X9e8Qag6JV/bL5G6bU8sdVRltWKmdHsFUGS3eVndqE8= @@ -2931,5 +2937,3 @@ sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= software.sslmate.com/src/go-pkcs12 v0.0.0-20210415151418-c5206de65a78 h1:SqYE5+A2qvRhErbsXFfUEUmpWEKxxRSMgGLkvRAFOV4= storj.io/drpc v0.0.33-0.20220622181519-9206537a4db7 h1:6jIp39oQGZMjfrG3kiafK2tcL0Fbprh2kvaoJNfhvuM= storj.io/drpc v0.0.33-0.20220622181519-9206537a4db7/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg= -tailscale.com v1.26.2 h1:EBR0DXblI2Rx3mPe/YU29oZbQLnC8BtJYUTufmEygUY= -tailscale.com v1.26.2/go.mod h1:KM47Ct0eNTFJqoazXV5XRdMnnWtD2HHDciY9RwyqweE= diff --git a/peer/peerwg/derp.go b/peer/peerwg/derp.go deleted file mode 100644 index 25c677fdfb2bd..0000000000000 --- a/peer/peerwg/derp.go +++ /dev/null @@ -1,67 +0,0 @@ -package peerwg - -import ( - "net" - - "tailscale.com/tailcfg" - "tailscale.com/wgengine/magicsock" -) - -// This is currently set to use Tailscale's DERP server in DFW while we build in -// our own support for DERP servers. -var DerpMap = &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 9: { - RegionID: 9, - RegionCode: "dfw", - RegionName: "Dallas", - Avoid: false, - Nodes: []*tailcfg.DERPNode{ - { - Name: "9a", - RegionID: 9, - HostName: "derp9.tailscale.com", - CertName: "", - IPv4: "207.148.3.137", - IPv6: "2001:19f0:6401:1d9c:5400:2ff:feef:bb82", - STUNPort: 0, - STUNOnly: false, - DERPPort: 0, - InsecureForTests: false, - STUNTestIP: "", - }, - { - Name: "9c", - RegionID: 9, - HostName: "derp9c.tailscale.com", - CertName: "", - IPv4: "155.138.243.219", - IPv6: "2001:19f0:6401:fe7:5400:3ff:fe8d:6d9c", - STUNPort: 0, - STUNOnly: false, - DERPPort: 0, - InsecureForTests: false, - STUNTestIP: "", - }, - { - Name: "9b", - RegionID: 9, - HostName: "derp9b.tailscale.com", - CertName: "", - IPv4: "144.202.67.195", - IPv6: "2001:19f0:6401:eb5:5400:3ff:fe8d:6d9b", - STUNPort: 0, - STUNOnly: false, - DERPPort: 0, - InsecureForTests: false, - STUNTestIP: "", - }, - }, - }, - }, - OmitDefaultRegions: true, -} - -// DefaultDerpHome is the ipv4 representation of a DERP server. The port is the -// DERP id. We only support using DERP 9 for now. -var DefaultDerpHome = net.JoinHostPort(magicsock.DerpMagicIP, "9") diff --git a/peer/peerwg/handshake.go b/peer/peerwg/handshake.go deleted file mode 100644 index 08fdc234052ad..0000000000000 --- a/peer/peerwg/handshake.go +++ /dev/null @@ -1,94 +0,0 @@ -package peerwg - -import ( - "bytes" - "strconv" - - "github.com/google/uuid" - "golang.org/x/xerrors" - "inet.af/netaddr" - "tailscale.com/types/key" -) - -const handshakeSeparator byte = '|' - -// Handshake is a message received from a wireguard peer, indicating -// it would like to connect. -type Handshake struct { - // Recipient is the uuid of the agent that the message was intended for. - Recipient uuid.UUID `json:"recipient"` - // DiscoPublicKey is the disco public key of the peer. - DiscoPublicKey key.DiscoPublic `json:"disco"` - // NodePublicKey is the public key of the peer. - NodePublicKey key.NodePublic `json:"public"` - // IPv6 is the IPv6 address of the peer. - IPv6 netaddr.IP `json:"ipv6"` -} - -// HandshakeRecipientHint parses the first part of a serialized -// Handshake to quickly determine if the message is meant for the -// provided recipient. -func HandshakeRecipientHint(agentID []byte, msg []byte) (bool, error) { - idx := bytes.Index(msg, []byte{handshakeSeparator}) - if idx == -1 { - return false, xerrors.Errorf("invalid peer message, no separator") - } - - return bytes.Equal(agentID, msg[:idx]), nil -} - -func (h *Handshake) UnmarshalText(text []byte) error { - sp := bytes.Split(text, []byte{handshakeSeparator}) - if len(sp) != 4 { - return xerrors.Errorf("expected 4 parts, got %d", len(sp)) - } - - err := h.Recipient.UnmarshalText(sp[0]) - if err != nil { - return xerrors.Errorf("parse recipient: %w", err) - } - - err = h.DiscoPublicKey.UnmarshalText(sp[1]) - if err != nil { - return xerrors.Errorf("parse disco: %w", err) - } - - err = h.NodePublicKey.UnmarshalText(sp[2]) - if err != nil { - return xerrors.Errorf("parse public: %w", err) - } - - h.IPv6, err = netaddr.ParseIP(string(sp[3])) - if err != nil { - return xerrors.Errorf("parse ipv6: %w", err) - } - - return nil -} - -func (h Handshake) MarshalText() ([]byte, error) { - const expectedLen = 223 - var buf bytes.Buffer - buf.Grow(expectedLen) - - recp, _ := h.Recipient.MarshalText() - _, _ = buf.Write(recp) - _ = buf.WriteByte(handshakeSeparator) - - disco, _ := h.DiscoPublicKey.MarshalText() - _, _ = buf.Write(disco) - _ = buf.WriteByte(handshakeSeparator) - - pub, _ := h.NodePublicKey.MarshalText() - _, _ = buf.Write(pub) - _ = buf.WriteByte(handshakeSeparator) - - ipv6 := h.IPv6.StringExpanded() - _, _ = buf.WriteString(ipv6) - - // Ensure we're always allocating exactly enough. - if buf.Len() != expectedLen { - panic("buffer length mismatch: want 223, got " + strconv.Itoa(buf.Len())) - } - return buf.Bytes(), nil -} diff --git a/peer/peerwg/ssh.go b/peer/peerwg/ssh.go deleted file mode 100644 index 9ffe8cc92c816..0000000000000 --- a/peer/peerwg/ssh.go +++ /dev/null @@ -1,38 +0,0 @@ -package peerwg - -import ( - "context" - "net" - - "golang.org/x/crypto/ssh" - "golang.org/x/xerrors" - "inet.af/netaddr" -) - -func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) { - netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212)) - if err != nil { - return nil, xerrors.Errorf("dial agent ssh: %w", err) - } - - return netConn, nil -} - -func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) { - netConn, err := n.SSH(ctx, ip) - if err != nil { - return nil, xerrors.Errorf("ssh: %w", err) - } - - sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ - // SSH host validation isn't helpful, because obtaining a peer - // connection already signifies user-intent to dial a workspace. - // #nosec - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }) - if err != nil { - return nil, xerrors.Errorf("new ssh client conn: %w", err) - } - - return ssh.NewClient(sshConn, channels, requests), nil -} diff --git a/peer/peerwg/wireguard.go b/peer/peerwg/wireguard.go deleted file mode 100644 index b210b2b70dadc..0000000000000 --- a/peer/peerwg/wireguard.go +++ /dev/null @@ -1,441 +0,0 @@ -package peerwg - -import ( - "context" - "fmt" - "hash/fnv" - "io" - "log" - "net" - "strconv" - "sync" - "time" - - "github.com/google/uuid" - "github.com/tabbed/pqtype" - "golang.org/x/xerrors" - "inet.af/netaddr" - "tailscale.com/ipn/ipnstate" - "tailscale.com/net/dns" - "tailscale.com/net/netns" - "tailscale.com/net/tsdial" - "tailscale.com/tailcfg" - "tailscale.com/types/ipproto" - "tailscale.com/types/key" - tslogger "tailscale.com/types/logger" - "tailscale.com/types/netmap" - "tailscale.com/wgengine" - "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/magicsock" - "tailscale.com/wgengine/monitor" - "tailscale.com/wgengine/netstack" - "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg/nmcfg" - - "cdr.dev/slog" -) - -var Logf tslogger.Logf = log.Printf - -func init() { - // Globally disable network namespacing. - // All networking happens in userspace. - netns.SetEnabled(false) -} - -func UUIDToInet(uid uuid.UUID) pqtype.Inet { - uid = privateUUID(uid) - - return pqtype.Inet{ - Valid: true, - IPNet: net.IPNet{ - IP: uid[:], - Mask: net.CIDRMask(128, 128), - }, - } -} - -func UUIDToNetaddr(uid uuid.UUID) netaddr.IP { - return netaddr.IPFrom16(privateUUID(uid)) -} - -// privateUUID sets the uid to have the tailscale private ipv6 prefix. -func privateUUID(uid uuid.UUID) uuid.UUID { - // fd7a:115c:a1e0 - uid[0] = 0xfd - uid[1] = 0x7a - uid[2] = 0x11 - uid[3] = 0x5c - uid[4] = 0xa1 - uid[5] = 0xe0 - return uid -} - -type Network struct { - mu sync.Mutex - logger slog.Logger - - Netstack *netstack.Impl - magicSock *magicsock.Conn - netMap *netmap.NetworkMap - router *router.Config - wgEngine wgengine.Engine - - // listeners is a map of listening sockets that will be forwarded traffic - // from the wireguard interface. - listeners map[listenKey]*listener - - DiscoPublicKey key.DiscoPublic - NodePrivateKey key.NodePrivate -} - -// New constructs a Wireguard network that filters traffic -// to destinations matching the addresses provided. -func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { - nodePrivateKey := key.NewNode() - nodePublicKey := nodePrivateKey.Public() - id, stableID := nodeIDs(nodePublicKey) - - netMap := &netmap.NetworkMap{ - NodeKey: nodePublicKey, - PrivateKey: nodePrivateKey, - Addresses: addresses, - PacketFilter: []filter.Match{{ - // Allow any protocol! - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP}, - // Allow traffic sourced from anywhere. - Srcs: []netaddr.IPPrefix{ - netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0), - netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 0), - }, - // Allow traffic to route anywhere. - Dsts: []filter.NetPortRange{ - { - Net: netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - Net: netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - Caps: []filter.CapMatch{}, - }}, - } - // Identify itself as a node on the network with the addresses provided. - netMap.SelfNode = &tailcfg.Node{ - ID: id, - StableID: stableID, - Key: nodePublicKey, - Addresses: netMap.Addresses, - AllowedIPs: append(netMap.Addresses, netaddr.MustParseIPPrefix("::/0")), - Endpoints: []string{}, - DERP: DefaultDerpHome, - } - - wgMonitor, err := monitor.New(Logf) - if err != nil { - return nil, xerrors.Errorf("create link monitor: %w", err) - } - - dialer := new(tsdial.Dialer) - dialer.Logf = Logf - // Create a wireguard engine in userspace. - engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{ - LinkMonitor: wgMonitor, - Dialer: dialer, - }) - if err != nil { - return nil, xerrors.Errorf("create wgengine: %w", err) - } - - // This is taken from Tailscale: - // https://github.com/tailscale/tailscale/blob/0f05b2c13ff0c305aa7a1655fa9c17ed969d65be/tsnet/tsnet.go#L247-L255 - // nolint - tunDev, magicConn, dnsManager, ok := engine.(wgengine.InternalsGetter).GetInternals() - if !ok { - return nil, xerrors.New("could not get wgengine internals") - } - - // Update the keys for the magic connection! - err = magicConn.SetPrivateKey(nodePrivateKey) - if err != nil { - return nil, xerrors.Errorf("set node private key: %w", err) - } - netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() - - // Create the networking stack. - // This is called to route connections. - netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager) - if err != nil { - return nil, xerrors.Errorf("create netstack: %w", err) - } - netStack.ProcessLocalIPs = true - netStack.ProcessSubnets = true - dialer.UseNetstackForIP = func(ip netaddr.IP) bool { - _, ok := engine.PeerForIP(ip) - return ok - } - dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { - return netStack.DialContextTCP(ctx, dst) - } - err = netStack.Start() - if err != nil { - return nil, xerrors.Errorf("start netstack: %w", err) - } - engine = wgengine.NewWatchdog(engine) - - // Update the wireguard configuration to allow traffic to flow. - cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID) - if err != nil { - return nil, xerrors.Errorf("create wgcfg: %w", err) - } - - rtr := &router.Config{ - LocalAddrs: cfg.Addresses, - } - err = engine.Reconfig(cfg, rtr, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - return nil, xerrors.Errorf("reconfig: %w", err) - } - - engine.SetDERPMap(DerpMap) - engine.SetNetworkMap(copyNetMap(netMap)) - - ipb := netaddr.IPSetBuilder{} - for _, addr := range netMap.Addresses { - ipb.AddPrefix(addr) - } - ips, _ := ipb.IPSet() - - iplb := netaddr.IPSetBuilder{} - ipl, _ := iplb.IPSet() - engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf)) - - wn := &Network{ - logger: logger, - NodePrivateKey: nodePrivateKey, - DiscoPublicKey: magicConn.DiscoPublicKey(), - wgEngine: engine, - Netstack: netStack, - magicSock: magicConn, - netMap: netMap, - router: rtr, - listeners: map[listenKey]*listener{}, - } - netStack.ForwardTCPIn = wn.forwardTCP - - return wn, nil -} - -// forwardTCP handles incoming connections from Wireguard in userspace. -func (n *Network) forwardTCP(conn net.Conn, port uint16) { - n.mu.Lock() - listener, ok := n.listeners[listenKey{"tcp", "", fmt.Sprint(port)}] - n.mu.Unlock() - if !ok { - // No in-memory listener exists, forward to host. - n.forwardTCPToLocalHandler(conn, port) - return - } - - timer := time.NewTimer(time.Second) - defer timer.Stop() - select { - case listener.conn <- conn: - case <-timer.C: - _ = conn.Close() - } -} - -// forwardTCPToLocalHandler forwards the provided net.Conn to the -// matching port bound to localhost. -func (n *Network) forwardTCPToLocalHandler(c net.Conn, port uint16) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - defer c.Close() - - dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port))) - var stdDialer net.Dialer - server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) - if err != nil { - n.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err)) - return - } - defer server.Close() - - connClosed := make(chan error, 2) - go func() { - _, err := io.Copy(server, c) - connClosed <- err - }() - go func() { - _, err := io.Copy(c, server) - connClosed <- err - }() - err = <-connClosed - if err != nil { - n.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err)) - } - n.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) -} - -// AddPeer allows connections from another Wireguard instance with the -// handshake credentials. -func (n *Network) AddPeer(handshake Handshake) error { - n.mu.Lock() - defer n.mu.Unlock() - - // If the peer already exists in the network map, do nothing. - for _, p := range n.netMap.Peers { - if p.Key == handshake.NodePublicKey { - n.logger.Debug(context.Background(), "peer already in netmap", slog.F("peer", handshake.NodePublicKey.ShortString())) - return nil - } - } - - // The Tailscale engine owns this slice, so we need to copy to make - // modifications. - peers := append(([]*tailcfg.Node)(nil), n.netMap.Peers...) - - id, stableID := nodeIDs(handshake.NodePublicKey) - peers = append(peers, &tailcfg.Node{ - ID: id, - StableID: stableID, - Name: handshake.NodePublicKey.String() + ".com", - Key: handshake.NodePublicKey, - DiscoKey: handshake.DiscoPublicKey, - Addresses: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)}, - AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)}, - DERP: DefaultDerpHome, - Endpoints: []string{DefaultDerpHome}, - }) - - n.netMap.Peers = peers - - cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL")) - if err != nil { - return xerrors.Errorf("create wgcfg: %w", err) - } - - err = n.wgEngine.Reconfig(cfg, n.router, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - return xerrors.Errorf("reconfig: %w", err) - } - - // Always give the Tailscale engine a copy of our network map. - n.wgEngine.SetNetworkMap(copyNetMap(n.netMap)) - return nil -} - -// Ping sends a discovery ping to the provided peer. -// The peer address must be connected before a successful ping will work. -func (n *Network) Ping(ip netaddr.IP) *ipnstate.PingResult { - ch := make(chan *ipnstate.PingResult) - n.wgEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { - ch <- pr - }) - return <-ch -} - -// Listener returns a net.Listener in userspace that can be used to accept -// connections from the Wireguard network to the specified address. If a -// listener exists for a given address, all connections will be forwarded to the -// listener instead of being routed to the host. -func (n *Network) Listen(network, addr string) (net.Listener, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, xerrors.Errorf("split addr host port: %w", err) - } - - lkey := listenKey{network, host, port} - ln := &listener{ - wn: n, - key: lkey, - addr: addr, - - conn: make(chan net.Conn, 1), - } - - n.mu.Lock() - defer n.mu.Unlock() - - if _, ok := n.listeners[lkey]; ok { - return nil, xerrors.Errorf("listener already open for %s, %s", network, addr) - } - n.listeners[lkey] = ln - - return ln, nil -} - -func (n *Network) Close() error { - // Close all listeners. - for _, l := range n.listeners { - _ = l.Close() - } - - // Close the Wireguard netstack and engine. - _ = n.Netstack.Close() - n.wgEngine.Close() - - return nil -} - -type listenKey struct { - network string - host string - port string -} - -type listener struct { - wn *Network - key listenKey - addr string - conn chan net.Conn -} - -func (ln *listener) Accept() (net.Conn, error) { - c, ok := <-ln.conn - if !ok { - return nil, xerrors.Errorf("tsnet: %w", net.ErrClosed) - } - return c, nil -} - -func (ln *listener) Addr() net.Addr { return addr{ln} } -func (ln *listener) Close() error { - ln.wn.mu.Lock() - defer ln.wn.mu.Unlock() - - if v, ok := ln.wn.listeners[ln.key]; ok && v == ln { - delete(ln.wn.listeners, ln.key) - close(ln.conn) - } - - return nil -} - -type addr struct{ ln *listener } - -func (a addr) Network() string { return a.ln.key.network } -func (a addr) String() string { return a.ln.addr } - -// nodeIDs generates Tailscale node IDs for the provided public key. -func nodeIDs(public key.NodePublic) (tailcfg.NodeID, tailcfg.StableNodeID) { - idhash := fnv.New64() - pub, _ := public.MarshalText() - _, _ = idhash.Write(pub) - - return tailcfg.NodeID(idhash.Sum64()), tailcfg.StableNodeID(pub) -} - -func copyNetMap(nm *netmap.NetworkMap) *netmap.NetworkMap { - nmCopy := *nm - return &nmCopy -} diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 655787e0abb72..fe1d809f77ea7 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -175,6 +175,12 @@ export interface CreateWorkspaceRequest { readonly parameter_values?: CreateParameterRequest[] } +// From codersdk/workspaceresources.go +export interface DERPRegion { + readonly preferred: boolean + readonly latency_ms: number +} + // From codersdk/features.go export interface Entitlements { readonly features: Record @@ -509,17 +515,9 @@ export interface WorkspaceAgent { readonly operating_system: string readonly startup_script?: string readonly directory?: string - readonly apps: WorkspaceApp[] - // Named type "tailscale.com/types/key.NodePublic" unknown, using "any" - // eslint-disable-next-line @typescript-eslint/no-explicit-any - readonly wireguard_public_key: any - // Named type "tailscale.com/types/key.DiscoPublic" unknown, using "any" - // eslint-disable-next-line @typescript-eslint/no-explicit-any - readonly disco_public_key: any - // Named type "inet.af/netaddr.IPPrefix" unknown, using "any" - // eslint-disable-next-line @typescript-eslint/no-explicit-any - readonly ipv6: any readonly version: string + readonly apps: WorkspaceApp[] + readonly latency: Record } // From codersdk/workspaceagents.go @@ -527,6 +525,13 @@ export interface WorkspaceAgentAuthenticateResponse { readonly session_token: string } +// From codersdk/workspaceagents.go +export interface WorkspaceAgentConnectionInfo { + // Named type "tailscale.com/tailcfg.DERPMap" unknown, using "any" + // eslint-disable-next-line @typescript-eslint/no-explicit-any + readonly derp_map?: any +} + // From codersdk/workspaceresources.go export interface WorkspaceAgentInstanceMetadata { readonly jail_orchestrator: string diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index d9d61b67ce995..b11e43131c2a2 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -302,10 +302,8 @@ export const MockWorkspaceAgent: TypesGen.WorkspaceAgent = { resource_id: "", status: "connected", updated_at: "", - wireguard_public_key: "", - disco_public_key: "", - ipv6: "", version: MockBuildInfo.version, + latency: {}, } export const MockWorkspaceAgentDisconnected: TypesGen.WorkspaceAgent = { diff --git a/tailnet/conn.go b/tailnet/conn.go new file mode 100644 index 0000000000000..84aed3fc1ae81 --- /dev/null +++ b/tailnet/conn.go @@ -0,0 +1,516 @@ +package tailnet + +import ( + "context" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + "go4.org/netipx" + "golang.org/x/xerrors" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "tailscale.com/hostinfo" + "tailscale.com/ipn/ipnstate" + "tailscale.com/net/dns" + "tailscale.com/net/netns" + "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + tslogger "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/monitor" + "tailscale.com/wgengine/netstack" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg/nmcfg" + + "github.com/coder/coder/cryptorand" + + "cdr.dev/slog" +) + +func init() { + // Globally disable network namespacing. + // All networking happens in userspace. + netns.SetEnabled(false) +} + +type Options struct { + Addresses []netip.Prefix + DERPMap *tailcfg.DERPMap + + Logger slog.Logger +} + +// NewConn constructs a new Wireguard server that will accept connections from the addresses provided. +func NewConn(options *Options) (*Conn, error) { + if options == nil { + options = &Options{} + } + if len(options.Addresses) == 0 { + return nil, xerrors.New("At least one IP range must be provided") + } + if options.DERPMap == nil { + return nil, xerrors.New("DERPMap must be provided") + } + nodePrivateKey := key.NewNode() + nodePublicKey := nodePrivateKey.Public() + + netMap := &netmap.NetworkMap{ + NodeKey: nodePublicKey, + PrivateKey: nodePrivateKey, + Addresses: options.Addresses, + PacketFilter: []filter.Match{{ + // Allow any protocol! + IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP}, + // Allow traffic sourced from anywhere. + Srcs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), + netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), + }, + // Allow traffic to route anywhere. + Dsts: []filter.NetPortRange{ + { + Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), + Ports: filter.PortRange{ + First: 0, + Last: 65535, + }, + }, + { + Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), + Ports: filter.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + Caps: []filter.CapMatch{}, + }}, + } + nodeID, err := cryptorand.Int63() + if err != nil { + return nil, xerrors.Errorf("generate node id: %w", err) + } + // This is used by functions below to identify the node via key + netMap.SelfNode = &tailcfg.Node{ + ID: tailcfg.NodeID(nodeID), + Key: nodePublicKey, + Addresses: options.Addresses, + AllowedIPs: options.Addresses, + } + + wireguardMonitor, err := monitor.New(Logger(options.Logger.Named("wgmonitor"))) + if err != nil { + return nil, xerrors.Errorf("create wireguard link monitor: %w", err) + } + + dialer := &tsdial.Dialer{ + Logf: Logger(options.Logger), + } + wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("wgengine")), wgengine.Config{ + LinkMonitor: wireguardMonitor, + Dialer: dialer, + }) + if err != nil { + return nil, xerrors.Errorf("create wgengine: %w", err) + } + dialer.UseNetstackForIP = func(ip netip.Addr) bool { + _, ok := wireguardEngine.PeerForIP(ip) + return ok + } + + // This is taken from Tailscale: + // https://github.com/tailscale/tailscale/blob/0f05b2c13ff0c305aa7a1655fa9c17ed969d65be/tsnet/tsnet.go#L247-L255 + wireguardInternals, ok := wireguardEngine.(wgengine.InternalsGetter) + if !ok { + return nil, xerrors.Errorf("wireguard engine isn't the correct type %T", wireguardEngine) + } + tunDevice, magicConn, dnsManager, ok := wireguardInternals.GetInternals() + if !ok { + return nil, xerrors.New("failed to get wireguard internals") + } + + // Update the keys for the magic connection! + err = magicConn.SetPrivateKey(nodePrivateKey) + if err != nil { + return nil, xerrors.Errorf("set node private key: %w", err) + } + netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() + + netStack, err := netstack.Create( + Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager) + if err != nil { + return nil, xerrors.Errorf("create netstack: %w", err) + } + dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + return netStack.DialContextTCP(ctx, dst) + } + netStack.ProcessLocalIPs = true + err = netStack.Start() + if err != nil { + return nil, xerrors.Errorf("start netstack: %w", err) + } + wireguardEngine = wgengine.NewWatchdog(wireguardEngine) + + // Update the wireguard configuration to allow traffic to flow. + wireguardConfig, err := nmcfg.WGCfg(netMap, Logger(options.Logger.Named("wgconfig")), netmap.AllowSingleHosts, "") + if err != nil { + return nil, xerrors.Errorf("create wgcfg: %w", err) + } + + wireguardRouter := &router.Config{ + LocalAddrs: wireguardConfig.Addresses, + } + err = wireguardEngine.Reconfig(wireguardConfig, wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) + if err != nil { + return nil, xerrors.Errorf("reconfig: %w", err) + } + + wireguardEngine.SetDERPMap(options.DERPMap) + netMapCopy := *netMap + wireguardEngine.SetNetworkMap(&netMapCopy) + + localIPSet := netipx.IPSetBuilder{} + for _, addr := range netMap.Addresses { + localIPSet.AddPrefix(addr) + } + localIPs, _ := localIPSet.IPSet() + logIPSet := netipx.IPSetBuilder{} + logIPs, _ := logIPSet.IPSet() + wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) + server := &Conn{ + closed: make(chan struct{}), + logger: options.Logger, + magicConn: magicConn, + dialer: dialer, + listeners: map[listenKey]*listener{}, + tunDevice: tunDevice, + netMap: netMap, + netStack: netStack, + wireguardMonitor: wireguardMonitor, + wireguardRouter: wireguardRouter, + wireguardEngine: wireguardEngine, + } + netStack.ForwardTCPIn = server.forwardTCP + return server, nil +} + +// IP generates a new IP with a static service prefix. +func IP() netip.Addr { + // This is Tailscale's ephemeral service prefix. + // This can be changed easily later-on, because + // all of our nodes are ephemeral. + // fd7a:115c:a1e0 + uid := uuid.New() + uid[0] = 0xfd + uid[1] = 0x7a + uid[2] = 0x11 + uid[3] = 0x5c + uid[4] = 0xa1 + uid[5] = 0xe0 + return netip.AddrFrom16(uid) +} + +// Conn is an actively listening Wireguard connection. +type Conn struct { + mutex sync.Mutex + closed chan struct{} + logger slog.Logger + + dialer *tsdial.Dialer + tunDevice *tstun.Wrapper + netMap *netmap.NetworkMap + netStack *netstack.Impl + magicConn *magicsock.Conn + wireguardMonitor *monitor.Mon + wireguardRouter *router.Config + wireguardEngine wgengine.Engine + listeners map[listenKey]*listener + + lastMutex sync.Mutex + // It's only possible to store these values via status functions, + // so the values must be stored for retrieval later on. + lastEndpoints []string + lastPreferredDERP int + lastDERPLatency map[string]float64 +} + +// SetNodeCallback is triggered when a network change occurs and peer +// renegotiation may be required. Clients should constantly be emitting +// node changes. +func (c *Conn) SetNodeCallback(callback func(node *Node)) { + makeNode := func() *Node { + return &Node{ + ID: c.netMap.SelfNode.ID, + Key: c.netMap.SelfNode.Key, + Addresses: c.netMap.SelfNode.Addresses, + AllowedIPs: c.netMap.SelfNode.AllowedIPs, + DiscoKey: c.magicConn.DiscoPublicKey(), + Endpoints: c.lastEndpoints, + PreferredDERP: c.lastPreferredDERP, + DERPLatency: c.lastDERPLatency, + } + } + c.magicConn.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { + c.lastMutex.Lock() + c.lastPreferredDERP = ni.PreferredDERP + c.lastDERPLatency = ni.DERPLatency + node := makeNode() + c.lastMutex.Unlock() + callback(node) + }) + c.wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { + if err != nil { + return + } + endpoints := make([]string, 0, len(s.LocalAddrs)) + for _, addr := range s.LocalAddrs { + endpoints = append(endpoints, addr.Addr.String()) + } + c.lastMutex.Lock() + c.lastEndpoints = endpoints + node := makeNode() + c.lastMutex.Unlock() + callback(node) + }) +} + +// SetDERPMap updates the DERPMap of a connection. +func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.wireguardEngine.SetDERPMap(derpMap) +} + +// UpdateNodes connects with a set of peers. This can be constantly updated, +// and peers will continually be reconnected as necessary. +func (c *Conn) UpdateNodes(nodes []*Node) error { + c.mutex.Lock() + defer c.mutex.Unlock() + peerMap := map[tailcfg.NodeID]*tailcfg.Node{} + status := c.Status() + for _, peer := range c.netMap.Peers { + if peerStatus, ok := status.Peer[peer.Key]; ok { + // Clear out inactive connections! + if !peerStatus.Active { + continue + } + } + peerMap[peer.ID] = peer + } + for _, node := range nodes { + peerMap[node.ID] = &tailcfg.Node{ + ID: node.ID, + Key: node.Key, + DiscoKey: node.DiscoKey, + Addresses: node.Addresses, + AllowedIPs: node.AllowedIPs, + Endpoints: node.Endpoints, + DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP), + Hostinfo: hostinfo.New().View(), + } + } + c.netMap.Peers = make([]*tailcfg.Node, 0, len(peerMap)) + for _, peer := range peerMap { + c.netMap.Peers = append(c.netMap.Peers, peer) + } + cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") + if err != nil { + return xerrors.Errorf("update wireguard config: %w", err) + } + err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + netMapCopy := *c.netMap + c.wireguardEngine.SetNetworkMap(&netMapCopy) + return nil +} + +// Status returns the current ipnstate of a connection. +func (c *Conn) Status() *ipnstate.Status { + sb := &ipnstate.StatusBuilder{} + c.magicConn.UpdateStatus(sb) + return sb.Status() +} + +// Ping sends a ping to the Wireguard engine. +func (c *Conn) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func(*ipnstate.PingResult)) { + c.wireguardEngine.Ping(ip, pingType, cb) +} + +// Closed is a channel that ends when the connection has +// been closed. +func (c *Conn) Closed() <-chan struct{} { + return c.closed +} + +// Close shuts down the Wireguard connection. +func (c *Conn) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + select { + case <-c.closed: + return nil + default: + } + for _, l := range c.listeners { + _ = l.closeNoLock() + } + close(c.closed) + _ = c.dialer.Close() + _ = c.magicConn.Close() + _ = c.netStack.Close() + _ = c.wireguardMonitor.Close() + _ = c.tunDevice.Close() + c.wireguardEngine.Close() + return nil +} + +// This and below is taken _mostly_ verbatim from Tailscale: +// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494 + +// Listen announces only on the Tailscale network. +// It will start the server if it has not been started yet. +func (c *Conn) Listen(network, addr string) (net.Listener, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, xerrors.Errorf("wgnet: %w", err) + } + lk := listenKey{network, host, port} + ln := &listener{ + s: c, + key: lk, + addr: addr, + + conn: make(chan net.Conn), + } + c.mutex.Lock() + if c.listeners == nil { + c.listeners = map[listenKey]*listener{} + } + if _, ok := c.listeners[lk]; ok { + c.mutex.Unlock() + return nil, xerrors.Errorf("wgnet: listener already open for %s, %s", network, addr) + } + c.listeners[lk] = ln + c.mutex.Unlock() + return ln, nil +} + +func (c *Conn) DialContextTCP(ctx context.Context, ipp netip.AddrPort) (*gonet.TCPConn, error) { + return c.netStack.DialContextTCP(ctx, ipp) +} + +func (c *Conn) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.UDPConn, error) { + return c.netStack.DialContextUDP(ctx, ipp) +} + +func (c *Conn) forwardTCP(conn net.Conn, port uint16) { + c.mutex.Lock() + ln, ok := c.listeners[listenKey{"tcp", "", fmt.Sprint(port)}] + c.mutex.Unlock() + if !ok { + c.forwardTCPToLocal(conn, port) + return + } + t := time.NewTimer(time.Second) + defer t.Stop() + select { + case ln.conn <- conn: + case <-t.C: + _ = conn.Close() + } +} + +func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + defer conn.Close() + + dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port))) + var stdDialer net.Dialer + server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) + if err != nil { + c.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err)) + return + } + defer server.Close() + + connClosed := make(chan error, 2) + go func() { + _, err := io.Copy(server, conn) + connClosed <- err + }() + go func() { + _, err := io.Copy(conn, server) + connClosed <- err + }() + select { + case err = <-connClosed: + case <-c.closed: + return + } + if err != nil { + c.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err)) + } + c.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) +} + +type listenKey struct { + network string + host string + port string +} + +type listener struct { + s *Conn + key listenKey + addr string + conn chan net.Conn +} + +func (ln *listener) Accept() (net.Conn, error) { + c, ok := <-ln.conn + if !ok { + return nil, xerrors.Errorf("wgnet: %w", net.ErrClosed) + } + return c, nil +} + +func (ln *listener) Addr() net.Addr { return addr{ln} } +func (ln *listener) Close() error { + ln.s.mutex.Lock() + defer ln.s.mutex.Unlock() + return ln.closeNoLock() +} + +func (ln *listener) closeNoLock() error { + if v, ok := ln.s.listeners[ln.key]; ok && v == ln { + delete(ln.s.listeners, ln.key) + close(ln.conn) + } + return nil +} + +type addr struct{ ln *listener } + +func (a addr) Network() string { return a.ln.key.network } +func (a addr) String() string { return a.ln.addr } + +// Logger converts the Tailscale logging function to use slog. +func Logger(logger slog.Logger) tslogger.Logf { + return tslogger.Logf(func(format string, args ...any) { + logger.Debug(context.Background(), fmt.Sprintf(format, args...)) + }) +} diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go new file mode 100644 index 0000000000000..32c3083604ae5 --- /dev/null +++ b/tailnet/conn_test.go @@ -0,0 +1,83 @@ +package tailnet_test + +import ( + "context" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestTailnet(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + derpMap := tailnettest.RunDERPAndSTUN(t) + t.Run("InstantClose", func(t *testing.T) { + t.Parallel() + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + Logger: logger.Named("w1"), + DERPMap: derpMap, + }) + require.NoError(t, err) + err = conn.Close() + require.NoError(t, err) + }) + t.Run("Connect", func(t *testing.T) { + t.Parallel() + w1IP := tailnet.IP() + w1, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, + Logger: logger.Named("w1"), + DERPMap: derpMap, + }) + require.NoError(t, err) + + w2, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + Logger: logger.Named("w2"), + DERPMap: derpMap, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = w1.Close() + _ = w2.Close() + }) + w1.SetNodeCallback(func(node *tailnet.Node) { + w2.UpdateNodes([]*tailnet.Node{node}) + }) + w2.SetNodeCallback(func(node *tailnet.Node) { + w1.UpdateNodes([]*tailnet.Node{node}) + }) + + conn := make(chan struct{}) + go func() { + listener, err := w1.Listen("tcp", ":35565") + assert.NoError(t, err) + defer listener.Close() + nc, err := listener.Accept() + assert.NoError(t, err) + _ = nc.Close() + conn <- struct{}{} + }() + + nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565)) + require.NoError(t, err) + _ = nc.Close() + <-conn + + w1.Close() + w2.Close() + }) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go new file mode 100644 index 0000000000000..db8dffa0ebaef --- /dev/null +++ b/tailnet/coordinator.go @@ -0,0 +1,259 @@ +package tailnet + +import ( + "encoding/json" + "errors" + "io" + "net" + "net/netip" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// Node represents a node in the network. +type Node struct { + ID tailcfg.NodeID `json:"id"` + Key key.NodePublic `json:"key"` + DiscoKey key.DiscoPublic `json:"disco"` + PreferredDERP int `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + Addresses []netip.Prefix `json:"addresses"` + AllowedIPs []netip.Prefix `json:"allowed_ips"` + Endpoints []string `json:"endpoints"` +} + +// ServeCoordinator matches the RW structure of a coordinator to exchange node messages. +func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) { + errChan := make(chan error, 3) + go func() { + decoder := json.NewDecoder(conn) + for { + var nodes []*Node + err := decoder.Decode(&nodes) + if err != nil { + errChan <- xerrors.Errorf("read: %w", err) + return + } + err = updateNodes(nodes) + if err != nil { + errChan <- xerrors.Errorf("update nodes: %w", err) + } + } + }() + + return func(node *Node) { + data, err := json.Marshal(node) + if err != nil { + errChan <- xerrors.Errorf("marshal node: %w", err) + return + } + _, err = conn.Write(data) + if err != nil { + errChan <- xerrors.Errorf("write: %w", err) + } + }, errChan +} + +// NewCoordinator constructs a new in-memory connection coordinator. +func NewCoordinator() *Coordinator { + return &Coordinator{ + nodes: map[uuid.UUID]*Node{}, + agentSockets: map[uuid.UUID]net.Conn{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, + } +} + +// Coordinator exchanges nodes with agents to establish connections. +// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ +// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ +// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ +// This coordinator is incompatible with multiple Coder +// replicas as all node data is in-memory. +type Coordinator struct { + mutex sync.Mutex + + // Maps agent and connection IDs to a node. + nodes map[uuid.UUID]*Node + // Maps agent ID to an open socket. + agentSockets map[uuid.UUID]net.Conn + // Maps agent ID to connection ID for sending + // new node data as it comes in! + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn +} + +// Node returns an in-memory node by ID. +func (c *Coordinator) Node(id uuid.UUID) *Node { + c.mutex.Lock() + defer c.mutex.Unlock() + node := c.nodes[id] + return node +} + +// ServeClient accepts a WebSocket connection that wants to +// connect to an agent with the specified ID. +func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + c.mutex.Lock() + // When a new connection is requested, we update it with the latest + // node of the agent. This allows the connection to establish. + node, ok := c.nodes[agent] + if ok { + data, err := json.Marshal([]*Node{node}) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("marshal node: %w", err) + } + _, err = conn.Write(data) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("write nodes: %w", err) + } + } + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + connectionSockets = map[uuid.UUID]net.Conn{} + c.agentToConnectionSockets[agent] = connectionSockets + } + // Insert this connection into a map so the agent + // can publish node updates. + connectionSockets[id] = conn + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + return + } + delete(connectionSockets, id) + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agent) + }() + + decoder := json.NewDecoder(conn) + for { + var node Node + err := decoder.Decode(&node) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + c.mutex.Lock() + // Update the node of this client in our in-memory map. + // If an agent entirely shuts down and reconnects, it + // needs to be aware of all clients attempting to + // establish connections. + c.nodes[id] = &node + agentSocket, ok := c.agentSockets[agent] + if !ok { + c.mutex.Unlock() + continue + } + // Write the new node from this client to the actively + // connected agent. + data, err := json.Marshal([]*Node{&node}) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("marshal nodes: %w", err) + } + _, err = agentSocket.Write(data) + if errors.Is(err, io.EOF) { + c.mutex.Unlock() + return nil + } + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("write json: %w", err) + } + c.mutex.Unlock() + } +} + +// ServeAgent accepts a WebSocket connection to an agent that +// listens to incoming connections and publishes node updates. +func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { + c.mutex.Lock() + sockets, ok := c.agentToConnectionSockets[id] + if ok { + // Publish all nodes that want to connect to the + // desired agent ID. + nodes := make([]*Node, 0, len(sockets)) + for targetID := range sockets { + node, ok := c.nodes[targetID] + if !ok { + continue + } + nodes = append(nodes, node) + } + data, err := json.Marshal(nodes) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("marshal json: %w", err) + } + _, err = conn.Write(data) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("write nodes: %w", err) + } + } + + // If an old agent socket is connected, we close it + // to avoid any leaks. This shouldn't ever occur because + // we expect one agent to be running. + oldAgentSocket, ok := c.agentSockets[id] + if ok { + _ = oldAgentSocket.Close() + } + c.agentSockets[id] = conn + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.agentSockets, id) + delete(c.nodes, id) + }() + + decoder := json.NewDecoder(conn) + for { + var node Node + err := decoder.Decode(&node) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + c.mutex.Lock() + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + c.mutex.Unlock() + continue + } + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + _, _ = connectionSocket.Write(data) + wg.Done() + }() + } + wg.Wait() + c.mutex.Unlock() + } +} diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go new file mode 100644 index 0000000000000..f3fdab88d5ef8 --- /dev/null +++ b/tailnet/coordinator_test.go @@ -0,0 +1,148 @@ +package tailnet_test + +import ( + "net" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestCoordinator(t *testing.T) { + t.Parallel() + t.Run("ClientWithoutAgent", func(t *testing.T) { + t.Parallel() + coordinator := tailnet.NewCoordinator() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(server, id, uuid.New()) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err := client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithoutClients", func(t *testing.T) { + t.Parallel() + coordinator := tailnet.NewCoordinator() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err := client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + coordinator := tailnet.NewCoordinator() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*tailnet.Node) + sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&tailnet.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*tailnet.Node) + sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&tailnet.Node{}) + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&tailnet.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + err := agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*tailnet.Node) + _, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} diff --git a/tailnet/derpmap.go b/tailnet/derpmap.go new file mode 100644 index 0000000000000..36ff02b89f9dc --- /dev/null +++ b/tailnet/derpmap.go @@ -0,0 +1,60 @@ +package tailnet + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "strconv" + + "golang.org/x/xerrors" + "tailscale.com/tailcfg" +) + +// NewDERPMap constructs a DERPMap from a set of STUN addresses and optionally a remote +// URL to fetch a mapping from e.g. https://controlplane.tailscale.com/derpmap/default. +func NewDERPMap(ctx context.Context, region *tailcfg.DERPRegion, stunAddrs []string, remoteURL string) (*tailcfg.DERPMap, error) { + for index, stunAddr := range stunAddrs { + host, rawPort, err := net.SplitHostPort(stunAddr) + if err != nil { + return nil, xerrors.Errorf("split host port for %q: %w", stunAddr, err) + } + port, err := strconv.Atoi(rawPort) + if err != nil { + return nil, xerrors.Errorf("parse port for %q: %w", stunAddr, err) + } + region.Nodes = append(region.Nodes, &tailcfg.DERPNode{ + Name: fmt.Sprintf("%dstun%d", region.RegionID, index), + RegionID: region.RegionID, + HostName: host, + STUNOnly: true, + STUNPort: port, + }) + } + + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + } + if remoteURL != "" { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, remoteURL, nil) + if err != nil { + return nil, xerrors.Errorf("create request: %w", err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("get derpmap: %w", err) + } + defer res.Body.Close() + err = json.NewDecoder(res.Body).Decode(&derpMap) + if err != nil { + return nil, xerrors.Errorf("fetch derpmap: %w", err) + } + } + _, conflicts := derpMap.Regions[region.RegionID] + if conflicts { + return nil, xerrors.Errorf("the default region ID conflicts with a remote region from %q", remoteURL) + } + derpMap.Regions[region.RegionID] = region + return derpMap, nil +} diff --git a/tailnet/derpmap_test.go b/tailnet/derpmap_test.go new file mode 100644 index 0000000000000..252ccef907c20 --- /dev/null +++ b/tailnet/derpmap_test.go @@ -0,0 +1,60 @@ +package tailnet_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + + "github.com/coder/coder/tailnet" +) + +func TestNewDERPMap(t *testing.T) { + t.Parallel() + t.Run("WithoutRemoteURL", func(t *testing.T) { + t.Parallel() + derpMap, err := tailnet.NewDERPMap(context.Background(), &tailcfg.DERPRegion{ + RegionID: 1, + Nodes: []*tailcfg.DERPNode{{}}, + }, []string{"stun.google.com:2345"}, "") + require.NoError(t, err) + require.Len(t, derpMap.Regions[1].Nodes, 2) + }) + t.Run("RemoteURL", func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, _ := json.Marshal(&tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {}, + }, + }) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + derpMap, err := tailnet.NewDERPMap(context.Background(), &tailcfg.DERPRegion{ + RegionID: 2, + }, []string{}, server.URL) + require.NoError(t, err) + require.Len(t, derpMap.Regions, 2) + }) + t.Run("RemoteConflicts", func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, _ := json.Marshal(&tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {}, + }, + }) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + _, err := tailnet.NewDERPMap(context.Background(), &tailcfg.DERPRegion{ + RegionID: 1, + }, []string{}, server.URL) + require.Error(t, err) + }) +} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go new file mode 100644 index 0000000000000..ce020bc0f42bc --- /dev/null +++ b/tailnet/tailnettest/tailnettest.go @@ -0,0 +1,63 @@ +package tailnettest + +import ( + "crypto/tls" + "net" + "net/http" + "net/http/httptest" + "testing" + + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/net/stun/stuntest" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + tslogger "tailscale.com/types/logger" + "tailscale.com/types/nettype" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/tailnet" +) + +// RunDERPAndSTUN creates a DERP mapping for tests. +func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap { + logf := tailnet.Logger(slogtest.Make(t, nil)) + d := derp.NewServer(key.NewNode(), logf) + server := httptest.NewUnstartedServer(derphttp.Handler(d)) + server.Config.ErrorLog = tslogger.StdLogger(logf) + server.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + server.StartTLS() + + stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{}) + t.Cleanup(func() { + server.CloseClientConnections() + server.Close() + d.Close() + stunCleanup() + }) + tcpAddr, ok := server.Listener.Addr().(*net.TCPAddr) + if !ok { + t.FailNow() + } + + return &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "test", + RegionName: "Test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "t2", + RegionID: 1, + IPv4: "127.0.0.1", + IPv6: "none", + STUNPort: stunAddr.Port, + DERPPort: tcpAddr.Port, + InsecureForTests: true, + }, + }, + }, + }, + } +} diff --git a/tailnet/tailnettest/tailnettest_test.go b/tailnet/tailnettest/tailnettest_test.go new file mode 100644 index 0000000000000..aebb018a9bcb2 --- /dev/null +++ b/tailnet/tailnettest/tailnettest_test.go @@ -0,0 +1,18 @@ +package tailnettest_test + +import ( + "testing" + + "go.uber.org/goleak" + + "github.com/coder/coder/tailnet/tailnettest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestRunDERPAndSTUN(t *testing.T) { + t.Parallel() + _ = tailnettest.RunDERPAndSTUN(t) +}