Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: update tailscale #6091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 10, 2023
30 changes: 20 additions & 10 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type Options struct {
type Client interface {
Metadata(ctx context.Context) (agentsdk.Metadata, error)
Listen(ctx context.Context) (net.Conn, error)
ReportStats(ctx context.Context, log slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
Expand Down Expand Up @@ -112,6 +112,7 @@ func New(options Options) io.Closer {
logDir: options.LogDir,
tempDir: options.TempDir,
lifecycleUpdate: make(chan struct{}, 1),
connStatsChan: make(chan *agentsdk.Stats, 1),
}
a.init(ctx)
return a
Expand Down Expand Up @@ -143,7 +144,8 @@ type agent struct {
lifecycleMu sync.Mutex // Protects following.
lifecycleState codersdk.WorkspaceAgentLifecycle

network *tailnet.Conn
network *tailnet.Conn
connStatsChan chan *agentsdk.Stats
}

// runLoop attempts to start the agent in a retry loop.
Expand Down Expand Up @@ -351,11 +353,20 @@ func (a *agent) run(ctx context.Context) error {
return xerrors.New("agent is closed")
}

setStatInterval := func(d time.Duration) {
network.SetConnStatsCallback(d, 2048,
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
select {
case a.connStatsChan <- convertAgentStats(virtual):
default:
a.logger.Warn(ctx, "network stat dropped")
}
},
)
}

// Report statistics from the created network.
cl, err := a.client.ReportStats(ctx, a.logger, func() *agentsdk.Stats {
stats := network.ExtractTrafficStats()
return convertAgentStats(stats)
})
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, setStatInterval)
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
} else {
Expand Down Expand Up @@ -399,10 +410,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {

func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) {
network, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
EnableTrafficStats: true,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
Expand Down
41 changes: 20 additions & 21 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ import (
"testing"
"time"

"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"

scp "github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/pion/udp"
Expand All @@ -37,6 +33,9 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
Expand All @@ -53,6 +52,8 @@ func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

// NOTE: These tests only work when your default shell is bash for some reason.

func TestAgent_Stats_SSH(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
Expand Down Expand Up @@ -1153,17 +1154,16 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
closer := agent.New(agent.Options{
Client: c,
Filesystem: fs,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
})
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
EnableTrafficStats: true,
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
Expand Down Expand Up @@ -1251,28 +1251,27 @@ func (c *client) Listen(_ context.Context) (net.Conn, error) {
return clientConn, nil
}

func (c *client) ReportStats(ctx context.Context, _ slog.Logger, stats func() *agentsdk.Stats) (io.Closer, error) {
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)

go func() {
defer close(doneCh)

t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
select {
case c.statsChan <- stats():
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
Expand Down
50 changes: 29 additions & 21 deletions cli/vscodessh.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"

"github.com/coder/coder/codersdk"
)
Expand Down Expand Up @@ -92,6 +93,7 @@ func vscodeSSH() *cobra.Command {
if err != nil {
return xerrors.Errorf("find workspace: %w", err)
}

var agent codersdk.WorkspaceAgent
var found bool
for _, resource := range workspace.LatestBuild.Resources {
Expand All @@ -117,61 +119,67 @@ func vscodeSSH() *cobra.Command {
break
}
}
agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{
EnableTrafficStats: true,
})

agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{})
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer agentConn.Close()

agentConn.AwaitReachable(ctx)
rawSSH, err := agentConn.SSH(ctx)
if err != nil {
return err
}
defer rawSSH.Close()

// Copy SSH traffic over stdio.
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
go func() {
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
}()

// The VS Code extension obtains the PID of the SSH process to
// read the file below which contains network information to display.
//
// We get the parent PID because it's assumed `ssh` is calling this
// command via the ProxyCommand SSH option.
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
ticker := time.NewTicker(networkInfoInterval)
defer ticker.Stop()
lastCollected := time.Now()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
}
stats, err := collectNetworkStats(ctx, agentConn, lastCollected)

statsErrChan := make(chan error, 1)
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
if err != nil {
return err
statsErrChan <- err
return
}

rawStats, err := json.Marshal(stats)
if err != nil {
return err
statsErrChan <- err
return
}
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0600)
if err != nil {
return err
statsErrChan <- err
return
}
lastCollected = time.Now()
})

