Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9689bca

Browse files
authored
feat(cli): implement ssh remote forward (#8515)
1 parent c68e809 commit 9689bca

13 files changed

+255
-98
lines changed

cli/portforward.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
3232
client := new(codersdk.Client)
3333
cmd := &clibase.Cmd{
3434
Use: "port-forward <workspace>",
35-
Short: "Forward ports from machine to a workspace",
35+
Short: `Forward ports from a workspace to the local machine. Forward ports from a workspace to the local machine. For reverse port forwarding, use "coder ssh -R".`,
3636
Aliases: []string{"tunnel"},
3737
Long: formatExamples(
3838
example{

cli/remoteforward.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package cli
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"net"
8+
"regexp"
9+
"strconv"
10+
11+
gossh "golang.org/x/crypto/ssh"
12+
"golang.org/x/xerrors"
13+
14+
"github.com/coder/coder/agent/agentssh"
15+
)
16+
17+
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
18+
// cookie which is written to the connection before forwarding.
19+
type cookieAddr struct {
20+
net.Addr
21+
cookie []byte
22+
}
23+
24+
// Format:
25+
// remote_port:local_address:local_port
26+
var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
27+
28+
func validateRemoteForward(flag string) bool {
29+
return remoteForwardRegex.MatchString(flag)
30+
}
31+
32+
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
33+
matches := remoteForwardRegex.FindStringSubmatch(flag)
34+
35+
remotePort, err := strconv.Atoi(matches[1])
36+
if err != nil {
37+
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
38+
}
39+
localAddress, err := net.ResolveIPAddr("ip", matches[2])
40+
if err != nil {
41+
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
42+
}
43+
localPort, err := strconv.Atoi(matches[3])
44+
if err != nil {
45+
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
46+
}
47+
48+
localAddr := &net.TCPAddr{
49+
IP: localAddress.IP,
50+
Port: localPort,
51+
}
52+
53+
remoteAddr := &net.TCPAddr{
54+
IP: net.ParseIP("127.0.0.1"),
55+
Port: remotePort,
56+
}
57+
return localAddr, remoteAddr, nil
58+
}
59+
60+
// sshRemoteForward starts forwarding connections from a remote listener to a
61+
// local address via SSH in a goroutine.
62+
//
63+
// Accepts a `cookieAddr` as the local address.
64+
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
65+
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
66+
if err != nil {
67+
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
68+
}
69+
70+
go func() {
71+
for {
72+
remoteConn, err := listener.Accept()
73+
if err != nil {
74+
if ctx.Err() == nil {
75+
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
76+
}
77+
return
78+
}
79+
80+
go func() {
81+
defer remoteConn.Close()
82+
83+
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
84+
if err != nil {
85+
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
86+
return
87+
}
88+
defer localConn.Close()
89+
90+
if c, ok := localAddr.(cookieAddr); ok {
91+
_, err = localConn.Write(c.cookie)
92+
if err != nil {
93+
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
94+
return
95+
}
96+
}
97+
98+
agentssh.Bicopy(ctx, localConn, remoteConn)
99+
}()
100+
}
101+
}()
102+
103+
return listener, nil
104+
}

cli/ssh.go

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"errors"
77
"fmt"
88
"io"
9-
"net"
109
"net/url"
1110
"os"
1211
"os/exec"
@@ -27,7 +26,6 @@ import (
2726
"cdr.dev/slog"
2827
"cdr.dev/slog/sloggers/sloghuman"
2928

30-
"github.com/coder/coder/agent/agentssh"
3129
"github.com/coder/coder/cli/clibase"
3230
"github.com/coder/coder/cli/cliui"
3331
"github.com/coder/coder/coderd/autobuild/notify"
@@ -53,6 +51,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
5351
waitEnum string
5452
noWait bool
5553
logDirPath string
54+
remoteForward string
5655
)
5756
client := new(codersdk.Client)
5857
cmd := &clibase.Cmd{
@@ -122,6 +121,16 @@ func (r *RootCmd) ssh() *clibase.Cmd {
122121
client.SetLogger(logger)
123122
}
124123

124+
if remoteForward != "" {
125+
isValid := validateRemoteForward(remoteForward)
126+
if !isValid {
127+
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
128+
}
129+
if isValid && stdio {
130+
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
131+
}
132+
}
133+
125134
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0])
126135
if err != nil {
127136
return err
@@ -198,6 +207,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
198207
}
199208
defer conn.Close()
200209
conn.AwaitReachable(ctx)
210+
201211
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
202212
defer stopPolling()
203213

