Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Fix pass sample weights to final estimator #15773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
08eb587
added versionadded to three algorithm classes
J-A16 Nov 20, 2019
cb36b47
added versionadded to three algorithm classes
J-A16 Nov 20, 2019
2af9349
added versionadded to three algorithm classes
J-A16 Nov 20, 2019
613347f
Update _regression.py
J-A16 Nov 20, 2019
8ebdaf2
Update _regression.py
J-A16 Nov 20, 2019
f081b86
Update _unsupervised.py
J-A16 Nov 20, 2019
5ffdccd
added versionadded to three algorithm classes
J-A16 Nov 20, 2019
4ceca1a
Merge branch 'master' of https://github.com/J-A16/scikit-learn
J-A16 Nov 20, 2019
6e4fc9e
Update _regression.py
J-A16 Nov 20, 2019
090e2a7
Update _regression.py
J-A16 Nov 20, 2019
5df84b8
added versionadded comment to NearestNeighbors, KNeighborsRegressor a…
J-A16 Nov 23, 2019
9d8446b
Merge branch 'master' of https://github.com/J-A16/scikit-learn
J-A16 Nov 23, 2019
db7afac
added versionadded comment to NearestNeighbors, KNeighborsRegressor a…
J-A16 Nov 23, 2019
f40de4d
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
J-A16 Nov 24, 2019
ac093b7
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
J-A16 Nov 27, 2019
e0d78ce
Passed sample_weight to base_estimator.fit()
J-A16 Dec 3, 2019
2c1a103
Fix pass sample weights to final estimator
J-A16 Dec 4, 2019
2cc4e1d
Fix pass sample weights to final estimator
J-A16 Dec 4, 2019
af91dbc
added test to test_ransac.py
J-A16 Dec 4, 2019
4b33817
added test to test_ransac.py
J-A16 Dec 4, 2019
9b62a7b
Linter changes
J-A16 Dec 4, 2019
816f7a5
Undo
J-A16 Dec 4, 2019
14d8185
better test
J-A16 Dec 4, 2019
57e361e
lint correction
J-A16 Dec 4, 2019
2bc9103
added imports
J-A16 Dec 4, 2019
2a6638d
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
J-A16 Dec 4, 2019
4b7a474
added an entry to what's new
J-A16 Dec 4, 2019
5251f0c
fixed title underline length
J-A16 Dec 4, 2019
125dade
Update doc/whats_new/v0.23.rst
J-A16 Dec 4, 2019
fb3e5bc
Update doc/whats_new/v0.23.rst
J-A16 Dec 4, 2019
16b9c8a
fixed what's new entry
J-A16 Dec 7, 2019
2de97b8
resolve conflict
J-A16 Dec 9, 2019
6b2be6f
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
J-A16 Dec 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ Changelog
:func:`datasets.make_moons` now accept two-element tuple.
:pr:`15707` by :user:`Maciej J Mikulski <mjmikulski>`

:mod:`sklearn.linear_model`
...........................

- |Fix| Fixed a bug where if a `sample_weight` parameter was passed to the fit
method of :class:`linear_model.RANSACRegressor`, it would not be passed to
the wrapped `base_estimator` during the fitting of the final model.
:pr:`15573` by :user:`Jeremy Alexandre <J-A16>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
10 changes: 9 additions & 1 deletion sklearn/linear_model/_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def fit(self, X, y, sample_weight=None):
inlier_mask_best = None
X_inlier_best = None
y_inlier_best = None
inlier_best_idxs_subset = None
self.n_skips_no_inliers_ = 0
self.n_skips_invalid_data_ = 0
self.n_skips_invalid_model_ = 0
Expand Down Expand Up @@ -404,6 +405,7 @@ def fit(self, X, y, sample_weight=None):
inlier_mask_best = inlier_mask_subset
X_inlier_best = X_inlier_subset
y_inlier_best = y_inlier_subset
inlier_best_idxs_subset = inlier_idxs_subset

max_trials = min(
max_trials,
Expand Down Expand Up @@ -441,7 +443,13 @@ def fit(self, X, y, sample_weight=None):
ConvergenceWarning)

# estimate final model using all inliers
base_estimator.fit(X_inlier_best, y_inlier_best)
if sample_weight is None:
base_estimator.fit(X_inlier_best, y_inlier_best)
else:
base_estimator.fit(
X_inlier_best,
y_inlier_best,
sample_weight=sample_weight[inlier_best_idxs_subset])

self.estimator_ = base_estimator
self.inlier_mask_ = inlier_mask_best
Expand Down
20 changes: 20 additions & 0 deletions sklearn/linear_model/tests/test_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_raises_regexp
from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_allclose
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression, RANSACRegressor, Lasso
from sklearn.linear_model._ransac import _dynamic_max_trials
from sklearn.exceptions import ConvergenceWarning
Expand Down Expand Up @@ -494,3 +496,21 @@ def test_ransac_fit_sample_weight():
base_estimator = Lasso()
ransac_estimator = RANSACRegressor(base_estimator)
assert_raises(ValueError, ransac_estimator.fit, X, y, weights)


def test_ransac_final_model_fit_sample_weight():
X, y = make_regression(n_samples=1000, random_state=10)
rng = check_random_state(42)
sample_weight = rng.randint(1, 4, size=y.shape[0])
sample_weight = sample_weight / sample_weight.sum()
ransac = RANSACRegressor(base_estimator=LinearRegression(), random_state=0)
ransac.fit(X, y, sample_weight=sample_weight)

final_model = LinearRegression()
mask_samples = ransac.inlier_mask_
final_model.fit(
X[mask_samples], y[mask_samples],
sample_weight=sample_weight[mask_samples]
)

assert_allclose(ransac.estimator_.coef_, final_model.coef_)