Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] class_weight for Bagging, AdaBoost & GradientBoosting classifiers #4114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ Enhancements
faster in general. By `Joel Nothman`_.

- Add ``class_weight`` parameter to automatically weight samples by class
frequency for :class:`ensemble.RandomForestClassifier`,
:class:`tree.DecisionTreeClassifier`, :class:`ensemble.ExtraTreesClassifier`
frequency for :class:`ensemble.AdaBoostClassifier`,
:class:`ensemble.BaggingClassifier`,
:class:`ensemble.ExtraTreesClassifier`,
:class:`ensemble.GradientBoostingClassifier`,
:class:`ensemble.RandomForestClassifier`,
:class:`tree.DecisionTreeClassifier`,
and :class:`tree.ExtraTreeClassifier`. By `Trevor Stephens`_.

Documentation improvements
Expand Down
89 changes: 80 additions & 9 deletions sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

from ..base import ClassifierMixin, RegressorMixin
from ..externals.joblib import Parallel, delayed
from ..externals.six import with_metaclass
from ..externals.six import with_metaclass, string_types
from ..externals.six.moves import zip
from ..metrics import r2_score, accuracy_score
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..utils import check_random_state, check_X_y, check_array, column_or_1d
from ..utils import check_X_y, check_array, column_or_1d
from ..utils import check_random_state, compute_class_weight
from ..utils.random import sample_without_replacement
from ..utils.validation import has_fit_parameter, check_is_fitted

Expand All @@ -31,7 +32,7 @@


def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
seeds, verbose):
class_weight, seeds, verbose):
"""Private function used to build a batch of estimators within a job."""
# Retrieve settings
n_samples, n_features = X.shape
Expand All @@ -51,7 +52,6 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
support_sample_weight = has_fit_parameter(ensemble.base_estimator_,
"sample_weight")


# Build estimators
estimators = []
estimators_samples = []
Expand Down Expand Up @@ -98,6 +98,29 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,

curr_sample_weight[not_indices] = 0

if class_weight == 'subsample':
indices = np.where(curr_sample_weight > 0)

if class_weight == 'subsample':

classes_full = np.unique(y)
y_boot = y[indices]
classes_boot = np.unique(y_boot)

# Get class weights for the bootstrap sample, covering all
# classes in case some were missing from the bootstrap sample
expanded_class_weight = np.choose(
np.searchsorted(classes_boot, classes_full),
compute_class_weight('auto', classes_boot, y_boot),
mode='clip')

# Expand weights over the original y for this output
expanded_class_weight = expanded_class_weight[
np.searchsorted(classes_full, y)]

# Multiply all weights by sample & bootstrap weights
curr_sample_weight = curr_sample_weight * expanded_class_weight

estimator.fit(X[:, features], y, sample_weight=curr_sample_weight)
samples = curr_sample_weight > 0.

Expand Down Expand Up @@ -203,6 +226,7 @@ def __init__(self,
bootstrap=True,
bootstrap_features=False,
oob_score=False,
class_weight=None,
n_jobs=1,
random_state=None,
verbose=0):
Expand All @@ -215,6 +239,7 @@ def __init__(self,
self.bootstrap = bootstrap
self.bootstrap_features = bootstrap_features
self.oob_score = oob_score
self.class_weight = class_weight
self.n_jobs = n_jobs
self.random_state = random_state
self.verbose = verbose
Expand Down Expand Up @@ -250,7 +275,7 @@ def fit(self, X, y, sample_weight=None):

# Remap output
n_samples, self.n_features_ = X.shape
y = self._validate_y(y)
y, expanded_class_weight = self._validate_y_class_weight(y)

# Check parameters
self._validate_estimator()
Expand All @@ -275,6 +300,13 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")

# Apply class_weights to sample weights
if expanded_class_weight is not None:
if sample_weight is not None:
sample_weight = sample_weight * expanded_class_weight
else:
sample_weight = expanded_class_weight

# Free allocated memory, if any
self.estimators_ = None

Expand All @@ -290,6 +322,7 @@ def fit(self, X, y, sample_weight=None):
X,
y,
sample_weight,
self.class_weight,
seeds[starts[i]:starts[i + 1]],
verbose=self.verbose)
for i in range(n_jobs))
Expand All @@ -311,9 +344,9 @@ def fit(self, X, y, sample_weight=None):
def _set_oob_score(self, X, y):
"""Calculate out of bag predictions and score."""

def _validate_y(self, y):
def _validate_y_class_weight(self, y):
# Default implementation
return column_or_1d(y, warn=True)
return column_or_1d(y, warn=True), None


class BaggingClassifier(BaseBagging, ClassifierMixin):
Expand Down Expand Up @@ -365,6 +398,23 @@ class BaggingClassifier(BaseBagging, ClassifierMixin):
Whether to use out-of-bag samples to estimate
the generalization error.

class_weight : dict, "auto", "subsample" or None, optional
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.

The "auto" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data.

The "subsample" mode is the same as "auto" except that weights are
computed based on the bootstrap or sub-sample for every tree grown as
defined by the ``max_features`` and/or ``bootstrap`` options.

Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.

Note that this is supported only if the base estimator supports
sample weighting.

n_jobs : int, optional (default=1)
The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.
Expand Down Expand Up @@ -432,6 +482,7 @@ def __init__(self,
bootstrap=True,
bootstrap_features=False,
oob_score=False,
class_weight=None,
n_jobs=1,
random_state=None,
verbose=0):
Expand All @@ -444,6 +495,7 @@ def __init__(self,
bootstrap=bootstrap,
bootstrap_features=bootstrap_features,
oob_score=oob_score,
class_weight=class_weight,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
Expand Down Expand Up @@ -492,12 +544,31 @@ def _set_oob_score(self, X, y):
self.oob_decision_function_ = oob_decision_function
self.oob_score_ = oob_score

def _validate_y(self, y):
def _validate_y_class_weight(self, y):
y = column_or_1d(y, warn=True)
expanded_class_weight = None

if self.class_weight is not None:
y_original = np.copy(y)

self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = len(self.classes_)

return y
if self.class_weight is not None:
valid_presets = ('auto', 'subsample')
if isinstance(self.class_weight, string_types):
if self.class_weight not in valid_presets:
raise ValueError('Valid presets for class_weight include '
'"auto" and "subsample". Given "%s".'
% self.class_weight)

if self.class_weight != 'subsample':
expanded_class_weight = compute_class_weight(
self.class_weight, self.classes_, y_original)
expanded_class_weight = expanded_class_weight[
np.searchsorted(self.classes_, y_original)]

return y, expanded_class_weight

def predict(self, X):
"""Predict class for X.
Expand Down
Loading