diff --git a/symengine/lib/pywrapper.cpp b/symengine/lib/pywrapper.cpp index aaf128d2..4f45a0b7 100644 --- a/symengine/lib/pywrapper.cpp +++ b/symengine/lib/pywrapper.cpp @@ -281,29 +281,49 @@ inline PyObject* get_pickle_module() { return module; } +PyObject* pickle_loads(const std::string &pickle_str) { + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size()); + PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes); + Py_XDECREF(pickle_bytes); + if (obj == NULL) { + throw SerializationError("error when loading pickled symbol subclass object"); + } + return obj; +} + RCP load_basic(cereal::PortableBinaryInputArchive &ar, RCP &) { bool is_pysymbol; + bool store_pickle; std::string name; ar(is_pysymbol); ar(name); if (is_pysymbol) { std::string pickle_str; ar(pickle_str); - PyObject *module = get_pickle_module(); - PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size()); - PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes); - if (obj == NULL) { - throw SerializationError("error when loading pickled symbol subclass object"); - } - RCP result = make_rcp(name, obj); - Py_XDECREF(pickle_bytes); + ar(store_pickle); + PyObject *obj = pickle_loads(pickle_str); + RCP result = make_rcp(name, obj, store_pickle); + Py_XDECREF(obj); return result; } else { return symbol(name); } } +std::string pickle_dumps(const PyObject * obj) { + PyObject *module = get_pickle_module(); + PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", obj); + if (pickle_bytes == NULL) { + throw SerializationError("error when pickling symbol subclass object"); + } + Py_ssize_t size; + char* buffer; + PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size); + return std::string(buffer, size); +} + void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) { bool is_pysymbol = is_a_sub(b); @@ -311,17 +331,11 @@ void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b) ar(b.__str__()); if (is_pysymbol) { RCP p = rcp_static_cast(b.rcp_from_this()); - PyObject *module = get_pickle_module(); - PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object()); - if (pickle_bytes == NULL) { - throw SerializationError("error when pickling symbol subclass object"); - } - Py_ssize_t size; - char* buffer; - PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size); - std::string pickle_str(buffer, size); + PyObject *obj = p->get_py_object(); + std::string pickle_str = pickle_dumps(obj); ar(pickle_str); - Py_XDECREF(pickle_bytes); + ar(p->store_pickle); + Py_XDECREF(obj); } } diff --git a/symengine/lib/pywrapper.h b/symengine/lib/pywrapper.h index 175cc763..20a6dbee 100644 --- a/symengine/lib/pywrapper.h +++ b/symengine/lib/pywrapper.h @@ -8,6 +8,9 @@ namespace SymEngine { +std::string pickle_dumps(const PyObject *); +PyObject* pickle_loads(const std::string &); + /* * PySymbol is a subclass of Symbol that keeps a reference to a Python object. * When subclassing a Symbol from Python, the information stored in subclassed @@ -27,16 +30,30 @@ namespace SymEngine { class PySymbol : public Symbol { private: PyObject* obj; + std::string bytes; public: - PySymbol(const std::string& name, PyObject* obj) : Symbol(name), obj(obj) { - Py_INCREF(obj); + const bool store_pickle; + PySymbol(const std::string& name, PyObject* obj, bool store_pickle) : + Symbol(name), obj(obj), store_pickle(store_pickle) { + if (store_pickle) { + bytes = pickle_dumps(obj); + } else { + Py_INCREF(obj); + } } PyObject* get_py_object() const { - return obj; + if (store_pickle) { + return pickle_loads(bytes); + } else { + Py_INCREF(obj); + return obj; + } } virtual ~PySymbol() { - // TODO: This is never called because of the cyclic reference. - Py_DECREF(obj); + if (not store_pickle) { + // TODO: This is never called because of the cyclic reference. + Py_DECREF(obj); + } } }; diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index ef94aa19..06353c25 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -195,7 +195,7 @@ cdef extern from "" namespace "SymEngine": bool neq(const Basic &a, const Basic &b) nogil except + RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil - RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil + RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil except + RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast"(rcp_const_basic &b) nogil @@ -367,8 +367,8 @@ cdef extern from "pywrapper.h" namespace "SymEngine": cdef extern from "pywrapper.h" namespace "SymEngine": cdef cppclass PySymbol(Symbol): - PySymbol(string name, PyObject* pyobj) - PyObject* get_py_object() + PySymbol(string name, PyObject* pyobj, bool use_pickle) except + + PyObject* get_py_object() except + string wrapper_dumps(const Basic &x) nogil except + rcp_const_basic wrapper_loads(const string &s) nogil except + @@ -477,7 +477,7 @@ cdef extern from "" namespace "SymEngine": rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp"(string name) nogil rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp"() nogil rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp"(string name) nogil - rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp"(string name, PyObject * pyobj) nogil + rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp"(string name, PyObject * pyobj, bool use_pickle) except + rcp_const_basic make_rcp_Constant "SymEngine::make_rcp"(string name) nogil rcp_const_basic make_rcp_Infty "SymEngine::make_rcp"(RCP[const Number] i) nogil rcp_const_basic make_rcp_NaN "SymEngine::make_rcp"() nogil diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index 5b2858be..09ec7e46 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -46,6 +46,7 @@ cpdef void assign_to_capsule(object capsule, object value): cdef object c2py(rcp_const_basic o): cdef Basic r + cdef PyObject *obj if (symengine.is_a_Add(deref(o))): r = Expr.__new__(Add) elif (symengine.is_a_Mul(deref(o))): @@ -74,7 +75,10 @@ cdef object c2py(rcp_const_basic o): r = Dummy.__new__(Dummy) elif (symengine.is_a_Symbol(deref(o))): if (symengine.is_a_PySymbol(deref(o))): - return (deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object()) + obj = deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object() + result = (obj) + Py_XDECREF(obj); + return result r = Symbol.__new__(Symbol) elif (symengine.is_a_Constant(deref(o))): r = S.Pi @@ -1216,16 +1220,26 @@ cdef class Expr(Basic): cdef class Symbol(Expr): - """ Symbol is a class to store a symbolic variable with a given name. + Subclassing Symbol leads to a memory leak due to a cycle in reference counting. + To avoid this with a performance penalty, set the kwarg store_pickle=True + in the constructor and support the pickle protocol in the subclass by + implmenting __reduce__. """ def __init__(Basic self, name, *args, **kwargs): + cdef cppbool store_pickle; if type(self) == Symbol: self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8")) else: - self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), self) + store_pickle = kwargs.pop("store_pickle", False) + if store_pickle: + # First set the pointer to a regular symbol so that when pickle.dumps + # is called when the PySymbol is created, methods like name works. + self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8")) + self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), self, + store_pickle) def _sympy_(self): import sympy