Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a414de9

Browse files
authored
fix(tailnet): Improve tailnet setup and agentconn stability (#6292)
* fix(tailnet): Improve start and close to detect connection races * fix: Prevent agentConn use before ready via AwaitReachable * fix(tailnet): Ensure connstats are closed on conn close * fix(codersdk): Use AwaitReachable in DialWorkspaceAgent * fix(tailnet): Improve logging via slog.Helper()
1 parent 473ab20 commit a414de9

File tree

6 files changed

+106
-24
lines changed

6 files changed

+106
-24
lines changed

coderd/coderd_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/http"
77
"net/netip"
88
"strconv"
9+
"sync"
910
"testing"
1011

1112
"github.com/stretchr/testify/assert"
@@ -78,8 +79,14 @@ func TestDERP(t *testing.T) {
7879
DERPMap: derpMap,
7980
})
8081
require.NoError(t, err)
82+
83+
w2Ready := make(chan struct{}, 1)
84+
w2ReadyOnce := sync.Once{}
8185
w1.SetNodeCallback(func(node *tailnet.Node) {
8286
w2.UpdateNodes([]*tailnet.Node{node})
87+
w2ReadyOnce.Do(func() {
88+
close(w2Ready)
89+
})
8390
})
8491
w2.SetNodeCallback(func(node *tailnet.Node) {
8592
w1.UpdateNodes([]*tailnet.Node{node})
@@ -98,6 +105,7 @@ func TestDERP(t *testing.T) {
98105
}()
99106

100107
<-conn
108+
<-w2Ready
101109
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
102110
require.NoError(t, err)
103111
_ = nc.Close()

coderd/workspaceagents_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
469469
t.Parallel()
470470

471471
setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) {
472+
t.Helper()
473+
472474
client := coderdtest.New(t, &coderdtest.Options{
473475
IncludeProvisionerDaemon: true,
474476
})

coderd/wsconncache/wsconncache_test.go

+11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/coder/coder/codersdk/agentsdk"
3030
"github.com/coder/coder/tailnet"
3131
"github.com/coder/coder/tailnet/tailnettest"
32+
"github.com/coder/coder/testutil"
3233
)
3334

3435
func TestMain(m *testing.M) {
@@ -131,6 +132,14 @@ func TestCache(t *testing.T) {
131132
return
132133
}
133134
defer release()
135+
136+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
137+
defer cancel()
138+
if !conn.AwaitReachable(ctx) {
139+
t.Error("agent not reachable")
140+
return
141+
}
142+
134143
transport := conn.HTTPTransport()
135144
defer transport.CloseIdleConnections()
136145
proxy.Transport = transport
@@ -146,6 +155,8 @@ func TestCache(t *testing.T) {
146155
}
147156

148157
func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
158+
t.Helper()
159+
149160
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
150161

151162
coordinator := tailnet.NewCoordinator()

codersdk/workspaceagentconn.go

+18-3
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct {
176176
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
177177
ctx, span := tracing.StartSpan(ctx)
178178
defer span.End()
179-
179+
if !c.AwaitReachable(ctx) {
180+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
181+
}
180182
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
181183
if err != nil {
182184
return nil, err
@@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
207209
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
208210
ctx, span := tracing.StartSpan(ctx)
209211
defer span.End()
212+
if !c.AwaitReachable(ctx) {
213+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
214+
}
210215
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
211216
}
212217

@@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
235240
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
236241
ctx, span := tracing.StartSpan(ctx)
237242
defer span.End()
243+
if !c.AwaitReachable(ctx) {
244+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
245+
}
238246
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
239247
if err != nil {
240248
return nil, xerrors.Errorf("dial speedtest: %w", err)
@@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad
257265
_, rawPort, _ := net.SplitHostPort(addr)
258266
port, _ := strconv.ParseUint(rawPort, 10, 16)
259267
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
268+
if !c.AwaitReachable(ctx) {
269+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
270+
}
260271
if network == "udp" {
261272
return c.Conn.DialContextUDP(ctx, ipp)
262273
}
@@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
317328
// Disable keep alives as we're usually only making a single
318329
// request, and this triggers goleak in tests
319330
DisableKeepAlives: true,
320-
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
331+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
321332
if network != "tcp" {
322333
return nil, xerrors.Errorf("network must be tcp")
323334
}
@@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
331342
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
332343
}
333344

334-
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
345+
if !c.AwaitReachable(ctx) {
346+
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
347+
}
348+
349+
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
335350
if err != nil {
336351
return nil, xerrors.Errorf("dial http api: %w", err)
337352
}

codersdk/workspaceagents.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
199199
return nil, err
200200
}
201201

202-
return &WorkspaceAgentConn{
202+
agentConn := &WorkspaceAgentConn{
203203
Conn: conn,
204204
CloseFunc: func() {
205205
cancelFunc()
206206
<-closed
207207
},
208-
}, nil
208+
}
209+
if !agentConn.AwaitReachable(ctx) {
210+
_ = agentConn.Close()
211+
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
212+
}
213+
214+
return agentConn, nil
209215
}
210216

211217
// WorkspaceAgent returns an agent by ID.

tailnet/conn.go

+59-19
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ type Options struct {
6060
}
6161

6262
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
63-
func NewConn(options *Options) (*Conn, error) {
63+
func NewConn(options *Options) (conn *Conn, err error) {
6464
if options == nil {
6565
options = &Options{}
6666
}
@@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
123123
if err != nil {
124124
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
125125
}
126+
defer func() {
127+
if err != nil {
128+
wireguardMonitor.Close()
129+
}
130+
}()
126131

127132
dialer := &tsdial.Dialer{
128133
Logf: Logger(options.Logger),
@@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
134139
if err != nil {
135140
return nil, xerrors.Errorf("create wgengine: %w", err)
136141
}
142+
defer func() {
143+
if err != nil {
144+
wireguardEngine.Close()
145+
}
146+
}()
137147
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
138148
_, ok := wireguardEngine.PeerForIP(ip)
139149
return ok
@@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
166176
return netStack.DialContextTCP(ctx, dst)
167177
}
168178
netStack.ProcessLocalIPs = true
169-
err = netStack.Start(nil)
170-
if err != nil {
171-
return nil, xerrors.Errorf("start netstack: %w", err)
172-
}
173179
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
174180
wireguardEngine.SetDERPMap(options.DERPMap)
175181
netMapCopy := *netMap
@@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
203209
},
204210
wireguardEngine: wireguardEngine,
205211
}
212+
defer func() {
213+
if err != nil {
214+
_ = server.Close()
215+
}
216+
}()
206217
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
207218
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
208219
if err != nil {
@@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
236247
server.sendNode()
237248
})
238249
netStack.ForwardTCPIn = server.forwardTCP
250+
251+
err = netStack.Start(nil)
252+
if err != nil {
253+
return nil, xerrors.Errorf("start netstack: %w", err)
254+
}
255+
239256
return server, nil
240257
}
241258

