Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d0e2060

Browse files
authored
feat(agent): add second SSH listener on port 22 (#16627)
Fixes: coder/internal#377 Added an additional SSH listener on port 22, so the agent now listens on both, port one and port 22. --- Change-Id: Ifd986b260f8ac317e37d65111cd4e0bd1dc38af8 Signed-off-by: Thomas Kosiewski <[email protected]>
1 parent ca23abe commit d0e2060

File tree

6 files changed

+153
-95
lines changed

6 files changed

+153
-95
lines changed

agent/agent.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -1362,19 +1362,22 @@ func (a *agent) createTailnet(
13621362
return nil, xerrors.Errorf("update host signer: %w", err)
13631363
}
13641364

1365-
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
1366-
if err != nil {
1367-
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
1368-
}
1369-
defer func() {
1365+
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
1366+
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
13701367
if err != nil {
1371-
_ = sshListener.Close()
1368+
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
1369+
}
1370+
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
1371+
defer func() {
1372+
if err != nil {
1373+
_ = sshListener.Close()
1374+
}
1375+
}()
1376+
if err = a.trackGoroutine(func() {
1377+
_ = a.sshServer.Serve(sshListener)
1378+
}); err != nil {
1379+
return nil, err
13721380
}
1373-
}()
1374-
if err = a.trackGoroutine(func() {
1375-
_ = a.sshServer.Serve(sshListener)
1376-
}); err != nil {
1377-
return nil, err
13781381
}
13791382

13801383
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))

agent/agent_test.go

+120-79
Original file line numberDiff line numberDiff line change
@@ -65,38 +65,48 @@ func TestMain(m *testing.M) {
6565
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
6666
}
6767

68+
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
69+
6870
// NOTE: These tests only work when your default shell is bash for some reason.
6971

7072
func TestAgent_Stats_SSH(t *testing.T) {
7173
t.Parallel()
72-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
73-
defer cancel()
7474

75-
//nolint:dogsled
76-
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
75+
for _, port := range sshPorts {
76+
port := port
77+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
78+
t.Parallel()
7779

78-
sshClient, err := conn.SSHClient(ctx)
79-
require.NoError(t, err)
80-
defer sshClient.Close()
81-
session, err := sshClient.NewSession()
82-
require.NoError(t, err)
83-
defer session.Close()
84-
stdin, err := session.StdinPipe()
85-
require.NoError(t, err)
86-
err = session.Shell()
87-
require.NoError(t, err)
80+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
81+
defer cancel()
8882

89-
var s *proto.Stats
90-
require.Eventuallyf(t, func() bool {
91-
var ok bool
92-
s, ok = <-stats
93-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
94-
}, testutil.WaitLong, testutil.IntervalFast,
95-
"never saw stats: %+v", s,
96-
)
97-
_ = stdin.Close()
98-
err = session.Wait()
99-
require.NoError(t, err)
83+
//nolint:dogsled
84+
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
85+
86+
sshClient, err := conn.SSHClientOnPort(ctx, port)
87+
require.NoError(t, err)
88+
defer sshClient.Close()
89+
session, err := sshClient.NewSession()
90+
require.NoError(t, err)
91+
defer session.Close()
92+
stdin, err := session.StdinPipe()
93+
require.NoError(t, err)
94+
err = session.Shell()
95+
require.NoError(t, err)
96+
97+
var s *proto.Stats
98+
require.Eventuallyf(t, func() bool {
99+
var ok bool
100+
s, ok = <-stats
101+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
102+
}, testutil.WaitLong, testutil.IntervalFast,
103+
"never saw stats: %+v", s,
104+
)
105+
_ = stdin.Close()
106+
err = session.Wait()
107+
require.NoError(t, err)
108+
})
109+
}
100110
}
101111

102112
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
@@ -278,15 +288,23 @@ func TestAgent_Stats_Magic(t *testing.T) {
278288

279289
func TestAgent_SessionExec(t *testing.T) {
280290
t.Parallel()
281-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
282291

283-
command := "echo test"
284-
if runtime.GOOS == "windows" {
285-
command = "cmd.exe /c echo test"
292+
for _, port := range sshPorts {
293+
port := port
294+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
295+
t.Parallel()
296+
297+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
298+
299+
command := "echo test"
300+
if runtime.GOOS == "windows" {
301+
command = "cmd.exe /c echo test"
302+
}
303+
output, err := session.Output(command)
304+
require.NoError(t, err)
305+
require.Equal(t, "test", strings.TrimSpace(string(output)))
306+
})
286307
}
287-
output, err := session.Output(command)
288-
require.NoError(t, err)
289-
require.Equal(t, "test", strings.TrimSpace(string(output)))
290308
}
291309

