diff --git a/.cursor/commands/spdd-api-test.md b/.cursor/commands/spdd-api-test.md new file mode 100644 index 0000000..fc6269d --- /dev/null +++ b/.cursor/commands/spdd-api-test.md @@ -0,0 +1,657 @@ +--- +name: /spdd-api-test +id: spdd-api-test +category: Testing +description: Generate a self-contained shell script with cURL commands to test API endpoints based on generated code and acceptance criteria +--- + +Generate a comprehensive, self-contained shell script (`scripts/test-api.sh`) with cURL commands to test all API scenarios defined in the codebase or acceptance criteria. + +**Key Feature: Structured Test Case Tables** + +The generated script includes human-reviewable test case tables: +1. **At script top**: A structured table showing all test scenarios, inputs, and expected outputs +2. **After execution**: A results table showing expected vs actual results with pass/fail status + +This makes it easy for humans to: +- Review test coverage at a glance +- Verify expected values are correct +- Quickly identify which tests failed and why + +**Input**: The argument after `/spdd-api-test` is a reference to generated code, acceptance criteria document, or API specification. + +Input can be provided in several ways: + +1. **File/folder reference**: Using `@` to reference files containing API implementations or ACs +2. **Text description**: Direct text describing the API endpoints to test +3. **Combined**: Both file references and additional context + +**Examples**: + +``` +# Reference to generated code +/spdd-api-test @src/main/java/com/example/controller/BillingController.java + +# Reference to acceptance criteria +/spdd-api-test @spdd/prompt/GGQPA-XXX-202603131530-[Feat]-token-usage-billing.md + +# Reference to multiple files +/spdd-api-test @src/controllers/ @requirements/api-spec.md + +# With additional context +/spdd-api-test @src/api/ test the billing endpoints with edge cases for zero tokens +``` + +**Steps** + +1. **Validate and consolidate input context** + + a. **If no input provided**, use the **AskUserQuestion tool** (open-ended, no preset options) to ask: + > "Please provide the API implementation files, acceptance criteria, or API specification to generate tests for (you can use @file references or text description)." + + **IMPORTANT**: Do NOT proceed without input context. + + b. **If input contains `@` file/folder references**: + - Read ALL referenced files completely using the Read tool + - For folder references, read all relevant API-related files (controllers, routes, handlers) + - Consolidate all file contents into a unified API context + + c. **Extract API information**: + - Identify all API endpoints (HTTP method, path, request body schema) + - Extract acceptance criteria and expected behaviors + - Identify validation rules and error scenarios + - Note any seed data or test data requirements + +2. **Identify seed data and test fixtures** + + a. **Search for existing seed data**: + - Look for seed SQL files, fixtures, or test data in common locations: + - `src/main/resources/db/migration/` + - `src/test/resources/` + - `fixtures/` + - `seed/` + - `scripts/seed*.sql` + - Read relevant seed files to understand available test data + + b. **Document available test entities**: + - Customer IDs, names, and states + - Plan IDs, pricing configurations + - Quota configurations and limits + - Any other domain-specific test data + + c. **If no seed data found**, derive test data from: + - Entity definitions in the codebase + - Example values in acceptance criteria + - Common test data patterns + +3. **Design structured test case tables** + + Before generating the script, organize test cases into logical tables by test pattern. + + **Group tests by their input/output structure**: + + a. **Validation Error Tests** (expect HTTP 400/404): + ``` + | Test ID | Description | Customer | Model | Prompt | Completion | HTTP | Expected Error | + |---------|--------------------------|--------------|------------|--------|------------|------|--------------------------| + | AC1.1 | Missing modelId | CUST-001 | (missing) | 1000 | 500 | 400 | Model ID is required | + | AC1.2 | Missing customerId | (missing) | fast-model | 1000 | 500 | 400 | Customer ID is required | + | AC1.6 | Non-existent customer | NON-EXISTENT | fast-model | 1000 | 500 | 404 | Customer not found | + ``` + + b. **Standard Plan Tests** (quota-based billing): + ``` + | Test ID | Description | Customer | Model | Prompt | Completion | HTTP | IncludedUsed | Overage | TotalCharge | + |---------|----------------------|----------|------------|--------|------------|------|--------------|---------|-------------| + | AC2.1 | Within quota | CUST-001 | fast-model | 1000 | 500 | 201 | 1500 | 0 | 0.00 | + | AC2.3 | Exceeds small quota | CUST-002 | fast-model | 10000 | 5000 | 201 | 10000 | 5000 | 0.15 | + ``` + + c. **Premium Plan Tests** (split-rate billing): + ``` + | Test ID | Description | Customer | Model | Prompt | Completion | HTTP | PromptCharge | CompletionCharge | TotalCharge | + |---------|---------------------|--------------|-----------------|--------|------------|------|--------------|------------------|-------------| + | AC3.1 | fast-model billing | CUST-PREMIUM | fast-model | 10000 | 5000 | 201 | 0.10 | 0.10 | 0.20 | + | AC3.2 | reasoning-model | CUST-PREMIUM | reasoning-model | 10000 | 20000 | 201 | 0.30 | 1.20 | 1.50 | + ``` + + d. **Special/Structural Tests** (checks that don't fit tabular pattern): + - Keep these as descriptive test cases in the script + +4. **Generate test script structure** + + Create the shell script with the following structure: + + ```bash + #!/bin/bash + # ============================================================================= + # API Test Script + # Generated for: [API/Feature Name] + # ============================================================================= + # + # Usage: ./scripts/test-api.sh [BASE_URL] + # Default BASE_URL: http://localhost:8080 + # + # Requirements: + # - No external dependencies (no jq, only curl and bash) + # - Each request has -m 10 timeout to prevent hanging + # - HTTP status captured via: -o /tmp/response.txt -w "%{http_code}" + # + # ============================================================================= + # + # TEST CASE OVERVIEW (Human-Reviewable) + # ============================================================================= + # + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ VALIDATION ERROR TESTS │ + # ├─────────┬──────────────────────┬──────────┬────────┬────────┬──────┬────────┤ + # │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ HTTP │ + # ├─────────┼──────────────────────┼──────────┼────────┼────────┼──────┼────────┤ + # │ AC1.1 │ Missing modelId │ CUST-001 │ - │ 1000 │ 500 │ 400 │ + # │ AC1.2 │ Missing customerId │ - │ fast │ 1000 │ 500 │ 400 │ + # │ AC1.6 │ Non-existent customer│ INVALID │ fast │ 1000 │ 500 │ 404 │ + # └─────────┴──────────────────────┴──────────┴────────┴────────┴──────┴────────┘ + # + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ STANDARD PLAN TESTS (Quota-based billing) │ + # ├─────────┬────────────────┬──────────┬────────┬────────┬──────┬──────┬───────┤ + # │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ HTTP │ Charge│ + # ├─────────┼────────────────┼──────────┼────────┼────────┼──────┼──────┼───────┤ + # │ AC2.1 │ Within quota │ CUST-001 │ fast │ 1000 │ 500 │ 201 │ 0.00 │ + # │ AC2.3 │ Exceeds quota │ CUST-002 │ fast │ 10000 │ 5000 │ 201 │ 0.15 │ + # └─────────┴────────────────┴──────────┴────────┴────────┴──────┴──────┴───────┘ + # + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ PREMIUM PLAN TESTS (Split prompt/completion billing) │ + # ├─────────┬────────────────┬──────────┬─────────┬────────┬──────┬──────┬──────┤ + # │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ HTTP │ Total│ + # ├─────────┼────────────────┼──────────┼─────────┼────────┼──────┼──────┼──────┤ + # │ AC3.1 │ fast-model │ PREMIUM │ fast │ 10000 │ 5000 │ 201 │ 0.20 │ + # │ AC3.2 │ reasoning-model│ PREMIUM │ reason │ 10000 │ 20000│ 201 │ 1.50 │ + # └─────────┴────────────────┴──────────┴─────────┴────────┴──────┴──────┴──────┘ + # + # ============================================================================= + + # ----------------------------------------------------------------------------- + # CONFIGURATION + # ----------------------------------------------------------------------------- + BASE_URL="${1:-http://localhost:8080}" + + # Colors for output (disabled if not a terminal) + if [ -t 1 ]; then + RED='\033[0;31m' + GREEN='\033[0;32m' + YELLOW='\033[1;33m' + BLUE='\033[0;34m' + CYAN='\033[0;36m' + NC='\033[0m' # No Color + else + RED='' + GREEN='' + YELLOW='' + BLUE='' + CYAN='' + NC='' + fi + + # ----------------------------------------------------------------------------- + # SEED DATA REFERENCE + # ----------------------------------------------------------------------------- + # [Document available test data here] + # Customers: + # - ID: xxx, Name: xxx, Status: xxx + # Plans: + # - ID: xxx, Name: xxx, Rate: xxx + # Quotas: + # - ID: xxx, Limit: xxx, Reset: xxx + # ----------------------------------------------------------------------------- + + # ----------------------------------------------------------------------------- + # TEST COUNTERS AND RESULT TRACKING + # ----------------------------------------------------------------------------- + TESTS_PASSED=0 + TESTS_FAILED=0 + TESTS_TOTAL=0 + + # Arrays to track results for final summary table + declare -a TEST_IDS + declare -a TEST_DESCRIPTIONS + declare -a EXPECTED_STATUS + declare -a ACTUAL_STATUS + declare -a TEST_RESULTS + + # ----------------------------------------------------------------------------- + # HELPER FUNCTIONS + # ----------------------------------------------------------------------------- + print_test_header() { + echo "" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${BLUE}TEST: $1${NC}" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" + } + + print_expected() { + echo -e "${YELLOW}Expected: $1${NC}" + } + + print_result() { + echo -e "${GREEN}Response:${NC}" + } + + # Record test result for final summary table + # Usage: record_result "Test ID" "Description" "Expected" "Actual" "PASS|FAIL" + record_result() { + TEST_IDS+=("$1") + TEST_DESCRIPTIONS+=("$2") + EXPECTED_STATUS+=("$3") + ACTUAL_STATUS+=("$4") + TEST_RESULTS+=("$5") + } + + # Check test result - called after each curl command + # Usage: check_result "Test ID" "Test Description" "Expected Status" "$HTTP_CODE" "$BODY" + check_result() { + local test_id="$1" + local test_desc="$2" + local expected_status="$3" + local actual_status="$4" + local body="$5" + + echo "$body" + echo "" + + if [ "$actual_status" = "$expected_status" ]; then + echo -e "${GREEN}✓ PASSED${NC} [HTTP Status: $actual_status]" + TESTS_PASSED=$((TESTS_PASSED + 1)) + record_result "$test_id" "$test_desc" "$expected_status" "$actual_status" "PASS" + else + echo -e "${RED}✗ FAILED${NC} [HTTP Status: $actual_status, Expected: $expected_status]" + TESTS_FAILED=$((TESTS_FAILED + 1)) + record_result "$test_id" "$test_desc" "$expected_status" "$actual_status" "FAIL" + fi + echo "" + } + + # Print final results table + print_results_table() { + echo "" + echo -e "${CYAN}┌─────────────────────────────────────────────────────────────────────────────┐${NC}" + echo -e "${CYAN}│ TEST RESULTS SUMMARY │${NC}" + echo -e "${CYAN}├──────────┬────────────────────────────────┬──────────┬──────────┬──────────┤${NC}" + echo -e "${CYAN}│ Test ID │ Description │ Expected │ Actual │ Result │${NC}" + echo -e "${CYAN}├──────────┼────────────────────────────────┼──────────┼──────────┼──────────┤${NC}" + + for i in "${!TEST_IDS[@]}"; do + local result_color="${GREEN}" + if [ "${TEST_RESULTS[$i]}" = "FAIL" ]; then + result_color="${RED}" + fi + printf "${CYAN}│${NC} %-8s ${CYAN}│${NC} %-30s ${CYAN}│${NC} %-8s ${CYAN}│${NC} %-8s ${CYAN}│${NC} ${result_color}%-8s${NC} ${CYAN}│${NC}\n" \ + "${TEST_IDS[$i]}" \ + "${TEST_DESCRIPTIONS[$i]:0:30}" \ + "${EXPECTED_STATUS[$i]}" \ + "${ACTUAL_STATUS[$i]}" \ + "${TEST_RESULTS[$i]}" + done + + echo -e "${CYAN}└──────────┴────────────────────────────────┴──────────┴──────────┴──────────┘${NC}" + } + + # ----------------------------------------------------------------------------- + # TEST CASES + # ----------------------------------------------------------------------------- + ``` + + **IMPORTANT**: Do NOT use `eval` or `run_test` wrapper function with complex quoting. + Instead, use direct curl calls for each test case (see test case format below). + +4. **Generate test cases for each acceptance criterion** + + For each identified acceptance criterion or API endpoint: + + a. **Happy path tests**: + - Valid requests with expected successful responses + - Test with available seed data + + b. **Validation error tests**: + - Missing required fields + - Invalid field formats + - Out-of-range values + + c. **Edge case tests**: + - Zero values (e.g., zero tokens, zero amount) + - Empty strings or null values + - Boundary conditions + + d. **Error scenario tests**: + - Not found resources (invalid IDs) + - Conflict scenarios + - Unauthorized access (if applicable) + + **Test case format** (use direct curl calls, NOT eval-based wrapper): + + ```bash + # ----------------------------------------------------------------------------- + # AC[N]: [Acceptance Criterion Description] + # ----------------------------------------------------------------------------- + TEST_ID="AC[N]" + TEST_DESC="[Short Description]" + EXPECTED="[Expected HTTP Status]" + TESTS_TOTAL=$((TESTS_TOTAL + 1)) + print_test_header "$TEST_ID: $TEST_DESC" + print_expected "HTTP $EXPECTED" + print_result + HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X [METHOD] "${BASE_URL}/api/[endpoint]" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"[field]": "[value]"}') + BODY=$(cat /tmp/response.txt) + check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + ``` + +5. **Add edge case and negative tests** + + Include additional tests beyond acceptance criteria: + + ```bash + # ----------------------------------------------------------------------------- + # EDGE CASES + # ----------------------------------------------------------------------------- + + # Edge Case: Zero tokens + TEST_ID="EDGE1" + TEST_DESC="Zero Token Count" + EXPECTED="[Expected Status]" + TESTS_TOTAL=$((TESTS_TOTAL + 1)) + print_test_header "$TEST_ID: $TEST_DESC" + print_expected "HTTP $EXPECTED" + print_result + HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/[endpoint]" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"tokenCount": 0}') + BODY=$(cat /tmp/response.txt) + check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + + # Edge Case: Missing required field + TEST_ID="EDGE2" + TEST_DESC="Missing Required Field" + EXPECTED="400" + TESTS_TOTAL=$((TESTS_TOTAL + 1)) + print_test_header "$TEST_ID: $TEST_DESC" + print_expected "HTTP $EXPECTED" + print_result + HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/[endpoint]" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{}') + BODY=$(cat /tmp/response.txt) + check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + + # Edge Case: Invalid ID format + TEST_ID="EDGE3" + TEST_DESC="Invalid ID Format" + EXPECTED="404" + TESTS_TOTAL=$((TESTS_TOTAL + 1)) + print_test_header "$TEST_ID: $TEST_DESC" + print_expected "HTTP $EXPECTED" + print_result + HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X GET "${BASE_URL}/api/[endpoint]/invalid-id" \ + -H "Content-Type: application/json" \ + -m 10) + BODY=$(cat /tmp/response.txt) + check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + ``` + +6. **Add cleanup and test summary footer** + + ```bash + # ----------------------------------------------------------------------------- + # CLEANUP + # ----------------------------------------------------------------------------- + rm -f /tmp/response.txt + + # ----------------------------------------------------------------------------- + # TEST SUMMARY + # ----------------------------------------------------------------------------- + echo "" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${BLUE}TEST EXECUTION COMPLETE${NC}" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" + echo "" + echo "Base URL: ${BASE_URL}" + echo "Finished at: $(date)" + echo "" + + # Print structured results table + print_results_table + + echo "" + echo -e "Tests Passed: ${GREEN}${TESTS_PASSED}${NC}" + echo -e "Tests Failed: ${RED}${TESTS_FAILED}${NC}" + echo -e "Total Tests: ${TESTS_TOTAL}" + echo "" + + # Calculate pass rate + if [ "$TESTS_TOTAL" -gt 0 ]; then + PASS_RATE=$((TESTS_PASSED * 100 / TESTS_TOTAL)) + if [ "$TESTS_FAILED" -eq 0 ]; then + echo -e "${GREEN}✓ All tests passed! (${PASS_RATE}%)${NC}" + else + echo -e "${RED}✗ Some tests failed (${PASS_RATE}% passed)${NC}" + fi + fi + echo "" + + # Exit with error code if any tests failed + if [ "$TESTS_FAILED" -gt 0 ]; then + exit 1 + fi + ``` + +7. **Write script to file and make executable** + + a. **Create scripts directory if needed**: + - Ensure `scripts/` directory exists under the project root + + b. **Write the complete script**: + - Save to `scripts/test-api.sh` + + c. **Make script executable**: + - Run `chmod +x scripts/test-api.sh` + +8. **Report generation summary** + + ``` + ✅ API test script generated and saved to `scripts/test-api.sh` + + 📋 Test coverage: + - Acceptance Criteria tests: [count] + - Edge case tests: [count] + - Negative tests: [count] + - Total test cases: [count] + + 📊 Seed data referenced: + - Customers: [count] + - Plans: [count] + - [Other entities]: [count] + + 👀 Human Review Features: + - Test case overview table at script top (view before running) + - Results summary table after execution (expected vs actual) + - Each test has unique ID for easy tracking + + 📈 Test result tracking: + - Pass/Fail counters with color-coded output + - Pass rate percentage calculation + - Exit code: 0 (all passed) or 1 (some failed) + + 🚀 Usage: + ./scripts/test-api.sh # Uses localhost:8080 + ./scripts/test-api.sh http://api:3000 # Custom base URL + + 🔄 CI/CD Integration: + The script returns exit code 1 if any tests fail, + making it suitable for use in CI/CD pipelines. + ``` + +**Output** + +A self-contained, executable shell script (`scripts/test-api.sh`) with: +- **Structured test case table at script top** - Human-reviewable overview of all scenarios +- **Structured results table after execution** - Expected vs actual with pass/fail status +- cURL commands for all acceptance criteria scenarios +- Edge case and negative test scenarios +- Clear test output formatting +- No external dependencies (bash + curl only) +- Timeout protection (`-m 10`) on all requests +- HTTP status capture via `-o /tmp/response.txt -w "%{http_code}"` +- Documented seed data reference +- **Success/Failure counting** with pass rate calculation +- **Exit code** reflecting test results (0 = all passed, 1 = some failed) +- **Cleanup** of temp files after execution + +**Script Requirements Checklist** + +| Requirement | Implementation | +|-------------|----------------| +| **Test case overview table** | Commented table at script top showing all scenarios, inputs, expected outputs | +| **Results summary table** | `print_results_table` function showing expected vs actual after execution | +| Timeout | `-m 10` on every curl command | +| No external dependencies | Pure bash + curl, no jq/yq | +| HTTP status capture | `-o /tmp/response.txt -w "%{http_code}"` (NOT eval-based) | +| Test coverage | All ACs + edge cases + negative tests | +| Clear formatting | Header/expected/result for each test | +| Seed data reference | Commented at script top | +| Executable | `chmod +x` applied | +| Success/Failure counting | `TESTS_PASSED`, `TESTS_FAILED` counters with summary | +| Exit code | Exit 1 if any tests failed, exit 0 otherwise | +| Cleanup | `rm -f /tmp/response.txt` at end of script | + +**Guardrails** + +- Do NOT proceed without input context (API files or ACs) +- Do NOT use external tools like `jq`, `yq`, or `python` in the generated script +- Do NOT generate tests without proper timeout (`-m 10`) +- Do NOT skip edge cases (zero values, missing fields, invalid IDs) +- Do NOT hardcode test data without documenting it in the seed reference section +- **Do NOT use `eval` with `-w "\n%{http_code}"` — this causes shell quoting issues** +- **Do NOT use special characters (like `|`) in `-w` format strings** +- Always use `-o /tmp/response.txt -w "%{http_code}"` for HTTP status capture +- Always use `"${BASE_URL}"` variable for endpoint URLs +- Always make the script executable with `chmod +x` +- Always create `scripts/` directory if it does not exist +- Always cleanup temp files (`rm -f /tmp/response.txt`) at end of script +- Error messages in expected results MUST match actual API error messages +- Test data IDs MUST reference actual seed data or clearly indicate synthetic data + +**Context Integrity Guardrails**: + +- **MUST read ALL `@` referenced files completely** — do NOT skip or partially read +- **MUST search for seed data** in common locations before generating tests +- **Verify API endpoint paths** match actual implementation +- **Verify request body schemas** match actual DTOs/request objects +- **Verify error messages** match actual exception handler responses + +**cURL Best Practices** + +1. **Always include these flags**: + - `-s` (silent mode, suppress progress) + - `-m 10` (10 second timeout) + - `-o /tmp/response.txt` (write body to temp file) + - `-w "%{http_code}"` (capture only HTTP status code) + - `-H "Content-Type: application/json"` (for JSON bodies) + +2. **Request body formatting**: + - Use `-d` with single-quoted JSON strings + - Properly escape special characters in JSON + - Use single quotes around JSON to avoid shell interpolation issues + +3. **Variable interpolation**: + - Use `"${BASE_URL}"` with quotes to handle special characters + - Use `"${VARIABLE}"` syntax for all shell variables in URLs + +**CRITICAL: Shell Quoting Pitfalls to Avoid** + +When generating shell scripts with curl commands, avoid these common pitfalls: + +1. **DO NOT use `eval` with `-w "\n%{http_code}"`**: + - The `\n` escape sequence does not work reliably through `eval` + - It often outputs literal `n` instead of a newline, breaking HTTP status extraction + +2. **DO NOT use special characters in `-w` format string**: + - Characters like `|` will be interpreted as shell pipe operators + - Example: `-w '|||%{http_code}'` will fail with "syntax error near unexpected token `|'" + +3. **CORRECT APPROACH - Use temp file for response body**: + ```bash + # Capture HTTP status code directly, write body to temp file + HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/endpoint" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"field": "value"}') + BODY=$(cat /tmp/response.txt) + ``` + +4. **Why this works**: + - `-o /tmp/response.txt` writes the response body to a file (not stdout) + - `-w "%{http_code}"` outputs ONLY the HTTP status code to stdout + - No complex string parsing or newline handling needed + - No `eval` required, avoiding all shell quoting issues + +**Integration with SPDD Workflow** + +This command is the **validation phase** of the SPDD workflow: + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ SPDD Workflow │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Phase 1: /spdd-analysis │ +│ ┌────────────────────────────────────────────────────────────────┐ │ +│ │ Business Requirement → Enriched Context │ │ +│ └────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Phase 2: /spdd-reasons-canvas │ +│ ┌────────────────────────────────────────────────────────────────┐ │ +│ │ Enriched Context → REASONS Canvas Structured Prompt │ │ +│ └────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Phase 3: /spdd-generate │ +│ ┌────────────────────────────────────────────────────────────────┐ │ +│ │ Structured Prompt → Implementation Code │ │ +│ └────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Phase 4: /spdd-api-test ← YOU ARE HERE │ +│ ┌────────────────────────────────────────────────────────────────┐ │ +│ │ Generated Code + ACs → API Test Script │ │ +│ │ │ │ +│ │ Output: scripts/test-api.sh │ │ +│ └────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ Phase 5: Validate & Iterate │ +│ ┌────────────────────────────────────────────────────────────────┐ │ +│ │ Run tests → Identify issues → Update via /spdd-sync │ │ +│ └────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +**Why This Phase Matters** + +API testing validates that the generated implementation: +1. **Meets acceptance criteria**: Each AC is verified with actual HTTP requests +2. **Handles edge cases**: Zero values, missing fields, invalid data are tested +3. **Returns correct errors**: Error messages and status codes match specifications +4. **Works end-to-end**: Real HTTP calls verify the full request/response cycle + +The self-contained nature ensures: +- Tests can run in any environment with just `bash` and `curl` +- No installation of additional tools required +- CI/CD pipelines can execute without setup complexity +- Results are immediately visible without parsing \ No newline at end of file diff --git a/requirements/[User-story-1]Multi-Plan-Billing-Foundation-&-Model-Aware-Pricing.md b/requirements/[User-story-1]Multi-Plan-Billing-Foundation-&-Model-Aware-Pricing.md new file mode 100644 index 0000000..6135e21 --- /dev/null +++ b/requirements/[User-story-1]Multi-Plan-Billing-Foundation-&-Model-Aware-Pricing.md @@ -0,0 +1,39 @@ +## Background + +As our LLM API platform scales, a single pricing model is no longer sufficient. We need to refactor our existing billing engine to support different subscription strategies and variable pricing based on the AI model invoked, laying the foundation for future complex billing plans. + +## Business Value + +1. **Flexible Monetization**: Support diverse billing strategies (Standard, Premium) to capture different market segments. +2. **Model-Aware Pricing**: Charge different rates based on the specific AI model used. +3. **Architecture Scalability**: Implement an extensible design (e.g., Strategy Pattern) to isolate calculation logic and easily add future pricing models. + +## Scope In + +- Enhance the existing POST /api/usage endpoint. +- **New Request Field:** Add modelId (required, string, e.g., "fast-model", "reasoning-model"). +- Implement a routing mechanism (Strategy/Factory Pattern) to handle distinct calculation formulas. +- Implement two initial Plan Types: + 1. **Standard Plan (Legacy Refactor):** Has a monthly global quota. Overage rates now depend on the modelId. + 2. **Premium Plan (New):** No quota. Prompt and Completion tokens are billed separately, and rates vary by modelId. + +## Scope Out + +- Complex tiered/volume-based discount logic (Deferred to Phase 2). +- Subscription plan creation and assignment CRUD. +- Invoice generation. + +## Acceptance Criteria (ACs) + +1. **Base Validations (Regression & New)** + **Given** an invalid request (e.g., missing `modelId`, negative tokens) + **When** backend validates request + **Then** return HTTP 400 with appropriate error messages. +2. **Standard Plan with Model-Aware Overage** + **Given** a "Standard" customer with a 100,000 monthly quota, 90,000 used so far. Overage for "fast-model" is `$0.01/1K`. + **When** submitting 30,000 tokens for "fast-model" + **Then** bill shows: 10,000 from quota, 20,000 overage, $0.20 charge. +3. **Premium Plan with Split Rates** + **Given** a "Premium" customer. For "reasoning-model", Prompt is `$0.03/1K`, Completion is `$0.06/1K`. + **When** submitting 10,000 prompt and 20,000 completion tokens for "reasoning-model" + **Then** bill shows: 0 from quota, `$0.30` prompt charge, `$1.20` completion charge, total `$1.50`. diff --git a/requirements/[User-story-2]Enterprise-Plan-Volume-Based-Tiered-Billing.md b/requirements/[User-story-2]Enterprise-Plan-Volume-Based-Tiered-Billing.md new file mode 100644 index 0000000..77ae559 --- /dev/null +++ b/requirements/[User-story-2]Enterprise-Plan-Volume-Based-Tiered-Billing.md @@ -0,0 +1,26 @@ +## Background + +Building on our newly refactored multi-plan billing engine, we need to introduce complex, volume-based tiered discounts specifically for our high-volume Enterprise customers. + +## Business Value + +1. **Incentivize Usage**: Encourage high-volume clients through automated tiered discounts. +2. **Competitive Enterprise Offering**: Provide enterprise-grade billing flexibility. + +## Scope In + +- Enhance the existing billing Strategy/Factory to support a new Enterprise plan type. +- Implement logic to query the customer's accumulated usage for the current month prior to calculating the current request's cost. +- Calculate cross-tier dynamic rates based on specific usage thresholds. + +## Scope Out + +- Foundational multi-plan refactoring (Completed in Story 1). +- Invoice generation. + +## Acceptance Criteria (ACs) + +1. **Enterprise Plan with Cross-Tier Calculation (Complex)** + **Given** an "Enterprise" customer using "fast-model". Tier 1 (0 to 50,000 tokens) is `$0.02/1K`. Tier 2 (above 50,000 tokens) is `$0.01/1K`. The customer has already used 40,000 tokens this month. + **When** submitting 30,000 tokens (which pushes the total to 70,000) + **Then** bill splits the usage: 10,000 tokens calculated at Tier 1 rate (`$0.20`), and 20,000 tokens calculated at Tier 2 rate (`$0.20`), for a total charge of `$0.40`. diff --git a/scripts/test-api.sh b/scripts/test-api.sh new file mode 100755 index 0000000..1f6d43c --- /dev/null +++ b/scripts/test-api.sh @@ -0,0 +1,638 @@ +#!/bin/bash +# ============================================================================= +# API Test Script +# Generated for: Multi-Plan Billing Foundation & Model-Aware Pricing +# ============================================================================= +# +# Usage: ./scripts/test-api.sh [BASE_URL] +# Default BASE_URL: http://localhost:8080 +# +# Requirements: +# - No external dependencies (no jq, only curl and bash) +# - Each request has -m 10 timeout to prevent hanging +# - HTTP status captured via: -o /tmp/response.txt -w "%{http_code}" +# +# Prerequisites: +# - Application running with database migrated (V1 + V2) +# - Seed data loaded from migrations +# +# ============================================================================= +# +# TEST CASE OVERVIEW (Human-Reviewable) +# ============================================================================= +# +# ┌───────────────────────────────────────────────────────────────────────────────────────┐ +# │ VALIDATION ERROR TESTS (AC1) │ +# ├─────────┬────────────────────────────┬──────────────┬─────────────────┬───────┬───────┤ +# │ Test ID │ Description │ Customer │ Model │ HTTP │ Error │ +# ├─────────┼────────────────────────────┼──────────────┼─────────────────┼───────┼───────┤ +# │ AC1.1 │ Missing modelId │ CUST-001 │ (missing) │ 400 │ Model ID is required │ +# │ AC1.2 │ Missing customerId │ (missing) │ fast-model │ 400 │ Customer ID is required │ +# │ AC1.3 │ Negative promptTokens │ CUST-001 │ fast-model │ 400 │ Token count cannot be negative │ +# │ AC1.4 │ Negative completionTokens │ CUST-001 │ fast-model │ 400 │ Token count cannot be negative │ +# │ AC1.5 │ Missing promptTokens │ CUST-001 │ fast-model │ 400 │ (validation error) │ +# │ AC1.6 │ Non-existent customer │ NON-EXISTENT │ fast-model │ 404 │ Customer not found │ +# │ AC1.7 │ Unknown model │ CUST-001 │ unknown-model │ 400 │ Pricing not configured │ +# └─────────┴────────────────────────────┴──────────────┴─────────────────┴───────┴───────┘ +# +# ┌───────────────────────────────────────────────────────────────────────────────────────┐ +# │ STANDARD PLAN TESTS (AC2) - Quota-based billing │ +# ├─────────┬────────────────────────┬──────────┬──────────┬────────┬──────┬───────┬──────┤ +# │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ HTTP │Charge│ +# ├─────────┼────────────────────────┼──────────┼──────────┼────────┼──────┼───────┼──────┤ +# │ AC2.1 │ Within quota │ CUST-001 │ fast │ 1000 │ 500 │ 201 │ 0.00 │ +# │ AC2.2 │ Response has modelId │ CUST-003 │ fast │ 5000 │ 5000 │ 201 │ 0.00 │ +# │ AC2.3 │ Exceeds small quota │ CUST-002 │ fast │ 10000 │ 5000 │ 201 │ 0.15 │ +# │ AC2.4 │ Using reasoning-model │ CUST-001 │ reason │ 2000 │ 3000 │ 201 │ 0.00 │ +# └─────────┴────────────────────────┴──────────┴──────────┴────────┴──────┴───────┴──────┘ +# +# ┌───────────────────────────────────────────────────────────────────────────────────────┐ +# │ PREMIUM PLAN TESTS (AC3) - Split prompt/completion billing │ +# ├─────────┬────────────────────────┬──────────────┬─────────┬────────┬──────┬───────────┤ +# │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ Total │ +# ├─────────┼────────────────────────┼──────────────┼─────────┼────────┼──────┼───────────┤ +# │ AC3.1 │ fast-model billing │ CUST-PREMIUM │ fast │ 10000 │ 5000 │ 0.20 │ +# │ AC3.2 │ reasoning-model │ CUST-PREMIUM │ reason │ 10000 │ 20000│ 1.50 │ +# │ AC3.3 │ Response has modelId │ CUST-PREMIUM │ reason │ 1000 │ 1000 │ 0.09 │ +# │ AC3.4 │ Zero tokens │ CUST-PREMIUM │ fast │ 0 │ 0 │ 0.00 │ +# └─────────┴────────────────────────┴──────────────┴─────────┴────────┴──────┴───────────┘ +# +# ┌───────────────────────────────────────────────────────────────────────────────────────┐ +# │ EDGE CASE TESTS │ +# ├─────────┬────────────────────────┬──────────┬──────────┬────────┬──────┬──────────────┤ +# │ Test ID │ Description │ Customer │ Model │ Prompt │ Comp │ HTTP │ +# ├─────────┼────────────────────────┼──────────┼──────────┼────────┼──────┼──────────────┤ +# │ EDGE1 │ Zero tokens │ CUST-001 │ fast │ 0 │ 0 │ 201 │ +# │ EDGE2 │ Only prompt tokens │ CUST-001 │ fast │ 1000 │ 0 │ 201 │ +# │ EDGE3 │ Only completion tokens │ CUST-001 │ fast │ 0 │ 1000 │ 201 │ +# │ EDGE4 │ Empty JSON body │ - │ - │ - │ - │ 400 │ +# │ EDGE5 │ Invalid JSON format │ - │ - │ - │ - │ 400 │ +# │ EDGE6 │ Large token count │ CUST-003 │ fast │ 500000 │500000│ 201 │ +# └─────────┴────────────────────────┴──────────┴──────────┴────────┴──────┴──────────────┘ +# +# ============================================================================= + +# ----------------------------------------------------------------------------- +# CONFIGURATION +# ----------------------------------------------------------------------------- +BASE_URL="${1:-http://localhost:8080}" + +# Colors for output (disabled if not a terminal) +if [ -t 1 ]; then + RED='\033[0;31m' + GREEN='\033[0;32m' + YELLOW='\033[1;33m' + BLUE='\033[0;34m' + CYAN='\033[0;36m' + NC='\033[0m' # No Color +else + RED='' + GREEN='' + YELLOW='' + BLUE='' + CYAN='' + NC='' +fi + +# ----------------------------------------------------------------------------- +# SEED DATA REFERENCE +# ----------------------------------------------------------------------------- +# Customers (from V1 + V2 migrations): +# - CUST-001: Acme Corp → PLAN-STARTER (100,000 quota, STANDARD) +# - CUST-002: TechStart Inc → PLAN-FREE (10,000 quota, STANDARD) +# - CUST-003: Enterprise Ltd → PLAN-ENTERPRISE (2,000,000 quota, STANDARD) +# - CUST-PREMIUM: Premium Test Corp → PLAN-PREMIUM (no quota, PREMIUM) +# +# Plans (from V1 + V2 migrations): +# - PLAN-FREE: 10,000 quota, $0.03/1K overage, STANDARD +# - PLAN-STARTER: 100,000 quota, $0.02/1K overage, STANDARD +# - PLAN-PRO: 500,000 quota, $0.015/1K overage, STANDARD +# - PLAN-ENTERPRISE: 2,000,000 quota, $0.01/1K overage, STANDARD +# - PLAN-PREMIUM: 0 quota (no quota), PREMIUM +# +# Model Pricing (from V2 migration): +# - All STANDARD plans: fast-model & reasoning-model inherit plan's overage rate +# - PLAN-PREMIUM: fast-model ($0.01 prompt, $0.02 completion) +# - PLAN-PREMIUM: reasoning-model ($0.03 prompt, $0.06 completion) +# +# Models: +# - fast-model +# - reasoning-model +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# TEST COUNTERS AND RESULT TRACKING +# ----------------------------------------------------------------------------- +TESTS_PASSED=0 +TESTS_FAILED=0 +TESTS_TOTAL=0 + +# Arrays to track results for final summary table +declare -a TEST_IDS +declare -a TEST_DESCRIPTIONS +declare -a EXPECTED_STATUS +declare -a ACTUAL_STATUS +declare -a TEST_RESULTS + +# ----------------------------------------------------------------------------- +# HELPER FUNCTIONS +# ----------------------------------------------------------------------------- +print_test_header() { + echo "" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${BLUE}TEST: $1${NC}" + echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" +} + +print_expected() { + echo -e "${YELLOW}Expected: $1${NC}" +} + +print_result() { + echo -e "${GREEN}Response:${NC}" +} + +# Record test result for final summary table +record_result() { + TEST_IDS+=("$1") + TEST_DESCRIPTIONS+=("$2") + EXPECTED_STATUS+=("$3") + ACTUAL_STATUS+=("$4") + TEST_RESULTS+=("$5") +} + +check_result() { + local test_id="$1" + local test_desc="$2" + local expected_status="$3" + local actual_status="$4" + local body="$5" + + echo "$body" + echo "" + + if [ "$actual_status" = "$expected_status" ]; then + echo -e "${GREEN}✓ PASSED${NC} [HTTP Status: $actual_status]" + TESTS_PASSED=$((TESTS_PASSED + 1)) + record_result "$test_id" "$test_desc" "$expected_status" "$actual_status" "PASS" + else + echo -e "${RED}✗ FAILED${NC} [HTTP Status: $actual_status, Expected: $expected_status]" + TESTS_FAILED=$((TESTS_FAILED + 1)) + record_result "$test_id" "$test_desc" "$expected_status" "$actual_status" "FAIL" + fi + echo "" +} + +# Print final results table +print_results_table() { + echo "" + echo -e "${CYAN}┌─────────────────────────────────────────────────────────────────────────────┐${NC}" + echo -e "${CYAN}│ TEST RESULTS SUMMARY │${NC}" + echo -e "${CYAN}├──────────┬────────────────────────────────┬──────────┬──────────┬──────────┤${NC}" + echo -e "${CYAN}│ Test ID │ Description │ Expected │ Actual │ Result │${NC}" + echo -e "${CYAN}├──────────┼────────────────────────────────┼──────────┼──────────┼──────────┤${NC}" + + for i in "${!TEST_IDS[@]}"; do + local result_color="${GREEN}" + if [ "${TEST_RESULTS[$i]}" = "FAIL" ]; then + result_color="${RED}" + fi + printf "${CYAN}│${NC} %-8s ${CYAN}│${NC} %-30s ${CYAN}│${NC} %-8s ${CYAN}│${NC} %-8s ${CYAN}│${NC} ${result_color}%-8s${NC} ${CYAN}│${NC}\n" \ + "${TEST_IDS[$i]}" \ + "${TEST_DESCRIPTIONS[$i]:0:30}" \ + "${EXPECTED_STATUS[$i]}" \ + "${ACTUAL_STATUS[$i]}" \ + "${TEST_RESULTS[$i]}" + done + + echo -e "${CYAN}└──────────┴────────────────────────────────┴──────────┴──────────┴──────────┘${NC}" +} + +# ============================================================================= +# ACCEPTANCE CRITERIA TESTS +# ============================================================================= + +# ----------------------------------------------------------------------------- +# AC1: Base Validations (Regression & New) +# Given an invalid request (e.g., missing modelId, negative tokens) +# When backend validates request +# Then return HTTP 400 with appropriate error messages +# ----------------------------------------------------------------------------- + +TEST_ID="AC1.1" +TEST_DESC="Missing modelId returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message 'Model ID is required'" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "promptTokens": 1000, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.2" +TEST_DESC="Missing customerId returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message 'Customer ID is required'" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.3" +TEST_DESC="Negative promptTokens returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message 'Token count cannot be negative'" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": -100, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.4" +TEST_DESC="Negative completionTokens returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message 'Token count cannot be negative'" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 1000, "completionTokens": -500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.5" +TEST_DESC="Missing promptTokens returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.6" +TEST_DESC="Non-existent customer returns 404" +EXPECTED="404" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message 'Customer not found'" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "NON-EXISTENT", "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC1.7" +TEST_DESC="Unknown model returns 400" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with error message about pricing not configured" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "unknown-model", "promptTokens": 1000, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +# AC2: Standard Plan with Model-Aware Overage +# Given a "Standard" customer with a 100,000 monthly quota +# When submitting tokens for "fast-model" +# Then bill shows correct quota usage and overage calculation +# ----------------------------------------------------------------------------- + +TEST_ID="AC2.1" +TEST_DESC="Standard - Within quota (no overage)" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with includedTokensUsed=1500, overageTokens=0, totalCharge=0" +echo -e "${YELLOW}Using: CUST-001 (PLAN-STARTER: 100,000 quota)${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC2.2" +TEST_DESC="Standard - Response has modelId" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with modelId='fast-model' in response" +echo -e "${YELLOW}Using: CUST-003 (PLAN-ENTERPRISE: 2,000,000 quota)${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-003", "modelId": "fast-model", "promptTokens": 5000, "completionTokens": 5000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC2.3" +TEST_DESC="Standard - Exceeds small quota" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with overage charged (PLAN-FREE has only 10,000 quota)" +echo -e "${YELLOW}Using: CUST-002 (PLAN-FREE: 10,000 quota, \$0.03/1K overage)${NC}" +echo -e "${YELLOW}Submitting 15,000 tokens should result in 5,000 overage = \$0.15${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-002", "modelId": "fast-model", "promptTokens": 10000, "completionTokens": 5000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC2.4" +TEST_DESC="Standard - Using reasoning-model" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with modelId='reasoning-model'" +echo -e "${YELLOW}Using: CUST-001 with reasoning-model${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "reasoning-model", "promptTokens": 2000, "completionTokens": 3000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ============================================================================= +# EDGE CASE TESTS +# ============================================================================= + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE1" +TEST_DESC="Zero tokens submission" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with totalTokens=0, totalCharge=0" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 0, "completionTokens": 0}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE2" +TEST_DESC="Only prompt tokens (zero completion)" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with completionTokens=0" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 0}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE3" +TEST_DESC="Only completion tokens (zero prompt)" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with promptTokens=0" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 0, "completionTokens": 1000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE4" +TEST_DESC="Empty JSON body" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE5" +TEST_DESC="Invalid JSON format" +EXPECTED="400" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{ invalid json }') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="EDGE6" +TEST_DESC="Large token count" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED - handles large numbers correctly" +echo -e "${YELLOW}Using: CUST-003 (PLAN-ENTERPRISE: 2,000,000 quota)${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-003", "modelId": "fast-model", "promptTokens": 500000, "completionTokens": 500000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +# AC3: Premium Plan with Split Prompt/Completion Rates +# Given a "Premium" customer (no monthly quota) +# When submitting usage with prompt and completion tokens +# Then bill shows separate promptCharge and completionCharge (no quota deduction) +# ----------------------------------------------------------------------------- + +TEST_ID="AC3.1" +TEST_DESC="Premium - fast-model billing" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with includedTokensUsed=0, overageTokens=0, split charges" +echo -e "${YELLOW}Using: CUST-PREMIUM (PLAN-PREMIUM: no quota)${NC}" +echo -e "${YELLOW}fast-model rates: \$0.01/1K prompt, \$0.02/1K completion${NC}" +echo -e "${YELLOW}10,000 prompt + 5,000 completion = \$0.10 + \$0.10 = \$0.20 total${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-PREMIUM", "modelId": "fast-model", "promptTokens": 10000, "completionTokens": 5000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC3.2" +TEST_DESC="Premium - reasoning-model billing" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with promptCharge and completionCharge in response" +echo -e "${YELLOW}Using: CUST-PREMIUM (PLAN-PREMIUM)${NC}" +echo -e "${YELLOW}reasoning-model rates: \$0.03/1K prompt, \$0.06/1K completion${NC}" +echo -e "${YELLOW}10,000 prompt + 20,000 completion = \$0.30 + \$1.20 = \$1.50 total${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-PREMIUM", "modelId": "reasoning-model", "promptTokens": 10000, "completionTokens": 20000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC3.3" +TEST_DESC="Premium - Response has modelId" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with modelId='reasoning-model' in response" +echo -e "${YELLOW}Verify modelId field is present in Premium plan response${NC}" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-PREMIUM", "modelId": "reasoning-model", "promptTokens": 1000, "completionTokens": 1000}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="AC3.4" +TEST_DESC="Premium - Zero tokens submission" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with totalCharge=0, promptCharge=0, completionCharge=0" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-PREMIUM", "modelId": "fast-model", "promptTokens": 0, "completionTokens": 0}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ============================================================================= +# RESPONSE STRUCTURE VALIDATION +# ============================================================================= + +# ----------------------------------------------------------------------------- +TEST_ID="STRUCT1" +TEST_DESC="Standard - All required fields" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with billId, customerId, modelId, totalTokens, includedTokensUsed, overageTokens, totalCharge, calculatedAt" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-001", "modelId": "fast-model", "promptTokens": 100, "completionTokens": 100}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +TEST_ID="STRUCT2" +TEST_DESC="Premium - Contains charge breakdown" +EXPECTED="201" +TESTS_TOTAL=$((TESTS_TOTAL + 1)) +print_test_header "$TEST_ID: $TEST_DESC" +print_expected "HTTP $EXPECTED with promptCharge and completionCharge fields" +print_result +HTTP_CODE=$(curl -s -o /tmp/response.txt -w "%{http_code}" -X POST "${BASE_URL}/api/usage" \ + -H "Content-Type: application/json" \ + -m 10 \ + -d '{"customerId": "CUST-PREMIUM", "modelId": "fast-model", "promptTokens": 100, "completionTokens": 100}') +BODY=$(cat /tmp/response.txt) +check_result "$TEST_ID" "$TEST_DESC" "$EXPECTED" "$HTTP_CODE" "$BODY" + +# ----------------------------------------------------------------------------- +# CLEANUP +# ----------------------------------------------------------------------------- +rm -f /tmp/response.txt + +# ----------------------------------------------------------------------------- +# TEST SUMMARY +# ----------------------------------------------------------------------------- +echo "" +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" +echo -e "${BLUE}TEST EXECUTION COMPLETE${NC}" +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" +echo "" +echo "Base URL: ${BASE_URL}" +echo "Finished at: $(date)" + +# Print structured results table +print_results_table + +echo "" +echo -e "Tests Passed: ${GREEN}${TESTS_PASSED}${NC}" +echo -e "Tests Failed: ${RED}${TESTS_FAILED}${NC}" +echo -e "Total Tests: ${TESTS_TOTAL}" +echo "" + +# Calculate pass rate +if [ "$TESTS_TOTAL" -gt 0 ]; then + PASS_RATE=$((TESTS_PASSED * 100 / TESTS_TOTAL)) + if [ "$TESTS_FAILED" -eq 0 ]; then + echo -e "${GREEN}✓ All tests passed! (${PASS_RATE}%)${NC}" + else + echo -e "${RED}✗ Some tests failed (${PASS_RATE}% passed)${NC}" + fi +fi +echo "" + +# Exit with error code if any tests failed +if [ "$TESTS_FAILED" -gt 0 ]; then + exit 1 +fi diff --git a/spdd/analysis/GGQPA-001-202603191100-[Analysis]-multi-plan-billing-model-aware-pricing.md b/spdd/analysis/GGQPA-001-202603191100-[Analysis]-multi-plan-billing-model-aware-pricing.md new file mode 100644 index 0000000..3408d29 --- /dev/null +++ b/spdd/analysis/GGQPA-001-202603191100-[Analysis]-multi-plan-billing-model-aware-pricing.md @@ -0,0 +1,170 @@ +# SPDD Analysis: Multi-Plan Billing Foundation & Model-Aware Pricing + +## Original Business Requirement + +## Background + +As our LLM API platform scales, a single pricing model is no longer sufficient. We need to refactor our existing billing engine to support different subscription strategies and variable pricing based on the AI model invoked, laying the foundation for future complex billing plans. + +## Business Value + +1. **Flexible Monetization**: Support diverse billing strategies (Standard, Premium) to capture different market segments. +2. **Model-Aware Pricing**: Charge different rates based on the specific AI model used. +3. **Architecture Scalability**: Implement an extensible design (e.g., Strategy Pattern) to isolate calculation logic and easily add future pricing models. + +## Scope In + +- Enhance the existing POST /api/usage endpoint. +- **New Request Field:** Add modelId (required, string, e.g., "fast-model", "reasoning-model"). +- Implement a routing mechanism (Strategy/Factory Pattern) to handle distinct calculation formulas. +- Implement two initial Plan Types: + 1. **Standard Plan (Legacy Refactor):** Has a monthly global quota. Overage rates now depend on the modelId. + 2. **Premium Plan (New):** No quota. Prompt and Completion tokens are billed separately, and rates vary by modelId. + +## Scope Out + +- Complex tiered/volume-based discount logic (Deferred to Phase 2). +- Subscription plan creation and assignment CRUD. +- Invoice generation. + +## Acceptance Criteria (ACs) + +1. **Base Validations (Regression & New)** + **Given** an invalid request (e.g., missing `modelId`, negative tokens) + **When** backend validates request + **Then** return HTTP 400 with appropriate error messages. +2. **Standard Plan with Model-Aware Overage** + **Given** a "Standard" customer with a 100,000 monthly quota, 90,000 used so far. Overage for "fast-model" is `$0.01/1K`. + **When** submitting 30,000 tokens for "fast-model" + **Then** bill shows: 10,000 from quota, 20,000 overage, $0.20 charge. +3. **Premium Plan with Split Rates** + **Given** a "Premium" customer. For "reasoning-model", Prompt is `$0.03/1K`, Completion is `$0.06/1K`. + **When** submitting 10,000 prompt and 20,000 completion tokens for "reasoning-model" + **Then** bill shows: 0 from quota, `$0.30` prompt charge, `$1.20` completion charge, total `$1.50`. + +--- + +## Domain Concept Identification + +### Existing Concepts (from codebase) + +- **Customer**: Identity holder for billing; `customers` table with `id` (VARCHAR) as PK — unchanged, remains the billing anchor entity +- **PricingPlan**: Currently defines `monthly_quota` and single `overage_rate_per_1k` in `pricing_plans` table — **needs refactoring** to support plan types and remove model-agnostic rate +- **CustomerSubscription**: Links customer to pricing plan with temporal validity (`effective_from`, `effective_to`) — unchanged, continues to determine which plan governs billing +- **Bill**: Records usage and calculated charges in `bills` table — **needs extension** to store `model_id` and potentially split charge breakdown +- **UsageRequest**: Input DTO for usage submission — **needs extension** to include required `modelId` field +- **BillResponse**: Output DTO for calculated bill — **needs extension** to include model-specific charge breakdown + +### New Concepts Required + +- **PlanType**: Discriminator that identifies the billing strategy (e.g., "STANDARD", "PREMIUM") — attached to PricingPlan, determines which calculation strategy to apply +- **ModelPricing**: Rate configuration per AI model per plan — stores model-specific rates (overage rate for Standard, prompt/completion rates for Premium) +- **BillingStrategy**: Calculation logic encapsulation (Strategy Pattern) — polymorphic billing calculators selected based on PlanType + - **StandardBillingStrategy**: Quota-first consumption with model-aware overage rates + - **PremiumBillingStrategy**: No quota, separate prompt/completion charges with model-aware rates +- **BillingStrategyFactory**: Factory to resolve the appropriate BillingStrategy based on PlanType + +### Key Business Rules + +- **Model ID Required**: Every usage submission must specify which AI model was invoked — governs UsageRequest validation +- **Plan Type Determines Strategy**: The PlanType field on PricingPlan dictates which calculation formula to use — governs strategy selection +- **Standard Plan Rules**: + - Global monthly quota still applies (unchanged from legacy) + - `remaining_quota = monthly_quota - sum(included_tokens_used for current month)` + - `included_tokens_used = min(total_tokens, max(remaining_quota, 0))` + - `overage_tokens = total_tokens - included_tokens_used` + - **Model-aware overage**: `overage_charge = (overage_tokens / 1000) × model_specific_overage_rate` +- **Premium Plan Rules**: + - No quota concept; `monthly_quota = NULL` or 0, `included_tokens_used = 0` + - Split billing: prompt and completion tokens charged separately + - `prompt_charge = (prompt_tokens / 1000) × model_prompt_rate` + - `completion_charge = (completion_tokens / 1000) × model_completion_rate` + - `total_charge = prompt_charge + completion_charge` +- **Model Pricing Resolution**: Given a (plan, modelId) pair, resolve the applicable rates — fail if no pricing configured for the requested model +- **Backward Compatibility**: Existing plans must continue working; Standard plan refactors legacy behavior without breaking existing customers + +--- + +## Strategic Approach + +### Solution Direction + +Refactor the billing engine to support polymorphic calculation strategies, following the existing project conventions: + +1. **Schema Evolution**: Add `plan_type` to `pricing_plans`, create `model_pricing` table for model-specific rates, add `model_id` to `bills` and potentially charge breakdown fields +2. **Strategy Pattern Implementation**: Introduce `BillingStrategy` interface with `StandardBillingStrategy` and `PremiumBillingStrategy` implementations +3. **Factory Pattern**: Use a `BillingStrategyFactory` (or Spring bean lookup by PlanType) to resolve the appropriate strategy at runtime +4. **Service Refactor**: `BillingServiceImpl.calculateBill()` delegates to the resolved strategy instead of hardcoded logic +5. **Domain Model Update**: `Bill.create()` becomes strategy-specific; each strategy produces its own Bill variant or uses a unified Bill with optional fields +6. **DTO Extension**: Add `modelId` to `UsageRequest`; extend `BillResponse` with model info and charge breakdown + +### Key Design Decisions + +| Decision | Trade-offs | Recommendation | +|----------|------------|----------------| +| **Where to store PlanType** | In PricingPlan table (denormalized, simple lookup) vs. separate PlanType reference table (normalized, more flexible) | Add `plan_type VARCHAR` column to `pricing_plans` — simpler, sufficient for STANDARD/PREMIUM enum; normalize later if more metadata needed | +| **Model pricing storage** | Embed rates in PricingPlan (simple but rigid) vs. separate `model_pricing` table (normalized, model per plan flexibility) | Create `model_pricing` table with FK to `pricing_plans` and model_id — supports multiple models per plan with different rates | +| **Strategy resolution mechanism** | Enum-based switch (simple, compile-time safe) vs. Spring bean lookup by name (flexible, injectable) | Use Spring bean lookup with `@Qualifier` or bean name by PlanType — aligns with Spring idioms, allows strategies to have injected dependencies | +| **Premium Plan quota handling** | Set `monthly_quota = 0` (reuses existing column) vs. `monthly_quota = NULL` (semantically distinct) | Use `monthly_quota = 0` for Premium — simpler, avoids null handling complexity; Standard always has `monthly_quota > 0` | +| **Bill charge breakdown storage** | Single `total_charge` only (current) vs. add `prompt_charge`, `completion_charge` columns | Add `prompt_charge` and `completion_charge` nullable columns to `bills` — enables Premium billing breakdown and audit trail | +| **Model validation** | Fail if model not in pricing table vs. use default rates | Fail with HTTP 400 "Unknown model" — explicit failure prevents billing errors; models must be explicitly configured | + +### Alternatives Considered + +- **Inheritance-based Plan entities (StandardPlan, PremiumPlan)**: Rejected — JPA single-table inheritance adds complexity; discriminator column (`plan_type`) with strategy pattern achieves polymorphism without entity hierarchy +- **Configuration-driven formula evaluation (e.g., expression language)**: Rejected — over-engineered for two well-defined calculation types; Strategy pattern is cleaner and more maintainable +- **Separate endpoints per plan type**: Rejected — violates RESTful resource design; single `/api/usage` endpoint with internal routing is cleaner +- **Storing rates as JSON blob**: Rejected — loses query capability and type safety; normalized `model_pricing` table is more robust + +--- + +## Risk & Gap Analysis + +### Requirement Ambiguities + +| Ambiguity | What needs clarification | +|-----------|-------------------------| +| **Default model pricing** | What happens if a customer submits usage for a model without configured pricing? Recommendation: Return HTTP 400 "Pricing not configured for model: {modelId}". | +| **Standard Plan model not configured** | Should Standard plan have a fallback overage rate if specific model pricing is missing? Recommendation: No fallback; require explicit model pricing configuration. | +| **Premium Plan `included_tokens_used` semantics** | Should the response show `included_tokens_used = 0` explicitly, or omit the field? Recommendation: Always include with value 0 for Premium — consistent response schema. | +| **Charge breakdown in response** | AC3 mentions "$0.30 prompt charge, $1.20 completion charge" — should response have separate fields or just a breakdown description? Recommendation: Add `promptCharge`, `completionCharge` fields to BillResponse. | +| **modelId validation rules** | What format/length constraints apply to modelId? Recommendation: Non-empty string, max 50 characters, following slug format (lowercase, hyphens). | +| **Bill table backward compatibility** | Existing bills have no `model_id`. Should we add a NOT NULL constraint? Recommendation: Add `model_id` as nullable initially; existing bills remain valid with NULL. | + +### Edge Cases + +| Scenario | Why it matters | +|----------|----------------| +| **Model not configured for plan** | Customer's plan doesn't have pricing for the submitted modelId — verify clear error message | +| **Standard Plan with zero remaining quota** | All tokens become overage; verify full overage calculation with model-specific rate | +| **Premium Plan with zero tokens** | 0 prompt and 0 completion tokens — should produce $0.00 total, valid bill | +| **Premium Plan one token type only** | e.g., 0 prompt, 1000 completion — verify correct single-charge calculation | +| **Very high rates** | Verify no overflow; BigDecimal handles large charges correctly | +| **Multiple models in same plan** | Verify model_pricing lookup correctly isolates rates per model | +| **Legacy data migration** | Existing bills without model_id; existing plans without plan_type — verify migration strategy | + +### Technical Risks + +| Risk | Potential Impact | Mitigation Direction | +|------|------------------|---------------------| +| **Schema migration complexity** | Adding columns and new table requires careful migration; existing seed data needs plan_type | Create V2 migration that: (1) adds plan_type with default 'STANDARD', (2) creates model_pricing table, (3) adds model_id to bills as nullable, (4) migrates existing pricing to model_pricing | +| **Strategy selection performance** | Bean lookup per request could add overhead | Strategies are singleton beans; lookup is O(1) HashMap; negligible impact | +| **Breaking existing tests** | Existing tests assume single pricing model; many will fail | Update tests incrementally; ensure regression tests still pass with Standard plan behavior | +| **Model pricing cache staleness** | If pricing loaded once at startup, changes require restart | For MVP, accept restart requirement; add cache invalidation later if needed | +| **Concurrency on quota calculation** | Same risk as before; now with additional model dimension | Unchanged mitigation; quota is still global per customer, not per model | + +### Acceptance Criteria Coverage + +| AC# | Description | Addressable? | Gaps/Notes | +|-----|-------------|--------------|------------| +| AC1 | Validate invalid requests (missing modelId, negative tokens) → 400 | Yes | Add `@NotNull` for modelId; existing validation handles negative tokens | +| AC2 | Standard Plan: 100K quota, 90K used, 30K for "fast-model" @ $0.01/1K → 10K from quota, 20K overage, $0.20 | Yes | StandardBillingStrategy with model-specific overage rate | +| AC3 | Premium Plan: "reasoning-model" prompt $0.03/1K, completion $0.06/1K → 10K×$0.03 + 20K×$0.06 = $0.30 + $1.20 = $1.50 | Yes | PremiumBillingStrategy with split rates | + +**AC Coverage Summary**: All 3 ACs are addressable with the proposed Strategy Pattern approach. + +**Implicit Requirements Not in ACs**: +- Response structure needs `modelId` and potentially charge breakdown fields +- Error response for unknown model scenario +- Migration path for existing data (plan_type, model_pricing seed data) +- Existing functionality regression (Standard plan customers without explicit model pricing config) diff --git a/spdd/prompt/GGQPA-001-202603191105-[Feat]-multi-plan-billing-model-aware-pricing.md b/spdd/prompt/GGQPA-001-202603191105-[Feat]-multi-plan-billing-model-aware-pricing.md new file mode 100644 index 0000000..b148786 --- /dev/null +++ b/spdd/prompt/GGQPA-001-202603191105-[Feat]-multi-plan-billing-model-aware-pricing.md @@ -0,0 +1,703 @@ +# Multi-Plan Billing Foundation & Model-Aware Pricing Implementation + +## Requirements + +Refactor the existing billing engine to support multiple subscription strategies (Standard, Premium) and variable pricing based on the AI model invoked. The system must route billing calculations to appropriate strategies based on plan type, apply model-specific rates, and return itemized billing breakdowns. This lays the foundation for future complex billing plans by implementing an extensible Strategy Pattern architecture. + +**Key Capabilities:** +- Add required `modelId` field to usage submission +- Implement Standard Plan: quota-first consumption with model-aware overage rates +- Implement Premium Plan: no quota, separate prompt/completion billing with model-specific rates +- Return detailed charge breakdowns including prompt charge, completion charge, and model information + +## Entities + +```mermaid +classDiagram + direction TB + + class Customer { + +String id + +String name + +LocalDateTime createdAt + } + + class PricingPlan { + +String id + +String name + +PlanType planType + +Integer monthlyQuota + +LocalDateTime createdAt + } + + class PlanType { + <> + STANDARD + PREMIUM + } + + class ModelPricing { + +UUID id + +String planId + +String modelId + +BigDecimal overageRatePer1k + +BigDecimal promptRatePer1k + +BigDecimal completionRatePer1k + +LocalDateTime createdAt + } + + class CustomerSubscription { + +UUID id + +String customerId + +String planId + +LocalDate effectiveFrom + +LocalDate effectiveTo + +LocalDateTime createdAt + } + + class Bill { + +UUID id + +String customerId + +String modelId + +Integer promptTokens + +Integer completionTokens + +Integer totalTokens + +Integer includedTokensUsed + +Integer overageTokens + +BigDecimal promptCharge + +BigDecimal completionCharge + +BigDecimal totalCharge + +LocalDateTime calculatedAt + } + + class UsageRequest { + +String customerId + +String modelId + +Integer promptTokens + +Integer completionTokens + } + + class BillResponse { + +UUID billId + +String customerId + +String modelId + +Integer totalTokens + +Integer includedTokensUsed + +Integer overageTokens + +BigDecimal promptCharge + +BigDecimal completionCharge + +BigDecimal totalCharge + +LocalDateTime calculatedAt + } + + class BillingStrategy { + <> + +Bill calculate(BillingContext context) + +PlanType supportedPlanType() + } + + class StandardBillingStrategy { + +Bill calculate(BillingContext context) + +PlanType supportedPlanType() + } + + class PremiumBillingStrategy { + +Bill calculate(BillingContext context) + +PlanType supportedPlanType() + } + + class BillingContext { + +String customerId + +String modelId + +Integer promptTokens + +Integer completionTokens + +Integer remainingQuota + +ModelPricing modelPricing + } + + Customer "1" -- "0..*" CustomerSubscription : has + PricingPlan "1" -- "0..*" CustomerSubscription : defines + PricingPlan "1" -- "1" PlanType : has + PricingPlan "1" -- "0..*" ModelPricing : configures + Customer "1" -- "0..*" Bill : generates + UsageRequest --> Bill : creates via strategy + Bill --> BillResponse : maps to + BillingStrategy <|.. StandardBillingStrategy : implements + BillingStrategy <|.. PremiumBillingStrategy : implements + BillingContext --> BillingStrategy : input to +``` + +## Approach + +1. **Schema Evolution**: + - Add `plan_type VARCHAR(20)` column to `pricing_plans` table with default 'STANDARD' + - Create `model_pricing` table for model-specific rates per plan + - Add `model_id VARCHAR(50) NOT NULL DEFAULT 'fast-model'`, `prompt_charge DECIMAL(10,2)`, `completion_charge DECIMAL(10,2)` columns to `bills` table + - Migrate existing pricing plans to have explicit model pricing entries + +2. **Strategy Pattern Implementation**: + - Define `BillingStrategy` interface with `calculate(BillingContext)` and `supportedPlanType()` methods + - Implement `StandardBillingStrategy`: quota-first consumption, model-aware overage calculation + - Implement `PremiumBillingStrategy`: no quota, split prompt/completion billing + - Use Spring `@Component` with `@Qualifier` for strategy registration + +3. **Strategy Resolution**: + - Create `BillingStrategyFactory` that maps `PlanType` to `BillingStrategy` bean + - Inject all strategies via constructor, build lookup map on initialization + - Resolve strategy at runtime based on customer's plan type + +4. **Model Pricing Resolution**: + - Create `ModelPricingRepository` interface and JPA implementation + - Query by (planId, modelId) to get applicable rates + - Fail with HTTP 400 if model pricing not configured for the requested model + +5. **Domain Model Updates**: + - Add `PlanType` enum (STANDARD, PREMIUM) + - Update `PricingPlan` domain entity to include `planType` (remove `overageRatePer1k` - now in ModelPricing) + - Create `ModelPricing` domain entity for rate storage + - Update `Bill` domain entity with `modelId`, `promptCharge`, `completionCharge` + - Create `BillingContext` as value object for strategy input + +6. **Service Layer Refactoring**: + - `BillingServiceImpl.calculateBill()` orchestrates: + 1. Validate customer exists + 2. Resolve active subscription → pricing plan with plan type + 3. Resolve model pricing for (plan, modelId) + 4. Calculate remaining quota (only for Standard plans) + 5. Build `BillingContext` with all required data + 6. Delegate to appropriate `BillingStrategy` + 7. Persist and return bill + +7. **DTO Updates**: + - Add `@NotNull modelId` to `UsageRequest` + - Add `modelId`, `promptCharge`, `completionCharge` to `BillResponse` + +8. **Exception Handling**: + - Add `ModelPricingNotFoundException` for unknown model scenarios + - Return HTTP 400 with message "Pricing not configured for model: {modelId}" + +## Structure + +### Design Principles + +1. **Strategy Pattern for Billing Calculations**: Encapsulate different billing algorithms (Standard, Premium) as interchangeable strategies. The service layer delegates to the appropriate strategy based on plan type, enabling easy addition of new billing models without modifying existing code. + +2. **Three-Layer Architecture with Decoupled Models**: Maintain the existing clean separation: + - Controllers handle HTTP concerns + - Services orchestrate business logic and strategy selection + - Repositories abstract data access + - Domain entities remain framework-agnostic + +3. **Dependency Inversion Principle**: + - Service depends on `BillingStrategy` interface, not concrete implementations + - Strategy implementations depend on abstractions (ModelPricing, BillingContext) + - Factory pattern abstracts strategy resolution + +### Inheritance Relationships + +1. `StandardBillingStrategy` implements `BillingStrategy` for quota-based billing +2. `PremiumBillingStrategy` implements `BillingStrategy` for split-rate billing +3. `ModelPricingNotFoundException` extends `RuntimeException` for pricing lookup failures +4. `PlanType` is an enum (STANDARD, PREMIUM) - no inheritance needed +5. Domain entities remain pure Java objects (no JPA inheritance) +6. Persistence Objects use discriminator column (`plan_type`) - not JPA inheritance + +### Dependencies + +1. `UsageController` depends on `BillingService` interface (unchanged) +2. `BillingServiceImpl` depends on: + - `CustomerRepository`, `CustomerSubscriptionRepository`, `BillRepository` (existing) + - `ModelPricingRepository` (new) + - `BillingStrategyFactory` (new) +3. `BillingStrategyFactory` depends on `List` (Spring auto-injects all implementations) +4. `StandardBillingStrategy` and `PremiumBillingStrategy` are stateless Spring components +5. `GlobalExceptionHandler` handles `ModelPricingNotFoundException` (new handler) + +### Layered Architecture Updates + +1. **Domain Layer** (`domain`): + - Add `PlanType` enum + - Add `ModelPricing` entity + - Add `BillingContext` value object + - Update `PricingPlan` (add planType, remove overageRatePer1k) + - Update `Bill` (add modelId, promptCharge, completionCharge) + +2. **Service Layer** (`service`): + - Add `BillingStrategy` interface + - Add `strategy/StandardBillingStrategy` implementation + - Add `strategy/PremiumBillingStrategy` implementation + - Add `BillingStrategyFactory` + - Update `BillingServiceImpl` to use strategies + +3. **Repository Layer** (`repository`): + - Add `ModelPricingRepository` interface + +4. **Infrastructure Layer** (`infrastructure/persistence`): + - Add `ModelPricingPO` persistence object + - Add `SpringDataModelPricingRepository` + - Add `JpaModelPricingRepositoryAdapter` + - Add `ModelPricingMapper` + - Update `PricingPlanPO` (add planType column) + - Update `BillPO` (add modelId, promptCharge, completionCharge) + - Update `PricingPlanMapper`, `BillMapper` + +5. **Exception Layer** (`exception`): + - Add `ModelPricingNotFoundException` + - Update `GlobalExceptionHandler` + +## Operations + +### Create Database Migration V2 + +1. Responsibility: Schema evolution for multi-plan billing support +2. Location: `src/main/resources/db/migration/V2__Add_model_pricing.sql` +3. SQL Operations: + ```sql + -- Add plan_type to pricing_plans + ALTER TABLE pricing_plans ADD COLUMN plan_type VARCHAR(20) NOT NULL DEFAULT 'STANDARD'; + + -- Create model_pricing table + CREATE TABLE model_pricing ( + id UUID PRIMARY KEY, + plan_id VARCHAR(50) NOT NULL REFERENCES pricing_plans(id), + model_id VARCHAR(50) NOT NULL, + overage_rate_per_1k DECIMAL(10, 4), + prompt_rate_per_1k DECIMAL(10, 4), + completion_rate_per_1k DECIMAL(10, 4), + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(plan_id, model_id) + ); + + -- Add model_id and charge breakdown to bills + ALTER TABLE bills ADD COLUMN model_id VARCHAR(50) NOT NULL DEFAULT 'fast-model'; + ALTER TABLE bills ADD COLUMN prompt_charge DECIMAL(10, 2); + ALTER TABLE bills ADD COLUMN completion_charge DECIMAL(10, 2); + + -- Create index for model pricing lookup + CREATE INDEX idx_model_pricing_plan_model ON model_pricing(plan_id, model_id); + + -- Migrate existing plan overage rates to model_pricing for common models + INSERT INTO model_pricing (id, plan_id, model_id, overage_rate_per_1k) + SELECT gen_random_uuid(), id, 'fast-model', overage_rate_per_1k FROM pricing_plans; + + INSERT INTO model_pricing (id, plan_id, model_id, overage_rate_per_1k) + SELECT gen_random_uuid(), id, 'reasoning-model', overage_rate_per_1k FROM pricing_plans; + + -- Add a Premium plan for testing + INSERT INTO pricing_plans (id, name, monthly_quota, overage_rate_per_1k, plan_type) VALUES + ('PLAN-PREMIUM', 'Premium', 0, 0, 'PREMIUM'); + + -- Add Premium plan model pricing (prompt/completion rates) + INSERT INTO model_pricing (id, plan_id, model_id, prompt_rate_per_1k, completion_rate_per_1k) VALUES + (gen_random_uuid(), 'PLAN-PREMIUM', 'fast-model', 0.01, 0.02), + (gen_random_uuid(), 'PLAN-PREMIUM', 'reasoning-model', 0.03, 0.06); + + -- Add a Premium customer for testing + INSERT INTO customers (id, name) VALUES + ('CUST-PREMIUM', 'Premium Test Corp'); + + -- Add Premium customer subscription + INSERT INTO customer_subscriptions (id, customer_id, plan_id, effective_from) VALUES + ('d4e5f6a7-b8c9-0123-def0-456789abcdef', 'CUST-PREMIUM', 'PLAN-PREMIUM', '2026-01-01'); + ``` +4. Notes: Existing `overage_rate_per_1k` column in `pricing_plans` is retained for backward compatibility but superseded by `model_pricing` + +### Create Enum - PlanType + +1. Responsibility: Discriminator for billing strategy selection +2. Location: `domain/PlanType.java` +3. Values: + - `STANDARD` - Quota-based billing with model-aware overage + - `PREMIUM` - No quota, split prompt/completion billing +4. Notes: Simple enum, no additional methods needed + +### Create Domain Entity - ModelPricing + +1. Responsibility: Rate configuration for a specific model within a plan +2. Location: `domain/ModelPricing.java` +3. Attributes: + - `id`: UUID - Pricing configuration identifier + - `planId`: String - Reference to pricing plan + - `modelId`: String - AI model identifier (e.g., "fast-model", "reasoning-model") + - `overageRatePer1k`: BigDecimal - Rate for overage tokens (Standard plans), nullable + - `promptRatePer1k`: BigDecimal - Rate for prompt tokens (Premium plans), nullable + - `completionRatePer1k`: BigDecimal - Rate for completion tokens (Premium plans), nullable + - `createdAt`: LocalDateTime - Creation timestamp +4. Notes: No JPA annotations - pure Java POJO + +### Create Value Object - BillingContext + +1. Responsibility: Encapsulate all inputs needed for billing calculation +2. Location: `domain/BillingContext.java` +3. Attributes: + - `customerId`: String - Customer identifier + - `modelId`: String - AI model used + - `promptTokens`: int - Prompt token count + - `completionTokens`: int - Completion token count + - `remainingQuota`: int - Remaining monthly quota (0 for Premium plans) + - `modelPricing`: ModelPricing - Applicable rates +4. Notes: Immutable value object using Lombok `@Value` or `@Builder` + +### Update Domain Entity - PricingPlan + +1. Responsibility: Add plan type discriminator +2. Location: `domain/PricingPlan.java` +3. Changes: + - Add attribute `planType`: PlanType - Determines billing strategy + - Keep `monthlyQuota` (0 for Premium plans) + - Keep `overageRatePer1k` for backward compatibility (deprecated, use ModelPricing) +4. Notes: Update builder pattern to include planType + +### Update Domain Entity - Bill + +1. Responsibility: Add model info and charge breakdown +2. Location: `domain/Bill.java` +3. Changes: + - Add attribute `modelId`: String - AI model used for this bill (required, default "fast-model") + - Add attribute `promptCharge`: BigDecimal - Charge for prompt tokens (Premium only), nullable + - Add attribute `completionCharge`: BigDecimal - Charge for completion tokens (Premium only), nullable + - Deprecate static `create()` method - billing logic moves to strategies +4. New Factory Method: + - `static createStandard(String customerId, String modelId, int promptTokens, int completionTokens, int remainingQuota, BigDecimal overageRatePer1k)`: Bill + - Logic: Existing quota-first calculation, set promptCharge/completionCharge to null + - `static createPremium(String customerId, String modelId, int promptTokens, int completionTokens, BigDecimal promptRatePer1k, BigDecimal completionRatePer1k)`: Bill + - Logic: Calculate prompt/completion charges separately, includedTokensUsed = 0, overageTokens = 0 +5. Notes: Keep both factory methods in Bill domain entity; strategies call appropriate factory + +### Create Interface - BillingStrategy + +1. Responsibility: Contract for billing calculation algorithms +2. Location: `service/strategy/BillingStrategy.java` +3. Type: Interface +4. Methods: + - `Bill calculate(BillingContext context)` - Calculate and return Bill + - `PlanType supportedPlanType()` - Return the PlanType this strategy handles +5. Notes: No Spring annotations on interface + +### Create Strategy - StandardBillingStrategy + +1. Responsibility: Billing calculation for Standard (quota-based) plans +2. Location: `service/strategy/StandardBillingStrategy.java` +3. Annotations: `@Component` +4. Implements: `BillingStrategy` +5. Methods: + - `calculate(BillingContext context)`: Bill + - Logic: + 1. Extract promptTokens, completionTokens, remainingQuota from context + 2. Calculate totalTokens = promptTokens + completionTokens + 3. Calculate includedTokensUsed = min(totalTokens, max(remainingQuota, 0)) + 4. Calculate overageTokens = totalTokens - includedTokensUsed + 5. Calculate totalCharge = (overageTokens / 1000) × context.modelPricing.overageRatePer1k + 6. Call `Bill.createStandard()` with calculated values + 7. Return Bill + - `supportedPlanType()`: PlanType + - Return `PlanType.STANDARD` +6. Notes: Stateless component, thread-safe + +### Create Strategy - PremiumBillingStrategy + +1. Responsibility: Billing calculation for Premium (split-rate) plans +2. Location: `service/strategy/PremiumBillingStrategy.java` +3. Annotations: `@Component` +4. Implements: `BillingStrategy` +5. Methods: + - `calculate(BillingContext context)`: Bill + - Logic: + 1. Extract promptTokens, completionTokens, modelPricing from context + 2. Calculate promptCharge = (promptTokens / 1000) × modelPricing.promptRatePer1k + 3. Calculate completionCharge = (completionTokens / 1000) × modelPricing.completionRatePer1k + 4. Calculate totalCharge = promptCharge + completionCharge + 5. Set includedTokensUsed = 0, overageTokens = 0 (no quota concept) + 6. Call `Bill.createPremium()` with calculated values + 7. Return Bill + - `supportedPlanType()`: PlanType + - Return `PlanType.PREMIUM` +6. Notes: Stateless component, thread-safe + +### Create Factory - BillingStrategyFactory + +1. Responsibility: Resolve appropriate BillingStrategy based on PlanType +2. Location: `service/strategy/BillingStrategyFactory.java` +3. Annotations: `@Component` +4. Dependencies: `List` (Spring injects all implementations) +5. Attributes: + - `strategyMap`: Map - Lookup map built on construction +6. Constructor: + - `BillingStrategyFactory(List strategies)` + - Logic: Build map from PlanType → Strategy using each strategy's `supportedPlanType()` +7. Methods: + - `getStrategy(PlanType planType)`: BillingStrategy + - Logic: Lookup in map, throw IllegalArgumentException if not found (defensive, should never happen) +8. Notes: Initialized once at startup, O(1) lookup + +### Create Persistence Object - ModelPricingPO + +1. Responsibility: JPA entity mapping to `model_pricing` table +2. Location: `infrastructure/persistence/entity/ModelPricingPO.java` +3. Attributes: + - `id`: UUID - Primary key + - `planId`: String - Foreign key to pricing_plans + - `modelId`: String - Model identifier + - `overageRatePer1k`: BigDecimal - Overage rate (nullable) + - `promptRatePer1k`: BigDecimal - Prompt rate (nullable) + - `completionRatePer1k`: BigDecimal - Completion rate (nullable) + - `createdAt`: LocalDateTime - Creation timestamp +4. Annotations: `@Entity`, `@Table(name = "model_pricing")`, `@Id`, `@Column` + +### Update Persistence Object - PricingPlanPO + +1. Responsibility: Add plan_type column mapping +2. Location: `infrastructure/persistence/entity/PricingPlanPO.java` +3. Changes: + - Add attribute `planType`: String - Maps to `plan_type` column +4. Annotations: `@Column(name = "plan_type", length = 20, nullable = false)` + +### Update Persistence Object - BillPO + +1. Responsibility: Add model_id and charge breakdown columns +2. Location: `infrastructure/persistence/entity/BillPO.java` +3. Changes: + - Add `modelId`: String - Maps to `model_id` column (NOT NULL, default "fast-model") + - Add `promptCharge`: BigDecimal - Maps to `prompt_charge` column (nullable) + - Add `completionCharge`: BigDecimal - Maps to `completion_charge` column (nullable) +4. Annotations: `@Column(name = "model_id", length = 50, nullable = false)` for modelId; `@Column` with appropriate naming and precision for charge fields + +### Create Mapper - ModelPricingMapper + +1. Responsibility: Convert between ModelPricingPO and ModelPricing domain entity +2. Location: `infrastructure/persistence/mapper/ModelPricingMapper.java` +3. Methods: + - `toDomain(ModelPricingPO po)`: ModelPricing + +### Update Mapper - PricingPlanMapper + +1. Responsibility: Handle planType conversion +2. Location: `infrastructure/persistence/mapper/PricingPlanMapper.java` +3. Changes: + - Update `toDomain()` to convert String planType to PlanType enum +4. Logic: `PlanType.valueOf(po.getPlanType())` + +### Update Mapper - BillMapper + +1. Responsibility: Handle new Bill fields +2. Location: `infrastructure/persistence/mapper/BillMapper.java` +3. Changes: + - Update `toDomain()` to include modelId, promptCharge, completionCharge + - Update `toPO()` to include modelId, promptCharge, completionCharge + +### Create Repository Interface - ModelPricingRepository + +1. Responsibility: Define data-access contract for ModelPricing entities +2. Location: `repository/ModelPricingRepository.java` +3. Type: Interface (no Spring annotations) +4. Methods: + - `findByPlanIdAndModelId(String planId, String modelId)`: Optional + +### Create Spring Data Interface - SpringDataModelPricingRepository + +1. Responsibility: Internal Spring Data JPA interface for ModelPricingPO +2. Location: `infrastructure/persistence/SpringDataModelPricingRepository.java` +3. Interface: `extends JpaRepository` +4. Methods: + - `Optional findByPlanIdAndModelId(String planId, String modelId)` + +### Create JPA Repository - JpaModelPricingRepositoryAdapter + +1. Responsibility: Spring Data JPA implementation of ModelPricingRepository +2. Location: `infrastructure/persistence/JpaModelPricingRepositoryAdapter.java` +3. Annotations: `@Repository` +4. Dependencies: `SpringDataModelPricingRepository`, `ModelPricingMapper` +5. Implements: `ModelPricingRepository` +6. Methods: + - `findByPlanIdAndModelId(String planId, String modelId)`: Optional + - Logic: Delegate to Spring Data, map result using ModelPricingMapper.toDomain() + +### Create Exception - ModelPricingNotFoundException + +1. Responsibility: Thrown when model pricing not configured for plan+model combination +2. Location: `exception/ModelPricingNotFoundException.java` +3. Inheritance: extends RuntimeException +4. Attributes: + - `planId`: String - The plan ID + - `modelId`: String - The model ID that was not found +5. Constructors: + - `ModelPricingNotFoundException(String planId, String modelId)`: Sets message "Pricing not configured for model: {modelId}" +6. HTTP Status: 400 Bad Request + +### Update Exception Handler - GlobalExceptionHandler + +1. Responsibility: Handle ModelPricingNotFoundException +2. Location: `exception/GlobalExceptionHandler.java` +3. Changes: + - Add handler method: + ```java + @ExceptionHandler(ModelPricingNotFoundException.class) + public ResponseEntity handleModelPricingNotFoundException(ModelPricingNotFoundException ex) { + log.error("Model pricing not found: planId={}, modelId={}", ex.getPlanId(), ex.getModelId()); + ErrorResponse errorResponse = ErrorResponse.of("BAD_REQUEST", ex.getMessage()); + return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(errorResponse); + } + ``` + +### Update DTO - UsageRequest + +1. Responsibility: Add required modelId field +2. Location: `dto/UsageRequest.java` +3. Changes: + - Add attribute `modelId`: String - `@NotNull(message = "Model ID is required")` +4. Notes: Validation ensures modelId is present before service layer + +### Update DTO - BillResponse + +1. Responsibility: Add model info and charge breakdown +2. Location: `dto/BillResponse.java` +3. Changes: + - Add `modelId`: String + - Add `promptCharge`: BigDecimal (nullable, populated for Premium plans) + - Add `completionCharge`: BigDecimal (nullable, populated for Premium plans) +4. Update `fromBill(Bill bill)`: + - Map modelId, promptCharge, completionCharge from Bill + +### Update Service Implementation - BillingServiceImpl + +1. Responsibility: Orchestrate strategy-based billing +2. Location: `service/impl/BillingServiceImpl.java` +3. Constants: + - `NO_QUOTA = 0` - Represents no quota or zero quota for Premium plans + - `FIRST_DAY_OF_MONTH = 1` - First day for monthly billing period calculation + - `ONE_MONTH = 1` - Increment for calculating next month boundary +4. Changes: + - Add dependency: `ModelPricingRepository` (injected via constructor) + - Add dependency: `BillingStrategyFactory` (injected via constructor) +5. Updated `calculateBill(UsageRequest request)` method: + - Logic: + 1. Extract customerId, modelId, promptTokens, completionTokens from request + 2. Call `validateCustomerExists(customerId)` + 3. Call `resolveActivePricingPlan(customerId)` → get PricingPlan with planType + 4. Call `resolveModelPricing(plan.getId(), modelId)` → get ModelPricing + 5. Call `calculateRemainingQuota(customerId, plan)` → get remainingQuota (returns NO_QUOTA for Premium) + 6. Build `BillingContext` with all data + 7. Get strategy via `billingStrategyFactory.getStrategy(plan.getPlanType())` + 8. Call `strategy.calculate(context)` → get Bill + 9. Log billing result + 10. Save and return Bill +6. New private method `resolveModelPricing(String planId, String modelId)`: ModelPricing + - Logic: Query ModelPricingRepository, throw ModelPricingNotFoundException if not found +7. Update `calculateRemainingQuota(String customerId, PricingPlan plan)`: + - Logic: + 1. If plan.monthlyQuota is null or equals NO_QUOTA, return NO_QUOTA (Premium plans have no quota) + 2. Get current date in UTC + 3. Calculate monthStart using FIRST_DAY_OF_MONTH + 4. Calculate monthEnd using plusMonths(ONE_MONTH) and FIRST_DAY_OF_MONTH + 5. Query billRepository for current month usage + 6. Return plan.monthlyQuota - currentMonthUsage + +## Norms + +1. **Package Structure** (Extended for Strategy Pattern): + - `org.tw.token_billing.service.strategy` - Billing strategy interface and implementations + - All other packages remain unchanged from existing structure + +2. **Strategy Pattern Conventions**: + - Strategy interface in `service.strategy` package + - Strategy implementations as `@Component` beans + - Factory as `@Component` with constructor injection of all strategies + - Strategies are stateless and thread-safe + +3. **Enum Conventions**: + - Enums in `domain` package + - Simple enums without behavior (PlanType) + - Database stores enum name as VARCHAR + +4. **Nullable Fields**: + - Use Java `BigDecimal` (nullable) for optional charge fields + - Don't use `Optional` for entity fields (JPA compatibility) + - Document nullability in comments + +5. **Backward Compatibility**: + - New columns added as nullable + - Existing data continues to work with defaults + - Migration seeds data for common models + +6. **Annotation Standards** (Extended): + - Strategy implementations: `@Component` + - Factory: `@Component` + - Enum: No annotations + +7. **Naming Conventions** (Extended): + - Strategies: `{PlanType}BillingStrategy` (StandardBillingStrategy, PremiumBillingStrategy) + - Factory: `BillingStrategyFactory` + - Context/Value Objects: Descriptive noun (BillingContext) + - Enums: Singular noun (PlanType) + +8. **Calculation Precision**: + - All monetary calculations use BigDecimal + - Intermediate calculations: scale 10, RoundingMode.HALF_UP + - Final charges: scale 2, RoundingMode.HALF_UP + - Same precision rules apply to all strategies + +9. **Named Constants**: + - Use `private static final` constants for magic numbers + - Constants should have descriptive names reflecting their purpose (e.g., `NO_QUOTA`, `FIRST_DAY_OF_MONTH`) + - Prefer constants over inline literals for values used in business logic + +## Safeguards + +1. **Functional Constraints**: + - All existing functional constraints remain valid + - Model ID must be provided in every usage request + - Model pricing must be configured for the (plan, model) combination + - Premium plans must have both prompt and completion rates configured + - Standard plans must have overage rate configured for each model + +2. **Input Validation Constraints**: + - `modelId`: Required, non-null, non-empty + - Validation message for missing modelId: "Model ID is required" + - Existing validations for customerId, promptTokens, completionTokens remain + +3. **Business Rule Constraints**: + - Standard Plan: Same quota-first logic, but overage rate from model_pricing table + - Premium Plan: No quota (monthlyQuota = 0), split billing for prompt/completion + - Charge calculation formulas: + - Standard overage: `(overageTokens / 1000) × modelPricing.overageRatePer1k` + - Premium prompt: `(promptTokens / 1000) × modelPricing.promptRatePer1k` + - Premium completion: `(completionTokens / 1000) × modelPricing.completionRatePer1k` + +4. **Data Integrity Constraints**: + - model_pricing.plan_id must reference valid pricing_plans.id + - model_pricing (plan_id, model_id) combination must be unique + - bills.model_id is NOT NULL with default value 'fast-model' for backward compatibility with existing bills + +5. **Strategy Pattern Constraints**: + - Every PlanType must have exactly one BillingStrategy implementation + - Strategies must be stateless (no instance state between calls) + - BillingStrategyFactory must fail fast if strategy not found (defensive programming) + +6. **Response Constraints**: + - BillResponse must include modelId for all new bills + - promptCharge and completionCharge are null for Standard plans + - promptCharge and completionCharge are populated for Premium plans + - includedTokensUsed = 0 for Premium plans + - overageTokens = 0 for Premium plans + +7. **HTTP Response Constraints** (Extended): + - Model pricing not found: HTTP 400 with message "Pricing not configured for model: {modelId}" + - All other error responses unchanged + +8. **Migration Constraints**: + - V2 migration must be idempotent (safe to re-run) + - Existing pricing plans get plan_type = 'STANDARD' by default + - Existing overage rates are copied to model_pricing for backward compatibility + - Seed data includes 'fast-model' and 'reasoning-model' for all plans + +9. **Test Data Constraints** (for ACs): + - AC2 test setup: Standard plan customer with 100K quota, 90K used, fast-model @ $0.01/1K overage + - AC3 test setup: Premium plan customer, reasoning-model @ $0.03/1K prompt, $0.06/1K completion + - Seeded test customers: + - `CUST-001`, `CUST-002`, `CUST-003` - Standard plan customers (from V1) + - `CUST-PREMIUM` - Premium Test Corp with `PLAN-PREMIUM` subscription (from V2) + +10. **Architecture Constraints** (Extended): + - Strategies depend only on BillingContext value object (not on repositories) + - BillingServiceImpl is the only class that interacts with repositories + - Strategies contain only calculation logic, no I/O or side effects diff --git a/spdd/prompt/GGQPA-001-202603191105-[Test]-multi-plan-billing-model-aware-pricing.md b/spdd/prompt/GGQPA-001-202603191105-[Test]-multi-plan-billing-model-aware-pricing.md new file mode 100644 index 0000000..0250752 --- /dev/null +++ b/spdd/prompt/GGQPA-001-202603191105-[Test]-multi-plan-billing-model-aware-pricing.md @@ -0,0 +1,361 @@ +# Test Scenarios for Multi-Plan Billing Foundation & Model-Aware Pricing + +> **Note**: This file contains only NEW test scenarios that don't exist in the current codebase. +> Existing tests have been analyzed and excluded to avoid duplication. + +--- + +## 1. UsageController Test Scenarios (New Only) + +### Update `UsageControllerTest` class +Add the following new test scenarios to the existing `UsageControllerTest` class. + +#### should_return_201_with_charge_breakdown_when_submit_usage_given_valid_premium_plan_request +- Description: Submit valid usage for a Premium plan customer returns 201 with prompt/completion charge breakdown +- Input: POST `/api/usage` with `{"customerId": "CUST-PREMIUM", "modelId": "reasoning-model", "promptTokens": 10000, "completionTokens": 20000}` +- Expected Output: HTTP 201 with BillResponse including promptCharge and completionCharge +- Verification Points: + - HTTP status is 201 Created + - Response contains `promptCharge` (not null) + - Response contains `completionCharge` (not null) + - Response `includedTokensUsed` equals 0 + - Response `overageTokens` equals 0 + +#### should_return_400_when_submit_usage_given_missing_model_id +- Description: Missing modelId returns 400 Bad Request with validation error +- Input: POST `/api/usage` with `{"customerId": "CUST-001", "promptTokens": 1000, "completionTokens": 500}` +- Expected Output: HTTP 400 with error message "Model ID is required" +- Verification Points: + - HTTP status is 400 Bad Request + - Response contains error code "BAD_REQUEST" + - Response message contains "Model ID is required" + +#### should_return_400_when_submit_usage_given_unknown_model_id +- Description: Unknown modelId returns 400 Bad Request with pricing not configured message +- Input: POST `/api/usage` with `{"customerId": "CUST-001", "modelId": "unknown-model", "promptTokens": 1000, "completionTokens": 500}` +- Expected Output: HTTP 400 with error message "Pricing not configured for model: unknown-model" +- Verification Points: + - HTTP status is 400 Bad Request + - Response message contains "Pricing not configured for model" + - Mock `BillingService` to throw `ModelPricingNotFoundException` + +--- + +## 2. BillingServiceImpl Test Scenarios (New Only) + +### Update `BillingServiceImplTest` class +Add the following new test scenarios to the existing `BillingServiceImplTest` class. + +#### should_return_bill_with_split_charges_when_calculate_bill_given_premium_plan_usage +- Description: Premium plan usage returns bill with separate prompt and completion charges +- Input: UsageRequest with customerId="CUST-PREMIUM", modelId="reasoning-model", promptTokens=10000, completionTokens=20000 +- Expected Output: Bill with promptCharge=$0.30, completionCharge=$1.20, totalCharge=$1.50 +- Verification Points: + - Bill includedTokensUsed equals 0 + - Bill overageTokens equals 0 + - Bill promptCharge equals $0.30 (10000/1000 × $0.03) + - Bill completionCharge equals $1.20 (20000/1000 × $0.06) + - Bill totalCharge equals $1.50 + - BillingStrategyFactory.getStrategy() called with PlanType.PREMIUM +- Setup: + - Create PricingPlan with planType=PREMIUM, monthlyQuota=0 + - Create ModelPricing with promptRatePer1k=$0.03, completionRatePer1k=$0.06 + - Mock BillingStrategyFactory to return PremiumBillingStrategy + +#### should_throw_model_pricing_not_found_exception_when_calculate_bill_given_unknown_model +- Description: Unknown model throws ModelPricingNotFoundException +- Input: UsageRequest with modelId="unknown-model" for valid customer with subscription +- Expected Output: ModelPricingNotFoundException thrown +- Verification Points: + - ModelPricingNotFoundException is thrown + - Exception message contains "unknown-model" + - ModelPricingRepository.findByPlanIdAndModelId() returns empty + - BillRepository.save() NOT called + +#### should_return_zero_remaining_quota_when_calculate_bill_given_premium_plan +- Description: Premium plan always has zero remaining quota (no quota concept) +- Input: UsageRequest for Premium plan customer +- Expected Output: BillingContext built with remainingQuota=0 +- Verification Points: + - BillRepository.sumIncludedTokensUsedForMonth() NOT called for Premium plans (monthlyQuota=0) + - Bill is calculated using PremiumBillingStrategy +- Setup: + - Create PricingPlan with planType=PREMIUM, monthlyQuota=0 + +--- + +## 3. StandardBillingStrategy Test Scenarios (All New) + +### Create `StandardBillingStrategyTest` class +1. Create `StandardBillingStrategyTest` class in `service/strategy/` package +2. Instantiate `StandardBillingStrategy` directly (no mocks needed - stateless) +3. Create test scenarios based on the prompts below +4. Generate test code for each test scenario + +#### should_return_standard_when_supported_plan_type_given_strategy_instance +- Description: supportedPlanType() returns PlanType.STANDARD +- Input: StandardBillingStrategy instance +- Expected Output: PlanType.STANDARD +- Verification Points: + - supportedPlanType() returns PlanType.STANDARD + +#### should_return_bill_with_all_tokens_included_when_calculate_given_usage_within_quota +- Description: Usage within quota uses all tokens from quota, zero charge +- Input: BillingContext with promptTokens=1000, completionTokens=500, remainingQuota=10000, modelPricing with overageRatePer1k=$0.02 +- Expected Output: Bill with includedTokensUsed=1500, overageTokens=0, totalCharge=0 +- Verification Points: + - Bill.includedTokensUsed equals 1500 + - Bill.overageTokens equals 0 + - Bill.totalCharge equals BigDecimal.ZERO + - Bill.modelId equals context.modelId + +#### should_return_bill_with_overage_when_calculate_given_usage_exceeds_quota +- Description: Usage exceeding quota calculates correct overage charge +- Input: BillingContext with promptTokens=8000, completionTokens=5000, remainingQuota=10000, modelPricing with overageRatePer1k=$0.02 +- Expected Output: Bill with includedTokensUsed=10000, overageTokens=3000, totalCharge=$0.06 +- Verification Points: + - Bill.includedTokensUsed equals 10000 + - Bill.overageTokens equals 3000 + - Bill.totalCharge equals $0.06 (3000/1000 × $0.02) + +#### should_return_bill_with_all_overage_when_calculate_given_zero_remaining_quota +- Description: Zero remaining quota treats all tokens as overage +- Input: BillingContext with promptTokens=1000, completionTokens=500, remainingQuota=0, overageRatePer1k=$0.02 +- Expected Output: Bill with includedTokensUsed=0, overageTokens=1500, totalCharge=$0.03 +- Verification Points: + - Bill.includedTokensUsed equals 0 + - Bill.overageTokens equals 1500 + - Bill.totalCharge equals $0.03 + +#### should_return_bill_with_null_prompt_completion_charges_when_calculate_given_standard_plan +- Description: Standard plan bills have null promptCharge and completionCharge +- Input: Any valid BillingContext for Standard plan +- Expected Output: Bill with promptCharge=null, completionCharge=null +- Verification Points: + - Bill.promptCharge is null + - Bill.completionCharge is null + +--- + +## 4. PremiumBillingStrategy Test Scenarios (All New) + +### Create `PremiumBillingStrategyTest` class +1. Create `PremiumBillingStrategyTest` class in `service/strategy/` package +2. Instantiate `PremiumBillingStrategy` directly (no mocks needed - stateless) +3. Create test scenarios based on the prompts below +4. Generate test code for each test scenario + +#### should_return_premium_when_supported_plan_type_given_strategy_instance +- Description: supportedPlanType() returns PlanType.PREMIUM +- Input: PremiumBillingStrategy instance +- Expected Output: PlanType.PREMIUM +- Verification Points: + - supportedPlanType() returns PlanType.PREMIUM + +#### should_return_bill_with_split_charges_when_calculate_given_valid_context +- Description: Calculate separate prompt and completion charges +- Input: BillingContext with promptTokens=10000, completionTokens=20000, modelPricing with promptRatePer1k=$0.03, completionRatePer1k=$0.06 +- Expected Output: Bill with promptCharge=$0.30, completionCharge=$1.20, totalCharge=$1.50 +- Verification Points: + - Bill.promptCharge equals $0.30 + - Bill.completionCharge equals $1.20 + - Bill.totalCharge equals $1.50 + +#### should_return_bill_with_zero_included_tokens_when_calculate_given_premium_plan +- Description: Premium plan has no quota concept - includedTokensUsed always 0 +- Input: Any valid BillingContext for Premium plan +- Expected Output: Bill with includedTokensUsed=0, overageTokens=0 +- Verification Points: + - Bill.includedTokensUsed equals 0 + - Bill.overageTokens equals 0 + +#### should_return_bill_with_zero_charges_when_calculate_given_zero_tokens +- Description: Zero tokens results in zero charges +- Input: BillingContext with promptTokens=0, completionTokens=0 +- Expected Output: Bill with promptCharge=0, completionCharge=0, totalCharge=0 +- Verification Points: + - Bill.promptCharge equals BigDecimal.ZERO + - Bill.completionCharge equals BigDecimal.ZERO + - Bill.totalCharge equals BigDecimal.ZERO + +#### should_return_bill_with_only_prompt_charge_when_calculate_given_zero_completion_tokens +- Description: Only prompt tokens generates only prompt charge +- Input: BillingContext with promptTokens=10000, completionTokens=0, promptRatePer1k=$0.03 +- Expected Output: Bill with promptCharge=$0.30, completionCharge=0, totalCharge=$0.30 +- Verification Points: + - Bill.promptCharge equals $0.30 + - Bill.completionCharge equals BigDecimal.ZERO + - Bill.totalCharge equals $0.30 + +#### should_return_bill_with_only_completion_charge_when_calculate_given_zero_prompt_tokens +- Description: Only completion tokens generates only completion charge +- Input: BillingContext with promptTokens=0, completionTokens=20000, completionRatePer1k=$0.06 +- Expected Output: Bill with promptCharge=0, completionCharge=$1.20, totalCharge=$1.20 +- Verification Points: + - Bill.promptCharge equals BigDecimal.ZERO + - Bill.completionCharge equals $1.20 + - Bill.totalCharge equals $1.20 + +--- + +## 5. BillingStrategyFactory Test Scenarios (All New) + +### Create `BillingStrategyFactoryTest` class +1. Create `BillingStrategyFactoryTest` class in `service/strategy/` package +2. Create real strategy instances for testing +3. Create test scenarios based on the prompts below +4. Generate test code for each test scenario + +#### should_return_standard_strategy_when_get_strategy_given_standard_plan_type +- Description: STANDARD plan type returns StandardBillingStrategy +- Input: PlanType.STANDARD +- Expected Output: StandardBillingStrategy instance +- Verification Points: + - Returned strategy instanceof StandardBillingStrategy + - strategy.supportedPlanType() returns PlanType.STANDARD + +#### should_return_premium_strategy_when_get_strategy_given_premium_plan_type +- Description: PREMIUM plan type returns PremiumBillingStrategy +- Input: PlanType.PREMIUM +- Expected Output: PremiumBillingStrategy instance +- Verification Points: + - Returned strategy instanceof PremiumBillingStrategy + - strategy.supportedPlanType() returns PlanType.PREMIUM + +#### should_throw_illegal_argument_exception_when_get_strategy_given_null_plan_type +- Description: Null plan type throws IllegalArgumentException +- Input: null +- Expected Output: IllegalArgumentException thrown +- Verification Points: + - IllegalArgumentException is thrown + +#### should_build_strategy_map_correctly_when_construct_given_list_of_strategies +- Description: Factory correctly maps all provided strategies by their supported plan type +- Input: List containing StandardBillingStrategy and PremiumBillingStrategy +- Expected Output: Factory that can resolve both plan types +- Verification Points: + - getStrategy(STANDARD) returns StandardBillingStrategy + - getStrategy(PREMIUM) returns PremiumBillingStrategy + +--- + +## 6. JpaModelPricingRepositoryAdapter Test Scenarios (All New) + +### Create `JpaModelPricingRepositoryAdapterTest` class +1. Create `JpaModelPricingRepositoryAdapterTest` class in `infrastructure/persistence/` package +2. Use @Mock annotation to mock `SpringDataModelPricingRepository` and `ModelPricingMapper` +3. Use @InjectMocks annotation to inject the `JpaModelPricingRepositoryAdapter` instance +4. Create test scenarios based on the prompts below +5. Generate test code for each test scenario + +#### should_return_model_pricing_when_find_by_plan_id_and_model_id_given_existing_combination +- Description: Existing plan+model combination returns ModelPricing +- Input: planId="PLAN-STARTER", modelId="fast-model" +- Expected Output: Optional containing ModelPricing +- Verification Points: + - SpringDataModelPricingRepository.findByPlanIdAndModelId() called with correct params + - ModelPricingMapper.toDomain() called with returned PO + - Returned Optional is present + - Returned ModelPricing has correct planId and modelId + +#### should_return_empty_optional_when_find_by_plan_id_and_model_id_given_non_existent_combination +- Description: Non-existent plan+model combination returns empty Optional +- Input: planId="PLAN-STARTER", modelId="unknown-model" +- Expected Output: Optional.empty() +- Verification Points: + - SpringDataModelPricingRepository.findByPlanIdAndModelId() returns empty + - Returned Optional is empty + - ModelPricingMapper.toDomain() NOT called + +--- + +## 7. Mapper Test Scenarios (All New) + +### Create `ModelPricingMapperTest` class +1. Create `ModelPricingMapperTest` class in `infrastructure/persistence/mapper/` package +2. Instantiate `ModelPricingMapper` directly +3. Generate test code for each test scenario + +#### should_map_all_fields_when_to_domain_given_valid_po +- Description: toDomain() maps all fields from PO to domain entity +- Input: ModelPricingPO with id, planId, modelId, overageRatePer1k, promptRatePer1k, completionRatePer1k, createdAt +- Expected Output: ModelPricing domain entity with all fields mapped +- Verification Points: + - ModelPricing.id equals PO.id + - ModelPricing.planId equals PO.planId + - ModelPricing.modelId equals PO.modelId + - ModelPricing.overageRatePer1k equals PO.overageRatePer1k + - ModelPricing.promptRatePer1k equals PO.promptRatePer1k + - ModelPricing.completionRatePer1k equals PO.completionRatePer1k + - ModelPricing.createdAt equals PO.createdAt + +### Create `PricingPlanMapperTest` class +1. Create `PricingPlanMapperTest` class in `infrastructure/persistence/mapper/` package +2. Instantiate `PricingPlanMapper` directly +3. Generate test code for each test scenario + +#### should_convert_plan_type_string_to_enum_when_to_domain_given_standard_plan_type +- Description: toDomain() converts "STANDARD" string to PlanType.STANDARD enum +- Input: PricingPlanPO with planType="STANDARD" +- Expected Output: PricingPlan with planType=PlanType.STANDARD +- Verification Points: + - PricingPlan.planType equals PlanType.STANDARD + +#### should_convert_plan_type_string_to_enum_when_to_domain_given_premium_plan_type +- Description: toDomain() converts "PREMIUM" string to PlanType.PREMIUM enum +- Input: PricingPlanPO with planType="PREMIUM" +- Expected Output: PricingPlan with planType=PlanType.PREMIUM +- Verification Points: + - PricingPlan.planType equals PlanType.PREMIUM + +--- + +## 8. Integration Test Scenarios (New Only) + +### Update `UsageControllerIntegrationTest` class +Add the following new test scenario to the existing `UsageControllerIntegrationTest` class. + +#### should_return_201_with_premium_billing_when_submit_usage_given_premium_customer +- Description: End-to-end test for Premium plan billing with charge breakdown +- Input: POST `/api/usage` with `{"customerId": "CUST-PREMIUM", "modelId": "reasoning-model", "promptTokens": 10000, "completionTokens": 20000}` +- Expected Output: HTTP 201, bill with split charges persisted +- Verification Points: + - HTTP status is 201 Created + - Response contains `promptCharge` (not null, equals $0.30) + - Response contains `completionCharge` (not null, equals $1.20) + - Response `totalCharge` equals $1.50 + - Response `includedTokensUsed` equals 0 + - Response `overageTokens` equals 0 + - Response `modelId` equals "reasoning-model" + +--- + +## 9. Constraints + +- Test name should follow the format: `should_[expected result]_when_[action]_given_[condition]` +- All monetary assertions should use `isEqualByComparingTo()` for BigDecimal comparisons +- Use `@DisplayName` annotation for human-readable test descriptions +- Mock setup should use `when().thenReturn()` pattern consistently +- Integration tests should clean up test data between tests using `@Transactional` or `@DirtiesContext` +- Test data should match seeded migration data: + - CUST-001, CUST-002, CUST-003: Standard plan customers + - CUST-PREMIUM: Premium plan customer (with PLAN-PREMIUM subscription) + - fast-model, reasoning-model: Configured model IDs + - PLAN-PREMIUM rates: fast-model ($0.01/$0.02), reasoning-model ($0.03/$0.06) + +--- + +## Summary of New Tests + +| Test Class | New Tests | Status | +|------------|-----------|--------| +| UsageControllerTest | 3 | Add to existing | +| BillingServiceImplTest | 3 | Add to existing | +| StandardBillingStrategyTest | 5 | Create new class | +| PremiumBillingStrategyTest | 6 | Create new class | +| BillingStrategyFactoryTest | 4 | Create new class | +| JpaModelPricingRepositoryAdapterTest | 2 | Create new class | +| ModelPricingMapperTest | 1 | Create new class | +| PricingPlanMapperTest | 2 | Create new class | +| UsageControllerIntegrationTest | 1 | Add to existing | +| **Total** | **27** | | diff --git a/src/main/java/org/tw/token_billing/domain/Bill.java b/src/main/java/org/tw/token_billing/domain/Bill.java index ac0cda2..b2b0946 100644 --- a/src/main/java/org/tw/token_billing/domain/Bill.java +++ b/src/main/java/org/tw/token_billing/domain/Bill.java @@ -20,16 +20,25 @@ public class Bill { private final UUID id; private final String customerId; + private final String modelId; private final Integer promptTokens; private final Integer completionTokens; private final Integer totalTokens; private final Integer includedTokensUsed; private final Integer overageTokens; + private final BigDecimal promptCharge; + private final BigDecimal completionCharge; private final BigDecimal totalCharge; private final LocalDateTime calculatedAt; + @Deprecated public static Bill create(String customerId, int promptTokens, int completionTokens, int remainingQuota, BigDecimal overageRatePer1k) { + return createStandard(customerId, null, promptTokens, completionTokens, remainingQuota, overageRatePer1k); + } + + public static Bill createStandard(String customerId, String modelId, int promptTokens, int completionTokens, + int remainingQuota, BigDecimal overageRatePer1k) { int totalTokens = promptTokens + completionTokens; int includedTokensUsed = Math.min(totalTokens, Math.max(remainingQuota, 0)); int overageTokens = totalTokens - includedTokensUsed; @@ -42,11 +51,46 @@ public static Bill create(String customerId, int promptTokens, int completionTok return Bill.builder() .id(UUID.randomUUID()) .customerId(customerId) + .modelId(modelId) .promptTokens(promptTokens) .completionTokens(completionTokens) .totalTokens(totalTokens) .includedTokensUsed(includedTokensUsed) .overageTokens(overageTokens) + .promptCharge(null) + .completionCharge(null) + .totalCharge(totalCharge) + .calculatedAt(LocalDateTime.now(ZoneOffset.UTC)) + .build(); + } + + public static Bill createPremium(String customerId, String modelId, int promptTokens, int completionTokens, + BigDecimal promptRatePer1k, BigDecimal completionRatePer1k) { + int totalTokens = promptTokens + completionTokens; + + BigDecimal promptCharge = BigDecimal.valueOf(promptTokens) + .divide(BigDecimal.valueOf(TOKENS_PER_PRICING_UNIT), CALCULATION_PRECISION_SCALE, RoundingMode.HALF_UP) + .multiply(promptRatePer1k) + .setScale(CURRENCY_SCALE, RoundingMode.HALF_UP); + + BigDecimal completionCharge = BigDecimal.valueOf(completionTokens) + .divide(BigDecimal.valueOf(TOKENS_PER_PRICING_UNIT), CALCULATION_PRECISION_SCALE, RoundingMode.HALF_UP) + .multiply(completionRatePer1k) + .setScale(CURRENCY_SCALE, RoundingMode.HALF_UP); + + BigDecimal totalCharge = promptCharge.add(completionCharge); + + return Bill.builder() + .id(UUID.randomUUID()) + .customerId(customerId) + .modelId(modelId) + .promptTokens(promptTokens) + .completionTokens(completionTokens) + .totalTokens(totalTokens) + .includedTokensUsed(0) + .overageTokens(0) + .promptCharge(promptCharge) + .completionCharge(completionCharge) .totalCharge(totalCharge) .calculatedAt(LocalDateTime.now(ZoneOffset.UTC)) .build(); diff --git a/src/main/java/org/tw/token_billing/domain/BillingContext.java b/src/main/java/org/tw/token_billing/domain/BillingContext.java new file mode 100644 index 0000000..52b828f --- /dev/null +++ b/src/main/java/org/tw/token_billing/domain/BillingContext.java @@ -0,0 +1,15 @@ +package org.tw.token_billing.domain; + +import lombok.Builder; +import lombok.Getter; + +@Getter +@Builder +public class BillingContext { + private final String customerId; + private final String modelId; + private final int promptTokens; + private final int completionTokens; + private final int remainingQuota; + private final ModelPricing modelPricing; +} diff --git a/src/main/java/org/tw/token_billing/domain/ModelPricing.java b/src/main/java/org/tw/token_billing/domain/ModelPricing.java new file mode 100644 index 0000000..0ea02f2 --- /dev/null +++ b/src/main/java/org/tw/token_billing/domain/ModelPricing.java @@ -0,0 +1,22 @@ +package org.tw.token_billing.domain; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.UUID; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +@Getter +@Builder +@AllArgsConstructor +public class ModelPricing { + private final UUID id; + private final String planId; + private final String modelId; + private final BigDecimal overageRatePer1k; + private final BigDecimal promptRatePer1k; + private final BigDecimal completionRatePer1k; + private final LocalDateTime createdAt; +} diff --git a/src/main/java/org/tw/token_billing/domain/PlanType.java b/src/main/java/org/tw/token_billing/domain/PlanType.java new file mode 100644 index 0000000..34156df --- /dev/null +++ b/src/main/java/org/tw/token_billing/domain/PlanType.java @@ -0,0 +1,6 @@ +package org.tw.token_billing.domain; + +public enum PlanType { + STANDARD, + PREMIUM +} diff --git a/src/main/java/org/tw/token_billing/domain/PricingPlan.java b/src/main/java/org/tw/token_billing/domain/PricingPlan.java index 506dcc1..8caf5b1 100644 --- a/src/main/java/org/tw/token_billing/domain/PricingPlan.java +++ b/src/main/java/org/tw/token_billing/domain/PricingPlan.java @@ -13,7 +13,9 @@ public class PricingPlan { private final String id; private final String name; + private final PlanType planType; private final Integer monthlyQuota; + @Deprecated private final BigDecimal overageRatePer1k; private final LocalDateTime createdAt; } diff --git a/src/main/java/org/tw/token_billing/dto/BillResponse.java b/src/main/java/org/tw/token_billing/dto/BillResponse.java index 5c6ceb2..fca9801 100644 --- a/src/main/java/org/tw/token_billing/dto/BillResponse.java +++ b/src/main/java/org/tw/token_billing/dto/BillResponse.java @@ -20,9 +20,12 @@ public class BillResponse { private UUID billId; private String customerId; + private String modelId; private Integer totalTokens; private Integer includedTokensUsed; private Integer overageTokens; + private BigDecimal promptCharge; + private BigDecimal completionCharge; private BigDecimal totalCharge; private LocalDateTime calculatedAt; @@ -30,9 +33,12 @@ public static BillResponse fromBill(Bill bill) { return BillResponse.builder() .billId(bill.getId()) .customerId(bill.getCustomerId()) + .modelId(bill.getModelId()) .totalTokens(bill.getTotalTokens()) .includedTokensUsed(bill.getIncludedTokensUsed()) .overageTokens(bill.getOverageTokens()) + .promptCharge(bill.getPromptCharge()) + .completionCharge(bill.getCompletionCharge()) .totalCharge(bill.getTotalCharge()) .calculatedAt(bill.getCalculatedAt()) .build(); diff --git a/src/main/java/org/tw/token_billing/dto/UsageRequest.java b/src/main/java/org/tw/token_billing/dto/UsageRequest.java index 9225288..3e0715a 100644 --- a/src/main/java/org/tw/token_billing/dto/UsageRequest.java +++ b/src/main/java/org/tw/token_billing/dto/UsageRequest.java @@ -18,6 +18,9 @@ public class UsageRequest { @NotNull(message = "Customer ID is required") private String customerId; + @NotNull(message = "Model ID is required") + private String modelId; + @NotNull(message = "Token count cannot be negative") @Min(value = 0, message = "Token count cannot be negative") private Integer promptTokens; diff --git a/src/main/java/org/tw/token_billing/exception/GlobalExceptionHandler.java b/src/main/java/org/tw/token_billing/exception/GlobalExceptionHandler.java index c234d54..92700b3 100644 --- a/src/main/java/org/tw/token_billing/exception/GlobalExceptionHandler.java +++ b/src/main/java/org/tw/token_billing/exception/GlobalExceptionHandler.java @@ -49,4 +49,11 @@ public ResponseEntity handleConstraintViolationException(Constrai ErrorResponse errorResponse = ErrorResponse.of("BAD_REQUEST", message); return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(errorResponse); } + + @ExceptionHandler(ModelPricingNotFoundException.class) + public ResponseEntity handleModelPricingNotFoundException(ModelPricingNotFoundException ex) { + log.error("Model pricing not found: planId={}, modelId={}", ex.getPlanId(), ex.getModelId()); + ErrorResponse errorResponse = ErrorResponse.of("BAD_REQUEST", ex.getMessage()); + return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(errorResponse); + } } diff --git a/src/main/java/org/tw/token_billing/exception/ModelPricingNotFoundException.java b/src/main/java/org/tw/token_billing/exception/ModelPricingNotFoundException.java new file mode 100644 index 0000000..7c40613 --- /dev/null +++ b/src/main/java/org/tw/token_billing/exception/ModelPricingNotFoundException.java @@ -0,0 +1,15 @@ +package org.tw.token_billing.exception; + +import lombok.Getter; + +@Getter +public class ModelPricingNotFoundException extends RuntimeException { + private final String planId; + private final String modelId; + + public ModelPricingNotFoundException(String planId, String modelId) { + super("Pricing not configured for model: " + modelId); + this.planId = planId; + this.modelId = modelId; + } +} diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapter.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapter.java new file mode 100644 index 0000000..ea9111d --- /dev/null +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapter.java @@ -0,0 +1,24 @@ +package org.tw.token_billing.infrastructure.persistence; + +import java.util.Optional; + +import org.springframework.stereotype.Repository; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.infrastructure.persistence.mapper.ModelPricingMapper; +import org.tw.token_billing.repository.ModelPricingRepository; + +import lombok.RequiredArgsConstructor; + +@Repository +@RequiredArgsConstructor +public class JpaModelPricingRepositoryAdapter implements ModelPricingRepository { + + private final SpringDataModelPricingRepository springDataRepository; + private final ModelPricingMapper mapper; + + @Override + public Optional findByPlanIdAndModelId(String planId, String modelId) { + return springDataRepository.findByPlanIdAndModelId(planId, modelId) + .map(mapper::toDomain); + } +} diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/SpringDataModelPricingRepository.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/SpringDataModelPricingRepository.java new file mode 100644 index 0000000..aed4ee3 --- /dev/null +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/SpringDataModelPricingRepository.java @@ -0,0 +1,11 @@ +package org.tw.token_billing.infrastructure.persistence; + +import java.util.Optional; +import java.util.UUID; + +import org.springframework.data.jpa.repository.JpaRepository; +import org.tw.token_billing.infrastructure.persistence.entity.ModelPricingPO; + +public interface SpringDataModelPricingRepository extends JpaRepository { + Optional findByPlanIdAndModelId(String planId, String modelId); +} diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/BillPO.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/BillPO.java index 0c1dbcc..c8fcd5d 100644 --- a/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/BillPO.java +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/BillPO.java @@ -30,6 +30,9 @@ public class BillPO { @Column(name = "customer_id", length = 50, nullable = false) private String customerId; + @Column(name = "model_id", length = 50, nullable = false) + private String modelId; + @Column(name = "prompt_tokens", nullable = false) private Integer promptTokens; @@ -45,6 +48,12 @@ public class BillPO { @Column(name = "overage_tokens", nullable = false) private Integer overageTokens; + @Column(name = "prompt_charge", precision = 10, scale = 2) + private BigDecimal promptCharge; + + @Column(name = "completion_charge", precision = 10, scale = 2) + private BigDecimal completionCharge; + @Column(name = "total_charge", precision = 10, scale = 2, nullable = false) private BigDecimal totalCharge; diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/ModelPricingPO.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/ModelPricingPO.java new file mode 100644 index 0000000..2280deb --- /dev/null +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/ModelPricingPO.java @@ -0,0 +1,47 @@ +package org.tw.token_billing.infrastructure.persistence.entity; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.UUID; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Entity +@Table(name = "model_pricing") +@Getter +@Setter +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ModelPricingPO { + + @Id + @Column(name = "id") + private UUID id; + + @Column(name = "plan_id", length = 50, nullable = false) + private String planId; + + @Column(name = "model_id", length = 50, nullable = false) + private String modelId; + + @Column(name = "overage_rate_per_1k", precision = 10, scale = 4) + private BigDecimal overageRatePer1k; + + @Column(name = "prompt_rate_per_1k", precision = 10, scale = 4) + private BigDecimal promptRatePer1k; + + @Column(name = "completion_rate_per_1k", precision = 10, scale = 4) + private BigDecimal completionRatePer1k; + + @Column(name = "created_at", nullable = false) + private LocalDateTime createdAt; +} diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/PricingPlanPO.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/PricingPlanPO.java index a8d9728..f4090e2 100644 --- a/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/PricingPlanPO.java +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/entity/PricingPlanPO.java @@ -35,6 +35,9 @@ public class PricingPlanPO { @Column(name = "overage_rate_per_1k", precision = 10, scale = 4, nullable = false) private BigDecimal overageRatePer1k; + @Column(name = "plan_type", length = 20, nullable = false) + private String planType; + @Column(name = "created_at", nullable = false) private LocalDateTime createdAt; } diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/BillMapper.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/BillMapper.java index f66a755..ed1c597 100644 --- a/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/BillMapper.java +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/BillMapper.java @@ -14,11 +14,14 @@ public Bill toDomain(BillPO po) { return Bill.builder() .id(po.getId()) .customerId(po.getCustomerId()) + .modelId(po.getModelId()) .promptTokens(po.getPromptTokens()) .completionTokens(po.getCompletionTokens()) .totalTokens(po.getTotalTokens()) .includedTokensUsed(po.getIncludedTokensUsed()) .overageTokens(po.getOverageTokens()) + .promptCharge(po.getPromptCharge()) + .completionCharge(po.getCompletionCharge()) .totalCharge(po.getTotalCharge()) .calculatedAt(po.getCalculatedAt()) .build(); @@ -31,11 +34,14 @@ public BillPO toPO(Bill domain) { return BillPO.builder() .id(domain.getId()) .customerId(domain.getCustomerId()) + .modelId(domain.getModelId()) .promptTokens(domain.getPromptTokens()) .completionTokens(domain.getCompletionTokens()) .totalTokens(domain.getTotalTokens()) .includedTokensUsed(domain.getIncludedTokensUsed()) .overageTokens(domain.getOverageTokens()) + .promptCharge(domain.getPromptCharge()) + .completionCharge(domain.getCompletionCharge()) .totalCharge(domain.getTotalCharge()) .calculatedAt(domain.getCalculatedAt()) .build(); diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapper.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapper.java new file mode 100644 index 0000000..fc545d3 --- /dev/null +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapper.java @@ -0,0 +1,24 @@ +package org.tw.token_billing.infrastructure.persistence.mapper; + +import org.springframework.stereotype.Component; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.infrastructure.persistence.entity.ModelPricingPO; + +@Component +public class ModelPricingMapper { + + public ModelPricing toDomain(ModelPricingPO po) { + if (po == null) { + return null; + } + return ModelPricing.builder() + .id(po.getId()) + .planId(po.getPlanId()) + .modelId(po.getModelId()) + .overageRatePer1k(po.getOverageRatePer1k()) + .promptRatePer1k(po.getPromptRatePer1k()) + .completionRatePer1k(po.getCompletionRatePer1k()) + .createdAt(po.getCreatedAt()) + .build(); + } +} diff --git a/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapper.java b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapper.java index 9ec5567..a14b09e 100644 --- a/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapper.java +++ b/src/main/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapper.java @@ -1,6 +1,7 @@ package org.tw.token_billing.infrastructure.persistence.mapper; import org.springframework.stereotype.Component; +import org.tw.token_billing.domain.PlanType; import org.tw.token_billing.domain.PricingPlan; import org.tw.token_billing.infrastructure.persistence.entity.PricingPlanPO; @@ -14,6 +15,7 @@ public PricingPlan toDomain(PricingPlanPO po) { return PricingPlan.builder() .id(po.getId()) .name(po.getName()) + .planType(PlanType.valueOf(po.getPlanType())) .monthlyQuota(po.getMonthlyQuota()) .overageRatePer1k(po.getOverageRatePer1k()) .createdAt(po.getCreatedAt()) diff --git a/src/main/java/org/tw/token_billing/repository/ModelPricingRepository.java b/src/main/java/org/tw/token_billing/repository/ModelPricingRepository.java new file mode 100644 index 0000000..ebe2fbe --- /dev/null +++ b/src/main/java/org/tw/token_billing/repository/ModelPricingRepository.java @@ -0,0 +1,9 @@ +package org.tw.token_billing.repository; + +import java.util.Optional; + +import org.tw.token_billing.domain.ModelPricing; + +public interface ModelPricingRepository { + Optional findByPlanIdAndModelId(String planId, String modelId); +} diff --git a/src/main/java/org/tw/token_billing/service/impl/BillingServiceImpl.java b/src/main/java/org/tw/token_billing/service/impl/BillingServiceImpl.java index aa4c2fa..1c35000 100644 --- a/src/main/java/org/tw/token_billing/service/impl/BillingServiceImpl.java +++ b/src/main/java/org/tw/token_billing/service/impl/BillingServiceImpl.java @@ -7,15 +7,21 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; import org.tw.token_billing.domain.CustomerSubscription; +import org.tw.token_billing.domain.ModelPricing; import org.tw.token_billing.domain.PricingPlan; import org.tw.token_billing.dto.UsageRequest; import org.tw.token_billing.exception.CustomerNotFoundException; +import org.tw.token_billing.exception.ModelPricingNotFoundException; import org.tw.token_billing.exception.NoActiveSubscriptionException; import org.tw.token_billing.repository.BillRepository; import org.tw.token_billing.repository.CustomerRepository; import org.tw.token_billing.repository.CustomerSubscriptionRepository; +import org.tw.token_billing.repository.ModelPricingRepository; import org.tw.token_billing.service.BillingService; +import org.tw.token_billing.service.strategy.BillingStrategy; +import org.tw.token_billing.service.strategy.BillingStrategyFactory; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -26,28 +32,40 @@ @RequiredArgsConstructor public class BillingServiceImpl implements BillingService { + private static final int NO_QUOTA = 0; + private static final int FIRST_DAY_OF_MONTH = 1; + private static final int ONE_MONTH = 1; + private final CustomerRepository customerRepository; private final CustomerSubscriptionRepository customerSubscriptionRepository; private final BillRepository billRepository; + private final ModelPricingRepository modelPricingRepository; + private final BillingStrategyFactory billingStrategyFactory; @Override public Bill calculateBill(UsageRequest request) { String customerId = request.getCustomerId(); + String modelId = request.getModelId(); validateCustomerExists(customerId); PricingPlan plan = resolveActivePricingPlan(customerId); + ModelPricing modelPricing = resolveModelPricing(plan.getId(), modelId); int remainingQuota = calculateRemainingQuota(customerId, plan); - Bill bill = Bill.create( - customerId, - request.getPromptTokens(), - request.getCompletionTokens(), - remainingQuota, - plan.getOverageRatePer1k() - ); + BillingContext context = BillingContext.builder() + .customerId(customerId) + .modelId(modelId) + .promptTokens(request.getPromptTokens()) + .completionTokens(request.getCompletionTokens()) + .remainingQuota(remainingQuota) + .modelPricing(modelPricing) + .build(); + + BillingStrategy strategy = billingStrategyFactory.getStrategy(plan.getPlanType()); + Bill bill = strategy.calculate(context); - log.info("Calculated bill for customer {}: totalTokens={}, includedTokensUsed={}, overageTokens={}, totalCharge={}", - customerId, bill.getTotalTokens(), bill.getIncludedTokensUsed(), + log.info("Calculated bill for customer {}: model={}, planType={}, totalTokens={}, includedTokensUsed={}, overageTokens={}, totalCharge={}", + customerId, modelId, plan.getPlanType(), bill.getTotalTokens(), bill.getIncludedTokensUsed(), bill.getOverageTokens(), bill.getTotalCharge()); return billRepository.save(bill); @@ -68,10 +86,19 @@ private PricingPlan resolveActivePricingPlan(String customerId) { return subscription.getPlan(); } + private ModelPricing resolveModelPricing(String planId, String modelId) { + return modelPricingRepository.findByPlanIdAndModelId(planId, modelId) + .orElseThrow(() -> new ModelPricingNotFoundException(planId, modelId)); + } + private int calculateRemainingQuota(String customerId, PricingPlan plan) { + if (plan.getMonthlyQuota() == null || plan.getMonthlyQuota() == NO_QUOTA) { + return NO_QUOTA; + } + LocalDate currentDate = LocalDate.now(ZoneOffset.UTC); - LocalDateTime monthStart = currentDate.withDayOfMonth(1).atStartOfDay(); - LocalDateTime monthEnd = currentDate.plusMonths(1).withDayOfMonth(1).atStartOfDay(); + LocalDateTime monthStart = currentDate.withDayOfMonth(FIRST_DAY_OF_MONTH).atStartOfDay(); + LocalDateTime monthEnd = currentDate.plusMonths(ONE_MONTH).withDayOfMonth(FIRST_DAY_OF_MONTH).atStartOfDay(); Integer currentMonthUsage = billRepository.sumIncludedTokensUsedForMonth(customerId, monthStart, monthEnd); return plan.getMonthlyQuota() - currentMonthUsage; diff --git a/src/main/java/org/tw/token_billing/service/strategy/BillingStrategy.java b/src/main/java/org/tw/token_billing/service/strategy/BillingStrategy.java new file mode 100644 index 0000000..969bbd4 --- /dev/null +++ b/src/main/java/org/tw/token_billing/service/strategy/BillingStrategy.java @@ -0,0 +1,10 @@ +package org.tw.token_billing.service.strategy; + +import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; +import org.tw.token_billing.domain.PlanType; + +public interface BillingStrategy { + Bill calculate(BillingContext context); + PlanType supportedPlanType(); +} diff --git a/src/main/java/org/tw/token_billing/service/strategy/BillingStrategyFactory.java b/src/main/java/org/tw/token_billing/service/strategy/BillingStrategyFactory.java new file mode 100644 index 0000000..2e9ebd6 --- /dev/null +++ b/src/main/java/org/tw/token_billing/service/strategy/BillingStrategyFactory.java @@ -0,0 +1,31 @@ +package org.tw.token_billing.service.strategy; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.springframework.stereotype.Component; +import org.tw.token_billing.domain.PlanType; + +@Component +public class BillingStrategyFactory { + + private final Map strategyMap; + + public BillingStrategyFactory(List strategies) { + this.strategyMap = strategies.stream() + .collect(Collectors.toMap( + BillingStrategy::supportedPlanType, + Function.identity() + )); + } + + public BillingStrategy getStrategy(PlanType planType) { + BillingStrategy strategy = strategyMap.get(planType); + if (strategy == null) { + throw new IllegalArgumentException("No billing strategy found for plan type: " + planType); + } + return strategy; + } +} diff --git a/src/main/java/org/tw/token_billing/service/strategy/PremiumBillingStrategy.java b/src/main/java/org/tw/token_billing/service/strategy/PremiumBillingStrategy.java new file mode 100644 index 0000000..8986273 --- /dev/null +++ b/src/main/java/org/tw/token_billing/service/strategy/PremiumBillingStrategy.java @@ -0,0 +1,27 @@ +package org.tw.token_billing.service.strategy; + +import org.springframework.stereotype.Component; +import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; +import org.tw.token_billing.domain.PlanType; + +@Component +public class PremiumBillingStrategy implements BillingStrategy { + + @Override + public Bill calculate(BillingContext context) { + return Bill.createPremium( + context.getCustomerId(), + context.getModelId(), + context.getPromptTokens(), + context.getCompletionTokens(), + context.getModelPricing().getPromptRatePer1k(), + context.getModelPricing().getCompletionRatePer1k() + ); + } + + @Override + public PlanType supportedPlanType() { + return PlanType.PREMIUM; + } +} diff --git a/src/main/java/org/tw/token_billing/service/strategy/StandardBillingStrategy.java b/src/main/java/org/tw/token_billing/service/strategy/StandardBillingStrategy.java new file mode 100644 index 0000000..da06494 --- /dev/null +++ b/src/main/java/org/tw/token_billing/service/strategy/StandardBillingStrategy.java @@ -0,0 +1,27 @@ +package org.tw.token_billing.service.strategy; + +import org.springframework.stereotype.Component; +import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; +import org.tw.token_billing.domain.PlanType; + +@Component +public class StandardBillingStrategy implements BillingStrategy { + + @Override + public Bill calculate(BillingContext context) { + return Bill.createStandard( + context.getCustomerId(), + context.getModelId(), + context.getPromptTokens(), + context.getCompletionTokens(), + context.getRemainingQuota(), + context.getModelPricing().getOverageRatePer1k() + ); + } + + @Override + public PlanType supportedPlanType() { + return PlanType.STANDARD; + } +} diff --git a/src/main/resources/db/migration/V2__Add_model_pricing.sql b/src/main/resources/db/migration/V2__Add_model_pricing.sql new file mode 100644 index 0000000..c438804 --- /dev/null +++ b/src/main/resources/db/migration/V2__Add_model_pricing.sql @@ -0,0 +1,46 @@ +-- Add plan_type to pricing_plans +ALTER TABLE pricing_plans ADD COLUMN plan_type VARCHAR(20) NOT NULL DEFAULT 'STANDARD'; + +-- Create model_pricing table +CREATE TABLE model_pricing ( + id UUID PRIMARY KEY, + plan_id VARCHAR(50) NOT NULL REFERENCES pricing_plans(id), + model_id VARCHAR(50) NOT NULL, + overage_rate_per_1k DECIMAL(10, 4), + prompt_rate_per_1k DECIMAL(10, 4), + completion_rate_per_1k DECIMAL(10, 4), + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(plan_id, model_id) +); + +-- Add model_id and charge breakdown to bills +ALTER TABLE bills ADD COLUMN model_id VARCHAR(50) NOT NULL DEFAULT 'fast-model'; +ALTER TABLE bills ADD COLUMN prompt_charge DECIMAL(10, 2); +ALTER TABLE bills ADD COLUMN completion_charge DECIMAL(10, 2); + +-- Create index for model pricing lookup +CREATE INDEX idx_model_pricing_plan_model ON model_pricing(plan_id, model_id); + +-- Migrate existing plan overage rates to model_pricing for common models +INSERT INTO model_pricing (id, plan_id, model_id, overage_rate_per_1k) +SELECT gen_random_uuid(), id, 'fast-model', overage_rate_per_1k FROM pricing_plans; + +INSERT INTO model_pricing (id, plan_id, model_id, overage_rate_per_1k) +SELECT gen_random_uuid(), id, 'reasoning-model', overage_rate_per_1k FROM pricing_plans; + +-- Add a Premium plan for testing +INSERT INTO pricing_plans (id, name, monthly_quota, overage_rate_per_1k, plan_type) VALUES + ('PLAN-PREMIUM', 'Premium', 0, 0, 'PREMIUM'); + +-- Add Premium plan model pricing (prompt/completion rates) +INSERT INTO model_pricing (id, plan_id, model_id, prompt_rate_per_1k, completion_rate_per_1k) VALUES + (gen_random_uuid(), 'PLAN-PREMIUM', 'fast-model', 0.01, 0.02), + (gen_random_uuid(), 'PLAN-PREMIUM', 'reasoning-model', 0.03, 0.06); + +-- Add a Premium customer for testing +INSERT INTO customers (id, name) VALUES + ('CUST-PREMIUM', 'Premium Test Corp'); + +-- Add Premium customer subscription +INSERT INTO customer_subscriptions (id, customer_id, plan_id, effective_from) VALUES + ('d4e5f6a7-b8c9-0123-def0-456789abcdef', 'CUST-PREMIUM', 'PLAN-PREMIUM', '2026-01-01'); diff --git a/src/test/java/org/tw/token_billing/UsageControllerIntegrationTest.java b/src/test/java/org/tw/token_billing/UsageControllerIntegrationTest.java index 3e8b497..bbffbf3 100644 --- a/src/test/java/org/tw/token_billing/UsageControllerIntegrationTest.java +++ b/src/test/java/org/tw/token_billing/UsageControllerIntegrationTest.java @@ -39,6 +39,7 @@ void should_return_201_and_persist_bill_when_submit_usage_given_valid_customer_w .content(""" { "customerId": "CUST-001", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -46,6 +47,7 @@ void should_return_201_and_persist_bill_when_submit_usage_given_valid_customer_w .andExpect(status().isCreated()) .andExpect(jsonPath("$.billId").exists()) .andExpect(jsonPath("$.customerId").value("CUST-001")) + .andExpect(jsonPath("$.modelId").value("fast-model")) .andExpect(jsonPath("$.totalTokens").value(1500)) .andExpect(jsonPath("$.includedTokensUsed").exists()) .andExpect(jsonPath("$.overageTokens").exists()) @@ -65,6 +67,7 @@ void should_return_404_when_submit_usage_given_non_existent_customer() throws Ex .content(""" { "customerId": "NON-EXISTENT-CUSTOMER", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -90,6 +93,7 @@ void should_track_quota_correctly_when_submit_multiple_usages_given_same_custome .content(""" { "customerId": "CUST-002", + "modelId": "fast-model", "promptTokens": 5000, "completionTokens": 3000 } @@ -106,6 +110,7 @@ void should_track_quota_correctly_when_submit_multiple_usages_given_same_custome .content(""" { "customerId": "CUST-002", + "modelId": "fast-model", "promptTokens": 5000, "completionTokens": 3000 } @@ -130,6 +135,7 @@ void should_calculate_overage_when_submit_usage_given_quota_exhausted() throws E .content(""" { "customerId": "CUST-002", + "modelId": "fast-model", "promptTokens": 4000, "completionTokens": 2000 } @@ -142,6 +148,7 @@ void should_calculate_overage_when_submit_usage_given_quota_exhausted() throws E .content(""" { "customerId": "CUST-002", + "modelId": "fast-model", "promptTokens": 3000, "completionTokens": 2000 } @@ -157,4 +164,36 @@ void should_calculate_overage_when_submit_usage_given_quota_exhausted() throws E assertThat(totalCharge).isGreaterThan(BigDecimal.ZERO); } } + + @Test + @DisplayName("Should return 201 with premium billing when submitting usage for premium customer") + void should_return_201_with_premium_billing_when_submit_usage_given_premium_customer() throws Exception { + MvcResult result = mockMvc.perform(post("/api/usage") + .contentType(MediaType.APPLICATION_JSON) + .content(""" + { + "customerId": "CUST-PREMIUM", + "modelId": "reasoning-model", + "promptTokens": 10000, + "completionTokens": 20000 + } + """)) + .andExpect(status().isCreated()) + .andExpect(jsonPath("$.billId").exists()) + .andExpect(jsonPath("$.customerId").value("CUST-PREMIUM")) + .andExpect(jsonPath("$.modelId").value("reasoning-model")) + .andExpect(jsonPath("$.totalTokens").value(30000)) + .andExpect(jsonPath("$.includedTokensUsed").value(0)) + .andExpect(jsonPath("$.overageTokens").value(0)) + .andExpect(jsonPath("$.promptCharge").value(0.30)) + .andExpect(jsonPath("$.completionCharge").value(1.20)) + .andExpect(jsonPath("$.totalCharge").value(1.50)) + .andReturn(); + + JsonNode response = objectMapper.readTree(result.getResponse().getContentAsString()); + assertThat(response.get("billId").asText()).isNotEmpty(); + assertThat(new BigDecimal(response.get("promptCharge").asText())).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(new BigDecimal(response.get("completionCharge").asText())).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(new BigDecimal(response.get("totalCharge").asText())).isEqualByComparingTo(new BigDecimal("1.50")); + } } diff --git a/src/test/java/org/tw/token_billing/controller/UsageControllerTest.java b/src/test/java/org/tw/token_billing/controller/UsageControllerTest.java index 67ace07..f77ead8 100644 --- a/src/test/java/org/tw/token_billing/controller/UsageControllerTest.java +++ b/src/test/java/org/tw/token_billing/controller/UsageControllerTest.java @@ -22,6 +22,7 @@ import org.tw.token_billing.domain.Bill; import org.tw.token_billing.dto.UsageRequest; import org.tw.token_billing.exception.CustomerNotFoundException; +import org.tw.token_billing.exception.ModelPricingNotFoundException; import org.tw.token_billing.exception.NoActiveSubscriptionException; import org.tw.token_billing.service.BillingService; @@ -40,11 +41,14 @@ void should_return_201_created_with_bill_response_when_submit_usage_given_valid_ Bill bill = Bill.builder() .id(UUID.randomUUID()) .customerId("CUST-001") + .modelId("fast-model") .promptTokens(1000) .completionTokens(500) .totalTokens(1500) .includedTokensUsed(1500) .overageTokens(0) + .promptCharge(null) + .completionCharge(null) .totalCharge(BigDecimal.ZERO) .calculatedAt(LocalDateTime.now()) .build(); @@ -56,6 +60,7 @@ void should_return_201_created_with_bill_response_when_submit_usage_given_valid_ .content(""" { "customerId": "CUST-001", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -63,6 +68,7 @@ void should_return_201_created_with_bill_response_when_submit_usage_given_valid_ .andExpect(status().isCreated()) .andExpect(jsonPath("$.billId").exists()) .andExpect(jsonPath("$.customerId").value("CUST-001")) + .andExpect(jsonPath("$.modelId").value("fast-model")) .andExpect(jsonPath("$.totalTokens").value(1500)) .andExpect(jsonPath("$.includedTokensUsed").value(1500)) .andExpect(jsonPath("$.overageTokens").value(0)) @@ -79,6 +85,7 @@ void should_return_400_bad_request_when_submit_usage_given_missing_customer_id() .contentType(MediaType.APPLICATION_JSON) .content(""" { + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -97,6 +104,7 @@ void should_return_400_bad_request_when_submit_usage_given_negative_prompt_token .content(""" { "customerId": "CUST-001", + "modelId": "fast-model", "promptTokens": -100, "completionTokens": 500 } @@ -115,6 +123,7 @@ void should_return_400_bad_request_when_submit_usage_given_negative_completion_t .content(""" { "customerId": "CUST-001", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": -500 } @@ -133,6 +142,7 @@ void should_return_400_bad_request_when_submit_usage_given_null_prompt_tokens() .content(""" { "customerId": "CUST-001", + "modelId": "fast-model", "completionTokens": 500 } """)) @@ -152,6 +162,7 @@ void should_return_404_not_found_when_submit_usage_given_non_existent_customer() .content(""" { "customerId": "INVALID-CUSTOMER", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -171,6 +182,7 @@ void should_return_422_unprocessable_entity_when_submit_usage_given_no_active_su .content(""" { "customerId": "CUST-NO-SUB", + "modelId": "fast-model", "promptTokens": 1000, "completionTokens": 500 } @@ -178,4 +190,81 @@ void should_return_422_unprocessable_entity_when_submit_usage_given_no_active_su .andExpect(status().isUnprocessableEntity()) .andExpect(jsonPath("$.message").value("No active subscription found")); } + + @Test + @DisplayName("Should return 201 with charge breakdown when submitting valid premium plan request") + void should_return_201_with_charge_breakdown_when_submit_usage_given_valid_premium_plan_request() throws Exception { + Bill bill = Bill.builder() + .id(UUID.randomUUID()) + .customerId("CUST-PREMIUM") + .modelId("reasoning-model") + .promptTokens(10000) + .completionTokens(20000) + .totalTokens(30000) + .includedTokensUsed(0) + .overageTokens(0) + .promptCharge(new BigDecimal("0.30")) + .completionCharge(new BigDecimal("1.20")) + .totalCharge(new BigDecimal("1.50")) + .calculatedAt(LocalDateTime.now()) + .build(); + + when(billingService.calculateBill(any(UsageRequest.class))).thenReturn(bill); + + mockMvc.perform(post("/api/usage") + .contentType(MediaType.APPLICATION_JSON) + .content(""" + { + "customerId": "CUST-PREMIUM", + "modelId": "reasoning-model", + "promptTokens": 10000, + "completionTokens": 20000 + } + """)) + .andExpect(status().isCreated()) + .andExpect(jsonPath("$.promptCharge").value(0.30)) + .andExpect(jsonPath("$.completionCharge").value(1.20)) + .andExpect(jsonPath("$.includedTokensUsed").value(0)) + .andExpect(jsonPath("$.overageTokens").value(0)); + + verify(billingService).calculateBill(any(UsageRequest.class)); + } + + @Test + @DisplayName("Should return 400 Bad Request when modelId is missing") + void should_return_400_bad_request_when_submit_usage_given_missing_model_id() throws Exception { + mockMvc.perform(post("/api/usage") + .contentType(MediaType.APPLICATION_JSON) + .content(""" + { + "customerId": "CUST-001", + "promptTokens": 1000, + "completionTokens": 500 + } + """)) + .andExpect(status().isBadRequest()) + .andExpect(jsonPath("$.message").value("Model ID is required")); + + verify(billingService, never()).calculateBill(any()); + } + + @Test + @DisplayName("Should return 400 Bad Request when model pricing is not configured") + void should_return_400_bad_request_when_submit_usage_given_unknown_model_id() throws Exception { + when(billingService.calculateBill(any(UsageRequest.class))) + .thenThrow(new ModelPricingNotFoundException("PLAN-STARTER", "unknown-model")); + + mockMvc.perform(post("/api/usage") + .contentType(MediaType.APPLICATION_JSON) + .content(""" + { + "customerId": "CUST-001", + "modelId": "unknown-model", + "promptTokens": 1000, + "completionTokens": 500 + } + """)) + .andExpect(status().isBadRequest()) + .andExpect(jsonPath("$.message").value("Pricing not configured for model: unknown-model")); + } } diff --git a/src/test/java/org/tw/token_billing/domain/BillTest.java b/src/test/java/org/tw/token_billing/domain/BillTest.java index 52be6bc..76dfca5 100644 --- a/src/test/java/org/tw/token_billing/domain/BillTest.java +++ b/src/test/java/org/tw/token_billing/domain/BillTest.java @@ -16,10 +16,11 @@ class BillTest { void should_create_bill_with_correct_totals_when_create_given_valid_inputs() { LocalDateTime beforeCreation = LocalDateTime.now(ZoneOffset.UTC).minusSeconds(1); - Bill bill = Bill.create("CUST-001", 1000, 500, 10000, new BigDecimal("0.02")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 1000, 500, 10000, new BigDecimal("0.02")); assertThat(bill.getId()).isNotNull(); assertThat(bill.getCustomerId()).isEqualTo("CUST-001"); + assertThat(bill.getModelId()).isEqualTo("fast-model"); assertThat(bill.getPromptTokens()).isEqualTo(1000); assertThat(bill.getCompletionTokens()).isEqualTo(500); assertThat(bill.getTotalTokens()).isEqualTo(1500); @@ -30,17 +31,19 @@ void should_create_bill_with_correct_totals_when_create_given_valid_inputs() { @Test @DisplayName("Should create bill with all included when usage is within quota") void should_create_bill_with_all_included_when_create_given_usage_within_quota() { - Bill bill = Bill.create("CUST-001", 1000, 500, 10000, new BigDecimal("0.02")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 1000, 500, 10000, new BigDecimal("0.02")); assertThat(bill.getIncludedTokensUsed()).isEqualTo(1500); assertThat(bill.getOverageTokens()).isEqualTo(0); assertThat(bill.getTotalCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getPromptCharge()).isNull(); + assertThat(bill.getCompletionCharge()).isNull(); } @Test @DisplayName("Should create bill with partial included when usage exceeds quota") void should_create_bill_with_partial_included_when_create_given_usage_exceeds_quota() { - Bill bill = Bill.create("CUST-001", 8000, 5000, 10000, new BigDecimal("0.02")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 8000, 5000, 10000, new BigDecimal("0.02")); assertThat(bill.getTotalTokens()).isEqualTo(13000); assertThat(bill.getIncludedTokensUsed()).isEqualTo(10000); @@ -51,7 +54,7 @@ void should_create_bill_with_partial_included_when_create_given_usage_exceeds_qu @Test @DisplayName("Should create bill with all overage when remaining quota is zero") void should_create_bill_with_all_overage_when_create_given_zero_remaining_quota() { - Bill bill = Bill.create("CUST-001", 1000, 500, 0, new BigDecimal("0.02")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 1000, 500, 0, new BigDecimal("0.02")); assertThat(bill.getIncludedTokensUsed()).isEqualTo(0); assertThat(bill.getOverageTokens()).isEqualTo(1500); @@ -61,7 +64,7 @@ void should_create_bill_with_all_overage_when_create_given_zero_remaining_quota( @Test @DisplayName("Should create bill with all overage when remaining quota is negative") void should_create_bill_with_all_overage_when_create_given_negative_remaining_quota() { - Bill bill = Bill.create("CUST-001", 1000, 500, -500, new BigDecimal("0.02")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 1000, 500, -500, new BigDecimal("0.02")); assertThat(bill.getIncludedTokensUsed()).isEqualTo(0); assertThat(bill.getOverageTokens()).isEqualTo(1500); @@ -71,7 +74,7 @@ void should_create_bill_with_all_overage_when_create_given_negative_remaining_qu @Test @DisplayName("Should create bill with correct charge precision for fractional calculation") void should_create_bill_with_correct_charge_precision_when_create_given_fractional_calculation() { - Bill bill = Bill.create("CUST-001", 1234, 567, 0, new BigDecimal("0.0234")); + Bill bill = Bill.createStandard("CUST-001", "fast-model", 1234, 567, 0, new BigDecimal("0.0234")); assertThat(bill.getTotalTokens()).isEqualTo(1801); assertThat(bill.getOverageTokens()).isEqualTo(1801); @@ -82,4 +85,23 @@ void should_create_bill_with_correct_charge_precision_when_create_given_fraction .divide(new BigDecimal("1000"), 2, java.math.RoundingMode.HALF_UP); assertThat(bill.getTotalCharge()).isEqualByComparingTo(expectedCharge); } + + @Test + @DisplayName("Should create premium bill with split charges") + void should_create_premium_bill_with_split_charges_when_create_premium_given_valid_inputs() { + Bill bill = Bill.createPremium("CUST-001", "reasoning-model", 10000, 20000, + new BigDecimal("0.03"), new BigDecimal("0.06")); + + assertThat(bill.getId()).isNotNull(); + assertThat(bill.getCustomerId()).isEqualTo("CUST-001"); + assertThat(bill.getModelId()).isEqualTo("reasoning-model"); + assertThat(bill.getPromptTokens()).isEqualTo(10000); + assertThat(bill.getCompletionTokens()).isEqualTo(20000); + assertThat(bill.getTotalTokens()).isEqualTo(30000); + assertThat(bill.getIncludedTokensUsed()).isEqualTo(0); + assertThat(bill.getOverageTokens()).isEqualTo(0); + assertThat(bill.getPromptCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(bill.getCompletionCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("1.50")); + } } diff --git a/src/test/java/org/tw/token_billing/dto/BillResponseTest.java b/src/test/java/org/tw/token_billing/dto/BillResponseTest.java index 8cac91b..2611a47 100644 --- a/src/test/java/org/tw/token_billing/dto/BillResponseTest.java +++ b/src/test/java/org/tw/token_billing/dto/BillResponseTest.java @@ -21,11 +21,14 @@ void should_map_all_fields_when_from_bill_given_complete_bill() { Bill bill = Bill.builder() .id(billId) .customerId("CUST-001") + .modelId("fast-model") .promptTokens(1000) .completionTokens(500) .totalTokens(1500) .includedTokensUsed(1000) .overageTokens(500) + .promptCharge(null) + .completionCharge(null) .totalCharge(new BigDecimal("0.01")) .calculatedAt(calculatedAt) .build(); @@ -34,10 +37,46 @@ void should_map_all_fields_when_from_bill_given_complete_bill() { assertThat(response.getBillId()).isEqualTo(billId); assertThat(response.getCustomerId()).isEqualTo("CUST-001"); + assertThat(response.getModelId()).isEqualTo("fast-model"); assertThat(response.getTotalTokens()).isEqualTo(1500); assertThat(response.getIncludedTokensUsed()).isEqualTo(1000); assertThat(response.getOverageTokens()).isEqualTo(500); + assertThat(response.getPromptCharge()).isNull(); + assertThat(response.getCompletionCharge()).isNull(); assertThat(response.getTotalCharge()).isEqualByComparingTo(new BigDecimal("0.01")); assertThat(response.getCalculatedAt()).isEqualTo(calculatedAt); } + + @Test + @DisplayName("Should map premium bill with charge breakdown") + void should_map_premium_bill_when_from_bill_given_premium_bill_with_charges() { + UUID billId = UUID.randomUUID(); + LocalDateTime calculatedAt = LocalDateTime.now(); + + Bill bill = Bill.builder() + .id(billId) + .customerId("CUST-002") + .modelId("reasoning-model") + .promptTokens(10000) + .completionTokens(20000) + .totalTokens(30000) + .includedTokensUsed(0) + .overageTokens(0) + .promptCharge(new BigDecimal("0.30")) + .completionCharge(new BigDecimal("1.20")) + .totalCharge(new BigDecimal("1.50")) + .calculatedAt(calculatedAt) + .build(); + + BillResponse response = BillResponse.fromBill(bill); + + assertThat(response.getBillId()).isEqualTo(billId); + assertThat(response.getCustomerId()).isEqualTo("CUST-002"); + assertThat(response.getModelId()).isEqualTo("reasoning-model"); + assertThat(response.getIncludedTokensUsed()).isEqualTo(0); + assertThat(response.getOverageTokens()).isEqualTo(0); + assertThat(response.getPromptCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(response.getCompletionCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(response.getTotalCharge()).isEqualByComparingTo(new BigDecimal("1.50")); + } } diff --git a/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaBillRepositoryAdapterTest.java b/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaBillRepositoryAdapterTest.java index 5fdb381..931cd14 100644 --- a/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaBillRepositoryAdapterTest.java +++ b/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaBillRepositoryAdapterTest.java @@ -40,11 +40,14 @@ void should_save_and_return_domain_bill_when_save_given_valid_bill() { Bill domainBill = Bill.builder() .id(billId) .customerId("CUST-001") + .modelId("fast-model") .promptTokens(1000) .completionTokens(500) .totalTokens(1500) .includedTokensUsed(1500) .overageTokens(0) + .promptCharge(null) + .completionCharge(null) .totalCharge(BigDecimal.ZERO) .calculatedAt(now) .build(); @@ -52,11 +55,14 @@ void should_save_and_return_domain_bill_when_save_given_valid_bill() { BillPO billPO = BillPO.builder() .id(billId) .customerId("CUST-001") + .modelId("fast-model") .promptTokens(1000) .completionTokens(500) .totalTokens(1500) .includedTokensUsed(1500) .overageTokens(0) + .promptCharge(null) + .completionCharge(null) .totalCharge(BigDecimal.ZERO) .calculatedAt(now) .build(); diff --git a/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapterTest.java b/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapterTest.java new file mode 100644 index 0000000..f43c1b0 --- /dev/null +++ b/src/test/java/org/tw/token_billing/infrastructure/persistence/JpaModelPricingRepositoryAdapterTest.java @@ -0,0 +1,87 @@ +package org.tw.token_billing.infrastructure.persistence; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.Optional; +import java.util.UUID; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.infrastructure.persistence.entity.ModelPricingPO; +import org.tw.token_billing.infrastructure.persistence.mapper.ModelPricingMapper; + +@ExtendWith(MockitoExtension.class) +class JpaModelPricingRepositoryAdapterTest { + + @Mock + private SpringDataModelPricingRepository springDataRepository; + + @Mock + private ModelPricingMapper modelPricingMapper; + + @InjectMocks + private JpaModelPricingRepositoryAdapter adapter; + + @Test + @DisplayName("Should return ModelPricing when finding by existing plan and model combination") + void should_return_model_pricing_when_find_by_plan_id_and_model_id_given_existing_combination() { + String planId = "PLAN-STARTER"; + String modelId = "fast-model"; + UUID id = UUID.randomUUID(); + LocalDateTime now = LocalDateTime.now(); + + ModelPricingPO po = ModelPricingPO.builder() + .id(id) + .planId(planId) + .modelId(modelId) + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(now) + .build(); + + ModelPricing expectedDomain = ModelPricing.builder() + .id(id) + .planId(planId) + .modelId(modelId) + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(now) + .build(); + + when(springDataRepository.findByPlanIdAndModelId(planId, modelId)) + .thenReturn(Optional.of(po)); + when(modelPricingMapper.toDomain(po)).thenReturn(expectedDomain); + + Optional result = adapter.findByPlanIdAndModelId(planId, modelId); + + assertThat(result).isPresent(); + assertThat(result.get().getPlanId()).isEqualTo(planId); + assertThat(result.get().getModelId()).isEqualTo(modelId); + verify(springDataRepository).findByPlanIdAndModelId(planId, modelId); + verify(modelPricingMapper).toDomain(po); + } + + @Test + @DisplayName("Should return empty Optional when finding by non-existent plan and model combination") + void should_return_empty_optional_when_find_by_plan_id_and_model_id_given_non_existent_combination() { + String planId = "PLAN-STARTER"; + String modelId = "unknown-model"; + + when(springDataRepository.findByPlanIdAndModelId(planId, modelId)) + .thenReturn(Optional.empty()); + + Optional result = adapter.findByPlanIdAndModelId(planId, modelId); + + assertThat(result).isEmpty(); + verify(springDataRepository).findByPlanIdAndModelId(planId, modelId); + verify(modelPricingMapper, never()).toDomain(org.mockito.ArgumentMatchers.any()); + } +} diff --git a/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapperTest.java b/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapperTest.java new file mode 100644 index 0000000..25a438e --- /dev/null +++ b/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/ModelPricingMapperTest.java @@ -0,0 +1,50 @@ +package org.tw.token_billing.infrastructure.persistence.mapper; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.UUID; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.infrastructure.persistence.entity.ModelPricingPO; + +class ModelPricingMapperTest { + + private ModelPricingMapper mapper; + + @BeforeEach + void setUp() { + mapper = new ModelPricingMapper(); + } + + @Test + @DisplayName("Should map all fields when converting PO to domain") + void should_map_all_fields_when_to_domain_given_valid_po() { + UUID id = UUID.randomUUID(); + LocalDateTime createdAt = LocalDateTime.now(); + + ModelPricingPO po = ModelPricingPO.builder() + .id(id) + .planId("PLAN-PREMIUM") + .modelId("reasoning-model") + .overageRatePer1k(new BigDecimal("0.02")) + .promptRatePer1k(new BigDecimal("0.03")) + .completionRatePer1k(new BigDecimal("0.06")) + .createdAt(createdAt) + .build(); + + ModelPricing result = mapper.toDomain(po); + + assertThat(result.getId()).isEqualTo(id); + assertThat(result.getPlanId()).isEqualTo("PLAN-PREMIUM"); + assertThat(result.getModelId()).isEqualTo("reasoning-model"); + assertThat(result.getOverageRatePer1k()).isEqualByComparingTo(new BigDecimal("0.02")); + assertThat(result.getPromptRatePer1k()).isEqualByComparingTo(new BigDecimal("0.03")); + assertThat(result.getCompletionRatePer1k()).isEqualByComparingTo(new BigDecimal("0.06")); + assertThat(result.getCreatedAt()).isEqualTo(createdAt); + } +} diff --git a/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapperTest.java b/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapperTest.java new file mode 100644 index 0000000..9c21842 --- /dev/null +++ b/src/test/java/org/tw/token_billing/infrastructure/persistence/mapper/PricingPlanMapperTest.java @@ -0,0 +1,57 @@ +package org.tw.token_billing.infrastructure.persistence.mapper; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.time.LocalDateTime; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.tw.token_billing.domain.PlanType; +import org.tw.token_billing.domain.PricingPlan; +import org.tw.token_billing.infrastructure.persistence.entity.PricingPlanPO; + +class PricingPlanMapperTest { + + private PricingPlanMapper mapper; + + @BeforeEach + void setUp() { + mapper = new PricingPlanMapper(); + } + + @Test + @DisplayName("Should convert STANDARD string to PlanType.STANDARD enum when mapping to domain") + void should_convert_plan_type_string_to_enum_when_to_domain_given_standard_plan_type() { + PricingPlanPO po = PricingPlanPO.builder() + .id("PLAN-STARTER") + .name("Starter Plan") + .planType("STANDARD") + .monthlyQuota(100000) + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + PricingPlan result = mapper.toDomain(po); + + assertThat(result.getPlanType()).isEqualTo(PlanType.STANDARD); + } + + @Test + @DisplayName("Should convert PREMIUM string to PlanType.PREMIUM enum when mapping to domain") + void should_convert_plan_type_string_to_enum_when_to_domain_given_premium_plan_type() { + PricingPlanPO po = PricingPlanPO.builder() + .id("PLAN-PREMIUM") + .name("Premium Plan") + .planType("PREMIUM") + .monthlyQuota(0) + .overageRatePer1k(BigDecimal.ZERO) + .createdAt(LocalDateTime.now()) + .build(); + + PricingPlan result = mapper.toDomain(po); + + assertThat(result.getPlanType()).isEqualTo(PlanType.PREMIUM); + } +} diff --git a/src/test/java/org/tw/token_billing/service/impl/BillingServiceImplTest.java b/src/test/java/org/tw/token_billing/service/impl/BillingServiceImplTest.java index 4aca6cb..e08a8af 100644 --- a/src/test/java/org/tw/token_billing/service/impl/BillingServiceImplTest.java +++ b/src/test/java/org/tw/token_billing/service/impl/BillingServiceImplTest.java @@ -25,13 +25,20 @@ import org.tw.token_billing.domain.Bill; import org.tw.token_billing.domain.Customer; import org.tw.token_billing.domain.CustomerSubscription; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.domain.PlanType; import org.tw.token_billing.domain.PricingPlan; import org.tw.token_billing.dto.UsageRequest; import org.tw.token_billing.exception.CustomerNotFoundException; +import org.tw.token_billing.exception.ModelPricingNotFoundException; import org.tw.token_billing.exception.NoActiveSubscriptionException; import org.tw.token_billing.repository.BillRepository; import org.tw.token_billing.repository.CustomerRepository; import org.tw.token_billing.repository.CustomerSubscriptionRepository; +import org.tw.token_billing.repository.ModelPricingRepository; +import org.tw.token_billing.service.strategy.BillingStrategyFactory; +import org.tw.token_billing.service.strategy.PremiumBillingStrategy; +import org.tw.token_billing.service.strategy.StandardBillingStrategy; @ExtendWith(MockitoExtension.class) class BillingServiceImplTest { @@ -45,12 +52,20 @@ class BillingServiceImplTest { @Mock private BillRepository billRepository; + @Mock + private ModelPricingRepository modelPricingRepository; + + @Mock + private BillingStrategyFactory billingStrategyFactory; + @InjectMocks private BillingServiceImpl billingService; private Customer customer; private PricingPlan pricingPlan; private CustomerSubscription subscription; + private ModelPricing modelPricing; + private static final String MODEL_ID = "fast-model"; @BeforeEach void setUp() { @@ -63,6 +78,7 @@ void setUp() { pricingPlan = PricingPlan.builder() .id("PLAN-STARTER") .name("Starter Plan") + .planType(PlanType.STANDARD) .monthlyQuota(100000) .overageRatePer1k(new BigDecimal("0.02")) .createdAt(LocalDateTime.now()) @@ -76,6 +92,14 @@ void setUp() { .effectiveTo(null) .createdAt(LocalDateTime.now()) .build(); + + modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-STARTER") + .modelId(MODEL_ID) + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); } @Test @@ -83,6 +107,7 @@ void setUp() { void should_return_bill_with_zero_charge_when_calculate_bill_given_usage_within_quota() { UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(1000) .completionTokens(500) .build(); @@ -90,6 +115,9 @@ void should_return_bill_with_zero_charge_when_calculate_bill_given_usage_within_ when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) .thenReturn(Optional.of(subscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-STARTER", MODEL_ID)) + .thenReturn(Optional.of(modelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.STANDARD)).thenReturn(new StandardBillingStrategy()); when(billRepository.sumIncludedTokensUsedForMonth(eq("CUST-001"), any(), any())).thenReturn(0); when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); @@ -112,6 +140,7 @@ void should_return_bill_with_overage_charge_when_calculate_bill_given_usage_exce PricingPlan smallQuotaPlan = PricingPlan.builder() .id("PLAN-FREE") .name("Free Plan") + .planType(PlanType.STANDARD) .monthlyQuota(10000) .overageRatePer1k(new BigDecimal("0.02")) .createdAt(LocalDateTime.now()) @@ -126,8 +155,17 @@ void should_return_bill_with_overage_charge_when_calculate_bill_given_usage_exce .createdAt(LocalDateTime.now()) .build(); + ModelPricing smallPlanModelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-FREE") + .modelId(MODEL_ID) + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(8000) .completionTokens(5000) .build(); @@ -135,6 +173,9 @@ void should_return_bill_with_overage_charge_when_calculate_bill_given_usage_exce when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) .thenReturn(Optional.of(smallQuotaSubscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-FREE", MODEL_ID)) + .thenReturn(Optional.of(smallPlanModelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.STANDARD)).thenReturn(new StandardBillingStrategy()); when(billRepository.sumIncludedTokensUsedForMonth(eq("CUST-001"), any(), any())).thenReturn(0); when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); @@ -151,6 +192,7 @@ void should_return_bill_with_overage_charge_when_calculate_bill_given_usage_exce void should_return_bill_with_full_overage_when_calculate_bill_given_zero_remaining_quota() { UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(1000) .completionTokens(500) .build(); @@ -158,6 +200,9 @@ void should_return_bill_with_full_overage_when_calculate_bill_given_zero_remaini when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) .thenReturn(Optional.of(subscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-STARTER", MODEL_ID)) + .thenReturn(Optional.of(modelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.STANDARD)).thenReturn(new StandardBillingStrategy()); when(billRepository.sumIncludedTokensUsedForMonth(eq("CUST-001"), any(), any())).thenReturn(100000); when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); @@ -173,6 +218,7 @@ void should_return_bill_with_full_overage_when_calculate_bill_given_zero_remaini void should_return_bill_with_zero_tokens_when_calculate_bill_given_zero_usage() { UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(0) .completionTokens(0) .build(); @@ -180,6 +226,9 @@ void should_return_bill_with_zero_tokens_when_calculate_bill_given_zero_usage() when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) .thenReturn(Optional.of(subscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-STARTER", MODEL_ID)) + .thenReturn(Optional.of(modelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.STANDARD)).thenReturn(new StandardBillingStrategy()); when(billRepository.sumIncludedTokensUsedForMonth(eq("CUST-001"), any(), any())).thenReturn(0); when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); @@ -196,6 +245,7 @@ void should_return_bill_with_zero_tokens_when_calculate_bill_given_zero_usage() void should_throw_customer_not_found_exception_when_calculate_bill_given_invalid_customer_id() { UsageRequest request = UsageRequest.builder() .customerId("INVALID-CUSTOMER") + .modelId(MODEL_ID) .promptTokens(1000) .completionTokens(500) .build(); @@ -214,6 +264,7 @@ void should_throw_customer_not_found_exception_when_calculate_bill_given_invalid void should_throw_no_active_subscription_exception_when_calculate_bill_given_customer_without_subscription() { UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(1000) .completionTokens(500) .build(); @@ -234,6 +285,7 @@ void should_throw_no_active_subscription_exception_when_calculate_bill_given_cus void should_use_correct_month_boundaries_when_calculate_bill_given_mid_month_request() { UsageRequest request = UsageRequest.builder() .customerId("CUST-001") + .modelId(MODEL_ID) .promptTokens(1000) .completionTokens(500) .build(); @@ -241,6 +293,9 @@ void should_use_correct_month_boundaries_when_calculate_bill_given_mid_month_req when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) .thenReturn(Optional.of(subscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-STARTER", MODEL_ID)) + .thenReturn(Optional.of(modelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.STANDARD)).thenReturn(new StandardBillingStrategy()); when(billRepository.sumIncludedTokensUsedForMonth(eq("CUST-001"), any(), any())).thenReturn(0); when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); @@ -269,4 +324,147 @@ void should_use_correct_month_boundaries_when_calculate_bill_given_mid_month_req assertThat(monthEnd.getSecond()).isEqualTo(0); assertThat(monthEnd.getMonth()).isEqualTo(monthStart.getMonth().plus(1)); } + + @Test + @DisplayName("Should return bill with split charges when calculating bill for premium plan") + void should_return_bill_with_split_charges_when_calculate_bill_given_premium_plan_usage() { + PricingPlan premiumPlan = PricingPlan.builder() + .id("PLAN-PREMIUM") + .name("Premium Plan") + .planType(PlanType.PREMIUM) + .monthlyQuota(0) + .overageRatePer1k(BigDecimal.ZERO) + .createdAt(LocalDateTime.now()) + .build(); + + CustomerSubscription premiumSubscription = CustomerSubscription.builder() + .id(UUID.randomUUID()) + .customerId("CUST-PREMIUM") + .plan(premiumPlan) + .effectiveFrom(LocalDate.of(2026, 1, 1)) + .effectiveTo(null) + .createdAt(LocalDateTime.now()) + .build(); + + ModelPricing premiumModelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("reasoning-model") + .promptRatePer1k(new BigDecimal("0.03")) + .completionRatePer1k(new BigDecimal("0.06")) + .createdAt(LocalDateTime.now()) + .build(); + + Customer premiumCustomer = Customer.builder() + .id("CUST-PREMIUM") + .name("Premium Customer") + .createdAt(LocalDateTime.now()) + .build(); + + UsageRequest request = UsageRequest.builder() + .customerId("CUST-PREMIUM") + .modelId("reasoning-model") + .promptTokens(10000) + .completionTokens(20000) + .build(); + + when(customerRepository.findById("CUST-PREMIUM")).thenReturn(Optional.of(premiumCustomer)); + when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-PREMIUM"), any(LocalDate.class))) + .thenReturn(Optional.of(premiumSubscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-PREMIUM", "reasoning-model")) + .thenReturn(Optional.of(premiumModelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.PREMIUM)).thenReturn(new PremiumBillingStrategy()); + when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); + + Bill result = billingService.calculateBill(request); + + assertThat(result.getIncludedTokensUsed()).isEqualTo(0); + assertThat(result.getOverageTokens()).isEqualTo(0); + assertThat(result.getPromptCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(result.getCompletionCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(result.getTotalCharge()).isEqualByComparingTo(new BigDecimal("1.50")); + + verify(billingStrategyFactory).getStrategy(PlanType.PREMIUM); + verify(billRepository, never()).sumIncludedTokensUsedForMonth(any(), any(), any()); + } + + @Test + @DisplayName("Should throw ModelPricingNotFoundException when model pricing is not configured") + void should_throw_model_pricing_not_found_exception_when_calculate_bill_given_unknown_model() { + UsageRequest request = UsageRequest.builder() + .customerId("CUST-001") + .modelId("unknown-model") + .promptTokens(1000) + .completionTokens(500) + .build(); + + when(customerRepository.findById("CUST-001")).thenReturn(Optional.of(customer)); + when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-001"), any(LocalDate.class))) + .thenReturn(Optional.of(subscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-STARTER", "unknown-model")) + .thenReturn(Optional.empty()); + + assertThatThrownBy(() -> billingService.calculateBill(request)) + .isInstanceOf(ModelPricingNotFoundException.class) + .hasMessageContaining("unknown-model"); + + verify(billRepository, never()).save(any()); + } + + @Test + @DisplayName("Should not query quota usage when calculating bill for premium plan with zero quota") + void should_return_zero_remaining_quota_when_calculate_bill_given_premium_plan() { + PricingPlan premiumPlan = PricingPlan.builder() + .id("PLAN-PREMIUM") + .name("Premium Plan") + .planType(PlanType.PREMIUM) + .monthlyQuota(0) + .overageRatePer1k(BigDecimal.ZERO) + .createdAt(LocalDateTime.now()) + .build(); + + CustomerSubscription premiumSubscription = CustomerSubscription.builder() + .id(UUID.randomUUID()) + .customerId("CUST-PREMIUM") + .plan(premiumPlan) + .effectiveFrom(LocalDate.of(2026, 1, 1)) + .effectiveTo(null) + .createdAt(LocalDateTime.now()) + .build(); + + ModelPricing premiumModelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("fast-model") + .promptRatePer1k(new BigDecimal("0.01")) + .completionRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + Customer premiumCustomer = Customer.builder() + .id("CUST-PREMIUM") + .name("Premium Customer") + .createdAt(LocalDateTime.now()) + .build(); + + UsageRequest request = UsageRequest.builder() + .customerId("CUST-PREMIUM") + .modelId("fast-model") + .promptTokens(5000) + .completionTokens(3000) + .build(); + + when(customerRepository.findById("CUST-PREMIUM")).thenReturn(Optional.of(premiumCustomer)); + when(customerSubscriptionRepository.findActiveSubscription(eq("CUST-PREMIUM"), any(LocalDate.class))) + .thenReturn(Optional.of(premiumSubscription)); + when(modelPricingRepository.findByPlanIdAndModelId("PLAN-PREMIUM", "fast-model")) + .thenReturn(Optional.of(premiumModelPricing)); + when(billingStrategyFactory.getStrategy(PlanType.PREMIUM)).thenReturn(new PremiumBillingStrategy()); + when(billRepository.save(any(Bill.class))).thenAnswer(invocation -> invocation.getArgument(0)); + + billingService.calculateBill(request); + + verify(billRepository, never()).sumIncludedTokensUsedForMonth(any(), any(), any()); + verify(billingStrategyFactory).getStrategy(PlanType.PREMIUM); + } } diff --git a/src/test/java/org/tw/token_billing/service/strategy/BillingStrategyFactoryTest.java b/src/test/java/org/tw/token_billing/service/strategy/BillingStrategyFactoryTest.java new file mode 100644 index 0000000..8cf88d5 --- /dev/null +++ b/src/test/java/org/tw/token_billing/service/strategy/BillingStrategyFactoryTest.java @@ -0,0 +1,64 @@ +package org.tw.token_billing.service.strategy; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.tw.token_billing.domain.PlanType; + +class BillingStrategyFactoryTest { + + private BillingStrategyFactory factory; + private StandardBillingStrategy standardStrategy; + private PremiumBillingStrategy premiumStrategy; + + @BeforeEach + void setUp() { + standardStrategy = new StandardBillingStrategy(); + premiumStrategy = new PremiumBillingStrategy(); + factory = new BillingStrategyFactory(List.of(standardStrategy, premiumStrategy)); + } + + @Test + @DisplayName("Should return StandardBillingStrategy when getting strategy for STANDARD plan type") + void should_return_standard_strategy_when_get_strategy_given_standard_plan_type() { + BillingStrategy result = factory.getStrategy(PlanType.STANDARD); + + assertThat(result).isInstanceOf(StandardBillingStrategy.class); + assertThat(result.supportedPlanType()).isEqualTo(PlanType.STANDARD); + } + + @Test + @DisplayName("Should return PremiumBillingStrategy when getting strategy for PREMIUM plan type") + void should_return_premium_strategy_when_get_strategy_given_premium_plan_type() { + BillingStrategy result = factory.getStrategy(PlanType.PREMIUM); + + assertThat(result).isInstanceOf(PremiumBillingStrategy.class); + assertThat(result.supportedPlanType()).isEqualTo(PlanType.PREMIUM); + } + + @Test + @DisplayName("Should throw IllegalArgumentException when getting strategy for null plan type") + void should_throw_illegal_argument_exception_when_get_strategy_given_null_plan_type() { + assertThatThrownBy(() -> factory.getStrategy(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + @DisplayName("Should build strategy map correctly when constructed with list of strategies") + void should_build_strategy_map_correctly_when_construct_given_list_of_strategies() { + BillingStrategyFactory newFactory = new BillingStrategyFactory( + List.of(new StandardBillingStrategy(), new PremiumBillingStrategy()) + ); + + BillingStrategy standardResult = newFactory.getStrategy(PlanType.STANDARD); + BillingStrategy premiumResult = newFactory.getStrategy(PlanType.PREMIUM); + + assertThat(standardResult).isInstanceOf(StandardBillingStrategy.class); + assertThat(premiumResult).isInstanceOf(PremiumBillingStrategy.class); + } +} diff --git a/src/test/java/org/tw/token_billing/service/strategy/PremiumBillingStrategyTest.java b/src/test/java/org/tw/token_billing/service/strategy/PremiumBillingStrategyTest.java new file mode 100644 index 0000000..a9774d6 --- /dev/null +++ b/src/test/java/org/tw/token_billing/service/strategy/PremiumBillingStrategyTest.java @@ -0,0 +1,172 @@ +package org.tw.token_billing.service.strategy; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.UUID; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.domain.PlanType; + +class PremiumBillingStrategyTest { + + private PremiumBillingStrategy strategy; + + @BeforeEach + void setUp() { + strategy = new PremiumBillingStrategy(); + } + + @Test + @DisplayName("Should return PREMIUM when getting supported plan type") + void should_return_premium_when_supported_plan_type_given_strategy_instance() { + PlanType result = strategy.supportedPlanType(); + + assertThat(result).isEqualTo(PlanType.PREMIUM); + } + + @Test + @DisplayName("Should return bill with split charges when calculating with valid context") + void should_return_bill_with_split_charges_when_calculate_given_valid_context() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("reasoning-model") + .promptRatePer1k(new BigDecimal("0.03")) + .completionRatePer1k(new BigDecimal("0.06")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-PREMIUM") + .modelId("reasoning-model") + .promptTokens(10000) + .completionTokens(20000) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getPromptCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(bill.getCompletionCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("1.50")); + } + + @Test + @DisplayName("Should return bill with zero included tokens for premium plan") + void should_return_bill_with_zero_included_tokens_when_calculate_given_premium_plan() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("fast-model") + .promptRatePer1k(new BigDecimal("0.01")) + .completionRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-PREMIUM") + .modelId("fast-model") + .promptTokens(5000) + .completionTokens(3000) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getIncludedTokensUsed()).isEqualTo(0); + assertThat(bill.getOverageTokens()).isEqualTo(0); + } + + @Test + @DisplayName("Should return bill with zero charges when tokens are zero") + void should_return_bill_with_zero_charges_when_calculate_given_zero_tokens() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("fast-model") + .promptRatePer1k(new BigDecimal("0.01")) + .completionRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-PREMIUM") + .modelId("fast-model") + .promptTokens(0) + .completionTokens(0) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getPromptCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getCompletionCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(BigDecimal.ZERO); + } + + @Test + @DisplayName("Should return bill with only prompt charge when completion tokens are zero") + void should_return_bill_with_only_prompt_charge_when_calculate_given_zero_completion_tokens() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("reasoning-model") + .promptRatePer1k(new BigDecimal("0.03")) + .completionRatePer1k(new BigDecimal("0.06")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-PREMIUM") + .modelId("reasoning-model") + .promptTokens(10000) + .completionTokens(0) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getPromptCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + assertThat(bill.getCompletionCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("0.30")); + } + + @Test + @DisplayName("Should return bill with only completion charge when prompt tokens are zero") + void should_return_bill_with_only_completion_charge_when_calculate_given_zero_prompt_tokens() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-PREMIUM") + .modelId("reasoning-model") + .promptRatePer1k(new BigDecimal("0.03")) + .completionRatePer1k(new BigDecimal("0.06")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-PREMIUM") + .modelId("reasoning-model") + .promptTokens(0) + .completionTokens(20000) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getPromptCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getCompletionCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("1.20")); + } +} diff --git a/src/test/java/org/tw/token_billing/service/strategy/StandardBillingStrategyTest.java b/src/test/java/org/tw/token_billing/service/strategy/StandardBillingStrategyTest.java new file mode 100644 index 0000000..b5604b3 --- /dev/null +++ b/src/test/java/org/tw/token_billing/service/strategy/StandardBillingStrategyTest.java @@ -0,0 +1,141 @@ +package org.tw.token_billing.service.strategy; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.time.LocalDateTime; +import java.util.UUID; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.tw.token_billing.domain.Bill; +import org.tw.token_billing.domain.BillingContext; +import org.tw.token_billing.domain.ModelPricing; +import org.tw.token_billing.domain.PlanType; + +class StandardBillingStrategyTest { + + private StandardBillingStrategy strategy; + + @BeforeEach + void setUp() { + strategy = new StandardBillingStrategy(); + } + + @Test + @DisplayName("Should return STANDARD when getting supported plan type") + void should_return_standard_when_supported_plan_type_given_strategy_instance() { + PlanType result = strategy.supportedPlanType(); + + assertThat(result).isEqualTo(PlanType.STANDARD); + } + + @Test + @DisplayName("Should return bill with all tokens included when usage is within quota") + void should_return_bill_with_all_tokens_included_when_calculate_given_usage_within_quota() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-STARTER") + .modelId("fast-model") + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-001") + .modelId("fast-model") + .promptTokens(1000) + .completionTokens(500) + .remainingQuota(10000) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getIncludedTokensUsed()).isEqualTo(1500); + assertThat(bill.getOverageTokens()).isEqualTo(0); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(BigDecimal.ZERO); + assertThat(bill.getModelId()).isEqualTo("fast-model"); + } + + @Test + @DisplayName("Should return bill with overage when usage exceeds quota") + void should_return_bill_with_overage_when_calculate_given_usage_exceeds_quota() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-FREE") + .modelId("fast-model") + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-001") + .modelId("fast-model") + .promptTokens(8000) + .completionTokens(5000) + .remainingQuota(10000) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getIncludedTokensUsed()).isEqualTo(10000); + assertThat(bill.getOverageTokens()).isEqualTo(3000); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("0.06")); + } + + @Test + @DisplayName("Should return bill with all overage when remaining quota is zero") + void should_return_bill_with_all_overage_when_calculate_given_zero_remaining_quota() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-STARTER") + .modelId("fast-model") + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-001") + .modelId("fast-model") + .promptTokens(1000) + .completionTokens(500) + .remainingQuota(0) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getIncludedTokensUsed()).isEqualTo(0); + assertThat(bill.getOverageTokens()).isEqualTo(1500); + assertThat(bill.getTotalCharge()).isEqualByComparingTo(new BigDecimal("0.03")); + } + + @Test + @DisplayName("Should return bill with null prompt and completion charges for standard plan") + void should_return_bill_with_null_prompt_completion_charges_when_calculate_given_standard_plan() { + ModelPricing modelPricing = ModelPricing.builder() + .id(UUID.randomUUID()) + .planId("PLAN-STARTER") + .modelId("fast-model") + .overageRatePer1k(new BigDecimal("0.02")) + .createdAt(LocalDateTime.now()) + .build(); + + BillingContext context = BillingContext.builder() + .customerId("CUST-001") + .modelId("fast-model") + .promptTokens(1000) + .completionTokens(500) + .remainingQuota(10000) + .modelPricing(modelPricing) + .build(); + + Bill bill = strategy.calculate(context); + + assertThat(bill.getPromptCharge()).isNull(); + assertThat(bill.getCompletionCharge()).isNull(); + } +} diff --git a/src/test/resources/application.yml b/src/test/resources/application.yml new file mode 100644 index 0000000..31402d8 --- /dev/null +++ b/src/test/resources/application.yml @@ -0,0 +1,19 @@ +spring: + application: + name: token-billing-test + + datasource: + url: jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;MODE=PostgreSQL + username: sa + password: + driver-class-name: org.h2.Driver + + jpa: + hibernate: + ddl-auto: none + show-sql: true + database-platform: org.hibernate.dialect.H2Dialect + + flyway: + enabled: true + locations: classpath:db/migration