Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b1237b9

Browse files
ochafikmglambda
authored andcommitted
common: utils to split / join / repeat strings (from json converter) (ggml-org#11342)
* Factor string_join, string_split, string_repeat into common * json: refactor to surface a versatile builder * Update common.cpp
1 parent 7418d8b commit b1237b9

File tree

4 files changed

+90
-64
lines changed

4 files changed

+90
-64
lines changed

common/common.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
484484
s = std::move(builder);
485485
}
486486

487+
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
488+
std::ostringstream result;
489+
for (size_t i = 0; i < values.size(); ++i) {
490+
if (i > 0) {
491+
result << separator;
492+
}
493+
result << values[i];
494+
}
495+
return result.str();
496+
}
497+
498+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
499+
std::vector<std::string> parts;
500+
size_t start = 0;
501+
size_t end = str.find(delimiter);
502+
503+
while (end != std::string::npos) {
504+
parts.push_back(str.substr(start, end - start));
505+
start = end + delimiter.length();
506+
end = str.find(delimiter, start);
507+
}
508+
509+
parts.push_back(str.substr(start));
510+
511+
return parts;
512+
}
513+
514+
std::string string_repeat(const std::string & str, size_t n) {
515+
if (n == 0) {
516+
return "";
517+
}
518+
519+
std::string result;
520+
result.reserve(str.length() * n);
521+
522+
for (size_t i = 0; i < n; ++i) {
523+
result += str;
524+
}
525+
526+
return result;
527+
}
528+
487529
std::string string_from(bool value) {
488530
return value ? "true" : "false";
489531
}

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,10 @@ std::string string_format(const char * fmt, ...);
429429
std::string string_strip(const std::string & str);
430430
std::string string_get_sortable_timestamp();
431431

432+
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
433+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
434+
std::string string_repeat(const std::string & str, size_t n);
435+
432436
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
433437

434438
template<class T>

common/json-schema-to-grammar.cpp

Lines changed: 35 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "json-schema-to-grammar.h"
2+
#include "common.h"
3+
24
#include <algorithm>
35
#include <fstream>
46
#include <map>
@@ -11,11 +13,6 @@
1113

1214
using json = nlohmann::ordered_json;
1315

14-
template <typename Iterator>
15-
static std::string join(Iterator begin, Iterator end, const std::string & separator);
16-
17-
static std::string repeat(const std::string & str, size_t n);
18-
1916
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
2017
auto has_max = max_items != std::numeric_limits<int>::max();
2118

@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
128125
if (sub_len > 0) {
129126
auto from_sub = from.substr(i + 1);
130127
auto to_sub = to.substr(i + 1);
131-
auto sub_zeros = repeat("0", sub_len);
132-
auto sub_nines = repeat("9", sub_len);
128+
auto sub_zeros = string_repeat("0", sub_len);
129+
auto sub_nines = string_repeat("9", sub_len);
133130

134131
auto to_reached = false;
135132
out << "(";
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
188185
auto max_digits = max_s.length();
189186

190187
for (auto digits = min_digits; digits < max_digits; digits++) {
191-
uniform_range(min_s, repeat("9", digits));
192-
min_s = "1" + repeat("0", digits);
188+
uniform_range(min_s, string_repeat("9", digits));
189+
min_s = "1" + string_repeat("0", digits);
193190
out << " | ";
194191
}
195192
uniform_range(min_s, max_s);
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
318315
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
319316
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
320317

321-
template <typename Iterator>
322-
std::string join(Iterator begin, Iterator end, const std::string & separator) {
323-
std::ostringstream result;
324-
if (begin != end) {
325-
result << *begin;
326-
for (Iterator it = begin + 1; it != end; ++it) {
327-
result << separator << *it;
328-
}
329-
}
330-
return result.str();
331-
}
332-
333-
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
334-
std::vector<std::string> tokens;
335-
size_t start = 0;
336-
size_t end = str.find(delimiter);
337-
338-
while (end != std::string::npos) {
339-
tokens.push_back(str.substr(start, end - start));
340-
start = end + delimiter.length();
341-
end = str.find(delimiter, start);
342-
}
343-
344-
tokens.push_back(str.substr(start));
345-
346-
return tokens;
347-
}
348-
349-
static std::string repeat(const std::string & str, size_t n) {
350-
if (n == 0) {
351-
return "";
352-
}
353-
354-
std::string result;
355-
result.reserve(str.length() * n);
356-
357-
for (size_t i = 0; i < n; ++i) {
358-
result += str;
359-
}
360-
361-
return result;
362-
}
363-
364318
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
365319
std::smatch match;
366320
std::string result;
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
389343

