Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c916a9e

Browse files
authored
fix(agent): guard against multiple rpty race for same id (coder#7998)
* fix(agent): guard against multiple rpty race for same id * fix(agent): ensure pty is closed on error
1 parent 9440b3d commit c916a9e

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

agent/agent.go

+25-6
Original file line numberDiff line numberDiff line change
@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10251025
}()
10261026

10271027
var rpty *reconnectingPTY
1028-
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
1028+
sendConnected := make(chan *reconnectingPTY, 1)
1029+
// On store, reserve this ID to prevent multiple concurrent new connections.
1030+
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10291031
if ok {
1032+
close(sendConnected) // Unused.
10301033
logger.Debug(ctx, "connecting to existing session")
1031-
rpty, ok = rawRPTY.(*reconnectingPTY)
1034+
c, ok := waitReady.(chan *reconnectingPTY)
10321035
if !ok {
1033-
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
1036+
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10341037
}
1038+
rpty, ok = <-c
1039+
if !ok || rpty == nil {
1040+
return xerrors.Errorf("reconnecting pty closed before connection")
1041+
}
1042+
c <- rpty // Put it back for the next reconnect.
10351043
} else {
10361044
logger.Debug(ctx, "creating new session")
10371045

1046+
connected := false
1047+
defer func() {
1048+
if !connected && retErr != nil {
1049+
a.reconnectingPTYs.Delete(msg.ID)
1050+
close(sendConnected)
1051+
}
1052+
}()
1053+
10381054
// Empty command will default to the users shell!
10391055
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
10401056
if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10551071
return xerrors.Errorf("start command: %w", err)
10561072
}
10571073

1058-
ctx, cancelFunc := context.WithCancel(ctx)
1074+
ctx, cancel := context.WithCancel(ctx)
10591075
rpty = &reconnectingPTY{
10601076
activeConns: map[string]net.Conn{
10611077
// We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10641080
},
10651081
ptty: ptty,
10661082
// Timeouts created with an after func can be reset!
1067-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
1083+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
10681084
circularBuffer: circularBuffer,
10691085
}
1070-
a.reconnectingPTYs.Store(msg.ID, rpty)
10711086
// We don't need to separately monitor for the process exiting.
10721087
// When it exits, our ptty.OutputReader() will return EOF after
10731088
// reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11151130
rpty.Close()
11161131
a.reconnectingPTYs.Delete(msg.ID)
11171132
}); err != nil {
1133+
_ = process.Kill()
1134+
_ = ptty.Close()
11181135
return xerrors.Errorf("start routine: %w", err)
11191136
}
1137+
connected = true
1138+
sendConnected <- rpty
11201139
}
11211140
// Resize the PTY to initial height + width.
11221141
err := rpty.ptty.Resize(msg.Height, msg.Width)

0 commit comments

Comments
 (0)