From 9ff3342ef707c6fdc8b633942207aa8162a0d20d Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Wed, 8 Jan 2025 18:55:41 -0800 Subject: [PATCH] feat: add --network-info-dir and --network-info-interval flags to coder ssh This is the first in a series of PRs to enable "coder ssh" to replace "coder vscodessh". This change adds --network-info-dir and --network-info-interval flags to the ssh subcommand. These were formerly only available with the vscodessh subcommand. Subsequent PRs will add a --ssh-host-prefix flag to the ssh subcommand, and adjust the log file naming to contain the parent PID. --- cli/ssh.go | 224 +++++++++++++++++++++++++-- cli/ssh_test.go | 73 +++++++++ cli/testdata/coder_ssh_--help.golden | 6 + cli/vscodessh.go | 152 +----------------- docs/reference/cli/ssh.md | 17 ++ 5 files changed, 309 insertions(+), 163 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 7a1d5940bfd01..ea03916e3c293 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -13,6 +14,7 @@ import ( "os/exec" "path/filepath" "slices" + "strconv" "strings" "sync" "time" @@ -21,11 +23,14 @@ import ( "github.com/gofrs/flock" "github.com/google/uuid" "github.com/mattn/go-isatty" + "github.com/spf13/afero" gossh "golang.org/x/crypto/ssh" gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "tailscale.com/tailcfg" + "tailscale.com/types/netlogtype" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" @@ -55,19 +60,21 @@ var ( func (r *RootCmd) ssh() *serpent.Command { var ( - stdio bool - forwardAgent bool - forwardGPG bool - identityAgent string - wsPollInterval time.Duration - waitEnum string - noWait bool - logDirPath string - remoteForwards []string - env []string - usageApp string - disableAutostart bool - appearanceConfig codersdk.AppearanceConfig + stdio bool + forwardAgent bool + forwardGPG bool + identityAgent string + wsPollInterval time.Duration + waitEnum string + noWait bool + logDirPath string + remoteForwards []string + env []string + usageApp string + disableAutostart bool + appearanceConfig codersdk.AppearanceConfig + networkInfoDir string + networkInfoInterval time.Duration ) client := new(codersdk.Client) cmd := &serpent.Command{ @@ -284,13 +291,21 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + var errCh <-chan error + if networkInfoDir != "" { + errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval) + if err != nil { + return err + } + } + wg.Add(1) go func() { defer wg.Done() watchAndClose(ctx, func() error { stack.close(xerrors.New("watchAndClose")) return nil - }, logger, client, workspace) + }, logger, client, workspace, errCh) }() copier.copy(&wg) return nil @@ -312,6 +327,14 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + var errCh <-chan error + if networkInfoDir != "" { + errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval) + if err != nil { + return err + } + } + wg.Add(1) go func() { defer wg.Done() @@ -324,6 +347,7 @@ func (r *RootCmd) ssh() *serpent.Command { logger, client, workspace, + errCh, ) }() @@ -540,6 +564,17 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.StringOf(&usageApp), Hidden: true, }, + { + Flag: "network-info-dir", + Description: "Specifies a directory to write network information periodically.", + Value: serpent.StringOf(&networkInfoDir), + }, + { + Flag: "network-info-interval", + Description: "Specifies the interval to update network information.", + Default: "5s", + Value: serpent.DurationOf(&networkInfoInterval), + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd @@ -555,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command { // will usually not propagate. // // See: https://github.com/coder/coder/issues/6180 -func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace) { +func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace, errCh <-chan error) { // Ensure session is ended on both context cancellation // and workspace stop. defer func() { @@ -606,6 +641,9 @@ startWatchLoop: logger.Info(ctx, "workspace stopped") return } + case err := <-errCh: + logger.Error(ctx, "failed to collect network stats", slog.Error(err)) + return } } } @@ -1144,3 +1182,159 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName { return codersdk.UsageAppNameSSH } + +func setStatsCallback( + ctx context.Context, + agentConn *workspacesdk.AgentConn, + logger slog.Logger, + networkInfoDir string, + networkInfoInterval time.Duration, +) (<-chan error, error) { + fs, ok := ctx.Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil { + return nil, xerrors.Errorf("mkdir: %w", err) + } + + // The VS Code extension obtains the PID of the SSH process to + // read files to display logs and network info. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + pid := os.Getppid() + + // The VS Code extension obtains the PID of the SSH process to + // read the file below which contains network information to display. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid)) + + var ( + firstErrTime time.Time + errCh = make(chan error, 1) + ) + cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { + sendErr := func(tolerate bool, err error) { + logger.Error(ctx, "collect network stats", slog.Error(err)) + // Tolerate up to 1 minute of errors. + if tolerate { + if firstErrTime.IsZero() { + logger.Info(ctx, "tolerating network stats errors for up to 1 minute") + firstErrTime = time.Now() + } + if time.Since(firstErrTime) < time.Minute { + return + } + } + + select { + case errCh <- err: + default: + } + } + + stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual) + if err != nil { + sendErr(true, err) + return + } + + rawStats, err := json.Marshal(stats) + if err != nil { + sendErr(false, err) + return + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) + if err != nil { + sendErr(false, err) + return + } + + firstErrTime = time.Time{} + } + + now := time.Now() + cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) + agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) + return errCh, nil +} + +type sshNetworkStats struct { + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` +} + +func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { + latency, p2p, pingResult, err := agentConn.Ping(ctx) + if err != nil { + return nil, err + } + node := agentConn.Node() + derpMap := agentConn.DERPMap() + derpLatency := map[string]float64{} + + // Convert DERP region IDs to friendly names for display in the UI. + for rawRegion, latency := range node.DERPLatency { + regionParts := strings.SplitN(rawRegion, "-", 2) + regionID, err := strconv.Atoi(regionParts[0]) + if err != nil { + continue + } + region, found := derpMap.Regions[regionID] + if !found { + // It's possible that a workspace agent is using an old DERPMap + // and reports regions that do not exist. If that's the case, + // report the region as unknown! + region = &tailcfg.DERPRegion{ + RegionID: regionID, + RegionName: fmt.Sprintf("Unnamed %d", regionID), + } + } + // Convert the microseconds to milliseconds. + derpLatency[region.RegionName] = latency * 1000 + } + + totalRx := uint64(0) + totalTx := uint64(0) + for _, stat := range counts { + totalRx += stat.RxBytes + totalTx += stat.TxBytes + } + // Tracking the time since last request is required because + // ExtractTrafficStats() resets its counters after each call. + dur := end.Sub(start) + uploadSecs := float64(totalTx) / dur.Seconds() + downloadSecs := float64(totalRx) / dur.Seconds() + + // Sometimes the preferred DERP doesn't match the one we're actually + // connected with. Perhaps because the agent prefers a different DERP and + // we're using that server instead. + preferredDerpID := node.PreferredDERP + if pingResult.DERPRegionID != 0 { + preferredDerpID = pingResult.DERPRegionID + } + preferredDerp, ok := derpMap.Regions[preferredDerpID] + preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID) + if ok { + preferredDerpName = preferredDerp.RegionName + } + if _, ok := derpLatency[preferredDerpName]; !ok { + derpLatency[preferredDerpName] = 0 + } + + return &sshNetworkStats{ + P2P: p2p, + Latency: float64(latency.Microseconds()) / 1000, + PreferredDERP: preferredDerpName, + DERPLatency: derpLatency, + UploadBytesSec: int64(uploadSecs), + DownloadBytesSec: int64(downloadSecs), + }, nil +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index bd107852251f7..9a16460ea5fe4 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/google/uuid" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -438,6 +439,78 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("NetworkInfo", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + _ = agenttest.New(t, client.URL, agentToken) + <-ctx.Done() + }) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) + + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name, "--network-info-dir", "/net", "--network-info-interval", "25ms") + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + command := "sh -c exit" + if runtime.GOOS == "windows" { + command = "cmd.exe /c exit" + } + err = session.Run(command) + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + assert.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) + + <-cmdDone + }) + t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) { t.Parallel() diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 80aaa3c204fda..d847e9d7abb03 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -30,6 +30,12 @@ OPTIONS: -l, --log-dir string, $CODER_SSH_LOG_DIR Specify the directory containing SSH diagnostic log files. + --network-info-dir string + Specifies a directory to write network information periodically. + + --network-info-interval duration (default: 5s) + Specifies the interval to update network information. + --no-wait bool, $CODER_SSH_NO_WAIT Enter workspace immediately after the agent has connected. This is the default if the template has configured the agent startup script diff --git a/cli/vscodessh.go b/cli/vscodessh.go index d64e49c674a01..630c405241d17 100644 --- a/cli/vscodessh.go +++ b/cli/vscodessh.go @@ -2,21 +2,17 @@ package cli import ( "context" - "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path/filepath" - "strconv" "strings" "time" "github.com/spf13/afero" "golang.org/x/xerrors" - "tailscale.com/tailcfg" - "tailscale.com/types/netlogtype" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" @@ -83,11 +79,6 @@ func (r *RootCmd) vscodeSSH() *serpent.Command { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - err = fs.MkdirAll(networkInfoDir, 0o700) - if err != nil { - return xerrors.Errorf("mkdir: %w", err) - } - client := codersdk.New(serverURL) client.SetSessionToken(string(sessionToken)) @@ -155,20 +146,13 @@ func (r *RootCmd) vscodeSSH() *serpent.Command { } } - // The VS Code extension obtains the PID of the SSH process to - // read files to display logs and network info. - // - // We get the parent PID because it's assumed `ssh` is calling this - // command via the ProxyCommand SSH option. - pid := os.Getppid() - // Use a stripped down writer that doesn't sync, otherwise you get // "failed to sync sloghuman: sync /dev/stderr: The handle is // invalid" on Windows. Syncing isn't required for stdout/stderr // anyways. logger := inv.Logger.AppendSinks(sloghuman.Sink(slogWriter{w: inv.Stderr})).Leveled(slog.LevelDebug) if logDir != "" { - logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid)) + logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", os.Getppid())) logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return xerrors.Errorf("open log file %q: %w", logFilePath, err) @@ -212,61 +196,10 @@ func (r *RootCmd) vscodeSSH() *serpent.Command { _, _ = io.Copy(rawSSH, inv.Stdin) }() - // The VS Code extension obtains the PID of the SSH process to - // read the file below which contains network information to display. - // - // We get the parent PID because it's assumed `ssh` is calling this - // command via the ProxyCommand SSH option. - networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid)) - - var ( - firstErrTime time.Time - errCh = make(chan error, 1) - ) - cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { - sendErr := func(tolerate bool, err error) { - logger.Error(ctx, "collect network stats", slog.Error(err)) - // Tolerate up to 1 minute of errors. - if tolerate { - if firstErrTime.IsZero() { - logger.Info(ctx, "tolerating network stats errors for up to 1 minute") - firstErrTime = time.Now() - } - if time.Since(firstErrTime) < time.Minute { - return - } - } - - select { - case errCh <- err: - default: - } - } - - stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual) - if err != nil { - sendErr(true, err) - return - } - - rawStats, err := json.Marshal(stats) - if err != nil { - sendErr(false, err) - return - } - err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) - if err != nil { - sendErr(false, err) - return - } - - firstErrTime = time.Time{} + errCh, err := setStatsCallback(ctx, agentConn, logger, networkInfoDir, networkInfoInterval) + if err != nil { + return err } - - now := time.Now() - cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) - agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) - select { case <-ctx.Done(): return nil @@ -323,80 +256,3 @@ var _ io.Writer = slogWriter{} func (s slogWriter) Write(p []byte) (n int, err error) { return s.w.Write(p) } - -type sshNetworkStats struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` -} - -func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { - latency, p2p, pingResult, err := agentConn.Ping(ctx) - if err != nil { - return nil, err - } - node := agentConn.Node() - derpMap := agentConn.DERPMap() - derpLatency := map[string]float64{} - - // Convert DERP region IDs to friendly names for display in the UI. - for rawRegion, latency := range node.DERPLatency { - regionParts := strings.SplitN(rawRegion, "-", 2) - regionID, err := strconv.Atoi(regionParts[0]) - if err != nil { - continue - } - region, found := derpMap.Regions[regionID] - if !found { - // It's possible that a workspace agent is using an old DERPMap - // and reports regions that do not exist. If that's the case, - // report the region as unknown! - region = &tailcfg.DERPRegion{ - RegionID: regionID, - RegionName: fmt.Sprintf("Unnamed %d", regionID), - } - } - // Convert the microseconds to milliseconds. - derpLatency[region.RegionName] = latency * 1000 - } - - totalRx := uint64(0) - totalTx := uint64(0) - for _, stat := range counts { - totalRx += stat.RxBytes - totalTx += stat.TxBytes - } - // Tracking the time since last request is required because - // ExtractTrafficStats() resets its counters after each call. - dur := end.Sub(start) - uploadSecs := float64(totalTx) / dur.Seconds() - downloadSecs := float64(totalRx) / dur.Seconds() - - // Sometimes the preferred DERP doesn't match the one we're actually - // connected with. Perhaps because the agent prefers a different DERP and - // we're using that server instead. - preferredDerpID := node.PreferredDERP - if pingResult.DERPRegionID != 0 { - preferredDerpID = pingResult.DERPRegionID - } - preferredDerp, ok := derpMap.Regions[preferredDerpID] - preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID) - if ok { - preferredDerpName = preferredDerp.RegionName - } - if _, ok := derpLatency[preferredDerpName]; !ok { - derpLatency[preferredDerpName] = 0 - } - - return &sshNetworkStats{ - P2P: p2p, - Latency: float64(latency.Microseconds()) / 1000, - PreferredDERP: preferredDerpName, - DERPLatency: derpLatency, - UploadBytesSec: int64(uploadSecs), - DownloadBytesSec: int64(downloadSecs), - }, nil -} diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index 72513e0c9ecdc..74e28837ad7e4 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -103,6 +103,23 @@ Enable remote port forwarding (remote_port:local_address:local_port). Set environment variable(s) for session (key1=value1,key2=value2,...). +### --network-info-dir + +| | | +|------|---------------------| +| Type | string | + +Specifies a directory to write network information periodically. + +### --network-info-interval + +| | | +|---------|-----------------------| +| Type | duration | +| Default | 5s | + +Specifies the interval to update network information. + ### --disable-autostart | | |