2
2
import itertools
3
3
import platform
4
4
5
+ import numpy as np
5
6
import pytest
6
7
7
8
from mpl_toolkits .mplot3d import Axes3D , axes3d , proj3d , art3d
14
15
from matplotlib .testing .decorators import image_comparison , check_figures_equal
15
16
from matplotlib .testing .widgets import mock_event
16
17
from matplotlib .collections import LineCollection , PolyCollection
18
+ from matplotlib .cbook import hexbin
17
19
from matplotlib .patches import Circle , PathPatch
18
20
from matplotlib .path import Path
19
21
from matplotlib .text import Text
@@ -32,7 +34,42 @@ def plot_cuboid(ax, scale):
32
34
pts = itertools .combinations (np .array (list (itertools .product (r , r , r ))), 2 )
33
35
for start , end in pts :
34
36
if np .sum (np .abs (start - end )) == r [1 ] - r [0 ]:
35
- ax .plot3D (* zip (start * np .array (scale ), end * np .array (scale )))
37
+ ax .plot3D (* zip (start * np .array (scale ), end * np .array (scale )))
38
+
39
+
40
+ def get_gaussian_bars (mu = (0 , 0 ),
41
+ sigma = ([0.8 , 0.3 ],
42
+ [0.3 , 0.5 ]),
43
+ range = (- 3 , 3 ),
44
+ res = 8 ,
45
+ seed = 123 ):
46
+ np .random .seed (seed )
47
+ sl = slice (* range , complex (res ))
48
+ xy = np .array (np .mgrid [sl , sl ][::- 1 ]).T - mu
49
+ p = np .linalg .inv (sigma )
50
+ exp = np .sum (np .moveaxis (xy .T , 0 , 1 ) * (p @ np .moveaxis (xy , 0 , - 1 )), 1 )
51
+ z = np .exp (- exp / 2 ) / np .sqrt (np .linalg .det (sigma )) / np .pi / 2
52
+ return * xy .T , z , '0.8'
53
+
54
+
55
+ def get_gaussian_hexs (mu = (0 , 0 ),
56
+ sigma = ([0.8 , 0.3 ],
57
+ [0.3 , 0.5 ]),
58
+ n = 10_000 ,
59
+ res = 8 ,
60
+ seed = 123 ):
61
+ np .random .seed (seed )
62
+ xy = np .random .multivariate_normal (mu , sigma , n )
63
+ xyz , (xmin , xmax ), (ymin , ymax ), (nx , ny ) = hexbin (* xy .T , gridsize = res )
64
+ dxy = np .array ([(xmax - xmin ) / nx , (ymax - ymin ) / ny / np .sqrt (3 )]) * 0.95
65
+ return * xyz , dxy
66
+
67
+
68
+ def get_bar3d_test_data ():
69
+ return {
70
+ 'rect' : get_gaussian_bars (),
71
+ 'hex' : get_gaussian_hexs ()
72
+ }
36
73
37
74
38
75
@check_figures_equal (extensions = ["png" ])
@@ -220,6 +257,79 @@ def test_bar3d_lightsource():
220
257
np .testing .assert_array_max_ulp (color , collection ._facecolor3d [1 ::6 ], 4 )
221
258
222
259
260
+ @pytest .fixture (params = [get_bar3d_test_data ])
261
+ def bar3d_test_data (request ):
262
+ return request .param ()
263
+
264
+
265
+ class TestBar3D :
266
+
267
+ shapes = ('rect' , 'hex' )
268
+
269
+ def _plot_bar3d (self , ax , x , y , z , dxy , shape , azim = None , elev = None , ** kws ):
270
+
271
+ api_function = ax .hexbar3d if shape == 'hex' else ax .bar3d_grid
272
+ bars = api_function (x , y , z , dxy , ** kws )
273
+
274
+ if azim :
275
+ ax .azim = azim
276
+ if elev :
277
+ ax .elev = elev
278
+
279
+ return bars
280
+
281
+ @mpl3d_image_comparison (['bar3d_with_1d_data.png' ])
282
+ def test_bar3d_with_1d_data (self ):
283
+ fig , axes = plt .subplots (1 , 2 , subplot_kw = {'projection' : '3d' })
284
+ for ax , shape in zip (axes , self .shapes ):
285
+ self ._plot_bar3d (ax , 0 , 0 , 1 , '0.8' , shape , ec = '0.5' , lw = 0.5 )
286
+
287
+ @mpl3d_image_comparison (['bar3d_zsort.png' , 'bar3d_zsort_hex.png' ])
288
+ def test_bar3d_zsort (self ):
289
+ for shape in self .shapes :
290
+ fig , axes = plt .subplots (2 , 4 , subplot_kw = {'projection' : '3d' })
291
+ elev = 45
292
+ azim0 , astep = - 22.5 , 45
293
+ camera = itertools .product (np .r_ [azim0 :(180 + azim0 ):astep ],
294
+ (elev , - elev ))
295
+ # sourcery skip: no-loop-in-tests
296
+ for ax , (azim , elev ) in zip (axes .T .ravel (), camera ):
297
+ self ._plot_bar3d (ax ,
298
+ [0 , 1 ], [0 , 1 ], [1 , 2 ],
299
+ '0.8' ,
300
+ shape ,
301
+ azim = azim , elev = elev ,
302
+ ec = '0.5' , lw = 0.5 )
303
+
304
+ @mpl3d_image_comparison (['bar3d_with_2d_data.png' ])
305
+ def test_bar3d_with_2d_data (self , bar3d_test_data ):
306
+ fig , axes = plt .subplots (1 , 2 , subplot_kw = {'projection' : '3d' })
307
+ for ax , shape in zip (axes , self .shapes ):
308
+ x , y , z , dxy = bar3d_test_data [shape ]
309
+ self ._plot_bar3d (ax , x , y , z , dxy , shape , ec = '0.5' , lw = 0.5 )
310
+
311
+ def _gen_bar3d_subplots (self , bar3d_test_data ):
312
+ config = dict (edgecolors = '0.5' , lw = 0.5 )
313
+ fig , axes = plt .subplots (2 , 2 , subplot_kw = {'projection' : '3d' })
314
+ for i , shape in enumerate (self .shapes ):
315
+ x , y , z , dxy = bar3d_test_data [shape ]
316
+ for j , shade in enumerate ((0 , 1 )):
317
+ yield (axes [i , j ], x , y , z , dxy , shape ), {** config , 'shade' : shade }
318
+
319
+ @mpl3d_image_comparison (['bar3d_facecolors.png' ])
320
+ def test_bar3d_facecolors (self , bar3d_test_data ):
321
+ for (ax , x , y , z , dxy , shape ), kws in self ._gen_bar3d_subplots (bar3d_test_data ):
322
+ bars = self ._plot_bar3d (
323
+ ax , x , y , z , dxy , shape , ** kws ,
324
+ facecolors = list (mcolors .CSS4_COLORS )[:x .size ]
325
+ )
326
+
327
+ @mpl3d_image_comparison (['bar3d_cmap.png' ])
328
+ def test_bar3d_cmap (self , bar3d_test_data ):
329
+ for (ax , x , y , z , dxy , shape ), kws in self ._gen_bar3d_subplots (bar3d_test_data ):
330
+ bars = self ._plot_bar3d (ax , x , y , z , dxy , shape , cmap = 'viridis' , ** kws )
331
+
332
+
223
333
@mpl3d_image_comparison (
224
334
['contour3d.png' ], style = 'mpl20' ,
225
335
tol = 0.002 if platform .machine () in ('aarch64' , 'ppc64le' , 's390x' ) else 0 )
0 commit comments