From b85a0baaf0d1464c9fe0edec11fb5e1d4e29e8ec Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 26 Nov 2021 15:39:28 -0600 Subject: [PATCH 1/8] Support pickling of Basic objects --- symengine/lib/symengine.pxd | 3 +++ symengine/lib/symengine_wrapper.pyx | 8 ++++++++ symengine/tests/test_pickling.py | 11 ++++++++++- symengine_version.txt | 2 +- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index 9d86d9831..5ef0fc1e7 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -183,6 +183,8 @@ cdef extern from "" namespace "SymEngine": unsigned int hash() nogil except + vec_basic get_args() nogil int __cmp__(const Basic &o) nogil + string dumps() nogil except + + ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP" ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic" ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator" @@ -193,6 +195,7 @@ cdef extern from "" namespace "SymEngine": bool eq(const Basic &a, const Basic &b) nogil except + bool neq(const Basic &a, const Basic &b) nogil except + + RCP[const Basic] loads "SymEngine::Basic::loads"(const string &) nogil except + RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index d66af0584..10c5f06ff 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -826,6 +826,10 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec): return result +def load_basic(bytes s): + return c2py(symengine.loads(s)) + + repr_latex=[False] cdef class Basic(object): @@ -836,6 +840,10 @@ cdef class Basic(object): def __repr__(self): return self.__str__() + def __reduce__(self): + cdef bytes s = deref(self.thisptr).dumps() + return (load_basic, (s,)) + def _repr_latex_(self): if repr_latex[0]: return "${}$".format(latex(self)) diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index b7f14a181..75c353b60 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -1,7 +1,16 @@ -from symengine import symbols, sin, sinh, have_numpy, have_llvm +from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos import pickle import unittest + +def test_basic(): + x, y, z = symbols('x y z') + expr = sin(cos(x + y)/z)**2 + s = pickle.dumps(expr) + expr2 = pickle.loads(s) + assert expr == expr2 + + @unittest.skipUnless(have_llvm, "No LLVM support") @unittest.skipUnless(have_numpy, "Numpy not installed") def test_llvm_double(): diff --git a/symengine_version.txt b/symengine_version.txt index 845dbe7f9..f4d17b3e7 100644 --- a/symengine_version.txt +++ b/symengine_version.txt @@ -1 +1 @@ -23abf31763620463500d5fad114d855afd66d011 +4b841d144bbc1ecd4367e4bd7dd4e7c6be8fac05 From f096b4e4e3de30615d78873e85558c72eafaea65 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 30 Nov 2021 23:43:54 -0800 Subject: [PATCH 2/8] support serializing and deserializing pysymbol --- symengine/lib/pywrapper.cpp | 77 +++++++++++++++++++++++++++++ symengine/lib/pywrapper.h | 3 ++ symengine/lib/symengine.pxd | 6 +-- symengine/lib/symengine_wrapper.pyx | 10 +++- symengine/tests/test_pickling.py | 20 +++++++- 5 files changed, 109 insertions(+), 7 deletions(-) diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index c321ea39f..dc714ea69 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -1,4 +1,5 @@ #include "pywrapper.h" +#include #if PY_MAJOR_VERSION >= 3 #define PyInt_FromLong PyLong_FromLong @@ -269,4 +270,80 @@ int PyFunction::compare(const Basic &o) const { return unified_compare(get_vec(), s.get_vec()); } +inline PyObject* get_pickle_module() { + static PyObject *module = NULL; + if (module == NULL) { + module = PyImport_ImportModule("pickle"); + } + return module; +} + +RCP load_basic(cereal::PortableBinaryInputArchive &ar, RCP &) +{ + bool is_pysymbol; + std::string name; + ar(is_pysymbol); + ar(name); + if (is_pysymbol) { + std::string pickle_str; + ar(pickle_str); + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size()); + PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes); + RCP result = make_rcp(name, obj); + Py_XDECREF(pickle_bytes); + return result; + } else { + return symbol(name); + } +} + +void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) +{ + bool is_pysymbol = is_a_sub(b); + ar(is_pysymbol); + ar(b.__str__()); + if (is_pysymbol) { + RCP p = rcp_static_cast(b.rcp_from_this()); + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object()); + Py_ssize_t size; + char* buffer; + PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size); + std::string pickle_str(buffer, size); + ar(pickle_str); + Py_XDECREF(pickle_bytes); + } +} + +std::string wrapper_dumps(const Basic &x) +{ + std::cout << "qwe" << std::endl; + std::ostringstream oss; + unsigned short major = SYMENGINE_MAJOR_VERSION; + unsigned short minor = SYMENGINE_MINOR_VERSION; + cereal::PortableBinaryOutputArchive{oss}(major, minor, + x.rcp_from_this()); + return oss.str(); +} + +RCP wrapper_loads(const std::string &serialized) +{ + unsigned short major, minor; + RCP obj; + std::istringstream iss(serialized); + cereal::PortableBinaryInputArchive iarchive{iss}; + iarchive(major, minor); + if (major != SYMENGINE_MAJOR_VERSION or minor != SYMENGINE_MINOR_VERSION) { + throw SerializationError(StreamFmt() + << "SymEngine-" << SYMENGINE_MAJOR_VERSION + << "." << SYMENGINE_MINOR_VERSION + << " was asked to deserialize an object " + << "created using SymEngine-" << major << "." + << minor << "."); + } + iarchive(obj); + return obj; +} + } // SymEngine diff --git a/symengine/lib/pywrapper.h b/symengine/lib/pywrapper.h index ba5fe70a8..175cc763c 100644 --- a/symengine/lib/pywrapper.h +++ b/symengine/lib/pywrapper.h @@ -195,6 +195,9 @@ class PyFunction : public FunctionWrapper { virtual hash_t __hash__() const; }; +std::string wrapper_dumps(const Basic &x); +RCP wrapper_loads(const std::string &s); + } #endif //SYMENGINE_PYWRAPPER_H diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index 5ef0fc1e7..10a3b33c7 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -183,7 +183,6 @@ cdef extern from "" namespace "SymEngine": unsigned int hash() nogil except + vec_basic get_args() nogil int __cmp__(const Basic &o) nogil - string dumps() nogil except + ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP" ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic" @@ -195,8 +194,6 @@ cdef extern from "" namespace "SymEngine": bool eq(const Basic &a, const Basic &b) nogil except + bool neq(const Basic &a, const Basic &b) nogil except + - RCP[const Basic] loads "SymEngine::Basic::loads"(const string &) nogil except + - RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil @@ -373,6 +370,9 @@ cdef extern from "pywrapper.h" namespace "SymEngine": PySymbol(string name, PyObject* pyobj) PyObject* get_py_object() + string wrapper_dumps(const Basic &x) nogil except + + rcp_const_basic wrapper_loads(const string &s) nogil except + + cdef extern from "" namespace "SymEngine": cdef cppclass Integer(Number): Integer(int i) nogil diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index 10c5f06ff..5c5fb669c 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -827,7 +827,7 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec): def load_basic(bytes s): - return c2py(symengine.loads(s)) + return c2py(symengine.wrapper_loads(s)) repr_latex=[False] @@ -841,7 +841,7 @@ cdef class Basic(object): return self.__str__() def __reduce__(self): - cdef bytes s = deref(self.thisptr).dumps() + cdef bytes s = symengine.wrapper_dumps(deref(self.thisptr)) return (load_basic, (s,)) def _repr_latex_(self): @@ -1231,6 +1231,12 @@ cdef class Symbol(Expr): import sympy return sympy.Symbol(str(self)) + def __reduce__(self): + if type(self) == Symbol: + return Basic.__reduce__(self) + else: + raise NotImplementedError("pickling for Symbol subclass not implemented") + def _sage_(self): import sage.all as sage return sage.SR.symbol(str(self)) diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index 75c353b60..6e75f133d 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -1,4 +1,4 @@ -from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos +from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol import pickle import unittest @@ -11,6 +11,23 @@ def test_basic(): assert expr == expr2 +class MySymbol(Symbol): + def __init__(self, name, attr): + super().__init__(name=name) + self.attr = attr + + def __reduce__(self): + return (self.__class__, (self.name, self.attr)) + + +def test_pysymbol(): + a = MySymbol("hello", attr=1) + b = pickle.loads(pickle.dumps(a)) + assert b.attr == 1 + a._unsafe_reset() + b._unsafe_reset() + + @unittest.skipUnless(have_llvm, "No LLVM support") @unittest.skipUnless(have_numpy, "Numpy not installed") def test_llvm_double(): @@ -23,4 +40,3 @@ def test_llvm_double(): ll = pickle.loads(ss) inp = [1, 2, 3] assert np.allclose(l(inp), ll(inp)) - From 322f2f4f3a096495fcdef545edf8c08a57f50e30 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 00:01:37 -0800 Subject: [PATCH 3/8] handle errors gracefully --- symengine/lib/pywrapper.cpp | 9 +++++++++ symengine/tests/test_pickling.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index dc714ea69..077bd1ede 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -275,6 +275,9 @@ inline PyObject* get_pickle_module() { if (module == NULL) { module = PyImport_ImportModule("pickle"); } + if (module == NULL) { + throw SymEngineException("error importing pickle module.") + } return module; } @@ -290,6 +293,9 @@ RCP load_basic(cereal::PortableBinaryInputArchive &ar, RCP result = make_rcp(name, obj); Py_XDECREF(pickle_bytes); return result; @@ -307,6 +313,9 @@ void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) RCP p = rcp_static_cast(b.rcp_from_this()); PyObject *module = get_pickle_module(); PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object()); + if (pickle_bytes == NULL) { + throw SymEngineException("error when pickling symbol subclass object"); + } Py_ssize_t size; char* buffer; PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size); diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index 6e75f133d..7c814ea63 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -1,4 +1,5 @@ from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol +from symengine.utilities import raises import pickle import unittest @@ -11,22 +12,28 @@ def test_basic(): assert expr == expr2 -class MySymbol(Symbol): +class MySymbolBase(Symbol): def __init__(self, name, attr): super().__init__(name=name) self.attr = attr + +class MySymbol(MySymbolBase): def __reduce__(self): return (self.__class__, (self.name, self.attr)) def test_pysymbol(): a = MySymbol("hello", attr=1) - b = pickle.loads(pickle.dumps(a)) + b = pickle.loads(pickle.dumps(a + 2)) - 2 assert b.attr == 1 a._unsafe_reset() b._unsafe_reset() + a = MySymbolBase("hello", attr=1) + raises(NotImplementedError, lambda: pickle.dumps(a + 2)) + a._unsafe_reset() + @unittest.skipUnless(have_llvm, "No LLVM support") @unittest.skipUnless(have_numpy, "Numpy not installed") From 6aa316b4e55a57e4215b265b793b9a6c03df20d8 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 00:22:17 -0800 Subject: [PATCH 4/8] improve tests --- symengine/tests/test_pickling.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index 7c814ea63..4e70c0a63 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -17,6 +17,11 @@ def __init__(self, name, attr): super().__init__(name=name) self.attr = attr + def __eq__(self, other): + if not isinstance(other, MySymbolBase): + return False + return self.name == other.name and self.attr == other.attr + class MySymbol(MySymbolBase): def __reduce__(self): @@ -26,13 +31,18 @@ def __reduce__(self): def test_pysymbol(): a = MySymbol("hello", attr=1) b = pickle.loads(pickle.dumps(a + 2)) - 2 - assert b.attr == 1 - a._unsafe_reset() - b._unsafe_reset() + try: + assert a == b + finally: + a._unsafe_reset() + b._unsafe_reset() a = MySymbolBase("hello", attr=1) - raises(NotImplementedError, lambda: pickle.dumps(a + 2)) - a._unsafe_reset() + try: + raises(NotImplementedError, lambda: pickle.dumps(a)) + raises(NotImplementedError, lambda: pickle.dumps(a + 2)) + finally: + a._unsafe_reset() @unittest.skipUnless(have_llvm, "No LLVM support") From 80b6a6eb48a6872e9c5d8cbac0dff9ae4993f368 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 00:51:17 -0800 Subject: [PATCH 5/8] bump symengine version --- symengine_version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/symengine_version.txt b/symengine_version.txt index f4d17b3e7..37b09f412 100644 --- a/symengine_version.txt +++ b/symengine_version.txt @@ -1 +1 @@ -4b841d144bbc1ecd4367e4bd7dd4e7c6be8fac05 +36ac51d06e248657d828bfa4859cff32ab5f03ba From 782eef341da16067b1d0cc1257a8bd93820c2912 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 00:54:11 -0800 Subject: [PATCH 6/8] remove debug line --- symengine/lib/pywrapper.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index 077bd1ede..a45a7d237 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -327,7 +327,6 @@ void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) std::string wrapper_dumps(const Basic &x) { - std::cout << "qwe" << std::endl; std::ostringstream oss; unsigned short major = SYMENGINE_MAJOR_VERSION; unsigned short minor = SYMENGINE_MINOR_VERSION; From b6366236a01f77385f13f933566c27134d7cba8c Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 00:55:15 -0800 Subject: [PATCH 7/8] fix typo --- symengine/lib/pywrapper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index a45a7d237..234c481a6 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -276,7 +276,7 @@ inline PyObject* get_pickle_module() { module = PyImport_ImportModule("pickle"); } if (module == NULL) { - throw SymEngineException("error importing pickle module.") + throw SymEngineException("error importing pickle module."); } return module; } From 1ca9c70ac72fe8807d656d39dca4c8bab9a6283b Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 1 Dec 2021 16:12:37 -0600 Subject: [PATCH 8/8] Throw SerializationError instead of SymEngineException --- symengine/lib/pywrapper.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index 234c481a6..aaf128d28 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -294,7 +294,7 @@ RCP load_basic(cereal::PortableBinaryInputArchive &ar, RCP result = make_rcp(name, obj); Py_XDECREF(pickle_bytes); @@ -314,7 +314,7 @@ void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) PyObject *module = get_pickle_module(); PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object()); if (pickle_bytes == NULL) { - throw SymEngineException("error when pickling symbol subclass object"); + throw SerializationError("error when pickling symbol subclass object"); } Py_ssize_t size; char* buffer;