Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 082c7f1

Browse files
committed
Merge pull request #29 from ogrisel/enh-pr-5540
TST improved test for df param
2 parents 78067ea + 712fe1f commit 082c7f1

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

sklearn/tests/test_base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,9 @@ def test_score_sample_weight():
237237

238238

239239
def test_clone_pandas_dataframe():
240-
class MockDataFrameWithEq(MockDataFrame):
241-
"""implemenets __eq__ operator to leverage the current test."""
242-
def __eq__(self, other):
243-
return self.array == other.array
244240

245241
class DummyEstimator(BaseEstimator, TransformerMixin):
246-
"""This is a dummpy class for generating numerical features
242+
"""This is a dummy class for generating numerical features
247243
248244
This feature extractor extracts numerical features from pandas data
249245
frame.
@@ -257,9 +253,9 @@ class DummyEstimator(BaseEstimator, TransformerMixin):
257253
Notes
258254
-----
259255
"""
260-
def __init__(self, df, toto):
256+
def __init__(self, df=None, scalar_param=1):
261257
self.df = df
262-
self.toto = toto
258+
self.scalar_param = scalar_param
263259

264260
def fit(self, X, y=None):
265261
pass
@@ -269,9 +265,10 @@ def transform(self, X, y=None):
269265

270266
# build and clone estimator
271267
d = np.arange(10)
272-
df = MockDataFrameWithEq(d)
273-
e = DummyEstimator(df, toto=1)
268+
df = MockDataFrame(d)
269+
e = DummyEstimator(df, scalar_param=1)
274270
cloned_e = clone(e)
275271

276272
# the test
277-
assert_equal(e.toto, cloned_e.toto)
273+
assert_true((e.df == cloned_e.df).values.all())
274+
assert_equal(e.scalar_param, cloned_e.scalar_param)

sklearn/utils/mocking.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class MockDataFrame(object):
1818
# have shape an length but don't support indexing.
1919
def __init__(self, array):
2020
self.array = array
21+
self.values = array
2122
self.shape = array.shape
2223
self.ndim = array.ndim
2324
# ugly hack to make iloc work.
@@ -32,6 +33,9 @@ def __array__(self):
3233
# method.
3334
return self.array
3435

36+
def __eq__(self, other):
37+
return MockDataFrame(self.array == other.array)
38+
3539

3640
class CheckingClassifier(BaseEstimator, ClassifierMixin):
3741
"""Dummy classifier to test pipelining and meta-estimators.

0 commit comments

Comments
 (0)