@@ -519,22 +536,35 @@ func (c *Conn) Close() error {
519536
default:
520537
}
521538
close(c.closed)
522-
for _, l := range c.listeners {
523-
_ = l.closeNoLock()
524-
}
525539
c.mutex.Unlock()
526-
c.dialCancel()
527-
_ = c.dialer.Close()
528-
_ = c.magicConn.Close()
540+
541+
var wg sync.WaitGroup
542+
defer wg.Wait()
543+
544+
if c.trafficStats != nil {
545+
wg.Add(1)
546+
go func() {
547+
defer wg.Done()
548+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
549+
defer cancel()
550+
_ = c.trafficStats.Shutdown(ctx)
551+
}()
552+
}
553+
529554
_ = c.netStack.Close()
555+
c.dialCancel()
530556
_ = c.wireguardMonitor.Close()
531-
_ = c.tunDevice.Close()
557+
_ = c.dialer.Close()
558+
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
532559
c.wireguardEngine.Close()
533-
if c.trafficStats != nil {
534-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
535-
defer cancel()
536-
_ = c.trafficStats.Shutdown(ctx)
560+
561+
c.mutex.Lock()
562+
for _, l := range c.listeners {
563+
_ = l.closeNoLock()
537564
}
565+
c.listeners = nil
566+
c.mutex.Unlock()
567+
538568
return nil
539569
}
540570

@@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
714744
func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) {
715745
connStats := connstats.NewStatistics(maxPeriod, maxConns, dump)
716746

747+
shutdown := func(s *connstats.Statistics) {
748+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
749+
defer cancel()
750+
_ = s.Shutdown(ctx)
751+
}
752+
717753
c.mutex.Lock()
754+
if c.isClosed() {
755+
c.mutex.Unlock()
756+
shutdown(connStats)
757+
return
758+
}
718759
old := c.trafficStats
719760
c.trafficStats = connStats
720761
c.mutex.Unlock()
721762

722763
// Make sure to shutdown the old callback.
723764
if old != nil {
724-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
725-
defer cancel()
726-
_ = old.Shutdown(ctx)
765+
shutdown(old)
727766
}
728767

729768
c.tunDevice.SetStatistics(connStats)
@@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr }
776815
// Logger converts the Tailscale logging function to use slog.
777816
func Logger(logger slog.Logger) tslogger.Logf {
778817
return tslogger.Logf(func(format string, args ...any) {
818+
slog.Helper()
779819
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
780820
})
781821
}

0 commit comments

Comments
 (0)