diff --git a/matplotlibcpp.h b/matplotlibcpp.h index 2797295..ade6878 100644 --- a/matplotlibcpp.h +++ b/matplotlibcpp.h @@ -500,6 +500,101 @@ void plot_surface(const std::vector<::std::vector> &x, if (res) Py_DECREF(res); } +template +void plot_scatter(const std::vector<::std::vector> &x, + const std::vector<::std::vector> &y, + const std::vector<::std::vector> &z) +{ + + //const std::map &keywords = std::map() + + // We lazily load the modules here the first time this function is called + // because I'm not sure that we can assume "matplotlib installed" implies + // "mpl_toolkits installed" on all platforms, and we don't want to require + // it for people who don't need 3d plots. + static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; + if (!mpl_toolkitsmod) { + detail::_interpreter::get(); + + PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); + PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); + if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } + + mpl_toolkitsmod = PyImport_Import(mpl_toolkits); + Py_DECREF(mpl_toolkits); + if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } + + axis3dmod = PyImport_Import(axis3d); + Py_DECREF(axis3d); + if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } + } + + assert(x.size() == y.size()); + assert(y.size() == z.size()); + + // using numpy arrays + PyObject *xarray = get_2darray(x); + PyObject *yarray = get_2darray(y); + PyObject *zarray = get_2darray(z); + + // construct positional args + PyObject *args = PyTuple_New(3); + PyTuple_SetItem(args, 0, xarray); + PyTuple_SetItem(args, 1, yarray); + PyTuple_SetItem(args, 2, zarray); + + // Build up the kw args. + PyObject *kwargs = PyDict_New(); + //PyDict_SetItemString(kwargs, "rstride", PyInt_FromLong(1)); + //PyDict_SetItemString(kwargs, "cstride", PyInt_FromLong(1)); + + // PyObject *python_colormap_coolwarm = PyObject_GetAttrString( \ + detail::_interpreter::get().s_python_colormap, "coolwarm"); + + //PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm); +/* + for (std::map::const_iterator it = keywords.begin(); + it != keywords.end(); ++it) { + PyDict_SetItemString(kwargs, it->first.c_str(), + PyString_FromString(it->second.c_str())); + } +*/ + + PyObject *fig = + PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, + detail::_interpreter::get().s_python_empty_tuple); + if (!fig) throw std::runtime_error("Call to figure() failed."); + + PyObject *gca_kwargs = PyDict_New(); + PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); + + PyObject *gca = PyObject_GetAttrString(fig, "gca"); + if (!gca) throw std::runtime_error("No gca"); + Py_INCREF(gca); + PyObject *axis = PyObject_Call( + gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); + + if (!axis) throw std::runtime_error("No axis"); + Py_INCREF(axis); + + Py_DECREF(gca); + Py_DECREF(gca_kwargs); + + PyObject *scatter = PyObject_GetAttrString(axis, "scatter"); + if (!scatter) throw std::runtime_error("No scatter"); + + Py_INCREF(scatter); + PyObject *res = PyObject_Call(scatter, args, kwargs); + if (!res) throw std::runtime_error("failed scatter"); + Py_DECREF(scatter); + + Py_DECREF(axis); + Py_DECREF(args); + Py_DECREF(kwargs); + if (res) Py_DECREF(res); +} + template bool stem(const std::vector &x, const std::vector &y, const std::map& keywords) {