Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 770706c

Browse files
committed
TST: refactor tests
1 parent 42af8e1 commit 770706c

21 files changed

+85
-109
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@
7979

8080
# Base hexagon for creating prisms (HexBar3DCollection).
8181
# sides are ordered anti-clockwise from left: ['W', 'SW', 'SE', 'E', 'NE', 'NW']
82-
HEXAGON = np.array([
83-
[-2, 1],
82+
HEXAGON = np.array([ # autopep8: off
83+
[-2, 1],
8484
[-2, -1],
85-
[0, -2],
86-
[2, -1],
87-
[2, 1],
88-
[0, 2]
85+
[ 0, -2],
86+
[ 2, -1],
87+
[ 2, 1],
88+
[ 0, 2]
8989
]) / 4
90-
90+
# autopep8: on
9191
# ---------------------------------------------------------------------------- #
9292

9393

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 78 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_gaussian_bars(mu=(0, 0),
4747
p = np.linalg.inv(sigma)
4848
exp = np.sum(np.moveaxis(xy.T, 0, 1) * (p @ np.moveaxis(xy, 0, -1)), 1)
4949
z = np.exp(-exp / 2) / np.sqrt(np.linalg.det(sigma)) / np.pi / 2
50-
return *xy.T, z
50+
return *xy.T, z, '0.8'
5151

5252

5353
def get_gaussian_hexs(mu=(0, 0),
@@ -59,14 +59,15 @@ def get_gaussian_hexs(mu=(0, 0),
5959
np.random.seed(seed)
6060
xy = np.random.multivariate_normal(mu, sigma, n)
6161
xyz, (xmin, xmax), (ymin, ymax), (nx, ny) = hexbin(*xy.T, gridsize=res)
62-
dxy = np.array([(xmax - xmin) / nx, (ymax - ymin) / ny / np.sqrt(3)]) * 0.9
62+
dxy = np.array([(xmax - xmin) / nx, (ymax - ymin) / ny / np.sqrt(3)]) * 0.95
6363
return *xyz, dxy
6464

6565

66-
bar3d_data_generators = {
67-
art3d.Bar3DCollection: get_gaussian_bars,
68-
art3d.HexBar3DCollection: get_gaussian_hexs
69-
}
66+
def get_bar3d_test_data():
67+
return {
68+
'rect': get_gaussian_bars(),
69+
'hex': get_gaussian_hexs()
70+
}
7071

7172

7273
@check_figures_equal(extensions=["png"])
@@ -254,102 +255,77 @@ def test_bar3d_lightsource():
254255
np.testing.assert_array_max_ulp(color, collection._facecolor3d[1::6], 4)
255256

256257

257-
@pytest.fixture(params=(art3d.Bar3DCollection, art3d.HexBar3DCollection))
258-
def bar3d_class(request):
259-
return request.param
260-
261-
262-
# @mpl3d_image_comparison(baseline_images=['bar3d_with_1d_data_Bar3DCollection.png',
263-
# 'bar3d_with_1d_data_HexBar3DCollection.png'])
264-
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
265-
def test_bar3d_with_1d_data(bar3d_class):
266-
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
267-
_plot_bar3d(ax, bar3d_class, 0, 0, 1, ec='0.5', lw=0.5)
268-
return fig
269-
270-
271-
# @mpl3d_image_comparison(baseline_images=['bar3d_zsort_Bar3DCollection.png',
272-
# 'bar3d_zsort_HexBar3DCollection.png'])
273-
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
274-
def test_bar3d_zsort(bar3d_class):
275-
fig, axes = plt.subplots(2, 4, subplot_kw={'projection': '3d'})
276-
elev = 45
277-
azim0, astep = -22.5, 45
278-
camera = itertools.product(np.r_[azim0:(180 + azim0):astep], (elev, -elev))
279-
# sourcery skip: no-loop-in-tests
280-
for ax, (azim, elev) in zip(axes.T.ravel(), camera):
281-
_plot_bar3d(ax, bar3d_class,
282-
[0, 1], [0, 1], [1, 2],
283-
azim=azim, elev=elev,
284-
ec='0.5', lw=0.5)
285-
return fig
286-
287-
288-
# @mpl3d_image_comparison(baseline_images=['bar3d_with_2d_data_Bar3DCollection.png',
289-
# 'bar3d_with_2d_data_HexBar3DCollection.png'])
290-
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
291-
def test_bar3d_with_2d_data(bar3d_class):
292-
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
293-
_plot_bar3d(ax, bar3d_class, *bar3d_data_generators[bar3d_class](),
294-
ec='0.5', lw=0.5)
295-
return fig
296-
297-
298-
# @mpl3d_image_comparison(baseline_images=['bar3d_facecolors_Bar3DCollection-0.png',
299-
# 'bar3d_facecolors_Bar3DCollection-1.png',
300-
# 'bar3d_facecolors_HexBar3DCollection-0.png'
301-
# 'bar3d_facecolors_HexBar3DCollection-1.png'])
302-
@pytest.mark.parametrize('shade', (0, 1))
303-
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
304-
def test_bar3d_facecolors(bar3d_class, shade):
305-
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
306-
307-
xyz = bar3d_data_generators[bar3d_class]()
308-
bars = _plot_bar3d(ax, bar3d_class, *xyz,
309-
facecolors=list(mcolors.CSS4_COLORS)[:xyz[0].size],
310-
edgecolors='0.5', lw=0.5,
311-
shade=shade)
312-
return fig
313-
314-
315-
# @mpl3d_image_comparison(baseline_images=['bar3d_cmap_Bar3DCollection-0.png',
316-
# 'bar3d_cmap_Bar3DCollection-1.png',
317-
# 'bar3d_cmap_HexBar3DCollection-0.png'
318-
# 'bar3d_cmap_HexBar3DCollection-1.png'])
319-
@pytest.mark.parametrize('shade', (0, 1))
320-
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
321-
def test_bar3d_cmap(bar3d_class, shade):
322-
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
323-
324-
xyz = bar3d_data_generators[bar3d_class]()
325-
bars = _plot_bar3d(ax, bar3d_class, *xyz,
326-
cmap='viridis',
327-
shade=shade,
328-
edgecolors='0.5', lw=0.5)
329-
return fig
330-
331-
332-
def _plot_bar3d(ax, kls, x, y, z, dxy='0.8', azim=None, elev=None, **kws):
333-
334-
bars = kls(x, y, z, dxy=dxy, **kws)
335-
ax.add_collection(bars)
336-
337-
viewlim = np.array([(np.min(x), np.max(np.add(x, bars.dx))),
338-
(np.min(y), np.max(np.add(y, bars.dy))),
339-
(min(bars.z0, np.min(z)), np.max(z))])
340-
341-
if kls is art3d.HexBar3DCollection:
342-
viewlim[:2, 0] = viewlim[:2, 0] - np.array([bars.dx / 2, bars.dy / 2]).T
343-
344-
ax.auto_scale_xyz(*viewlim, False)
345-
# ax.set(xlabel='x', ylabel='y', zlabel='z')
346-
347-
if azim:
348-
ax.azim = azim
349-
if elev:
350-
ax.elev = elev
351-
352-
return bars
258+
@pytest.fixture(params=[get_bar3d_test_data])
259+
def bar3d_test_data(request):
260+
return request.param()
261+
262+
263+
class TestBar3D:
264+
265+
shapes = ('rect', 'hex')
266+
267+
def _plot_bar3d(self, ax, x, y, z, dxy, shape, azim=None, elev=None, **kws):
268+
269+
api_function = ax.hexbar3d if shape == 'hex' else ax.bar3d_grid
270+
bars = api_function(x, y, z, dxy, **kws)
271+
272+
if azim:
273+
ax.azim = azim
274+
if elev:
275+
ax.elev = elev
276+
277+
return bars
278+
279+
@mpl3d_image_comparison(['bar3d_with_1d_data.png'])
280+
def test_bar3d_with_1d_data(self):
281+
fig, axes = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
282+
for ax, shape in zip(axes, self.shapes):
283+
self._plot_bar3d(ax, 0, 0, 1, '0.8', shape, ec='0.5', lw=0.5)
284+
285+
@mpl3d_image_comparison(['bar3d_zsort.png', 'bar3d_zsort_hex.png'])
286+
def test_bar3d_zsort(self):
287+
for shape in self.shapes:
288+
fig, axes = plt.subplots(2, 4, subplot_kw={'projection': '3d'})
289+
elev = 45
290+
azim0, astep = -22.5, 45
291+
camera = itertools.product(np.r_[azim0:(180 + azim0):astep],
292+
(elev, -elev))
293+
# sourcery skip: no-loop-in-tests
294+
for ax, (azim, elev) in zip(axes.T.ravel(), camera):
295+
self._plot_bar3d(ax,
296+
[0, 1], [0, 1], [1, 2],
297+
'0.8',
298+
shape,
299+
azim=azim, elev=elev,
300+
ec='0.5', lw=0.5)
301+
302+
@mpl3d_image_comparison(['bar3d_with_2d_data.png'])
303+
def test_bar3d_with_2d_data(self, bar3d_test_data):
304+
fig, axes = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
305+
for ax, shape in zip(axes, self.shapes):
306+
x, y, z, dxy = bar3d_test_data[shape]
307+
self._plot_bar3d(ax, x, y, z, dxy, shape, ec='0.5', lw=0.5)
308+
309+
def _gen_bar3d_subplots(self, bar3d_test_data):
310+
config = dict(edgecolors='0.5', lw=0.5)
311+
fig, axes = plt.subplots(2, 2, subplot_kw={'projection': '3d'})
312+
for i, shape in enumerate(self.shapes):
313+
x, y, z, dxy = bar3d_test_data[shape]
314+
for j, shade in enumerate((0, 1)):
315+
yield (axes[i, j], x, y, z, dxy, shape), {**config, 'shade': shade}
316+
317+
@mpl3d_image_comparison(['bar3d_facecolors.png'])
318+
def test_bar3d_facecolors(self, bar3d_test_data):
319+
for (ax, x, y, z, dxy, shape), kws in self._gen_bar3d_subplots(bar3d_test_data):
320+
bars = self._plot_bar3d(
321+
ax, x, y, z, dxy, shape, **kws,
322+
facecolors=list(mcolors.CSS4_COLORS)[:x.size]
323+
)
324+
325+
@mpl3d_image_comparison(['bar3d_cmap.png'])
326+
def test_bar3d_cmap(self, bar3d_test_data):
327+
for (ax, x, y, z, dxy, shape), kws in self._gen_bar3d_subplots(bar3d_test_data):
328+
bars = self._plot_bar3d(ax, x, y, z, dxy, shape, cmap='viridis', **kws)
353329

354330

355331
@mpl3d_image_comparison(

0 commit comments

Comments
 (0)