diff --git a/cli/ssh.go b/cli/ssh.go
index 7a1d5940bfd01..ea03916e3c293 100644
--- a/cli/ssh.go
+++ b/cli/ssh.go
@@ -3,6 +3,7 @@ package cli
import (
"bytes"
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -13,6 +14,7 @@ import (
"os/exec"
"path/filepath"
"slices"
+ "strconv"
"strings"
"sync"
"time"
@@ -21,11 +23,14 @@ import (
"github.com/gofrs/flock"
"github.com/google/uuid"
"github.com/mattn/go-isatty"
+ "github.com/spf13/afero"
gossh "golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
"golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/netlogtype"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
@@ -55,19 +60,21 @@ var (
func (r *RootCmd) ssh() *serpent.Command {
var (
- stdio bool
- forwardAgent bool
- forwardGPG bool
- identityAgent string
- wsPollInterval time.Duration
- waitEnum string
- noWait bool
- logDirPath string
- remoteForwards []string
- env []string
- usageApp string
- disableAutostart bool
- appearanceConfig codersdk.AppearanceConfig
+ stdio bool
+ forwardAgent bool
+ forwardGPG bool
+ identityAgent string
+ wsPollInterval time.Duration
+ waitEnum string
+ noWait bool
+ logDirPath string
+ remoteForwards []string
+ env []string
+ usageApp string
+ disableAutostart bool
+ appearanceConfig codersdk.AppearanceConfig
+ networkInfoDir string
+ networkInfoInterval time.Duration
)
client := new(codersdk.Client)
cmd := &serpent.Command{
@@ -284,13 +291,21 @@ func (r *RootCmd) ssh() *serpent.Command {
return err
}
+ var errCh <-chan error
+ if networkInfoDir != "" {
+ errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
+ if err != nil {
+ return err
+ }
+ }
+
wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(ctx, func() error {
stack.close(xerrors.New("watchAndClose"))
return nil
- }, logger, client, workspace)
+ }, logger, client, workspace, errCh)
}()
copier.copy(&wg)
return nil
@@ -312,6 +327,14 @@ func (r *RootCmd) ssh() *serpent.Command {
return err
}
+ var errCh <-chan error
+ if networkInfoDir != "" {
+ errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
+ if err != nil {
+ return err
+ }
+ }
+
wg.Add(1)
go func() {
defer wg.Done()
@@ -324,6 +347,7 @@ func (r *RootCmd) ssh() *serpent.Command {
logger,
client,
workspace,
+ errCh,
)
}()
@@ -540,6 +564,17 @@ func (r *RootCmd) ssh() *serpent.Command {
Value: serpent.StringOf(&usageApp),
Hidden: true,
},
+ {
+ Flag: "network-info-dir",
+ Description: "Specifies a directory to write network information periodically.",
+ Value: serpent.StringOf(&networkInfoDir),
+ },
+ {
+ Flag: "network-info-interval",
+ Description: "Specifies the interval to update network information.",
+ Default: "5s",
+ Value: serpent.DurationOf(&networkInfoInterval),
+ },
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
}
return cmd
@@ -555,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
// will usually not propagate.
//
// See: https://github.com/coder/coder/issues/6180
-func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace) {
+func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace, errCh <-chan error) {
// Ensure session is ended on both context cancellation
// and workspace stop.
defer func() {
@@ -606,6 +641,9 @@ startWatchLoop:
logger.Info(ctx, "workspace stopped")
return
}
+ case err := <-errCh:
+ logger.Error(ctx, "failed to collect network stats", slog.Error(err))
+ return
}
}
}
@@ -1144,3 +1182,159 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
return codersdk.UsageAppNameSSH
}
+
+func setStatsCallback(
+ ctx context.Context,
+ agentConn *workspacesdk.AgentConn,
+ logger slog.Logger,
+ networkInfoDir string,
+ networkInfoInterval time.Duration,
+) (<-chan error, error) {
+ fs, ok := ctx.Value("fs").(afero.Fs)
+ if !ok {
+ fs = afero.NewOsFs()
+ }
+ if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil {
+ return nil, xerrors.Errorf("mkdir: %w", err)
+ }
+
+ // The VS Code extension obtains the PID of the SSH process to
+ // read files to display logs and network info.
+ //
+ // We get the parent PID because it's assumed `ssh` is calling this
+ // command via the ProxyCommand SSH option.
+ pid := os.Getppid()
+
+ // The VS Code extension obtains the PID of the SSH process to
+ // read the file below which contains network information to display.
+ //
+ // We get the parent PID because it's assumed `ssh` is calling this
+ // command via the ProxyCommand SSH option.
+ networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid))
+
+ var (
+ firstErrTime time.Time
+ errCh = make(chan error, 1)
+ )
+ cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
+ sendErr := func(tolerate bool, err error) {
+ logger.Error(ctx, "collect network stats", slog.Error(err))
+ // Tolerate up to 1 minute of errors.
+ if tolerate {
+ if firstErrTime.IsZero() {
+ logger.Info(ctx, "tolerating network stats errors for up to 1 minute")
+ firstErrTime = time.Now()
+ }
+ if time.Since(firstErrTime) < time.Minute {
+ return
+ }
+ }
+
+ select {
+ case errCh <- err:
+ default:
+ }
+ }
+
+ stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
+ if err != nil {
+ sendErr(true, err)
+ return
+ }
+
+ rawStats, err := json.Marshal(stats)
+ if err != nil {
+ sendErr(false, err)
+ return
+ }
+ err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
+ if err != nil {
+ sendErr(false, err)
+ return
+ }
+
+ firstErrTime = time.Time{}
+ }
+
+ now := time.Now()
+ cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
+ agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
+ return errCh, nil
+}
+
+type sshNetworkStats struct {
+ P2P bool `json:"p2p"`
+ Latency float64 `json:"latency"`
+ PreferredDERP string `json:"preferred_derp"`
+ DERPLatency map[string]float64 `json:"derp_latency"`
+ UploadBytesSec int64 `json:"upload_bytes_sec"`
+ DownloadBytesSec int64 `json:"download_bytes_sec"`
+}
+
+func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
+ latency, p2p, pingResult, err := agentConn.Ping(ctx)
+ if err != nil {
+ return nil, err
+ }
+ node := agentConn.Node()
+ derpMap := agentConn.DERPMap()
+ derpLatency := map[string]float64{}
+
+ // Convert DERP region IDs to friendly names for display in the UI.
+ for rawRegion, latency := range node.DERPLatency {
+ regionParts := strings.SplitN(rawRegion, "-", 2)
+ regionID, err := strconv.Atoi(regionParts[0])
+ if err != nil {
+ continue
+ }
+ region, found := derpMap.Regions[regionID]
+ if !found {
+ // It's possible that a workspace agent is using an old DERPMap
+ // and reports regions that do not exist. If that's the case,
+ // report the region as unknown!
+ region = &tailcfg.DERPRegion{
+ RegionID: regionID,
+ RegionName: fmt.Sprintf("Unnamed %d", regionID),
+ }
+ }
+ // Convert the microseconds to milliseconds.
+ derpLatency[region.RegionName] = latency * 1000
+ }
+
+ totalRx := uint64(0)
+ totalTx := uint64(0)
+ for _, stat := range counts {
+ totalRx += stat.RxBytes
+ totalTx += stat.TxBytes
+ }
+ // Tracking the time since last request is required because
+ // ExtractTrafficStats() resets its counters after each call.
+ dur := end.Sub(start)
+ uploadSecs := float64(totalTx) / dur.Seconds()
+ downloadSecs := float64(totalRx) / dur.Seconds()
+
+ // Sometimes the preferred DERP doesn't match the one we're actually
+ // connected with. Perhaps because the agent prefers a different DERP and
+ // we're using that server instead.
+ preferredDerpID := node.PreferredDERP
+ if pingResult.DERPRegionID != 0 {
+ preferredDerpID = pingResult.DERPRegionID
+ }
+ preferredDerp, ok := derpMap.Regions[preferredDerpID]
+ preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID)
+ if ok {
+ preferredDerpName = preferredDerp.RegionName
+ }
+ if _, ok := derpLatency[preferredDerpName]; !ok {
+ derpLatency[preferredDerpName] = 0
+ }
+
+ return &sshNetworkStats{
+ P2P: p2p,
+ Latency: float64(latency.Microseconds()) / 1000,
+ PreferredDERP: preferredDerpName,
+ DERPLatency: derpLatency,
+ UploadBytesSec: int64(uploadSecs),
+ DownloadBytesSec: int64(downloadSecs),
+ }, nil
+}
diff --git a/cli/ssh_test.go b/cli/ssh_test.go
index bd107852251f7..9a16460ea5fe4 100644
--- a/cli/ssh_test.go
+++ b/cli/ssh_test.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/google/uuid"
+ "github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
@@ -438,6 +439,78 @@ func TestSSH(t *testing.T) {
<-cmdDone
})
+ t.Run("NetworkInfo", func(t *testing.T) {
+ t.Parallel()
+ client, workspace, agentToken := setupWorkspaceForAgent(t)
+ _, _ = tGoContext(t, func(ctx context.Context) {
+ // Run this async so the SSH command has to wait for
+ // the build and agent to connect!
+ _ = agenttest.New(t, client.URL, agentToken)
+ <-ctx.Done()
+ })
+
+ clientOutput, clientInput := io.Pipe()
+ serverOutput, serverInput := io.Pipe()
+ defer func() {
+ for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
+ _ = c.Close()
+ }
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
+ defer cancel()
+
+ fs := afero.NewMemMapFs()
+ //nolint:revive,staticcheck
+ ctx = context.WithValue(ctx, "fs", fs)
+
+ inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name, "--network-info-dir", "/net", "--network-info-interval", "25ms")
+ clitest.SetupConfig(t, client, root)
+ inv.Stdin = clientOutput
+ inv.Stdout = serverInput
+ inv.Stderr = io.Discard
+
+ cmdDone := tGo(t, func() {
+ err := inv.WithContext(ctx).Run()
+ assert.NoError(t, err)
+ })
+
+ conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
+ Reader: serverOutput,
+ Writer: clientInput,
+ }, "", &ssh.ClientConfig{
+ // #nosec
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ })
+ require.NoError(t, err)
+ defer conn.Close()
+
+ sshClient := ssh.NewClient(conn, channels, requests)
+ session, err := sshClient.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
+
+ command := "sh -c exit"
+ if runtime.GOOS == "windows" {
+ command = "cmd.exe /c exit"
+ }
+ err = session.Run(command)
+ require.NoError(t, err)
+ err = sshClient.Close()
+ require.NoError(t, err)
+ _ = clientOutput.Close()
+
+ assert.Eventually(t, func() bool {
+ entries, err := afero.ReadDir(fs, "/net")
+ if err != nil {
+ return false
+ }
+ return len(entries) > 0
+ }, testutil.WaitLong, testutil.IntervalFast)
+
+ <-cmdDone
+ })
+
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
t.Parallel()
diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden
index 80aaa3c204fda..d847e9d7abb03 100644
--- a/cli/testdata/coder_ssh_--help.golden
+++ b/cli/testdata/coder_ssh_--help.golden
@@ -30,6 +30,12 @@ OPTIONS:
-l, --log-dir string, $CODER_SSH_LOG_DIR
Specify the directory containing SSH diagnostic log files.
+ --network-info-dir string
+ Specifies a directory to write network information periodically.
+
+ --network-info-interval duration (default: 5s)
+ Specifies the interval to update network information.
+
--no-wait bool, $CODER_SSH_NO_WAIT
Enter workspace immediately after the agent has connected. This is the
default if the template has configured the agent startup script
diff --git a/cli/vscodessh.go b/cli/vscodessh.go
index d64e49c674a01..630c405241d17 100644
--- a/cli/vscodessh.go
+++ b/cli/vscodessh.go
@@ -2,21 +2,17 @@ package cli
import (
"context"
- "encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
- "strconv"
"strings"
"time"
"github.com/spf13/afero"
"golang.org/x/xerrors"
- "tailscale.com/tailcfg"
- "tailscale.com/types/netlogtype"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
@@ -83,11 +79,6 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
- err = fs.MkdirAll(networkInfoDir, 0o700)
- if err != nil {
- return xerrors.Errorf("mkdir: %w", err)
- }
-
client := codersdk.New(serverURL)
client.SetSessionToken(string(sessionToken))
@@ -155,20 +146,13 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
}
}
- // The VS Code extension obtains the PID of the SSH process to
- // read files to display logs and network info.
- //
- // We get the parent PID because it's assumed `ssh` is calling this
- // command via the ProxyCommand SSH option.
- pid := os.Getppid()
-
// Use a stripped down writer that doesn't sync, otherwise you get
// "failed to sync sloghuman: sync /dev/stderr: The handle is
// invalid" on Windows. Syncing isn't required for stdout/stderr
// anyways.
logger := inv.Logger.AppendSinks(sloghuman.Sink(slogWriter{w: inv.Stderr})).Leveled(slog.LevelDebug)
if logDir != "" {
- logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid))
+ logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", os.Getppid()))
logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return xerrors.Errorf("open log file %q: %w", logFilePath, err)
@@ -212,61 +196,10 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
_, _ = io.Copy(rawSSH, inv.Stdin)
}()
- // The VS Code extension obtains the PID of the SSH process to
- // read the file below which contains network information to display.
- //
- // We get the parent PID because it's assumed `ssh` is calling this
- // command via the ProxyCommand SSH option.
- networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid))
-
- var (
- firstErrTime time.Time
- errCh = make(chan error, 1)
- )
- cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
- sendErr := func(tolerate bool, err error) {
- logger.Error(ctx, "collect network stats", slog.Error(err))
- // Tolerate up to 1 minute of errors.
- if tolerate {
- if firstErrTime.IsZero() {
- logger.Info(ctx, "tolerating network stats errors for up to 1 minute")
- firstErrTime = time.Now()
- }
- if time.Since(firstErrTime) < time.Minute {
- return
- }
- }
-
- select {
- case errCh <- err:
- default:
- }
- }
-
- stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
- if err != nil {
- sendErr(true, err)
- return
- }
-
- rawStats, err := json.Marshal(stats)
- if err != nil {
- sendErr(false, err)
- return
- }
- err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
- if err != nil {
- sendErr(false, err)
- return
- }
-
- firstErrTime = time.Time{}
+ errCh, err := setStatsCallback(ctx, agentConn, logger, networkInfoDir, networkInfoInterval)
+ if err != nil {
+ return err
}
-
- now := time.Now()
- cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
- agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
-
select {
case <-ctx.Done():
return nil
@@ -323,80 +256,3 @@ var _ io.Writer = slogWriter{}
func (s slogWriter) Write(p []byte) (n int, err error) {
return s.w.Write(p)
}
-
-type sshNetworkStats struct {
- P2P bool `json:"p2p"`
- Latency float64 `json:"latency"`
- PreferredDERP string `json:"preferred_derp"`
- DERPLatency map[string]float64 `json:"derp_latency"`
- UploadBytesSec int64 `json:"upload_bytes_sec"`
- DownloadBytesSec int64 `json:"download_bytes_sec"`
-}
-
-func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
- latency, p2p, pingResult, err := agentConn.Ping(ctx)
- if err != nil {
- return nil, err
- }
- node := agentConn.Node()
- derpMap := agentConn.DERPMap()
- derpLatency := map[string]float64{}
-
- // Convert DERP region IDs to friendly names for display in the UI.
- for rawRegion, latency := range node.DERPLatency {
- regionParts := strings.SplitN(rawRegion, "-", 2)
- regionID, err := strconv.Atoi(regionParts[0])
- if err != nil {
- continue
- }
- region, found := derpMap.Regions[regionID]
- if !found {
- // It's possible that a workspace agent is using an old DERPMap
- // and reports regions that do not exist. If that's the case,
- // report the region as unknown!
- region = &tailcfg.DERPRegion{
- RegionID: regionID,
- RegionName: fmt.Sprintf("Unnamed %d", regionID),
- }
- }
- // Convert the microseconds to milliseconds.
- derpLatency[region.RegionName] = latency * 1000
- }
-
- totalRx := uint64(0)
- totalTx := uint64(0)
- for _, stat := range counts {
- totalRx += stat.RxBytes
- totalTx += stat.TxBytes
- }
- // Tracking the time since last request is required because
- // ExtractTrafficStats() resets its counters after each call.
- dur := end.Sub(start)
- uploadSecs := float64(totalTx) / dur.Seconds()
- downloadSecs := float64(totalRx) / dur.Seconds()
-
- // Sometimes the preferred DERP doesn't match the one we're actually
- // connected with. Perhaps because the agent prefers a different DERP and
- // we're using that server instead.
- preferredDerpID := node.PreferredDERP
- if pingResult.DERPRegionID != 0 {
- preferredDerpID = pingResult.DERPRegionID
- }
- preferredDerp, ok := derpMap.Regions[preferredDerpID]
- preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID)
- if ok {
- preferredDerpName = preferredDerp.RegionName
- }
- if _, ok := derpLatency[preferredDerpName]; !ok {
- derpLatency[preferredDerpName] = 0
- }
-
- return &sshNetworkStats{
- P2P: p2p,
- Latency: float64(latency.Microseconds()) / 1000,
- PreferredDERP: preferredDerpName,
- DERPLatency: derpLatency,
- UploadBytesSec: int64(uploadSecs),
- DownloadBytesSec: int64(downloadSecs),
- }, nil
-}
diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md
index 72513e0c9ecdc..74e28837ad7e4 100644
--- a/docs/reference/cli/ssh.md
+++ b/docs/reference/cli/ssh.md
@@ -103,6 +103,23 @@ Enable remote port forwarding (remote_port:local_address:local_port).
Set environment variable(s) for session (key1=value1,key2=value2,...).
+### --network-info-dir
+
+| | |
+|------|---------------------|
+| Type | string
|
+
+Specifies a directory to write network information periodically.
+
+### --network-info-interval
+
+| | |
+|---------|-----------------------|
+| Type | duration
|
+| Default | 5s
|
+
+Specifies the interval to update network information.
+
### --disable-autostart
| | |