Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ad6f094

Browse files
authored
[MRG+2] switch to multinomial composition for mixture sampling (scikit-learn#7702)
* switch to multinomial composition for mixture sampling * add shape assertions to test * Use n_components=3 to test actual regression n_components and n_features were equal and one was used for the other in some places.
2 parents cd714b1 + 4e1c101 commit ad6f094

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

sklearn/mixture/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def sample(self, n_samples=1):
385385

386386
_, n_features = self.means_.shape
387387
rng = check_random_state(self.random_state)
388-
n_samples_comp = np.round(self.weights_ * n_samples).astype(int)
388+
n_samples_comp = rng.multinomial(n_samples, self.weights_)
389389

390390
if self.covariance_type == 'full':
391391
X = np.vstack([

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def test_property():
918918

919919
def test_sample():
920920
rng = np.random.RandomState(0)
921-
rand_data = RandomData(rng, scale=7)
921+
rand_data = RandomData(rng, scale=7, n_components=3)
922922
n_features, n_components = rand_data.n_features, rand_data.n_components
923923

924924
for covar_type in COVARIANCE_TYPE:
@@ -935,8 +935,10 @@ def test_sample():
935935
gmm.sample, 0)
936936

937937
# Just to make sure the class samples correctly
938-
X_s, y_s = gmm.sample(20000)
939-
for k in range(n_features):
938+
n_samples = 20000
939+
X_s, y_s = gmm.sample(n_samples)
940+
941+
for k in range(n_components):
940942
if covar_type == 'full':
941943
assert_array_almost_equal(gmm.covariances_[k],
942944
np.cov(X_s[y_s == k].T), decimal=1)
@@ -953,9 +955,17 @@ def test_sample():
953955
decimal=1)
954956

955957
means_s = np.array([np.mean(X_s[y_s == k], 0)
956-
for k in range(n_features)])
958+
for k in range(n_components)])
957959
assert_array_almost_equal(gmm.means_, means_s, decimal=1)
958960

961+
# Check shapes of sampled data, see
962+
# https://github.com/scikit-learn/scikit-learn/issues/7701
963+
assert_equal(X_s.shape, (n_samples, n_features))
964+
965+
for sample_size in range(1, 100):
966+
X_s, _ = gmm.sample(sample_size)
967+
assert_equal(X_s.shape, (sample_size, n_features))
968+
959969

960970
@ignore_warnings(category=ConvergenceWarning)
961971
def test_init():

0 commit comments

Comments
 (0)