Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bc2cbbb

Browse files
committed
FIX: remove scipy dependency
1 parent 2b4ffed commit bc2cbbb

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import pytest
7-
from scipy.stats import multivariate_normal
87

98
from mpl_toolkits.mplot3d import Axes3D, axes3d, proj3d, art3d
109
from mpl_toolkits.mplot3d.axes3d import _Quaternion as Quaternion
@@ -40,15 +39,15 @@ def get_gaussian_bars(mu=(0, 0),
4039
sigma=([0.8, 0.3],
4140
[0.3, 0.5]),
4241
range=(-3, 3),
43-
res=2 ** 3,
42+
res=8,
4443
seed=123):
4544
np.random.seed(seed)
46-
rv = multivariate_normal(mu, np.array(sigma))
4745
sl = slice(*range, complex(res))
48-
xy = np.array(np.mgrid[sl, sl][::-1])
49-
z = rv.pdf(xy.transpose(1, 2, 0)).T
50-
51-
return *xy, z
46+
xy = np.array(np.mgrid[sl, sl][::-1]).T - mu
47+
p = np.linalg.inv(sigma)
48+
exp = np.sum(np.moveaxis(xy.T, 0, 1) * (p @ np.moveaxis(xy, 0, -1)), 1)
49+
z = np.exp(-exp / 2) / np.sqrt(np.linalg.det(sigma)) / np.pi / 2
50+
return *xy.T, z
5251

5352

5453
def get_gaussian_hexs(mu=(0, 0),
@@ -58,8 +57,8 @@ def get_gaussian_hexs(mu=(0, 0),
5857
res=8,
5958
seed=123):
6059
np.random.seed(seed)
61-
rv = multivariate_normal(mu, np.array(sigma))
62-
xyz, (xmin, xmax), (ymin, ymax), (nx, ny) = hexbin(*rv.rvs(n).T, gridsize=res)
60+
xy = np.random.multivariate_normal(mu, sigma, n)
61+
xyz, (xmin, xmax), (ymin, ymax), (nx, ny) = hexbin(*xy.T, gridsize=res)
6362
dxy = np.array([(xmax - xmin) / nx, (ymax - ymin) / ny]) * 0.9
6463
return *xyz, dxy
6564

@@ -299,8 +298,6 @@ def test_bar3d_with_2d_data(bar3d_class):
299298
# 'bar3d_facecolors_Bar3DCollection-1.png',
300299
# 'bar3d_facecolors_HexBar3DCollection-0.png'
301300
# 'bar3d_facecolors_HexBar3DCollection-1.png'])
302-
303-
304301
@pytest.mark.parametrize('shade', (0, 1))
305302
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
306303
def test_bar3d_facecolors(bar3d_class, shade):
@@ -317,8 +314,6 @@ def test_bar3d_facecolors(bar3d_class, shade):
317314
# 'bar3d_cmap_Bar3DCollection-1.png',
318315
# 'bar3d_cmap_HexBar3DCollection-0.png'
319316
# 'bar3d_cmap_HexBar3DCollection-1.png'])
320-
321-
322317
@pytest.mark.parametrize('shade', (0, 1))
323318
@pytest.mark.mpl_image_compare(style='default', remove_text=True)
324319
def test_bar3d_cmap(bar3d_class, shade):

0 commit comments

Comments
 (0)