diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index e62240aeda56e..f9501d384a390 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/netip" "strconv" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -78,8 +79,14 @@ func TestDERP(t *testing.T) { DERPMap: derpMap, }) require.NoError(t, err) + + w2Ready := make(chan struct{}, 1) + w2ReadyOnce := sync.Once{} w1.SetNodeCallback(func(node *tailnet.Node) { w2.UpdateNodes([]*tailnet.Node{node}) + w2ReadyOnce.Do(func() { + close(w2Ready) + }) }) w2.SetNodeCallback(func(node *tailnet.Node) { w1.UpdateNodes([]*tailnet.Node{node}) @@ -98,6 +105,7 @@ func TestDERP(t *testing.T) { }() <-conn + <-w2Ready nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565)) require.NoError(t, err) _ = nc.Close() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index f367307e22e59..b4ac803e08c19 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) { t.Parallel() setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) { + t.Helper() + client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, }) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index cfe432d56f6ec..fd1b25a836608 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -29,6 +29,7 @@ import ( "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" ) func TestMain(m *testing.M) { @@ -131,6 +132,14 @@ func TestCache(t *testing.T) { return } defer release() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + if !conn.AwaitReachable(ctx) { + t.Error("agent not reachable") + return + } + transport := conn.HTTPTransport() defer transport.CloseIdleConnections() proxy.Transport = transport @@ -146,6 +155,8 @@ func TestCache(t *testing.T) { } func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { + t.Helper() + metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator() diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index a9898592d4bb1..7fb86c1d3a098 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct { func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err @@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort)) } @@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) @@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.ParseUint(rawPort, 10, 16) ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port)) + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } if network == "udp" { return c.Conn.DialContextUDP(ctx, ipp) } @@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { // Disable keep alives as we're usually only making a single // request, and this triggers goleak in tests DisableKeepAlives: true, - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if network != "tcp" { return nil, xerrors.Errorf("network must be tcp") } @@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { return nil, xerrors.Errorf("request %q does not appear to be for http api", addr) } - conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) + if !c.AwaitReachable(ctx) { + return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) + } + + conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("dial http api: %w", err) } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 9fbb9eb9200c6..b3940e154abc5 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, err } - return &WorkspaceAgentConn{ + agentConn := &WorkspaceAgentConn{ Conn: conn, CloseFunc: func() { cancelFunc() <-closed }, - }, nil + } + if !agentConn.AwaitReachable(ctx) { + _ = agentConn.Close() + return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) + } + + return agentConn, nil } // WorkspaceAgent returns an agent by ID. diff --git a/tailnet/conn.go b/tailnet/conn.go index fcdce9a72f930..9846d4096cde9 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -60,7 +60,7 @@ type Options struct { } // NewConn constructs a new Wireguard server that will accept connections from the addresses provided. -func NewConn(options *Options) (*Conn, error) { +func NewConn(options *Options) (conn *Conn, err error) { if options == nil { options = &Options{} } @@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) { if err != nil { return nil, xerrors.Errorf("create wireguard link monitor: %w", err) } + defer func() { + if err != nil { + wireguardMonitor.Close() + } + }() dialer := &tsdial.Dialer{ Logf: Logger(options.Logger), @@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) { if err != nil { return nil, xerrors.Errorf("create wgengine: %w", err) } + defer func() { + if err != nil { + wireguardEngine.Close() + } + }() dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := wireguardEngine.PeerForIP(ip) return ok @@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) { return netStack.DialContextTCP(ctx, dst) } netStack.ProcessLocalIPs = true - err = netStack.Start(nil) - if err != nil { - return nil, xerrors.Errorf("start netstack: %w", err) - } wireguardEngine = wgengine.NewWatchdog(wireguardEngine) wireguardEngine.SetDERPMap(options.DERPMap) netMapCopy := *netMap @@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) { }, wireguardEngine: wireguardEngine, } + defer func() { + if err != nil { + _ = server.Close() + } + }() wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err)) if err != nil { @@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) { server.sendNode() }) netStack.ForwardTCPIn = server.forwardTCP + + err = netStack.Start(nil) + if err != nil { + return nil, xerrors.Errorf("start netstack: %w", err) + } + return server, nil } @@ -519,22 +536,35 @@ func (c *Conn) Close() error { default: } close(c.closed) - for _, l := range c.listeners { - _ = l.closeNoLock() - } c.mutex.Unlock() - c.dialCancel() - _ = c.dialer.Close() - _ = c.magicConn.Close() + + var wg sync.WaitGroup + defer wg.Wait() + + if c.trafficStats != nil { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = c.trafficStats.Shutdown(ctx) + }() + } + _ = c.netStack.Close() + c.dialCancel() _ = c.wireguardMonitor.Close() - _ = c.tunDevice.Close() + _ = c.dialer.Close() + // Stops internals, e.g. tunDevice, magicConn and dnsManager. c.wireguardEngine.Close() - if c.trafficStats != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = c.trafficStats.Shutdown(ctx) + + c.mutex.Lock() + for _, l := range c.listeners { + _ = l.closeNoLock() } + c.listeners = nil + c.mutex.Unlock() + return nil } @@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) { connStats := connstats.NewStatistics(maxPeriod, maxConns, dump) + shutdown := func(s *connstats.Statistics) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.Shutdown(ctx) + } + c.mutex.Lock() + if c.isClosed() { + c.mutex.Unlock() + shutdown(connStats) + return + } old := c.trafficStats c.trafficStats = connStats c.mutex.Unlock() // Make sure to shutdown the old callback. if old != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = old.Shutdown(ctx) + shutdown(old) } c.tunDevice.SetStatistics(connStats) @@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr } // Logger converts the Tailscale logging function to use slog. func Logger(logger slog.Logger) tslogger.Logf { return tslogger.Logf(func(format string, args ...any) { + slog.Helper() logger.Debug(context.Background(), fmt.Sprintf(format, args...)) }) }