Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 69e58fe

Browse files
AustinSchuhBenno Evers
authored and
Benno Evers
committed
Add plot_surface
It probably hard-codes too many defaults, but it works well enough for my purposes.
1 parent 65933e3 commit 69e58fe

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,32 @@ int main()
159159
160160
![quiver example](./examples/quiver.png)
161161
162+
When working with 3d functions, you might be interested in 3d plots:
163+
```cpp
164+
#include "../matplotlibcpp.h"
165+
166+
namespace plt = matplotlibcpp;
167+
168+
int main()
169+
{
170+
std::vector<std::vector<double>> x, y, z;
171+
for (double i = -5; i <= 5; i += 0.25) {
172+
std::vector<double> x_row, y_row, z_row;
173+
for (double j = -5; j <= 5; j += 0.25) {
174+
x_row.push_back(i);
175+
y_row.push_back(j);
176+
z_row.push_back(::std::sin(::std::hypot(x, y)));
177+
}
178+
x.push_back(x_row);
179+
y.push_back(y_row);
180+
z.push_back(z_row);
181+
}
182+
183+
plt::plot_surface(x, y, z);
184+
plt::show();
185+
}
186+
```
187+
162188
Installation
163189
------------
164190

matplotlibcpp.h

+112-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct _interpreter {
5959
PyObject *s_python_function_errorbar;
6060
PyObject *s_python_function_annotate;
6161
PyObject *s_python_function_tight_layout;
62+
PyObject *s_python_colormap;
6263
PyObject *s_python_empty_tuple;
6364
PyObject *s_python_function_stem;
6465
PyObject *s_python_function_xkcd;
@@ -115,9 +116,13 @@ struct _interpreter {
115116

116117
PyObject* matplotlibname = PyString_FromString("matplotlib");
117118
PyObject* pyplotname = PyString_FromString("matplotlib.pyplot");
119+
PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
120+
PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
118121
PyObject* pylabname = PyString_FromString("pylab");
119-
if (!pyplotname || !pylabname || !matplotlibname) {
120-
throw std::runtime_error("couldnt create string");
122+
PyObject* cmname = PyString_FromString("matplotlib.cm");
123+
if (!pyplotname || !pylabname || !matplotlibname || !mpl_toolkits ||
124+
!axis3d || !cmname) {
125+
throw std::runtime_error("couldnt create string");
121126
}
122127

123128
PyObject* matplotlib = PyImport_Import(matplotlibname);
@@ -134,11 +139,22 @@ struct _interpreter {
134139
Py_DECREF(pyplotname);
135140
if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); }
136141

142+
s_python_colormap = PyImport_Import(cmname);
143+
Py_DECREF(cmname);
144+
if (!s_python_colormap) { throw std::runtime_error("Error loading module matplotlib.cm!"); }
137145

138146
PyObject* pylabmod = PyImport_Import(pylabname);
139147
Py_DECREF(pylabname);
140148
if (!pylabmod) { throw std::runtime_error("Error loading module pylab!"); }
141149

150+
PyObject* mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
151+
Py_DECREF(mpl_toolkitsmod);
152+
if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
153+
154+
PyObject* axis3dmod = PyImport_Import(axis3d);
155+
Py_DECREF(axis3dmod);
156+
if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
157+
142158
s_python_function_show = PyObject_GetAttrString(pymod, "show");
143159
s_python_function_close = PyObject_GetAttrString(pymod, "close");
144160
s_python_function_draw = PyObject_GetAttrString(pymod, "draw");
@@ -325,6 +341,30 @@ PyObject* get_array(const std::vector<Numeric>& v)
325341
return varray;
326342
}
327343

344+
template<typename Numeric>
345+
PyObject* get_2darray(const std::vector<::std::vector<Numeric>>& v)
346+
{
347+
detail::_interpreter::get(); //interpreter needs to be initialized for the numpy commands to work
348+
if (v.size() < 1) throw std::runtime_error("get_2d_array v too small");
349+
350+
npy_intp vsize[2] = {static_cast<npy_intp>(v.size()),
351+
static_cast<npy_intp>(v[0].size())};
352+
353+
PyArrayObject *varray =
354+
(PyArrayObject *)PyArray_SimpleNew(2, vsize, NPY_DOUBLE);
355+
356+
double *vd_begin = static_cast<double *>(PyArray_DATA(varray));
357+
358+
for (const ::std::vector<Numeric> &v_row : v) {
359+
if (v_row.size() != static_cast<size_t>(vsize[1]))
360+
throw std::runtime_error("Missmatched array size");
361+
std::copy(v_row.begin(), v_row.end(), vd_begin);
362+
vd_begin += vsize[1];
363+
}
364+
365+
return reinterpret_cast<PyObject *>(varray);
366+
}
367+
328368
#else // fallback if we don't have numpy: copy every element of the given vector
329369

330370
template<typename Numeric>
@@ -369,6 +409,76 @@ bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const st
369409
return res;
370410
}
371411

