Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 838ee3b

Browse files
authored
feat: add --network-info-dir and --network-info-interval flags to coder ssh (#16078)
This is the first in a series of PRs to enable `coder ssh` to replace `coder vscodessh`. This change adds `--network-info-dir` and `--network-info-interval` flags to the `ssh` subcommand. These were formerly only available with the `vscodessh` subcommand. Subsequent PRs will add a `--ssh-host-prefix` flag to the ssh subcommand, and adjust the log file naming to contain the parent PID.
1 parent c2b5534 commit 838ee3b

File tree

5 files changed

+309
-163
lines changed

5 files changed

+309
-163
lines changed

cli/ssh.go

+209-15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cli
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
@@ -13,6 +14,7 @@ import (
1314
"os/exec"
1415
"path/filepath"
1516
"slices"
17+
"strconv"
1618
"strings"
1719
"sync"
1820
"time"
@@ -21,11 +23,14 @@ import (
2123
"github.com/gofrs/flock"
2224
"github.com/google/uuid"
2325
"github.com/mattn/go-isatty"
26+
"github.com/spf13/afero"
2427
gossh "golang.org/x/crypto/ssh"
2528
gosshagent "golang.org/x/crypto/ssh/agent"
2629
"golang.org/x/term"
2730
"golang.org/x/xerrors"
2831
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
32+
"tailscale.com/tailcfg"
33+
"tailscale.com/types/netlogtype"
2934

3035
"cdr.dev/slog"
3136
"cdr.dev/slog/sloggers/sloghuman"
@@ -55,19 +60,21 @@ var (
5560

5661
func (r *RootCmd) ssh() *serpent.Command {
5762
var (
58-
stdio bool
59-
forwardAgent bool
60-
forwardGPG bool
61-
identityAgent string
62-
wsPollInterval time.Duration
63-
waitEnum string
64-
noWait bool
65-
logDirPath string
66-
remoteForwards []string
67-
env []string
68-
usageApp string
69-
disableAutostart bool
70-
appearanceConfig codersdk.AppearanceConfig
63+
stdio bool
64+
forwardAgent bool
65+
forwardGPG bool
66+
identityAgent string
67+
wsPollInterval time.Duration
68+
waitEnum string
69+
noWait bool
70+
logDirPath string
71+
remoteForwards []string
72+
env []string
73+
usageApp string
74+
disableAutostart bool
75+
appearanceConfig codersdk.AppearanceConfig
76+
networkInfoDir string
77+
networkInfoInterval time.Duration
7178
)
7279
client := new(codersdk.Client)
7380
cmd := &serpent.Command{
@@ -284,13 +291,21 @@ func (r *RootCmd) ssh() *serpent.Command {
284291
return err
285292
}
286293

294+
var errCh <-chan error
295+
if networkInfoDir != "" {
296+
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
297+
if err != nil {
298+
return err
299+
}
300+
}
301+
287302
wg.Add(1)
288303
go func() {
289304
defer wg.Done()
290305
watchAndClose(ctx, func() error {
291306
stack.close(xerrors.New("watchAndClose"))
292307
return nil
293-
}, logger, client, workspace)
308+
}, logger, client, workspace, errCh)
294309
}()
295310
copier.copy(&wg)
296311
return nil
@@ -312,6 +327,14 @@ func (r *RootCmd) ssh() *serpent.Command {
312327
return err
313328
}
314329

330+
var errCh <-chan error
331+
if networkInfoDir != "" {
332+
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
333+
if err != nil {
334+
return err
335+
}
336+
}
337+
315338
wg.Add(1)
316339
go func() {
317340
defer wg.Done()
@@ -324,6 +347,7 @@ func (r *RootCmd) ssh() *serpent.Command {
324347
logger,
325348
client,
326349
workspace,
350+
errCh,
327351
)
328352
}()
329353

@@ -540,6 +564,17 @@ func (r *RootCmd) ssh() *serpent.Command {
540564
Value: serpent.StringOf(&usageApp),
541565
Hidden: true,
542566
},
567+
{
568+
Flag: "network-info-dir",
569+
Description: "Specifies a directory to write network information periodically.",
570+
Value: serpent.StringOf(&networkInfoDir),
571+
},
572+
{
573+
Flag: "network-info-interval",
574+
Description: "Specifies the interval to update network information.",
575+
Default: "5s",
576+
Value: serpent.DurationOf(&networkInfoInterval),
577+
},
543578
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
544579
}
545580
return cmd
@@ -555,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
555590
// will usually not propagate.
556591
//
557592
// See: https://github.com/coder/coder/issues/6180
558-
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace) {
593+
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace, errCh <-chan error) {
559594
// Ensure session is ended on both context cancellation
560595
// and workspace stop.
561596
defer func() {
@@ -606,6 +641,9 @@ startWatchLoop:
606641
logger.Info(ctx, "workspace stopped")
607642
return
608643
}
644+
case err := <-errCh:
645+
logger.Error(ctx, "failed to collect network stats", slog.Error(err))
646+
return
609647
}
610648
}
611649
}
@@ -1144,3 +1182,159 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
11441182

