3
3
package agentssh_test
4
4
5
5
import (
6
+ "bufio"
6
7
"bytes"
7
8
"context"
9
+ "fmt"
8
10
"net"
9
11
"runtime"
10
12
"strings"
@@ -24,6 +26,7 @@ import (
24
26
"github.com/coder/coder/v2/agent/agentssh"
25
27
"github.com/coder/coder/v2/codersdk/agentsdk"
26
28
"github.com/coder/coder/v2/pty/ptytest"
29
+ "github.com/coder/coder/v2/testutil"
27
30
)
28
31
29
32
func TestMain (m * testing.M ) {
@@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
57
60
58
61
var b bytes.Buffer
59
62
sess , err := c .NewSession ()
60
- sess .Stdout = & b
61
63
require .NoError (t , err )
64
+ sess .Stdout = & b
62
65
err = sess .Start ("echo hello" )
63
66
require .NoError (t , err )
64
67
@@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
139
142
defer wg .Done ()
140
143
c := sshClient (t , ln .Addr ().String ())
141
144
sess , err := c .NewSession ()
145
+ assert .NoError (t , err )
142
146
sess .Stdin = pty .Input ()
143
147
sess .Stdout = pty .Output ()
144
148
sess .Stderr = pty .Output ()
@@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
159
163
wg .Wait ()
160
164
}
161
165
166
+ func TestNewServer_Signal (t * testing.T ) {
167
+ t .Parallel ()
168
+
169
+ t .Run ("Stdout" , func (t * testing.T ) {
170
+ t .Parallel ()
171
+
172
+ ctx := context .Background ()
173
+ logger := slogtest .Make (t , nil )
174
+ s , err := agentssh .NewServer (ctx , logger , prometheus .NewRegistry (), afero .NewMemMapFs (), 0 , "" )
175
+ require .NoError (t , err )
176
+ defer s .Close ()
177
+
178
+ // The assumption is that these are set before serving SSH connections.
179
+ s .AgentToken = func () string { return "" }
180
+ s .Manifest = atomic .NewPointer (& agentsdk.Manifest {})
181
+
182
+ ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
183
+ require .NoError (t , err )
184
+
185
+ done := make (chan struct {})
186
+ go func () {
187
+ defer close (done )
188
+ err := s .Serve (ln )
189
+ assert .Error (t , err ) // Server is closed.
190
+ }()
191
+ defer func () {
192
+ err := s .Close ()
193
+ require .NoError (t , err )
194
+ <- done
195
+ }()
196
+
197
+ c := sshClient (t , ln .Addr ().String ())
198
+
199
+ sess , err := c .NewSession ()
200
+ require .NoError (t , err )
201
+ r , err := sess .StdoutPipe ()
202
+ require .NoError (t , err )
203
+
204
+ // Perform multiple sleeps since the interrupt signal doesn't propagate to
205
+ // the process group, this lets us exit early.
206
+ sleeps := strings .Repeat ("sleep 1 && " , int (testutil .WaitMedium .Seconds ()))
207
+ err = sess .Start (fmt .Sprintf ("echo hello && %s echo bye" , sleeps ))
208
+ require .NoError (t , err )
209
+
210
+ sc := bufio .NewScanner (r )
211
+ for sc .Scan () {
212
+ t .Log (sc .Text ())
213
+ if strings .Contains (sc .Text (), "hello" ) {
214
+ break
215
+ }
216
+ }
217
+ require .NoError (t , sc .Err ())
218
+
219
+ err = sess .Signal (ssh .SIGINT )
220
+ require .NoError (t , err )
221
+
222
+ // Assumption, signal propagates and the command exists, closing stdout.
223
+ for sc .Scan () {
224
+ t .Log (sc .Text ())
225
+ require .NotContains (t , sc .Text (), "bye" )
226
+ }
227
+ require .NoError (t , sc .Err ())
228
+
229
+ err = sess .Wait ()
230
+ require .Error (t , err )
231
+ })
232
+ t .Run ("PTY" , func (t * testing.T ) {
233
+ t .Parallel ()
234
+
235
+ ctx := context .Background ()
236
+ logger := slogtest .Make (t , nil )
237
+ s , err := agentssh .NewServer (ctx , logger , prometheus .NewRegistry (), afero .NewMemMapFs (), 0 , "" )
238
+ require .NoError (t , err )
239
+ defer s .Close ()
240
+
241
+ // The assumption is that these are set before serving SSH connections.
242
+ s .AgentToken = func () string { return "" }
243
+ s .Manifest = atomic .NewPointer (& agentsdk.Manifest {})
244
+
245
+ ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
246
+ require .NoError (t , err )
247
+
248
+ done := make (chan struct {})
249
+ go func () {
250
+ defer close (done )
251
+ err := s .Serve (ln )
252
+ assert .Error (t , err ) // Server is closed.
253
+ }()
254
+ defer func () {
255
+ err := s .Close ()
256
+ require .NoError (t , err )
257
+ <- done
258
+ }()
259
+
260
+ c := sshClient (t , ln .Addr ().String ())
261
+
262
+ pty := ptytest .New (t )
263
+
264
+ sess , err := c .NewSession ()
265
+ require .NoError (t , err )
266
+ r , err := sess .StdoutPipe ()
267
+ require .NoError (t , err )
268
+
269
+ // Note, we request pty but don't use ptytest here because we can't
270
+ // easily test for no text before EOF.
271
+ sess .Stdin = pty .Input ()
272
+ sess .Stderr = pty .Output ()
273
+
274
+ err = sess .RequestPty ("xterm" , 80 , 80 , nil )
275
+ require .NoError (t , err )
276
+
277
+ // Perform multiple sleeps since the interrupt signal doesn't propagate to
278
+ // the process group, this lets us exit early.
279
+ sleeps := strings .Repeat ("sleep 1 && " , int (testutil .WaitMedium .Seconds ()))
280
+ err = sess .Start (fmt .Sprintf ("echo hello && %s echo bye" , sleeps ))
281
+ require .NoError (t , err )
282
+
283
+ sc := bufio .NewScanner (r )
284
+ for sc .Scan () {
285
+ t .Log (sc .Text ())
286
+ if strings .Contains (sc .Text (), "hello" ) {
287
+ break
288
+ }
289
+ }
290
+ require .NoError (t , sc .Err ())
291
+
292
+ err = sess .Signal (ssh .SIGINT )
293
+ require .NoError (t , err )
294
+
295
+ // Assumption, signal propagates and the command exists, closing stdout.
296
+ for sc .Scan () {
297
+ t .Log (sc .Text ())
298
+ require .NotContains (t , sc .Text (), "bye" )
299
+ }
300
+ require .NoError (t , sc .Err ())
301
+
302
+ err = sess .Wait ()
303
+ require .Error (t , err )
304
+ })
305
+ }
306
+
162
307
func sshClient (t * testing.T , addr string ) * ssh.Client {
163
308
conn , err := net .Dial ("tcp" , addr )
164
309
require .NoError (t , err )
0 commit comments