292310
//nolint:tparallel // Sub tests need to run sequentially.
@@ -396,25 +414,33 @@ func TestAgent_SessionTTYShell(t *testing.T) {
396414
// it seems like it could be either.
397415
t.Skip("ConPTY appears to be inconsistent on Windows.")
398416
}
399-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
400-
command := "sh"
401-
if runtime.GOOS == "windows" {
402-
command = "cmd.exe"
417+
418+
for _, port := range sshPorts {
419+
port := port
420+
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
421+
t.Parallel()
422+
423+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
424+
command := "sh"
425+
if runtime.GOOS == "windows" {
426+
command = "cmd.exe"
427+
}
428+
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
429+
require.NoError(t, err)
430+
ptty := ptytest.New(t)
431+
session.Stdout = ptty.Output()
432+
session.Stderr = ptty.Output()
433+
session.Stdin = ptty.Input()
434+
err = session.Start(command)
435+
require.NoError(t, err)
436+
_ = ptty.Peek(ctx, 1) // wait for the prompt
437+
ptty.WriteLine("echo test")
438+
ptty.ExpectMatch("test")
439+
ptty.WriteLine("exit")
440+
err = session.Wait()
441+
require.NoError(t, err)
442+
})
403443
}
404-
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
405-
require.NoError(t, err)
406-
ptty := ptytest.New(t)
407-
session.Stdout = ptty.Output()
408-
session.Stderr = ptty.Output()
409-
session.Stdin = ptty.Input()
410-
err = session.Start(command)
411-
require.NoError(t, err)
412-
_ = ptty.Peek(ctx, 1) // wait for the prompt
413-
ptty.WriteLine("echo test")
414-
ptty.ExpectMatch("test")
415-
ptty.WriteLine("exit")
416-
err = session.Wait()
417-
require.NoError(t, err)
418444
}
419445

420446
func TestAgent_SessionTTYExitCode(t *testing.T) {
@@ -608,37 +634,41 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
608634
//nolint:dogsled // Allow the blank identifiers.
609635
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
610636

611-
sshClient, err := conn.SSHClient(ctx)
612-
require.NoError(t, err)
613-
t.Cleanup(func() {
614-
_ = sshClient.Close()
615-
})
616-
617637
//nolint:paralleltest // These tests need to swap the banner func.
618-
for i, test := range tests {
619-
test := test
620-
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
621-
// Set new banner func and wait for the agent to call it to update the
622-
// banner.
623-
ready := make(chan struct{}, 2)
624-
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
625-
select {
626-
case ready <- struct{}{}:
627-
default:
628-
}
629-
return []codersdk.BannerConfig{test.banner}, nil
630-
})
631-
<-ready
632-
<-ready // Wait for two updates to ensure the value has propagated.
633-
634-
session, err := sshClient.NewSession()
635-
require.NoError(t, err)
636-
t.Cleanup(func() {
637-
_ = session.Close()
638-
})
638+
for _, port := range sshPorts {
639+
port := port
639640

640-
testSessionOutput(t, session, test.expected, test.unexpected, nil)
641+
sshClient, err := conn.SSHClientOnPort(ctx, port)
642+
require.NoError(t, err)
643+
t.Cleanup(func() {
644+
_ = sshClient.Close()
641645
})
646+
647+
for i, test := range tests {
648+
test := test
649+
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
650+
// Set new banner func and wait for the agent to call it to update the
651+
// banner.
652+
ready := make(chan struct{}, 2)
653+
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
654+
select {
655+
case ready <- struct{}{}:
656+
default:
657+
}
658+
return []codersdk.BannerConfig{test.banner}, nil
659+
})
660+
<-ready
661+
<-ready // Wait for two updates to ensure the value has propagated.
662+
663+
session, err := sshClient.NewSession()
664+
require.NoError(t, err)
665+
t.Cleanup(func() {
666+
_ = session.Close()
667+
})
668+
669+
testSessionOutput(t, session, test.expected, test.unexpected, nil)
670+
})
671+
}
642672
}
643673
}
644674