11451183
return codersdk.UsageAppNameSSH
11461184
}
1185+
1186+
func setStatsCallback(
1187+
ctx context.Context,
1188+
agentConn *workspacesdk.AgentConn,
1189+
logger slog.Logger,
1190+
networkInfoDir string,
1191+
networkInfoInterval time.Duration,
1192+
) (<-chan error, error) {
1193+
fs, ok := ctx.Value("fs").(afero.Fs)
1194+
if !ok {
1195+
fs = afero.NewOsFs()
1196+
}
1197+
if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil {
1198+
return nil, xerrors.Errorf("mkdir: %w", err)
1199+
}
1200+
1201+
// The VS Code extension obtains the PID of the SSH process to
1202+
// read files to display logs and network info.
1203+
//
1204+
// We get the parent PID because it's assumed `ssh` is calling this
1205+
// command via the ProxyCommand SSH option.
1206+
pid := os.Getppid()
1207+
1208+
// The VS Code extension obtains the PID of the SSH process to
1209+
// read the file below which contains network information to display.
1210+
//
1211+
// We get the parent PID because it's assumed `ssh` is calling this
1212+
// command via the ProxyCommand SSH option.
1213+
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid))
1214+
1215+
var (
1216+
firstErrTime time.Time
1217+
errCh = make(chan error, 1)
1218+
)
1219+
cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
1220+
sendErr := func(tolerate bool, err error) {
1221+
logger.Error(ctx, "collect network stats", slog.Error(err))
1222+
// Tolerate up to 1 minute of errors.
1223+
if tolerate {
1224+
if firstErrTime.IsZero() {
1225+
logger.Info(ctx, "tolerating network stats errors for up to 1 minute")
1226+
firstErrTime = time.Now()
1227+
}
1228+
if time.Since(firstErrTime) < time.Minute {
1229+
return
1230+
}
1231+
}
1232+
1233+
select {
1234+
case errCh <- err:
1235+
default:
1236+
}
1237+
}
1238+
1239+
stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
1240+
if err != nil {
1241+
sendErr(true, err)
1242+
return
1243+
}
1244+
1245+
rawStats, err := json.Marshal(stats)
1246+
if err != nil {
1247+
sendErr(false, err)
1248+
return
1249+
}
1250+
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
1251+
if err != nil {
1252+
sendErr(false, err)
1253+
return
1254+
}
1255+
1256+
firstErrTime = time.Time{}
1257+
}
1258+
1259+
now := time.Now()
1260+
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1261+
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1262+
return errCh, nil
1263+
}
1264+
1265+
type sshNetworkStats struct {
1266+
P2P bool `json:"p2p"`
1267+
Latency float64 `json:"latency"`
1268+
PreferredDERP string `json:"preferred_derp"`
1269+
DERPLatency map[string]float64 `json:"derp_latency"`
1270+
UploadBytesSec int64 `json:"upload_bytes_sec"`
1271+
DownloadBytesSec int64 `json:"download_bytes_sec"`
1272+
}
1273+
1274+
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1275+
latency, p2p, pingResult, err := agentConn.Ping(ctx)
1276+
if err != nil {
1277+
return nil, err
1278+
}
1279+
node := agentConn.Node()
1280+
derpMap := agentConn.DERPMap()
1281+
derpLatency := map[string]float64{}
1282+
1283+
// Convert DERP region IDs to friendly names for display in the UI.
1284+
for rawRegion, latency := range node.DERPLatency {
1285+
regionParts := strings.SplitN(rawRegion, "-", 2)
1286+
regionID, err := strconv.Atoi(regionParts[0])
1287+
if err != nil {
1288+
continue
1289+
}
1290+
region, found := derpMap.Regions[regionID]
1291+
if !found {
1292+
// It's possible that a workspace agent is using an old DERPMap
1293+
// and reports regions that do not exist. If that's the case,
1294+
// report the region as unknown!
1295+
region = &tailcfg.DERPRegion{
1296+
RegionID: regionID,
1297+
RegionName: fmt.Sprintf("Unnamed %d", regionID),
1298+
}
1299+
}
1300+
// Convert the microseconds to milliseconds.
1301+
derpLatency[region.RegionName] = latency * 1000
1302+
}
1303+
1304+
totalRx := uint64(0)
1305+
totalTx := uint64(0)
1306+
for _, stat := range counts {
1307+
totalRx += stat.RxBytes
1308+
totalTx += stat.TxBytes
1309+
}
1310+
// Tracking the time since last request is required because
1311+
// ExtractTrafficStats() resets its counters after each call.
1312+
dur := end.Sub(start)
1313+
uploadSecs := float64(totalTx) / dur.Seconds()
1314+
downloadSecs := float64(totalRx) / dur.Seconds()
1315+
1316+
// Sometimes the preferred DERP doesn't match the one we're actually
1317+
// connected with. Perhaps because the agent prefers a different DERP and
1318+
// we're using that server instead.
1319+
preferredDerpID := node.PreferredDERP
1320+
if pingResult.DERPRegionID != 0 {
1321+
preferredDerpID = pingResult.DERPRegionID
1322+
}
1323+
preferredDerp, ok := derpMap.Regions[preferredDerpID]
1324+
preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID)
1325+
if ok {
1326+
preferredDerpName = preferredDerp.RegionName
1327+
}
1328+
if _, ok := derpLatency[preferredDerpName]; !ok {
1329+
derpLatency[preferredDerpName] = 0
1330+
}
1331+
1332+
return &sshNetworkStats{
1333+
P2P: p2p,
1334+
Latency: float64(latency.Microseconds()) / 1000,
1335+
PreferredDERP: preferredDerpName,
1336+
DERPLatency: derpLatency,
1337+
UploadBytesSec: int64(uploadSecs),
1338+
DownloadBytesSec: int64(downloadSecs),
1339+
}, nil
1340+
}

