Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gateway/src/routes/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ pub fn build_internal_non_otel_enabled_routes() -> Router<AppStateData> {
"/internal/feedback/{target_id}/latest-id-by-metric",
get(endpoints::feedback::internal::get_latest_feedback_id_by_metric_handler),
)
.route(
"/internal/feedback/{target_id}/count",
get(endpoints::feedback::internal::count_feedback_by_target_id_handler),
)
.route(
"/internal/model_inferences/{inference_id}",
get(endpoints::internal::model_inferences::get_model_inferences_handler),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.

export type CountFeedbackByTargetIdParams = { target_id: string };
export type CountFeedbackByTargetIdResponse = { count: bigint };
2 changes: 1 addition & 1 deletion internal/tensorzero-node/lib/bindings/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export * from "./CommentFeedbackRow";
export * from "./CommentTargetType";
export * from "./ContentBlockChatOutput";
export * from "./ContentBlockOutput";
export * from "./CountFeedbackByTargetIdParams";
export * from "./CountFeedbackByTargetIdResponse";
export * from "./CountInferencesRequest";
export * from "./CountInferencesResponse";
export * from "./CountModelsResponse";
Expand Down
10 changes: 0 additions & 10 deletions internal/tensorzero-node/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import type {
OptimizationJobHandle,
OptimizationJobInfo,
StaleDatasetResponse,
CountFeedbackByTargetIdParams,
QueryDemonstrationFeedbackByInferenceIdParams,
DemonstrationFeedbackRow,
GetCumulativeFeedbackTimeseriesParams,
Expand Down Expand Up @@ -232,15 +231,6 @@ export class DatabaseClient {
) as CumulativeFeedbackTimeSeriesPoint[];
}

async countFeedbackByTargetId(
params: CountFeedbackByTargetIdParams,
): Promise<number> {
const paramsString = safeStringify(params);
const countString =
await this.nativeDatabaseClient.countFeedbackByTargetId(paramsString);
return JSON.parse(countString) as number;
}

async getFeedbackByVariant(
params: GetFeedbackByVariantParams,
): Promise<FeedbackByVariant[]> {
Expand Down
16 changes: 0 additions & 16 deletions internal/tensorzero-node/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ impl DatabaseClient {
)
}

#[napi]
pub async fn count_feedback_by_target_id(&self, params: String) -> Result<String, napi::Error> {
napi_call!(
&self,
count_feedback_by_target_id,
params,
CountFeedbackByTargetIdParams { target_id }
)
}

#[napi]
pub async fn query_demonstration_feedback_by_inference_id(
&self,
Expand Down Expand Up @@ -100,12 +90,6 @@ struct QueryDemonstrationFeedbackByInferenceIdParams {
limit: Option<u32>,
}

#[derive(Deserialize, ts_rs::TS)]
#[ts(export, optional_fields)]
struct CountFeedbackByTargetIdParams {
target_id: Uuid,
}

#[derive(Deserialize, ts_rs::TS)]
#[ts(export, optional_fields)]
struct GetFeedbackByVariantParams {
Expand Down
93 changes: 93 additions & 0 deletions tensorzero-core/src/endpoints/feedback/internal/count_feedback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//! Feedback endpoint for counting feedback by target ID

use axum::extract::{Path, State};
use axum::{Json, debug_handler};
use serde::{Deserialize, Serialize};
use tracing::instrument;
use uuid::Uuid;

use crate::db::feedback::FeedbackQueries;
use crate::error::Error;
use crate::utils::gateway::{AppState, AppStateData};

#[derive(Debug, Serialize, Deserialize, ts_rs::TS)]
#[ts(export)]
pub struct CountFeedbackByTargetIdResponse {
pub count: u64,
}

/// HTTP handler for counting feedback by target ID
#[debug_handler(state = AppStateData)]
#[instrument(
name = "count_feedback_by_target_id_handler",
skip_all,
fields(
target_id = %target_id,
)
)]
pub async fn count_feedback_by_target_id_handler(
State(app_state): AppState,
Path(target_id): Path<Uuid>,
) -> Result<Json<CountFeedbackByTargetIdResponse>, Error> {
let response =
count_feedback_by_target_id(&app_state.clickhouse_connection_info, target_id).await?;
Ok(Json(response))
}

/// Core business logic for counting feedback by target ID
pub async fn count_feedback_by_target_id(
clickhouse: &impl FeedbackQueries,
target_id: Uuid,
) -> Result<CountFeedbackByTargetIdResponse, Error> {
let count = clickhouse.count_feedback_by_target_id(target_id).await?;

Ok(CountFeedbackByTargetIdResponse { count })
}

#[cfg(test)]
mod tests {
use super::*;
use crate::db::feedback::MockFeedbackQueries;

#[tokio::test]
async fn test_count_feedback_by_target_id_returns_count() {
let mut mock_clickhouse = MockFeedbackQueries::new();
let target_id = Uuid::now_v7();

mock_clickhouse
.expect_count_feedback_by_target_id()
.withf(move |id| *id == target_id)
.times(1)
.returning(|_| Box::pin(async move { Ok(42) }));

let response = count_feedback_by_target_id(&mock_clickhouse, target_id)
.await
.expect("Expected count to be returned");

assert_eq!(
response.count, 42,
"Expected count to be 42 from ClickHouse"
);
}

#[tokio::test]
async fn test_count_feedback_by_target_id_returns_zero_for_no_feedback() {
let mut mock_clickhouse = MockFeedbackQueries::new();
let target_id = Uuid::now_v7();

mock_clickhouse
.expect_count_feedback_by_target_id()
.withf(move |id| *id == target_id)
.times(1)
.returning(|_| Box::pin(async move { Ok(0) }));

let response = count_feedback_by_target_id(&mock_clickhouse, target_id)
.await
.expect("Expected zero count response");

assert_eq!(
response.count, 0,
"Expected count to be 0 when no feedback exists"
);
}
}
2 changes: 2 additions & 0 deletions tensorzero-core/src/endpoints/feedback/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod count_feedback;
mod get_feedback_bounds;
mod get_feedback_by_target_id;
mod latest_feedback_by_metric;

pub use count_feedback::*;
pub use get_feedback_bounds::*;
pub use get_feedback_by_target_id::*;
pub use latest_feedback_by_metric::*;
89 changes: 88 additions & 1 deletion tensorzero-core/tests/e2e/endpoints/internal/feedback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use reqwest::Client;
use serde_json::json;
use std::collections::HashMap;
use tensorzero_core::endpoints::feedback::internal::{
GetFeedbackBoundsResponse, GetFeedbackByTargetIdResponse, LatestFeedbackIdByMetricResponse,
CountFeedbackByTargetIdResponse, GetFeedbackBoundsResponse, GetFeedbackByTargetIdResponse,
LatestFeedbackIdByMetricResponse,
};
use uuid::Uuid;

Expand Down Expand Up @@ -393,3 +394,89 @@ async fn test_get_feedback_by_target_id_rejects_both_before_and_after() {
"Expected 400 when both before and after are specified"
);
}

