Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2c6e0f7

Browse files
authored
feat(agent/agentssh): handle session signals (coder#10842)
1 parent a7c27ca commit 2c6e0f7

10 files changed

+312
-9
lines changed

agent/agentssh/agentssh.go

+55-8
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
311311
if isPty {
312312
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
313313
}
314-
return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec())
314+
return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec())
315315
}
316316

317-
func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
317+
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
318318
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
319319

320320
cmd.Stdout = session
@@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string,
338338
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
339339
return xerrors.Errorf("start: %w", err)
340340
}
341+
sigs := make(chan ssh.Signal, 1)
342+
session.Signals(sigs)
343+
defer func() {
344+
session.Signals(nil)
345+
close(sigs)
346+
}()
347+
go func() {
348+
for sig := range sigs {
349+
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
350+
}
351+
}()
341352
return cmd.Wait()
342353
}
343354

@@ -348,6 +359,7 @@ type ptySession interface {
348359
Context() ssh.Context
349360
DisablePTYEmulation()
350361
RawCommand() string
362+
Signals(chan<- ssh.Signal)
351363
}
352364

353365
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
@@ -403,13 +415,36 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
403415
}
404416
}
405417
}()
418+
sigs := make(chan ssh.Signal, 1)
419+
session.Signals(sigs)
420+
defer func() {
421+
session.Signals(nil)
422+
close(sigs)
423+
}()
406424
go func() {
407-
for win := range windowSize {
408-
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
409-
// If the pty is closed, then command has exited, no need to log.
410-
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
411-
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
412-
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
425+
for {
426+
if sigs == nil && windowSize == nil {
427+
return
428+
}
429+
430+
select {
431+
case sig, ok := <-sigs:
432+
if !ok {
433+
sigs = nil
434+
continue
435+
}
436+
s.handleSignal(logger, sig, process, magicTypeLabel)
437+
case win, ok := <-windowSize:
438+
if !ok {
439+
windowSize = nil
440+
continue
441+
}
442+
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
443+
// If the pty is closed, then command has exited, no need to log.
444+
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
445+
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
446+
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
447+
}
413448
}
414449
}
415450
}()
@@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
452487
return nil
453488
}
454489

490+
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
491+
ctx := context.Background()
492+
sig := osSignalFrom(ssig)
493+
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
494+
logger.Info(ctx, "received signal from client")
495+
err := signaler.Signal(sig)
496+
if err != nil {
497+
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
498+
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
499+
}
500+
}
501+
455502
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
456503
s.metrics.sftpConnectionsTotal.Add(1)
457504

agent/agentssh/agentssh_internal_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ type testSSHContext struct {
114114
context.Context
115115
}
116116

117+
var (
118+
_ gliderssh.Context = testSSHContext{}
119+
_ ptySession = &testSession{}
120+
)
121+
117122
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
118123
toClient, fromPty := io.Pipe()
119124
toPty, fromClient := io.Pipe()
@@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) {
144149
return s.fromPty.Write(p)
145150
}
146151

152+
func (*testSession) Signals(_ chan<- gliderssh.Signal) {
153+
// Not implemented, but will be called.
154+
}
155+
147156
func (testSSHContext) Lock() {
148157
panic("not implemented")
149158
}

agent/agentssh/agentssh_test.go

+146-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
package agentssh_test
44

