File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -385,7 +385,7 @@ def sample(self, n_samples=1):
385
385
386
386
_ , n_features = self .means_ .shape
387
387
rng = check_random_state (self .random_state )
388
- n_samples_comp = rng .multinomial (n_samples , self .weights_ ). astype ( int )
388
+ n_samples_comp = rng .multinomial (n_samples , self .weights_ )
389
389
390
390
if self .covariance_type == 'full' :
391
391
X = np .vstack ([
Original file line number Diff line number Diff line change @@ -935,7 +935,8 @@ def test_sample():
935
935
gmm .sample , 0 )
936
936
937
937
# Just to make sure the class samples correctly
938
- X_s , y_s = gmm .sample (20000 )
938
+ n_samples = 20000
939
+ X_s , y_s = gmm .sample (n_samples )
939
940
for k in range (n_features ):
940
941
if covar_type == 'full' :
941
942
assert_array_almost_equal (gmm .covariances_ [k ],
@@ -956,6 +957,13 @@ def test_sample():
956
957
for k in range (n_features )])
957
958
assert_array_almost_equal (gmm .means_ , means_s , decimal = 1 )
958
959
960
+ # Check that sizes that are drawn match what is requested
961
+ assert_equal (X_s .shape , (n_samples , n_components ))
962
+ for sample_size in range (1 , 50 ):
963
+ X_s , _ = gmm .sample (sample_size )
964
+ assert_equal (X_s .shape , (sample_size , n_components ))
965
+
966
+
959
967
960
968
@ignore_warnings (category = ConvergenceWarning )
961
969
def test_init ():
You can’t perform that action at this time.
0 commit comments