Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+2] Add class_weight support to the forests & trees #3961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 13, 2015
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ Enhancements
- DBSCAN now supports sparse input and sample weights, and should be
faster in general. By `Joel Nothman`_.

- Add ``class_weight`` parameter to automatically weight samples by class
frequency for :class:`ensemble.RandomForestClassifier`,
:class:`tree.DecisionTreeClassifier`, :class:`ensemble.ExtraTreesClassifier`
and :class:`tree.ExtraTreeClassifier`. By `Trevor Stephens`_.

Documentation improvements
..........................

Expand Down Expand Up @@ -3183,3 +3188,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Matteo Visconti di Oleggio Castello: http://www.mvdoc.me

.. _Raghav R V: https://github.com/ragv

.. _Trevor Stephens: http://trevorstephens.com/
157 changes: 140 additions & 17 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class calls the ``fit`` method of each sub-estimator on random samples

from __future__ import division

import numpy as np

from warnings import warn
from abc import ABCMeta, abstractmethod

Expand All @@ -58,7 +56,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor,
ExtraTreeClassifier, ExtraTreeRegressor)
from ..tree._tree import DTYPE, DOUBLE
from ..utils import check_random_state, check_array
from ..utils import check_random_state, check_array, compute_class_weight
from ..utils.validation import DataConversionWarning, check_is_fitted
from .base import BaseEnsemble, _partition_estimators

Expand All @@ -72,7 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples


def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
verbose=0):
verbose=0, class_weight=None):
"""Private function used to fit a single tree in parallel."""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))
Expand All @@ -89,6 +87,32 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

if class_weight == 'subsample':

expanded_class_weight = [curr_sample_weight]

for k in range(y.shape[1]):
y_full = y[:, k]
classes_full = np.unique(y_full)
y_boot = y[indices, k]
classes_boot = np.unique(y_boot)

# Get class weights for the bootstrap sample, covering all
# classes in case some were missing from the bootstrap sample
weight_k = np.choose(
np.searchsorted(classes_boot, classes_full),
compute_class_weight('auto', classes_boot, y_boot),
mode='clip')

# Expand weights over the original y for this output
weight_k = weight_k[np.searchsorted(classes_full, y_full)]
expanded_class_weight.append(weight_k)

# Multiply all weights by sample & bootstrap weights
curr_sample_weight = np.prod(expanded_class_weight,
axis=0,
dtype=np.float64)

tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)

