From 5736f49a5692b86fb93e435c4e1240730b77140b Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Jan 2023 19:40:19 -0600 Subject: [PATCH 1/5] feat: replace vscodeipc with vscodessh The VS Code extension has been refactored to use VS Code Remote SSH instead of using the private API. This changes the structure to continue using SSH, but output network information periodically to a file. --- cli/root.go | 2 +- cli/ssh.go | 2 +- cli/vscodeipc.go | 88 --------- cli/vscodeipc/vscodeipc.go | 313 -------------------------------- cli/vscodeipc/vscodeipc_test.go | 202 --------------------- cli/vscodeipc_test.go | 44 ----- cli/vscodessh.go | 237 ++++++++++++++++++++++++ cli/vscodessh_test.go | 71 ++++++++ 8 files changed, 310 insertions(+), 649 deletions(-) delete mode 100644 cli/vscodeipc.go delete mode 100644 cli/vscodeipc/vscodeipc.go delete mode 100644 cli/vscodeipc/vscodeipc_test.go delete mode 100644 cli/vscodeipc_test.go create mode 100644 cli/vscodessh.go create mode 100644 cli/vscodessh_test.go diff --git a/cli/root.go b/cli/root.go index eb1bc98cf19e2..190bb35ac8a63 100644 --- a/cli/root.go +++ b/cli/root.go @@ -97,7 +97,7 @@ func Core() []*cobra.Command { update(), users(), versionCmd(), - vscodeipcCmd(), + vscodeSSH(), workspaceAgent(), } } diff --git a/cli/ssh.go b/cli/ssh.go index 3c1671c7849b1..0d642754aed59 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -256,7 +256,7 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder ) if shuffle { res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ - Owner: codersdk.Me, + Owner: userID, }) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err diff --git a/cli/vscodeipc.go b/cli/vscodeipc.go deleted file mode 100644 index 262fbf69aae8e..0000000000000 --- a/cli/vscodeipc.go +++ /dev/null @@ -1,88 +0,0 @@ -package cli - -import ( - "fmt" - "net" - "net/http" - "net/url" - - "github.com/google/uuid" - "github.com/spf13/cobra" - "golang.org/x/xerrors" - - "github.com/coder/coder/cli/cliflag" - "github.com/coder/coder/cli/vscodeipc" - "github.com/coder/coder/codersdk" -) - -// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages. -// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder -func vscodeipcCmd() *cobra.Command { - var ( - rawURL string - token string - port uint16 - ) - cmd := &cobra.Command{ - Use: "vscodeipc ", - Args: cobra.ExactArgs(1), - SilenceUsage: true, - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - if rawURL == "" { - return xerrors.New("CODER_URL must be set!") - } - // token is validated in a header on each request to prevent - // unauthenticated clients from connecting. - if token == "" { - return xerrors.New("CODER_TOKEN must be set!") - } - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - return xerrors.Errorf("listen: %w", err) - } - defer listener.Close() - addr, ok := listener.Addr().(*net.TCPAddr) - if !ok { - return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr()) - } - url, err := url.Parse(rawURL) - if err != nil { - return err - } - agentID, err := uuid.Parse(args[0]) - if err != nil { - return err - } - client := codersdk.New(url) - client.SetSessionToken(token) - - handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil) - if err != nil { - return err - } - defer closer.Close() - // nolint:gosec - server := http.Server{ - Handler: handler, - } - defer server.Close() - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String()) - errChan := make(chan error, 1) - go func() { - err := server.Serve(listener) - errChan <- err - }() - select { - case <-cmd.Context().Done(): - return cmd.Context().Err() - case err := <-errChan: - return err - } - }, - } - cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!") - cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!") - cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!") - return cmd -} diff --git a/cli/vscodeipc/vscodeipc.go b/cli/vscodeipc/vscodeipc.go deleted file mode 100644 index 9d4e094564da2..0000000000000 --- a/cli/vscodeipc/vscodeipc.go +++ /dev/null @@ -1,313 +0,0 @@ -package vscodeipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-chi/chi/v5" - "github.com/google/uuid" - "golang.org/x/crypto/ssh" - "golang.org/x/xerrors" - "tailscale.com/tailcfg" - - "github.com/coder/coder/agent" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" -) - -const AuthHeader = "Coder-IPC-Token" - -// New creates a VS Code IPC client that can be used to communicate with workspaces. -// -// Creating this IPC was required instead of using SSH, because we're unable to get -// connection information to display in the bottom-bar when using SSH. It's possible -// we could jank around this (maybe by using a temporary SSH host), but that's not -// ideal. -// -// This persists a single workspace connection, and lets you execute commands, check -// for network information, and forward ports. -// -// The VS Code extension is located at https://github.com/coder/vscode-coder. The -// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc` -// which calls this function. This API must maintain backward compatibility with -// the extension to support prior versions of Coder. -func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) { - if options == nil { - options = &codersdk.DialWorkspaceAgentOptions{} - } - // We need this to track upload and download! - options.EnableTrafficStats = true - - agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options) - if err != nil { - return nil, nil, err - } - api := &api{ - agentConn: agentConn, - } - r := chi.NewRouter() - // This is to prevent unauthorized clients on the same machine from executing - // requests on behalf of the workspace. - r.Use(sessionTokenMiddleware(client.SessionToken())) - r.Route("/v1", func(r chi.Router) { - r.Get("/port/{port}", api.port) - r.Get("/network", api.network) - r.Post("/execute", api.execute) - }) - return r, api, nil -} - -type api struct { - agentConn *codersdk.AgentConn - sshClient *ssh.Client - sshClientErr error - sshClientOnce sync.Once - - lastNetwork time.Time -} - -func (api *api) Close() error { - if api.sshClient != nil { - api.sshClient.Close() - } - return api.agentConn.Close() -} - -type NetworkResponse struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` -} - -// port accepts an HTTP request to dial a port on the workspace agent. -// It uses an HTTP connection upgrade to transfer the connection to TCP. -func (api *api) port(w http.ResponseWriter, r *http.Request) { - port, err := strconv.Atoi(chi.URLParam(r, "port")) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{ - Message: "Port must be an integer!", - }) - return - } - remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - httpapi.InternalServerError(w, err) - return - } - defer remoteConn.Close() - - // Upgrade an switch to TCP! - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "tcp") - w.WriteHeader(http.StatusSwitchingProtocols) - - hijacker, ok := w.(http.Hijacker) - if !ok { - httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w)) - return - } - - localConn, brw, err := hijacker.Hijack() - if err != nil { - httpapi.InternalServerError(w, err) - return - } - defer localConn.Close() - - _ = brw.Flush() - agent.Bicopy(r.Context(), localConn, remoteConn) -} - -// network returns network information about the workspace. -func (api *api) network(w http.ResponseWriter, r *http.Request) { - // Ping the workspace agent to get the latency. - latency, p2p, err := api.agentConn.Ping(r.Context()) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to ping the workspace agent.", - Detail: err.Error(), - }) - return - } - - node := api.agentConn.Node() - derpMap := api.agentConn.DERPMap() - derpLatency := map[string]float64{} - - // Convert DERP region IDs to friendly names for display in the UI. - for rawRegion, latency := range node.DERPLatency { - regionParts := strings.SplitN(rawRegion, "-", 2) - regionID, err := strconv.Atoi(regionParts[0]) - if err != nil { - continue - } - region, found := derpMap.Regions[regionID] - if !found { - // It's possible that a workspace agent is using an old DERPMap - // and reports regions that do not exist. If that's the case, - // report the region as unknown! - region = &tailcfg.DERPRegion{ - RegionID: regionID, - RegionName: fmt.Sprintf("Unnamed %d", regionID), - } - } - // Convert the microseconds to milliseconds. - derpLatency[region.RegionName] = latency * 1000 - } - - totalRx := uint64(0) - totalTx := uint64(0) - for _, stat := range api.agentConn.ExtractTrafficStats() { - totalRx += stat.RxBytes - totalTx += stat.TxBytes - } - // Tracking the time since last request is required because - // ExtractTrafficStats() resets its counters after each call. - dur := time.Since(api.lastNetwork) - uploadSecs := float64(totalTx) / dur.Seconds() - downloadSecs := float64(totalRx) / dur.Seconds() - - api.lastNetwork = time.Now() - - httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{ - P2P: p2p, - Latency: float64(latency.Microseconds()) / 1000, - PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName, - DERPLatency: derpLatency, - UploadBytesSec: int64(uploadSecs), - DownloadBytesSec: int64(downloadSecs), - }) -} - -type ExecuteRequest struct { - Command string `json:"command"` - Stdin string `json:"stdin"` -} - -type ExecuteResponse struct { - Data string `json:"data"` - ExitCode *int `json:"exit_code"` -} - -// execute runs the command provided, streams the output back, and returns the exit code. -func (api *api) execute(w http.ResponseWriter, r *http.Request) { - var req ExecuteRequest - if !httpapi.Read(r.Context(), w, r, &req) { - return - } - api.sshClientOnce.Do(func() { - // The SSH client is lazily created because it's not needed for - // all requests. It's only needed for the execute endpoint. - // - // It's alright if this fails on the first execution, because - // a new instance of this API is created for each remote SSH request. - api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background()) - }) - if api.sshClientErr != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create SSH client.", - Detail: api.sshClientErr.Error(), - }) - return - } - session, err := api.sshClient.NewSession() - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create SSH session.", - Detail: err.Error(), - }) - return - } - defer session.Close() - f, ok := w.(http.Flusher) - if !ok { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w), - }) - return - } - - execWriter := &execWriter{w, f} - session.Stdout = execWriter - session.Stderr = execWriter - session.Stdin = strings.NewReader(req.Stdin) - err = session.Start(req.Command) - if err != nil { - httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to start SSH session.", - Detail: err.Error(), - }) - return - } - err = session.Wait() - - writeExit := func(exitCode int) { - data, _ := json.Marshal(&ExecuteResponse{ - ExitCode: &exitCode, - }) - _, _ = w.Write(data) - f.Flush() - } - - if err != nil { - var exitError *ssh.ExitError - if errors.As(err, &exitError) { - writeExit(exitError.ExitStatus()) - return - } - } - writeExit(0) -} - -type execWriter struct { - w http.ResponseWriter - f http.Flusher -} - -func (e *execWriter) Write(data []byte) (int, error) { - js, err := json.Marshal(&ExecuteResponse{ - Data: string(data), - }) - if err != nil { - return 0, err - } - _, err = e.w.Write(js) - if err != nil { - return 0, err - } - e.f.Flush() - return len(data), nil -} - -func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get(AuthHeader) - if token == "" { - httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ - Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader), - }) - return - } - if token != sessionToken { - httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ - Message: "The session token provided doesn't match the one used to create the client.", - }) - return - } - w.Header().Set("Access-Control-Allow-Origin", "*") - h.ServeHTTP(w, r) - }) - } -} diff --git a/cli/vscodeipc/vscodeipc_test.go b/cli/vscodeipc/vscodeipc_test.go deleted file mode 100644 index 5213c2422be6a..0000000000000 --- a/cli/vscodeipc/vscodeipc_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package vscodeipc_test - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "net/http/httptest" - "net/url" - "runtime" - "testing" - - "github.com/google/uuid" - "github.com/spf13/afero" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "nhooyr.io/websocket" - - "github.com/coder/coder/agent" - "github.com/coder/coder/cli/vscodeipc" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/tailnet" - "github.com/coder/coder/tailnet/tailnettest" - "github.com/coder/coder/testutil" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestVSCodeIPC(t *testing.T) { - t.Parallel() - ctx := context.Background() - - id := uuid.New() - derpMap := tailnettest.RunDERPAndSTUN(t) - coordinator := tailnet.NewCoordinator() - t.Cleanup(func() { - _ = coordinator.Close() - }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id): - assert.Equal(t, r.Method, http.MethodGet) - httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{ - DERPMap: derpMap, - }) - return - case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id): - assert.Equal(t, r.Method, http.MethodGet) - ws, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) - _ = coordinator.ServeClient(conn, uuid.New(), id) - return - case "/api/v2/workspaceagents/me/version": - assert.Equal(t, r.Method, http.MethodPost) - w.WriteHeader(http.StatusOK) - return - case "/api/v2/workspaceagents/me/metadata": - assert.Equal(t, r.Method, http.MethodGet) - httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{ - DERPMap: derpMap, - }) - return - case "/api/v2/workspaceagents/me/coordinate": - assert.Equal(t, r.Method, http.MethodGet) - ws, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) - _ = coordinator.ServeAgent(conn, id) - return - case "/api/v2/workspaceagents/me/report-stats": - assert.Equal(t, r.Method, http.MethodPost) - w.WriteHeader(http.StatusOK) - return - case "/": - w.WriteHeader(http.StatusOK) - return - default: - t.Fatalf("unexpected request %s", r.URL.Path) - } - })) - t.Cleanup(srv.Close) - srvURL, _ := url.Parse(srv.URL) - - client := codersdk.New(srvURL) - token := uuid.New().String() - client.SetSessionToken(token) - agentConn := agent.New(agent.Options{ - Client: client, - Filesystem: afero.NewMemMapFs(), - TempDir: t.TempDir(), - }) - t.Cleanup(func() { - _ = agentConn.Close() - }) - - handler, closer, err := vscodeipc.New(ctx, client, id, nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = closer.Close() - }) - - // Ensure that we're actually connected! - require.Eventually(t, func() bool { - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/v1/network", nil) - req.Header.Set(vscodeipc.AuthHeader, token) - handler.ServeHTTP(res, req) - network := &vscodeipc.NetworkResponse{} - err = json.NewDecoder(res.Body).Decode(&network) - assert.NoError(t, err) - return network.Latency != 0 - }, testutil.WaitLong, testutil.IntervalFast) - - _, port, err := net.SplitHostPort(srvURL.Host) - require.NoError(t, err) - - t.Run("NoSessionToken", func(t *testing.T) { - t.Parallel() - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - handler.ServeHTTP(res, req) - require.Equal(t, http.StatusUnauthorized, res.Code) - }) - - t.Run("MismatchedSessionToken", func(t *testing.T) { - t.Parallel() - res := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - req.Header.Set(vscodeipc.AuthHeader, uuid.NewString()) - handler.ServeHTTP(res, req) - require.Equal(t, http.StatusUnauthorized, res.Code) - }) - - t.Run("Port", func(t *testing.T) { - // Tests that the port endpoint can be used for forward traffic. - // For this test, we simply use the already listening httptest server. - t.Parallel() - input, output := net.Pipe() - defer input.Close() - defer output.Close() - res := &hijackable{httptest.NewRecorder(), output} - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) - req.Header.Set(vscodeipc.AuthHeader, token) - go handler.ServeHTTP(res, req) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil) - require.NoError(t, err) - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return input, nil - }, - }, - } - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("Execute", func(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Execute isn't supported on Windows yet!") - return - } - - res := httptest.NewRecorder() - data, _ := json.Marshal(vscodeipc.ExecuteRequest{ - Command: "echo test", - }) - req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data)) - req.Header.Set(vscodeipc.AuthHeader, token) - handler.ServeHTTP(res, req) - - decoder := json.NewDecoder(res.Body) - var msg vscodeipc.ExecuteResponse - err = decoder.Decode(&msg) - require.NoError(t, err) - require.Equal(t, "test\n", msg.Data) - err = decoder.Decode(&msg) - require.NoError(t, err) - require.Equal(t, 0, *msg.ExitCode) - }) -} - -type hijackable struct { - *httptest.ResponseRecorder - conn net.Conn -} - -func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil -} diff --git a/cli/vscodeipc_test.go b/cli/vscodeipc_test.go deleted file mode 100644 index 1edb52102841c..0000000000000 --- a/cli/vscodeipc_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cli_test - -import ( - "io" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/cli/clitest" - "github.com/coder/coder/testutil" -) - -func TestVSCodeIPC(t *testing.T) { - t.Parallel() - // Ensures the vscodeipc command outputs it's running port! - // This signifies to the caller that it's ready to accept requests. - t.Run("PortOutputs", func(t *testing.T) { - t.Parallel() - client, workspace, _ := setupWorkspaceForAgent(t, nil) - cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(), - "--token", client.SessionToken(), "--url", client.URL.String()) - rdr, wtr := io.Pipe() - cmd.SetOut(wtr) - ctx, cancelFunc := testutil.Context(t) - defer cancelFunc() - done := make(chan error, 1) - go func() { - err := cmd.ExecuteContext(ctx) - done <- err - }() - - buf := make([]byte, 64) - require.Eventually(t, func() bool { - t.Log("Looking for address!") - var err error - _, err = rdr.Read(buf) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) - t.Logf("Address: %s\n", buf) - - cancelFunc() - <-done - }) -} diff --git a/cli/vscodessh.go b/cli/vscodessh.go new file mode 100644 index 0000000000000..f4f27b832d084 --- /dev/null +++ b/cli/vscodessh.go @@ -0,0 +1,237 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/spf13/afero" + "github.com/spf13/cobra" + "golang.org/x/xerrors" + "tailscale.com/tailcfg" + + "github.com/coder/coder/codersdk" +) + +// vscodeSSH is used by the Coder VS Code extension to establish +// a connection to a workspace. +// +// This command needs to remain stable for compatibility with +// various VS Code versions, so it's kept separate from our +// standard SSH command. +func vscodeSSH() *cobra.Command { + var ( + sessionTokenFile string + urlFile string + networkInfoDir string + networkInfoInterval time.Duration + ) + cmd := &cobra.Command{ + // A SSH config entry is added by the VS Code extension that + // passes %h to ProxyCommand. The prefix of `coder-vscode--` + // is a magical string represented in our VS Cod extension. + // It's not important here, only the delimiter `--` is. + Use: "vscodessh -->", + Hidden: true, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if networkInfoDir == "" { + return xerrors.New("network-info-dir must be specified") + } + if sessionTokenFile == "" { + return xerrors.New("session-token-file must be specified") + } + if urlFile == "" { + return xerrors.New("url-file must be specified") + } + + fs, ok := cmd.Context().Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + + sessionToken, err := afero.ReadFile(fs, sessionTokenFile) + if err != nil { + return xerrors.Errorf("read session token: %w", err) + } + rawURL, err := afero.ReadFile(fs, urlFile) + if err != nil { + return xerrors.Errorf("read url: %w", err) + } + serverURL, err := url.Parse(string(rawURL)) + if err != nil { + return xerrors.Errorf("parse url: %w", err) + } + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + err = fs.MkdirAll(networkInfoDir, 0700) + if err != nil { + return xerrors.Errorf("mkdir: %w", err) + } + + client := codersdk.New(serverURL) + client.SetSessionToken(string(sessionToken)) + + parts := strings.Split(args[0], "--") + if len(parts) < 3 { + return xerrors.Errorf("invalid argument format. must be: coder-vscode----") + } + owner := parts[1] + name := parts[2] + + workspace, err := client.WorkspaceByOwnerAndName(ctx, owner, name, codersdk.WorkspaceOptions{}) + if err != nil { + return xerrors.Errorf("find workspace: %w", err) + } + var agent codersdk.WorkspaceAgent + var found bool + for _, resource := range workspace.LatestBuild.Resources { + if len(resource.Agents) == 0 { + continue + } + for _, resourceAgent := range resource.Agents { + // If an agent name isn't included we default to + // the first agent! + if len(parts) != 4 { + agent = resourceAgent + found = true + break + } + if resourceAgent.Name != parts[3] { + continue + } + agent = resourceAgent + found = true + break + } + if found { + break + } + } + agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{ + EnableTrafficStats: true, + }) + if err != nil { + return xerrors.Errorf("dial workspace agent: %w", err) + } + defer agentConn.Close() + agentConn.AwaitReachable(ctx) + rawSSH, err := agentConn.SSH(ctx) + if err != nil { + return err + } + defer rawSSH.Close() + // Copy SSH traffic over stdio. + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + }() + go func() { + _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + }() + // The VS Code extension obtains the PID of the SSH process to + // read the file below which contains network information to display. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + ticker := time.NewTicker(networkInfoInterval) + defer ticker.Stop() + lastCollected := time.Now() + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + } + stats, err := collectNetworkStats(ctx, agentConn, lastCollected) + if err != nil { + return err + } + rawStats, err := json.Marshal(stats) + if err != nil { + return err + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0600) + if err != nil { + return err + } + lastCollected = time.Now() + } + }, + } + cmd.Flags().StringVarP(&networkInfoDir, "network-info-dir", "", "", "Specifies a directory to write network information periodically.") + cmd.Flags().StringVarP(&sessionTokenFile, "session-token-file", "", "", "Specifies a file that contains a session token.") + cmd.Flags().StringVarP(&urlFile, "url-file", "", "", "Specifies a file that contains the Coder URL.") + cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 3*time.Second, "Specifies the interval to update network information.") + return cmd +} + +type sshNetworkStats struct { + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` +} + +func collectNetworkStats(ctx context.Context, agentConn *codersdk.AgentConn, lastCollected time.Time) (*sshNetworkStats, error) { + latency, p2p, err := agentConn.Ping(ctx) + if err != nil { + return nil, err + } + node := agentConn.Node() + derpMap := agentConn.DERPMap() + derpLatency := map[string]float64{} + + // Convert DERP region IDs to friendly names for display in the UI. + for rawRegion, latency := range node.DERPLatency { + regionParts := strings.SplitN(rawRegion, "-", 2) + regionID, err := strconv.Atoi(regionParts[0]) + if err != nil { + continue + } + region, found := derpMap.Regions[regionID] + if !found { + // It's possible that a workspace agent is using an old DERPMap + // and reports regions that do not exist. If that's the case, + // report the region as unknown! + region = &tailcfg.DERPRegion{ + RegionID: regionID, + RegionName: fmt.Sprintf("Unnamed %d", regionID), + } + } + // Convert the microseconds to milliseconds. + derpLatency[region.RegionName] = latency * 1000 + } + + totalRx := uint64(0) + totalTx := uint64(0) + for _, stat := range agentConn.ExtractTrafficStats() { + totalRx += stat.RxBytes + totalTx += stat.TxBytes + } + // Tracking the time since last request is required because + // ExtractTrafficStats() resets its counters after each call. + dur := time.Since(lastCollected) + uploadSecs := float64(totalTx) / dur.Seconds() + downloadSecs := float64(totalRx) / dur.Seconds() + + return &sshNetworkStats{ + P2P: p2p, + Latency: float64(latency.Microseconds()) / 1000, + PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName, + DERPLatency: derpLatency, + UploadBytesSec: int64(uploadSecs), + DownloadBytesSec: int64(downloadSecs), + }, nil +} diff --git a/cli/vscodessh_test.go b/cli/vscodessh_test.go new file mode 100644 index 0000000000000..55840ac9d7ded --- /dev/null +++ b/cli/vscodessh_test.go @@ -0,0 +1,71 @@ +package cli_test + +import ( + "context" + "fmt" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/agent" + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +// TestVSCodeSSH ensures the agent connects properly with SSH +// and that network information is properly written to the FS. +func TestVSCodeSSH(t *testing.T) { + t.Parallel() + ctx, cancel := testutil.Context(t) + defer cancel() + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + user, err := client.User(ctx, codersdk.Me) + require.NoError(t, err) + + agentClient := codersdk.New(client.URL) + agentClient.SetSessionToken(agentToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + Logger: slogtest.Make(t, nil).Named("agent"), + }) + defer func() { + _ = agentCloser.Close() + }() + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + fs := afero.NewMemMapFs() + err = afero.WriteFile(fs, "/url", []byte(client.URL.String()), 0600) + require.NoError(t, err) + err = afero.WriteFile(fs, "/token", []byte(client.SessionToken()), 0600) + require.NoError(t, err) + + cmd, _ := clitest.New(t, + "vscodessh", + "--url-file", "/url", + "--session-token-file", "/token", + "--network-info-dir", "/net", + "--network-info-interval", "25ms", + fmt.Sprintf("coder-vscode--%s--%s", user.Username, workspace.Name)) + done := make(chan struct{}) + go func() { + //nolint // The above seems reasonable for a one-off test. + err := cmd.ExecuteContext(context.WithValue(ctx, "fs", fs)) + assert.NoError(t, err) + close(done) + }() + require.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) + cancel() + <-done +} From 29983eba0eddc3a11c2242f3b94966767e3715fa Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Jan 2023 19:49:59 -0600 Subject: [PATCH 2/5] Track SSH connections --- agent/agent.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agent/agent.go b/agent/agent.go index dd900700c8913..f012963f32e6c 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -304,7 +304,9 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ if err != nil { return } - go a.sshServer.HandleConn(conn) + _ = a.trackConnGoroutine(func() { + a.sshServer.HandleConn(conn) + }) } }); err != nil { return nil, err From 6454d1501502cf7e46fca37ac89a6573a9d9e619 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Jan 2023 19:59:38 -0600 Subject: [PATCH 3/5] Fix error type --- cli/vscodessh_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cli/vscodessh_test.go b/cli/vscodessh_test.go index 55840ac9d7ded..6af34913c97c9 100644 --- a/cli/vscodessh_test.go +++ b/cli/vscodessh_test.go @@ -56,7 +56,9 @@ func TestVSCodeSSH(t *testing.T) { go func() { //nolint // The above seems reasonable for a one-off test. err := cmd.ExecuteContext(context.WithValue(ctx, "fs", fs)) - assert.NoError(t, err) + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + } close(done) }() require.Eventually(t, func() bool { From e9dedd3052b26f3a1dc55e33ee9fbe7c0f1d997d Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Jan 2023 22:07:07 -0600 Subject: [PATCH 4/5] Fix closing of SSH listener --- agent/agent.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/agent/agent.go b/agent/agent.go index f012963f32e6c..a872689327cc0 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -298,6 +298,15 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ _ = sshListener.Close() } }() + if err = a.trackConnGoroutine(func() { + select { + case <-network.Closed(): + case <-a.closed: + } + _ = sshListener.Close() + }); err != nil { + return nil, err + } if err = a.trackConnGoroutine(func() { for { conn, err := sshListener.Accept() From 10a5af2b3ec3acd2a13fb3ef3275c468fa89ef95 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Jan 2023 22:16:30 -0600 Subject: [PATCH 5/5] Fix connection tracking --- agent/agent.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index a872689327cc0..47d9c394a86b9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -298,22 +298,22 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ _ = sshListener.Close() } }() - if err = a.trackConnGoroutine(func() { - select { - case <-network.Closed(): - case <-a.closed: - } - _ = sshListener.Close() - }); err != nil { - return nil, err - } if err = a.trackConnGoroutine(func() { for { conn, err := sshListener.Accept() if err != nil { return } + closed := make(chan struct{}) + _ = a.trackConnGoroutine(func() { + select { + case <-network.Closed(): + case <-closed: + } + _ = conn.Close() + }) _ = a.trackConnGoroutine(func() { + defer close(closed) a.sshServer.HandleConn(conn) }) }