Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix(cli/ssh): Avoid connection hang when workspace is stopped #7201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 72 additions & 8 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/retry"
)

var (
Expand Down Expand Up @@ -100,17 +101,82 @@ func (r *RootCmd) ssh() *clibase.Cmd {
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

// Enure connection is closed if the context is canceled or
// the workspace reaches the stopped state.
//
// Watching the stopped state is a work-around for cases
// where the agent is not gracefully shut down and the
// connection is left open. If, for instance, the networking
// is stopped before the agent is shut down, the disconnect
// will usually not propagate.
//
// See: https://github.com/coder/coder/issues/6180
watchAndClose := func(closer func() error) {
// Ensure session is ended on both context cancellation
// and workspace stop.
defer func() {
_ = closer()
}()

startWatchLoop:
for {
// (Re)connect to the coder server and watch workspace events.
var wsWatch <-chan codersdk.Workspace
var err error
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
wsWatch, err = client.WatchWorkspace(ctx, workspace.ID)
if err == nil {
break
}
if ctx.Err() != nil {
return
}
}

for {
select {
case <-ctx.Done():
return
case w, ok := <-wsWatch:
if !ok {
continue startWatchLoop
}

// Transitioning to stop or delete could mean that
// the agent will still gracefully stop. If a new
// build is starting, there's no reason to wait for
// the agent, it should be long gone.
if workspace.LatestBuild.ID != w.LatestBuild.ID && w.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
return
}
// Note, we only react to the stopped state here because we
// want to give the agent a chance to gracefully shut down
// during "stopping".
if w.LatestBuild.Status == codersdk.WorkspaceStatusStopped {
return
}
}
}
}
}

if stdio {
rawSSH, err := conn.SSH(ctx)
if err != nil {
return err
}
defer rawSSH.Close()
go watchAndClose(rawSSH.Close)

go func() {
_, _ = io.Copy(inv.Stdout, rawSSH)
// Ensure stdout copy closes incase stdin is closed
// unexpectedly. Typically we wouldn't worry about
// this since OpenSSH should kill the proxy command.
defer rawSSH.Close()

_, _ = io.Copy(rawSSH, inv.Stdin)
}()
_, _ = io.Copy(rawSSH, inv.Stdin)
_, _ = io.Copy(inv.Stdout, rawSSH)
return nil
}

Expand All @@ -125,13 +191,11 @@ func (r *RootCmd) ssh() *clibase.Cmd {
return err
}
defer sshSession.Close()

// Ensure context cancellation is propagated to the
// SSH session, e.g. to cancel `Wait()` at the end.
go func() {
<-ctx.Done()
go watchAndClose(func() error {
_ = sshSession.Close()
}()
_ = sshClient.Close()
return nil
})

if identityAgent == "" {
identityAgent = os.Getenv("SSH_AUTH_SOCK")
Expand Down
115 changes: 115 additions & 0 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/provisioner/echo"
Expand Down Expand Up @@ -143,6 +144,50 @@ func TestSSH(t *testing.T) {
cancel()
<-cmdDone
})

t.Run("ExitOnStop", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
}

client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
inv, root := clitest.New(t, "ssh", workspace.Name)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.Error(t, err)
})
pty.ExpectMatch("Waiting")

agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer func() {
_ = agentCloser.Close()
}()

// Ensure the agent is connected.
pty.WriteLine("echo hell'o'")
pty.ExpectMatchContext(ctx, "hello")

workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)

select {
case <-cmdDone:
case <-ctx.Done():
require.Fail(t, "command did not exit in time")
}
})

t.Run("Stdio", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
Expand Down Expand Up @@ -207,6 +252,76 @@ func TestSSH(t *testing.T) {

<-cmdDone
})

t.Run("StdioExitOnStop", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
}
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
_, _ = tGoContext(t, func(ctx context.Context) {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
<-ctx.Done()
_ = agentCloser.Close()
})

clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe()
defer func() {
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
_ = c.Close()
}
}()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root)
inv.Stdin = clientOutput
inv.Stdout = serverInput
inv.Stderr = io.Discard
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
})

conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
Reader: serverOutput,
Writer: clientInput,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
defer conn.Close()

sshClient := ssh.NewClient(conn, channels, requests)
defer sshClient.Close()

session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()

err = session.Shell()
require.NoError(t, err)

workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)

select {
case <-cmdDone:
case <-ctx.Done():
require.Fail(t, "command did not exit in time")
}
})

t.Run("ForwardAgent", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Test not supported on windows")
Expand Down