diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index aa6afc112..2abd83ef9 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -2,6 +2,8 @@ import os import json +import boto3 +#import gradio as gr from threading import Lock from functools import partial @@ -50,6 +52,54 @@ from llama_cpp.server.errors import RouteErrorHandler +title_message = os.getenv('TITLEMESSAGE', "🦙 llama.cpp Python API") +apitable = os.getenv('APITABLE') + +def check_and_update_api_key(api_key, invocation_type, credit_cost=1): + # Initialize a boto3 DynamoDB resource + dynamodb = boto3.resource('dynamodb') + table = dynamodb.Table(apitable) # Ensure 'apitable' is correctly defined earlier in your code + + print("The api key coming in is ", api_key) + + # Try to get the item for the given API key + response = table.get_item(Key={'ApiKey': api_key}) + item = response.get('Item') + + if not item or not item.get('Authorized'): + # API key not found, not authorized, or not enough credits + return False,"API key not authorized. " + creditval = item.get('Credits', 0) + if creditval < credit_cost: + return False,"API Key does not have enough credits, have "+str(creditval)+", need "+str(credit_cost) + + # Prepare the update expression + update_expression = "SET Credits = Credits - :cost" + expression_attribute_values = {':cost': credit_cost, ':newval': 1} + expression_attribute_names = {'#type': invocation_type} + + # The UpdateExpression to handle both new and existing invocation types + update_expression += ", TotalInvocations.#type = if_not_exists(TotalInvocations.#type, :startval) + :newval" + + expression_attribute_values[':startval'] = 0 + + # Update the item in DynamoDB for the given API key + try: + table.update_item( + Key={'ApiKey': api_key}, + UpdateExpression=update_expression, + ExpressionAttributeNames=expression_attribute_names, + ExpressionAttributeValues=expression_attribute_values, + ConditionExpression="attribute_exists(ApiKey) AND Credits >= :cost", + ReturnValues="UPDATED_NEW" + ) + return True,"" + except Exception as e: + print(f"Error updating item: {e}") + return False, "There was an error with that API key. Please check and try again, otherwise contact support." + + + router = APIRouter(route_class=RouteErrorHandler) _server_settings: Optional[ServerSettings] = None @@ -122,7 +172,8 @@ def create_app( middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))] app = FastAPI( middleware=middleware, - title="🦙 llama.cpp Python API", + ###WORKHERE Make a modification so this reads in from OS on the specific endpoint for the end customer + title=title_message, version=llama_cpp.__version__, ) app.add_middleware( @@ -136,7 +187,11 @@ def create_app( assert model_settings is not None set_llama_proxy(model_settings=model_settings) + #We're going to see if we can get the gradio url settings working + #CUSTOM_PATH = "/gradio" + #io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox",share=True,debug=True) + #app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH) return app @@ -180,7 +235,7 @@ def _logit_bias_tokens_to_input_ids( # Setup Bearer authentication scheme bearer_scheme = HTTPBearer(auto_error=False) - +#so here is where I can put in my custom API authentication system. ###WORKHERE async def authenticate( settings: Settings = Depends(get_server_settings), authorization: Optional[str] = Depends(bearer_scheme), @@ -191,13 +246,21 @@ async def authenticate( # check bearer credentials against the api_key if authorization and authorization.credentials == settings.api_key: + #goodkey,message=check_and_update_api_key(api_key=authorization.credentials,invocation_type="text") + #if goodkey: # api key is valid + # return authorization.credentials + #else: + # raise HTTPException( + # status_code=status.HTTP_401_UNAUTHORIZED, + # detail=message, + # ) return authorization.credentials # raise http error 401 raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", + detail="Invalid API key. Check API key and credits.", ) diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index daa913fac..f905c5717 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -3,7 +3,7 @@ import multiprocessing from typing import Optional, List, Literal, Union -from pydantic import Field +from pydantic import Field, root_validator from pydantic_settings import BaseSettings import llama_cpp @@ -67,12 +67,12 @@ class ModelSettings(BaseSettings): n_threads: int = Field( default=max(multiprocessing.cpu_count() // 2, 1), ge=1, - description="The number of threads to use.", + description="The number of threads to use. Use -1 for max cpu threads", ) n_threads_batch: int = Field( default=max(multiprocessing.cpu_count() // 2, 1), ge=0, - description="The number of threads to use when batch processing.", + description="The number of threads to use when batch processing. Use -1 for max cpu threads", ) rope_scaling_type: int = Field( default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED @@ -163,6 +163,15 @@ class ModelSettings(BaseSettings): verbose: bool = Field( default=True, description="Whether to print debug information." ) + @root_validator(pre=True) # pre=True to ensure this runs before any other validation + def set_dynamic_defaults(cls, values): + # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() + cpu_count = multiprocessing.cpu_count() + if values.get('n_threads', 0) == -1: + values['n_threads'] = cpu_count + if values.get('n_threads_batch', 0) == -1: + values['n_threads_batch'] = cpu_count + return values class ServerSettings(BaseSettings):