412+
template <typename Numeric>
413+
void plot_surface(const std::vector<::std::vector<Numeric>> &x,
414+
const std::vector<::std::vector<Numeric>> &y,
415+
const std::vector<::std::vector<Numeric>> &z,
416+
const std::map<std::string, std::string> &keywords =
417+
std::map<std::string, std::string>()) {
418+
assert(x.size() == y.size());
419+
assert(y.size() == z.size());
420+
421+
// using numpy arrays
422+
PyObject *xarray = get_2darray(x);
423+
PyObject *yarray = get_2darray(y);
424+
PyObject *zarray = get_2darray(z);
425+
426+
// construct positional args
427+
PyObject *args = PyTuple_New(3);
428+
PyTuple_SetItem(args, 0, xarray);
429+
PyTuple_SetItem(args, 1, yarray);
430+
PyTuple_SetItem(args, 2, zarray);
431+
432+
// Build up the kw args.
433+
PyObject *kwargs = PyDict_New();
434+
PyDict_SetItemString(kwargs, "rstride", PyInt_FromLong(1));
435+
PyDict_SetItemString(kwargs, "cstride", PyInt_FromLong(1));
436+
437+
PyObject *python_colormap_coolwarm = PyObject_GetAttrString(
438+
detail::_interpreter::get().s_python_colormap, "coolwarm");
439+
440+
PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm);
441+
442+
for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
443+
it != keywords.end(); ++it) {
444+
PyDict_SetItemString(kwargs, it->first.c_str(),
445+
PyString_FromString(it->second.c_str()));
446+
}
447+
448+
449+
PyObject *fig =
450+
PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
451+
detail::_interpreter::get().s_python_empty_tuple);
452+
if (!fig) throw std::runtime_error("Call to figure() failed.");
453+
454+
PyObject *gca_kwargs = PyDict_New();
455+
PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
456+
457+
PyObject *gca = PyObject_GetAttrString(fig, "gca");
458+
if (!gca) throw std::runtime_error("No gca");
459+
Py_INCREF(gca);
460+
PyObject *axis = PyObject_Call(
461+
gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
462+
463+
if (!axis) throw std::runtime_error("No axis");
464+
Py_INCREF(axis);
465+
466+
Py_DECREF(gca);
467+
Py_DECREF(gca_kwargs);
468+
469+
PyObject *plot_surface = PyObject_GetAttrString(axis, "plot_surface");
470+
if (!plot_surface) throw std::runtime_error("No surface");
471+
Py_INCREF(plot_surface);
472+
PyObject *res = PyObject_Call(plot_surface, args, kwargs);
473+
if (!res) throw std::runtime_error("failed surface");
474+
Py_DECREF(plot_surface);
475+
476+
Py_DECREF(axis);
477+
Py_DECREF(args);
478+
Py_DECREF(kwargs);
479+
if (res) Py_DECREF(res);
480+
}
481+
372482
template<typename Numeric>
373483
bool stem(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
374484
{

0 commit comments

Comments
 (0)