@@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
14
14
}
15
15
function_calling_utils::PreprocessRequest (json_body);
16
16
auto tool_choice = json_body->get (" tool_choice" , Json::Value::null);
17
+ auto model_id = json_body->get (" model" , " " ).asString ();
18
+ if (saved_models_.find (model_id) != saved_models_.end ()) {
19
+ // check if model is started, if not start it first
20
+ Json::Value root;
21
+ root[" model" ] = model_id;
22
+ root[" engine" ] = engine_type;
23
+ auto ir = GetModelStatus (std::make_shared<Json::Value>(root));
24
+ auto status = std::get<0 >(ir)[" status_code" ].asInt ();
25
+ if (status != drogon::k200OK) {
26
+ CTL_INF (" Model is not loaded, start loading it: " << model_id);
27
+ auto res = LoadModel (saved_models_.at (model_id));
28
+ // ignore return result
29
+ }
30
+ }
31
+
17
32
auto engine_result = engine_service_->GetLoadedEngine (engine_type);
18
33
if (engine_result.has_error ()) {
19
34
Json::Value res;
@@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
23
38
LOG_WARN << " Engine is not loaded yet" ;
24
39
return cpp::fail (std::make_pair (stt, res));
25
40
}
41
+
42
+ if (!model_id.empty ()) {
43
+ if (auto model_service = model_service_.lock ()) {
44
+ auto metadata_ptr = model_service->GetCachedModelMetadata (model_id);
45
+ if (metadata_ptr != nullptr &&
46
+ !metadata_ptr->tokenizer ->chat_template .empty ()) {
47
+ auto tokenizer = metadata_ptr->tokenizer ;
48
+ auto messages = (*json_body)[" messages" ];
49
+ Json::Value messages_jsoncpp (Json::arrayValue);
50
+ for (auto message : messages) {
51
+ messages_jsoncpp.append (message);
52
+ }
26
53
27
- {
28
- auto model_id = json_body->get (" model" , " " ).asString ();
29
- if (!model_id.empty ()) {
30
- if (auto model_service = model_service_.lock ()) {
31
- auto metadata_ptr = model_service->GetCachedModelMetadata (model_id);
32
- if (metadata_ptr != nullptr &&
33
- !metadata_ptr->tokenizer ->chat_template .empty ()) {
34
- auto tokenizer = metadata_ptr->tokenizer ;
35
- auto messages = (*json_body)[" messages" ];
36
- Json::Value messages_jsoncpp (Json::arrayValue);
37
- for (auto message : messages) {
38
- messages_jsoncpp.append (message);
39
- }
40
-
41
- Json::Value tools (Json::arrayValue);
42
- Json::Value template_data_json;
43
- template_data_json[" messages" ] = messages_jsoncpp;
44
- // template_data_json["tools"] = tools;
45
-
46
- auto prompt_result = jinja::RenderTemplate (
47
- tokenizer->chat_template , template_data_json,
48
- tokenizer->bos_token , tokenizer->eos_token ,
49
- tokenizer->add_bos_token , tokenizer->add_eos_token ,
50
- tokenizer->add_generation_prompt );
51
- if (prompt_result.has_value ()) {
52
- (*json_body)[" prompt" ] = prompt_result.value ();
53
- Json::Value stops (Json::arrayValue);
54
- stops.append (tokenizer->eos_token );
55
- (*json_body)[" stop" ] = stops;
56
- } else {
57
- CTL_ERR (" Failed to render prompt: " + prompt_result.error ());
58
- }
54
+ Json::Value tools (Json::arrayValue);
55
+ Json::Value template_data_json;
56
+ template_data_json[" messages" ] = messages_jsoncpp;
57
+ // template_data_json["tools"] = tools;
58
+
59
+ auto prompt_result = jinja::RenderTemplate (
60
+ tokenizer->chat_template , template_data_json, tokenizer->bos_token ,
61
+ tokenizer->eos_token , tokenizer->add_bos_token ,
62
+ tokenizer->add_eos_token , tokenizer->add_generation_prompt );
63
+ if (prompt_result.has_value ()) {
64
+ (*json_body)[" prompt" ] = prompt_result.value ();
65
+ Json::Value stops (Json::arrayValue);
66
+ stops.append (tokenizer->eos_token );
67
+ (*json_body)[" stop" ] = stops;
68
+ } else {
69
+ CTL_ERR (" Failed to render prompt: " + prompt_result.error ());
59
70
}
60
71
}
61
72
}
62
73
}
63
74
64
- CTL_INF (" Json body inference: " + json_body->toStyledString ());
75
+
76
+ CTL_DBG (" Json body inference: " + json_body->toStyledString ());
65
77
66
78
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
67
79
if (!tool_choice.isNull ()) {
@@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
205
217
std::get<RemoteEngineI*>(engine_result.value ())
206
218
->LoadModel (json_body, std::move (cb));
207
219
}
220
+ if (!engine_service_->IsRemoteEngine (engine_type)) {
221
+ auto model_id = json_body->get (" model" , " " ).asString ();
222
+ saved_models_[model_id] = json_body;
223
+ }
208
224
return std::make_pair (stt, r);
209
225
}
210
226
0 commit comments