@@ -3,6 +3,7 @@ package cli
3
3
import (
4
4
"bytes"
5
5
"context"
6
+ "encoding/json"
6
7
"errors"
7
8
"fmt"
8
9
"io"
@@ -13,6 +14,7 @@ import (
13
14
"os/exec"
14
15
"path/filepath"
15
16
"slices"
17
+ "strconv"
16
18
"strings"
17
19
"sync"
18
20
"time"
@@ -21,11 +23,14 @@ import (
21
23
"github.com/gofrs/flock"
22
24
"github.com/google/uuid"
23
25
"github.com/mattn/go-isatty"
26
+ "github.com/spf13/afero"
24
27
gossh "golang.org/x/crypto/ssh"
25
28
gosshagent "golang.org/x/crypto/ssh/agent"
26
29
"golang.org/x/term"
27
30
"golang.org/x/xerrors"
28
31
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
32
+ "tailscale.com/tailcfg"
33
+ "tailscale.com/types/netlogtype"
29
34
30
35
"cdr.dev/slog"
31
36
"cdr.dev/slog/sloggers/sloghuman"
@@ -55,19 +60,21 @@ var (
55
60
56
61
func (r * RootCmd ) ssh () * serpent.Command {
57
62
var (
58
- stdio bool
59
- forwardAgent bool
60
- forwardGPG bool
61
- identityAgent string
62
- wsPollInterval time.Duration
63
- waitEnum string
64
- noWait bool
65
- logDirPath string
66
- remoteForwards []string
67
- env []string
68
- usageApp string
69
- disableAutostart bool
70
- appearanceConfig codersdk.AppearanceConfig
63
+ stdio bool
64
+ forwardAgent bool
65
+ forwardGPG bool
66
+ identityAgent string
67
+ wsPollInterval time.Duration
68
+ waitEnum string
69
+ noWait bool
70
+ logDirPath string
71
+ remoteForwards []string
72
+ env []string
73
+ usageApp string
74
+ disableAutostart bool
75
+ appearanceConfig codersdk.AppearanceConfig
76
+ networkInfoDir string
77
+ networkInfoInterval time.Duration
71
78
)
72
79
client := new (codersdk.Client )
73
80
cmd := & serpent.Command {
@@ -284,13 +291,21 @@ func (r *RootCmd) ssh() *serpent.Command {
284
291
return err
285
292
}
286
293
294
+ var errCh <- chan error
295
+ if networkInfoDir != "" {
296
+ errCh , err = setStatsCallback (ctx , conn , logger , networkInfoDir , networkInfoInterval )
297
+ if err != nil {
298
+ return err
299
+ }
300
+ }
301
+
287
302
wg .Add (1 )
288
303
go func () {
289
304
defer wg .Done ()
290
305
watchAndClose (ctx , func () error {
291
306
stack .close (xerrors .New ("watchAndClose" ))
292
307
return nil
293
- }, logger , client , workspace )
308
+ }, logger , client , workspace , errCh )
294
309
}()
295
310
copier .copy (& wg )
296
311
return nil
@@ -312,6 +327,14 @@ func (r *RootCmd) ssh() *serpent.Command {
312
327
return err
313
328
}
314
329
330
+ var errCh <- chan error
331
+ if networkInfoDir != "" {
332
+ errCh , err = setStatsCallback (ctx , conn , logger , networkInfoDir , networkInfoInterval )
333
+ if err != nil {
334
+ return err
335
+ }
336
+ }
337
+
315
338
wg .Add (1 )
316
339
go func () {
317
340
defer wg .Done ()
@@ -324,6 +347,7 @@ func (r *RootCmd) ssh() *serpent.Command {
324
347
logger ,
325
348
client ,
326
349
workspace ,
350
+ errCh ,
327
351
)
328
352
}()
329
353
@@ -540,6 +564,17 @@ func (r *RootCmd) ssh() *serpent.Command {
540
564
Value : serpent .StringOf (& usageApp ),
541
565
Hidden : true ,
542
566
},
567
+ {
568
+ Flag : "network-info-dir" ,
569
+ Description : "Specifies a directory to write network information periodically." ,
570
+ Value : serpent .StringOf (& networkInfoDir ),
571
+ },
572
+ {
573
+ Flag : "network-info-interval" ,
574
+ Description : "Specifies the interval to update network information." ,
575
+ Default : "5s" ,
576
+ Value : serpent .DurationOf (& networkInfoInterval ),
577
+ },
543
578
sshDisableAutostartOption (serpent .BoolOf (& disableAutostart )),
544
579
}
545
580
return cmd
@@ -555,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
555
590
// will usually not propagate.
556
591
//
557
592
// See: https://github.com/coder/coder/issues/6180
558
- func watchAndClose (ctx context.Context , closer func () error , logger slog.Logger , client * codersdk.Client , workspace codersdk.Workspace ) {
593
+ func watchAndClose (ctx context.Context , closer func () error , logger slog.Logger , client * codersdk.Client , workspace codersdk.Workspace , errCh <- chan error ) {
559
594
// Ensure session is ended on both context cancellation
560
595
// and workspace stop.
561
596
defer func () {
@@ -606,6 +641,9 @@ startWatchLoop:
606
641
logger .Info (ctx , "workspace stopped" )
607
642
return
608
643
}
644
+ case err := <- errCh :
645
+ logger .Error (ctx , "failed to collect network stats" , slog .Error (err ))
646
+ return
609
647
}
610
648
}
611
649
}
@@ -1144,3 +1182,159 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
1144
1182
1145
1183
return codersdk .UsageAppNameSSH
1146
1184
}
1185
+
1186
+ func setStatsCallback (
1187
+ ctx context.Context ,
1188
+ agentConn * workspacesdk.AgentConn ,
1189
+ logger slog.Logger ,
1190
+ networkInfoDir string ,
1191
+ networkInfoInterval time.Duration ,
1192
+ ) (<- chan error , error ) {
1193
+ fs , ok := ctx .Value ("fs" ).(afero.Fs )
1194
+ if ! ok {
1195
+ fs = afero .NewOsFs ()
1196
+ }
1197
+ if err := fs .MkdirAll (networkInfoDir , 0o700 ); err != nil {
1198
+ return nil , xerrors .Errorf ("mkdir: %w" , err )
1199
+ }
1200
+
1201
+ // The VS Code extension obtains the PID of the SSH process to
1202
+ // read files to display logs and network info.
1203
+ //
1204
+ // We get the parent PID because it's assumed `ssh` is calling this
1205
+ // command via the ProxyCommand SSH option.
1206
+ pid := os .Getppid ()
1207
+
1208
+ // The VS Code extension obtains the PID of the SSH process to
1209
+ // read the file below which contains network information to display.
1210
+ //
1211
+ // We get the parent PID because it's assumed `ssh` is calling this
1212
+ // command via the ProxyCommand SSH option.
1213
+ networkInfoFilePath := filepath .Join (networkInfoDir , fmt .Sprintf ("%d.json" , pid ))
1214
+
1215
+ var (
1216
+ firstErrTime time.Time
1217
+ errCh = make (chan error , 1 )
1218
+ )
1219
+ cb := func (start , end time.Time , virtual , _ map [netlogtype.Connection ]netlogtype.Counts ) {
1220
+ sendErr := func (tolerate bool , err error ) {
1221
+ logger .Error (ctx , "collect network stats" , slog .Error (err ))
1222
+ // Tolerate up to 1 minute of errors.
1223
+ if tolerate {
1224
+ if firstErrTime .IsZero () {
1225
+ logger .Info (ctx , "tolerating network stats errors for up to 1 minute" )
1226
+ firstErrTime = time .Now ()
1227
+ }
1228
+ if time .Since (firstErrTime ) < time .Minute {
1229
+ return
1230
+ }
1231
+ }
1232
+
1233
+ select {
1234
+ case errCh <- err :
1235
+ default :
1236
+ }
1237
+ }
1238
+
1239
+ stats , err := collectNetworkStats (ctx , agentConn , start , end , virtual )
1240
+ if err != nil {
1241
+ sendErr (true , err )
1242
+ return
1243
+ }
1244
+
1245
+ rawStats , err := json .Marshal (stats )
1246
+ if err != nil {
1247
+ sendErr (false , err )
1248
+ return
1249
+ }
1250
+ err = afero .WriteFile (fs , networkInfoFilePath , rawStats , 0o600 )
1251
+ if err != nil {
1252
+ sendErr (false , err )
1253
+ return
1254
+ }
1255
+
1256
+ firstErrTime = time.Time {}
1257
+ }
1258
+
1259
+ now := time .Now ()
1260
+ cb (now , now .Add (time .Nanosecond ), map [netlogtype.Connection ]netlogtype.Counts {}, map [netlogtype.Connection ]netlogtype.Counts {})
1261
+ agentConn .SetConnStatsCallback (networkInfoInterval , 2048 , cb )
1262
+ return errCh , nil
1263
+ }
1264
+
1265
+ type sshNetworkStats struct {
1266
+ P2P bool `json:"p2p"`
1267
+ Latency float64 `json:"latency"`
1268
+ PreferredDERP string `json:"preferred_derp"`
1269
+ DERPLatency map [string ]float64 `json:"derp_latency"`
1270
+ UploadBytesSec int64 `json:"upload_bytes_sec"`
1271
+ DownloadBytesSec int64 `json:"download_bytes_sec"`
1272
+ }
1273
+
1274
+ func collectNetworkStats (ctx context.Context , agentConn * workspacesdk.AgentConn , start , end time.Time , counts map [netlogtype.Connection ]netlogtype.Counts ) (* sshNetworkStats , error ) {
1275
+ latency , p2p , pingResult , err := agentConn .Ping (ctx )
1276
+ if err != nil {
1277
+ return nil , err
1278
+ }
1279
+ node := agentConn .Node ()
1280
+ derpMap := agentConn .DERPMap ()
1281
+ derpLatency := map [string ]float64 {}
1282
+
1283
+ // Convert DERP region IDs to friendly names for display in the UI.
1284
+ for rawRegion , latency := range node .DERPLatency {
1285
+ regionParts := strings .SplitN (rawRegion , "-" , 2 )
1286
+ regionID , err := strconv .Atoi (regionParts [0 ])
1287
+ if err != nil {
1288
+ continue
1289
+ }
1290
+ region , found := derpMap .Regions [regionID ]
1291
+ if ! found {
1292
+ // It's possible that a workspace agent is using an old DERPMap
1293
+ // and reports regions that do not exist. If that's the case,
1294
+ // report the region as unknown!
1295
+ region = & tailcfg.DERPRegion {
1296
+ RegionID : regionID ,
1297
+ RegionName : fmt .Sprintf ("Unnamed %d" , regionID ),
1298
+ }
1299
+ }
1300
+ // Convert the microseconds to milliseconds.
1301
+ derpLatency [region .RegionName ] = latency * 1000
1302
+ }
1303
+
1304
+ totalRx := uint64 (0 )
1305
+ totalTx := uint64 (0 )
1306
+ for _ , stat := range counts {
1307
+ totalRx += stat .RxBytes
1308
+ totalTx += stat .TxBytes
1309
+ }
1310
+ // Tracking the time since last request is required because
1311
+ // ExtractTrafficStats() resets its counters after each call.
1312
+ dur := end .Sub (start )
1313
+ uploadSecs := float64 (totalTx ) / dur .Seconds ()
1314
+ downloadSecs := float64 (totalRx ) / dur .Seconds ()
1315
+
1316
+ // Sometimes the preferred DERP doesn't match the one we're actually
1317
+ // connected with. Perhaps because the agent prefers a different DERP and
1318
+ // we're using that server instead.
1319
+ preferredDerpID := node .PreferredDERP
1320
+ if pingResult .DERPRegionID != 0 {
1321
+ preferredDerpID = pingResult .DERPRegionID
1322
+ }
1323
+ preferredDerp , ok := derpMap .Regions [preferredDerpID ]
1324
+ preferredDerpName := fmt .Sprintf ("Unnamed %d" , preferredDerpID )
1325
+ if ok {
1326
+ preferredDerpName = preferredDerp .RegionName
1327
+ }
1328
+ if _ , ok := derpLatency [preferredDerpName ]; ! ok {
1329
+ derpLatency [preferredDerpName ] = 0
1330
+ }
1331
+
1332
+ return & sshNetworkStats {
1333
+ P2P : p2p ,
1334
+ Latency : float64 (latency .Microseconds ()) / 1000 ,
1335
+ PreferredDERP : preferredDerpName ,
1336
+ DERPLatency : derpLatency ,
1337
+ UploadBytesSec : int64 (uploadSecs ),
1338
+ DownloadBytesSec : int64 (downloadSecs ),
1339
+ }, nil
1340
+ }
0 commit comments