Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 527f1f3

Browse files
mafredrijohnstcn
andauthored
feat: Add SSH agent forwarding support to coder agent (#1548)
* feat: Add SSH agent forwarding support to coder agent * feat: Add forward agent flag to `coder ssh` * refactor: Share setup between SSH tests, sync goroutines * feat: Add test for `coder ssh --forward-agent` * fix: Fix test flakes and implement Deans suggestion for helpers * fix: Add example to config-ssh * fix: Allow forwarding agent via -A Co-authored-by: Cian Johnston <[email protected]>
1 parent 22ef456 commit 527f1f3

File tree

4 files changed

+211
-69
lines changed

4 files changed

+211
-69
lines changed

agent/agent.go

+10
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
391391
return err
392392
}
393393

394+
if ssh.AgentRequested(session) {
395+
l, err := ssh.NewAgentListener()
396+
if err != nil {
397+
return xerrors.Errorf("new agent listener: %w", err)
398+
}
399+
defer l.Close()
400+
go ssh.ForwardAgentConnections(l, session)
401+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
402+
}
403+
394404
sshPty, windowSize, isPty := session.Pty()
395405
if isPty {
396406
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

cli/configssh.go

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ func configSSH() *cobra.Command {
3838
Annotations: workspaceCommand,
3939
Use: "config-ssh",
4040
Short: "Populate your SSH config with Host entries for all of your workspaces",
41+
Example: `
42+
- You can use -o (or --ssh-option) so set SSH options to be used for all your
43+
workspaces.
44+
45+
` + cliui.Styles.Code.Render("$ coder config-ssh -o ForwardAgent=yes"),
4146
RunE: func(cmd *cobra.Command, args []string) error {
4247
client, err := createClient(cmd)
4348
if err != nil {

cli/ssh.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/mattn/go-isatty"
1616
"github.com/spf13/cobra"
1717
gossh "golang.org/x/crypto/ssh"
18+
gosshagent "golang.org/x/crypto/ssh/agent"
1819
"golang.org/x/term"
1920
"golang.org/x/xerrors"
2021

@@ -32,6 +33,7 @@ func ssh() *cobra.Command {
3233
var (
3334
stdio bool
3435
shuffle bool
36+
forwardAgent bool
3537
wsPollInterval time.Duration
3638
)
3739
cmd := &cobra.Command{
@@ -108,6 +110,17 @@ func ssh() *cobra.Command {
108110
return err
109111
}
110112

113+
if forwardAgent && os.Getenv("SSH_AUTH_SOCK") != "" {
114+
err = gosshagent.ForwardToRemote(sshClient, os.Getenv("SSH_AUTH_SOCK"))
115+
if err != nil {
116+
return xerrors.Errorf("forward agent failed: %w", err)
117+
}
118+
err = gosshagent.RequestAgentForwarding(sshSession)
119+
if err != nil {
120+
return xerrors.Errorf("request agent forwarding failed: %w", err)
121+
}
122+
}
123+
111124
stdoutFile, valid := cmd.OutOrStdout().(*os.File)
112125
if valid && isatty.IsTerminal(stdoutFile.Fd()) {
113126
state, err := term.MakeRaw(int(os.Stdin.Fd()))
@@ -156,8 +169,9 @@ func ssh() *cobra.Command {
156169
}
157170
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
158171
cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace")
159-
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
160172
_ = cmd.Flags().MarkHidden("shuffle")
173+
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
174+
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
161175

162176
return cmd
163177
}

cli/ssh_test.go

+181-68
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package cli_test
22

33
import (
4+
"context"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
7+
"crypto/rand"
8+
"errors"
49
"io"
510
"net"
11+
"path/filepath"
612
"runtime"
713
"testing"
814
"time"
@@ -11,9 +17,11 @@ import (
1117
"github.com/stretchr/testify/assert"
1218
"github.com/stretchr/testify/require"
1319
"golang.org/x/crypto/ssh"
20+
gosshagent "golang.org/x/crypto/ssh/agent"
1421

1522
"cdr.dev/slog"
1623
"cdr.dev/slog/sloggers/slogtest"
24+
1725
"github.com/coder/coder/agent"
1826
"github.com/coder/coder/cli/clitest"
1927
"github.com/coder/coder/coderd/coderdtest"
@@ -23,49 +31,53 @@ import (
2331
"github.com/coder/coder/pty/ptytest"
2432
)
2533

34+
func setupWorkspaceForSSH(t *testing.T) (*codersdk.Client, codersdk.Workspace, string) {
35+
t.Helper()
36+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
37+
user := coderdtest.CreateFirstUser(t, client)
38+
agentToken := uuid.NewString()
39+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
40+
Parse: echo.ParseComplete,
41+
ProvisionDryRun: echo.ProvisionComplete,
42+
Provision: []*proto.Provision_Response{{
43+
Type: &proto.Provision_Response_Complete{
44+
Complete: &proto.Provision_Complete{
45+
Resources: []*proto.Resource{{
46+
Name: "dev",
47+
Type: "google_compute_instance",
48+
Agents: []*proto.Agent{{
49+
Id: uuid.NewString(),
50+
Auth: &proto.Agent_Token{
51+
Token: agentToken,
52+
},
53+
}},
54+
}},
55+
},
56+
},
57+
}},
58+
})
59+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
60+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
61+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
62+
63+
return client, workspace, agentToken
64+
}
65+
2666
func TestSSH(t *testing.T) {
27-
t.Skip("This is causing test flakes. TODO @cian fix this")
2867
t.Parallel()
2968
t.Run("ImmediateExit", func(t *testing.T) {
3069
t.Parallel()
31-
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
32-
user := coderdtest.CreateFirstUser(t, client)
33-
agentToken := uuid.NewString()
34-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
35-
Parse: echo.ParseComplete,
36-
ProvisionDryRun: echo.ProvisionComplete,
37-
Provision: []*proto.Provision_Response{{
38-
Type: &proto.Provision_Response_Complete{
39-
Complete: &proto.Provision_Complete{
40-
Resources: []*proto.Resource{{
41-
Name: "dev",
42-
Type: "google_compute_instance",
43-
Agents: []*proto.Agent{{
44-
Id: uuid.NewString(),
45-
Auth: &proto.Agent_Token{
46-
Token: agentToken,
47-
},
48-
}},
49-
}},
50-
},
51-
},
52-
}},
53-
})
54-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
55-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
56-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
70+
client, workspace, agentToken := setupWorkspaceForSSH(t)
5771
cmd, root := clitest.New(t, "ssh", workspace.Name)
5872
clitest.SetupConfig(t, client, root)
59-
doneChan := make(chan struct{})
6073
pty := ptytest.New(t)
6174
cmd.SetIn(pty.Input())
6275
cmd.SetErr(pty.Output())
6376
cmd.SetOut(pty.Output())
64-
go func() {
65-
defer close(doneChan)
77+
cmdDone := tGo(t, func() {
6678
err := cmd.Execute()
6779
assert.NoError(t, err)
68-
}()
80+
})
6981
pty.ExpectMatch("Waiting")
7082
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
7183
agentClient := codersdk.New(client.URL)
@@ -76,39 +88,16 @@ func TestSSH(t *testing.T) {
7688
t.Cleanup(func() {
7789
_ = agentCloser.Close()
7890
})
91+
7992
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
8093
pty.WriteLine("exit")
81-
<-doneChan
94+
<-cmdDone
8295
})
8396
t.Run("Stdio", func(t *testing.T) {
8497
t.Parallel()
85-
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
86-
user := coderdtest.CreateFirstUser(t, client)
87-
agentToken := uuid.NewString()
88-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
89-
Parse: echo.ParseComplete,
90-
ProvisionDryRun: echo.ProvisionComplete,
91-
Provision: []*proto.Provision_Response{{
92-
Type: &proto.Provision_Response_Complete{
93-
Complete: &proto.Provision_Complete{
94-
Resources: []*proto.Resource{{
95-
Name: "dev",
96-
Type: "google_compute_instance",
97-
Agents: []*proto.Agent{{
98-
Id: uuid.NewString(),
99-
Auth: &proto.Agent_Token{
100-
Token: agentToken,
101-
},
102-
}},
103-
}},
104-
},
105-
},
106-
}},
107-
})
108-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
109-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
110-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
111-
go func() {
98+
client, workspace, agentToken := setupWorkspaceForSSH(t)
99+
100+
_, _ = tGoContext(t, func(ctx context.Context) {
112101
// Run this async so the SSH command has to wait for
113102
// the build and agent to connect!
114103
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
@@ -117,25 +106,22 @@ func TestSSH(t *testing.T) {
117106
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
118107
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
119108
})
120-
t.Cleanup(func() {
121-
_ = agentCloser.Close()
122-
})
123-
}()
109+
<-ctx.Done()
110+
_ = agentCloser.Close()
111+
})
124112

125113
clientOutput, clientInput := io.Pipe()
126114
serverOutput, serverInput := io.Pipe()
127115

128116
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
129117
clitest.SetupConfig(t, client, root)
130-
doneChan := make(chan struct{})
131118
cmd.SetIn(clientOutput)
132119
cmd.SetOut(serverInput)
133120
cmd.SetErr(io.Discard)
134-
go func() {
135-
defer close(doneChan)
121+
cmdDone := tGo(t, func() {
136122
err := cmd.Execute()
137123
assert.NoError(t, err)
138-
}()
124+
})
139125

140126
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
141127
Reader: serverOutput,
@@ -157,8 +143,135 @@ func TestSSH(t *testing.T) {
157143
err = sshClient.Close()
158144
require.NoError(t, err)
159145
_ = clientOutput.Close()
160-
<-doneChan
146+
147+
<-cmdDone
148+
})
149+
//nolint:paralleltest // Disabled due to use of t.Setenv.
150+
t.Run("ForwardAgent", func(t *testing.T) {
151+
if runtime.GOOS == "windows" {
152+
t.Skip("Test not supported on windows")
153+
}
154+
155+
client, workspace, agentToken := setupWorkspaceForSSH(t)
156+
157+
_, _ = tGoContext(t, func(ctx context.Context) {
158+
// Run this async so the SSH command has to wait for
159+
// the build and agent to connect!
160+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
161+
agentClient := codersdk.New(client.URL)
162+
agentClient.SessionToken = agentToken
163+
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
164+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
165+
})
166+
<-ctx.Done()
167+
_ = agentCloser.Close()
168+
})
169+
170+
// Generate private key.
171+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
172+
require.NoError(t, err)
173+
kr := gosshagent.NewKeyring()
174+
kr.Add(gosshagent.AddedKey{
175+
PrivateKey: privateKey,
176+
})
177+
178+
// Start up ssh agent listening on unix socket.
179+
tmpdir := t.TempDir()
180+
agentSock := filepath.Join(tmpdir, "agent.sock")
181+
l, err := net.Listen("unix", agentSock)
182+
require.NoError(t, err)
183+
defer l.Close()
184+
_ = tGo(t, func() {
185+
for {
186+
fd, err := l.Accept()
187+
if err != nil {
188+
if !errors.Is(err, net.ErrClosed) {
189+
t.Logf("accept error: %v", err)
190+
}
191+
return
192+
}
193+
194+
err = gosshagent.ServeAgent(kr, fd)
195+
if !errors.Is(err, io.EOF) {
196+
assert.NoError(t, err)
197+
}
198+
}
199+
})
200+
201+
t.Setenv("SSH_AUTH_SOCK", agentSock)
202+
cmd, root := clitest.New(t,
203+
"ssh",
204+
workspace.Name,
205+
"--forward-agent",
206+
)
207+
clitest.SetupConfig(t, client, root)
208+
pty := ptytest.New(t)
209+
cmd.SetIn(pty.Input())
210+
cmd.SetOut(pty.Output())
211+
cmd.SetErr(io.Discard)
212+
cmdDone := tGo(t, func() {
213+
err := cmd.Execute()
214+
assert.NoError(t, err)
215+
})
216+
217+
// Ensure that SSH_AUTH_SOCK is set.
218+
// Linux: /tmp/auth-agent3167016167/listener.sock
219+
// macOS: /var/folders/ng/m1q0wft14hj0t3rtjxrdnzsr0000gn/T/auth-agent3245553419/listener.sock
220+
pty.WriteLine("env")
221+
pty.ExpectMatch("SSH_AUTH_SOCK=")
222+
// Ensure that ssh-add lists our key.
223+
pty.WriteLine("ssh-add -L")
224+
keys, err := kr.List()
225+
require.NoError(t, err)
226+
pty.ExpectMatch(keys[0].String())
227+
228+
// And we're done.
229+
pty.WriteLine("exit")
230+
<-cmdDone
231+
})
232+
}
233+
234+
// tGoContext runs fn in a goroutine passing a context that will be
235+
// canceled on test completion and wait until fn has finished executing.
236+
// Done and cancel are returned for optionally waiting until completion
237+
// or early cancellation.
238+
//
239+
// NOTE(mafredri): This could be moved to a helper library.
240+
func tGoContext(t *testing.T, fn func(context.Context)) (done <-chan struct{}, cancel context.CancelFunc) {
241+
t.Helper()
242+
243+
ctx, cancel := context.WithCancel(context.Background())
244+
doneC := make(chan struct{})
245+
t.Cleanup(func() {
246+
cancel()
247+
<-done
248+
})
249+
go func() {
250+
fn(ctx)
251+
close(doneC)
252+
}()
253+
254+
return doneC, cancel
255+
}
256+
257+
// tGo runs fn in a goroutine and waits until fn has completed before
258+
// test completion. Done is returned for optionally waiting for fn to
259+
// exit.
260+
//
261+
// NOTE(mafredri): This could be moved to a helper library.
262+
func tGo(t *testing.T, fn func()) (done <-chan struct{}) {
263+
t.Helper()
264+
265+
doneC := make(chan struct{})
266+
t.Cleanup(func() {
267+
<-doneC
161268
})
269+
go func() {
270+
fn()
271+
close(doneC)
272+
}()
273+
274+
return doneC
162275
}
163276

164277
type stdioConn struct {

0 commit comments

Comments
 (0)