@@ -59,6 +59,7 @@ struct _interpreter {
59
59
PyObject *s_python_function_errorbar;
60
60
PyObject *s_python_function_annotate;
61
61
PyObject *s_python_function_tight_layout;
62
+ PyObject *s_python_colormap;
62
63
PyObject *s_python_empty_tuple;
63
64
PyObject *s_python_function_stem;
64
65
PyObject *s_python_function_xkcd;
@@ -115,9 +116,13 @@ struct _interpreter {
115
116
116
117
PyObject* matplotlibname = PyString_FromString (" matplotlib" );
117
118
PyObject* pyplotname = PyString_FromString (" matplotlib.pyplot" );
119
+ PyObject* mpl_toolkits = PyString_FromString (" mpl_toolkits" );
120
+ PyObject* axis3d = PyString_FromString (" mpl_toolkits.mplot3d" );
118
121
PyObject* pylabname = PyString_FromString (" pylab" );
119
- if (!pyplotname || !pylabname || !matplotlibname) {
120
- throw std::runtime_error (" couldnt create string" );
122
+ PyObject* cmname = PyString_FromString (" matplotlib.cm" );
123
+ if (!pyplotname || !pylabname || !matplotlibname || !mpl_toolkits ||
124
+ !axis3d || !cmname) {
125
+ throw std::runtime_error (" couldnt create string" );
121
126
}
122
127
123
128
PyObject* matplotlib = PyImport_Import (matplotlibname);
@@ -134,11 +139,22 @@ struct _interpreter {
134
139
Py_DECREF (pyplotname);
135
140
if (!pymod) { throw std::runtime_error (" Error loading module matplotlib.pyplot!" ); }
136
141
142
+ s_python_colormap = PyImport_Import (cmname);
143
+ Py_DECREF (cmname);
144
+ if (!s_python_colormap) { throw std::runtime_error (" Error loading module matplotlib.cm!" ); }
137
145
138
146
PyObject* pylabmod = PyImport_Import (pylabname);
139
147
Py_DECREF (pylabname);
140
148
if (!pylabmod) { throw std::runtime_error (" Error loading module pylab!" ); }
141
149
150
+ PyObject* mpl_toolkitsmod = PyImport_Import (mpl_toolkits);
151
+ Py_DECREF (mpl_toolkitsmod);
152
+ if (!mpl_toolkitsmod) { throw std::runtime_error (" Error loading module mpl_toolkits!" ); }
153
+
154
+ PyObject* axis3dmod = PyImport_Import (axis3d);
155
+ Py_DECREF (axis3dmod);
156
+ if (!axis3dmod) { throw std::runtime_error (" Error loading module mpl_toolkits.mplot3d!" ); }
157
+
142
158
s_python_function_show = PyObject_GetAttrString (pymod, " show" );
143
159
s_python_function_close = PyObject_GetAttrString (pymod, " close" );
144
160
s_python_function_draw = PyObject_GetAttrString (pymod, " draw" );
@@ -325,6 +341,30 @@ PyObject* get_array(const std::vector<Numeric>& v)
325
341
return varray;
326
342
}
327
343
344
+ template <typename Numeric>
345
+ PyObject* get_2darray (const std::vector<::std::vector<Numeric>>& v)
346
+ {
347
+ detail::_interpreter::get (); // interpreter needs to be initialized for the numpy commands to work
348
+ if (v.size () < 1 ) throw std::runtime_error (" get_2d_array v too small" );
349
+
350
+ npy_intp vsize[2 ] = {static_cast <npy_intp>(v.size ()),
351
+ static_cast <npy_intp>(v[0 ].size ())};
352
+
353
+ PyArrayObject *varray =
354
+ (PyArrayObject *)PyArray_SimpleNew (2 , vsize, NPY_DOUBLE);
355
+
356
+ double *vd_begin = static_cast <double *>(PyArray_DATA (varray));
357
+
358
+ for (const ::std::vector<Numeric> &v_row : v) {
359
+ if (v_row.size () != static_cast <size_t >(vsize[1 ]))
360
+ throw std::runtime_error (" Missmatched array size" );
361
+ std::copy (v_row.begin (), v_row.end (), vd_begin);
362
+ vd_begin += vsize[1 ];
363
+ }
364
+
365
+ return reinterpret_cast <PyObject *>(varray);
366
+ }
367
+
328
368
#else // fallback if we don't have numpy: copy every element of the given vector
329
369
330
370
template <typename Numeric>
@@ -369,6 +409,76 @@ bool plot(const std::vector<Numeric> &x, const std::vector<Numeric> &y, const st
369
409
return res;
370
410
}
371
411
412
+ template <typename Numeric>
413
+ void plot_surface (const std::vector<::std::vector<Numeric>> &x,
414
+ const std::vector<::std::vector<Numeric>> &y,
415
+ const std::vector<::std::vector<Numeric>> &z,
416
+ const std::map<std::string, std::string> &keywords =
417
+ std::map<std::string, std::string>()) {
418
+ assert (x.size () == y.size ());
419
+ assert (y.size () == z.size ());
420
+
421
+ // using numpy arrays
422
+ PyObject *xarray = get_2darray (x);
423
+ PyObject *yarray = get_2darray (y);
424
+ PyObject *zarray = get_2darray (z);
425
+
426
+ // construct positional args
427
+ PyObject *args = PyTuple_New (3 );
428
+ PyTuple_SetItem (args, 0 , xarray);
429
+ PyTuple_SetItem (args, 1 , yarray);
430
+ PyTuple_SetItem (args, 2 , zarray);
431
+
432
+ // Build up the kw args.
433
+ PyObject *kwargs = PyDict_New ();
434
+ PyDict_SetItemString (kwargs, " rstride" , PyInt_FromLong (1 ));
435
+ PyDict_SetItemString (kwargs, " cstride" , PyInt_FromLong (1 ));
436
+
437
+ PyObject *python_colormap_coolwarm = PyObject_GetAttrString (
438
+ detail::_interpreter::get ().s_python_colormap , " coolwarm" );
439
+
440
+ PyDict_SetItemString (kwargs, " cmap" , python_colormap_coolwarm);
441
+
442
+ for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
443
+ it != keywords.end (); ++it) {
444
+ PyDict_SetItemString (kwargs, it->first .c_str (),
445
+ PyString_FromString (it->second .c_str ()));
446
+ }
447
+
448
+
449
+ PyObject *fig =
450
+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
451
+ detail::_interpreter::get ().s_python_empty_tuple );
452
+ if (!fig) throw std::runtime_error (" Call to figure() failed." );
453
+
454
+ PyObject *gca_kwargs = PyDict_New ();
455
+ PyDict_SetItemString (gca_kwargs, " projection" , PyString_FromString (" 3d" ));
456
+
457
+ PyObject *gca = PyObject_GetAttrString (fig, " gca" );
458
+ if (!gca) throw std::runtime_error (" No gca" );
459
+ Py_INCREF (gca);
460
+ PyObject *axis = PyObject_Call (
461
+ gca, detail::_interpreter::get ().s_python_empty_tuple , gca_kwargs);
462
+
463
+ if (!axis) throw std::runtime_error (" No axis" );
464
+ Py_INCREF (axis);
465
+
466
+ Py_DECREF (gca);
467
+ Py_DECREF (gca_kwargs);
468
+
469
+ PyObject *plot_surface = PyObject_GetAttrString (axis, " plot_surface" );
470
+ if (!plot_surface) throw std::runtime_error (" No surface" );
471
+ Py_INCREF (plot_surface);
472
+ PyObject *res = PyObject_Call (plot_surface, args, kwargs);
473
+ if (!res) throw std::runtime_error (" failed surface" );
474
+ Py_DECREF (plot_surface);
475
+
476
+ Py_DECREF (axis);
477
+ Py_DECREF (args);
478
+ Py_DECREF (kwargs);
479
+ if (res) Py_DECREF (res);
480
+ }
481
+
372
482
template <typename Numeric>
373
483
bool stem (const std::vector<Numeric> &x, const std::vector<Numeric> &y, const std::map<std::string, std::string>& keywords)
374
484
{
0 commit comments