tree.indices_ = sample_counts > 0.
Expand Down Expand Up @@ -122,7 +146,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
class_weight=None):
super(BaseForest, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand All @@ -134,6 +159,7 @@ def __init__(self,
self.random_state = random_state
self.verbose = verbose
self.warm_start = warm_start
self.class_weight = class_weight

def apply(self, X):
"""Apply trees in the forest to X, return leaf indices.
Expand Down Expand Up @@ -213,11 +239,17 @@ def fit(self, X, y, sample_weight=None):

self.n_outputs_ = y.shape[1]

y = self._validate_y(y)
y, expanded_class_weight = self._validate_y_class_weight(y)

if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

if expanded_class_weight is not None:
if sample_weight is not None:
sample_weight = sample_weight * expanded_class_weight
else:
sample_weight = expanded_class_weight

# Check parameters
self._validate_estimator()

Expand Down Expand Up @@ -261,7 +293,7 @@ def fit(self, X, y, sample_weight=None):
backend="threading")(
delayed(_parallel_build_trees)(
t, self, X, y, sample_weight, i, len(trees),
verbose=self.verbose)
verbose=self.verbose, class_weight=self.class_weight)
for i, t in enumerate(trees))

# Collect newly grown trees
Expand All @@ -281,9 +313,9 @@ def fit(self, X, y, sample_weight=None):
def _set_oob_score(self, X, y):
"""Calculate out of bag predictions and score."""

def _validate_y(self, y):
def _validate_y_class_weight(self, y):
# Default implementation
return y
return y, None

@property
def feature_importances_(self):
Expand Down Expand Up @@ -324,7 +356,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
class_weight=None):

super(ForestClassifier, self).__init__(
base_estimator,
Expand All @@ -335,7 +368,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
class_weight=class_weight)

def _set_oob_score(self, X, y):
"""Compute out-of-bag score"""
Expand Down Expand Up @@ -381,8 +415,12 @@ def _set_oob_score(self, X, y):

self.oob_score_ = oob_score / self.n_outputs_

def _validate_y(self, y):
def _validate_y_class_weight(self, y):
y = np.copy(y)
expanded_class_weight = None

if self.class_weight is not None:
y_original = np.copy(y)

self.classes_ = []
self.n_classes_ = []
Expand All @@ -392,7 +430,52 @@ def _validate_y(self, y):
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])

return y
if self.class_weight is not None:
valid_presets = ('auto', 'subsample')
if isinstance(self.class_weight, six.string_types):
if self.class_weight not in valid_presets:
raise ValueError('Valid presets for class_weight include '
'"auto" and "subsample". Given "%s".'
% self.class_weight)
if self.warm_start:
warn('class_weight presets "auto" or "subsample" are '
'not recommended for warm_start if the fitted data '
'differs from the full dataset. In order to use '
'"auto" weights, use compute_class_weight("auto", '
'classes, y). In place of y you can use a large '
'enough sample of the full training set target to '
'properly estimate the class frequency '
'distributions. Pass the resulting weights as the '
'class_weight parameter.')
elif self.n_outputs_ > 1:
if not hasattr(self.class_weight, "__iter__"):
raise ValueError("For multi-output, class_weight should "
"be a list of dicts, or a valid string.")
elif len(self.class_weight) != self.n_outputs_:
raise ValueError("For multi-output, number of elements "
"in class_weight should match number of "
"outputs.")

if self.class_weight != 'subsample' or not self.bootstrap:
expanded_class_weight = []
for k in range(self.n_outputs_):
if self.class_weight in valid_presets:
class_weight_k = 'auto'
elif self.n_outputs_ == 1:
class_weight_k = self.class_weight
else:
class_weight_k = self.class_weight[k]
weight_k = compute_class_weight(class_weight_k,
self.classes_[k],
y_original[:, k])
weight_k = weight_k[np.searchsorted(self.classes_[k],
y_original[:, k])]
expanded_class_weight.append(weight_k)
expanded_class_weight = np.prod(expanded_class_weight,
axis=0,
dtype=np.float64)

return y, expanded_class_weight

def predict(self, X):
"""Predict class for X.
Expand Down Expand Up @@ -717,6 +800,24 @@ class RandomForestClassifier(ForestClassifier):
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.

class_weight : dict, list of dicts, "auto", "subsample" or None, optional

Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one. For
multi-output problems, a list of dicts can be provided in the same
order as the columns of y.

The "auto" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data.

The "subsample" mode is the same as "auto" except that weights are
computed based on the bootstrap sample for every tree grown.

For multi-output, the weights of each column of y will be multiplied.

Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.

Attributes
----------
estimators_ : list of DecisionTreeClassifier
Expand Down Expand Up @@ -765,7 +866,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
class_weight=None):
super(RandomForestClassifier, self).__init__(
base_estimator=DecisionTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -778,7 +880,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
class_weight=class_weight)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1027,6 +1130,24 @@ class ExtraTreesClassifier(ForestClassifier):
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.

class_weight : dict, list of dicts, "auto", "subsample" or None, optional

Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one. For
multi-output problems, a list of dicts can be provided in the same
order as the columns of y.

The "auto" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data.

The "subsample" mode is the same as "auto" except that weights are
computed based on the bootstrap sample for every tree grown.

For multi-output, the weights of each column of y will be multiplied.

Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.

Attributes
----------
estimators_ : list of DecisionTreeClassifier
Expand Down Expand Up @@ -1078,7 +1199,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
class_weight=None):
super(ExtraTreesClassifier, self).__init__(
base_estimator=ExtraTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -1090,7 +1212,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
class_weight=class_weight)

self.criterion = criterion
self.max_depth = max_depth
Expand Down
Loading