@@ -2424,6 +2454,17 @@ func setupSSHSession(
24242454
banner codersdk.BannerConfig,
24252455
prepareFS func(fs afero.Fs),
24262456
opts ...func(*agenttest.Client, *agent.Options),
2457+
) *ssh.Session {
2458+
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
2459+
}
2460+
2461+
func setupSSHSessionOnPort(
2462+
t *testing.T,
2463+
manifest agentsdk.Manifest,
2464+
banner codersdk.BannerConfig,
2465+
prepareFS func(fs afero.Fs),
2466+
port uint16,
2467+
opts ...func(*agenttest.Client, *agent.Options),
24272468
) *ssh.Session {
24282469
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
24292470
defer cancel()
@@ -2437,7 +2478,7 @@ func setupSSHSession(
24372478
if prepareFS != nil {
24382479
prepareFS(fs)
24392480
}
2440-
sshClient, err := conn.SSHClient(ctx)
2481+
sshClient, err := conn.SSHClientOnPort(ctx, port)
24412482
require.NoError(t, err)
24422483
t.Cleanup(func() {
24432484
_ = sshClient.Close()

agent/usershell/usershell_darwin.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func Get(username string) (string, error) {
1818
return "", xerrors.Errorf("username is nonlocal path: %s", username)
1919
}
2020
//nolint: gosec // input checked above
21-
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
21+
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output() //nolint:gocritic
2222
s, ok := strings.CutPrefix(string(out), "UserShell: ")
2323
if ok {
2424
return strings.TrimSpace(s), nil

codersdk/workspacesdk/agentconn.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,36 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w
165165
// SSH pipes the SSH protocol over the returned net.Conn.
166166
// This connects to the built-in SSH server in the workspace agent.
167167
func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
168+
return c.SSHOnPort(ctx, AgentSSHPort)
169+
}
170+
171+
// SSHOnPort pipes the SSH protocol over the returned net.Conn.
172+
// This connects to the built-in SSH server in the workspace agent on the specified port.
173+
func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) {
168174
ctx, span := tracing.StartSpan(ctx)
169175
defer span.End()
170176

171177
if !c.AwaitReachable(ctx) {
172178
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
173179
}
174180

175-
c.Conn.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
176-
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), AgentSSHPort))
181+
c.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
182+
return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port))
177183
}
178184

179185
// SSHClient calls SSH to create a client that uses a weak cipher
180186
// to improve throughput.
181187
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
188+
return c.SSHClientOnPort(ctx, AgentSSHPort)
189+
}
190+
191+
// SSHClientOnPort calls SSH to create a client on a specific port
192+
// that uses a weak cipher to improve throughput.
193+
func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) {
182194
ctx, span := tracing.StartSpan(ctx)
183195
defer span.End()
184196

185-
netConn, err := c.SSH(ctx)
197+
netConn, err := c.SSHOnPort(ctx, port)
186198
if err != nil {
187199
return nil, xerrors.Errorf("ssh: %w", err)
188200
}

codersdk/workspacesdk/workspacesdk.go

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ var ErrSkipClose = xerrors.New("skip tailnet close")
3131

3232
const (
3333
AgentSSHPort = tailnet.WorkspaceAgentSSHPort
34+
AgentStandardSSHPort = tailnet.WorkspaceAgentStandardSSHPort
3435
AgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
3536
AgentSpeedtestPort = tailnet.WorkspaceAgentSpeedtestPort
3637
// AgentHTTPAPIServerPort serves a HTTP server with endpoints for e.g.

tailnet/conn.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const (
5252
WorkspaceAgentSSHPort = 1
5353
WorkspaceAgentReconnectingPTYPort = 2
5454
WorkspaceAgentSpeedtestPort = 3
55+
WorkspaceAgentStandardSSHPort = 22
5556
)
5657

5758
// EnvMagicsockDebugLogging enables super-verbose logging for the magicsock
@@ -745,7 +746,7 @@ func (c *Conn) forwardTCP(src, dst netip.AddrPort) (handler func(net.Conn), opts
745746
return nil, nil, false
746747
}
747748
// See: https://github.com/tailscale/tailscale/blob/c7cea825aea39a00aca71ea02bab7266afc03e7c/wgengine/netstack/netstack.go#L888
748-
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == 22 {
749+
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == WorkspaceAgentStandardSSHPort {
749750
opt := tcpip.KeepaliveIdleOption(72 * time.Hour)
750751
opts = append(opts, &opt)
751752
}

0 commit comments

Comments
 (0)