Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cab5f15

Browse files
committed
feat: Add workspace agent for SSH
This adds the initial agent that supports TTY and execution over SSH. It functions across MacOS, Windows, and Linux. This does not handle the coderd interaction yet, but does setup a simple path forward.
1 parent e5db936 commit cab5f15

11 files changed

+501
-28
lines changed

agent/agent.go

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"crypto/rsa"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net"
11+
"os/exec"
12+
"time"
13+
14+
"cdr.dev/slog"
15+
"github.com/coder/coder/agent/usershell"
16+
"github.com/coder/coder/peer"
17+
"github.com/coder/coder/peerbroker"
18+
"github.com/coder/coder/pty"
19+
"github.com/coder/retry"
20+
21+
"github.com/gliderlabs/ssh"
22+
gossh "golang.org/x/crypto/ssh"
23+
"golang.org/x/xerrors"
24+
)
25+
26+
func DialSSH(conn *peer.Conn) (net.Conn, error) {
27+
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{
28+
Protocol: "ssh",
29+
})
30+
if err != nil {
31+
return nil, err
32+
}
33+
return channel.NetConn(), nil
34+
}
35+
36+
func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
37+
netConn, err := DialSSH(conn)
38+
if err != nil {
39+
return nil, err
40+
}
41+
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{
42+
User: "kyle",
43+
Config: gossh.Config{
44+
Ciphers: []string{"arcfour"},
45+
},
46+
// SSH host validation isn't helpful, because obtaining a peer
47+
// connection already signifies user-intent to dial a workspace.
48+
// #nosec
49+
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
50+
})
51+
if err != nil {
52+
return nil, err
53+
}
54+
return gossh.NewClient(sshConn, channels, requests), nil
55+
}
56+
57+
type Options struct {
58+
Logger slog.Logger
59+
}
60+
61+
type Dialer func(ctx context.Context) (*peerbroker.Listener, error)
62+
63+
func New(dialer Dialer, options *Options) io.Closer {
64+
ctx, cancelFunc := context.WithCancel(context.Background())
65+
server := &server{
66+
clientDialer: dialer,
67+
options: options,
68+
closeCancel: cancelFunc,
69+
}
70+
server.init(ctx)
71+
return server
72+
}
73+
74+
type server struct {
75+
clientDialer Dialer
76+
options *Options
77+
78+
closeCancel context.CancelFunc
79+
closed chan struct{}
80+
81+
sshServer *ssh.Server
82+
}
83+
84+
func (s *server) init(ctx context.Context) {
85+
// Clients' should ignore the host key when connecting.
86+
// The agent needs to authenticate with coderd to SSH,
87+
// so SSH authentication doesn't improve security.
88+
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
89+
if err != nil {
90+
panic(err)
91+
}
92+
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
93+
if err != nil {
94+
panic(err)
95+
}
96+
sshLogger := s.options.Logger.Named("ssh-server")
97+
forwardHandler := &ssh.ForwardedTCPHandler{}
98+
s.sshServer = &ssh.Server{
99+
ChannelHandlers: ssh.DefaultChannelHandlers,
100+
ConnectionFailedCallback: func(conn net.Conn, err error) {
101+
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
102+
},
103+
Handler: func(session ssh.Session) {
104+
err := s.handleSSHSession(session)
105+
if err != nil {
106+
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err))
107+
_ = session.Exit(1)
108+
return
109+
}
110+
},
111+
HostSigners: []ssh.Signer{randomSigner},
112+
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
113+
// Allow local port forwarding all!
114+
sshLogger.Debug(ctx, "local port forward",
115+
slog.F("destination-host", destinationHost),
116+
slog.F("destination-port", destinationPort))
117+
return true
118+
},
119+
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
120+
return true
121+
},
122+
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
123+
// Allow reverse port forwarding all!
124+
sshLogger.Debug(ctx, "local port forward",
125+
slog.F("bind-host", bindHost),
126+
slog.F("bind-port", bindPort))
127+
return true
128+
},
129+
RequestHandlers: map[string]ssh.RequestHandler{
130+
"tcpip-forward": forwardHandler.HandleSSHRequest,
131+
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
132+
},
133+
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
134+
return &gossh.ServerConfig{
135+
Config: gossh.Config{
136+
// "arcfour" is the fastest SSH cipher. We prioritize throughput
137+
// over encryption here, because the WebRTC connection is already
138+
// encrypted. If possible, we'd disable encryption entirely here.
139+
Ciphers: []string{"arcfour"},
140+
},
141+
NoClientAuth: true,
142+
}
143+
},
144+
}
145+
146+
go s.run(ctx)
147+
}
148+
149+
func (*server) handleSSHSession(session ssh.Session) error {
150+
var (
151+
command string
152+
args = []string{}
153+
err error
154+
)
155+
156+
// gliderlabs/ssh returns a command slice of zero
157+
// when a shell is requested.
158+
if len(session.Command()) == 0 {
159+
command, err = usershell.Get(session.User())
160+
if err != nil {
161+
return xerrors.Errorf("get user shell: %w", err)
162+
}
163+
} else {
164+
command = session.Command()[0]
165+
if len(session.Command()) > 1 {
166+
args = session.Command()[1:]
167+
}
168+
}
169+
170+
signals := make(chan ssh.Signal)
171+
breaks := make(chan bool)
172+
defer close(signals)
173+
defer close(breaks)
174+
go func() {
175+
for {
176+
select {
177+
case <-session.Context().Done():
178+
return
179+
// Ignore signals and breaks for now!
180+
case <-signals:
181+
case <-breaks:
182+
}
183+
}
184+
}()
185+
186+
cmd := exec.CommandContext(session.Context(), command, args...)
187+
cmd.Env = session.Environ()
188+
189+
sshPty, windowSize, isPty := session.Pty()
190+
if isPty {
191+
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
192+
ptty, process, err := pty.Start(cmd)
193+
if err != nil {
194+
return xerrors.Errorf("start command: %w", err)
195+
}
196+
go func() {
197+
for win := range windowSize {
198+
err := ptty.Resize(uint16(win.Width), uint16(win.Height))
199+
if err != nil {
200+
panic(err)
201+
}
202+
}
203+
}()
204+
go func() {
205+
_, _ = io.Copy(ptty.Input(), session)
206+
}()
207+
go func() {
208+
_, _ = io.Copy(session, ptty.Output())
209+
}()
210+
_, err = process.Wait()
211+
return err
212+
}
213+
214+
cmd.Stdout = session
215+
cmd.Stderr = session
216+
// This blocks forever until stdin is received if we don't
217+
// use StdinPipe. It's unknown what causes this.
218+
stdinPipe, err := cmd.StdinPipe()
219+
if err != nil {
220+
return xerrors.Errorf("create stdin pipe: %w", err)
221+
}
222+
go func() {
223+
_, _ = io.Copy(stdinPipe, session)
224+
}()
225+
err = cmd.Start()
226+
if err != nil {
227+
return xerrors.Errorf("start: %w", err)
228+
}
229+
return cmd.Wait()
230+
}
231+
232+
func (s *server) run(ctx context.Context) {
233+
var peerListener *peerbroker.Listener
234+
var err error
235+
// An exponential back-off occurs when the connection is failing to dial.
236+
// This is to prevent server spam in case of a coderd outage.
237+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
238+
peerListener, err = s.clientDialer(ctx)
239+
if err != nil {
240+
if errors.Is(err, context.Canceled) {
241+
return
242+
}
243+
if s.isClosed() {
244+
return
245+
}
246+
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
247+
continue
248+
}
249+
s.options.Logger.Debug(context.Background(), "connected")
250+
break
251+
}
252+
253+
for {
254+
conn, err := peerListener.Accept()
255+
if err != nil {
256+
// This is closed!
257+
return
258+
}
259+
go s.handlePeerConn(ctx, conn)
260+
}
261+
}
262+
263+
func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
264+
for {
265+
channel, err := conn.Accept(ctx)
266+
if err != nil {
267+
// TODO: Log here!
268+
return
269+
}
270+
271+
switch channel.Protocol() {
272+
case "ssh":
273+
s.sshServer.HandleConn(channel.NetConn())
274+
case "proxy":
275+
// Proxy the port provided.
276+
}
277+
}
278+
}
279+
280+
// isClosed returns whether the API is closed or not.
281+
func (s *server) isClosed() bool {
282+
select {
283+
case <-s.closed:
284+
return true
285+
default:
286+
return false
287+
}
288+
}
289+
290+
func (s *server) Close() error {
291+
s.sshServer.Close()
292+
return nil
293+
}

