@@ -124,6 +124,7 @@ type Server struct {
124
124
listeners map [net.Listener ]struct {}
125
125
conns map [net.Conn ]struct {}
126
126
sessions map [ssh.Session ]struct {}
127
+ processes map [* os.Process ]struct {}
127
128
closing chan struct {}
128
129
// Wait for goroutines to exit, waited without
129
130
// a lock on mu but protected by closing.
@@ -182,6 +183,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
182
183
fs : fs ,
183
184
conns : make (map [net.Conn ]struct {}),
184
185
sessions : make (map [ssh.Session ]struct {}),
186
+ processes : make (map [* os.Process ]struct {}),
185
187
logger : logger ,
186
188
187
189
config : config ,
@@ -586,7 +588,10 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
586
588
// otherwise context cancellation will not propagate properly
587
589
// and SSH server close may be delayed.
588
590
cmd .SysProcAttr = cmdSysProcAttr ()
589
- cmd .Cancel = cmdCancel (session .Context (), logger , cmd )
591
+
592
+ // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
593
+ // c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
594
+ cmd .Cancel = nil
590
595
591
596
cmd .Stdout = session
592
597
cmd .Stderr = session .Stderr ()
@@ -609,6 +614,16 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
609
614
s .metrics .sessionErrors .WithLabelValues (magicTypeLabel , "no" , "start_command" ).Add (1 )
610
615
return xerrors .Errorf ("start: %w" , err )
611
616
}
617
+
618
+ // Since we don't cancel the process when the session stops, we still need to tear it down if we are closing. So
619
+ // track it here.
620
+ if ! s .trackProcess (cmd .Process , true ) {
621
+ // must be closing
622
+ err = cmdCancel (logger , cmd .Process )
623
+ return xerrors .Errorf ("failed to track process: %w" , err )
624
+ }
625
+ defer s .trackProcess (cmd .Process , false )
626
+
612
627
sigs := make (chan ssh.Signal , 1 )
613
628
session .Signals (sigs )
614
629
defer func () {
@@ -1052,6 +1067,27 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
1052
1067
return true
1053
1068
}
1054
1069
1070
+ // trackCommand registers the process with the server. If the server is
1071
+ // closing, the process is not registered and should be closed.
1072
+ //
1073
+ //nolint:revive
1074
+ func (s * Server ) trackProcess (p * os.Process , add bool ) (ok bool ) {
1075
+ s .mu .Lock ()
1076
+ defer s .mu .Unlock ()
1077
+ if add {
1078
+ if s .closing != nil {
1079
+ // Server closed.
1080
+ return false
1081
+ }
1082
+ s .wg .Add (1 )
1083
+ s .processes [p ] = struct {}{}
1084
+ return true
1085
+ }
1086
+ s .wg .Done ()
1087
+ delete (s .processes , p )
1088
+ return true
1089
+ }
1090
+
1055
1091
// Close the server and all active connections. Server can be re-used
1056
1092
// after Close is done.
1057
1093
func (s * Server ) Close () error {
@@ -1091,6 +1127,10 @@ func (s *Server) Close() error {
1091
1127
_ = c .Close ()
1092
1128
}
1093
1129
1130
+ for p := range s .processes {
1131
+ _ = cmdCancel (s .logger , p )
1132
+ }
1133
+
1094
1134
s .logger .Debug (ctx , "closing SSH server" )
1095
1135
err := s .srv .Close ()
1096
1136
0 commit comments