Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9f5e500

Browse files
committed
Remote forward
1 parent ece557e commit 9f5e500

File tree

2 files changed

+107
-84
lines changed

2 files changed

+107
-84
lines changed

cli/remoteforward.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package cli
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"net"
8+
"regexp"
9+
"strconv"
10+
11+
gossh "golang.org/x/crypto/ssh"
12+
"golang.org/x/xerrors"
13+
14+
"github.com/coder/coder/agent/agentssh"
15+
)
16+
17+
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
18+
// cookie which is written to the connection before forwarding.
19+
type cookieAddr struct {
20+
net.Addr
21+
cookie []byte
22+
}
23+
24+
var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
25+
26+
func validateRemoteForward(flag string) bool {
27+
return remoteForwardRegex.MatchString(flag)
28+
}
29+
30+
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
31+
matches := remoteForwardRegex.FindStringSubmatch(flag)
32+
33+
// Format:
34+
// remote_port:local_address:local_port
35+
remotePort, err := strconv.Atoi(matches[1])
36+
if err != nil {
37+
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
38+
}
39+
localAddress, err := net.ResolveIPAddr("ip", matches[2])
40+
if err != nil {
41+
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
42+
}
43+
localPort, err := strconv.Atoi(matches[3])
44+
if err != nil {
45+
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
46+
}
47+
48+
localAddr := &net.TCPAddr{
49+
IP: localAddress.IP,
50+
Port: localPort,
51+
}
52+
53+
remoteAddr := &net.TCPAddr{
54+
IP: net.ParseIP("127.0.0.1"),
55+
Port: remotePort,
56+
}
57+
return localAddr, remoteAddr, nil
58+
}
59+
60+
// sshRemoteForward starts forwarding connections from a remote listener to a
61+
// local address via SSH in a goroutine.
62+
//
63+
// Accepts a `cookieAddr` as the local address.
64+
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
65+
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
66+
if err != nil {
67+
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
68+
}
69+
70+
go func() {
71+
for {
72+
remoteConn, err := listener.Accept()
73+
if err != nil {
74+
if ctx.Err() == nil {
75+
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
76+
}
77+
return
78+
}
79+
80+
go func() {
81+
defer remoteConn.Close()
82+
83+
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
84+
if err != nil {
85+
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
86+
return
87+
}
88+
defer localConn.Close()
89+
90+
if c, ok := localAddr.(cookieAddr); ok {
91+
_, err = localConn.Write(c.cookie)
92+
if err != nil {
93+
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
94+
return
95+
}
96+
}
97+
98+
agentssh.Bicopy(ctx, localConn, remoteConn)
99+
}()
100+
}
101+
}()
102+
103+
return listener, nil
104+
}

cli/ssh.go

Lines changed: 3 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ import (
66
"errors"
77
"fmt"
88
"io"
9-
"net"
109
"net/url"
1110
"os"
1211
"os/exec"
1312
"path/filepath"
14-
"regexp"
15-
"strconv"
1613
"strings"
1714
"sync"
1815
"time"
@@ -29,7 +26,6 @@ import (
2926
"cdr.dev/slog"
3027
"cdr.dev/slog/sloggers/sloghuman"
3128

32-
"github.com/coder/coder/agent/agentssh"
3329
"github.com/coder/coder/cli/clibase"
3430
"github.com/coder/coder/cli/cliui"
3531
"github.com/coder/coder/coderd/autobuild/notify"
@@ -42,8 +38,6 @@ import (
4238
var (
4339
workspacePollInterval = time.Minute
4440
autostopNotifyCountdown = []time.Duration{30 * time.Minute}
45-
46-
remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
4741
)
4842

4943
//nolint:gocyclo
@@ -128,7 +122,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
128122
}
129123

130124
if remoteForward != "" {
131-
isValid := remoteForwardRegex.MatchString(remoteForward)
125+
isValid := validateRemoteForward(remoteForward)
132126
if !isValid {
133127
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
134128
}
@@ -317,31 +311,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
317311
}
318312

319313
if remoteForward != "" {
320-
matches := remoteForwardRegex.FindStringSubmatch(remoteForward)
321-
322-
// Format:
323-
// remote_port:local_address:local_port
324-
remotePort, err := strconv.Atoi(matches[1])
325-
if err != nil {
326-
return xerrors.Errorf("remote port is invalid: %w", err)
327-
}
328-
localAddress, err := net.ResolveIPAddr("ip", matches[2])
329-
if err != nil {
330-
return xerrors.Errorf("local address is invalid: %w", err)
331-
}
332-
localPort, err := strconv.Atoi(matches[3])
314+
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
333315
if err != nil {
334-
return xerrors.Errorf("local port is invalid: %w", err)
335-
}
336-
337-
localAddr := &net.TCPAddr{
338-
IP: localAddress.IP,
339-
Port: localPort,
340-
}
341-
342-
remoteAddr := &net.TCPAddr{
343-
IP: net.ParseIP("127.0.0.1"),
344-
Port: remotePort,
316+
return err
345317
}
346318

347319
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
@@ -817,56 +789,3 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
817789

818790
return string(bytes.TrimSpace(remoteSocket)), nil
819791
}
820-
821-
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
822-
// cookie which is written to the connection before forwarding.
823-
type cookieAddr struct {
824-
net.Addr
825-
cookie []byte
826-
}
827-
828-
// sshRemoteForward starts forwarding connections from a remote listener to a
829-
// local address via SSH in a goroutine.
830-
//
831-
// Accepts a `cookieAddr` as the local address.
832-
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
833-
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
834-
if err != nil {
835-
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
836-
}
837-
838-
go func() {
839-
for {
840-
remoteConn, err := listener.Accept()
841-
if err != nil {
842-
if ctx.Err() == nil {
843-
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
844-
}
845-
return
846-
}
847-
848-
go func() {
849-
defer remoteConn.Close()
850-
851-
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
852-
if err != nil {
853-
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
854-
return
855-
}
856-
defer localConn.Close()
857-
858-
if c, ok := localAddr.(cookieAddr); ok {
859-
_, err = localConn.Write(c.cookie)
860-
if err != nil {
861-
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
862-
return
863-
}
864-
}
865-
866-
agentssh.Bicopy(ctx, localConn, remoteConn)
867-
}()
868-
}
869-
}()
870-
871-
return listener, nil
872-
}

0 commit comments

Comments
 (0)