diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f8f0c5968b..95655e4455 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -75,46 +75,46 @@ jobs: - name: Run tests run: cargo test --locked --workspace --lib --bins --test '*' --exclude fig_desktop-fuzz - cargo-clippy-windows-chat-cli: - name: Clippy Windows (chat_cli) - runs-on: windows-latest - timeout-minutes: 60 - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@1.84.0 - id: toolchain - with: - components: clippy - - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ - target/ - key: cargo-clippy-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} - - run: cargo clippy --locked -p chat_cli --color always -- -D warnings + # cargo-clippy-windows-chat-cli: + # name: Clippy Windows (chat_cli) + # runs-on: windows-latest + # timeout-minutes: 60 + # steps: + # - uses: actions/checkout@v4 + # - uses: dtolnay/rust-toolchain@1.84.0 + # id: toolchain + # with: + # components: clippy + # - uses: actions/cache@v4 + # with: + # path: | + # ~/.cargo/registry/index/ + # ~/.cargo/registry/cache/ + # ~/.cargo/git/db/ + # target/ + # key: cargo-clippy-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} + # - run: cargo clippy --locked -p chat_cli --color always -- -D warnings - cargo-test-windows-chat-cli: - name: Test Windows (chat_cli) - runs-on: windows-latest - timeout-minutes: 60 - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly - id: toolchain - with: - components: llvm-tools-preview - - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ - target/ - key: cargo-test-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} - - name: Run tests - run: cargo test --locked -p chat_cli + # cargo-test-windows-chat-cli: + # name: Test Windows (chat_cli) + # runs-on: windows-latest + # timeout-minutes: 60 + # steps: + # - uses: actions/checkout@v4 + # - uses: dtolnay/rust-toolchain@nightly + # id: toolchain + # with: + # components: llvm-tools-preview + # - uses: actions/cache@v4 + # with: + # path: | + # ~/.cargo/registry/index/ + # ~/.cargo/registry/cache/ + # ~/.cargo/git/db/ + # target/ + # key: cargo-test-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} + # - name: Run tests + # run: cargo test --locked -p chat_cli cargo-fmt: name: Fmt diff --git a/Cargo.lock b/Cargo.lock index 73a6f2ea8f..41980c82f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1447,6 +1447,7 @@ dependencies = [ name = "chat_cli" version = "1.19.3" dependencies = [ + "agent", "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", "amzn-consolas-client", @@ -1549,6 +1550,7 @@ dependencies = [ "thiserror 2.0.17", "time", "tokio", + "tokio-stream", "tokio-tungstenite", "tokio-util", "toml", diff --git a/Cargo.toml b/Cargo.toml index 7120e92c82..d71772ecf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ version = "1.19.3" license = "MIT OR Apache-2.0" [workspace.dependencies] +agent = { path = "crates/agent" } amzn-codewhisperer-client = { path = "crates/amzn-codewhisperer-client" } amzn-codewhisperer-streaming-client = { path = "crates/amzn-codewhisperer-streaming-client" } amzn-consolas-client = { path = "crates/amzn-consolas-client" } @@ -112,6 +113,7 @@ tempfile = "3.18.0" thiserror = "2.0.12" time = { version = "0.3.39", features = ["parsing", "formatting", "local-offset", "macros", "serde"] } tokio = { version = "1.45.0", features = ["full"] } +tokio-stream = { version = "0.1.17", features = ["io-util"] } tokio-tungstenite = "0.26.2" tokio-util = { version = "0.7.15", features = ["codec", "compat"] } toml = "0.8.12" diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 4e568a62dc..08f219024e 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -69,7 +69,7 @@ tempfile.workspace = true thiserror.workspace = true time.workspace = true tokio.workspace = true -tokio-stream = { version = "0.1.17", features = ["io-util"] } +tokio-stream.workspace = true tokio-util.workspace = true tracing.workspace = true tracing-appender = "0.2.3" diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index 52725c72ef..68370bc3d0 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -16,6 +16,7 @@ default = [] wayland = ["arboard/wayland-data-control"] [dependencies] +agent.workspace = true amzn-codewhisperer-client.workspace = true amzn-codewhisperer-streaming-client.workspace = true amzn-consolas-client.workspace = true @@ -102,6 +103,7 @@ tempfile.workspace = true thiserror.workspace = true time.workspace = true tokio.workspace = true +tokio-stream.workspace = true tokio-tungstenite.workspace = true tokio-util.workspace = true toml.workspace = true diff --git a/crates/chat-cli/src/agent/mod.rs b/crates/chat-cli/src/agent/mod.rs new file mode 100644 index 0000000000..7633acd8e9 --- /dev/null +++ b/crates/chat-cli/src/agent/mod.rs @@ -0,0 +1 @@ +mod rts; diff --git a/crates/chat-cli/src/agent/rts/mod.rs b/crates/chat-cli/src/agent/rts/mod.rs new file mode 100644 index 0000000000..2648658386 --- /dev/null +++ b/crates/chat-cli/src/agent/rts/mod.rs @@ -0,0 +1,773 @@ +#![allow(dead_code)] + +pub mod types; + +use std::pin::Pin; +use std::sync::Arc; +use std::time::{ + Duration, + Instant, +}; + +use agent::agent_loop::model::Model; +use agent::agent_loop::protocol::StreamResult; +use agent::agent_loop::types::{ + ContentBlock, + ContentBlockDelta, + ContentBlockDeltaEvent, + ContentBlockStart, + ContentBlockStartEvent, + ContentBlockStopEvent, + Message, + MessageStartEvent, + MessageStopEvent, + MetadataEvent, + MetadataMetrics, + MetadataService, + Role, + StopReason, + StreamError, + StreamErrorKind, + StreamErrorSource, + StreamEvent, + ToolResultContentBlock, + ToolSpec, + ToolUseBlockDelta, + ToolUseBlockStart, +}; +use chrono::{ + DateTime, + Utc, +}; +use eyre::Result; +use futures::Stream; +use serde::{ + Deserialize, + Serialize, +}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tracing::{ + debug, + error, + info, + trace, + warn, +}; +use uuid::Uuid; + +use crate::api_client::error::{ + ApiClientError, + ConverseStreamError, + ConverseStreamErrorKind, +}; +use crate::api_client::model::{ + ChatResponseStream, + ConversationState, + ToolSpecification, + UserInputMessage, + UserInputMessageContext, +}; +use crate::api_client::send_message_output::SendMessageOutput; +use crate::api_client::{ + ApiClient, + model as rts, +}; +use crate::cli::chat::util::serde_value_to_document; + +/// A [Model] implementation using the RTS backend. +#[derive(Debug, Clone)] +pub struct RtsModel { + client: ApiClient, + conversation_id: Uuid, + model_id: Option, +} + +impl RtsModel { + pub fn new(client: ApiClient, conversation_id: Uuid, model_id: Option) -> Self { + Self { + client, + conversation_id, + model_id, + } + } + + pub fn conversation_id(&self) -> &Uuid { + &self.conversation_id + } + + pub fn model_id(&self) -> Option<&str> { + self.model_id.as_deref() + } + + async fn converse_stream_rts( + self, + tx: mpsc::Sender, + cancel_token: CancellationToken, + messages: Vec, + tool_specs: Option>, + system_prompt: Option, + ) { + let state = match self.make_conversation_state(messages, tool_specs, system_prompt) { + Ok(s) => s, + Err(msg) => { + error!(?msg, "failed to create conversation state"); + tx.send(StreamResult::Err(StreamError::new(StreamErrorKind::Validation { + message: Some(msg), + }))) + .await + .map_err(|err| error!(?err, "failed to send model event")) + .ok(); + return; + }, + }; + + let request_start_time = Instant::now(); + let request_start_time_sys = Utc::now(); + let token_clone = cancel_token.clone(); + let result = tokio::select! { + _ = token_clone.cancelled() => { + warn!("rts request cancelled during send"); + tx.send(StreamResult::Err(StreamError::new(StreamErrorKind::Interrupted))) + .await + .map_err(|err| (error!(?err, "failed to send event"))) + .ok(); + return; + }, + result = self.client.send_message(state) => { + result + } + }; + self.handle_send_message_output( + result, + request_start_time.elapsed(), + tx, + cancel_token, + request_start_time, + request_start_time_sys, + ) + .await; + } + + async fn handle_send_message_output( + &self, + res: Result, + request_duration: Duration, + tx: mpsc::Sender, + token: CancellationToken, + request_start_time: Instant, + request_start_time_sys: DateTime, + ) { + match res { + Ok(output) => { + info!(?request_duration, "rts request sent successfully"); + let request_id = output.request_id().map(String::from); + ResponseParser::new( + output, + tx, + token, + request_id, + request_start_time, + request_start_time_sys, + ) + .consume_stream() + .await; + }, + Err(err) => { + error!(?err, ?request_duration, "failed to send rts request"); + let kind = match err.kind { + ConverseStreamErrorKind::Throttling => StreamErrorKind::Throttling, + ConverseStreamErrorKind::MonthlyLimitReached => StreamErrorKind::Other(err.to_string()), + ConverseStreamErrorKind::ContextWindowOverflow => StreamErrorKind::ContextWindowOverflow, + ConverseStreamErrorKind::ModelOverloadedError => StreamErrorKind::Throttling, + ConverseStreamErrorKind::Unknown { .. } => StreamErrorKind::Other(err.to_string()), + }; + let request_id = err.request_id.clone(); + tx.send(StreamResult::Err( + StreamError::new(kind) + .set_original_request_id(request_id) + .set_original_status_code(err.status_code) + .with_source(Arc::new(err)), + )) + .await + .map_err(|err| error!(?err, "failed to send stream event")) + .ok(); + }, + } + } + + fn make_conversation_state( + &self, + mut messages: Vec, + tool_specs: Option>, + _system_prompt: Option, + ) -> Result { + debug!(?messages, ?tool_specs, "creating conversation state"); + let tools = tool_specs.map(|v| { + v.into_iter() + .map(Into::::into) + .map(Into::into) + .collect() + }); + + // Creates the next user message to send. + let user_input_message = match messages.pop() { + Some(m) if m.role == Role::User => { + let content = m.text(); + let (tool_results, images) = extract_tool_results_and_images(&m); + let user_input_message_context = Some(UserInputMessageContext { + env_state: None, + git_state: None, + tool_results, + tools, + }); + + UserInputMessage { + content, + user_input_message_context, + user_intent: None, + images, + model_id: self.model_id.clone(), + } + }, + Some(m) => return Err(format!("Next message must be from the user, instead found: {}", m.role)), + None => return Err("Empty conversation".to_string()), + }; + + let history = messages + .into_iter() + .map(|m| match m.role { + Role::User => { + let content = m.text(); + let (tool_results, _) = extract_tool_results_and_images(&m); + let ctx = if tool_results.is_some() { + Some(UserInputMessageContext { + env_state: None, + git_state: None, + tool_results, + tools: None, + }) + } else { + None + }; + let msg = UserInputMessage { + content, + user_input_message_context: ctx, + user_intent: None, + images: None, + model_id: None, + }; + rts::ChatMessage::UserInputMessage(msg) + }, + Role::Assistant => { + let msg = rts::AssistantResponseMessage { + message_id: m.id.clone(), + content: m.text(), + tool_uses: m.tool_uses().map(|v| v.into_iter().map(Into::into).collect()), + }; + rts::ChatMessage::AssistantResponseMessage(msg) + }, + }) + .collect(); + + Ok(ConversationState { + conversation_id: Some(self.conversation_id.to_string()), + user_input_message, + history: Some(history), + }) + } +} + +impl StreamErrorSource for ConverseStreamError { + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +impl StreamErrorSource for ApiClientError { + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Annoyingly, the RTS API doesn't allow images as tool use results, so we have to extract tool +/// results and image content separately. +fn extract_tool_results_and_images(message: &Message) -> (Option>, Option>) { + let mut images = Vec::new(); + let mut tool_results = Vec::new(); + for item in &message.content { + match item { + ContentBlock::ToolResult(block) => { + let tool_use_id = block.tool_use_id.clone(); + let status = block.status.into(); + let mut content = Vec::new(); + for c in &block.content { + match c { + ToolResultContentBlock::Text(t) => content.push(rts::ToolResultContentBlock::Text(t.clone())), + ToolResultContentBlock::Json(v) => { + content.push(rts::ToolResultContentBlock::Json(serde_value_to_document(v.clone()))); + }, + ToolResultContentBlock::Image(img) => images.push(rts::ImageBlock { + format: img.format.into(), + source: img.source.clone().into(), + }), + } + } + tool_results.push(rts::ToolResult { + tool_use_id, + content, + status, + }); + }, + ContentBlock::Image(img) => images.push(rts::ImageBlock { + format: img.format.into(), + source: img.source.clone().into(), + }), + _ => (), + } + } + + ( + if tool_results.is_empty() { + None + } else { + Some(tool_results) + }, + if images.is_empty() { None } else { Some(images) }, + ) +} + +impl Model for RtsModel { + fn stream( + &self, + messages: Vec, + tool_specs: Option>, + system_prompt: Option, + cancel_token: CancellationToken, + ) -> Pin + Send + 'static>> { + let (tx, rx) = mpsc::channel(16); + + let self_clone = self.clone(); + let cancel_token_clone = cancel_token.clone(); + + tokio::spawn(async move { + self_clone + .converse_stream_rts(tx, cancel_token_clone, messages, tool_specs, system_prompt) + .await; + }); + + Box::pin(ReceiverStream::new(rx)) + } +} + +/// Contains only the serializable data associated with [RtsModel]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RtsModelState { + pub conversation_id: Uuid, + pub model_id: Option, +} + +impl RtsModelState { + pub fn new() -> Self { + Self { + conversation_id: Uuid::new_v4(), + model_id: None, + } + } +} + +impl Default for RtsModelState { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct ResponseParser { + /// The response to consume and parse into a sequence of [StreamEvent]. + response: SendMessageOutput, + event_tx: mpsc::Sender, + cancel_token: CancellationToken, + + /// Buffer that is continually written to during stream parsing. + buf: Vec, + + // parse state + /// Whether or not the stream has completed. + ended: bool, + /// Buffer to hold the next event in [SendMessageOutput]. + /// + /// Required since the RTS stream needs 1 look-ahead token to ensure we don't emit assistant + /// response events that are immediately followed by a code reference event. + peek: Option, + /// Whether or not we have sent a [MessageStartEvent]. + message_start_pushed: bool, + /// Whether or not we are currently receiving tool use delta events. Tuple of + /// `Some((tool_use_id, name))` if true, [None] otherwise. + parsing_tool_use: Option<(String, String)>, + /// Whether or not the response stream contained at least one tool use. + tool_use_seen: bool, + + // metadata fields + request_id: Option, + /// Time immediately before sending the request. + request_start_time: Instant, + /// Time immediately before sending the request, as a [SystemTime]. + request_start_time_sys: DateTime, + time_to_first_chunk: Option, + time_between_chunks: Vec, + /// Total size (in bytes) of the response received so far. + received_response_size: usize, +} + +impl ResponseParser { + fn new( + response: SendMessageOutput, + event_tx: mpsc::Sender, + cancel_token: CancellationToken, + request_id: Option, + request_start_time: Instant, + request_start_time_sys: DateTime, + ) -> Self { + Self { + response, + event_tx, + cancel_token, + ended: false, + peek: None, + message_start_pushed: false, + parsing_tool_use: None, + tool_use_seen: false, + buf: vec![], + time_to_first_chunk: None, + time_between_chunks: vec![], + request_id, + request_start_time, + request_start_time_sys, + received_response_size: 0, + } + } + + /// Consumes the entire response stream, emitting [StreamEvent] and [StreamError], or exiting + /// early if [Self::cancel_token] is cancelled. + /// + /// In either case, metadata regarding the stream is emitted with a [StreamEvent::Metadata]. + async fn consume_stream(mut self) { + loop { + if self.ended { + debug!("rts response stream has ended"); + return; + } + + let token = self.cancel_token.clone(); + tokio::select! { + _ = token.cancelled() => { + debug!("rts response parser was cancelled"); + self.buf.push(StreamResult::Ok(self.make_metadata())); + self.buf.push(StreamResult::Err(StreamError::new(StreamErrorKind::Interrupted))); + self.drain_buf_events().await; + return; + }, + res = self.fill_streamevent_buf() => { + match res { + Ok(_) => { + self.drain_buf_events().await; + }, + Err(err) => { + self.buf.push(StreamResult::Ok(self.make_metadata())); + self.buf.push(StreamResult::Err(self.recv_error_to_stream_error(err))); + self.drain_buf_events().await; + return; + }, + } + } + } + } + } + + async fn drain_buf_events(&mut self) { + for ev in self.buf.drain(..) { + self.event_tx + .send(ev) + .await + .map_err(|err| error!(?err, "failed to send event to channel")) + .ok(); + } + } + + /// Consumes the next token(s) in the response stream, filling [Self::buf] with the stream + /// events to be emitted, sequentially. + /// + /// We only consume the stream in parts in order to ensure we exit in a timely manner if + /// [Self::cancel_token] is cancelled. + async fn fill_streamevent_buf(&mut self) -> Result<(), RecvError> { + // First, handle discarding AssistantResponseEvent's that immediately precede a + // CodeReferenceEvent. + let peek = self.peek().await?; + if let Some(ChatResponseStream::AssistantResponseEvent { content }) = peek { + // Cloning to bypass borrowchecker stuff. + let content = content.clone(); + self.next().await?; + match self.peek().await? { + Some(ChatResponseStream::CodeReferenceEvent(_)) => (), + _ => { + self.buf.push(StreamResult::Ok(StreamEvent::ContentBlockDelta( + ContentBlockDeltaEvent { + delta: ContentBlockDelta::Text(content), + content_block_index: None, + }, + ))); + }, + } + } + + loop { + match self.next().await? { + Some(ev) => match ev { + ChatResponseStream::AssistantResponseEvent { content } => { + self.buf.push(StreamResult::Ok(StreamEvent::ContentBlockDelta( + ContentBlockDeltaEvent { + delta: ContentBlockDelta::Text(content), + content_block_index: None, + }, + ))); + return Ok(()); + }, + ChatResponseStream::ToolUseEvent { + tool_use_id, + name, + input, + stop, + } => { + self.tool_use_seen = true; + if self.parsing_tool_use.is_none() { + self.parsing_tool_use = Some((tool_use_id.clone(), name.clone())); + self.buf.push(StreamResult::Ok(StreamEvent::ContentBlockStart( + ContentBlockStartEvent { + content_block_start: Some(ContentBlockStart::ToolUse(ToolUseBlockStart { + tool_use_id, + name, + })), + content_block_index: None, + }, + ))); + } + if let Some(input) = input { + self.buf.push(StreamResult::Ok(StreamEvent::ContentBlockDelta( + ContentBlockDeltaEvent { + delta: ContentBlockDelta::ToolUse(ToolUseBlockDelta { input }), + content_block_index: None, + }, + ))); + } + if let Some(true) = stop { + self.buf + .push(StreamResult::Ok(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + content_block_index: None, + }))); + self.parsing_tool_use = None; + } + return Ok(()); + }, + other => { + warn!(?other, "received unexpected rts event"); + }, + }, + None => { + self.ended = true; + self.buf + .push(StreamResult::Ok(StreamEvent::MessageStop(MessageStopEvent { + stop_reason: if self.tool_use_seen { + StopReason::ToolUse + } else { + StopReason::EndTurn + }, + }))); + self.buf.push(StreamResult::Ok(self.make_metadata())); + return Ok(()); + }, + } + } + } + + async fn peek(&mut self) -> Result, RecvError> { + if self.peek.is_some() { + return Ok(self.peek.as_ref()); + } + match self.next().await? { + Some(v) => { + self.peek = Some(v); + Ok(self.peek.as_ref()) + }, + None => Ok(None), + } + } + + async fn next(&mut self) -> Result, RecvError> { + if let Some(ev) = self.peek.take() { + return Ok(Some(ev)); + } + + trace!("Attempting to recv next event"); + let start = Instant::now(); + let result = self.response.recv().await; + let duration = Instant::now().duration_since(start); + match result { + Ok(ev) => { + trace!(?ev, "Received new event"); + + if !self.message_start_pushed { + self.buf + .push(StreamResult::Ok(StreamEvent::MessageStart(MessageStartEvent { + role: Role::Assistant, + }))); + self.message_start_pushed = true; + } + + // Track metadata about the chunk. + self.time_to_first_chunk + .get_or_insert_with(|| self.request_start_time.elapsed()); + self.time_between_chunks.push(duration); + self.received_response_size += ev.as_ref().map(|e| e.len()).unwrap_or_default(); + + Ok(ev) + }, + Err(err) => { + error!(?err, "failed to receive the next event"); + if duration.as_secs() >= 59 { + Err(RecvError::Timeout { source: err, duration }) + } else { + Err(RecvError::Other { source: err }) + } + }, + } + } + + fn recv_error_to_stream_error(&self, err: RecvError) -> StreamError { + match err { + RecvError::Timeout { source, duration } => StreamError::new(StreamErrorKind::StreamTimeout { duration }) + .set_original_request_id(self.request_id.clone()) + .with_source(Arc::new(source)), + RecvError::Other { source } => StreamError::new(StreamErrorKind::Other(format!( + "An unexpected error occurred during the response stream: {:?}", + source + ))) + .set_original_request_id(self.request_id.clone()) + .with_source(Arc::new(source)), + } + } + + fn make_metadata(&self) -> StreamEvent { + StreamEvent::Metadata(MetadataEvent { + metrics: Some(MetadataMetrics { + request_start_time: self.request_start_time_sys, + request_end_time: Utc::now(), + time_to_first_chunk: self.time_to_first_chunk, + time_between_chunks: if self.time_between_chunks.is_empty() { + None + } else { + Some(self.time_between_chunks.clone()) + }, + response_stream_len: self.received_response_size as u32, + }), + // if only rts gave usage metrics... + usage: None, + service: Some(MetadataService { + request_id: self.response.request_id().map(String::from), + status_code: None, + }), + }) + } +} + +#[derive(Debug)] +enum RecvError { + Timeout { source: ApiClientError, duration: Duration }, + Other { source: ApiClientError }, +} + +#[cfg(test)] +mod tests { + use tokio_stream::StreamExt as _; + + use super::*; + use crate::database::Database; + use crate::os::{ + Env, + Fs, + }; + use crate::util::env_var::is_integ_test; + + /// Manual test to verify cancellation succeeds in a timely manner. + #[tokio::test] + async fn integ_test_rts_cancel() { + if !is_integ_test() { + return; + } + + let rts = RtsModel::new( + ApiClient::new(&Env::new(), &Fs::new(), &mut Database::new().await.unwrap(), None) + .await + .unwrap(), + Uuid::new_v4(), + None, + ); + let cancel_token = CancellationToken::new(); + let token_clone = cancel_token.clone(); + let (tx, mut rx) = mpsc::channel(8); + tokio::spawn(async move { + let mut stream = rts.stream( + vec![Message::new( + Role::User, + vec![ContentBlock::Text( + "Hello, can you explain how to write hello world in c, python, and rust?".to_string(), + )], + None, + )], + None, + None, + token_clone, + ); + while let Some(ev) = stream.next().await { + let _ = tx.send(ev).await; + } + }); + + // Assertion logic here is: + // 1. Loop until we start receiving content + // 2. Once content is received, cancel the stream + // 3. Assert that we receive a metadata stream event, and then immediately followed by an + // Interrupted error. These events should be received almost immediately after cancelling. + let mut was_cancelled = false; + let mut cancelled_time = None; + loop { + let ev = rx.recv().await.expect("should not fail"); + if let StreamResult::Ok(StreamEvent::ContentBlockDelta(_)) = ev { + if was_cancelled { + continue; + } + // We received content, so time to interrupt the stream. + cancel_token.cancel(); + was_cancelled = true; + cancelled_time = Some(Instant::now()); + } + if let StreamResult::Ok(StreamEvent::Metadata(_)) = ev { + // Next event should be an interrupted error. + let ev = rx.recv().await.expect("should have another event after metadata"); + let err = ev.unwrap_err(); + assert!(matches!(err.kind, StreamErrorKind::Interrupted)); + let elapsed = cancelled_time.unwrap().elapsed(); + assert!( + elapsed.as_millis() < 25, + "stream should have been interrupted in a timely manner, instead took: {}ms", + elapsed.as_millis() + ); + break; + } + } + if !was_cancelled { + panic!("stream was never cancelled"); + } + } +} diff --git a/crates/chat-cli/src/agent/rts/types.rs b/crates/chat-cli/src/agent/rts/types.rs new file mode 100644 index 0000000000..20abd9825b --- /dev/null +++ b/crates/chat-cli/src/agent/rts/types.rs @@ -0,0 +1,69 @@ +use agent::agent_loop::types::*; + +use crate::api_client::model; +use crate::cli::chat::util::serde_value_to_document; + +impl From for model::ImageBlock { + fn from(v: ImageBlock) -> Self { + Self { + format: v.format.into(), + source: v.source.into(), + } + } +} + +impl From for model::ImageFormat { + fn from(value: ImageFormat) -> Self { + match value { + ImageFormat::Gif => Self::Gif, + ImageFormat::Jpeg => Self::Jpeg, + ImageFormat::Png => Self::Png, + ImageFormat::Webp => Self::Webp, + } + } +} + +impl From for model::ImageSource { + fn from(value: ImageSource) -> Self { + match value { + ImageSource::Bytes(items) => Self::Bytes(items), + } + } +} + +impl From for model::ToolUse { + fn from(v: ToolUseBlock) -> Self { + Self { + tool_use_id: v.tool_use_id, + name: v.name, + input: serde_value_to_document(v.input).into(), + } + } +} + +impl From for model::ToolResultStatus { + fn from(value: ToolResultStatus) -> Self { + match value { + ToolResultStatus::Error => Self::Error, + ToolResultStatus::Success => Self::Success, + } + } +} + +impl From for model::ToolSpecification { + fn from(v: ToolSpec) -> Self { + Self { + name: v.name, + description: v.description, + input_schema: v.input_schema.into(), + } + } +} + +impl From> for model::ToolInputSchema { + fn from(v: serde_json::Map) -> Self { + Self { + json: Some(serde_value_to_document(v.into()).into()), + } + } +} diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs index 4ac80f329c..35aa92472c 100644 --- a/crates/chat-cli/src/api_client/error.rs +++ b/crates/chat-cli/src/api_client/error.rs @@ -25,6 +25,10 @@ use crate::telemetry::ReasonCode; #[derive(Debug, Error)] pub enum ApiClientError { + /// The converse stream operation + #[error("{}", .0)] + ConverseStream(#[from] ConverseStreamError), + // Generate completions errors #[error("{}", SdkErrorDisplay(.0))] GenerateCompletions(#[from] SdkError), @@ -41,40 +45,15 @@ pub enum ApiClientError { #[error("{}", SdkErrorDisplay(.0))] SendTelemetryEvent(#[from] SdkError), - // Send message errors - #[error("{}", SdkErrorDisplay(.0))] - CodewhispererGenerateAssistantResponse(#[from] SdkError), - #[error("{}", SdkErrorDisplay(.0))] - QDeveloperSendMessage(#[from] SdkError), - // chat stream errors #[error("{}", SdkErrorDisplay(.0))] CodewhispererChatResponseStream(#[from] SdkError), #[error("{}", SdkErrorDisplay(.0))] QDeveloperChatResponseStream(#[from] SdkError), - // quota breach - #[error("quota has reached its limit")] - QuotaBreach { - message: &'static str, - status_code: Option, - }, - - // Separate from quota breach (somehow) - #[error("monthly query limit reached")] - MonthlyLimitReached { status_code: Option }, - #[error("{}", SdkErrorDisplay(.0))] CreateSubscriptionToken(#[from] SdkError), - /// Returned from the backend when the user input is too large to fit within the model context - /// window. - /// - /// Note that we currently do not receive token usage information regarding how large the - /// context window is. - #[error("the context window has overflowed")] - ContextWindowOverflow { status_code: Option }, - #[error(transparent)] SmithyBuild(#[from] aws_smithy_types::error::operation::BuildError), @@ -84,14 +63,6 @@ pub enum ApiClientError { #[error(transparent)] AuthError(#[from] AuthError), - #[error( - "The model you've selected is temporarily unavailable. Please use '/model' to select a different model and try again." - )] - ModelOverloadedError { - request_id: Option, - status_code: Option, - }, - // Credential errors #[error("failed to load credentials: {}", .0)] Credentials(CredentialsError), @@ -109,23 +80,18 @@ pub enum ApiClientError { impl ApiClientError { pub fn status_code(&self) -> Option { match self { + Self::ConverseStream(e) => e.status_code, Self::GenerateCompletions(e) => sdk_status_code(e), Self::GenerateRecommendations(e) => sdk_status_code(e), Self::ListAvailableCustomizations(e) => sdk_status_code(e), Self::ListAvailableServices(e) => sdk_status_code(e), - Self::CodewhispererGenerateAssistantResponse(e) => sdk_status_code(e), - Self::QDeveloperSendMessage(e) => sdk_status_code(e), Self::CodewhispererChatResponseStream(_) => None, Self::QDeveloperChatResponseStream(_) => None, Self::ListAvailableProfilesError(e) => sdk_status_code(e), Self::SendTelemetryEvent(e) => sdk_status_code(e), Self::CreateSubscriptionToken(e) => sdk_status_code(e), - Self::QuotaBreach { status_code, .. } => *status_code, - Self::ContextWindowOverflow { status_code } => *status_code, Self::SmithyBuild(_) => None, Self::AuthError(_) => None, - Self::ModelOverloadedError { status_code, .. } => *status_code, - Self::MonthlyLimitReached { status_code } => *status_code, Self::Credentials(_e) => None, Self::ListAvailableModelsError(e) => sdk_status_code(e), Self::DefaultModelNotFound => None, @@ -137,23 +103,18 @@ impl ApiClientError { impl ReasonCode for ApiClientError { fn reason_code(&self) -> String { match self { + Self::ConverseStream(e) => e.reason_code(), Self::GenerateCompletions(e) => sdk_error_code(e), Self::GenerateRecommendations(e) => sdk_error_code(e), Self::ListAvailableCustomizations(e) => sdk_error_code(e), Self::ListAvailableServices(e) => sdk_error_code(e), - Self::CodewhispererGenerateAssistantResponse(e) => sdk_error_code(e), - Self::QDeveloperSendMessage(e) => sdk_error_code(e), Self::CodewhispererChatResponseStream(e) => sdk_error_code(e), Self::QDeveloperChatResponseStream(e) => sdk_error_code(e), Self::ListAvailableProfilesError(e) => sdk_error_code(e), Self::SendTelemetryEvent(e) => sdk_error_code(e), Self::CreateSubscriptionToken(e) => sdk_error_code(e), - Self::QuotaBreach { .. } => "QuotaBreachError".to_string(), - Self::ContextWindowOverflow { .. } => "ContextWindowOverflow".to_string(), Self::SmithyBuild(_) => "SmithyBuildError".to_string(), Self::AuthError(_) => "AuthError".to_string(), - Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(), - Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(), Self::Credentials(_) => "CredentialsError".to_string(), Self::ListAvailableModelsError(e) => sdk_error_code(e), Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(), @@ -162,7 +123,96 @@ impl ReasonCode for ApiClientError { } } -fn sdk_error_code(e: &SdkError) -> String { +#[derive(Debug, Error)] +#[error("{}", .kind)] +pub struct ConverseStreamError { + pub request_id: Option, + pub status_code: Option, + pub kind: ConverseStreamErrorKind, + #[source] + pub source: Option, +} + +impl ConverseStreamError { + pub fn new(kind: ConverseStreamErrorKind, source: Option>) -> Self { + Self { + kind, + source: source.map(Into::into), + request_id: None, + status_code: None, + } + } + + pub fn set_request_id(mut self, request_id: Option) -> Self { + self.request_id = request_id; + self + } + + pub fn set_status_code(mut self, status_code: Option) -> Self { + self.status_code = status_code; + self + } +} + +impl ReasonCode for ConverseStreamError { + fn reason_code(&self) -> String { + match &self.kind { + // maintaining backwards compatibility with the previous throttling error code. + ConverseStreamErrorKind::Throttling => "QuotaBreachError".to_string(), + ConverseStreamErrorKind::MonthlyLimitReached => "MonthlyLimitReached".to_string(), + ConverseStreamErrorKind::ContextWindowOverflow => "ContextWindowOverflow".to_string(), + ConverseStreamErrorKind::ModelOverloadedError => "ModelOverloadedError".to_string(), + ConverseStreamErrorKind::Unknown { reason_code } => reason_code.clone(), + } + } +} + +impl From for ConverseStreamError { + fn from(value: aws_smithy_types::error::operation::BuildError) -> Self { + Self { + request_id: None, + status_code: None, + kind: ConverseStreamErrorKind::Unknown { + reason_code: value.to_string(), + }, + source: Some(value.into()), + } + } +} + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum ConverseStreamErrorKind { + #[error("Too many requests have been sent recently, please wait and try again later")] + Throttling, + #[error("The monthly usage limit has been reached")] + MonthlyLimitReached, + /// Returned from the backend when the user input is too large to fit within the model context + /// window. + /// + /// Note that we currently do not receive token usage information regarding how large the + /// context window is. + #[error("The context window has overflowed")] + ContextWindowOverflow, + #[error( + "The model you've selected is temporarily unavailable. Please use '/model' to select a different model and try again." + )] + ModelOverloadedError, + #[error("An unknown error occurred: {}", .reason_code)] + Unknown { reason_code: String }, +} + +#[derive(Debug, Error)] +pub enum ConverseStreamSdkError { + #[error("{}", SdkErrorDisplay(.0))] + CodewhispererGenerateAssistantResponse(#[from] SdkError), + #[error("{}", SdkErrorDisplay(.0))] + QDeveloperSendMessage(#[from] SdkError), + #[error(transparent)] + SmithyBuild(#[from] aws_smithy_types::error::operation::BuildError), +} + +pub fn sdk_error_code(e: &SdkError) -> String { e.as_service_error() .and_then(|se| se.meta().code().map(str::to_string)) .unwrap_or_else(|| e.to_string()) @@ -192,6 +242,23 @@ mod tests { fn all_errors() -> Vec { vec![ + ApiClientError::ConverseStream(ConverseStreamError { + request_id: None, + status_code: None, + kind: ConverseStreamErrorKind::Throttling, + source: Some(ConverseStreamSdkError::CodewhispererGenerateAssistantResponse( + SdkError::service_error(GenerateAssistantResponseError::unhandled(""), response()), + )), + }), + ApiClientError::ConverseStream(ConverseStreamError { + request_id: None, + status_code: None, + kind: ConverseStreamErrorKind::Throttling, + source: Some(ConverseStreamSdkError::QDeveloperSendMessage(SdkError::service_error( + QDeveloperSendMessageError::unhandled(""), + response(), + ))), + }), ApiClientError::Credentials(CredentialsError::unhandled("")), ApiClientError::GenerateCompletions(SdkError::service_error( GenerateCompletionsError::unhandled(""), @@ -217,14 +284,6 @@ mod tests { ListCustomizationsError::unhandled(""), response(), )), - ApiClientError::CodewhispererGenerateAssistantResponse(SdkError::service_error( - GenerateAssistantResponseError::unhandled(""), - response(), - )), - ApiClientError::QDeveloperSendMessage(SdkError::service_error( - QDeveloperSendMessageError::unhandled(""), - response(), - )), ApiClientError::CreateSubscriptionToken(SdkError::service_error( CreateSubscriptionTokenError::unhandled(""), response(), diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index 8363d377a2..898352b01f 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -2,7 +2,7 @@ mod credentials; pub mod customization; mod delay_interceptor; mod endpoints; -mod error; +pub mod error; pub mod model; mod opt_out; pub mod profile; @@ -29,10 +29,15 @@ use aws_config::retry::RetryConfig; use aws_config::timeout::TimeoutConfig; use aws_credential_types::Credentials; use aws_credential_types::provider::ProvideCredentials; +use aws_sdk_ssooidc::error::ProvideErrorMetadata; use aws_types::request_id::RequestId; use aws_types::sdk_config::StalledStreamProtectionConfig; pub use endpoints::Endpoint; pub use error::ApiClientError; +use error::{ + ConverseStreamError, + ConverseStreamErrorKind, +}; use parking_lot::Mutex; pub use profile::list_available_profiles; use serde_json::Map; @@ -65,6 +70,7 @@ use crate::os::{ Env, Fs, }; +use crate::util::env_var::is_integ_test; // Opt out constants pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-optout"; @@ -128,7 +134,7 @@ impl ApiClient { .build(), ); - if cfg!(test) { + if cfg!(test) && !is_integ_test() { let mut this = Self { client, streaming_client: None, @@ -375,7 +381,10 @@ impl ApiClient { .map_err(ApiClientError::CreateSubscriptionToken) } - pub async fn send_message(&self, conversation: ConversationState) -> Result { + pub async fn send_message( + &self, + conversation: ConversationState, + ) -> Result { debug!("Sending conversation: {:#?}", conversation); let ConversationState { @@ -412,72 +421,22 @@ impl ApiClient { { Ok(response) => Ok(SendMessageOutput::Codewhisperer(response)), Err(err) => { + let request_id = err + .as_service_error() + .and_then(|err| err.meta().request_id()) + .map(|s| s.to_string()); let status_code = err.raw_response().map(|res| res.status().as_u16()); - let is_quota_breach = status_code.is_some_and(|status| status == 429); - let is_context_window_overflow = err.as_service_error().is_some_and(|err| { - matches!(err, err if err.meta().code() == Some("ValidationException") && err.meta().message() == Some("Input is too long.")) - }); - - let is_model_unavailable = { - // check if ThrottlingException - let is_throttling_exception = err - .as_service_error() - .is_some_and(|service_err| service_err.meta().code() == Some("ThrottlingException")); - - // check if the response contains INSUFFICIENT_MODEL_CAPACITY - let has_insufficient_capacity = err - .raw_response() - .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()) - .is_some_and(|body| body.contains("INSUFFICIENT_MODEL_CAPACITY")); - - (is_throttling_exception && has_insufficient_capacity) - // Legacy error response fallback - || (model_id_opt.is_some() - && status_code.is_some_and(|status| status == 500) - && err.as_service_error().is_some_and(|err| { - err.meta().message() == Some( - "Encountered unexpectedly high load when processing the request, please try again.", - )})) - }; - - let is_monthly_limit_err = err + + let body = err .raw_response() .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| match String::from_utf8(bytes.to_vec()) { - Ok(s) => Some(s.contains("MONTHLY_REQUEST_COUNT")), - Err(_) => None, - }) - .unwrap_or(false); - - if is_context_window_overflow { - return Err(ApiClientError::ContextWindowOverflow { status_code }); - } - - // Both ModelOverloadedError and QuotaBreach return 429, - // so check is_model_unavailable first. - if is_model_unavailable { - return Err(ApiClientError::ModelOverloadedError { - request_id: err - .as_service_error() - .and_then(|err| err.meta().request_id()) - .map(|s| s.to_string()), - status_code, - }); - } - - if is_quota_breach { - return Err(ApiClientError::QuotaBreach { - message: "quota has reached its limit", - status_code, - }); - } - - if is_monthly_limit_err { - return Err(ApiClientError::MonthlyLimitReached { status_code }); - } - - Err(err.into()) + .unwrap_or_default(); + Err(ConverseStreamError::new( + classify_error_kind(status_code, body, model_id_opt.as_deref(), &err), + Some(err), + ) + .set_request_id(request_id) + .set_status_code(status_code)) }, } } else if let Some(client) = &self.sigv4_streaming_client { @@ -504,72 +463,22 @@ impl ApiClient { { Ok(response) => Ok(SendMessageOutput::QDeveloper(response)), Err(err) => { + let request_id = err + .as_service_error() + .and_then(|err| err.meta().request_id()) + .map(|s| s.to_string()); let status_code = err.raw_response().map(|res| res.status().as_u16()); - let is_quota_breach = status_code.is_some_and(|status| status == 429); - let is_context_window_overflow = err.as_service_error().is_some_and(|err| { - matches!(err, err if err.meta().code() == Some("ValidationException") && err.meta().message() == Some("Input is too long.")) - }); - - let is_model_unavailable = { - // check if ThrottlingException - let is_throttling_exception = err - .as_service_error() - .is_some_and(|service_err| service_err.meta().code() == Some("ThrottlingException")); - - // check if the response contains INSUFFICIENT_MODEL_CAPACITY - let has_insufficient_capacity = err - .raw_response() - .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()) - .is_some_and(|body| body.contains("INSUFFICIENT_MODEL_CAPACITY")); - - (is_throttling_exception && has_insufficient_capacity) - // Legacy error response fallback - || (model_id_opt.is_some() - && status_code.is_some_and(|status| status == 500) - && err.as_service_error().is_some_and(|err| { - err.meta().message() == Some( - "Encountered unexpectedly high load when processing the request, please try again.", - )})) - }; - - let is_monthly_limit_err = err + + let body = err .raw_response() .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| match String::from_utf8(bytes.to_vec()) { - Ok(s) => Some(s.contains("MONTHLY_REQUEST_COUNT")), - Err(_) => None, - }) - .unwrap_or(false); - - // Both ModelOverloadedError and QuotaBreach return 429, - // so check is_model_unavailable first. - if is_model_unavailable { - return Err(ApiClientError::ModelOverloadedError { - request_id: err - .as_service_error() - .and_then(|err| err.meta().request_id()) - .map(|s| s.to_string()), - status_code, - }); - } - - if is_quota_breach { - return Err(ApiClientError::QuotaBreach { - message: "quota has reached its limit", - status_code, - }); - } - - if is_context_window_overflow { - return Err(ApiClientError::ContextWindowOverflow { status_code }); - } - - if is_monthly_limit_err { - return Err(ApiClientError::MonthlyLimitReached { status_code }); - } - - Err(err.into()) + .unwrap_or_default(); + Err(ConverseStreamError::new( + classify_error_kind(status_code, body, model_id_opt.as_deref(), &err), + Some(err), + ) + .set_request_id(request_id) + .set_status_code(status_code)) }, } } else if let Some(client) = &self.mock_client { @@ -612,6 +521,51 @@ impl ApiClient { } } +fn classify_error_kind( + status_code: Option, + body: &[u8], + model_id_opt: Option<&str>, + sdk_error: &error::SdkError, +) -> ConverseStreamErrorKind { + let contains = |haystack: &[u8], needle: &[u8]| haystack.windows(needle.len()).any(|v| v == needle); + + let is_throttling = status_code.is_some_and(|status| status == 429); + let is_context_window_overflow = contains(body, b"Input is too long."); + let is_model_unavailable = contains(body, b"INSUFFICIENT_MODEL_CAPACITY") + // Legacy error response fallback + || (model_id_opt.is_some() + && status_code.is_some_and(|status| status == 500) + && contains( + body, + b"Encountered unexpectedly high load when processing the request, please try again.", + )); + let is_monthly_limit_err = contains(body, b"MONTHLY_REQUEST_COUNT"); + + if is_context_window_overflow { + return ConverseStreamErrorKind::ContextWindowOverflow; + } + + // Both ModelOverloadedError and Throttling return 429, + // so check is_model_unavailable first. + if is_model_unavailable { + return ConverseStreamErrorKind::ModelOverloadedError; + } + + if is_throttling { + return ConverseStreamErrorKind::Throttling; + } + + if is_monthly_limit_err { + return ConverseStreamErrorKind::MonthlyLimitReached; + } + + ConverseStreamErrorKind::Unknown { + // do not change - we currently use sdk_error_code for mapping from an arbitrary sdk error + // to a reason code. + reason_code: error::sdk_error_code(sdk_error), + } +} + fn timeout_config(database: &Database) -> TimeoutConfig { let timeout = database .settings @@ -679,6 +633,7 @@ mod tests { IdeCategory, OperatingSystem, }; + use bstr::ByteSlice; use super::*; use crate::api_client::model::UserInputMessage; @@ -754,4 +709,100 @@ mod tests { } assert_eq!(output_content, "Hello! How can I assist you today?"); } + + #[test] + fn test_classify_error_kind() { + use aws_smithy_runtime_api::http::Response; + use aws_smithy_types::body::SdkBody; + + use crate::api_client::error::{ + GenerateAssistantResponseError, + SdkError, + }; + + let mock_sdk_error = || { + SdkError::service_error( + GenerateAssistantResponseError::unhandled("test"), + Response::new(500.try_into().unwrap(), SdkBody::empty()), + ) + }; + + let test_cases: Vec<(Option, &[u8], Option<&str>, ConverseStreamErrorKind)> = vec![ + ( + Some(400), + b"Input is too long.", + None, + ConverseStreamErrorKind::ContextWindowOverflow, + ), + ( + Some(500), + b"INSUFFICIENT_MODEL_CAPACITY", + Some("model-1"), + ConverseStreamErrorKind::ModelOverloadedError, + ), + ( + Some(500), + b"Encountered unexpectedly high load when processing the request, please try again.", + Some("model-1"), + ConverseStreamErrorKind::ModelOverloadedError, + ), + ( + Some(429), + b"Rate limit exceeded", + None, + ConverseStreamErrorKind::Throttling, + ), + ( + Some(400), + b"MONTHLY_REQUEST_COUNT exceeded", + None, + ConverseStreamErrorKind::MonthlyLimitReached, + ), + ( + Some(429), + b"Input is too long.", + None, + ConverseStreamErrorKind::ContextWindowOverflow, + ), + ( + Some(429), + b"INSUFFICIENT_MODEL_CAPACITY", + Some("model-1"), + ConverseStreamErrorKind::ModelOverloadedError, + ), + ( + Some(500), + b"Encountered unexpectedly high load when processing the request, please try again.", + None, + ConverseStreamErrorKind::Unknown { + reason_code: "test".to_string(), + }, + ), + ( + Some(400), + b"Encountered unexpectedly high load when processing the request, please try again.", + Some("model-1"), + ConverseStreamErrorKind::Unknown { + reason_code: "test".to_string(), + }, + ), + (Some(500), b"Some other error", None, ConverseStreamErrorKind::Unknown { + reason_code: "test".to_string(), + }), + ]; + + for (status_code, body, model_id, expected) in test_cases { + let result = classify_error_kind(status_code, body, model_id, &mock_sdk_error()); + assert_eq!( + std::mem::discriminant(&result), + std::mem::discriminant(&expected), + "expected '{}', got '{}' | status_code: {:?}, body: '{}', model_id: '{:?}'", + expected, + result, + status_code, + body.to_str_lossy(), + model_id + ); + } + } } diff --git a/crates/chat-cli/src/api_client/model.rs b/crates/chat-cli/src/api_client/model.rs index 808081ec63..1a72023f60 100644 --- a/crates/chat-cli/src/api_client/model.rs +++ b/crates/chat-cli/src/api_client/model.rs @@ -322,6 +322,12 @@ pub enum Tool { ToolSpecification(ToolSpecification), } +impl From for Tool { + fn from(value: ToolSpecification) -> Self { + Self::ToolSpecification(value) + } +} + impl From for amzn_codewhisperer_streaming_client::types::Tool { fn from(value: Tool) -> Self { match value { @@ -575,6 +581,33 @@ pub enum ChatResponseStream { Unknown, } +impl ChatResponseStream { + /// Returns the length of the content of the message event - ie, the number of bytes of content + /// contained within the message. + /// + /// This doesn't reflect the actual number of bytes the message took up being serialized over + /// the network. + pub fn len(&self) -> usize { + match self { + ChatResponseStream::AssistantResponseEvent { content } => content.len(), + ChatResponseStream::CodeEvent { content } => content.len(), + ChatResponseStream::CodeReferenceEvent(_) => 0, + ChatResponseStream::FollowupPromptEvent(_) => 0, + ChatResponseStream::IntentsEvent(_) => 0, + ChatResponseStream::InvalidStateEvent { .. } => 0, + ChatResponseStream::MessageMetadataEvent { .. } => 0, + ChatResponseStream::SupplementaryWebLinksEvent(_) => 0, + ChatResponseStream::ToolUseEvent { input, .. } => input.as_ref().map(|s| s.len()).unwrap_or_default(), + ChatResponseStream::Unknown => 0, + } + } + + #[allow(dead_code)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + impl From for ChatResponseStream { fn from(value: amzn_codewhisperer_streaming_client::types::ChatResponseStream) -> Self { match value { diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index ec8df09713..ebc899df11 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -64,7 +64,10 @@ use crate::database::{ Secret, }; use crate::os::Env; -use crate::util::env_var::is_sigv4_enabled; +use crate::util::env_var::{ + is_integ_test, + is_sigv4_enabled, +}; #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum OAuthFlow { @@ -320,7 +323,7 @@ impl BuilderIdToken { ) -> Result, AuthError> { // Can't use #[cfg(test)] without breaking lints, and we don't want to require // authentication in order to run ChatSession tests. Hence, adding this here with cfg!(test) - if cfg!(test) { + if cfg!(test) && !is_integ_test() { return Ok(Some(Self { access_token: Secret("test_access_token".to_string()), expires_at: time::OffsetDateTime::now_utc() + time::Duration::minutes(60), diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 6803e56df1..081957937e 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -5,6 +5,7 @@ use spinners::{ Spinners, }; +use crate::api_client::error::ConverseStreamErrorKind; use crate::theme::StyledText; use crate::util::ui::should_send_structured_message; pub mod cli; @@ -915,10 +916,10 @@ impl ChatSession { )?; ("Unable to compact the conversation history", eyre!(err), true) }, - ChatError::SendMessage(err) => match err.source { + ChatError::SendMessage(err) => match &err.source.kind { // Errors from attempting to send too large of a conversation history. In // this case, attempt to automatically compact the history for the user. - ApiClientError::ContextWindowOverflow { .. } => { + ConverseStreamErrorKind::ContextWindowOverflow => { if os .database .settings @@ -961,10 +962,7 @@ impl ChatSession { return Ok(()); } }, - ApiClientError::QuotaBreach { - message: _, - status_code: _, - } => { + ConverseStreamErrorKind::Throttling => { let err = "Request quota exceeded. Please wait a moment and try again.".to_string(); self.conversation.append_transcript(err.clone()); execute!( @@ -979,7 +977,7 @@ impl ChatSession { )?; (error_messages::TROUBLE_RESPONDING, eyre!(err), false) }, - ApiClientError::ModelOverloadedError { request_id, .. } => { + ConverseStreamErrorKind::ModelOverloadedError => { if self.interactive { execute!( self.stderr, @@ -992,7 +990,7 @@ impl ChatSession { StyledText::reset(), )?; - if let Some(id) = request_id { + if let Some(id) = err.source.request_id { self.conversation .append_transcript(format!("Model unavailable (Request ID: {})", id)); } @@ -1007,7 +1005,7 @@ impl ChatSession { let err = format!( "The model you've selected is temporarily unavailable. {}{}\n\n", model_instruction, - match request_id { + match err.source.request_id { Some(id) => format!("\n Request ID: {}", id), None => "".to_owned(), } @@ -1024,7 +1022,7 @@ impl ChatSession { )?; (error_messages::TROUBLE_RESPONDING, eyre!(err), false) }, - ApiClientError::MonthlyLimitReached { .. } => { + ConverseStreamErrorKind::MonthlyLimitReached => { let subscription_status = get_subscription_status(os).await; if subscription_status.is_err() { execute!( @@ -1484,7 +1482,7 @@ impl ChatSession { let history_len = self.conversation.history().len(); match err { ChatError::SendMessage(err) - if matches!(err.source, ApiClientError::ContextWindowOverflow { .. }) => + if matches!(err.source.kind, ConverseStreamErrorKind::ContextWindowOverflow) => { error!(?strategy, "failed to send compaction request"); // If there's only two messages in the history, we have no choice but to diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs index 2e0cdfb03c..eea45d891d 100644 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ b/crates/chat-cli/src/cli/chat/parser.rs @@ -29,15 +29,13 @@ use super::message::{ AssistantMessage, AssistantToolUse, }; +use crate::api_client::ApiClient; +use crate::api_client::error::ConverseStreamError; use crate::api_client::model::{ ChatResponseStream, ConversationState, }; use crate::api_client::send_message_output::SendMessageOutput; -use crate::api_client::{ - ApiClient, - ApiClientError, -}; use crate::telemetry::ReasonCode; use crate::telemetry::core::{ ChatConversationType, @@ -48,13 +46,13 @@ use crate::telemetry::core::{ #[derive(Debug, Error)] pub struct SendMessageError { #[source] - pub source: ApiClientError, + pub source: ConverseStreamError, pub request_metadata: RequestMetadata, } impl SendMessageError { pub fn status_code(&self) -> Option { - self.source.status_code() + self.source.status_code } } diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs index 0fec09d552..880cacfc4a 100644 --- a/crates/chat-cli/src/database/mod.rs +++ b/crates/chat-cli/src/database/mod.rs @@ -35,6 +35,7 @@ use tracing::{ use uuid::Uuid; use crate::cli::ConversationState; +use crate::util::env_var::is_integ_test; use crate::util::paths::{ DirectoryError, GlobalPaths, @@ -187,7 +188,7 @@ pub struct Database { impl Database { pub async fn new() -> Result { - let path = match cfg!(test) { + let path = match cfg!(test) && !is_integ_test() { true => { return Self { pool: Pool::builder().build(SqliteConnectionManager::memory()).unwrap(), diff --git a/crates/chat-cli/src/main.rs b/crates/chat-cli/src/main.rs index c4f0317551..6608be567f 100644 --- a/crates/chat-cli/src/main.rs +++ b/crates/chat-cli/src/main.rs @@ -1,3 +1,4 @@ +mod agent; mod api_client; mod auth; mod aws_common; diff --git a/crates/chat-cli/src/os/env.rs b/crates/chat-cli/src/os/env.rs index 63e5449419..ec579eacba 100644 --- a/crates/chat-cli/src/os/env.rs +++ b/crates/chat-cli/src/os/env.rs @@ -14,6 +14,8 @@ use std::sync::{ Mutex, }; +use agent::util::is_integ_test; + use crate::os::ACTIVE_USER_HOME; #[derive(Debug, Clone)] @@ -43,7 +45,7 @@ mod inner { impl Env { pub fn new() -> Self { - if cfg!(test) { + if cfg!(test) && !is_integ_test() { match cfg!(windows) { true => Env::from_slice(&[ ("USERPROFILE", ACTIVE_USER_HOME), diff --git a/crates/chat-cli/src/os/fs/mod.rs b/crates/chat-cli/src/os/fs/mod.rs index 755966acaf..e8dec21ec3 100644 --- a/crates/chat-cli/src/os/fs/mod.rs +++ b/crates/chat-cli/src/os/fs/mod.rs @@ -40,6 +40,8 @@ use windows::{ symlink_sync, }; +use crate::util::env_var::is_integ_test; + /// Rust path handling is hard coded to work specific ways depending on the /// OS that is being executed on. Because of this, if Unix paths are provided, /// they aren't recognized. For example a leading prefix of '/' isn't considered @@ -97,7 +99,7 @@ pub enum Fs { impl Fs { pub fn new() -> Self { - match cfg!(test) { + match cfg!(test) && !is_integ_test() { true => { let tempdir = tempfile::tempdir().expect("failed creating temporary directory"); let fs = Self::Chroot(tempdir.into()); diff --git a/crates/chat-cli/src/util/consts.rs b/crates/chat-cli/src/util/consts.rs index 14a0cea686..a22ce973a2 100644 --- a/crates/chat-cli/src/util/consts.rs +++ b/crates/chat-cli/src/util/consts.rs @@ -69,6 +69,9 @@ pub mod env_var { /// Identifier for the client application or service using the chat-cli Q_CLI_CLIENT_APPLICATION = "Q_CLI_CLIENT_APPLICATION", + /// Flag for running integration tests + CLI_IS_INTEG_TEST = "Q_CLI_IS_INTEG_TEST", + /// Enable logging to stdout Q_LOG_STDOUT = "Q_LOG_STDOUT", diff --git a/crates/chat-cli/src/util/env_var.rs b/crates/chat-cli/src/util/env_var.rs index 850bd5831b..8f855833d1 100644 --- a/crates/chat-cli/src/util/env_var.rs +++ b/crates/chat-cli/src/util/env_var.rs @@ -51,6 +51,10 @@ pub fn in_ci() -> bool { env.get_os(CI).is_some() || env.get_os(Q_CI).is_some() } +pub fn is_integ_test() -> bool { + std::env::var_os(CLI_IS_INTEG_TEST).is_some_and(|s| !s.is_empty()) +} + /// Get CLI client application pub fn get_cli_client_application() -> Option { Env::new().get(Q_CLI_CLIENT_APPLICATION).ok()