select {
case <-ctx.Done():
return nil
case err := <-statsErrChan:
return err
}
},
}
cmd.Flags().StringVarP(&networkInfoDir, "network-info-dir", "", "", "Specifies a directory to write network information periodically.")
cmd.Flags().StringVarP(&sessionTokenFile, "session-token-file", "", "", "Specifies a file that contains a session token.")
cmd.Flags().StringVarP(&urlFile, "url-file", "", "", "Specifies a file that contains the Coder URL.")
cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 3*time.Second, "Specifies the interval to update network information.")
cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 5*time.Second, "Specifies the interval to update network information.")
return cmd
}

Expand All @@ -184,7 +192,7 @@ type sshNetworkStats struct {
DownloadBytesSec int64 `json:"download_bytes_sec"`
}

func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgentConn, lastCollected time.Time) (*sshNetworkStats, error) {
func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
latency, p2p, err := agentConn.Ping(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -216,13 +224,13 @@ func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgent

totalRx := uint64(0)
totalTx := uint64(0)
for _, stat := range agentConn.ExtractTrafficStats() {
for _, stat := range counts {
totalRx += stat.RxBytes
totalTx += stat.TxBytes
}
// Tracking the time since last request is required because
// ExtractTrafficStats() resets its counters after each call.
dur := time.Since(lastCollected)
dur := end.Sub(start)
uploadSecs := float64(totalTx) / dur.Seconds()
downloadSecs := float64(totalRx) / dur.Seconds()

Expand Down
2 changes: 1 addition & 1 deletion coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (c *client) Listen(_ context.Context) (net.Conn, error) {
return clientConn, nil
}

func (*client) ReportStats(_ context.Context, _ slog.Logger, _ func() *agentsdk.Stats) (io.Closer, error) {
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
return io.NopCloser(strings.NewReader("")), nil
}

Expand Down
53 changes: 30 additions & 23 deletions codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,39 +368,46 @@ func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateRes

// ReportStats begins a stat streaming connection with the Coder server.
// It is resilient to network failures and intermittent coderd issues.
func (c *Client) ReportStats(
ctx context.Context,
log slog.Logger,
getStats func() *Stats,
) (io.Closer, error) {
func (c *Client) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *Stats, setInterval func(time.Duration)) (io.Closer, error) {
var interval time.Duration
ctx, cancel := context.WithCancel(ctx)

go func() {
// Immediately trigger a stats push to get the correct interval.
timer := time.NewTimer(time.Nanosecond)
defer timer.Stop()
postStat := func(stat *Stats) {
var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, stat)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
}

nextInterval = resp.ReportInterval
break
}

if interval != nextInterval {
setInterval(nextInterval)
}
interval = nextInterval
}

// Send an empty stat to get the interval.
postStat(&Stats{ConnsByProto: map[string]int64{}})

go func() {
for {
select {
case <-ctx.Done():
return
case <-timer.C:
}

var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, getStats())
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
case stat, ok := <-statsChan:
if !ok {
return
}

nextInterval = resp.ReportInterval
break
postStat(stat)
}
timer.Reset(nextInterval)
}
}()

Expand Down
12 changes: 5 additions & 7 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ type WorkspaceAgentConnectionInfo struct {
type DialWorkspaceAgentOptions struct {
Logger slog.Logger
// BlockEndpoints forced a direct connection through DERP.
BlockEndpoints bool
EnableTrafficStats bool
BlockEndpoints bool
}

func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*WorkspaceAgentConn, error) {
Expand All @@ -121,11 +120,10 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti

ip := tailnet.IP()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: options.Logger,
BlockEndpoints: options.BlockEndpoints,
EnableTrafficStats: options.EnableTrafficStats,
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: options.Logger,
BlockEndpoints: options.BlockEndpoints,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
Expand Down
Loading