From abf9bc321294b6bcb5f4661df2e8c75244e20ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Kr=C3=BCger?= Date: Tue, 14 Feb 2017 13:57:03 +0100 Subject: [PATCH 1/4] Introduce support for user defined functions --- hdr/sqlite_modern_cpp.h | 201 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 200 insertions(+), 1 deletion(-) diff --git a/hdr/sqlite_modern_cpp.h b/hdr/sqlite_modern_cpp.h index df729faf..91399d8a 100644 --- a/hdr/sqlite_modern_cpp.h +++ b/hdr/sqlite_modern_cpp.h @@ -316,6 +316,32 @@ namespace sqlite { } }; + namespace sql_function_binder { + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + class database { private: std::shared_ptr _db; @@ -362,6 +388,19 @@ namespace sqlite { return sqlite3_last_insert_rowid(_db.get()); } + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + }); + } + }; template @@ -420,6 +459,9 @@ namespace sqlite { } ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); } inline void get_col_from_db(database_binder& db, int inx, int& val) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -428,6 +470,13 @@ namespace sqlite { val = sqlite3_column_int(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, int& val) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_value_int(value); + } + } // sqlite_int64 inline database_binder& operator <<(database_binder& db, const sqlite_int64& val) { @@ -438,6 +487,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); } inline void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -446,6 +498,13 @@ namespace sqlite { i = sqlite3_column_int64(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, sqlite3_int64& i) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_value_int64(value); + } + } // float inline database_binder& operator <<(database_binder& db, const float& val) { @@ -456,6 +515,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); } inline void get_col_from_db(database_binder& db, int inx, float& f) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -464,6 +526,13 @@ namespace sqlite { f = float(sqlite3_column_double(db._stmt.get(), inx)); } } + inline void get_val_from_db(sqlite3_value *value, float& f) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_value_double(value)); + } + } // double inline database_binder& operator <<(database_binder& db, const double& val) { @@ -474,6 +543,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); } inline void get_col_from_db(database_binder& db, int inx, double& d) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -482,6 +554,13 @@ namespace sqlite { d = sqlite3_column_double(db._stmt.get(), inx); } } + inline void get_val_from_db(sqlite3_value *value, double& d) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_value_double(value); + } + } // vector template inline database_binder& operator<<(database_binder& db, const std::vector& vec) { @@ -494,6 +573,11 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } template inline void get_col_from_db(database_binder& db, int inx, std::vector& vec) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { vec.clear(); @@ -503,6 +587,15 @@ namespace sqlite { vec = std::vector(buf, buf + bytes/sizeof(T)); } } + template inline void get_val_from_db(sqlite3_value *value, std::vector& vec) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_value_bytes(value); + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } /* for nullptr support */ inline database_binder& operator <<(database_binder& db, std::nullptr_t) { @@ -512,6 +605,9 @@ namespace sqlite { } ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); } /* for nullptr support */ template inline database_binder& operator <<(database_binder& db, const std::unique_ptr& val) { @@ -532,6 +628,15 @@ namespace sqlite { _ptr_.reset(underling_ptr); } } + template inline void get_val_from_db(sqlite3_value *value, std::unique_ptr& _ptr_) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_val_from_db(value, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } // std::string inline void get_col_from_db(database_binder& db, int inx, std::string & s) { @@ -542,6 +647,14 @@ namespace sqlite { s = std::string(reinterpret_cast(sqlite3_column_text(db._stmt.get(), inx))); } } + inline void get_val_from_db(sqlite3_value *value, std::string & s) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_value_bytes(value); + s = std::string(reinterpret_cast(sqlite3_value_text(value))); + } + } // Convert char* to string to trigger op<<(..., const std::string ) template inline database_binder& operator <<(database_binder& db, const char(&STR)[N]) { return db << std::string(STR); } @@ -555,6 +668,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::string& val) { + sqlite3_result_text(db, val.data(), -1, SQLITE_TRANSIENT); } // std::u16string inline void get_col_from_db(database_binder& db, int inx, std::u16string & w) { @@ -565,6 +681,14 @@ namespace sqlite { w = std::u16string(reinterpret_cast(sqlite3_column_text16(db._stmt.get(), inx))); } } + inline void get_val_from_db(sqlite3_value *value, std::u16string & w) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_value_bytes16(value); + w = std::u16string(reinterpret_cast(sqlite3_value_text16(value))); + } + } inline database_binder& operator <<(database_binder& db, const std::u16string& txt) { @@ -575,6 +699,9 @@ namespace sqlite { ++db._inx; return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::u16string& val) { + sqlite3_result_text16(db, val.data(), -1, SQLITE_TRANSIENT); } // std::optional support for NULL values #ifdef _MODERN_SQLITE_STD_OPTIONAL_SUPPORT @@ -590,13 +717,28 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const std::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } template inline void get_col_from_db(database_binder& db, int inx, std::optional& o) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { o.reset(); } else { OptionalT v; - get_col_from_db(db, inx, v); + get_col_from_db(value, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + OptionalT v; + get_val_from_db(value, v); o = std::move(v); } } @@ -616,6 +758,12 @@ namespace sqlite { ++db._inx; return db; } + template inline void store_result_in_db(sqlite3_context* db, const boost::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } template inline void get_col_from_db(database_binder& db, int inx, boost::optional& o) { if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { @@ -626,6 +774,15 @@ namespace sqlite { o = std::move(v); } } + template inline void get_val_from_db(sqlite3_value *value, boost::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } #endif // Some ppl are lazy so we have a operator for proper prep. statemant handling. @@ -634,4 +791,46 @@ namespace sqlite { // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) template database_binder& operator << (database_binder&& db, const T& val) { return db << val; } + namespace sql_function_binder { + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename utility::function_traits::template argument value{}; + get_val_from_db(vals[sizeof...(Values)], value); + + scalar(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::move(values)...)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } } From 3a14459604cbdbbf73c1eb8ba56f44df46bc6690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Kr=C3=BCger?= Date: Wed, 15 Feb 2017 13:38:32 +0100 Subject: [PATCH 2/4] Add support for custom aggregate functions --- hdr/sqlite_modern_cpp.h | 133 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/hdr/sqlite_modern_cpp.h b/hdr/sqlite_modern_cpp.h index 91399d8a..dc524348 100644 --- a/hdr/sqlite_modern_cpp.h +++ b/hdr/sqlite_modern_cpp.h @@ -317,6 +317,47 @@ namespace sqlite { }; namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + template< std::size_t Count, typename Function, @@ -401,6 +442,22 @@ namespace sqlite { }); } + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + exceptions::throw_sqlite_error(result); + } + }; template @@ -792,6 +849,80 @@ namespace sqlite { template database_binder& operator << (database_binder&& db, const T& val) { return db << val; } namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename utility::function_traits::template argument value{}; + get_val_from_db(vals[sizeof...(Values) - 1], value); + + step(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + }; + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + try { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + template< std::size_t Count, typename Function, @@ -822,7 +953,7 @@ namespace sqlite { ) { try { store_result_in_db(db, - (*static_cast(sqlite3_user_data(db)))(std::move(values)...)); + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); } catch(sqlite_exception &e) { sqlite3_result_error_code(db, e.get_code()); sqlite3_result_error(db, e.what(), -1); From bbfff7bda789180562cf3ea11214f9e9cec53635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Kr=C3=BCger?= Date: Wed, 15 Feb 2017 19:13:45 +0100 Subject: [PATCH 3/4] Added test and documentation --- README.md | 19 ++++++++++++++ tests/functions.cc | 65 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 tests/functions.cc diff --git a/README.md b/README.md index 82f1b4c0..94974f6d 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,25 @@ sqlite::sqlite_exception has a get_code() member function to get the SQLITE3 err catch(sqlite::exceptions::constraint e) { } */ ``` +Custom SQL functions +---- + +To extend SQLite with custom functions, you just implement them in C++: + +```c++ + database db(":memory:"); + db.define("tgamma", [](double i) {return std::tgamma(i);}); + db << "CREATE TABLE numbers (number INTEGER);"; + + for(auto i=0; i!=10; ++i) + db << "INSERT INTO numbers VALUES (?);" << i; + + db << "SELECT number, tgamma(number+1) FROM numbers;" >> [](double number, double factorial) { + cout << number << "! = " << factorial << '\n'; + }; +``` + + NDK support ---- Just Make sure you are using the full path of your database file : diff --git a/tests/functions.cc b/tests/functions.cc new file mode 100644 index 00000000..f1fb073b --- /dev/null +++ b/tests/functions.cc @@ -0,0 +1,65 @@ +#include +#include +#include +#include +using namespace sqlite; +using namespace std; + +int main() +{ + try + { + database db(":memory:"); + + db.define("my_new_concat", [](std::string i, std::string j) {return i+j;}); + db.define("my_new_concat", [](std::string i, std::string j, std::string k) {return i+j+k;}); + db.define("add_integers", [](int i, int j) {return i+j;}); + std::string test1, test3; + int test2 = 0; + db << "select my_new_concat('Hello ','world!')" >> test1; + db << "select add_integers(1,1)" >> test2; + db << "select my_new_concat('a','b','c')" >> test3; + + if(test1 != "Hello world!" || test2 != 2 || test3 != "abc") { + cout << "Wrong result\n"; + exit(EXIT_FAILURE); + } + + db.define("my_count", [](int &i, int) {++i;}, [](int &i) {return i;}); + db.define("my_concat_aggregate", [](std::string &stored, std::string current) {stored += current;}, [](std::string &stored) {return stored;}); + db << "create table countable(i, s)"; + db << "insert into countable values(1, 'a')"; + db << "insert into countable values(2, 'b')"; + db << "insert into countable values(3, 'c')"; + db << "select my_count(i) from countable" >> test2; + db << "select my_concat_aggregate(s) from countable order by i" >> test3; + + if(test2 != 3 || test3 != "abc") { + cout << "Wrong result\n"; + exit(EXIT_FAILURE); + } + + db.define("tgamma", [](double i) {return std::tgamma(i);}); + db << "CREATE TABLE numbers (number INTEGER);"; + + for(auto i=0; i!=10; ++i) + db << "INSERT INTO numbers VALUES (?);" << i; + + db << "SELECT number, tgamma(number+1) FROM numbers;" >> [](double number, double factorial) { + cout << number << "! = " << factorial << '\n'; + }; + } + catch(sqlite_exception e) + { + cout << "Unexpected error " << e.what() << endl; + exit(EXIT_FAILURE); + } + catch(...) + { + cout << "Unknown error\n"; + exit(EXIT_FAILURE); + } + + cout << "OK\n"; + exit(EXIT_SUCCESS); +} From 5f421c21a9b8d3eec228ac8f4d74c0e272c10651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Kr=C3=BCger?= Date: Fri, 17 Feb 2017 10:49:53 +0100 Subject: [PATCH 4/4] Support (const) reference arguments --- hdr/sqlite_modern_cpp.h | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/hdr/sqlite_modern_cpp.h b/hdr/sqlite_modern_cpp.h index dc524348..06d6eeff 100644 --- a/hdr/sqlite_modern_cpp.h +++ b/hdr/sqlite_modern_cpp.h @@ -434,12 +434,13 @@ namespace sqlite { typedef utility::function_traits traits; auto funcPtr = new auto(std::forward(func)); - sqlite3_create_function_v2( + if(int result = sqlite3_create_function_v2( _db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr, sql_function_binder::scalar::type>, nullptr, nullptr, [](void* ptr){ delete static_cast(ptr); - }); + })) + exceptions::throw_sqlite_error(result); } template @@ -455,7 +456,7 @@ namespace sqlite { [](void* ptr){ delete static_cast(ptr); })) - exceptions::throw_sqlite_error(result); + exceptions::throw_sqlite_error(result); } }; @@ -488,7 +489,7 @@ namespace sqlite { Function&& function, Values&&... values ) { - nth_argument_type value{}; + typename std::remove_cv>::type>::type value{}; get_col_from_db(db, sizeof...(Values), value); run(db, function, std::forward(values)..., std::move(value)); @@ -882,7 +883,13 @@ namespace sqlite { sqlite3_value** vals, Values&&... values ) { - typename utility::function_traits::template argument value{}; + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type value{}; get_val_from_db(vals[sizeof...(Values) - 1], value); step(db, count, vals, std::forward(values)..., std::move(value)); @@ -934,7 +941,11 @@ namespace sqlite { sqlite3_value** vals, Values&&... values ) { - typename utility::function_traits::template argument value{}; + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type value{}; get_val_from_db(vals[sizeof...(Values)], value); scalar(db, count, vals, std::forward(values)..., std::move(value));