390344
class SchemaConverter {
391345
private:
346+
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
392347
std::function<json(const std::string &)> _fetch_json;
393348
bool _dotall;
394349
std::map<std::string, std::string> _rules;
@@ -418,7 +373,7 @@ class SchemaConverter {
418373
for (size_t i = 0; i < alt_schemas.size(); i++) {
419374
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
420375
}
421-
return join(rules.begin(), rules.end(), " | ");
376+
return string_join(rules, " | ");
422377
}
423378

424379
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -481,7 +436,7 @@ class SchemaConverter {
481436
for (const auto & item : ret) {
482437
results.push_back(to_rule(item));
483438
}
484-
return std::make_pair(join(results.begin(), results.end(), " "), false);
439+
return std::make_pair(string_join(results, " "), false);
485440
};
486441

487442
while (i < length) {
@@ -539,7 +494,7 @@ class SchemaConverter {
539494
}
540495
curly_brackets += '}';
541496
i++;
542-
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
497+
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
543498
int min_times = 0;
544499
int max_times = std::numeric_limits<int>::max();
545500
try {
@@ -854,7 +809,7 @@ class SchemaConverter {
854809
return;
855810
}
856811
std::string pointer = ref.substr(ref.find('#') + 1);
857-
std::vector<std::string> tokens = split(pointer, "/");
812+
std::vector<std::string> tokens = string_split(pointer, "/");
858813
for (size_t i = 1; i < tokens.size(); ++i) {
859814
std::string sel = tokens[i];
860815
if (target.is_null() || !target.contains(sel)) {
@@ -905,7 +860,7 @@ class SchemaConverter {
905860
for (const auto & v : schema["enum"]) {
906861
enum_values.push_back(_generate_constant_rule(v));
907862
}
908-
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
863+
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
909864
} else if ((schema_type.is_null() || schema_type == "object")
910865
&& (schema.contains("properties") ||
911866
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -1019,10 +974,10 @@ class SchemaConverter {
1019974

1020975
void check_errors() {
1021976
if (!_errors.empty()) {
1022-
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
977+
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
1023978
}
1024979
if (!_warnings.empty()) {
1025-
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
980+
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
1026981
}
1027982
}
1028983

@@ -1036,10 +991,27 @@ class SchemaConverter {
1036991
};
1037992

1038993
std::string json_schema_to_grammar(const json & schema) {
1039-
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
1040-
auto copy = schema;
1041-
converter.resolve_refs(copy, "input");
1042-
converter.visit(copy, "");
994+
return build_grammar([&](const llama_grammar_builder & callbacks) {
995+
auto copy = schema;
996+
callbacks.resolve_refs(copy);
997+
callbacks.add_schema("", copy);
998+
});
999+
}
1000+
1001+
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
1002+
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
1003+
llama_grammar_builder builder {
1004+
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1005+
return converter._add_rule(name, rule);
1006+
},
1007+
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1008+
return converter.visit(schema, name == "root" ? "" : name);
1009+
},
1010+
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1011+
converter.resolve_refs(schema, "");
1012+
}
1013+
};
1014+
cb(builder);
10431015
converter.check_errors();
10441016
return converter.format_grammar();
10451017
}

common/json-schema-to-grammar.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@
55
#define JSON_ASSERT GGML_ASSERT
66
#include "json.hpp"
77

8-
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
8+
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
9+
10+
struct llama_grammar_builder {
11+
std::function<std::string(const std::string &, const std::string &)> add_rule;
12+
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
13+
std::function<void(nlohmann::ordered_json &)> resolve_refs;
14+
};
15+
16+
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);

0 commit comments

Comments
 (0)