diff --git a/agent/agent.go b/agent/agent.go index dd900700c8913..47d9c394a86b9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -304,7 +304,18 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ if err != nil { return } - go a.sshServer.HandleConn(conn) + closed := make(chan struct{}) + _ = a.trackConnGoroutine(func() { + select { + case <-network.Closed(): + case <-closed: + } + _ = conn.Close() + }) + _ = a.trackConnGoroutine(func() { + defer close(closed) + a.sshServer.HandleConn(conn) + }) } }); err != nil { return nil, err diff --git a/cli/root.go b/cli/root.go index eb1bc98cf19e2..190bb35ac8a63 100644 --- a/cli/root.go +++ b/cli/root.go @@ -97,7 +97,7 @@ func Core() []*cobra.Command { update(), users(), versionCmd(), - vscodeipcCmd(), + vscodeSSH(), workspaceAgent(), } } diff --git a/cli/ssh.go b/cli/ssh.go index 3c1671c7849b1..0d642754aed59 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -256,7 +256,7 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder ) if shuffle { res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ - Owner: codersdk.Me, + Owner: userID, }) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err diff --git a/cli/vscodeipc.go b/cli/vscodeipc.go deleted file mode 100644 index 262fbf69aae8e..0000000000000 --- a/cli/vscodeipc.go +++ /dev/null @@ -1,88 +0,0 @@ -package cli - -import ( - "fmt" - "net" - "net/http" - "net/url" - - "github.com/google/uuid" - "github.com/spf13/cobra" - "golang.org/x/xerrors" - - "github.com/coder/coder/cli/cliflag" - "github.com/coder/coder/cli/vscodeipc" - "github.com/coder/coder/codersdk" -) - -// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages. -// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder -func vscodeipcCmd() *cobra.Command { - var ( - rawURL string - token string - port uint16 - ) - cmd := &cobra.Command{ - Use: "vscodeipc ", - Args: cobra.ExactArgs(1), - SilenceUsage: true, - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - if rawURL == "" { - return xerrors.New("CODER_URL must be set!") - } - // token is validated in a header on each request to prevent - // unauthenticated clients from connecting. - if token == "" { - return xerrors.New("CODER_TOKEN must be set!") - } - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - return xerrors.Errorf("listen: %w", err) - } - defer listener.Close() - addr, ok := listener.Addr().(*net.TCPAddr) - if !ok { - return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr()) - } - url, err := url.Parse(rawURL) - if err != nil { - return err - } - agentID, err := uuid.Parse(args[0]) - if err != nil { - return err - } - client := codersdk.New(url) - client.SetSessionToken(token) - - handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil) - if err != nil { - return err - } - defer closer.Close() - // nolint:gosec - server := http.Server{ - Handler: handler, - } - defer server.Close() - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String()) - errChan := make(chan error, 1) - go func() { - err := server.Serve(listener) - errChan <- err - }() - select { - case <-cmd.Context().Done(): - return cmd.Context().Err() - case err := <-errChan: - return err - } - }, - } - cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!") - cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!") - cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!") - return cmd -} diff --git a/cli/vscodeipc/vscodeipc.go b/cli/vscodeipc/vscodeipc.go deleted file mode 100644 index 9d4e094564da2..0000000000000 --- a/cli/vscodeipc/vscodeipc.go +++ /dev/null @@ -1,313 +0,0 @@ -package vscodeipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-chi/chi/v5" - "github.com/google/uuid" - "golang.org/x/crypto/ssh" - "golang.org/x/xerrors" - "tailscale.com/tailcfg" - - "github.com/coder/coder/agent" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -const AuthHeader = "Coder-IPC-Token" - -// New creates a VS Code IPC client that can be used to communicate with workspaces. -// -// Creating this IPC was required instead of using SSH, because we're unable to get -// connection information to display in the bottom-bar when using SSH. It's possible -// we could jank around this (maybe by using a temporary SSH host), but that's not -// ideal. -// -// This persists a single workspace connection, and lets you execute commands, check -// for network information, and forward ports. -// -// The VS Code extension is located at https://github.com/coder/vscode-coder. The -// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc` -// which calls this function. This API must maintain backward compatibility with -// the extension to support prior versions of Coder. -func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) { - if options == nil { - options = &codersdk.DialWorkspaceAgentOptions{} - } - // We need this to track upload and download! - options.EnableTrafficStats = true - - agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options) - if err != nil { - return nil, nil, err - } - api := &api{ - agentConn: agentConn, - } - r := chi.NewRouter() - // This is to prevent unauthorized clients on the same machine from executing - // requests on behalf of the workspace. - r.Use(sessionTokenMiddleware(client.SessionToken())) - r.Route("/v1", func(r chi.Router) { - r.Get("/port/{port}", api.port) - r.Get("/network", api.network) - r.Post("/execute", api.execute) - }) - return r, api, nil -} - -type api struct { - agentConn *codersdk.AgentConn - sshClient *ssh.Client - sshClientErr error - sshClientOnce sync.Once - - lastNetwork time.Time -} - -func (api *api) Close() error { - if api.sshClient != nil { - api.sshClient.Close() - } - return api.agentConn.Close() -} - -type NetworkResponse struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` -} - -// port accepts an HTTP request to dial a port on the workspace agent. -// It uses an HTTP connection upgrade to transfer the connection to TCP. -func (api *api) port(w http.ResponseWriter, r *http.Request) { - port, err := strconv.Atoi(chi.URLParam(r, "port")) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{ - Message: "Port must be an integer!", - }) - return - } - remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - httpapi.InternalServerError(w, err) - return - } - defer remoteConn.Close() - - // Upgrade an switch to TCP! - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "tcp") - w.WriteHeader(http.StatusSwitchingProtocols) - - hijacker, ok := w.(http.Hijacker) - if !ok { - httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w)) - return - } - - localConn, brw, err := hijacker.Hijack() - if err != nil { - httpapi.InternalServerError(w, err) - return - } - defer localConn.Close() - - _ = brw.Flush() - agent.Bicopy(r.Context(), localConn, remoteConn) -} - -// network returns network information about the workspace. -func (api *api) network(w http.ResponseWriter, r *http.Request) { - // Ping the workspace agent to get the latency. - latency, p2p, err := api.agentConn.Ping(r.Context()) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to ping the workspace agent.", - Detail: err.Error(), - }) - return - } - - node := api.agentConn.Node() - derpMap := api.agentConn.DERPMap() - derpLatency := map[string]float64{} - - // Convert DERP region IDs to friendly names for display in the UI. - for rawRegion, latency := range node.DERPLatency { - regionParts := strings.SplitN(rawRegion, "-", 2) - regionID, err := strconv.Atoi(regionParts[0]) - if err != nil { - continue - } - region, found := derpMap.Regions[regionID] - if !found { - // It's possible that a workspace agent is using an old DERPMap - // and reports regions that do not exist. If that's the case, - // report the region as unknown! - region = &tailcfg.DERPRegion{ - RegionID: regionID, - RegionName: fmt.Sprintf("Unnamed %d", regionID), - } - } - // Convert the microseconds to milliseconds. - derpLatency[region.RegionName] = latency * 1000 - } - - totalRx := uint64(0) - totalTx := uint64(0) - for _, stat := range api.agentConn.ExtractTrafficStats() { - totalRx += stat.RxBytes - totalTx += stat.TxBytes - } - // Tracking the time since last request is required because - // ExtractTrafficStats() resets its counters after each call. - dur := time.Since(api.lastNetwork) - uploadSecs := float64(totalTx) / dur.Seconds() - downloadSecs := float64(totalRx) / dur.Seconds() - - api.lastNetwork = time.Now() - - httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{ - P2P: p2p, - Latency: float64(latency.Microseconds()) / 1000, - PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName, - DERPLatency: derpLatency, - UploadBytesSec: int64(uploadSecs), - DownloadBytesSec: int64(downloadSecs), - }) -} - -type ExecuteRequest struct { - Command string `json:"command"` - Stdin string `json:"stdin"` -} - -type ExecuteResponse struct { - Data string `json:"data"` - ExitCode *int `json:"exit_code"` -} - -// execute runs the command provided, streams the output back, and returns the exit code. -func (api *api) execute(w http.ResponseWriter, r *http.Request) { - var req ExecuteRequest - if !httpapi.Read(r.Context(), w, r, &req) { - return - } - api.sshClientOnce.Do(func() { - // The SSH client is lazily created because it's not needed for - // all requests. It's only needed for the execute endpoint. - // - // It's alright if this fails on the first execution, because - // a new instance of this API is created for each remote SSH request. - api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background()) - }) - if api.sshClientErr != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create SSH client.", - Detail: api.sshClientErr.Error(), - }) - return - } - session, err := api.sshClient.NewSession() - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create SSH session.", - Detail: err.Error(), - }) - return - } - defer session.Close() - f, ok := w.(http.Flusher) - if !ok { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w), - }) - return - } - - execWriter := &execWriter{w, f} - session.Stdout = execWriter - session.Stderr = execWriter - session.Stdin = strings.NewReader(req.Stdin) - err = session.Start(req.Command) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to start SSH session.", - Detail: err.Error(), - }) - return - } - err = session.Wait() - - writeExit := func(exitCode int) { - data, _ := json.Marshal(&ExecuteResponse{ - ExitCode: &exitCode, - }) - _, _ = w.Write(data) - f.Flush() - } - - if err != nil { - var exitError *ssh.ExitError - if errors.As(err, &exitError) { - writeExit(exitError.ExitStatus()) - return - } - } - writeExit(0) -} - -type execWriter struct { - w http.ResponseWriter - f http.Flusher -} - -func (e *execWriter) Write(data []byte) (int, error) { - js, err := json.Marshal(&ExecuteResponse{ - Data: string(data), - }) - if err != nil { - return 0, err - } - _, err = e.w.Write(js) - if err != nil { - return 0, err - } - e.f.Flush() - return len(data), nil -} - -func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get(AuthHeader) - if token == "" { - httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ - Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader), - }) - return - } - if token != sessionToken { - httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ - Message: "The session token provided doesn't match the one used to create the client.", - }) - return - } - w.Header().Set("Access-Control-Allow-Origin", "*") - h.ServeHTTP(w, r) - }) - } -} diff --git a/cli/vscodeipc/vscodeipc_test.go b/cli/vscodeipc/vscodeipc_test.go deleted file mode 100644 index 5213c2422be6a..0000000000000 --- a/cli/vscodeipc/vscodeipc_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package vscodeipc_test - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "net/http/httptest" - "net/url" - "runtime" - "testing" - - "github.com/google/uuid" - "github.com/spf13/afero" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "nhooyr.io/websocket" - - "github.com/coder/coder/agent" - "github.com/coder/coder/cli/vscodeipc" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/tailnet" - "github.com/coder/coder/tailnet/tailnettest" - "github.com/coder/coder/testutil" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestVSCodeIPC(t *testing.T) { - t.Parallel() - ctx := context.Background() - - id := uuid.New() - derpMap := tailnettest.RunDERPAndSTUN(t) - coordinator := tailnet.NewCoordinator() - t.Cleanup(func() { - _ = coordinator.Close() - }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id): - assert.Equal(t, r.Method, http.MethodGet) - httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{ - DERPMap: derpMap, - }) - return - case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id): - assert.Equal(t, r.Method, http.MethodGet) - ws, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) - _ = coordinator.ServeClient(conn, uuid.New(), id) - return - case "/api/v2/workspaceagents/me/version": - assert.Equal(t, r.Method, http.MethodPost) - w.WriteHeader(http.StatusOK) - return - case "/api/v2/workspaceagents/me/metadata": - assert.Equal(t, r.Method, http.MethodGet) - httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{ - DERPMap: derpMap, - }) - return - case "/api/v2/workspaceagents/me/coordinate": - assert.Equal(t, r.Method, http.MethodGet) - ws, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) - _ = coordinator.ServeAgent(conn, id) - return - case "/api/v2/workspaceagents/me/report-stats": - assert.Equal(t, r.Method, http.MethodPost) - w.WriteHeader(http.StatusOK) - return - case "/": - w.WriteHeader(http.StatusOK) - return - default: - t.Fatalf("unexpected request %s", r.URL.Path) - } - })) - t.Cleanup(srv.Close) - srvURL, _ := url.Parse(srv.URL) - - client := codersdk.New(srvURL) - token := uuid.New().String() - client.SetSessionToken(token) - agentConn := agent.New(agent.Options{ - Client: client, - Filesystem: afero.NewMemMapFs(), - TempDir: t.TempDir(), - }) - t.Cleanup(func() { - _ = agentConn.Close() - }) - - handler, closer, err := vscodeipc.New(ctx, client, id, nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = closer.Close() - }) - - // Ensure that we're actually connected! - require.Eventually(t, func() bool { - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/v1/network", nil) - req.Header.Set(vscodeipc.AuthHeader, token) - handler.ServeHTTP(res, req) - network := &vscodeipc.NetworkResponse{} - err = json.NewDecoder(res.Body).Decode(&network) - assert.NoError(t, err) - return network.Latency != 0 - }, testutil.WaitLong, testutil.IntervalFast) - - _, port, err := net.SplitHostPort(srvURL.Host) - require.NoError(t, err) - - t.Run("NoSessionToken", func(t *testing.T) { - t.Parallel() - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - handler.ServeHTTP(res, req) - require.Equal(t, http.StatusUnauthorized, res.Code) - }) - - t.Run("MismatchedSessionToken", func(t *testing.T) { - t.Parallel() - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - req.Header.Set(vscodeipc.AuthHeader, uuid.NewString()) - handler.ServeHTTP(res, req) - require.Equal(t, http.StatusUnauthorized, res.Code) - }) - - t.Run("Port", func(t *testing.T) { - // Tests that the port endpoint can be used for forward traffic. - // For this test, we simply use the already listening httptest server. - t.Parallel() - input, output := net.Pipe() - defer input.Close() - defer output.Close() - res := &hijackable{httptest.NewRecorder(), output} - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - req.Header.Set(vscodeipc.AuthHeader, token) - go handler.ServeHTTP(res, req) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil) - require.NoError(t, err) - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return input, nil - }, - }, - } - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("Execute", func(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Execute isn't supported on Windows yet!") - return - } - - res := httptest.NewRecorder() - data, _ := json.Marshal(vscodeipc.ExecuteRequest{ - Command: "echo test", - }) - req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data)) - req.Header.Set(vscodeipc.AuthHeader, token) - handler.ServeHTTP(res, req) - - decoder := json.NewDecoder(res.Body) - var msg vscodeipc.ExecuteResponse - err = decoder.Decode(&msg) - require.NoError(t, err) - require.Equal(t, "test\n", msg.Data) - err = decoder.Decode(&msg) - require.NoError(t, err) - require.Equal(t, 0, *msg.ExitCode) - }) -} - -type hijackable struct { - *httptest.ResponseRecorder - conn net.Conn -} - -func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil -} diff --git a/cli/vscodeipc_test.go b/cli/vscodeipc_test.go deleted file mode 100644 index 1edb52102841c..0000000000000 --- a/cli/vscodeipc_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cli_test - -import ( - "io" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/cli/clitest" - "github.com/coder/coder/testutil" -) - -func TestVSCodeIPC(t *testing.T) { - t.Parallel() - // Ensures the vscodeipc command outputs it's running port! - // This signifies to the caller that it's ready to accept requests. - t.Run("PortOutputs", func(t *testing.T) { - t.Parallel() - client, workspace, _ := setupWorkspaceForAgent(t, nil) - cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(), - "--token", client.SessionToken(), "--url", client.URL.String()) - rdr, wtr := io.Pipe() - cmd.SetOut(wtr) - ctx, cancelFunc := testutil.Context(t) - defer cancelFunc() - done := make(chan error, 1) - go func() { - err := cmd.ExecuteContext(ctx) - done <- err - }() - - buf := make([]byte, 64) - require.Eventually(t, func() bool { - t.Log("Looking for address!") - var err error - _, err = rdr.Read(buf) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) - t.Logf("Address: %s\n", buf) - - cancelFunc() - <-done - }) -} diff --git a/cli/vscodessh.go b/cli/vscodessh.go new file mode 100644 index 0000000000000..f4f27b832d084 --- /dev/null +++ b/cli/vscodessh.go @@ -0,0 +1,237 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/spf13/afero" + "github.com/spf13/cobra" + "golang.org/x/xerrors" + "tailscale.com/tailcfg" + + "github.com/coder/coder/codersdk" +) + +// vscodeSSH is used by the Coder VS Code extension to establish +// a connection to a workspace. +// +// This command needs to remain stable for compatibility with +// various VS Code versions, so it's kept separate from our +// standard SSH command. +func vscodeSSH() *cobra.Command { + var ( + sessionTokenFile string + urlFile string + networkInfoDir string + networkInfoInterval time.Duration + ) + cmd := &cobra.Command{ + // A SSH config entry is added by the VS Code extension that + // passes %h to ProxyCommand. The prefix of `coder-vscode--` + // is a magical string represented in our VS Cod extension. + // It's not important here, only the delimiter `--` is. + Use: "vscodessh -->", + Hidden: true, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if networkInfoDir == "" { + return xerrors.New("network-info-dir must be specified") + } + if sessionTokenFile == "" { + return xerrors.New("session-token-file must be specified") + } + if urlFile == "" { + return xerrors.New("url-file must be specified") + } + + fs, ok := cmd.Context().Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + + sessionToken, err := afero.ReadFile(fs, sessionTokenFile) + if err != nil { + return xerrors.Errorf("read session token: %w", err) + } + rawURL, err := afero.ReadFile(fs, urlFile) + if err != nil { + return xerrors.Errorf("read url: %w", err) + } + serverURL, err := url.Parse(string(rawURL)) + if err != nil { + return xerrors.Errorf("parse url: %w", err) + } + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + err = fs.MkdirAll(networkInfoDir, 0700) + if err != nil { + return xerrors.Errorf("mkdir: %w", err) + } + + client := codersdk.New(serverURL) + client.SetSessionToken(string(sessionToken)) + + parts := strings.Split(args[0], "--") + if len(parts) < 3 { + return xerrors.Errorf("invalid argument format. must be: coder-vscode----") + } + owner := parts[1] + name := parts[2] + + workspace, err := client.WorkspaceByOwnerAndName(ctx, owner, name, codersdk.WorkspaceOptions{}) + if err != nil { + return xerrors.Errorf("find workspace: %w", err) + } + var agent codersdk.WorkspaceAgent + var found bool + for _, resource := range workspace.LatestBuild.Resources { + if len(resource.Agents) == 0 { + continue + } + for _, resourceAgent := range resource.Agents { + // If an agent name isn't included we default to + // the first agent! + if len(parts) != 4 { + agent = resourceAgent + found = true + break + } + if resourceAgent.Name != parts[3] { + continue + } + agent = resourceAgent + found = true + break + } + if found { + break + } + } + agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{ + EnableTrafficStats: true, + }) + if err != nil { + return xerrors.Errorf("dial workspace agent: %w", err) + } + defer agentConn.Close() + agentConn.AwaitReachable(ctx) + rawSSH, err := agentConn.SSH(ctx) + if err != nil { + return err + } + defer rawSSH.Close() + // Copy SSH traffic over stdio. + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + }() + go func() { + _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + }() + // The VS Code extension obtains the PID of the SSH process to + // read the file below which contains network information to display. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + ticker := time.NewTicker(networkInfoInterval) + defer ticker.Stop() + lastCollected := time.Now() + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + } + stats, err := collectNetworkStats(ctx, agentConn, lastCollected) + if err != nil { + return err + } + rawStats, err := json.Marshal(stats) + if err != nil { + return err + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0600) + if err != nil { + return err + } + lastCollected = time.Now() + } + }, + } + cmd.Flags().StringVarP(&networkInfoDir, "network-info-dir", "", "", "Specifies a directory to write network information periodically.") + cmd.Flags().StringVarP(&sessionTokenFile, "session-token-file", "", "", "Specifies a file that contains a session token.") + cmd.Flags().StringVarP(&urlFile, "url-file", "", "", "Specifies a file that contains the Coder URL.") + cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 3*time.Second, "Specifies the interval to update network information.") + return cmd +} + +type sshNetworkStats struct { + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` +} + +func collectNetworkStats(ctx context.Context, agentConn *codersdk.AgentConn, lastCollected time.Time) (*sshNetworkStats, error) { + latency, p2p, err := agentConn.Ping(ctx) + if err != nil { + return nil, err + } + node := agentConn.Node() + derpMap := agentConn.DERPMap() + derpLatency := map[string]float64{} + + // Convert DERP region IDs to friendly names for display in the UI. + for rawRegion, latency := range node.DERPLatency { + regionParts := strings.SplitN(rawRegion, "-", 2) + regionID, err := strconv.Atoi(regionParts[0]) + if err != nil { + continue + } + region, found := derpMap.Regions[regionID] + if !found { + // It's possible that a workspace agent is using an old DERPMap + // and reports regions that do not exist. If that's the case, + // report the region as unknown! + region = &tailcfg.DERPRegion{ + RegionID: regionID, + RegionName: fmt.Sprintf("Unnamed %d", regionID), + } + } + // Convert the microseconds to milliseconds. + derpLatency[region.RegionName] = latency * 1000 + } + + totalRx := uint64(0) + totalTx := uint64(0) + for _, stat := range agentConn.ExtractTrafficStats() { + totalRx += stat.RxBytes + totalTx += stat.TxBytes + } + // Tracking the time since last request is required because + // ExtractTrafficStats() resets its counters after each call. + dur := time.Since(lastCollected) + uploadSecs := float64(totalTx) / dur.Seconds() + downloadSecs := float64(totalRx) / dur.Seconds() + + return &sshNetworkStats{ + P2P: p2p, + Latency: float64(latency.Microseconds()) / 1000, + PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName, + DERPLatency: derpLatency, + UploadBytesSec: int64(uploadSecs), + DownloadBytesSec: int64(downloadSecs), + }, nil +} diff --git a/cli/vscodessh_test.go b/cli/vscodessh_test.go new file mode 100644 index 0000000000000..6af34913c97c9 --- /dev/null +++ b/cli/vscodessh_test.go @@ -0,0 +1,73 @@ +package cli_test + +import ( + "context" + "fmt" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/agent" + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +// TestVSCodeSSH ensures the agent connects properly with SSH +// and that network information is properly written to the FS. +func TestVSCodeSSH(t *testing.T) { + t.Parallel() + ctx, cancel := testutil.Context(t) + defer cancel() + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + user, err := client.User(ctx, codersdk.Me) + require.NoError(t, err) + + agentClient := codersdk.New(client.URL) + agentClient.SetSessionToken(agentToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + Logger: slogtest.Make(t, nil).Named("agent"), + }) + defer func() { + _ = agentCloser.Close() + }() + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + fs := afero.NewMemMapFs() + err = afero.WriteFile(fs, "/url", []byte(client.URL.String()), 0600) + require.NoError(t, err) + err = afero.WriteFile(fs, "/token", []byte(client.SessionToken()), 0600) + require.NoError(t, err) + + cmd, _ := clitest.New(t, + "vscodessh", + "--url-file", "/url", + "--session-token-file", "/token", + "--network-info-dir", "/net", + "--network-info-interval", "25ms", + fmt.Sprintf("coder-vscode--%s--%s", user.Username, workspace.Name)) + done := make(chan struct{}) + go func() { + //nolint // The above seems reasonable for a one-off test. + err := cmd.ExecuteContext(context.WithValue(ctx, "fs", fs)) + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + } + close(done) + }() + require.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) + cancel() + <-done +}