From 03b09175eb90a810f3c3bfe9316c77fed5514c6b Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 21 Feb 2023 12:48:59 +0000 Subject: [PATCH 1/6] fix(tailnet): Improve start and close to detect connection races --- tailnet/conn.go | 62 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/tailnet/conn.go b/tailnet/conn.go index fcdce9a72f930..57f64b75010be 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -60,7 +60,7 @@ type Options struct { } // NewConn constructs a new Wireguard server that will accept connections from the addresses provided. -func NewConn(options *Options) (*Conn, error) { +func NewConn(options *Options) (conn *Conn, err error) { if options == nil { options = &Options{} } @@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) { if err != nil { return nil, xerrors.Errorf("create wireguard link monitor: %w", err) } + defer func() { + if err != nil { + wireguardMonitor.Close() + } + }() dialer := &tsdial.Dialer{ Logf: Logger(options.Logger), @@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) { if err != nil { return nil, xerrors.Errorf("create wgengine: %w", err) } + defer func() { + if err != nil { + wireguardEngine.Close() + } + }() dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := wireguardEngine.PeerForIP(ip) return ok @@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) { return netStack.DialContextTCP(ctx, dst) } netStack.ProcessLocalIPs = true - err = netStack.Start(nil) - if err != nil { - return nil, xerrors.Errorf("start netstack: %w", err) - } wireguardEngine = wgengine.NewWatchdog(wireguardEngine) wireguardEngine.SetDERPMap(options.DERPMap) netMapCopy := *netMap @@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) { }, wireguardEngine: wireguardEngine, } + defer func() { + if err != nil { + _ = server.Close() + } + }() wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err)) if err != nil { @@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) { server.sendNode() }) netStack.ForwardTCPIn = server.forwardTCP + + err = netStack.Start(nil) + if err != nil { + return nil, xerrors.Errorf("start netstack: %w", err) + } + return server, nil } @@ -519,22 +536,35 @@ func (c *Conn) Close() error { default: } close(c.closed) - for _, l := range c.listeners { - _ = l.closeNoLock() - } c.mutex.Unlock() - c.dialCancel() - _ = c.dialer.Close() - _ = c.magicConn.Close() + + var wg sync.WaitGroup + defer wg.Wait() + + if c.trafficStats != nil { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = c.trafficStats.Shutdown(ctx) + }() + } + _ = c.netStack.Close() + c.dialCancel() _ = c.wireguardMonitor.Close() - _ = c.tunDevice.Close() + _ = c.dialer.Close() + // Stops internals, e.g. tunDevice, magicConn and dnsManager. c.wireguardEngine.Close() - if c.trafficStats != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = c.trafficStats.Shutdown(ctx) + + c.mutex.Lock() + for _, l := range c.listeners { + _ = l.closeNoLock() } + c.listeners = nil + c.mutex.Unlock() + return nil } From 18aabcaf8b5cf22a75e2fa7e5d1a6d4c21773f70 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 21 Feb 2023 12:50:32 +0000 Subject: [PATCH 2/6] fix: Prevent agentConn use before ready via AwaitReachable --- coderd/coderd_test.go | 6 ++++++ coderd/workspaceagents_test.go | 2 ++ coderd/wsconncache/wsconncache_test.go | 11 +++++++++++ codersdk/workspaceagentconn.go | 21 ++++++++++++++++++--- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index e62240aeda56e..4771914e991ac 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -78,8 +78,13 @@ func TestDERP(t *testing.T) { DERPMap: derpMap, }) require.NoError(t, err) + w2Ready := make(chan struct{}, 1) w1.SetNodeCallback(func(node *tailnet.Node) { w2.UpdateNodes([]*tailnet.Node{node}) + select { + case w2Ready <- struct{}{}: + default: + } }) w2.SetNodeCallback(func(node *tailnet.Node) { w1.UpdateNodes([]*tailnet.Node{node}) @@ -98,6 +103,7 @@ func TestDERP(t *testing.T) { }() <-conn + <-w2Ready nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565)) require.NoError(t, err) _ = nc.Close() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index f367307e22e59..b4ac803e08c19 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) { t.Parallel() setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) { + t.Helper() + client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, }) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index cfe432d56f6ec..fd1b25a836608 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -29,6 +29,7 @@ import ( "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" ) func TestMain(m *testing.M) { @@ -131,6 +132,14 @@ func TestCache(t *testing.T) { return } defer release() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + if !conn.AwaitReachable(ctx) { + t.Error("agent not reachable") + return + } + transport := conn.HTTPTransport() defer transport.CloseIdleConnections() proxy.Transport = transport @@ -146,6 +155,8 @@ func TestCache(t *testing.T) { } func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { + t.Helper() + metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator() diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index a9898592d4bb1..7fb86c1d3a098 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct { func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err @@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort)) } @@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) @@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.ParseUint(rawPort, 10, 16) ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port)) + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } if network == "udp" { return c.Conn.DialContextUDP(ctx, ipp) } @@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { // Disable keep alives as we're usually only making a single // request, and this triggers goleak in tests DisableKeepAlives: true, - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if network != "tcp" { return nil, xerrors.Errorf("network must be tcp") } @@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { return nil, xerrors.Errorf("request %q does not appear to be for http api", addr) } - conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } + + conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("dial http api: %w", err) } From 1a905998c94291b58399ed9b5cda76f7fa2b0fc6 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 21 Feb 2023 13:55:34 +0000 Subject: [PATCH 3/6] fix(tailnet): Ensure connstats are closed on conn close --- tailnet/conn.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tailnet/conn.go b/tailnet/conn.go index 57f64b75010be..b65b6d630bde5 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -744,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) { connStats := connstats.NewStatistics(maxPeriod, maxConns, dump) + shutdown := func(s *connstats.Statistics) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.Shutdown(ctx) + } + c.mutex.Lock() + if c.isClosed() { + c.mutex.Unlock() + shutdown(connStats) + return + } old := c.trafficStats c.trafficStats = connStats c.mutex.Unlock() // Make sure to shutdown the old callback. if old != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = old.Shutdown(ctx) + shutdown(old) } c.tunDevice.SetStatistics(connStats) From 19023e75b3263f5f2da7e5220efa45d1c4a727c0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 21 Feb 2023 18:27:50 +0000 Subject: [PATCH 4/6] fix(codersdk): Use AwaitReachable in DialWorkspaceAgent --- codersdk/workspaceagents.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 9fbb9eb9200c6..b3940e154abc5 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, err } - return &WorkspaceAgentConn{ + agentConn := &WorkspaceAgentConn{ Conn: conn, CloseFunc: func() { cancelFunc() <-closed }, - }, nil + } + if !agentConn.AwaitReachable(ctx) { + _ = agentConn.Close() + return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) + } + + return agentConn, nil } // WorkspaceAgent returns an agent by ID. From cdf01f8a5f4cfcf6bbcb6a75404c12c3194e0a67 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 22 Feb 2023 13:29:11 +0000 Subject: [PATCH 5/6] Improve logging via slog.Helper() --- tailnet/conn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tailnet/conn.go b/tailnet/conn.go index b65b6d630bde5..9846d4096cde9 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -815,6 +815,7 @@ func (a addr) String() string { return a.ln.addr } // Logger converts the Tailscale logging function to use slog. func Logger(logger slog.Logger) tslogger.Logf { return tslogger.Logf(func(format string, args ...any) { + slog.Helper() logger.Debug(context.Background(), fmt.Sprintf(format, args...)) }) } From ac9dd2b009847d74d50d961c98cc99991e2985fc Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 23 Feb 2023 08:34:16 +0000 Subject: [PATCH 6/6] Use sync.Once in test --- coderd/coderd_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 4771914e991ac..f9501d384a390 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/netip" "strconv" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -78,13 +79,14 @@ func TestDERP(t *testing.T) { DERPMap: derpMap, }) require.NoError(t, err) + w2Ready := make(chan struct{}, 1) + w2ReadyOnce := sync.Once{} w1.SetNodeCallback(func(node *tailnet.Node) { w2.UpdateNodes([]*tailnet.Node{node}) - select { - case w2Ready <- struct{}{}: - default: - } + w2ReadyOnce.Do(func() { + close(w2Ready) + }) }) w2.SetNodeCallback(func(node *tailnet.Node) { w1.UpdateNodes([]*tailnet.Node{node})