diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index c321ea39..aaf128d2 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -1,4 +1,5 @@ #include "pywrapper.h" +#include #if PY_MAJOR_VERSION >= 3 #define PyInt_FromLong PyLong_FromLong @@ -269,4 +270,88 @@ int PyFunction::compare(const Basic &o) const { return unified_compare(get_vec(), s.get_vec()); } +inline PyObject* get_pickle_module() { + static PyObject *module = NULL; + if (module == NULL) { + module = PyImport_ImportModule("pickle"); + } + if (module == NULL) { + throw SymEngineException("error importing pickle module."); + } + return module; +} + +RCP load_basic(cereal::PortableBinaryInputArchive &ar, RCP &) +{ + bool is_pysymbol; + std::string name; + ar(is_pysymbol); + ar(name); + if (is_pysymbol) { + std::string pickle_str; + ar(pickle_str); + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size()); + PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes); + if (obj == NULL) { + throw SerializationError("error when loading pickled symbol subclass object"); + } + RCP result = make_rcp(name, obj); + Py_XDECREF(pickle_bytes); + return result; + } else { + return symbol(name); + } +} + +void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) +{ + bool is_pysymbol = is_a_sub(b); + ar(is_pysymbol); + ar(b.__str__()); + if (is_pysymbol) { + RCP p = rcp_static_cast(b.rcp_from_this()); + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object()); + if (pickle_bytes == NULL) { + throw SerializationError("error when pickling symbol subclass object"); + } + Py_ssize_t size; + char* buffer; + PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size); + std::string pickle_str(buffer, size); + ar(pickle_str); + Py_XDECREF(pickle_bytes); + } +} + +std::string wrapper_dumps(const Basic &x) +{ + std::ostringstream oss; + unsigned short major = SYMENGINE_MAJOR_VERSION; + unsigned short minor = SYMENGINE_MINOR_VERSION; + cereal::PortableBinaryOutputArchive{oss}(major, minor, + x.rcp_from_this()); + return oss.str(); +} + +RCP wrapper_loads(const std::string &serialized) +{ + unsigned short major, minor; + RCP obj; + std::istringstream iss(serialized); + cereal::PortableBinaryInputArchive iarchive{iss}; + iarchive(major, minor); + if (major != SYMENGINE_MAJOR_VERSION or minor != SYMENGINE_MINOR_VERSION) { + throw SerializationError(StreamFmt() + << "SymEngine-" << SYMENGINE_MAJOR_VERSION + << "." << SYMENGINE_MINOR_VERSION + << " was asked to deserialize an object " + << "created using SymEngine-" << major << "." + << minor << "."); + } + iarchive(obj); + return obj; +} + } // SymEngine diff --git a/symengine/lib/pywrapper.h b/symengine/lib/pywrapper.h index ba5fe70a..175cc763 100644 --- a/symengine/lib/pywrapper.h +++ b/symengine/lib/pywrapper.h @@ -195,6 +195,9 @@ class PyFunction : public FunctionWrapper { virtual hash_t __hash__() const; }; +std::string wrapper_dumps(const Basic &x); +RCP wrapper_loads(const std::string &s); + } #endif //SYMENGINE_PYWRAPPER_H diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index 9d86d983..10a3b33c 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -183,6 +183,7 @@ cdef extern from "" namespace "SymEngine": unsigned int hash() nogil except + vec_basic get_args() nogil int __cmp__(const Basic &o) nogil + ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP" ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic" ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator" @@ -193,7 +194,6 @@ cdef extern from "" namespace "SymEngine": bool eq(const Basic &a, const Basic &b) nogil except + bool neq(const Basic &a, const Basic &b) nogil except + - RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil @@ -370,6 +370,9 @@ cdef extern from "pywrapper.h" namespace "SymEngine": PySymbol(string name, PyObject* pyobj) PyObject* get_py_object() + string wrapper_dumps(const Basic &x) nogil except + + rcp_const_basic wrapper_loads(const string &s) nogil except + + cdef extern from "" namespace "SymEngine": cdef cppclass Integer(Number): Integer(int i) nogil diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index d66af058..5c5fb669 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -826,6 +826,10 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec): return result +def load_basic(bytes s): + return c2py(symengine.wrapper_loads(s)) + + repr_latex=[False] cdef class Basic(object): @@ -836,6 +840,10 @@ cdef class Basic(object): def __repr__(self): return self.__str__() + def __reduce__(self): + cdef bytes s = symengine.wrapper_dumps(deref(self.thisptr)) + return (load_basic, (s,)) + def _repr_latex_(self): if repr_latex[0]: return "${}$".format(latex(self)) @@ -1223,6 +1231,12 @@ cdef class Symbol(Expr): import sympy return sympy.Symbol(str(self)) + def __reduce__(self): + if type(self) == Symbol: + return Basic.__reduce__(self) + else: + raise NotImplementedError("pickling for Symbol subclass not implemented") + def _sage_(self): import sage.all as sage return sage.SR.symbol(str(self)) diff --git a/symengine/tests/test_pickling.py b/symengine/tests/test_pickling.py index b7f14a18..4e70c0a6 100644 --- a/symengine/tests/test_pickling.py +++ b/symengine/tests/test_pickling.py @@ -1,7 +1,50 @@ -from symengine import symbols, sin, sinh, have_numpy, have_llvm +from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol +from symengine.utilities import raises import pickle import unittest + +def test_basic(): + x, y, z = symbols('x y z') + expr = sin(cos(x + y)/z)**2 + s = pickle.dumps(expr) + expr2 = pickle.loads(s) + assert expr == expr2 + + +class MySymbolBase(Symbol): + def __init__(self, name, attr): + super().__init__(name=name) + self.attr = attr + + def __eq__(self, other): + if not isinstance(other, MySymbolBase): + return False + return self.name == other.name and self.attr == other.attr + + +class MySymbol(MySymbolBase): + def __reduce__(self): + return (self.__class__, (self.name, self.attr)) + + +def test_pysymbol(): + a = MySymbol("hello", attr=1) + b = pickle.loads(pickle.dumps(a + 2)) - 2 + try: + assert a == b + finally: + a._unsafe_reset() + b._unsafe_reset() + + a = MySymbolBase("hello", attr=1) + try: + raises(NotImplementedError, lambda: pickle.dumps(a)) + raises(NotImplementedError, lambda: pickle.dumps(a + 2)) + finally: + a._unsafe_reset() + + @unittest.skipUnless(have_llvm, "No LLVM support") @unittest.skipUnless(have_numpy, "Numpy not installed") def test_llvm_double(): @@ -14,4 +57,3 @@ def test_llvm_double(): ll = pickle.loads(ss) inp = [1, 2, 3] assert np.allclose(l(inp), ll(inp)) - diff --git a/symengine_version.txt b/symengine_version.txt index 845dbe7f..37b09f41 100644 --- a/symengine_version.txt +++ b/symengine_version.txt @@ -1 +1 @@ -23abf31763620463500d5fad114d855afd66d011 +36ac51d06e248657d828bfa4859cff32ab5f03ba