agent/agent_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package agent_test
2+
3+
import (
4+
"context"
5+
"runtime"
6+
"strings"
7+
"testing"
8+
9+
"github.com/pion/webrtc/v3"
10+
"github.com/stretchr/testify/require"
11+
"go.uber.org/goleak"
12+
"golang.org/x/crypto/ssh"
13+
14+
"cdr.dev/slog/sloggers/slogtest"
15+
"github.com/coder/coder/agent"
16+
"github.com/coder/coder/peer"
17+
"github.com/coder/coder/peerbroker"
18+
"github.com/coder/coder/peerbroker/proto"
19+
"github.com/coder/coder/provisionersdk"
20+
"github.com/coder/coder/pty/ptytest"
21+
)
22+
23+
func TestMain(m *testing.M) {
24+
goleak.VerifyTestMain(m)
25+
}
26+
27+
func TestAgent(t *testing.T) {
28+
t.Parallel()
29+
t.Run("SessionExec", func(t *testing.T) {
30+
t.Parallel()
31+
api := setup(t)
32+
stream, err := api.NegotiateConnection(context.Background())
33+
require.NoError(t, err)
34+
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
35+
Logger: slogtest.Make(t, nil),
36+
})
37+
require.NoError(t, err)
38+
defer conn.Close()
39+
sshClient, err := agent.DialSSHClient(conn)
40+
require.NoError(t, err)
41+
session, err := sshClient.NewSession()
42+
require.NoError(t, err)
43+
command := "echo test"
44+
if runtime.GOOS == "windows" {
45+
command = "cmd.exe /c echo test"
46+
}
47+
output, err := session.Output(command)
48+
require.NoError(t, err)
49+
require.Equal(t, "test", strings.TrimSpace(string(output)))
50+
})
51+
52+
t.Run("SessionTTY", func(t *testing.T) {
53+
t.Parallel()
54+
api := setup(t)
55+
stream, err := api.NegotiateConnection(context.Background())
56+
require.NoError(t, err)
57+
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
58+
Logger: slogtest.Make(t, nil),
59+
})
60+
require.NoError(t, err)
61+
defer conn.Close()
62+
sshClient, err := agent.DialSSHClient(conn)
63+
require.NoError(t, err)
64+
session, err := sshClient.NewSession()
65+
require.NoError(t, err)
66+
prompt := "$"
67+
command := "bash"
68+
if runtime.GOOS == "windows" {
69+
command = "cmd.exe"
70+
prompt = ">"
71+
}
72+
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
73+
require.NoError(t, err)
74+
ptty := ptytest.New(t)
75+
require.NoError(t, err)
76+
session.Stdout = ptty.Output()
77+
session.Stderr = ptty.Output()
78+
session.Stdin = ptty.Input()
79+
err = session.Start(command)
80+
require.NoError(t, err)
81+
ptty.ExpectMatch(prompt)
82+
ptty.WriteLine("echo test")
83+
ptty.ExpectMatch("test")
84+
ptty.WriteLine("exit")
85+
err = session.Wait()
86+
require.NoError(t, err)
87+
})
88+
}
89+
90+
func setup(t *testing.T) proto.DRPCPeerBrokerClient {
91+
client, server := provisionersdk.TransportPipe()
92+
closer := agent.New(func(ctx context.Context) (*peerbroker.Listener, error) {
93+
return peerbroker.Listen(server, &peer.ConnOptions{
94+
Logger: slogtest.Make(t, nil),
95+
})
96+
}, &agent.Options{
97+
Logger: slogtest.Make(t, nil),
98+
})
99+
t.Cleanup(func() {
100+
_ = client.Close()
101+
_ = server.Close()
102+
_ = closer.Close()
103+
})
104+
return proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
105+
}

0 commit comments

Comments
 (0)