diff --git a/cli/ssh.go b/cli/ssh.go index e1ebbcd04cfd2..5b89ca9fc56ad 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -30,6 +30,7 @@ import ( "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" + "github.com/coder/retry" ) var ( @@ -100,17 +101,82 @@ func (r *RootCmd) ssh() *clibase.Cmd { stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) defer stopPolling() + // Enure connection is closed if the context is canceled or + // the workspace reaches the stopped state. + // + // Watching the stopped state is a work-around for cases + // where the agent is not gracefully shut down and the + // connection is left open. If, for instance, the networking + // is stopped before the agent is shut down, the disconnect + // will usually not propagate. + // + // See: https://github.com/coder/coder/issues/6180 + watchAndClose := func(closer func() error) { + // Ensure session is ended on both context cancellation + // and workspace stop. + defer func() { + _ = closer() + }() + + startWatchLoop: + for { + // (Re)connect to the coder server and watch workspace events. + var wsWatch <-chan codersdk.Workspace + var err error + for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); { + wsWatch, err = client.WatchWorkspace(ctx, workspace.ID) + if err == nil { + break + } + if ctx.Err() != nil { + return + } + } + + for { + select { + case <-ctx.Done(): + return + case w, ok := <-wsWatch: + if !ok { + continue startWatchLoop + } + + // Transitioning to stop or delete could mean that + // the agent will still gracefully stop. If a new + // build is starting, there's no reason to wait for + // the agent, it should be long gone. + if workspace.LatestBuild.ID != w.LatestBuild.ID && w.LatestBuild.Transition == codersdk.WorkspaceTransitionStart { + return + } + // Note, we only react to the stopped state here because we + // want to give the agent a chance to gracefully shut down + // during "stopping". + if w.LatestBuild.Status == codersdk.WorkspaceStatusStopped { + return + } + } + } + } + } + if stdio { rawSSH, err := conn.SSH(ctx) if err != nil { return err } defer rawSSH.Close() + go watchAndClose(rawSSH.Close) go func() { - _, _ = io.Copy(inv.Stdout, rawSSH) + // Ensure stdout copy closes incase stdin is closed + // unexpectedly. Typically we wouldn't worry about + // this since OpenSSH should kill the proxy command. + defer rawSSH.Close() + + _, _ = io.Copy(rawSSH, inv.Stdin) }() - _, _ = io.Copy(rawSSH, inv.Stdin) + _, _ = io.Copy(inv.Stdout, rawSSH) return nil } @@ -125,13 +191,11 @@ func (r *RootCmd) ssh() *clibase.Cmd { return err } defer sshSession.Close() - - // Ensure context cancellation is propagated to the - // SSH session, e.g. to cancel `Wait()` at the end. - go func() { - <-ctx.Done() + go watchAndClose(func() error { _ = sshSession.Close() - }() + _ = sshClient.Close() + return nil + }) if identityAgent == "" { identityAgent = os.Getenv("SSH_AUTH_SOCK") diff --git a/cli/ssh_test.go b/cli/ssh_test.go index ec1dc1cb46b74..ee544a328e2ea 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/provisioner/echo" @@ -143,6 +144,50 @@ func TestSSH(t *testing.T) { cancel() <-cmdDone }) + + t.Run("ExitOnStop", func(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it") + } + + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + inv, root := clitest.New(t, "ssh", workspace.Name) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.Error(t, err) + }) + pty.ExpectMatch("Waiting") + + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(agentToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + Logger: slogtest.Make(t, nil).Named("agent"), + }) + defer func() { + _ = agentCloser.Close() + }() + + // Ensure the agent is connected. + pty.WriteLine("echo hell'o'") + pty.ExpectMatchContext(ctx, "hello") + + workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop) + + select { + case <-cmdDone: + case <-ctx.Done(): + require.Fail(t, "command did not exit in time") + } + }) + t.Run("Stdio", func(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t, nil) @@ -207,6 +252,76 @@ func TestSSH(t *testing.T) { <-cmdDone }) + + t.Run("StdioExitOnStop", func(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it") + } + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(agentToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + Logger: slogtest.Make(t, nil).Named("agent"), + }) + <-ctx.Done() + _ = agentCloser.Close() + }) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + defer sshClient.Close() + + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + err = session.Shell() + require.NoError(t, err) + + workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop) + + select { + case <-cmdDone: + case <-ctx.Done(): + require.Fail(t, "command did not exit in time") + } + }) + t.Run("ForwardAgent", func(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Test not supported on windows")