Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: update tailscale #6091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 10, 2023
30 changes: 20 additions & 10 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type Options struct {
type Client interface {
Metadata(ctx context.Context) (agentsdk.Metadata, error)
Listen(ctx context.Context) (net.Conn, error)
ReportStats(ctx context.Context, log slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
Expand Down Expand Up @@ -112,6 +112,7 @@ func New(options Options) io.Closer {
logDir: options.LogDir,
tempDir: options.TempDir,
lifecycleUpdate: make(chan struct{}, 1),
connStatsChan: make(chan *agentsdk.Stats, 1),
}
a.init(ctx)
return a
Expand Down Expand Up @@ -143,7 +144,8 @@ type agent struct {
lifecycleMu sync.Mutex // Protects following.
lifecycleState codersdk.WorkspaceAgentLifecycle

network *tailnet.Conn
network *tailnet.Conn
connStatsChan chan *agentsdk.Stats
}

// runLoop attempts to start the agent in a retry loop.
Expand Down Expand Up @@ -351,11 +353,20 @@ func (a *agent) run(ctx context.Context) error {
return xerrors.New("agent is closed")
}

setStatInterval := func(d time.Duration) {
network.SetConnStatsCallback(d, 2048,
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
select {
case a.connStatsChan <- convertAgentStats(virtual):
default:
a.logger.Warn(ctx, "network stat dropped")
}
},
)
}

// Report statistics from the created network.
cl, err := a.client.ReportStats(ctx, a.logger, func() *agentsdk.Stats {
stats := network.ExtractTrafficStats()
return convertAgentStats(stats)
})
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, setStatInterval)
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
} else {
Expand Down Expand Up @@ -399,10 +410,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {

func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) {
network, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
EnableTrafficStats: true,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
Expand Down
41 changes: 20 additions & 21 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ import (
"testing"
"time"

"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"

scp "github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/pion/udp"
Expand All @@ -37,6 +33,9 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
Expand All @@ -53,6 +52,8 @@ func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

// NOTE: These tests only work when your default shell is bash for some reason.

func TestAgent_Stats_SSH(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
Expand Down Expand Up @@ -1153,17 +1154,16 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
closer := agent.New(agent.Options{
Client: c,
Filesystem: fs,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
})
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
EnableTrafficStats: true,
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
Expand Down Expand Up @@ -1251,28 +1251,27 @@ func (c *client) Listen(_ context.Context) (net.Conn, error) {
return clientConn, nil
}

func (c *client) ReportStats(ctx context.Context, _ slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error) {
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)

go func() {
defer close(doneCh)

t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
select {
case c.statsChan <- stats():
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
Expand Down
59 changes: 39 additions & 20 deletions cli/vscodessh.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"

"github.com/coder/coder/codersdk"
)
Expand Down Expand Up @@ -92,6 +93,7 @@ func vscodeSSH() *cobra.Command {
if err != nil {
return xerrors.Errorf("find workspace: %w", err)
}

var agent codersdk.WorkspaceAgent
var found bool
for _, resource := range workspace.LatestBuild.Resources {
Expand All @@ -117,61 +119,78 @@ func vscodeSSH() *cobra.Command {
break
}
}
agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{
EnableTrafficStats: true,
})

agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{})
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer agentConn.Close()

agentConn.AwaitReachable(ctx)
rawSSH, err := agentConn.SSH(ctx)
if err != nil {
return err
}
defer rawSSH.Close()

// Copy SSH traffic over stdio.
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
go func() {
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
}()

// The VS Code extension obtains the PID of the SSH process to
// read the file below which contains network information to display.
//
// We get the parent PID because it's assumed `ssh` is calling this
// command via the ProxyCommand SSH option.
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
ticker := time.NewTicker(networkInfoInterval)
defer ticker.Stop()
lastCollected := time.Now()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:

statsErrChan := make(chan error, 1)
cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
sendErr := func(err error) {
select {
case statsErrChan <- err:
default:
}
}
stats, err := collectNetworkStats(ctx, agentConn, lastCollected)

stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
if err != nil {
return err
sendErr(err)
return
}

rawStats, err := json.Marshal(stats)
if err != nil {
return err
sendErr(err)
return
}
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0600)
if err != nil {
return err
sendErr(err)
return
}
lastCollected = time.Now()
}

now := time.Now()
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)

select {
case <-ctx.Done():
return nil
case err := <-statsErrChan:
return err
}
},
}
cmd.Flags().StringVarP(&networkInfoDir, "network-info-dir", "", "", "Specifies a directory to write network information periodically.")
cmd.Flags().StringVarP(&sessionTokenFile, "session-token-file", "", "", "Specifies a file that contains a session token.")
cmd.Flags().StringVarP(&urlFile, "url-file", "", "", "Specifies a file that contains the Coder URL.")
cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 3*time.Second, "Specifies the interval to update network information.")
cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 5*time.Second, "Specifies the interval to update network information.")
return cmd
}

Expand All @@ -184,7 +203,7 @@ type sshNetworkStats struct {
DownloadBytesSec int64 `json:"download_bytes_sec"`
}

func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgentConn, lastCollected time.Time) (*sshNetworkStats, error) {
func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
latency, p2p, err := agentConn.Ping(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -216,13 +235,13 @@ func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgent

totalRx := uint64(0)
totalTx := uint64(0)
for _, stat := range agentConn.ExtractTrafficStats() {
for _, stat := range counts {
totalRx += stat.RxBytes
totalTx += stat.TxBytes
}
// Tracking the time since last request is required because
// ExtractTrafficStats() resets its counters after each call.
dur := time.Since(lastCollected)
dur := end.Sub(start)
uploadSecs := float64(totalTx) / dur.Seconds()
downloadSecs := float64(totalRx) / dur.Seconds()

Expand Down
2 changes: 1 addition & 1 deletion coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (c *client) Listen(_ context.Context) (net.Conn, error) {
return clientConn, nil
}

func (*client) ReportStats(_ context.Context, _ slog.Logger, _ func() *agentsdk.Stats) (io.Closer, error) {
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
return io.NopCloser(strings.NewReader("")), nil
}

Expand Down
55 changes: 33 additions & 22 deletions codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,44 +368,55 @@ func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateRes

// ReportStats begins a stat streaming connection with the Coder server.
// It is resilient to network failures and intermittent coderd issues.
func (c *Client) ReportStats(
ctx context.Context,
log slog.Logger,
getStats func() *Stats,
) (io.Closer, error) {
func (c *Client) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *Stats, setInterval func(time.Duration)) (io.Closer, error) {
var interval time.Duration
ctx, cancel := context.WithCancel(ctx)
exited := make(chan struct{})

postStat := func(stat *Stats) {
var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, stat)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
}

nextInterval = resp.ReportInterval
break
}

if nextInterval != 0 && interval != nextInterval {
setInterval(nextInterval)
}
interval = nextInterval
}

// Send an empty stat to get the interval.
postStat(&Stats{ConnsByProto: map[string]int64{}})

go func() {
// Immediately trigger a stats push to get the correct interval.
timer := time.NewTimer(time.Nanosecond)
defer timer.Stop()
defer close(exited)

for {
select {
case <-ctx.Done():
return
case <-timer.C:
}

var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, getStats())
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
case stat, ok := <-statsChan:
if !ok {
return
}

nextInterval = resp.ReportInterval
break
postStat(stat)
}
timer.Reset(nextInterval)
}
}()

return closeFunc(func() error {
cancel()
<-exited
return nil
}), nil
}
Expand Down
Loading