4
4
5
5
import numpy as np
6
6
import pytest
7
- from scipy .stats import multivariate_normal
8
7
9
8
from mpl_toolkits .mplot3d import Axes3D , axes3d , proj3d , art3d
10
9
from mpl_toolkits .mplot3d .axes3d import _Quaternion as Quaternion
@@ -40,15 +39,15 @@ def get_gaussian_bars(mu=(0, 0),
40
39
sigma = ([0.8 , 0.3 ],
41
40
[0.3 , 0.5 ]),
42
41
range = (- 3 , 3 ),
43
- res = 2 ** 3 ,
42
+ res = 8 ,
44
43
seed = 123 ):
45
44
np .random .seed (seed )
46
- rv = multivariate_normal (mu , np .array (sigma ))
47
45
sl = slice (* range , complex (res ))
48
- xy = np .array (np .mgrid [sl , sl ][::- 1 ])
49
- z = rv .pdf (xy .transpose (1 , 2 , 0 )).T
50
-
51
- return * xy , z
46
+ xy = np .array (np .mgrid [sl , sl ][::- 1 ]).T - mu
47
+ p = np .linalg .inv (sigma )
48
+ exp = np .sum (np .moveaxis (xy .T , 0 , 1 ) * (p @ np .moveaxis (xy , 0 , - 1 )), 1 )
49
+ z = np .exp (- exp / 2 ) / np .sqrt (np .linalg .det (sigma )) / np .pi / 2
50
+ return * xy .T , z
52
51
53
52
54
53
def get_gaussian_hexs (mu = (0 , 0 ),
@@ -58,8 +57,8 @@ def get_gaussian_hexs(mu=(0, 0),
58
57
res = 8 ,
59
58
seed = 123 ):
60
59
np .random .seed (seed )
61
- rv = multivariate_normal (mu , np . array ( sigma ) )
62
- xyz , (xmin , xmax ), (ymin , ymax ), (nx , ny ) = hexbin (* rv . rvs ( n ) .T , gridsize = res )
60
+ xy = np . random . multivariate_normal (mu , sigma , n )
61
+ xyz , (xmin , xmax ), (ymin , ymax ), (nx , ny ) = hexbin (* xy .T , gridsize = res )
63
62
dxy = np .array ([(xmax - xmin ) / nx , (ymax - ymin ) / ny ]) * 0.9
64
63
return * xyz , dxy
65
64
@@ -299,8 +298,6 @@ def test_bar3d_with_2d_data(bar3d_class):
299
298
# 'bar3d_facecolors_Bar3DCollection-1.png',
300
299
# 'bar3d_facecolors_HexBar3DCollection-0.png'
301
300
# 'bar3d_facecolors_HexBar3DCollection-1.png'])
302
-
303
-
304
301
@pytest .mark .parametrize ('shade' , (0 , 1 ))
305
302
@pytest .mark .mpl_image_compare (style = 'default' , remove_text = True )
306
303
def test_bar3d_facecolors (bar3d_class , shade ):
@@ -317,8 +314,6 @@ def test_bar3d_facecolors(bar3d_class, shade):
317
314
# 'bar3d_cmap_Bar3DCollection-1.png',
318
315
# 'bar3d_cmap_HexBar3DCollection-0.png'
319
316
# 'bar3d_cmap_HexBar3DCollection-1.png'])
320
-
321
-
322
317
@pytest .mark .parametrize ('shade' , (0 , 1 ))
323
318
@pytest .mark .mpl_image_compare (style = 'default' , remove_text = True )
324
319
def test_bar3d_cmap (bar3d_class , shade ):
0 commit comments