Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 920239f

Browse files
authored
fix: codex delegate cancellation (openai#7092)
1 parent 99bcb90 commit 920239f

2 files changed

Lines changed: 170 additions & 45 deletions

File tree

codex-rs/core/src/codex.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,9 @@ use crate::features::Features;
24362436
#[cfg(test)]
24372437
pub(crate) use tests::make_session_and_context;
24382438

2439+
#[cfg(test)]
2440+
pub(crate) use tests::make_session_and_context_with_rx;
2441+
24392442
#[cfg(test)]
24402443
mod tests {
24412444
use super::*;
@@ -2712,7 +2715,7 @@ mod tests {
27122715

27132716
// Like make_session_and_context, but returns Arc<Session> and the event receiver
27142717
// so tests can assert on emitted events.
2715-
fn make_session_and_context_with_rx() -> (
2718+
pub(crate) fn make_session_and_context_with_rx() -> (
27162719
Arc<Session>,
27172720
Arc<TurnContext>,
27182721
async_channel::Receiver<Event>,

codex-rs/core/src/codex_delegate.rs

Lines changed: 166 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use codex_protocol::protocol::SessionSource;
1313
use codex_protocol::protocol::SubAgentSource;
1414
use codex_protocol::protocol::Submission;
1515
use codex_protocol::user_input::UserInput;
16+
use std::time::Duration;
17+
use tokio::time::timeout;
1618
use tokio_util::sync::CancellationToken;
1719

1820
use crate::AuthManager;
@@ -60,14 +62,13 @@ pub(crate) async fn run_codex_conversation_interactive(
6062
let parent_ctx_clone = Arc::clone(&parent_ctx);
6163
let codex_for_events = Arc::clone(&codex);
6264
tokio::spawn(async move {
63-
let _ = forward_events(
65+
forward_events(
6466
codex_for_events,
6567
tx_sub,
6668
parent_session_clone,
6769
parent_ctx_clone,
68-
cancel_token_events.clone(),
70+
cancel_token_events,
6971
)
70-
.or_cancel(&cancel_token_events)
7172
.await;
7273
});
7374

@@ -156,53 +157,92 @@ async fn forward_events(
156157
parent_ctx: Arc<TurnContext>,
157158
cancel_token: CancellationToken,
158159
) {
159-
while let Ok(event) = codex.next_event().await {
160-
match event {
161-
// ignore all legacy delta events
162-
Event {
163-
id: _,
164-
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
165-
} => continue,
166-
Event {
167-
id: _,
168-
msg: EventMsg::SessionConfigured(_),
169-
} => continue,
170-
Event {
171-
id,
172-
msg: EventMsg::ExecApprovalRequest(event),
173-
} => {
174-
// Initiate approval via parent session; do not surface to consumer.
175-
handle_exec_approval(
176-
&codex,
177-
id,
178-
&parent_session,
179-
&parent_ctx,
180-
event,
181-
&cancel_token,
182-
)
183-
.await;
184-
}
185-
Event {
186-
id,
187-
msg: EventMsg::ApplyPatchApprovalRequest(event),
188-
} => {
189-
handle_patch_approval(
190-
&codex,
191-
id,
192-
&parent_session,
193-
&parent_ctx,
194-
event,
195-
&cancel_token,
196-
)
197-
.await;
160+
let cancelled = cancel_token.cancelled();
161+
tokio::pin!(cancelled);
162+
163+
loop {
164+
tokio::select! {
165+
_ = &mut cancelled => {
166+
shutdown_delegate(&codex).await;
167+
break;
198168
}
199-
other => {
200-
let _ = tx_sub.send(other).await;
169+
event = codex.next_event() => {
170+
let event = match event {
171+
Ok(event) => event,
172+
Err(_) => break,
173+
};
174+
match event {
175+
// ignore all legacy delta events
176+
Event {
177+
id: _,
178+
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
179+
} => {}
180+
Event {
181+
id: _,
182+
msg: EventMsg::SessionConfigured(_),
183+
} => {}
184+
Event {
185+
id,
186+
msg: EventMsg::ExecApprovalRequest(event),
187+
} => {
188+
// Initiate approval via parent session; do not surface to consumer.
189+
handle_exec_approval(
190+
&codex,
191+
id,
192+
&parent_session,
193+
&parent_ctx,
194+
event,
195+
&cancel_token,
196+
)
197+
.await;
198+
}
199+
Event {
200+
id,
201+
msg: EventMsg::ApplyPatchApprovalRequest(event),
202+
} => {
203+
handle_patch_approval(
204+
&codex,
205+
id,
206+
&parent_session,
207+
&parent_ctx,
208+
event,
209+
&cancel_token,
210+
)
211+
.await;
212+
}
213+
other => {
214+
match tx_sub.send(other).or_cancel(&cancel_token).await {
215+
Ok(Ok(())) => {}
216+
_ => {
217+
shutdown_delegate(&codex).await;
218+
break;
219+
}
220+
}
221+
}
222+
}
201223
}
202224
}
203225
}
204226
}
205227

228+
/// Ask the delegate to stop and drain its events so background sends do not hit a closed channel.
229+
async fn shutdown_delegate(codex: &Codex) {
230+
let _ = codex.submit(Op::Interrupt).await;
231+
let _ = codex.submit(Op::Shutdown {}).await;
232+
233+
let _ = timeout(Duration::from_millis(500), async {
234+
while let Ok(event) = codex.next_event().await {
235+
if matches!(
236+
event.msg,
237+
EventMsg::TurnAborted(_) | EventMsg::TaskComplete(_)
238+
) {
239+
break;
240+
}
241+
}
242+
})
243+
.await;
244+
}
245+
206246
/// Forward ops from a caller to a sub-agent, respecting cancellation.
207247
async fn forward_ops(
208248
codex: Arc<Codex>,
@@ -298,3 +338,85 @@ where
298338
}
299339
}
300340
}
341+
342+
#[cfg(test)]
343+
mod tests {
344+
use super::*;
345+
use async_channel::bounded;
346+
use codex_protocol::models::ResponseItem;
347+
use codex_protocol::protocol::RawResponseItemEvent;
348+
use codex_protocol::protocol::TurnAbortReason;
349+
use codex_protocol::protocol::TurnAbortedEvent;
350+
use pretty_assertions::assert_eq;
351+
352+
#[tokio::test]
353+
async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() {
354+
let (tx_events, rx_events) = bounded(1);
355+
let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);
356+
let codex = Arc::new(Codex {
357+
next_id: AtomicU64::new(0),
358+
tx_sub,
359+
rx_event: rx_events,
360+
});
361+
362+
let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx();
363+
364+
let (tx_out, rx_out) = bounded(1);
365+
tx_out
366+
.send(Event {
367+
id: "full".to_string(),
368+
msg: EventMsg::TurnAborted(TurnAbortedEvent {
369+
reason: TurnAbortReason::Interrupted,
370+
}),
371+
})
372+
.await
373+
.unwrap();
374+
375+
let cancel = CancellationToken::new();
376+
let forward = tokio::spawn(forward_events(
377+
Arc::clone(&codex),
378+
tx_out.clone(),
379+
session,
380+
ctx,
381+
cancel.clone(),
382+
));
383+
384+
tx_events
385+
.send(Event {
386+
id: "evt".to_string(),
387+
msg: EventMsg::RawResponseItem(RawResponseItemEvent {
388+
item: ResponseItem::CustomToolCall {
389+
id: None,
390+
status: None,
391+
call_id: "call-1".to_string(),
392+
name: "tool".to_string(),
393+
input: "{}".to_string(),
394+
},
395+
}),
396+
})
397+
.await
398+
.unwrap();
399+
400+
drop(tx_events);
401+
cancel.cancel();
402+
timeout(std::time::Duration::from_millis(1000), forward)
403+
.await
404+
.expect("forward_events hung")
405+
.expect("forward_events join error");
406+
407+
let received = rx_out.recv().await.expect("prefilled event missing");
408+
assert_eq!("full", received.id);
409+
let mut ops = Vec::new();
410+
while let Ok(sub) = rx_sub.try_recv() {
411+
ops.push(sub.op);
412+
}
413+
assert!(
414+
ops.iter().any(|op| matches!(op, Op::Interrupt)),
415+
"expected Interrupt op after cancellation"
416+
);
417+
assert!(
418+
ops.iter().any(|op| matches!(op, Op::Shutdown)),
419+
"expected Shutdown op after cancellation"
420+
);
421+
}
422+
}

0 commit comments

Comments
 (0)