@@ -300,6 +310,19 @@ func (r *RootCmd) ssh() *clibase.Cmd {
300310
defer closer.Close()
301311
}
302312

313+
if remoteForward != "" {
314+
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
315+
if err != nil {
316+
return err
317+
}
318+
319+
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
320+
if err != nil {
321+
return xerrors.Errorf("ssh remote forward: %w", err)
322+
}
323+
defer closer.Close()
324+
}
325+
303326
stdoutFile, validOut := inv.Stdout.(*os.File)
304327
stdinFile, validIn := inv.Stdin.(*os.File)
305328
if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) {
@@ -424,6 +447,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
424447
FlagShorthand: "l",
425448
Value: clibase.StringOf(&logDirPath),
426449
},
450+
{
451+
Flag: "remote-forward",
452+
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
453+
Env: "CODER_SSH_REMOTE_FORWARD",
454+
FlagShorthand: "R",
455+
Value: clibase.StringOf(&remoteForward),
456+
},
427457
}
428458
return cmd
429459
}
@@ -568,8 +598,15 @@ func getWorkspaceAndAgent(ctx context.Context, inv *clibase.Invocation, client *
568598
// of the CLI running simultaneously.
569599
func tryPollWorkspaceAutostop(ctx context.Context, client *codersdk.Client, workspace codersdk.Workspace) (stop func()) {
570600
lock := flock.New(filepath.Join(os.TempDir(), "coder-autostop-notify-"+workspace.ID.String()))
571-
condition := notifyCondition(ctx, client, workspace.ID, lock)
572-
return notify.Notify(condition, workspacePollInterval, autostopNotifyCountdown...)
601+
conditionCtx, cancelCondition := context.WithCancel(ctx)
602+
condition := notifyCondition(conditionCtx, client, workspace.ID, lock)
603+
stopFunc := notify.Notify(condition, workspacePollInterval, autostopNotifyCountdown...)
604+
return func() {
605+
// With many "ssh" processes running, `lock.TryLockContext` can be hanging until the context canceled.
606+
// Without this cancellation, a CLI process with failed remote-forward could be hanging indefinitely.
607+
cancelCondition()
608+
stopFunc()
609+
}
573610
}
574611

575612
// Notify the user if the workspace is due to shutdown.
@@ -752,56 +789,3 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
752789

753790
return string(bytes.TrimSpace(remoteSocket)), nil
754791
}
755-
756-
// cookieAddr is a special net.Addr accepted by sshForward() which includes a
757-
// cookie which is written to the connection before forwarding.
758-
type cookieAddr struct {
759-
net.Addr
760-
cookie []byte
761-
}
762-
763-
// sshForwardRemote starts forwarding connections from a remote listener to a
764-
// local address via SSH in a goroutine.
765-
//
766-
// Accepts a `cookieAddr` as the local address.
767-
func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
768-
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
769-
if err != nil {
770-
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
771-
}
772-
773-
go func() {
774-
for {
775-
remoteConn, err := listener.Accept()
776-
if err != nil {
777-
if ctx.Err() == nil {
778-
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
779-
}
780-
return
781-
}
782-
783-
go func() {
784-
defer remoteConn.Close()
785-
786-
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
787-
if err != nil {
788-
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
789-
return
790-
}
791-
defer localConn.Close()
792-
793-
if c, ok := localAddr.(cookieAddr); ok {
794-
_, err = localConn.Write(c.cookie)
795-
if err != nil {
796-
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
797-
return
798-
}
799-
}
800-
801-
agentssh.Bicopy(ctx, localConn, remoteConn)
802-
}()
803-
}
804-
}()
805-
806-
return listener, nil
807-
}

