Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f441ad6

Browse files
authored
fix(codersdk): keep workspace agent connection open after dial context (coder#10863)
1 parent 3a0a4dd commit f441ad6

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

coderd/workspaceagents_test.go

+24-9
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,20 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
416416
_ = agenttest.New(t, client.URL, r.AgentToken)
417417
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
418418

419-
ctx, cancelFunc := context.WithCancel(context.Background())
420-
defer cancelFunc()
421-
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
422-
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
423-
})
419+
conn, err := func() (*codersdk.WorkspaceAgentConn, error) {
420+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
421+
defer cancel() // Connection should remain open even if the dial context is canceled.
422+
423+
return client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
424+
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
425+
})
426+
}()
424427
require.NoError(t, err)
425428
defer conn.Close()
429+
430+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
431+
defer cancel()
432+
426433
sshClient, err := conn.SSHClient(ctx)
427434
require.NoError(t, err)
428435
session, err := sshClient.NewSession()
@@ -1355,12 +1362,20 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) {
13551362
agentID := resources[0].Agents[0].ID
13561363

13571364
// Connect from a client.
1358-
ctx := testutil.Context(t, testutil.WaitLong)
1359-
conn1, err := client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{
1360-
Logger: logger.Named("client1"),
1361-
})
1365+
conn1, err := func() (*codersdk.WorkspaceAgentConn, error) {
1366+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1367+
defer cancel() // Connection should remain open even if the dial context is canceled.
1368+
1369+
return client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{
1370+
Logger: logger.Named("client1"),
1371+
})
1372+
}()
13621373
require.NoError(t, err)
13631374
defer conn1.Close()
1375+
1376+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1377+
defer cancel()
1378+
13641379
ok := conn1.AwaitReachable(ctx)
13651380
require.True(t, ok)
13661381

codersdk/workspaceagents.go

+29-14
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,12 @@ type DialWorkspaceAgentOptions struct {
258258
BlockEndpoints bool
259259
}
260260

261-
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) {
261+
func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) {
262262
if options == nil {
263263
options = &DialWorkspaceAgentOptions{}
264264
}
265265

266-
connInfo, err := c.WorkspaceAgentConnectionInfo(ctx, agentID)
266+
connInfo, err := c.WorkspaceAgentConnectionInfo(dialCtx, agentID)
267267
if err != nil {
268268
return nil, xerrors.Errorf("get connection info: %w", err)
269269
}
@@ -302,7 +302,10 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
302302
tokenHeader = c.SessionTokenHeader
303303
}
304304
headers.Set(tokenHeader, c.SessionToken())
305-
ctx, cancel := context.WithCancel(ctx)
305+
306+
// New context, separate from dialCtx. We don't want to cancel the
307+
// connection if dialCtx is canceled.
308+
ctx, cancel := context.WithCancel(context.Background())
306309
defer func() {
307310
if err != nil {
308311
cancel()
@@ -314,7 +317,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
314317
return nil, xerrors.Errorf("parse url: %w", err)
315318
}
316319
closedCoordinator := make(chan struct{})
317-
firstCoordinator := make(chan error)
320+
// Must only ever be used once, send error OR close to avoid
321+
// reassignment race. Buffered so we don't hang in goroutine.
322+
firstCoordinator := make(chan error, 1)
318323
go func() {
319324
defer close(closedCoordinator)
320325
isFirst := true
@@ -366,7 +371,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
366371
return nil, xerrors.Errorf("parse url: %w", err)
367372
}
368373
closedDerpMap := make(chan struct{})
369-
firstDerpMap := make(chan error)
374+
// Must only ever be used once, send error OR close to avoid
375+
// reassignment race. Buffered so we don't hang in goroutine.
376+
firstDerpMap := make(chan error, 1)
370377
go func() {
371378
defer close(closedDerpMap)
372379
isFirst := true
@@ -420,13 +427,21 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
420427
}
421428
}()
422429

423-
err = <-firstCoordinator
424-
if err != nil {
425-
return nil, err
426-
}
427-
err = <-firstDerpMap
428-
if err != nil {
429-
return nil, err
430+
for firstCoordinator != nil || firstDerpMap != nil {
431+
select {
432+
case <-dialCtx.Done():
433+
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
434+
case err = <-firstCoordinator:
435+
if err != nil {
436+
return nil, xerrors.Errorf("start coordinator: %w", err)
437+
}
438+
firstCoordinator = nil
439+
case err = <-firstDerpMap:
440+
if err != nil {
441+
return nil, xerrors.Errorf("receive derp map: %w", err)
442+
}
443+
firstDerpMap = nil
444+
}
430445
}
431446

432447
agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{
@@ -444,9 +459,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
444459
},
445460
})
446461

447-
if !agentConn.AwaitReachable(ctx) {
462+
if !agentConn.AwaitReachable(dialCtx) {
448463
_ = agentConn.Close()
449-
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
464+
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", dialCtx.Err())
450465
}
451466

452467
return agentConn, nil

0 commit comments

Comments
 (0)