@@ -60,7 +60,7 @@ class Settings(BaseSettings):
60
60
61
61
62
62
class CreateCompletionRequest (BaseModel ):
63
- prompt : str
63
+ prompt : Union [ str , List [ str ]]
64
64
suffix : Optional [str ] = Field (None )
65
65
max_tokens : int = 16
66
66
temperature : float = 0.8
@@ -100,10 +100,10 @@ class Config:
100
100
response_model = CreateCompletionResponse ,
101
101
)
102
102
def create_completion (request : CreateCompletionRequest ):
103
- if request .stream :
104
- chunks : Iterator [ llama_cpp . CompletionChunk ] = llama ( ** request .dict ()) # type: ignore
105
- return EventSourceResponse ( dict ( data = json . dumps ( chunk )) for chunk in chunks )
106
- return llama (
103
+ if isinstance ( request .prompt , list ) :
104
+ request . prompt = "" . join ( request .prompt )
105
+
106
+ completion_or_chunks = llama (
107
107
** request .dict (
108
108
exclude = {
109
109
"model" ,
@@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
117
117
}
118
118
)
119
119
)
120
+ if request .stream :
121
+ chunks : Iterator [llama_cpp .CompletionChunk ] = completion_or_chunks # type: ignore
122
+ return EventSourceResponse (dict (data = json .dumps (chunk )) for chunk in chunks )
123
+ completion : llama_cpp .Completion = completion_or_chunks # type: ignore
124
+ return completion
120
125
121
126
122
127
class CreateEmbeddingRequest (BaseModel ):
@@ -259,4 +264,6 @@ def get_models() -> ModelList:
259
264
import os
260
265
import uvicorn
261
266
262
- uvicorn .run (app , host = os .getenv ("HOST" , "localhost" ), port = int (os .getenv ("PORT" , 8000 )))
267
+ uvicorn .run (
268
+ app , host = os .getenv ("HOST" , "localhost" ), port = int (os .getenv ("PORT" , 8000 ))
269
+ )
0 commit comments