1
1
#include " json-schema-to-grammar.h"
2
+ #include " common.h"
3
+
2
4
#include < algorithm>
3
5
#include < fstream>
4
6
#include < map>
11
13
12
14
using json = nlohmann::ordered_json;
13
15
14
- template <typename Iterator>
15
- static std::string join (Iterator begin, Iterator end, const std::string & separator);
16
-
17
- static std::string repeat (const std::string & str, size_t n);
18
-
19
16
static std::string build_repetition (const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = " " ) {
20
17
auto has_max = max_items != std::numeric_limits<int >::max ();
21
18
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
128
125
if (sub_len > 0 ) {
129
126
auto from_sub = from.substr (i + 1 );
130
127
auto to_sub = to.substr (i + 1 );
131
- auto sub_zeros = repeat (" 0" , sub_len);
132
- auto sub_nines = repeat (" 9" , sub_len);
128
+ auto sub_zeros = string_repeat (" 0" , sub_len);
129
+ auto sub_nines = string_repeat (" 9" , sub_len);
133
130
134
131
auto to_reached = false ;
135
132
out << " (" ;
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
188
185
auto max_digits = max_s.length ();
189
186
190
187
for (auto digits = min_digits; digits < max_digits; digits++) {
191
- uniform_range (min_s, repeat (" 9" , digits));
192
- min_s = " 1" + repeat (" 0" , digits);
188
+ uniform_range (min_s, string_repeat (" 9" , digits));
189
+ min_s = " 1" + string_repeat (" 0" , digits);
193
190
out << " | " ;
194
191
}
195
192
uniform_range (min_s, max_s);
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
318
315
std::unordered_set<char > NON_LITERAL_SET = {' |' , ' .' , ' (' , ' )' , ' [' , ' ]' , ' {' , ' }' , ' *' , ' +' , ' ?' };
319
316
std::unordered_set<char > ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {' ^' , ' $' , ' .' , ' [' , ' ]' , ' (' , ' )' , ' |' , ' {' , ' }' , ' *' , ' +' , ' ?' };
320
317
321
- template <typename Iterator>
322
- std::string join (Iterator begin, Iterator end, const std::string & separator) {
323
- std::ostringstream result;
324
- if (begin != end) {
325
- result << *begin;
326
- for (Iterator it = begin + 1 ; it != end; ++it) {
327
- result << separator << *it;
328
- }
329
- }
330
- return result.str ();
331
- }
332
-
333
- static std::vector<std::string> split (const std::string & str, const std::string & delimiter) {
334
- std::vector<std::string> tokens;
335
- size_t start = 0 ;
336
- size_t end = str.find (delimiter);
337
-
338
- while (end != std::string::npos) {
339
- tokens.push_back (str.substr (start, end - start));
340
- start = end + delimiter.length ();
341
- end = str.find (delimiter, start);
342
- }
343
-
344
- tokens.push_back (str.substr (start));
345
-
346
- return tokens;
347
- }
348
-
349
- static std::string repeat (const std::string & str, size_t n) {
350
- if (n == 0 ) {
351
- return " " ;
352
- }
353
-
354
- std::string result;
355
- result.reserve (str.length () * n);
356
-
357
- for (size_t i = 0 ; i < n; ++i) {
358
- result += str;
359
- }
360
-
361
- return result;
362
- }
363
-
364
318
static std::string replacePattern (const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
365
319
std::smatch match;
366
320
std::string result;
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
389
343
390
344
class SchemaConverter {
391
345
private:
346
+ friend std::string build_grammar (const std::function<void (const llama_grammar_builder &)> & cb);
392
347
std::function<json(const std::string &)> _fetch_json;
393
348
bool _dotall;
394
349
std::map<std::string, std::string> _rules;
@@ -418,7 +373,7 @@ class SchemaConverter {
418
373
for (size_t i = 0 ; i < alt_schemas.size (); i++) {
419
374
rules.push_back (visit (alt_schemas[i], name + (name.empty () ? " alternative-" : " -" ) + std::to_string (i)));
420
375
}
421
- return join (rules. begin (), rules. end () , " | " );
376
+ return string_join (rules, " | " );
422
377
}
423
378
424
379
std::string _visit_pattern (const std::string & pattern, const std::string & name) {
@@ -481,7 +436,7 @@ class SchemaConverter {
481
436
for (const auto & item : ret) {
482
437
results.push_back (to_rule (item));
483
438
}
484
- return std::make_pair (join (results. begin (), results. end () , " " ), false );
439
+ return std::make_pair (string_join (results, " " ), false );
485
440
};
486
441
487
442
while (i < length) {
@@ -539,7 +494,7 @@ class SchemaConverter {
539
494
}
540
495
curly_brackets += ' }' ;
541
496
i++;
542
- auto nums = split (curly_brackets.substr (1 , curly_brackets.length () - 2 ), " ," );
497
+ auto nums = string_split (curly_brackets.substr (1 , curly_brackets.length () - 2 ), " ," );
543
498
int min_times = 0 ;
544
499
int max_times = std::numeric_limits<int >::max ();
545
500
try {
@@ -854,7 +809,7 @@ class SchemaConverter {
854
809
return ;
855
810
}
856
811
std::string pointer = ref.substr (ref.find (' #' ) + 1 );
857
- std::vector<std::string> tokens = split (pointer, " /" );
812
+ std::vector<std::string> tokens = string_split (pointer, " /" );
858
813
for (size_t i = 1 ; i < tokens.size (); ++i) {
859
814
std::string sel = tokens[i];
860
815
if (target.is_null () || !target.contains (sel)) {
@@ -905,7 +860,7 @@ class SchemaConverter {
905
860
for (const auto & v : schema[" enum" ]) {
906
861
enum_values.push_back (_generate_constant_rule (v));
907
862
}
908
- return _add_rule (rule_name, " (" + join (enum_values. begin (), enum_values. end () , " | " ) + " ) space" );
863
+ return _add_rule (rule_name, " (" + string_join (enum_values, " | " ) + " ) space" );
909
864
} else if ((schema_type.is_null () || schema_type == " object" )
910
865
&& (schema.contains (" properties" ) ||
911
866
(schema.contains (" additionalProperties" ) && schema[" additionalProperties" ] != true ))) {
@@ -1019,10 +974,10 @@ class SchemaConverter {
1019
974
1020
975
void check_errors () {
1021
976
if (!_errors.empty ()) {
1022
- throw std::runtime_error (" JSON schema conversion failed:\n " + join (_errors. begin (), _errors. end () , " \n " ));
977
+ throw std::runtime_error (" JSON schema conversion failed:\n " + string_join (_errors, " \n " ));
1023
978
}
1024
979
if (!_warnings.empty ()) {
1025
- fprintf (stderr, " WARNING: JSON schema conversion was incomplete: %s\n " , join (_warnings. begin (), _warnings. end () , " ; " ).c_str ());
980
+ fprintf (stderr, " WARNING: JSON schema conversion was incomplete: %s\n " , string_join (_warnings, " ; " ).c_str ());
1026
981
}
1027
982
}
1028
983
@@ -1036,10 +991,27 @@ class SchemaConverter {
1036
991
};
1037
992
1038
993
std::string json_schema_to_grammar (const json & schema) {
1039
- SchemaConverter converter ([](const std::string &) { return json::object (); }, /* dotall= */ false );
1040
- auto copy = schema;
1041
- converter.resolve_refs (copy, " input" );
1042
- converter.visit (copy, " " );
994
+ return build_grammar ([&](const llama_grammar_builder & callbacks) {
995
+ auto copy = schema;
996
+ callbacks.resolve_refs (copy);
997
+ callbacks.add_schema (" " , copy);
998
+ });
999
+ }
1000
+
1001
+ std::string build_grammar (const std::function<void (const llama_grammar_builder &)> & cb) {
1002
+ SchemaConverter converter ([&](const std::string &) { return json (); }, /* dotall= */ false );
1003
+ llama_grammar_builder builder {
1004
+ /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1005
+ return converter._add_rule (name, rule);
1006
+ },
1007
+ /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1008
+ return converter.visit (schema, name == " root" ? " " : name);
1009
+ },
1010
+ /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1011
+ converter.resolve_refs (schema, " " );
1012
+ }
1013
+ };
1014
+ cb (builder);
1043
1015
converter.check_errors ();
1044
1016
return converter.format_grammar ();
1045
1017
}
0 commit comments