@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
1025
1025
}()
1026
1026
1027
1027
var rpty * reconnectingPTY
1028
- rawRPTY , ok := a .reconnectingPTYs .Load (msg .ID )
1028
+ sendConnected := make (chan * reconnectingPTY , 1 )
1029
+ // On store, reserve this ID to prevent multiple concurrent new connections.
1030
+ waitReady , ok := a .reconnectingPTYs .LoadOrStore (msg .ID , sendConnected )
1029
1031
if ok {
1032
+ close (sendConnected ) // Unused.
1030
1033
logger .Debug (ctx , "connecting to existing session" )
1031
- rpty , ok = rawRPTY .( * reconnectingPTY )
1034
+ c , ok := waitReady .( chan * reconnectingPTY )
1032
1035
if ! ok {
1033
- return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , rawRPTY )
1036
+ return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , waitReady )
1034
1037
}
1038
+ rpty , ok = <- c
1039
+ if ! ok || rpty == nil {
1040
+ return xerrors .Errorf ("reconnecting pty closed before connection" )
1041
+ }
1042
+ c <- rpty // Put it back for the next reconnect.
1035
1043
} else {
1036
1044
logger .Debug (ctx , "creating new session" )
1037
1045
1046
+ connected := false
1047
+ defer func () {
1048
+ if ! connected && retErr != nil {
1049
+ a .reconnectingPTYs .Delete (msg .ID )
1050
+ close (sendConnected )
1051
+ }
1052
+ }()
1053
+
1038
1054
// Empty command will default to the users shell!
1039
1055
cmd , err := a .sshServer .CreateCommand (ctx , msg .Command , nil )
1040
1056
if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
1055
1071
return xerrors .Errorf ("start command: %w" , err )
1056
1072
}
1057
1073
1058
- ctx , cancelFunc := context .WithCancel (ctx )
1074
+ ctx , cancel := context .WithCancel (ctx )
1059
1075
rpty = & reconnectingPTY {
1060
1076
activeConns : map [string ]net.Conn {
1061
1077
// We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
1064
1080
},
1065
1081
ptty : ptty ,
1066
1082
// Timeouts created with an after func can be reset!
1067
- timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancelFunc ),
1083
+ timeout : time .AfterFunc (a .reconnectingPTYTimeout , cancel ),
1068
1084
circularBuffer : circularBuffer ,
1069
1085
}
1070
- a .reconnectingPTYs .Store (msg .ID , rpty )
1071
1086
// We don't need to separately monitor for the process exiting.
1072
1087
// When it exits, our ptty.OutputReader() will return EOF after
1073
1088
// reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
1115
1130
rpty .Close ()
1116
1131
a .reconnectingPTYs .Delete (msg .ID )
1117
1132
}); err != nil {
1133
+ _ = process .Kill ()
1134
+ _ = ptty .Close ()
1118
1135
return xerrors .Errorf ("start routine: %w" , err )
1119
1136
}
1137
+ connected = true
1138
+ sendConnected <- rpty
1120
1139
}
1121
1140
// Resize the PTY to initial height + width.
1122
1141
err := rpty .ptty .Resize (msg .Height , msg .Width )
0 commit comments