Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2b686a9

Browse files
authored
server: refactor child --> router communication (ggml-org#24821)
* server: refactor child --> router communication * fix wakeup case * add docs * improve update_status() * nits
1 parent 4b48a53 commit 2b686a9

8 files changed

Lines changed: 173 additions & 91 deletions

File tree

common/arg.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
303303

304304
if (!model.docker_repo.empty()) {
305305
model.path = common_docker_resolve_model(model.docker_repo);
306-
model.name = model.docker_repo;
307306
} else if (!model.hf_repo.empty()) {
308307
// If -m was used with -hf, treat the model "path" as the hf_file to download
309308
if (model.hf_file.empty() && !model.path.empty()) {
@@ -323,7 +322,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
323322
throw std::runtime_error("failed to download model from Hugging Face");
324323
}
325324

326-
model.name = model.hf_repo;
327325
model.path = download_result.model_path;
328326

329327
if (!download_result.mmproj_path.empty()) {

common/common.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,16 @@ struct common_params_model {
295295
std::string hf_repo = ""; // HF repo // NOLINT
296296
std::string hf_file = ""; // HF file // NOLINT
297297
std::string docker_repo = ""; // Docker repo // NOLINT
298-
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
298+
299+
std::string get_name() {
300+
if (!hf_repo.empty()) {
301+
return hf_repo;
302+
}
303+
if (!docker_repo.empty()) {
304+
return docker_repo;
305+
}
306+
return path;
307+
}
299308
};
300309

301310
// draft-model-based speculative decoding parameters

tools/server/README-dev.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,17 @@ That requires `JSON.stringify` when formatted to message content:
180180
}
181181
```
182182

183+
### Router mode: how child <--> router communicates
184+
185+
Upon spawning a new child process using `subprocess`, both child and router listen to the stdout/stderr (combined)
186+
187+
For the direction from child to router:
188+
- Generic messages are logs, it will be forwarded to router's stdout
189+
- Special state update messages are prefixed by `cmd_child_to_router:state:`, followed by a JSON. See `server_models::handle_child_state` for more
190+
191+
For the direction from router to child:
192+
- When server sends `cmd_router_to_child:exit`, the child should exit gracefully --> if after `DEFAULT_STOP_TIMEOUT` and the child is still running, force-kill it
193+
183194
### Model management API (router mode)
184195

185196
Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976)

tools/server/server-context.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,6 @@ enum slot_state {
6363
SLOT_STATE_GENERATING,
6464
};
6565

66-
enum server_state {
67-
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
68-
SERVER_STATE_READY, // Server is ready and model is loaded
69-
};
70-
7166
struct server_slot {
7267
int id;
7368

@@ -773,6 +768,8 @@ struct server_context_impl {
773768
// note: chat_params must not be refreshed upon existing sleeping state
774769
server_chat_params chat_params;
775770

771+
server_state_callback_t callback_state = [](server_state, json) -> void {};
772+
776773
server_context_impl() {
777774
mtmd_helper_log_set(common_log_default_callback, nullptr);
778775
}
@@ -1244,8 +1241,8 @@ struct server_context_impl {
12441241
if (!params_base.model_alias.empty()) {
12451242
// backward compat: use first alias as model name
12461243
model_name = *params_base.model_alias.begin();
1247-
} else if (!params_base.model.name.empty()) {
1248-
model_name = params_base.model.name;
1244+
} else if (!params_base.model.get_name().empty()) {
1245+
model_name = params_base.model.get_name();
12491246
} else {
12501247
// fallback: derive model name from file name
12511248
auto model_path = std::filesystem::path(params_base.model.path);
@@ -3734,8 +3731,11 @@ struct server_res_generator : server_http_res {
37343731
}
37353732
};
37363733

3737-
void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
3738-
impl->queue_tasks.on_sleeping_state(std::move(callback));
3734+
void server_context::set_state_callback(server_state_callback_t callback) {
3735+
impl->callback_state = std::move(callback);
3736+
impl->queue_tasks.on_sleeping_state([this](bool sleeping) {
3737+
impl->callback_state(sleeping ? SERVER_STATE_SLEEPING : SERVER_STATE_READY, {});
3738+
});
37393739
}
37403740

37413741
// compute the number of tokens before the last user message in the prompt

tools/server/server-context.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,31 @@ struct server_context_meta {
5252
uint64_t model_size;
5353
};
5454

55+
enum server_state {
56+
// SERVER_STATE_DOWNLOADING,
57+
SERVER_STATE_LOADING,
58+
SERVER_STATE_READY,
59+
SERVER_STATE_SLEEPING,
60+
};
61+
62+
static std::string server_state_to_str(server_state state) {
63+
switch (state) {
64+
case SERVER_STATE_LOADING: return "loading";
65+
case SERVER_STATE_READY: return "ready";
66+
case SERVER_STATE_SLEEPING: return "sleeping";
67+
default: GGML_ASSERT(false && "invalid server_state");
68+
}
69+
}
70+
71+
static server_state server_state_from_str(const std::string & str) {
72+
if (str == "loading") return SERVER_STATE_LOADING;
73+
if (str == "ready") return SERVER_STATE_READY;
74+
if (str == "sleeping") return SERVER_STATE_SLEEPING;
75+
GGML_ASSERT(false && "invalid server_state string");
76+
}
77+
78+
using server_state_callback_t = std::function<void(server_state, json /* payload */)>;
79+
5580
struct server_context {
5681
std::unique_ptr<server_context_impl> impl;
5782

@@ -79,9 +104,8 @@ struct server_context {
79104
// not thread-safe, should only be used from the main thread
80105
server_context_meta get_meta() const;
81106

82-
// register a callback to be called when sleeping state changes
83-
// must be set before load_model() is called
84-
void on_sleeping_changed(std::function<void(bool)> callback);
107+
// note: must be set before load_model() is called
108+
void set_state_callback(server_state_callback_t callback);
85109
};
86110

87111

tools/server/server-models.cpp

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "server-common.h"
22
#include "server-models.h"
3+
#include "server-context.h"
34

45
#include "build-info.h"
56
#include "preset.h"
@@ -44,9 +45,7 @@ extern char **environ;
4445
#define DEFAULT_STOP_TIMEOUT 10 // seconds
4546

4647
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
47-
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep
48-
#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep"
49-
#define CMD_CHILD_TO_ROUTER_INFO "cmd_child_to_router:info:" // followed by json string
48+
#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string
5049

5150
// address for child process, this is needed because router may run on 0.0.0.0
5251
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
@@ -904,12 +903,8 @@ void server_models::load(const std::string & name) {
904903
while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) {
905904
LOG("[%5d] %s", port, buffer);
906905
std::string str(buffer);
907-
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) {
908-
this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
909-
} else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_INFO)) {
910-
this->update_loaded_info(name, str);
911-
} else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) {
912-
this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0);
906+
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) {
907+
this->handle_child_state(name, str);
913908
}
914909
}
915910
} else {
@@ -976,7 +971,10 @@ void server_models::load(const std::string & name) {
976971
subprocess_destroy(&child_proc->get());
977972

978973
// update status and exit code
979-
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
974+
this->update_status(name, {
975+
SERVER_MODEL_STATUS_UNLOADED,
976+
exit_code
977+
});
980978
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
981979
});
982980

@@ -1016,7 +1014,8 @@ struct server_models_download_res : public common_download_callback {
10161014
common_download_model(model, opts);
10171015
is_ok = true;
10181016
} catch (const std::exception & e) {
1019-
SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what());
1017+
auto model_name = model.get_name();
1018+
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
10201019
is_ok = false;
10211020
}
10221021
return is_ok;
@@ -1036,7 +1035,7 @@ struct server_models_download_res : public common_download_callback {
10361035
};
10371036

10381037
void server_models::download(common_params_model && model, common_download_opts && opts) {
1039-
std::string name = model.name;
1038+
std::string name = model.get_name();
10401039
GGML_ASSERT(name == model.hf_repo);
10411040

10421041
std::unique_lock<std::mutex> lk(mutex);
@@ -1064,9 +1063,10 @@ void server_models::download(common_params_model && model, common_download_opts
10641063
inst.th = std::thread([this, dl = std::move(dl)]() {
10651064
dl->opts.callback = dl.get();
10661065
bool ok = dl->run();
1066+
auto model_name = dl->model.get_name();
10671067
SRV_INF("download finished for model name=%s with status=%s\n",
1068-
dl->model.name.c_str(), ok ? "success" : "failure");
1069-
update_download_progress(dl->model.name, {}, true, ok);
1068+
model_name.c_str(), ok ? "success" : "failure");
1069+
update_download_progress(model_name, {}, true, ok);
10701070
// need_reload is set inside update_download_progress under the mutex;
10711071
// the next load_models() call will clean up this instance
10721072
});
@@ -1130,51 +1130,34 @@ void server_models::unload_all() {
11301130
}
11311131
}
11321132

1133-
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
1133+
void server_models::update_status(const std::string & name, const update_status_args & args) {
11341134
std::unique_lock<std::mutex> lk(mutex);
11351135
auto it = mapping.find(name);
11361136
if (it != mapping.end()) {
11371137
auto & meta = it->second.meta;
1138-
meta.status = status;
1139-
meta.exit_code = exit_code;
1138+
meta.status = args.status;
1139+
meta.exit_code = args.exit_code;
1140+
if (!args.loaded_info.is_null()) {
1141+
meta.loaded_info = args.loaded_info;
1142+
}
11401143
}
11411144
// broadcast status change to SSE
11421145
{
11431146
json data = {
1144-
{"status", server_model_status_to_string(status)},
1147+
{"status", server_model_status_to_string(args.status)},
11451148
};
1146-
if (status == SERVER_MODEL_STATUS_UNLOADED) {
1147-
data["exit_code"] = exit_code;
1149+
if (args.status == SERVER_MODEL_STATUS_UNLOADED) {
1150+
data["exit_code"] = args.exit_code;
1151+
}
1152+
if (!args.loaded_info.is_null()) {
1153+
data["info"] = args.loaded_info;
11481154
}
11491155
// note: notify_sse doesn't acquire the lock, so no deadlock here
11501156
notify_sse("status_change", name, data);
11511157
}
11521158
cv.notify_all();
11531159
}
11541160

1155-
void server_models::update_loaded_info(const std::string & name, std::string & raw_info) {
1156-
if (!string_starts_with(raw_info, CMD_CHILD_TO_ROUTER_INFO)) {
1157-
SRV_WRN("invalid loaded info format from child for model name=%s: %s\n", name.c_str(), raw_info.c_str());
1158-
return;
1159-
}
1160-
1161-
json info;
1162-
try {
1163-
info = json::parse(raw_info.substr(strlen(CMD_CHILD_TO_ROUTER_INFO)));
1164-
} catch (const std::exception & e) {
1165-
SRV_WRN("failed to parse loaded info from child for model name=%s: %s\n", name.c_str(), e.what());
1166-
return;
1167-
}
1168-
1169-
std::unique_lock<std::mutex> lk(mutex);
1170-
auto it = mapping.find(name);
1171-
if (it != mapping.end()) {
1172-
auto & meta = it->second.meta;
1173-
meta.loaded_info = info;
1174-
}
1175-
cv.notify_all();
1176-
}
1177-
11781161
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
11791162
json curr;
11801163
{
@@ -1323,21 +1306,54 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
13231306
return proxy;
13241307
}
13251308

1326-
bool server_models::is_child_server() {
1309+
void server_models::handle_child_state(const std::string & name, const std::string & raw_input) {
1310+
server_state state;
1311+
json payload;
1312+
1313+
try {
1314+
json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE)));
1315+
state = server_state_from_str(json_value(data, "state", std::string()));
1316+
payload = json_value(data, "payload", json{});
1317+
} catch (const std::exception & e) {
1318+
SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what());
1319+
return;
1320+
}
1321+
1322+
switch (state) {
1323+
case SERVER_STATE_LOADING:
1324+
{
1325+
// do nothing for now
1326+
// TODO: report loading progress for first load and wakeup from sleep
1327+
} break;
1328+
case SERVER_STATE_READY:
1329+
{
1330+
update_status(name, {
1331+
SERVER_MODEL_STATUS_LOADED,
1332+
0,
1333+
// note: payload can be empty if this is a wakeup from sleep
1334+
payload.size() > 0 ? payload : nullptr
1335+
});
1336+
} break;
1337+
case SERVER_STATE_SLEEPING:
1338+
{
1339+
update_status(name, { SERVER_MODEL_STATUS_SLEEPING });
1340+
} break;
1341+
default:
1342+
// should never happen, but just in case
1343+
GGML_ASSERT(false && "unexpected state from child server");
1344+
}
1345+
}
1346+
1347+
//
1348+
// server_child
1349+
//
1350+
1351+
bool server_child::is_child() {
13271352
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
13281353
return router_port != nullptr;
13291354
}
13301355

1331-
std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler, const json & model_info) {
1332-
// send a notification to the router server that a model instance is ready
1333-
common_log_pause(common_log_main());
1334-
fflush(stdout);
1335-
fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
1336-
fflush(stdout);
1337-
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_INFO, safe_json_to_str(model_info).c_str());
1338-
fflush(stdout);
1339-
common_log_resume(common_log_main());
1340-
1356+
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
13411357
// setup thread for monitoring stdin
13421358
return std::thread([shutdown_handler]() {
13431359
// wait for EOF on stdin
@@ -1363,10 +1379,14 @@ std::thread server_models::setup_child_server(const std::function<void(int)> & s
13631379
});
13641380
}
13651381

1366-
void server_models::notify_router_sleeping_state(bool is_sleeping) {
1382+
void server_child::notify_to_router(const std::string & state, const json & payload) {
1383+
json data = {
1384+
{"state", state},
1385+
{"payload", payload},
1386+
};
13671387
common_log_pause(common_log_main());
13681388
fflush(stdout);
1369-
fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY);
1389+
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
13701390
fflush(stdout);
13711391
common_log_resume(common_log_main());
13721392
}
@@ -1644,7 +1664,6 @@ void server_models_routes::init_routes() {
16441664
common_params_model model;
16451665
common_download_opts opts;
16461666

1647-
model.name = name;
16481667
model.hf_repo = name;
16491668
opts.bearer_token = params.hf_token;
16501669
opts.download_mmproj = true;

0 commit comments

Comments
 (0)