Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5b5bf70

Browse files
committed
TST: add image tests
FIX: remove scipy dependency FIX: y scaling correction FIX: whitespace
1 parent 2e9bc18 commit 5b5bf70

File tree

7 files changed

+111
-1
lines changed

7 files changed

+111
-1
lines changed

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import platform
44

5+
import numpy as np
56
import pytest
67

78
from mpl_toolkits.mplot3d import Axes3D, axes3d, proj3d, art3d
@@ -14,6 +15,7 @@
1415
from matplotlib.testing.decorators import image_comparison, check_figures_equal
1516
from matplotlib.testing.widgets import mock_event
1617
from matplotlib.collections import LineCollection, PolyCollection
18+
from matplotlib.cbook import hexbin
1719
from matplotlib.patches import Circle, PathPatch
1820
from matplotlib.path import Path
1921
from matplotlib.text import Text
@@ -32,7 +34,42 @@ def plot_cuboid(ax, scale):
3234
pts = itertools.combinations(np.array(list(itertools.product(r, r, r))), 2)
3335
for start, end in pts:
3436
if np.sum(np.abs(start - end)) == r[1] - r[0]:
35-
ax.plot3D(*zip(start*np.array(scale), end*np.array(scale)))
37+
ax.plot3D(*zip(start * np.array(scale), end * np.array(scale)))
38+
39+
40+
def get_gaussian_bars(mu=(0, 0),
41+
sigma=([0.8, 0.3],
42+
[0.3, 0.5]),
43+
range=(-3, 3),
44+
res=8,
45+
seed=123):
46+
np.random.seed(seed)
47+
sl = slice(*range, complex(res))
48+
xy = np.array(np.mgrid[sl, sl][::-1]).T - mu
49+
p = np.linalg.inv(sigma)
50+
exp = np.sum(np.moveaxis(xy.T, 0, 1) * (p @ np.moveaxis(xy, 0, -1)), 1)
51+
z = np.exp(-exp / 2) / np.sqrt(np.linalg.det(sigma)) / np.pi / 2
52+
return *xy.T, z, '0.8'
53+
54+
55+
def get_gaussian_hexs(mu=(0, 0),
56+
sigma=([0.8, 0.3],
57+
[0.3, 0.5]),
58+
n=10_000,
59+
res=8,
60+
seed=123):
61+
np.random.seed(seed)
62+
xy = np.random.multivariate_normal(mu, sigma, n)
63+
xyz, (xmin, xmax), (ymin, ymax), (nx, ny) = hexbin(*xy.T, gridsize=res)
64+
dxy = np.array([(xmax - xmin) / nx, (ymax - ymin) / ny / np.sqrt(3)]) * 0.95
65+
return *xyz, dxy
66+
67+
68+
def get_bar3d_test_data():
69+
return {
70+
'rect': get_gaussian_bars(),
71+
'hex': get_gaussian_hexs()
72+
}
3673

3774

3875
@check_figures_equal(extensions=["png"])
@@ -220,6 +257,79 @@ def test_bar3d_lightsource():
220257
np.testing.assert_array_max_ulp(color, collection._facecolor3d[1::6], 4)
221258

222259

260+
@pytest.fixture(params=[get_bar3d_test_data])
261+
def bar3d_test_data(request):
262+
return request.param()
263+
264+
265+
class TestBar3D:
266+
267+
shapes = ('rect', 'hex')
268+
269+
def _plot_bar3d(self, ax, x, y, z, dxy, shape, azim=None, elev=None, **kws):
270+
271+
api_function = ax.hexbar3d if shape == 'hex' else ax.bar3d_grid
272+
bars = api_function(x, y, z, dxy, **kws)
273+
274+
if azim:
275+
ax.azim = azim
276+
if elev:
277+
ax.elev = elev
278+
279+
return bars
280+
281+
@mpl3d_image_comparison(['bar3d_with_1d_data.png'])
282+
def test_bar3d_with_1d_data(self):
283+
fig, axes = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
284+
for ax, shape in zip(axes, self.shapes):
285+
self._plot_bar3d(ax, 0, 0, 1, '0.8', shape, ec='0.5', lw=0.5)
286+
287+
@mpl3d_image_comparison(['bar3d_zsort.png', 'bar3d_zsort_hex.png'])
288+
def test_bar3d_zsort(self):
289+
for shape in self.shapes:
290+
fig, axes = plt.subplots(2, 4, subplot_kw={'projection': '3d'})
291+
elev = 45
292+
azim0, astep = -22.5, 45
293+
camera = itertools.product(np.r_[azim0:(180 + azim0):astep],
294+
(elev, -elev))
295+
# sourcery skip: no-loop-in-tests
296+
for ax, (azim, elev) in zip(axes.T.ravel(), camera):
297+
self._plot_bar3d(ax,
298+
[0, 1], [0, 1], [1, 2],
299+
'0.8',
300+
shape,
301+
azim=azim, elev=elev,
302+
ec='0.5', lw=0.5)
303+
304+
@mpl3d_image_comparison(['bar3d_with_2d_data.png'])
305+
def test_bar3d_with_2d_data(self, bar3d_test_data):
306+
fig, axes = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
307+
for ax, shape in zip(axes, self.shapes):
308+
x, y, z, dxy = bar3d_test_data[shape]
309+
self._plot_bar3d(ax, x, y, z, dxy, shape, ec='0.5', lw=0.5)
310+
311+
def _gen_bar3d_subplots(self, bar3d_test_data):
312+
config = dict(edgecolors='0.5', lw=0.5)
313+
fig, axes = plt.subplots(2, 2, subplot_kw={'projection': '3d'})
314+
for i, shape in enumerate(self.shapes):
315+
x, y, z, dxy = bar3d_test_data[shape]
316+
for j, shade in enumerate((0, 1)):
317+
yield (axes[i, j], x, y, z, dxy, shape), {**config, 'shade': shade}
318+
319+
@mpl3d_image_comparison(['bar3d_facecolors.png'])
320+
def test_bar3d_facecolors(self, bar3d_test_data):
321+
for (ax, x, y, z, dxy, shape), kws in self._gen_bar3d_subplots(bar3d_test_data):
322+
bars = self._plot_bar3d(
323+
ax, x, y, z, dxy, shape, **kws,
324+
facecolors=list(mcolors.CSS4_COLORS)[:x.size]
325+
)
326+
327+
@mpl3d_image_comparison(['bar3d_cmap.png'])
328+
def test_bar3d_cmap(self, bar3d_test_data):
329+
for (ax, x, y, z, dxy, shape), kws in self._gen_bar3d_subplots(bar3d_test_data):
330+
bars = self._plot_bar3d(ax, x, y, z, dxy, shape, cmap='viridis', **kws)
331+
332+
223333
@mpl3d_image_comparison(
224334
['contour3d.png'], style='mpl20',
225335
tol=0.002 if platform.machine() in ('aarch64', 'ppc64le', 's390x') else 0)

0 commit comments

Comments
 (0)