// ==================== Count Feedback By Target ID Tests ====================

#[tokio::test(flavor = "multi_thread")]
async fn test_count_feedback_by_target_id_with_feedback() {
let http_client = Client::new();

let inference_id = create_inference(&http_client, "basic_test").await;
let _feedback_id =
submit_inference_feedback(&http_client, inference_id, "task_success", json!(true)).await;

tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;

let url = get_gateway_endpoint(&format!("/internal/feedback/{inference_id}/count"));
let resp = http_client.get(url).send().await.unwrap();

assert!(
resp.status().is_success(),
"Expected success when counting feedback"
);

let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap();
assert_eq!(
response.count, 1,
"Expected count to be 1 for single feedback entry"
);
}

#[tokio::test(flavor = "multi_thread")]
async fn test_count_feedback_by_target_id_multiple_feedback() {
let http_client = Client::new();

let inference_id = create_inference(&http_client, "basic_test").await;

// Submit multiple feedback entries
submit_inference_feedback(&http_client, inference_id, "task_success", json!(true)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
submit_inference_feedback(&http_client, inference_id, "task_success", json!(false)).await;

tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;

let url = get_gateway_endpoint(&format!("/internal/feedback/{inference_id}/count"));
let resp = http_client.get(url).send().await.unwrap();

assert!(resp.status().is_success());
let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap();

assert_eq!(
response.count, 2,
"Expected count to be 2 for multiple feedback entries"
);
}

#[tokio::test(flavor = "multi_thread")]
async fn test_count_feedback_by_target_id_nonexistent_target() {
let http_client = Client::new();
let nonexistent_id = Uuid::now_v7();

let url = get_gateway_endpoint(&format!("/internal/feedback/{nonexistent_id}/count"));
let resp = http_client.get(url).send().await.unwrap();

assert!(
resp.status().is_success(),
"Expected success for nonexistent target"
);

let response: CountFeedbackByTargetIdResponse = resp.json().await.unwrap();
assert_eq!(
response.count, 0,
"Expected count to be 0 for nonexistent target"
);
}

#[tokio::test(flavor = "multi_thread")]
async fn test_count_feedback_by_target_id_invalid_uuid() {
let http_client = Client::new();

let url = get_gateway_endpoint("/internal/feedback/not-a-valid-uuid/count");
let resp = http_client.get(url).send().await.unwrap();

assert_eq!(
resp.status(),
reqwest::StatusCode::BAD_REQUEST,
"Expected 400 for invalid UUID"
);
}
7 changes: 2 additions & 5 deletions ui/app/routes/observability/episodes/$episode_id/route.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { listInferencesWithPagination } from "~/utils/clickhouse/inference.server";
import { pollForFeedbackItem } from "~/utils/clickhouse/feedback";
import { getNativeDatabaseClient } from "~/utils/tensorzero/native_client.server";
import { getTensorZeroClient } from "~/utils/tensorzero.server";
import type { Route } from "./+types/route";
import {
Expand Down Expand Up @@ -94,16 +93,14 @@ export async function loader({ request, params }: Route.LoaderArgs) {
throw data("Limit cannot exceed 100", { status: 400 });
}

const dbClient = await getNativeDatabaseClient();
const tensorZeroClient = getTensorZeroClient();

// Start count queries early - these will be streamed to section headers
const numInferencesPromise = tensorZeroClient
.getEpisodeInferenceCount(episode_id)
.then((response) => response.inference_count);
const numFeedbacksPromise = dbClient.countFeedbackByTargetId({
target_id: episode_id,
});
const numFeedbacksPromise =
tensorZeroClient.countFeedbackByTargetId(episode_id);

// Stream inferences data - will be resolved in the component
// Throws error if no inferences found (episode doesn't exist)
Expand Down
18 changes: 18 additions & 0 deletions ui/app/utils/tensorzero/tensorzero.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import { GatewayConnectionError, TensorZeroServerError } from "./errors";
import type {
CloneDatapointsResponse,
CountFeedbackByTargetIdResponse,
CountInferencesRequest,
CountInferencesResponse,
CountModelsResponse,
Expand Down Expand Up @@ -1332,6 +1333,23 @@ export class TensorZeroClient {
);
}

/**
* Queries the count of feedback for a given target ID.
* @param targetId - The target ID (inference_id or episode_id) to count feedback for
* @returns A promise that resolves with the feedback count
* @throws Error if the request fails
*/
async countFeedbackByTargetId(targetId: string): Promise<number> {
const endpoint = `/internal/feedback/${encodeURIComponent(targetId)}/count`;
const response = await this.fetch(endpoint, { method: "GET" });
if (!response.ok) {
const message = await this.getErrorText(response);
this.handleHttpError({ message, response });
}
const body = (await response.json()) as CountFeedbackByTargetIdResponse;
return Number(body.count);
}

/**
* Gets information about specific evaluation runs.
* @param evaluationRunIds - Array of evaluation run UUIDs to query
Expand Down
Loading