cli/ssh_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"time"
2525

2626
"github.com/google/uuid"
27+
"github.com/spf13/afero"
2728
"github.com/stretchr/testify/assert"
2829
"github.com/stretchr/testify/require"
2930
"golang.org/x/crypto/ssh"
@@ -452,6 +453,78 @@ func TestSSH(t *testing.T) {
452453
<-cmdDone
453454
})
454455

456+
t.Run("NetworkInfo", func(t *testing.T) {
457+
t.Parallel()
458+
client, workspace, agentToken := setupWorkspaceForAgent(t)
459+
_, _ = tGoContext(t, func(ctx context.Context) {
460+
// Run this async so the SSH command has to wait for
461+
// the build and agent to connect!
462+
_ = agenttest.New(t, client.URL, agentToken)
463+
<-ctx.Done()
464+
})
465+
466+
clientOutput, clientInput := io.Pipe()
467+
serverOutput, serverInput := io.Pipe()
468+
defer func() {
469+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
470+
_ = c.Close()
471+
}
472+
}()
473+
474+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
475+
defer cancel()
476+
477+
fs := afero.NewMemMapFs()
478+
//nolint:revive,staticcheck
479+
ctx = context.WithValue(ctx, "fs", fs)
480+
481+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name, "--network-info-dir", "/net", "--network-info-interval", "25ms")
482+
clitest.SetupConfig(t, client, root)
483+
inv.Stdin = clientOutput
484+
inv.Stdout = serverInput
485+
inv.Stderr = io.Discard
486+
487+
cmdDone := tGo(t, func() {
488+
err := inv.WithContext(ctx).Run()
489+
assert.NoError(t, err)
490+
})
491+
492+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
493+
Reader: serverOutput,
494+
Writer: clientInput,
495+
}, "", &ssh.ClientConfig{
496+
// #nosec
497+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
498+
})
499+
require.NoError(t, err)
500+
defer conn.Close()
501+
502+
sshClient := ssh.NewClient(conn, channels, requests)
503+
session, err := sshClient.NewSession()
504+
require.NoError(t, err)
505+
defer session.Close()
506+
507+
command := "sh -c exit"
508+
if runtime.GOOS == "windows" {
509+
command = "cmd.exe /c exit"
510+
}
511+
err = session.Run(command)
512+
require.NoError(t, err)
513+
err = sshClient.Close()
514+
require.NoError(t, err)
515+
_ = clientOutput.Close()
516+
517+
assert.Eventually(t, func() bool {
518+
entries, err := afero.ReadDir(fs, "/net")
519+
if err != nil {
520+
return false
521+
}
522+
return len(entries) > 0
523+
}, testutil.WaitLong, testutil.IntervalFast)
524+
525+
<-cmdDone
526+
})
527+
455528
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
456529
t.Parallel()
457530

0 commit comments

Comments
 (0)