@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10251025 }()
10261026
10271027 var rpty * reconnectingPTY
1028- rawRPTY , ok := a .reconnectingPTYs .Load (msg .ID )
1028+ sendConnected := make (chan * reconnectingPTY , 1 )
1029+ // On store, reserve this ID to prevent multiple concurrent new connections.
1030+ waitReady , ok := a .reconnectingPTYs .LoadOrStore (msg .ID , sendConnected )
10291031 if ok {
1032+ close (sendConnected ) // Unused.
10301033 logger .Debug (ctx , "connecting to existing session" )
1031- rpty , ok = rawRPTY .( * reconnectingPTY )
1034+ c , ok := waitReady .( chan * reconnectingPTY )
10321035 if ! ok {
1033- return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , rawRPTY )
1036+ return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , waitReady )
10341037 }
1038+ rpty , ok = <- c
1039+ if ! ok || rpty == nil {
1040+ return xerrors .Errorf ("reconnecting pty closed before connection" )
1041+ }
1042+ c <- rpty // Put it back for the next reconnect.
10351043 } else {
10361044 logger .Debug (ctx , "creating new session" )
10371045
1046+ connected := false
1047+ defer func () {
1048+ if ! connected && retErr != nil {
1049+ a .reconnectingPTYs .Delete (msg .ID )
1050+ close (sendConnected )
1051+ }
1052+ }()
1053+
10381054 // Empty command will default to the users shell!
10391055 cmd , err := a .sshServer .CreateCommand (ctx , msg .Command , nil )
10401056 if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10551071 return xerrors .Errorf ("start command: %w" , err )
10561072 }
10571073
1058- ctx , cancelFunc := context .WithCancel (ctx )
1074+ ctx , cancel := context .WithCancel (ctx )
10591075 rpty = & reconnectingPTY {
10601076 activeConns : map [string ]net.Conn {
10611077 // We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10641080 },
10651081 ptty : ptty ,
10661082 // Timeouts created with an after func can be reset!
1067- timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancelFunc ),
1083+ timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancel ),
10681084 circularBuffer : circularBuffer ,
10691085 }
1070- a .reconnectingPTYs .Store (msg .ID , rpty )
10711086 // We don't need to separately monitor for the process exiting.
10721087 // When it exits, our ptty.OutputReader() will return EOF after
10731088 // reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11151130 rpty .Close ()
11161131 a .reconnectingPTYs .Delete (msg .ID )
11171132 }); err != nil {
1133+ _ = process .Kill ()
1134+ _ = ptty .Close ()
11181135 return xerrors .Errorf ("start routine: %w" , err )
11191136 }
1137+ connected = true
1138+ sendConnected <- rpty
11201139 }
11211140 // Resize the PTY to initial height + width.
11221141 err := rpty .ptty .Resize (msg .Height , msg .Width )
0 commit comments