Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support pickling of Basic objects #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions symengine/lib/pywrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "pywrapper.h"
#include <symengine/serialize-cereal.h>

#if PY_MAJOR_VERSION >= 3
#define PyInt_FromLong PyLong_FromLong
Expand Down Expand Up @@ -269,4 +270,88 @@ int PyFunction::compare(const Basic &o) const {
return unified_compare(get_vec(), s.get_vec());
}

inline PyObject* get_pickle_module() {
static PyObject *module = NULL;
if (module == NULL) {
module = PyImport_ImportModule("pickle");
}
if (module == NULL) {
throw SymEngineException("error importing pickle module.");
}
return module;
}

RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
{
bool is_pysymbol;
std::string name;
ar(is_pysymbol);
ar(name);
if (is_pysymbol) {
std::string pickle_str;
ar(pickle_str);
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
if (obj == NULL) {
throw SerializationError("error when loading pickled symbol subclass object");
}
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
Py_XDECREF(pickle_bytes);
return result;
} else {
return symbol(name);
}
}

void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
{
bool is_pysymbol = is_a_sub<PySymbol>(b);
ar(is_pysymbol);
ar(b.__str__());
if (is_pysymbol) {
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
if (pickle_bytes == NULL) {
throw SerializationError("error when pickling symbol subclass object");
}
Py_ssize_t size;
char* buffer;
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
std::string pickle_str(buffer, size);
ar(pickle_str);
Py_XDECREF(pickle_bytes);
}
}

std::string wrapper_dumps(const Basic &x)
{
std::ostringstream oss;
unsigned short major = SYMENGINE_MAJOR_VERSION;
unsigned short minor = SYMENGINE_MINOR_VERSION;
cereal::PortableBinaryOutputArchive{oss}(major, minor,
x.rcp_from_this());
return oss.str();
}

RCP<const Basic> wrapper_loads(const std::string &serialized)
{
unsigned short major, minor;
RCP<const Basic> obj;
std::istringstream iss(serialized);
cereal::PortableBinaryInputArchive iarchive{iss};
iarchive(major, minor);
if (major != SYMENGINE_MAJOR_VERSION or minor != SYMENGINE_MINOR_VERSION) {
throw SerializationError(StreamFmt()
<< "SymEngine-" << SYMENGINE_MAJOR_VERSION
<< "." << SYMENGINE_MINOR_VERSION
<< " was asked to deserialize an object "
<< "created using SymEngine-" << major << "."
<< minor << ".");
}
iarchive(obj);
return obj;
}

} // SymEngine
3 changes: 3 additions & 0 deletions symengine/lib/pywrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class PyFunction : public FunctionWrapper {
virtual hash_t __hash__() const;
};

std::string wrapper_dumps(const Basic &x);
RCP<const Basic> wrapper_loads(const std::string &s);

}

#endif //SYMENGINE_PYWRAPPER_H
5 changes: 4 additions & 1 deletion symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
unsigned int hash() nogil except +
vec_basic get_args() nogil
int __cmp__(const Basic &o) nogil

ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP<const SymEngine::Number>"
ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic"
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
Expand All @@ -193,7 +194,6 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
bool eq(const Basic &a, const Basic &b) nogil except +
bool neq(const Basic &a, const Basic &b) nogil except +


RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
Expand Down Expand Up @@ -370,6 +370,9 @@ cdef extern from "pywrapper.h" namespace "SymEngine":
PySymbol(string name, PyObject* pyobj)
PyObject* get_py_object()

string wrapper_dumps(const Basic &x) nogil except +
rcp_const_basic wrapper_loads(const string &s) nogil except +

cdef extern from "<symengine/integer.h>" namespace "SymEngine":
cdef cppclass Integer(Number):
Integer(int i) nogil
Expand Down
14 changes: 14 additions & 0 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec):
return result


def load_basic(bytes s):
return c2py(symengine.wrapper_loads(s))


repr_latex=[False]

cdef class Basic(object):
Expand All @@ -836,6 +840,10 @@ cdef class Basic(object):
def __repr__(self):
return self.__str__()

def __reduce__(self):
cdef bytes s = symengine.wrapper_dumps(deref(self.thisptr))
return (load_basic, (s,))

def _repr_latex_(self):
if repr_latex[0]:
return "${}$".format(latex(self))
Expand Down Expand Up @@ -1223,6 +1231,12 @@ cdef class Symbol(Expr):
import sympy
return sympy.Symbol(str(self))

def __reduce__(self):
if type(self) == Symbol:
return Basic.__reduce__(self)
else:
raise NotImplementedError("pickling for Symbol subclass not implemented")

def _sage_(self):
import sage.all as sage
return sage.SR.symbol(str(self))
Expand Down
46 changes: 44 additions & 2 deletions symengine/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
from symengine import symbols, sin, sinh, have_numpy, have_llvm
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
from symengine.utilities import raises
import pickle
import unittest


def test_basic():
x, y, z = symbols('x y z')
expr = sin(cos(x + y)/z)**2
s = pickle.dumps(expr)
expr2 = pickle.loads(s)
assert expr == expr2


class MySymbolBase(Symbol):
def __init__(self, name, attr):
super().__init__(name=name)
self.attr = attr

def __eq__(self, other):
if not isinstance(other, MySymbolBase):
return False
return self.name == other.name and self.attr == other.attr


class MySymbol(MySymbolBase):
def __reduce__(self):
return (self.__class__, (self.name, self.attr))


def test_pysymbol():
a = MySymbol("hello", attr=1)
b = pickle.loads(pickle.dumps(a + 2)) - 2
try:
assert a == b
finally:
a._unsafe_reset()
b._unsafe_reset()

a = MySymbolBase("hello", attr=1)
try:
raises(NotImplementedError, lambda: pickle.dumps(a))
raises(NotImplementedError, lambda: pickle.dumps(a + 2))
finally:
a._unsafe_reset()


@unittest.skipUnless(have_llvm, "No LLVM support")
@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_llvm_double():
Expand All @@ -14,4 +57,3 @@ def test_llvm_double():
ll = pickle.loads(ss)
inp = [1, 2, 3]
assert np.allclose(l(inp), ll(inp))

2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
23abf31763620463500d5fad114d855afd66d011
36ac51d06e248657d828bfa4859cff32ab5f03ba