From 2562f5ae5615626680f85066acbce8e941664658 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 2 Oct 2024 17:44:22 +0100 Subject: [PATCH] `tool-call`: save/restore prompt cache --- examples/agent/run.py | 56 ++++++++++++++++++++++++-------- examples/server/server.cpp | 65 ++++++++++++++++++++++++++++---------- 2 files changed, 92 insertions(+), 29 deletions(-) diff --git a/examples/agent/run.py b/examples/agent/run.py index 40d18622b5398..2c848da26e36b 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -10,6 +10,7 @@ # /// import json import asyncio +import hashlib import logging import os import aiohttp @@ -157,24 +158,51 @@ async def main( if openai: api_key = os.environ.get('OPENAI_API_KEY') - tool_map, tools = await discover_tools(tools or [], logger=logger) - - sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else ""}\n') + completions_url = f'{endpoint}chat/completions' - messages = [ - dict( - role='user', - content=goal, - ) - ] + tool_map, tools = await discover_tools(tools or [], logger=logger) + headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' } async with aiohttp.ClientSession(headers=headers) as session: + + prompt_session_file = None + if not openai: + prompt_session_file = 'session.' + hashlib.sha256(json.dumps(dict( + model=model, + tools=tools, + )).encode()).hexdigest() + '.bin' + + if os.path.exists(prompt_session_file): + logger.info('Found prompt cache %s', prompt_session_file) + else: + payload = dict( + messages=[dict(role='user', content='')], + tools=tools, + max_tokens=1, + save_filename=prompt_session_file, + ) + logger.info('Computing prompt cache %s', prompt_session_file) + logger.debug('Calling %s: %s', completions_url, json.dumps(payload, indent=2)) + async with aiohttp.ClientSession(headers=headers) as session: + async with session.post(completions_url, json=payload) as response: + logger.debug('Response: %s', response) + response.raise_for_status() + response = await response.json() + + sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else ""}\n') + + messages = [ + dict( + role='user', + content=goal, + ) + ] + for i in range(max_iterations or sys.maxsize): - url = f'{endpoint}chat/completions' payload = dict( messages=messages, model=model, @@ -185,12 +213,14 @@ async def main( seed=seed, cache_prompt=cache_prompt, )) # type: ignore + if prompt_session_file and os.path.exists(prompt_session_file): + payload['restore_filename'] = prompt_session_file - logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2)) - async with session.post(url, json=payload) as response: - logger.debug('Response: %s', response) + logger.debug('Calling %s with %s', completions_url, json.dumps(payload, indent=2)) + async with session.post(completions_url, json=payload) as response: response.raise_for_status() response = await response.json() + logger.debug('Response: %s', response) assert len(response['choices']) == 1 choice = response['choices'][0] diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61b900a085a16..045305a31f2ba 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -126,6 +126,7 @@ struct server_task_result { struct slot_params { bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt + std::string save_filepath; // Where to save the slot data when done. int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half @@ -249,6 +250,34 @@ struct server_slot { return state != SLOT_STATE_IDLE; } + struct restore_results { + size_t token_count; + size_t nread; + }; + + restore_results restore(struct llama_context * ctx, const std::string & filepath) { + cache_tokens.resize(n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), id + 1, cache_tokens.data(), cache_tokens.size(), &token_count); + if (nread == 0) { + cache_tokens.resize(0); + throw std::runtime_error("Unable to restore slot, no available space in KV cache or invalid slot save file"); + } + cache_tokens.resize(token_count); + return {token_count, nread}; + } + + struct save_results { + size_t token_count; + size_t nwrite; + }; + + save_results save(struct llama_context * ctx, const std::string & filepath) const { + const size_t token_count = cache_tokens.size(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), id + 1, cache_tokens.data(), token_count); + return {token_count, nwrite}; + } + void add_token(const completion_token_output & token) { if (!is_processing()) { SLT_WRN(*this, "%s", "slot is not processing\n"); @@ -893,6 +922,7 @@ struct server_context { slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.save_filepath = params.slot_save_path + json_value(data, "save_filename", std::string()); // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -1581,6 +1611,12 @@ struct server_context { break; } + if (task.data.contains("restore_filename")) { + std::string filename = task.data.at("restore_filename"); + std::string filepath = params.slot_save_path + filename; + slot->restore(ctx, filepath); + } + if (task.data.contains("system_prompt")) { std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); system_prompt_set(sys_prompt); @@ -1698,13 +1734,12 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.data.at("filename"); std::string filepath = task.data.at("filepath"); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + auto save_results = slot->save(ctx, filepath); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -1716,8 +1751,9 @@ struct server_context { result.data = json { { "id_slot", id_slot }, { "filename", filename }, - { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written + { "filepath", filepath }, + { "n_saved", save_results.token_count }, // tokens saved + { "n_written", save_results.nwrite }, // bytes written { "timings", { { "save_ms", t_save_ms } } } @@ -1744,15 +1780,7 @@ struct server_context { std::string filename = task.data.at("filename"); std::string filepath = task.data.at("filepath"); - slot->cache_tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); - if (nread == 0) { - slot->cache_tokens.resize(0); - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - slot->cache_tokens.resize(token_count); + auto restore_results = slot->restore(ctx, filepath); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -1763,9 +1791,10 @@ struct server_context { result.error = false; result.data = json { { "id_slot", id_slot }, + { "filepath", filepath }, { "filename", filename }, - { "n_restored", token_count }, // tokens restored - { "n_read", nread }, // bytes read + { "n_restored", restore_results.token_count }, // tokens restored + { "n_read", restore_results.nread }, // bytes read { "timings", { { "restore_ms", t_restore_ms } } } @@ -2284,6 +2313,10 @@ struct server_context { slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + + if (!slot.params.save_filepath.empty()) { + slot.save(ctx, slot.params.save_filepath); + } } slot.i_batch = -1; @@ -2865,7 +2898,7 @@ int main(int argc, char ** argv) { const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); + return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); }; // TODO: maybe merge this function with "handle_completions_generic"