Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1371e5

Browse files
committed
TEST: take comments into account
1 parent 78974de commit d1371e5

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,19 @@ def check_importances(X, y, name, criterion):
206206
assert_less(0 < X_new.shape[1], X.shape[1])
207207

208208
# Check with sample weights
209-
sample_weight = np.ones(y.shape)
210-
sample_weight[y == 1] *= 100
211-
209+
sample_weight = check_random_state(0).randint(1, 10, len(X))
212210
est = ForestEstimator(n_estimators=20, random_state=0,
213211
criterion=criterion)
214212
est.fit(X, y, sample_weight=sample_weight)
215213
importances = est.feature_importances_
216214
assert_true(np.all(importances >= 0.0))
217215

218-
for scale in [3, 10, 1000, 100000]:
216+
for scale in [10, 100, 1000]:
219217
est = ForestEstimator(n_estimators=20, random_state=0,
220218
criterion=criterion)
221219
est.fit(X, y, sample_weight=scale * sample_weight)
222220
importances_bis = est.feature_importances_
223-
assert_almost_equal(importances, importances_bis)
221+
assert_less(np.abs(importances - importances_bis).mean(), 0.0001)
224222

225223

226224
def test_importances():
@@ -232,7 +230,7 @@ def test_importances():
232230
for name, criterion in product(FOREST_CLASSIFIERS, ["gini", "entropy"]):
233231
yield check_importances, X, y, name, criterion
234232

235-
for name, criterion in product(FOREST_REGRESSORS, ["mse"]):
233+
for name, criterion in product(FOREST_REGRESSORS, ["mse", "friedman_mse"]):
236234
yield check_importances, X, y, name, criterion
237235

238236

@@ -242,10 +240,7 @@ def test_importances_asymptotic():
242240
# Understanding variable importances in forests of randomized trees, 2013).
243241

244242
def binomial(k, n):
245-
if k < 0 or k > n:
246-
return 0
247-
else:
248-
return comb(int(n), int(k), exact=True)
243+
return 0 if k < 0 or k > n else comb(int(n), int(k), exact=True)
249244

250245
def entropy(samples):
251246
e = 0.
@@ -263,11 +258,9 @@ def mdi_importance(X_m, X, y):
263258

264259
variables = list(range(p))
265260
variables.pop(X_m)
266-
imp = 0.
261+
values = [np.unique(X[:, i]) for i in range(p)]
267262

268-
values = []
269-
for i in range(p):
270-
values.append(np.unique(X[:, i]))
263+
imp = 0.
271264

272265
for k in range(p):
273266
# Weight of each B of size k
@@ -331,7 +324,7 @@ def mdi_importance(X_m, X, y):
331324

332325
# Check correctness
333326
assert_almost_equal(entropy(y), sum(importances))
334-
assert_less(((true_importances - importances) ** 2).sum(), 0.0005)
327+
assert_less(np.abs(true_importances - importances).mean(), 0.01)
335328

336329

337330
def check_unfitted_feature_importances(name):

0 commit comments

Comments
 (0)