Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9d19657

Browse files
Baggins800Benno Evers
authored and
Benno Evers
committed
Add 3D scatter plots, allow more than one 3d plot on the same figure and make rcparams changeable.
1 parent 80bc9cd commit 9d19657

File tree

1 file changed

+167
-18
lines changed

1 file changed

+167
-18
lines changed

matplotlibcpp.h

+167-18
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct _interpreter {
9999
PyObject *s_python_function_barh;
100100
PyObject *s_python_function_colorbar;
101101
PyObject *s_python_function_subplots_adjust;
102+
PyObject *s_python_function_rcparams;
102103

103104

104105
/* For now, _interpreter is implemented as a singleton since its currently not possible to have
@@ -189,6 +190,7 @@ struct _interpreter {
189190
}
190191

191192
PyObject* matplotlib = PyImport_Import(matplotlibname);
193+
192194
Py_DECREF(matplotlibname);
193195
if (!matplotlib) {
194196
PyErr_Print();
@@ -201,6 +203,8 @@ struct _interpreter {
201203
PyObject_CallMethod(matplotlib, const_cast<char*>("use"), const_cast<char*>("s"), s_backend.c_str());
202204
}
203205

206+
207+
204208
PyObject* pymod = PyImport_Import(pyplotname);
205209
Py_DECREF(pyplotname);
206210
if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); }
@@ -264,6 +268,7 @@ struct _interpreter {
264268
s_python_function_barh = safe_import(pymod, "barh");
265269
s_python_function_colorbar = PyObject_GetAttrString(pymod, "colorbar");
266270
s_python_function_subplots_adjust = safe_import(pymod,"subplots_adjust");
271+
s_python_function_rcparams = PyObject_GetAttrString(pymod, "rcParams");
267272
#ifndef WITHOUT_NUMPY
268273
s_python_function_imshow = safe_import(pymod, "imshow");
269274
#endif
@@ -464,6 +469,7 @@ template <typename Numeric>
464469
void plot_surface(const std::vector<::std::vector<Numeric>> &x,
465470
const std::vector<::std::vector<Numeric>> &y,
466471
const std::vector<::std::vector<Numeric>> &z,
472+
const long fig_number=0,
467473
const std::map<std::string, std::string> &keywords =
468474
std::map<std::string, std::string>())
469475
{
@@ -516,14 +522,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
516522

517523
for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
518524
it != keywords.end(); ++it) {
519-
PyDict_SetItemString(kwargs, it->first.c_str(),
520-
PyString_FromString(it->second.c_str()));
525+
if (it->first == "linewidth" || it->first == "alpha") {
526+
PyDict_SetItemString(kwargs, it->first.c_str(),
527+
PyFloat_FromDouble(std::stod(it->second)));
528+
} else {
529+
PyDict_SetItemString(kwargs, it->first.c_str(),
530+
PyString_FromString(it->second.c_str()));
531+
}
521532
}
522533

523-
524-
PyObject *fig =
525-
PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
526-
detail::_interpreter::get().s_python_empty_tuple);
534+
PyObject *fig_args = PyTuple_New(1);
535+
PyObject* fig = nullptr;
536+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
537+
PyObject *fig_exists =
538+
PyObject_CallObject(
539+
detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
540+
if (!PyObject_IsTrue(fig_exists)) {
541+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
542+
detail::_interpreter::get().s_python_empty_tuple);
543+
} else {
544+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
545+
fig_args);
546+
}
547+
Py_DECREF(fig_exists);
527548
if (!fig) throw std::runtime_error("Call to figure() failed.");
528549

529550
PyObject *gca_kwargs = PyDict_New();
@@ -559,6 +580,7 @@ template <typename Numeric>
559580
void plot3(const std::vector<Numeric> &x,
560581
const std::vector<Numeric> &y,
561582
const std::vector<Numeric> &z,
583+
const long fig_number=0,
562584
const std::map<std::string, std::string> &keywords =
563585
std::map<std::string, std::string>())
564586
{
@@ -607,9 +629,18 @@ void plot3(const std::vector<Numeric> &x,
607629
PyString_FromString(it->second.c_str()));
608630
}
609631

610-
PyObject *fig =
611-
PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
612-
detail::_interpreter::get().s_python_empty_tuple);
632+
PyObject *fig_args = PyTuple_New(1);
633+
PyObject* fig = nullptr;
634+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
635+
PyObject *fig_exists =
636+
PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
637+
if (!PyObject_IsTrue(fig_exists)) {
638+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
639+
detail::_interpreter::get().s_python_empty_tuple);
640+
} else {
641+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
642+
fig_args);
643+
}
613644
if (!fig) throw std::runtime_error("Call to figure() failed.");
614645

615646
PyObject *gca_kwargs = PyDict_New();
@@ -911,6 +942,103 @@ bool scatter(const std::vector<NumericX>& x,
911942
return res;
912943
}
913944

945+
template<typename NumericX, typename NumericY, typename NumericZ>
946+
bool scatter(const std::vector<NumericX>& x,
947+
const std::vector<NumericY>& y,
948+
const std::vector<NumericZ>& z,
949+
const double s=1.0, // The marker size in points**2
950+
const long fig_number=0,
951+
const std::map<std::string, std::string> & keywords = {}) {
952+
detail::_interpreter::get();
953+
954+
// Same as with plot_surface: We lazily load the modules here the first time
955+
// this function is called because I'm not sure that we can assume "matplotlib
956+
// installed" implies "mpl_toolkits installed" on all platforms, and we don't
957+
// want to require it for people who don't need 3d plots.
958+
static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
959+
if (!mpl_toolkitsmod) {
960+
detail::_interpreter::get();
961+
962+
PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
963+
PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
964+
if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
965+
966+
mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
967+
Py_DECREF(mpl_toolkits);
968+
if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
969+
970+
axis3dmod = PyImport_Import(axis3d);
971+
Py_DECREF(axis3d);
972+
if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
973+
}
974+
975+
assert(x.size() == y.size());
976+
assert(y.size() == z.size());
977+
978+
PyObject *xarray = detail::get_array(x);
979+
PyObject *yarray = detail::get_array(y);
980+
PyObject *zarray = detail::get_array(z);
981+
982+
// construct positional args
983+
PyObject *args = PyTuple_New(3);
984+
PyTuple_SetItem(args, 0, xarray);
985+
PyTuple_SetItem(args, 1, yarray);
986+
PyTuple_SetItem(args, 2, zarray);
987+
988+
// Build up the kw args.
989+
PyObject *kwargs = PyDict_New();
990+
991+
for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
992+
it != keywords.end(); ++it) {
993+
PyDict_SetItemString(kwargs, it->first.c_str(),
994+
PyString_FromString(it->second.c_str()));
995+
}
996+
PyObject *fig_args = PyTuple_New(1);
997+
PyObject* fig = nullptr;
998+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
999+
PyObject *fig_exists =
1000+
PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
1001+
if (!PyObject_IsTrue(fig_exists)) {
1002+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
1003+
detail::_interpreter::get().s_python_empty_tuple);
1004+
} else {
1005+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
1006+
fig_args);
1007+
}
1008+
Py_DECREF(fig_exists);
1009+
if (!fig) throw std::runtime_error("Call to figure() failed.");
1010+
1011+
PyObject *gca_kwargs = PyDict_New();
1012+
PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
1013+
1014+
PyObject *gca = PyObject_GetAttrString(fig, "gca");
1015+
if (!gca) throw std::runtime_error("No gca");
1016+
Py_INCREF(gca);
1017+
PyObject *axis = PyObject_Call(
1018+
gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
1019+
1020+
if (!axis) throw std::runtime_error("No axis");
1021+
Py_INCREF(axis);
1022+
1023+
Py_DECREF(gca);
1024+
Py_DECREF(gca_kwargs);
1025+
1026+
PyObject *plot3 = PyObject_GetAttrString(axis, "scatter");
1027+
if (!plot3) throw std::runtime_error("No 3D line plot");
1028+
Py_INCREF(plot3);
1029+
PyObject *res = PyObject_Call(plot3, args, kwargs);
1030+
if (!res) throw std::runtime_error("Failed 3D line plot");
1031+
Py_DECREF(plot3);
1032+
1033+
Py_DECREF(axis);
1034+
Py_DECREF(args);
1035+
Py_DECREF(kwargs);
1036+
Py_DECREF(fig);
1037+
if (res) Py_DECREF(res);
1038+
return res;
1039+
1040+
}
1041+
9141042
template<typename Numeric>
9151043
bool boxplot(const std::vector<std::vector<Numeric>>& data,
9161044
const std::vector<std::string>& labels = {},
@@ -1139,9 +1267,9 @@ bool contour(const std::vector<NumericX>& x, const std::vector<NumericY>& y,
11391267
const std::map<std::string, std::string>& keywords = {}) {
11401268
assert(x.size() == y.size() && x.size() == z.size());
11411269

1142-
PyObject* xarray = get_array(x);
1143-
PyObject* yarray = get_array(y);
1144-
PyObject* zarray = get_array(z);
1270+
PyObject* xarray = detail::get_array(x);
1271+
PyObject* yarray = detail::get_array(y);
1272+
PyObject* zarray = detail::get_array(z);
11451273

11461274
PyObject* plot_args = PyTuple_New(3);
11471275
PyTuple_SetItem(plot_args, 0, xarray);
@@ -2094,12 +2222,14 @@ inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1.
20942222

20952223
// construct keyword args
20962224
PyObject* kwargs = PyDict_New();
2097-
for(std::map<std::string, std::string>::const_iterator it = keywords.begin(); it != keywords.end(); ++it)
2098-
{
2099-
if (it->first == "linewidth" || it->first == "alpha")
2100-
PyDict_SetItemString(kwargs, it->first.c_str(), PyFloat_FromDouble(std::stod(it->second)));
2101-
else
2102-
PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
2225+
for (auto it = keywords.begin(); it != keywords.end(); ++it) {
2226+
if (it->first == "linewidth" || it->first == "alpha") {
2227+
PyDict_SetItemString(kwargs, it->first.c_str(),
2228+
PyFloat_FromDouble(std::stod(it->second)));
2229+
} else {
2230+
PyDict_SetItemString(kwargs, it->first.c_str(),
2231+
PyString_FromString(it->second.c_str()));
2232+
}
21032233
}
21042234

21052235
PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axvspan, args, kwargs);
@@ -2319,6 +2449,25 @@ inline void save(const std::string& filename)
23192449
Py_DECREF(res);
23202450
}
23212451

2452+
inline void rcparams(const std::map<std::string, std::string>& keywords = {}) {
2453+
detail::_interpreter::get();
2454+
PyObject* args = PyTuple_New(0);
2455+
PyObject* kwargs = PyDict_New();
2456+
for (auto it = keywords.begin(); it != keywords.end(); ++it) {
2457+
if ("text.usetex" == it->first)
2458+
PyDict_SetItemString(kwargs, it->first.c_str(), PyLong_FromLong(std::stoi(it->second.c_str())));
2459+
else PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
2460+
}
2461+
2462+
PyObject * update = PyObject_GetAttrString(detail::_interpreter::get().s_python_function_rcparams, "update");
2463+
PyObject * res = PyObject_Call(update, args, kwargs);
2464+
if(!res) throw std::runtime_error("Call to rcParams.update() failed.");
2465+
Py_DECREF(args);
2466+
Py_DECREF(kwargs);
2467+
Py_DECREF(update);
2468+
Py_DECREF(res);
2469+
}
2470+
23222471
inline void clf() {
23232472
detail::_interpreter::get();
23242473

0 commit comments

Comments
 (0)