55
import (
6+
"bufio"
67
"bytes"
78
"context"
9+
"fmt"
810
"net"
911
"runtime"
1012
"strings"
@@ -24,6 +26,7 @@ import (
2426
"github.com/coder/coder/v2/agent/agentssh"
2527
"github.com/coder/coder/v2/codersdk/agentsdk"
2628
"github.com/coder/coder/v2/pty/ptytest"
29+
"github.com/coder/coder/v2/testutil"
2730
)
2831

2932
func TestMain(m *testing.M) {
@@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
5760

5861
var b bytes.Buffer
5962
sess, err := c.NewSession()
60-
sess.Stdout = &b
6163
require.NoError(t, err)
64+
sess.Stdout = &b
6265
err = sess.Start("echo hello")
6366
require.NoError(t, err)
6467

@@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
139142
defer wg.Done()
140143
c := sshClient(t, ln.Addr().String())
141144
sess, err := c.NewSession()
145+
assert.NoError(t, err)
142146
sess.Stdin = pty.Input()
143147
sess.Stdout = pty.Output()
144148
sess.Stderr = pty.Output()
@@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
159163
wg.Wait()
160164
}
161165

166+
func TestNewServer_Signal(t *testing.T) {
167+
t.Parallel()
168+
169+
t.Run("Stdout", func(t *testing.T) {
170+
t.Parallel()
171+
172+
ctx := context.Background()
173+
logger := slogtest.Make(t, nil)
174+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
175+
require.NoError(t, err)
176+
defer s.Close()
177+
178+
// The assumption is that these are set before serving SSH connections.
179+
s.AgentToken = func() string { return "" }
180+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
181+
182+
ln, err := net.Listen("tcp", "127.0.0.1:0")
183+
require.NoError(t, err)
184+
185+
done := make(chan struct{})
186+
go func() {
187+
defer close(done)
188+
err := s.Serve(ln)
189+
assert.Error(t, err) // Server is closed.
190+
}()
191+
defer func() {
192+
err := s.Close()
193+
require.NoError(t, err)
194+
<-done
195+
}()
196+
197+
c := sshClient(t, ln.Addr().String())
198+
199+
sess, err := c.NewSession()
200+
require.NoError(t, err)
201+
r, err := sess.StdoutPipe()
202+
require.NoError(t, err)
203+
204+
// Perform multiple sleeps since the interrupt signal doesn't propagate to
205+
// the process group, this lets us exit early.
206+
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
207+
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
208+
require.NoError(t, err)
209+
210+
sc := bufio.NewScanner(r)
211+
for sc.Scan() {
212+
t.Log(sc.Text())
213+
if strings.Contains(sc.Text(), "hello") {
214+
break
215+
}
216+
}
217+
require.NoError(t, sc.Err())
218+
219+
err = sess.Signal(ssh.SIGINT)
220+
require.NoError(t, err)
221+
222+
// Assumption, signal propagates and the command exists, closing stdout.
223+
for sc.Scan() {
224+
t.Log(sc.Text())
225+
require.NotContains(t, sc.Text(), "bye")
226+
}
227+
require.NoError(t, sc.Err())
228+
229+
err = sess.Wait()
230+
require.Error(t, err)
231+
})
232+
t.Run("PTY", func(t *testing.T) {
233+
t.Parallel()
234+
235+
ctx := context.Background()
236+
logger := slogtest.Make(t, nil)
237+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
238+
require.NoError(t, err)
239+
defer s.Close()
240+
241+
// The assumption is that these are set before serving SSH connections.
242+
s.AgentToken = func() string { return "" }
243+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
244+
245+
ln, err := net.Listen("tcp", "127.0.0.1:0")
246+
require.NoError(t, err)
247+
248+
done := make(chan struct{})
249+
go func() {
250+
defer close(done)
251+
err := s.Serve(ln)
252+
assert.Error(t, err) // Server is closed.
253+
}()
254+
defer func() {
255+
err := s.Close()
256+
require.NoError(t, err)
257+
<-done
258+
}()
259+
260+
c := sshClient(t, ln.Addr().String())
261+
262+
pty := ptytest.New(t)
263+
264+
sess, err := c.NewSession()
265+
require.NoError(t, err)
266+
r, err := sess.StdoutPipe()
267+
require.NoError(t, err)
268+
269+
// Note, we request pty but don't use ptytest here because we can't
270+
// easily test for no text before EOF.
271+
sess.Stdin = pty.Input()
272+
sess.Stderr = pty.Output()
273+
274+
err = sess.RequestPty("xterm", 80, 80, nil)
275+
require.NoError(t, err)
276+
277+
// Perform multiple sleeps since the interrupt signal doesn't propagate to
278+
// the process group, this lets us exit early.
279+
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
280+
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
281+
require.NoError(t, err)
282+
283+
sc := bufio.NewScanner(r)
284+
for sc.Scan() {
285+
t.Log(sc.Text())
286+
if strings.Contains(sc.Text(), "hello") {
287+
break
288+
}
289+
}
290+
require.NoError(t, sc.Err())
291+
292+
err = sess.Signal(ssh.SIGINT)
293+
require.NoError(t, err)
294+
295+
// Assumption, signal propagates and the command exists, closing stdout.
296+
for sc.Scan() {
297+
t.Log(sc.Text())
298+
require.NotContains(t, sc.Text(), "bye")
299+
}
300+
require.NoError(t, sc.Err())
301+
302+
err = sess.Wait()
303+
require.Error(t, err)
304+
})
305+
}
306+
162307
func sshClient(t *testing.T, addr string) *ssh.Client {
163308
conn, err := net.Dial("tcp", addr)
164309
require.NoError(t, err)

agent/agentssh/signal_other.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//go:build !windows
2+
3+
package agentssh
4+
5+
import (
6+
"os"
7+
8+
"github.com/gliderlabs/ssh"
9+
"golang.org/x/sys/unix"
10+
)
11+
12+
func osSignalFrom(sig ssh.Signal) os.Signal {
13+
switch sig {
14+
case ssh.SIGABRT:
15+
return unix.SIGABRT
16+
case ssh.SIGALRM:
17+
return unix.SIGALRM
18+
case ssh.SIGFPE:
19+
return unix.SIGFPE
20+
case ssh.SIGHUP:
21+
return unix.SIGHUP
22+
case ssh.SIGILL:
23+
return unix.SIGILL
24+
case ssh.SIGINT:
25+
return unix.SIGINT
26+
case ssh.SIGKILL:
27+
return unix.SIGKILL
28+
case ssh.SIGPIPE:
29+
return unix.SIGPIPE
30+
case ssh.SIGQUIT:
31+
return unix.SIGQUIT
32+
case ssh.SIGSEGV:
33+
return unix.SIGSEGV
34+
case ssh.SIGTERM:
35+
return unix.SIGTERM
36+
case ssh.SIGUSR1:
37+
return unix.SIGUSR1
38+
case ssh.SIGUSR2:
39+
return unix.SIGUSR2
40+
41+
// Unhandled, use sane fallback.
42+
default:
43+
return unix.SIGKILL
44+
}
45+
}

agent/agentssh/signal_windows.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package agentssh
2+
3+
import (
4+
"os"
5+
6+
"github.com/gliderlabs/ssh"
7+
)
8+
9+
func osSignalFrom(sig ssh.Signal) os.Signal {
10+
switch sig {
11+
// Signals are not supported on Windows.
12+
default:
13+
return os.Kill
14+
}
15+
}

pty/pty.go

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pty
33
import (
44
"io"
55
"log"
6+
"os"
67

78
"github.com/gliderlabs/ssh"
89
"golang.org/x/xerrors"
@@ -69,6 +70,11 @@ type Process interface {
6970

7071
// Kill the command process. Returned error is as for os.Process.Kill()
7172
Kill() error
73+
74+
// Signal sends a signal to the command process. On non-windows systems, the
75+
// returned error is as for os.Process.Signal(), on Windows it's
76+
// as for os.Process.Kill().
77+
Signal(sig os.Signal) error
7278
}
7379

7480
// WithFlags represents a PTY whose flags can be inspected, in particular

0 commit comments

Comments
 (0)