diff --git a/gateway/src/routes/internal.rs b/gateway/src/routes/internal.rs index 0f6c8440e1..f28e10ba3b 100644 --- a/gateway/src/routes/internal.rs +++ b/gateway/src/routes/internal.rs @@ -45,6 +45,10 @@ pub fn build_internal_non_otel_enabled_routes() -> Router { "/internal/feedback/{target_id}/latest-id-by-metric", get(endpoints::feedback::internal::get_latest_feedback_id_by_metric_handler), ) + .route( + "/internal/feedback/{target_id}/count", + get(endpoints::feedback::internal::count_feedback_by_target_id_handler), + ) .route( "/internal/model_inferences/{inference_id}", get(endpoints::internal::model_inferences::get_model_inferences_handler), diff --git a/internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdParams.ts b/internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdResponse.ts similarity index 62% rename from internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdParams.ts rename to internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdResponse.ts index 323fe78400..11a8fbf6f6 100644 --- a/internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdParams.ts +++ b/internal/tensorzero-node/lib/bindings/CountFeedbackByTargetIdResponse.ts @@ -1,3 +1,3 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. -export type CountFeedbackByTargetIdParams = { target_id: string }; +export type CountFeedbackByTargetIdResponse = { count: bigint }; diff --git a/internal/tensorzero-node/lib/bindings/index.ts b/internal/tensorzero-node/lib/bindings/index.ts index 3fa6e5147f..fc323bcd0a 100644 --- a/internal/tensorzero-node/lib/bindings/index.ts +++ b/internal/tensorzero-node/lib/bindings/index.ts @@ -30,7 +30,7 @@ export * from "./CommentFeedbackRow"; export * from "./CommentTargetType"; export * from "./ContentBlockChatOutput"; export * from "./ContentBlockOutput"; -export * from "./CountFeedbackByTargetIdParams"; +export * from "./CountFeedbackByTargetIdResponse"; export * from "./CountInferencesRequest"; export * from "./CountInferencesResponse"; export * from "./CountModelsResponse"; diff --git a/internal/tensorzero-node/lib/index.ts b/internal/tensorzero-node/lib/index.ts index b53a239d73..55314ff658 100644 --- a/internal/tensorzero-node/lib/index.ts +++ b/internal/tensorzero-node/lib/index.ts @@ -11,7 +11,6 @@ import type { OptimizationJobHandle, OptimizationJobInfo, StaleDatasetResponse, - CountFeedbackByTargetIdParams, QueryDemonstrationFeedbackByInferenceIdParams, DemonstrationFeedbackRow, GetCumulativeFeedbackTimeseriesParams, @@ -232,15 +231,6 @@ export class DatabaseClient { ) as CumulativeFeedbackTimeSeriesPoint[]; } - async countFeedbackByTargetId( - params: CountFeedbackByTargetIdParams, - ): Promise { - const paramsString = safeStringify(params); - const countString = - await this.nativeDatabaseClient.countFeedbackByTargetId(paramsString); - return JSON.parse(countString) as number; - } - async getFeedbackByVariant( params: GetFeedbackByVariantParams, ): Promise { diff --git a/internal/tensorzero-node/src/database.rs b/internal/tensorzero-node/src/database.rs index 29dcce97b1..6b7e26bb1a 100644 --- a/internal/tensorzero-node/src/database.rs +++ b/internal/tensorzero-node/src/database.rs @@ -34,16 +34,6 @@ impl DatabaseClient { ) } - #[napi] - pub async fn count_feedback_by_target_id(&self, params: String) -> Result { - napi_call!( - &self, - count_feedback_by_target_id, - params, - CountFeedbackByTargetIdParams { target_id } - ) - } - #[napi] pub async fn query_demonstration_feedback_by_inference_id( &self, @@ -100,12 +90,6 @@ struct QueryDemonstrationFeedbackByInferenceIdParams { limit: Option, } -#[derive(Deserialize, ts_rs::TS)] -#[ts(export, optional_fields)] -struct CountFeedbackByTargetIdParams { - target_id: Uuid, -} - #[derive(Deserialize, ts_rs::TS)] #[ts(export, optional_fields)] struct GetFeedbackByVariantParams { diff --git a/tensorzero-core/src/endpoints/feedback/internal/count_feedback.rs b/tensorzero-core/src/endpoints/feedback/internal/count_feedback.rs new file mode 100644 index 0000000000..043b240e47 --- /dev/null +++ b/tensorzero-core/src/endpoints/feedback/internal/count_feedback.rs @@ -0,0 +1,93 @@ +//! Feedback endpoint for counting feedback by target ID + +use axum::extract::{Path, State}; +use axum::{Json, debug_handler}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; +use uuid::Uuid; + +use crate::db::feedback::FeedbackQueries; +use crate::error::Error; +use crate::utils::gateway::{AppState, AppStateData}; + +#[derive(Debug, Serialize, Deserialize, ts_rs::TS)] +#[ts(export)] +pub struct CountFeedbackByTargetIdResponse { + pub count: u64, +} + +/// HTTP handler for counting feedback by target ID +#[debug_handler(state = AppStateData)] +#[instrument( + name = "count_feedback_by_target_id_handler", + skip_all, + fields( + target_id = %target_id, + ) +)] +pub async fn count_feedback_by_target_id_handler( + State(app_state): AppState, + Path(target_id): Path, +) -> Result, Error> { + let response = + count_feedback_by_target_id(&app_state.clickhouse_connection_info, target_id).await?; + Ok(Json(response)) +} + +/// Core business logic for counting feedback by target ID +pub async fn count_feedback_by_target_id( + clickhouse: &impl FeedbackQueries, + target_id: Uuid, +) -> Result { + let count = clickhouse.count_feedback_by_target_id(target_id).await?; + + Ok(CountFeedbackByTargetIdResponse { count }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::feedback::MockFeedbackQueries; + + #[tokio::test] + async fn test_count_feedback_by_target_id_returns_count() { + let mut mock_clickhouse = MockFeedbackQueries::new(); + let target_id = Uuid::now_v7(); + + mock_clickhouse + .expect_count_feedback_by_target_id() + .withf(move |id| *id == target_id) + .times(1) + .returning(|_| Box::pin(async move { Ok(42) })); + + let response = count_feedback_by_target_id(&mock_clickhouse, target_id) + .await + .expect("Expected count to be returned"); + + assert_eq!( + response.count, 42, + "Expected count to be 42 from ClickHouse" + ); + } + + #[tokio::test] + async fn test_count_feedback_by_target_id_returns_zero_for_no_feedback() { + let mut mock_clickhouse = MockFeedbackQueries::new(); + let target_id = Uuid::now_v7(); + + mock_clickhouse + .expect_count_feedback_by_target_id() + .withf(move |id| *id == target_id) + .times(1) + .returning(|_| Box::pin(async move { Ok(0) })); + + let response = count_feedback_by_target_id(&mock_clickhouse, target_id) + .await + .expect("Expected zero count response"); + + assert_eq!( + response.count, 0, + "Expected count to be 0 when no feedback exists" + ); + } +} diff --git a/tensorzero-core/src/endpoints/feedback/internal/mod.rs b/tensorzero-core/src/endpoints/feedback/internal/mod.rs index 268242d4fb..47056b5a52 100644 --- a/tensorzero-core/src/endpoints/feedback/internal/mod.rs +++ b/tensorzero-core/src/endpoints/feedback/internal/mod.rs @@ -1,7 +1,9 @@ +mod count_feedback; mod get_feedback_bounds; mod get_feedback_by_target_id; mod latest_feedback_by_metric; +pub use count_feedback::*; pub use get_feedback_bounds::*; pub use get_feedback_by_target_id::*; pub use latest_feedback_by_metric::*; diff --git a/tensorzero-core/tests/e2e/endpoints/internal/feedback.rs b/tensorzero-core/tests/e2e/endpoints/internal/feedback.rs index 2e2b8775a8..611924cd5d 100644 --- a/tensorzero-core/tests/e2e/endpoints/internal/feedback.rs +++ b/tensorzero-core/tests/e2e/endpoints/internal/feedback.rs @@ -4,7 +4,8 @@ use reqwest::Client; use serde_json::json; use std::collections::HashMap; use tensorzero_core::endpoints::feedback::internal::{ - GetFeedbackBoundsResponse, GetFeedbackByTargetIdResponse, LatestFeedbackIdByMetricResponse, + CountFeedbackByTargetIdResponse, GetFeedbackBoundsResponse, GetFeedbackByTargetIdResponse, + LatestFeedbackIdByMetricResponse, }; use uuid::Uuid; @@ -393,3 +394,89 @@ async fn test_get_feedback_by_target_id_rejects_both_before_and_after() { "Expected 400 when both before and after are specified" ); } + +// ==================== Count Feedback By Target ID Tests ==================== + +#[tokio::test(flavor = "multi_thread")] +async fn test_count_feedback_by_target_id_with_feedback() { + let http_client = Client::new(); + + let inference_id = create_inference(&http_client, "basic_test").await; + let _feedback_id = + submit_inference_feedback(&http_client, inference_id, "task_success", json!(true)).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + let url = get_gateway_endpoint(&format!("/internal/feedback/{inference_id}/count")); + let resp = http_client.get(url).send().await.unwrap(); + + assert!( + resp.status().is_success(), + "Expected success when counting feedback" + ); + + let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap(); + assert_eq!( + response.count, 1, + "Expected count to be 1 for single feedback entry" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count_feedback_by_target_id_multiple_feedback() { + let http_client = Client::new(); + + let inference_id = create_inference(&http_client, "basic_test").await; + + // Submit multiple feedback entries + submit_inference_feedback(&http_client, inference_id, "task_success", json!(true)).await; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + submit_inference_feedback(&http_client, inference_id, "task_success", json!(false)).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + let url = get_gateway_endpoint(&format!("/internal/feedback/{inference_id}/count")); + let resp = http_client.get(url).send().await.unwrap(); + + assert!(resp.status().is_success()); + let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap(); + + assert_eq!( + response.count, 2, + "Expected count to be 2 for multiple feedback entries" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count_feedback_by_target_id_nonexistent_target() { + let http_client = Client::new(); + let nonexistent_id = Uuid::now_v7(); + + let url = get_gateway_endpoint(&format!("/internal/feedback/{nonexistent_id}/count")); + let resp = http_client.get(url).send().await.unwrap(); + + assert!( + resp.status().is_success(), + "Expected success for nonexistent target" + ); + + let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap(); + assert_eq!( + response.count, 0, + "Expected count to be 0 for nonexistent target" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count_feedback_by_target_id_invalid_uuid() { + let http_client = Client::new(); + + let url = get_gateway_endpoint("/internal/feedback/not-a-valid-uuid/count"); + let resp = http_client.get(url).send().await.unwrap(); + + assert_eq!( + resp.status(), + reqwest::StatusCode::BAD_REQUEST, + "Expected 400 for invalid UUID" + ); +} diff --git a/ui/app/routes/observability/episodes/$episode_id/route.tsx b/ui/app/routes/observability/episodes/$episode_id/route.tsx index 822136301f..b277f8d5dc 100644 --- a/ui/app/routes/observability/episodes/$episode_id/route.tsx +++ b/ui/app/routes/observability/episodes/$episode_id/route.tsx @@ -1,6 +1,5 @@ import { listInferencesWithPagination } from "~/utils/clickhouse/inference.server"; import { pollForFeedbackItem } from "~/utils/clickhouse/feedback"; -import { getNativeDatabaseClient } from "~/utils/tensorzero/native_client.server"; import { getTensorZeroClient } from "~/utils/tensorzero.server"; import type { Route } from "./+types/route"; import { @@ -94,16 +93,14 @@ export async function loader({ request, params }: Route.LoaderArgs) { throw data("Limit cannot exceed 100", { status: 400 }); } - const dbClient = await getNativeDatabaseClient(); const tensorZeroClient = getTensorZeroClient(); // Start count queries early - these will be streamed to section headers const numInferencesPromise = tensorZeroClient .getEpisodeInferenceCount(episode_id) .then((response) => response.inference_count); - const numFeedbacksPromise = dbClient.countFeedbackByTargetId({ - target_id: episode_id, - }); + const numFeedbacksPromise = + tensorZeroClient.countFeedbackByTargetId(episode_id); // Stream inferences data - will be resolved in the component // Throws error if no inferences found (episode doesn't exist) diff --git a/ui/app/utils/tensorzero/tensorzero.ts b/ui/app/utils/tensorzero/tensorzero.ts index e1c6f4da8d..97b533338d 100644 --- a/ui/app/utils/tensorzero/tensorzero.ts +++ b/ui/app/utils/tensorzero/tensorzero.ts @@ -14,6 +14,7 @@ import { import { GatewayConnectionError, TensorZeroServerError } from "./errors"; import type { CloneDatapointsResponse, + CountFeedbackByTargetIdResponse, CountInferencesRequest, CountInferencesResponse, CountModelsResponse, @@ -1332,6 +1333,23 @@ export class TensorZeroClient { ); } + /** + * Queries the count of feedback for a given target ID. + * @param targetId - The target ID (inference_id or episode_id) to count feedback for + * @returns A promise that resolves with the feedback count + * @throws Error if the request fails + */ + async countFeedbackByTargetId(targetId: string): Promise { + const endpoint = `/internal/feedback/${encodeURIComponent(targetId)}/count`; + const response = await this.fetch(endpoint, { method: "GET" }); + if (!response.ok) { + const message = await this.getErrorText(response); + this.handleHttpError({ message, response }); + } + const body = (await response.json()) as CountFeedbackByTargetIdResponse; + return Number(body.count); + } + /** * Gets information about specific evaluation runs. * @param evaluationRunIds - Array of evaluation run UUIDs to query