@@ -99,6 +99,7 @@ struct _interpreter {
99
99
PyObject *s_python_function_barh;
100
100
PyObject *s_python_function_colorbar;
101
101
PyObject *s_python_function_subplots_adjust;
102
+ PyObject *s_python_function_rcparams;
102
103
103
104
104
105
/* For now, _interpreter is implemented as a singleton since its currently not possible to have
@@ -189,6 +190,7 @@ struct _interpreter {
189
190
}
190
191
191
192
PyObject* matplotlib = PyImport_Import (matplotlibname);
193
+
192
194
Py_DECREF (matplotlibname);
193
195
if (!matplotlib) {
194
196
PyErr_Print ();
@@ -201,6 +203,8 @@ struct _interpreter {
201
203
PyObject_CallMethod (matplotlib, const_cast <char *>(" use" ), const_cast <char *>(" s" ), s_backend.c_str ());
202
204
}
203
205
206
+
207
+
204
208
PyObject* pymod = PyImport_Import (pyplotname);
205
209
Py_DECREF (pyplotname);
206
210
if (!pymod) { throw std::runtime_error (" Error loading module matplotlib.pyplot!" ); }
@@ -264,6 +268,7 @@ struct _interpreter {
264
268
s_python_function_barh = safe_import (pymod, " barh" );
265
269
s_python_function_colorbar = PyObject_GetAttrString (pymod, " colorbar" );
266
270
s_python_function_subplots_adjust = safe_import (pymod," subplots_adjust" );
271
+ s_python_function_rcparams = PyObject_GetAttrString (pymod, " rcParams" );
267
272
#ifndef WITHOUT_NUMPY
268
273
s_python_function_imshow = safe_import (pymod, " imshow" );
269
274
#endif
@@ -464,6 +469,7 @@ template <typename Numeric>
464
469
void plot_surface (const std::vector<::std::vector<Numeric>> &x,
465
470
const std::vector<::std::vector<Numeric>> &y,
466
471
const std::vector<::std::vector<Numeric>> &z,
472
+ const long fig_number=0 ,
467
473
const std::map<std::string, std::string> &keywords =
468
474
std::map<std::string, std::string>())
469
475
{
@@ -516,14 +522,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
516
522
517
523
for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
518
524
it != keywords.end (); ++it) {
519
- PyDict_SetItemString (kwargs, it->first .c_str (),
520
- PyString_FromString (it->second .c_str ()));
525
+ if (it->first == " linewidth" || it->first == " alpha" ) {
526
+ PyDict_SetItemString (kwargs, it->first .c_str (),
527
+ PyFloat_FromDouble (std::stod (it->second )));
528
+ } else {
529
+ PyDict_SetItemString (kwargs, it->first .c_str (),
530
+ PyString_FromString (it->second .c_str ()));
531
+ }
521
532
}
522
533
523
-
524
- PyObject *fig =
525
- PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
526
- detail::_interpreter::get ().s_python_empty_tuple );
534
+ PyObject *fig_args = PyTuple_New (1 );
535
+ PyObject* fig = nullptr ;
536
+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
537
+ PyObject *fig_exists =
538
+ PyObject_CallObject (
539
+ detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
540
+ if (!PyObject_IsTrue (fig_exists)) {
541
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
542
+ detail::_interpreter::get ().s_python_empty_tuple );
543
+ } else {
544
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
545
+ fig_args);
546
+ }
547
+ Py_DECREF (fig_exists);
527
548
if (!fig) throw std::runtime_error (" Call to figure() failed." );
528
549
529
550
PyObject *gca_kwargs = PyDict_New ();
@@ -559,6 +580,7 @@ template <typename Numeric>
559
580
void plot3 (const std::vector<Numeric> &x,
560
581
const std::vector<Numeric> &y,
561
582
const std::vector<Numeric> &z,
583
+ const long fig_number=0 ,
562
584
const std::map<std::string, std::string> &keywords =
563
585
std::map<std::string, std::string>())
564
586
{
@@ -607,9 +629,18 @@ void plot3(const std::vector<Numeric> &x,
607
629
PyString_FromString (it->second .c_str ()));
608
630
}
609
631
610
- PyObject *fig =
611
- PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
612
- detail::_interpreter::get ().s_python_empty_tuple );
632
+ PyObject *fig_args = PyTuple_New (1 );
633
+ PyObject* fig = nullptr ;
634
+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
635
+ PyObject *fig_exists =
636
+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
637
+ if (!PyObject_IsTrue (fig_exists)) {
638
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
639
+ detail::_interpreter::get ().s_python_empty_tuple );
640
+ } else {
641
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
642
+ fig_args);
643
+ }
613
644
if (!fig) throw std::runtime_error (" Call to figure() failed." );
614
645
615
646
PyObject *gca_kwargs = PyDict_New ();
@@ -911,6 +942,103 @@ bool scatter(const std::vector<NumericX>& x,
911
942
return res;
912
943
}
913
944
945
+ template <typename NumericX, typename NumericY, typename NumericZ>
946
+ bool scatter (const std::vector<NumericX>& x,
947
+ const std::vector<NumericY>& y,
948
+ const std::vector<NumericZ>& z,
949
+ const double s=1.0 , // The marker size in points**2
950
+ const long fig_number=0 ,
951
+ const std::map<std::string, std::string> & keywords = {}) {
952
+ detail::_interpreter::get ();
953
+
954
+ // Same as with plot_surface: We lazily load the modules here the first time
955
+ // this function is called because I'm not sure that we can assume "matplotlib
956
+ // installed" implies "mpl_toolkits installed" on all platforms, and we don't
957
+ // want to require it for people who don't need 3d plots.
958
+ static PyObject *mpl_toolkitsmod = nullptr , *axis3dmod = nullptr ;
959
+ if (!mpl_toolkitsmod) {
960
+ detail::_interpreter::get ();
961
+
962
+ PyObject* mpl_toolkits = PyString_FromString (" mpl_toolkits" );
963
+ PyObject* axis3d = PyString_FromString (" mpl_toolkits.mplot3d" );
964
+ if (!mpl_toolkits || !axis3d) { throw std::runtime_error (" couldnt create string" ); }
965
+
966
+ mpl_toolkitsmod = PyImport_Import (mpl_toolkits);
967
+ Py_DECREF (mpl_toolkits);
968
+ if (!mpl_toolkitsmod) { throw std::runtime_error (" Error loading module mpl_toolkits!" ); }
969
+
970
+ axis3dmod = PyImport_Import (axis3d);
971
+ Py_DECREF (axis3d);
972
+ if (!axis3dmod) { throw std::runtime_error (" Error loading module mpl_toolkits.mplot3d!" ); }
973
+ }
974
+
975
+ assert (x.size () == y.size ());
976
+ assert (y.size () == z.size ());
977
+
978
+ PyObject *xarray = detail::get_array (x);
979
+ PyObject *yarray = detail::get_array (y);
980
+ PyObject *zarray = detail::get_array (z);
981
+
982
+ // construct positional args
983
+ PyObject *args = PyTuple_New (3 );
984
+ PyTuple_SetItem (args, 0 , xarray);
985
+ PyTuple_SetItem (args, 1 , yarray);
986
+ PyTuple_SetItem (args, 2 , zarray);
987
+
988
+ // Build up the kw args.
989
+ PyObject *kwargs = PyDict_New ();
990
+
991
+ for (std::map<std::string, std::string>::const_iterator it = keywords.begin ();
992
+ it != keywords.end (); ++it) {
993
+ PyDict_SetItemString (kwargs, it->first .c_str (),
994
+ PyString_FromString (it->second .c_str ()));
995
+ }
996
+ PyObject *fig_args = PyTuple_New (1 );
997
+ PyObject* fig = nullptr ;
998
+ PyTuple_SetItem (fig_args, 0 , PyLong_FromLong (fig_number));
999
+ PyObject *fig_exists =
1000
+ PyObject_CallObject (detail::_interpreter::get ().s_python_function_fignum_exists , fig_args);
1001
+ if (!PyObject_IsTrue (fig_exists)) {
1002
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
1003
+ detail::_interpreter::get ().s_python_empty_tuple );
1004
+ } else {
1005
+ fig = PyObject_CallObject (detail::_interpreter::get ().s_python_function_figure ,
1006
+ fig_args);
1007
+ }
1008
+ Py_DECREF (fig_exists);
1009
+ if (!fig) throw std::runtime_error (" Call to figure() failed." );
1010
+
1011
+ PyObject *gca_kwargs = PyDict_New ();
1012
+ PyDict_SetItemString (gca_kwargs, " projection" , PyString_FromString (" 3d" ));
1013
+
1014
+ PyObject *gca = PyObject_GetAttrString (fig, " gca" );
1015
+ if (!gca) throw std::runtime_error (" No gca" );
1016
+ Py_INCREF (gca);
1017
+ PyObject *axis = PyObject_Call (
1018
+ gca, detail::_interpreter::get ().s_python_empty_tuple , gca_kwargs);
1019
+
1020
+ if (!axis) throw std::runtime_error (" No axis" );
1021
+ Py_INCREF (axis);
1022
+
1023
+ Py_DECREF (gca);
1024
+ Py_DECREF (gca_kwargs);
1025
+
1026
+ PyObject *plot3 = PyObject_GetAttrString (axis, " scatter" );
1027
+ if (!plot3) throw std::runtime_error (" No 3D line plot" );
1028
+ Py_INCREF (plot3);
1029
+ PyObject *res = PyObject_Call (plot3, args, kwargs);
1030
+ if (!res) throw std::runtime_error (" Failed 3D line plot" );
1031
+ Py_DECREF (plot3);
1032
+
1033
+ Py_DECREF (axis);
1034
+ Py_DECREF (args);
1035
+ Py_DECREF (kwargs);
1036
+ Py_DECREF (fig);
1037
+ if (res) Py_DECREF (res);
1038
+ return res;
1039
+
1040
+ }
1041
+
914
1042
template <typename Numeric>
915
1043
bool boxplot (const std::vector<std::vector<Numeric>>& data,
916
1044
const std::vector<std::string>& labels = {},
@@ -1139,9 +1267,9 @@ bool contour(const std::vector<NumericX>& x, const std::vector<NumericY>& y,
1139
1267
const std::map<std::string, std::string>& keywords = {}) {
1140
1268
assert (x.size () == y.size () && x.size () == z.size ());
1141
1269
1142
- PyObject* xarray = get_array (x);
1143
- PyObject* yarray = get_array (y);
1144
- PyObject* zarray = get_array (z);
1270
+ PyObject* xarray = detail:: get_array (x);
1271
+ PyObject* yarray = detail:: get_array (y);
1272
+ PyObject* zarray = detail:: get_array (z);
1145
1273
1146
1274
PyObject* plot_args = PyTuple_New (3 );
1147
1275
PyTuple_SetItem (plot_args, 0 , xarray);
@@ -2094,12 +2222,14 @@ inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1.
2094
2222
2095
2223
// construct keyword args
2096
2224
PyObject* kwargs = PyDict_New ();
2097
- for (std::map<std::string, std::string>::const_iterator it = keywords.begin (); it != keywords.end (); ++it)
2098
- {
2099
- if (it->first == " linewidth" || it->first == " alpha" )
2100
- PyDict_SetItemString (kwargs, it->first .c_str (), PyFloat_FromDouble (std::stod (it->second )));
2101
- else
2102
- PyDict_SetItemString (kwargs, it->first .c_str (), PyString_FromString (it->second .c_str ()));
2225
+ for (auto it = keywords.begin (); it != keywords.end (); ++it) {
2226
+ if (it->first == " linewidth" || it->first == " alpha" ) {
2227
+ PyDict_SetItemString (kwargs, it->first .c_str (),
2228
+ PyFloat_FromDouble (std::stod (it->second )));
2229
+ } else {
2230
+ PyDict_SetItemString (kwargs, it->first .c_str (),
2231
+ PyString_FromString (it->second .c_str ()));
2232
+ }
2103
2233
}
2104
2234
2105
2235
PyObject* res = PyObject_Call (detail::_interpreter::get ().s_python_function_axvspan , args, kwargs);
@@ -2319,6 +2449,25 @@ inline void save(const std::string& filename)
2319
2449
Py_DECREF (res);
2320
2450
}
2321
2451
2452
+ inline void rcparams (const std::map<std::string, std::string>& keywords = {}) {
2453
+ detail::_interpreter::get ();
2454
+ PyObject* args = PyTuple_New (0 );
2455
+ PyObject* kwargs = PyDict_New ();
2456
+ for (auto it = keywords.begin (); it != keywords.end (); ++it) {
2457
+ if (" text.usetex" == it->first )
2458
+ PyDict_SetItemString (kwargs, it->first .c_str (), PyLong_FromLong (std::stoi (it->second .c_str ())));
2459
+ else PyDict_SetItemString (kwargs, it->first .c_str (), PyString_FromString (it->second .c_str ()));
2460
+ }
2461
+
2462
+ PyObject * update = PyObject_GetAttrString (detail::_interpreter::get ().s_python_function_rcparams , " update" );
2463
+ PyObject * res = PyObject_Call (update, args, kwargs);
2464
+ if (!res) throw std::runtime_error (" Call to rcParams.update() failed." );
2465
+ Py_DECREF (args);
2466
+ Py_DECREF (kwargs);
2467
+ Py_DECREF (update);
2468
+ Py_DECREF (res);
2469
+ }
2470
+
2322
2471
inline void clf () {
2323
2472
detail::_interpreter::get ();
2324
2473
0 commit comments