Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 55279b6

Browse files
committed
Handle prompt list
1 parent 38f7dea commit 55279b6

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

llama_cpp/server/__main__.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Settings(BaseSettings):
6060

6161

6262
class CreateCompletionRequest(BaseModel):
63-
prompt: str
63+
prompt: Union[str, List[str]]
6464
suffix: Optional[str] = Field(None)
6565
max_tokens: int = 16
6666
temperature: float = 0.8
@@ -100,10 +100,10 @@ class Config:
100100
response_model=CreateCompletionResponse,
101101
)
102102
def create_completion(request: CreateCompletionRequest):
103-
if request.stream:
104-
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
105-
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
106-
return llama(
103+
if isinstance(request.prompt, list):
104+
request.prompt = "".join(request.prompt)
105+
106+
completion_or_chunks = llama(
107107
**request.dict(
108108
exclude={
109109
"model",
@@ -117,6 +117,11 @@ def create_completion(request: CreateCompletionRequest):
117117
}
118118
)
119119
)
120+
if request.stream:
121+
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
122+
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
123+
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
124+
return completion
120125

121126

122127
class CreateEmbeddingRequest(BaseModel):
@@ -259,4 +264,6 @@ def get_models() -> ModelList:
259264
import os
260265
import uvicorn
261266

262-
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)))
267+
uvicorn.run(
268+
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
269+
)

0 commit comments

Comments
 (0)