Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 55672f9

Browse files
committed
add shape assertions to test
1 parent cfc280d commit 55672f9

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

sklearn/mixture/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def sample(self, n_samples=1):
385385

386386
_, n_features = self.means_.shape
387387
rng = check_random_state(self.random_state)
388-
n_samples_comp = rng.multinomial(n_samples, self.weights_).astype(int)
388+
n_samples_comp = rng.multinomial(n_samples, self.weights_)
389389

390390
if self.covariance_type == 'full':
391391
X = np.vstack([

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,8 @@ def test_sample():
935935
gmm.sample, 0)
936936

937937
# Just to make sure the class samples correctly
938-
X_s, y_s = gmm.sample(20000)
938+
n_samples = 20000
939+
X_s, y_s = gmm.sample(n_samples)
939940
for k in range(n_features):
940941
if covar_type == 'full':
941942
assert_array_almost_equal(gmm.covariances_[k],
@@ -956,6 +957,13 @@ def test_sample():
956957
for k in range(n_features)])
957958
assert_array_almost_equal(gmm.means_, means_s, decimal=1)
958959

960+
# Check that sizes that are drawn match what is requested
961+
assert_equal(X_s.shape, (n_samples, n_components))
962+
for sample_size in range(1, 50):
963+
X_s, _ = gmm.sample(sample_size)
964+
assert_equal(X_s.shape, (sample_size, n_components))
965+
966+
959967

960968
@ignore_warnings(category=ConvergenceWarning)
961969
def test_init():

0 commit comments

Comments
 (0)