cli/ssh_other.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Cli
4444
Net: "unix",
4545
}
4646

47-
return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr)
47+
return sshRemoteForward(ctx, stderr, sshClient, localAddr, remoteAddr)
4848
}

cli/ssh_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"fmt"
1111
"io"
1212
"net"
13+
"net/http"
14+
"net/http/httptest"
1315
"os"
1416
"os/exec"
1517
"path/filepath"
@@ -408,6 +410,58 @@ func TestSSH(t *testing.T) {
408410
<-cmdDone
409411
})
410412

413+
t.Run("RemoteForward", func(t *testing.T) {
414+
if runtime.GOOS == "windows" {
415+
t.Skip("Test not supported on windows")
416+
}
417+
418+
t.Parallel()
419+
420+
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
421+
w.Write([]byte("hello world"))
422+
}))
423+
defer httpServer.Close()
424+
425+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
426+
427+
agentClient := agentsdk.New(client.URL)
428+
agentClient.SetSessionToken(agentToken)
429+
agentCloser := agent.New(agent.Options{
430+
Client: agentClient,
431+
Logger: slogtest.Make(t, nil).Named("agent"),
432+
})
433+
defer agentCloser.Close()
434+
435+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
436+
defer cancel()
437+
438+
inv, root := clitest.New(t,
439+
"ssh",
440+
workspace.Name,
441+
"--remote-forward",
442+
"8222:"+httpServer.Listener.Addr().String(),
443+
)
444+
clitest.SetupConfig(t, client, root)
445+
pty := ptytest.New(t).Attach(inv)
446+
inv.Stderr = pty.Output()
447+
cmdDone := tGo(t, func() {
448+
err := inv.WithContext(ctx).Run()
449+
assert.NoError(t, err, "ssh command failed")
450+
})
451+
452+
// Wait for the prompt or any output really to indicate the command has
453+
// started and accepting input on stdin.
454+
_ = pty.Peek(ctx, 1)
455+
456+
// Download the test page
457+
pty.WriteLine("curl localhost:8222")
458+
pty.ExpectMatch("hello world")
459+
460+
// And we're done.
461+
pty.WriteLine("exit")
462+
<-cmdDone
463+
})
464+
411465
t.Run("FileLogging", func(t *testing.T) {
412466
t.Parallel()
413467

cli/ssh_windows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,5 @@ func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Cli
101101
Net: "unix",
102102
}
103103

104-
return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr)
104+
return sshRemoteForward(ctx, stderr, sshClient, localAddr, remoteAddr)
105105
}

cli/testdata/coder_--help.golden

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ Coder v0.0.0-devel — A tool for provisioning self-hosted development environme
2121
logout Unauthenticate your local session
2222
netcheck Print network debug information for DERP and STUN
2323
ping Ping a workspace
24-
port-forward Forward ports from machine to a workspace
24+
port-forward Forward ports from a workspace to the local machine.
25+
Forward ports from a workspace to the local machine. For
26+
reverse port forwarding, use "coder ssh -R".
2527
publickey Output your Coder public key used for Git operations
2628
rename Rename a workspace
2729
reset-password Directly connect to the database to reset a user's

cli/testdata/coder_port-forward_--help.golden

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Usage: coder port-forward [flags] <workspace>
22

3-
Forward ports from machine to a workspace
3+
Forward ports from a workspace to the local machine. Forward ports from a
4+
workspace to the local machine. For reverse port forwarding, use "coder ssh -R".
45

56
Aliases: tunnel
67

cli/testdata/coder_ssh_--help.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ Start a shell into a workspace
2727
behavior as non-blocking.
2828
DEPRECATED: Use --wait instead.
2929

30+
-R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD
31+
Enable remote port forwarding (remote_port:local_address:local_port).
32+
3033
--stdio bool, $CODER_SSH_STDIO
3134
Specifies whether to emit SSH output over stdin/stdout.
3235

0 commit comments

Comments
 (0)