Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat(cli): use coder connect in coder ssh --stdio, if available #17572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 123 additions & 9 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command {
stdio bool
hostPrefix string
hostnameSuffix string
forceNewTunnel bool
forwardAgent bool
forwardGPG bool
identityAgent string
Expand All @@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command {
containerUser string
)
client := new(codersdk.Client)
wsClient := workspacesdk.New(client)
cmd := &serpent.Command{
Annotations: workspaceCommand,
Use: "ssh <workspace>",
Expand Down Expand Up @@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command {
parsedEnv = append(parsedEnv, [2]string{k, v})
}

deploymentSSHConfig := codersdk.SSHConfigResponse{
cliConfig := codersdk.SSHConfigResponse{
Copy link
Member Author

@ethanndickson ethanndickson Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite the name, this was always populated from CLI arguments, which in half of the cases are not the deployment SSH config (i.e. for the VS Code extension it's something like vscode-coder)

HostnamePrefix: hostPrefix,
HostnameSuffix: hostnameSuffix,
}

workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], deploymentSSHConfig, disableAutostart)
inv.Args[0], cliConfig, disableAutostart)
if err != nil {
return err
}
Expand Down Expand Up @@ -275,10 +278,44 @@ func (r *RootCmd) ssh() *serpent.Command {
return err
}

// If we're in stdio mode, check to see if we can use Coder Connect.
// We don't support Coder Connect over non-stdio coder ssh yet.
if stdio && !forceNewTunnel {
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
if err != nil {
return xerrors.Errorf("get agent connection info: %w", err)
}
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost)
if exists {
defer cancel()

if networkInfoDir != "" {
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
}
}

stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

usageAppName := getUsageAppName(usageApp)
if usageAppName != "" {
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
AgentID: workspaceAgent.ID,
AppName: usageAppName,
})
defer closeUsage()
}
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
}
}

if r.disableDirect {
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
}
conn, err := workspacesdk.New(client).
conn, err := wsClient.
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
Expand Down Expand Up @@ -662,6 +699,12 @@ func (r *RootCmd) ssh() *serpent.Command {
Value: serpent.StringOf(&containerUser),
Hidden: true, // Hidden until this features is at least in beta.
},
{
Flag: "force-new-tunnel",
Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.",
Value: serpent.BoolOf(&forceNewTunnel),
Hidden: true,
},
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
}
return cmd
Expand Down Expand Up @@ -1374,12 +1417,13 @@ func setStatsCallback(
}

type sshNetworkStats struct {
P2P bool `json:"p2p"`
Latency float64 `json:"latency"`
PreferredDERP string `json:"preferred_derp"`
DERPLatency map[string]float64 `json:"derp_latency"`
UploadBytesSec int64 `json:"upload_bytes_sec"`
DownloadBytesSec int64 `json:"download_bytes_sec"`
P2P bool `json:"p2p"`
Latency float64 `json:"latency"`
PreferredDERP string `json:"preferred_derp"`
DERPLatency map[string]float64 `json:"derp_latency"`
UploadBytesSec int64 `json:"upload_bytes_sec"`
DownloadBytesSec int64 `json:"download_bytes_sec"`
UsingCoderConnect bool `json:"using_coder_connect"`
}

func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
Expand Down Expand Up @@ -1450,6 +1494,76 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn,
}, nil
}

type coderConnectDialerContextKey struct{}

type coderConnectDialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}

func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context {
return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer)
}

func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
if !ok || dialer == nil {
return &net.Dialer{}
}
return dialer
}

func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
dialer := testOrDefaultDialer(ctx)
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host: %w", err)
}
if err := stack.push("tcp conn", conn); err != nil {
return err
}

agentssh.Bicopy(ctx, conn, &StdioRwc{
Reader: stdin,
Writer: stdout,
})

return nil
}

type StdioRwc struct {
io.Reader
io.Writer
}

func (*StdioRwc) Close() error {
return nil
}

func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error {
fs, ok := ctx.Value("fs").(afero.Fs)
if !ok {
fs = afero.NewOsFs()
}
// The VS Code extension obtains the PID of the SSH process to
// find the log file associated with a SSH session.
//
// We get the parent PID because it's assumed `ssh` is calling this
// command via the ProxyCommand SSH option.
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
stats := &sshNetworkStats{
UsingCoderConnect: true,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll put up a PR for the vscode extension to read this.

}
rawStats, err := json.Marshal(stats)
if err != nil {
return xerrors.Errorf("marshal network stats: %w", err)
}
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
if err != nil {
return xerrors.Errorf("write network stats: %w", err)
}
return nil
}

// Converts workspace name input to owner/workspace.agent format
// Possible valid input formats:
// workspace
Expand Down
85 changes: 85 additions & 0 deletions cli/ssh_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package cli
import (
"context"
"fmt"
"io"
"net"
"net/url"
"sync"
"testing"
"time"

gliderssh "github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"cdr.dev/slog"
Expand Down Expand Up @@ -220,6 +224,87 @@ func TestCloserStack_Timeout(t *testing.T) {
testutil.TryReceive(ctx, t, closed)
}

func TestCoderConnectStdio(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
stack := newCloserStack(ctx, logger, quartz.NewMock(t))

clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe()
defer func() {
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
_ = c.Close()
}
}()

server := newSSHServer("127.0.0.1:0")
ln, err := net.Listen("tcp", server.server.Addr)
require.NoError(t, err)

go func() {
_ = server.Serve(ln)
}()
t.Cleanup(func() {
_ = server.Close()
})

stdioDone := make(chan struct{})
go func() {
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
assert.NoError(t, err)
close(stdioDone)
}()

conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{
Reader: serverOutput,
Writer: clientInput,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
defer conn.Close()

sshClient := ssh.NewClient(conn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()

// We're not connected to a real shell
err = session.Run("")
require.NoError(t, err)
err = sshClient.Close()
require.NoError(t, err)
_ = clientOutput.Close()

<-stdioDone
}

type sshServer struct {
server *gliderssh.Server
}

func newSSHServer(addr string) *sshServer {
return &sshServer{
server: &gliderssh.Server{
Addr: addr,
Handler: func(s gliderssh.Session) {
_, _ = io.WriteString(s.Stderr(), "Connected!")
},
},
}
}

func (s *sshServer) Serve(ln net.Listener) error {
return s.server.Serve(ln)
}

func (s *sshServer) Close() error {
return s.server.Close()
}

type fakeCloser struct {
closes *[]*fakeCloser
err error
Expand Down
Loading
Loading