-
Notifications
You must be signed in to change notification settings - Fork 888
feat: Add workspace agent for SSH #318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
fa7489a
feat: Add workspace agent for SSH
kylecarbs 2606fda
Merge branch 'main' into workspaceagent
kylecarbs b54e815
Fix pty tests on Windows
kylecarbs c484d47
Fix log race
kylecarbs cc2bcde
Lock around dial error to fix log output
kylecarbs ae36c63
Fix context return early
kylecarbs 4932ba7
fix: Leaking yamux session after HTTP handler is closed
kylecarbs 0c84636
Merge branch 'main' into closeconn
kylecarbs 5601b4d
Lock around close return
kylecarbs 94cf442
Merge branch 'main' into closeconn
kylecarbs 055ce11
Force failure with log
kylecarbs 94e03a2
Merge branch 'main' into closeconn
kylecarbs f8566f3
Merge branch 'closeconn' into workspaceagent
kylecarbs 70d3723
Fix failed handler
kylecarbs 223540e
Upgrade dep
kylecarbs 1a48bea
Fix defer inside loops
kylecarbs ee0ee70
Fix context cancel for HTTP requests
kylecarbs 30189aa
Merge branch 'main' into workspaceagent
kylecarbs 516b605
Fix resize
kylecarbs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
package agent | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net" | ||
"os/exec" | ||
"os/user" | ||
"sync" | ||
"time" | ||
|
||
"cdr.dev/slog" | ||
"github.com/coder/coder/agent/usershell" | ||
"github.com/coder/coder/peer" | ||
"github.com/coder/coder/peerbroker" | ||
"github.com/coder/coder/pty" | ||
"github.com/coder/retry" | ||
|
||
"github.com/gliderlabs/ssh" | ||
gossh "golang.org/x/crypto/ssh" | ||
"golang.org/x/xerrors" | ||
) | ||
|
||
func DialSSH(conn *peer.Conn) (net.Conn, error) { | ||
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{ | ||
Protocol: "ssh", | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return channel.NetConn(), nil | ||
} | ||
|
||
func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) { | ||
netConn, err := DialSSH(conn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{ | ||
Config: gossh.Config{ | ||
Ciphers: []string{"arcfour"}, | ||
}, | ||
// SSH host validation isn't helpful, because obtaining a peer | ||
// connection already signifies user-intent to dial a workspace. | ||
// #nosec | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return gossh.NewClient(sshConn, channels, requests), nil | ||
} | ||
|
||
type Options struct { | ||
Logger slog.Logger | ||
} | ||
|
||
type Dialer func(ctx context.Context) (*peerbroker.Listener, error) | ||
|
||
func New(dialer Dialer, options *Options) io.Closer { | ||
ctx, cancelFunc := context.WithCancel(context.Background()) | ||
server := &server{ | ||
clientDialer: dialer, | ||
options: options, | ||
closeCancel: cancelFunc, | ||
closed: make(chan struct{}), | ||
} | ||
server.init(ctx) | ||
return server | ||
} | ||
|
||
type server struct { | ||
clientDialer Dialer | ||
options *Options | ||
|
||
closeCancel context.CancelFunc | ||
closeMutex sync.Mutex | ||
closed chan struct{} | ||
|
||
sshServer *ssh.Server | ||
} | ||
|
||
func (s *server) init(ctx context.Context) { | ||
// Clients' should ignore the host key when connecting. | ||
// The agent needs to authenticate with coderd to SSH, | ||
// so SSH authentication doesn't improve security. | ||
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
panic(err) | ||
} | ||
randomSigner, err := gossh.NewSignerFromKey(randomHostKey) | ||
if err != nil { | ||
panic(err) | ||
} | ||
sshLogger := s.options.Logger.Named("ssh-server") | ||
forwardHandler := &ssh.ForwardedTCPHandler{} | ||
s.sshServer = &ssh.Server{ | ||
ChannelHandlers: ssh.DefaultChannelHandlers, | ||
ConnectionFailedCallback: func(conn net.Conn, err error) { | ||
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) | ||
}, | ||
Handler: func(session ssh.Session) { | ||
err := s.handleSSHSession(session) | ||
if err != nil { | ||
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err)) | ||
_ = session.Exit(1) | ||
return | ||
} | ||
}, | ||
HostSigners: []ssh.Signer{randomSigner}, | ||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { | ||
// Allow local port forwarding all! | ||
sshLogger.Debug(ctx, "local port forward", | ||
slog.F("destination-host", destinationHost), | ||
slog.F("destination-port", destinationPort)) | ||
return true | ||
}, | ||
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { | ||
return true | ||
}, | ||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { | ||
// Allow reverse port forwarding all! | ||
sshLogger.Debug(ctx, "local port forward", | ||
slog.F("bind-host", bindHost), | ||
slog.F("bind-port", bindPort)) | ||
return true | ||
}, | ||
RequestHandlers: map[string]ssh.RequestHandler{ | ||
"tcpip-forward": forwardHandler.HandleSSHRequest, | ||
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest, | ||
}, | ||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { | ||
return &gossh.ServerConfig{ | ||
Config: gossh.Config{ | ||
// "arcfour" is the fastest SSH cipher. We prioritize throughput | ||
// over encryption here, because the WebRTC connection is already | ||
// encrypted. If possible, we'd disable encryption entirely here. | ||
Ciphers: []string{"arcfour"}, | ||
}, | ||
NoClientAuth: true, | ||
} | ||
}, | ||
} | ||
|
||
go s.run(ctx) | ||
} | ||
|
||
func (*server) handleSSHSession(session ssh.Session) error { | ||
var ( | ||
command string | ||
args = []string{} | ||
err error | ||
) | ||
|
||
username := session.User() | ||
if username == "" { | ||
currentUser, err := user.Current() | ||
if err != nil { | ||
return xerrors.Errorf("get current user: %w", err) | ||
} | ||
username = currentUser.Username | ||
} | ||
|
||
// gliderlabs/ssh returns a command slice of zero | ||
// when a shell is requested. | ||
if len(session.Command()) == 0 { | ||
command, err = usershell.Get(username) | ||
if err != nil { | ||
return xerrors.Errorf("get user shell: %w", err) | ||
} | ||
} else { | ||
command = session.Command()[0] | ||
if len(session.Command()) > 1 { | ||
args = session.Command()[1:] | ||
} | ||
} | ||
|
||
signals := make(chan ssh.Signal) | ||
breaks := make(chan bool) | ||
defer close(signals) | ||
defer close(breaks) | ||
go func() { | ||
for { | ||
select { | ||
case <-session.Context().Done(): | ||
return | ||
// Ignore signals and breaks for now! | ||
case <-signals: | ||
case <-breaks: | ||
} | ||
} | ||
}() | ||
|
||
cmd := exec.CommandContext(session.Context(), command, args...) | ||
cmd.Env = session.Environ() | ||
|
||
sshPty, windowSize, isPty := session.Pty() | ||
if isPty { | ||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) | ||
ptty, process, err := pty.Start(cmd) | ||
if err != nil { | ||
return xerrors.Errorf("start command: %w", err) | ||
} | ||
go func() { | ||
for win := range windowSize { | ||
err := ptty.Resize(uint16(win.Width), uint16(win.Height)) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
}() | ||
go func() { | ||
_, _ = io.Copy(ptty.Input(), session) | ||
}() | ||
go func() { | ||
_, _ = io.Copy(session, ptty.Output()) | ||
}() | ||
_, _ = process.Wait() | ||
_ = ptty.Close() | ||
return nil | ||
} | ||
|
||
cmd.Stdout = session | ||
cmd.Stderr = session | ||
// This blocks forever until stdin is received if we don't | ||
// use StdinPipe. It's unknown what causes this. | ||
stdinPipe, err := cmd.StdinPipe() | ||
if err != nil { | ||
return xerrors.Errorf("create stdin pipe: %w", err) | ||
} | ||
go func() { | ||
_, _ = io.Copy(stdinPipe, session) | ||
}() | ||
err = cmd.Start() | ||
if err != nil { | ||
return xerrors.Errorf("start: %w", err) | ||
} | ||
_ = cmd.Wait() | ||
return nil | ||
} | ||
|
||
func (s *server) run(ctx context.Context) { | ||
var peerListener *peerbroker.Listener | ||
var err error | ||
// An exponential back-off occurs when the connection is failing to dial. | ||
// This is to prevent server spam in case of a coderd outage. | ||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { | ||
peerListener, err = s.clientDialer(ctx) | ||
if err != nil { | ||
if errors.Is(err, context.Canceled) { | ||
return | ||
} | ||
if s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) | ||
continue | ||
} | ||
s.options.Logger.Debug(context.Background(), "connected") | ||
break | ||
} | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
default: | ||
} | ||
|
||
for { | ||
conn, err := peerListener.Accept() | ||
if err != nil { | ||
if s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) | ||
s.run(ctx) | ||
return | ||
} | ||
go s.handlePeerConn(ctx, conn) | ||
} | ||
} | ||
|
||
func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) { | ||
for { | ||
channel, err := conn.Accept(ctx) | ||
if err != nil { | ||
if errors.Is(err, peer.ErrClosed) || s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) | ||
return | ||
} | ||
|
||
switch channel.Protocol() { | ||
case "ssh": | ||
s.sshServer.HandleConn(channel.NetConn()) | ||
default: | ||
s.options.Logger.Warn(ctx, "unhandled protocol from channel", | ||
slog.F("protocol", channel.Protocol()), | ||
slog.F("label", channel.Label()), | ||
) | ||
} | ||
} | ||
} | ||
|
||
// isClosed returns whether the API is closed or not. | ||
func (s *server) isClosed() bool { | ||
select { | ||
case <-s.closed: | ||
return true | ||
default: | ||
return false | ||
} | ||
} | ||
|
||
func (s *server) Close() error { | ||
s.closeMutex.Lock() | ||
defer s.closeMutex.Unlock() | ||
if s.isClosed() { | ||
return nil | ||
} | ||
close(s.closed) | ||
s.closeCancel() | ||
_ = s.sshServer.Close() | ||
return nil | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this agent be run as a standalone executable, or part of
coderd
? If it is the former - maybe it would make sense to beagentd
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Part of the
coder
CLI! I was debating the name and structure internally. The purpose is the listening agent inside of a workspace for enabling access. It'll be added ascoder agent start
, or something of the like.I initially had this in
cli/agent
, but felt that unnecessarily nested core business logic. Beyond that, it's likely this package will be used for tests that aren't relevant to the CLI at all. eg.agenttest.New
will likely exist.I'd appreciate your thoughts here. It's more of an agent than a daemon, because it doesn't require system-level privileges, and is active rather than passive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the helpful details @kylecarbs !
That sounds good!
Agreed, it seems like our convention here is to have packages at the top-level - don't see a compelling reason to switch that.
Makes sense to me 👍