Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f9a281e

Browse files
committed
utils : add t_max_predict_ms param to set prediction phase time limit
1 parent e8d99dd commit f9a281e

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,6 +1774,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
17741774
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
17751775
}
17761776
).set_sparam());
1777+
add_opt(common_arg(
1778+
{"--t-max-predict-ms"}, "MILLISECONDS",
1779+
string_format("time limit in ms for prediction phase; triggers if generation exceeds this time and a new-line was generated (default: %ld)", params.t_max_predict_ms),
1780+
[](common_params & params, const std::string & value) {
1781+
params.t_max_predict_ms = std::stoll(value);
1782+
}
1783+
).set_examples({LLAMA_EXAMPLE_SERVER}));
17771784
add_opt(common_arg(
17781785
{"-s", "--seed"}, "SEED",
17791786
string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct common_params {
347347
int32_t control_vector_layer_end = -1; // layer range for control vector
348348
bool offline = false;
349349

350+
int64_t t_max_predict_ms= 0; // max time in ms to predict after first new line (0 = unlimited)
350351
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
351352
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
352353
// (which is more convenient to use for plotting)

tools/main/main.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,10 @@ int main(int argc, char ** argv) {
562562
embd_inp.push_back(decoder_start_token_id);
563563
}
564564

565+
// Add for --t-max-predict-ms
566+
bool seen_new_line = false;
567+
int64_t t_start_generation = 0;
568+
565569
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
566570
// predict
567571
if (!embd.empty()) {
@@ -739,6 +743,17 @@ int main(int argc, char ** argv) {
739743
// Console/Stream Output
740744
LOG("%s", token_str.c_str());
741745

746+
if (token_str.find('\n') != std::string::npos) {
747+
if (!seen_new_line) {
748+
seen_new_line = true;
749+
t_start_generation = ggml_time_us();
750+
} else if (params.t_max_predict_ms > 0 && (ggml_time_us() - t_start_generation > 1000.0f * params.t_max_predict_ms)) {
751+
LOG_DBG("stopped by time limit, t_max_predict_ms = %d ms\n", (int) params.t_max_predict_ms);
752+
n_remain = 0;
753+
break;
754+
}
755+
}
756+
742757
// Record Displayed Tokens To Log
743758
// Note: Generated tokens are created one by one hence this check
744759
if (embd.size() > 1) {

0 commit comments

Comments
 (0)