diff --git a/doc/whats_new.rst b/doc/whats_new.rst index f1ba8c6d3a439..251870907f251 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -153,6 +153,11 @@ Enhancements - DBSCAN now supports sparse input and sample weights, and should be faster in general. By `Joel Nothman`_. + - Add ``class_weight`` parameter to automatically weight samples by class + frequency for :class:`ensemble.RandomForestClassifier`, + :class:`tree.DecisionTreeClassifier`, :class:`ensemble.ExtraTreesClassifier` + and :class:`tree.ExtraTreeClassifier`. By `Trevor Stephens`_. + Documentation improvements .......................... @@ -3183,3 +3188,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Matteo Visconti di Oleggio Castello: http://www.mvdoc.me .. _Raghav R V: https://github.com/ragv + +.. _Trevor Stephens: http://trevorstephens.com/ diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 01a9bbf056f01..076df001b9899 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -41,8 +41,6 @@ class calls the ``fit`` method of each sub-estimator on random samples from __future__ import division -import numpy as np - from warnings import warn from abc import ABCMeta, abstractmethod @@ -58,7 +56,7 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor, ExtraTreeClassifier, ExtraTreeRegressor) from ..tree._tree import DTYPE, DOUBLE -from ..utils import check_random_state, check_array +from ..utils import check_random_state, check_array, compute_class_weight from ..utils.validation import DataConversionWarning, check_is_fitted from .base import BaseEnsemble, _partition_estimators @@ -72,7 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0): + verbose=0, class_weight=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -89,6 +87,32 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts + if class_weight == 'subsample': + + expanded_class_weight = [curr_sample_weight] + + for k in range(y.shape[1]): + y_full = y[:, k] + classes_full = np.unique(y_full) + y_boot = y[indices, k] + classes_boot = np.unique(y_boot) + + # Get class weights for the bootstrap sample, covering all + # classes in case some were missing from the bootstrap sample + weight_k = np.choose( + np.searchsorted(classes_boot, classes_full), + compute_class_weight('auto', classes_boot, y_boot), + mode='clip') + + # Expand weights over the original y for this output + weight_k = weight_k[np.searchsorted(classes_full, y_full)] + expanded_class_weight.append(weight_k) + + # Multiply all weights by sample & bootstrap weights + curr_sample_weight = np.prod(expanded_class_weight, + axis=0, + dtype=np.float64) + tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) tree.indices_ = sample_counts > 0. @@ -122,7 +146,8 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + class_weight=None): super(BaseForest, self).__init__( base_estimator=base_estimator, n_estimators=n_estimators, @@ -134,6 +159,7 @@ def __init__(self, self.random_state = random_state self.verbose = verbose self.warm_start = warm_start + self.class_weight = class_weight def apply(self, X): """Apply trees in the forest to X, return leaf indices. @@ -213,11 +239,17 @@ def fit(self, X, y, sample_weight=None): self.n_outputs_ = y.shape[1] - y = self._validate_y(y) + y, expanded_class_weight = self._validate_y_class_weight(y) if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) + if expanded_class_weight is not None: + if sample_weight is not None: + sample_weight = sample_weight * expanded_class_weight + else: + sample_weight = expanded_class_weight + # Check parameters self._validate_estimator() @@ -261,7 +293,7 @@ def fit(self, X, y, sample_weight=None): backend="threading")( delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), - verbose=self.verbose) + verbose=self.verbose, class_weight=self.class_weight) for i, t in enumerate(trees)) # Collect newly grown trees @@ -281,9 +313,9 @@ def fit(self, X, y, sample_weight=None): def _set_oob_score(self, X, y): """Calculate out of bag predictions and score.""" - def _validate_y(self, y): + def _validate_y_class_weight(self, y): # Default implementation - return y + return y, None @property def feature_importances_(self): @@ -324,7 +356,8 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + class_weight=None): super(ForestClassifier, self).__init__( base_estimator, @@ -335,7 +368,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + class_weight=class_weight) def _set_oob_score(self, X, y): """Compute out-of-bag score""" @@ -381,8 +415,12 @@ def _set_oob_score(self, X, y): self.oob_score_ = oob_score / self.n_outputs_ - def _validate_y(self, y): + def _validate_y_class_weight(self, y): y = np.copy(y) + expanded_class_weight = None + + if self.class_weight is not None: + y_original = np.copy(y) self.classes_ = [] self.n_classes_ = [] @@ -392,7 +430,52 @@ def _validate_y(self, y): self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) - return y + if self.class_weight is not None: + valid_presets = ('auto', 'subsample') + if isinstance(self.class_weight, six.string_types): + if self.class_weight not in valid_presets: + raise ValueError('Valid presets for class_weight include ' + '"auto" and "subsample". Given "%s".' + % self.class_weight) + if self.warm_start: + warn('class_weight presets "auto" or "subsample" are ' + 'not recommended for warm_start if the fitted data ' + 'differs from the full dataset. In order to use ' + '"auto" weights, use compute_class_weight("auto", ' + 'classes, y). In place of y you can use a large ' + 'enough sample of the full training set target to ' + 'properly estimate the class frequency ' + 'distributions. Pass the resulting weights as the ' + 'class_weight parameter.') + elif self.n_outputs_ > 1: + if not hasattr(self.class_weight, "__iter__"): + raise ValueError("For multi-output, class_weight should " + "be a list of dicts, or a valid string.") + elif len(self.class_weight) != self.n_outputs_: + raise ValueError("For multi-output, number of elements " + "in class_weight should match number of " + "outputs.") + + if self.class_weight != 'subsample' or not self.bootstrap: + expanded_class_weight = [] + for k in range(self.n_outputs_): + if self.class_weight in valid_presets: + class_weight_k = 'auto' + elif self.n_outputs_ == 1: + class_weight_k = self.class_weight + else: + class_weight_k = self.class_weight[k] + weight_k = compute_class_weight(class_weight_k, + self.classes_[k], + y_original[:, k]) + weight_k = weight_k[np.searchsorted(self.classes_[k], + y_original[:, k])] + expanded_class_weight.append(weight_k) + expanded_class_weight = np.prod(expanded_class_weight, + axis=0, + dtype=np.float64) + + return y, expanded_class_weight def predict(self, X): """Predict class for X. @@ -717,6 +800,24 @@ class RandomForestClassifier(ForestClassifier): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + class_weight : dict, list of dicts, "auto", "subsample" or None, optional + + Weights associated with classes in the form ``{class_label: weight}``. + If not given, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + The "auto" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data. + + The "subsample" mode is the same as "auto" except that weights are + computed based on the bootstrap sample for every tree grown. + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -765,7 +866,8 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + class_weight=None): super(RandomForestClassifier, self).__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, @@ -778,7 +880,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + class_weight=class_weight) self.criterion = criterion self.max_depth = max_depth @@ -1027,6 +1130,24 @@ class ExtraTreesClassifier(ForestClassifier): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + class_weight : dict, list of dicts, "auto", "subsample" or None, optional + + Weights associated with classes in the form ``{class_label: weight}``. + If not given, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + The "auto" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data. + + The "subsample" mode is the same as "auto" except that weights are + computed based on the bootstrap sample for every tree grown. + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1078,7 +1199,8 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + class_weight=None): super(ExtraTreesClassifier, self).__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, @@ -1090,7 +1212,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + class_weight=class_weight) self.criterion = criterion self.max_depth = max_depth diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index b37d760f7eaf4..cd0697af20500 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -235,7 +235,6 @@ def test_unfitted_feature_importances(): yield check_unfitted_feature_importances, name - def check_oob_score(name, X, y, n_estimators=20): """Check that oob prediction is a good estimation of the generalization error.""" @@ -712,7 +711,6 @@ def check_memory_layout(name, dtype): y = iris.target assert_array_equal(est.fit(X, y).predict(X), y) - # Strided X = np.asarray(iris.data[::3], dtype=dtype) y = iris.target[::3] @@ -747,6 +745,102 @@ def test_1d_input(): yield check_1d_input, name, X, X_2d, y +def check_class_weights(name): + """Check class_weights resemble sample_weights behavior.""" + ForestClassifier = FOREST_CLASSIFIERS[name] + + # Iris is balanced, so no effect expected for using 'auto' weights + clf1 = ForestClassifier(random_state=0) + clf1.fit(iris.data, iris.target) + clf2 = ForestClassifier(class_weight='auto', random_state=0) + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + # Make a multi-output problem with three copies of Iris + iris_multi = np.vstack((iris.target, iris.target, iris.target)).T + # Create user-defined weights that should balance over the outputs + clf3 = ForestClassifier(class_weight=[{0: 2., 1: 2., 2: 1.}, + {0: 2., 1: 1., 2: 2.}, + {0: 1., 1: 2., 2: 2.}], + random_state=0) + clf3.fit(iris.data, iris_multi) + assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_) + # Check against multi-output "auto" which should also have no effect + clf4 = ForestClassifier(class_weight='auto', random_state=0) + clf4.fit(iris.data, iris_multi) + assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_) + + # Inflate importance of class 1, check against user-defined weights + sample_weight = np.ones(iris.target.shape) + sample_weight[iris.target == 1] *= 100 + class_weight = {0: 1., 1: 100., 2: 1.} + clf1 = ForestClassifier(random_state=0) + clf1.fit(iris.data, iris.target, sample_weight) + clf2 = ForestClassifier(class_weight=class_weight, random_state=0) + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + # Check that sample_weight and class_weight are multiplicative + clf1 = ForestClassifier(random_state=0) + clf1.fit(iris.data, iris.target, sample_weight**2) + clf2 = ForestClassifier(class_weight=class_weight, random_state=0) + clf2.fit(iris.data, iris.target, sample_weight) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + +def test_class_weights(): + for name in FOREST_CLASSIFIERS: + yield check_class_weights, name + + +def check_class_weight_auto_and_bootstrap_multi_output(name): + """Test class_weight works for multi-output""" + ForestClassifier = FOREST_CLASSIFIERS[name] + _y = np.vstack((y, np.array(y) * 2)).T + clf = ForestClassifier(class_weight='auto', random_state=0) + clf.fit(X, _y) + clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}], + random_state=0) + clf.fit(X, _y) + clf = ForestClassifier(class_weight='subsample', random_state=0) + clf.fit(X, _y) + + +def test_class_weight_auto_and_bootstrap_multi_output(): + for name in FOREST_CLASSIFIERS: + yield check_class_weight_auto_and_bootstrap_multi_output, name + + +def check_class_weight_errors(name): + """Test if class_weight raises errors and warnings when expected.""" + ForestClassifier = FOREST_CLASSIFIERS[name] + _y = np.vstack((y, np.array(y) * 2)).T + + # Invalid preset string + clf = ForestClassifier(class_weight='the larch', random_state=0) + assert_raises(ValueError, clf.fit, X, y) + assert_raises(ValueError, clf.fit, X, _y) + + # Warning warm_start with preset + clf = ForestClassifier(class_weight='auto', warm_start=True, + random_state=0) + assert_warns(UserWarning, clf.fit, X, y) + assert_warns(UserWarning, clf.fit, X, _y) + + # Not a list or preset for multi-output + clf = ForestClassifier(class_weight=1, random_state=0) + assert_raises(ValueError, clf.fit, X, _y) + + # Incorrect length list for multi-output + clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}], random_state=0) + assert_raises(ValueError, clf.fit, X, _y) + + +def test_class_weight_errors(): + for name in FOREST_CLASSIFIERS: + yield check_class_weight_errors, name + + def check_warm_start(name, random_state=42): """Test if fitting incrementally with warm start gives a forest of the right size and the same results as a normal fit.""" diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 96c02a2088b25..c8f68b21bbe58 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -848,6 +848,78 @@ def test_sample_weight_invalid(): assert_raises(ValueError, clf.fit, X, y, sample_weight=sample_weight) +def check_class_weights(name): + """Check class_weights resemble sample_weights behavior.""" + TreeClassifier = CLF_TREES[name] + + # Iris is balanced, so no effect expected for using 'auto' weights + clf1 = TreeClassifier(random_state=0) + clf1.fit(iris.data, iris.target) + clf2 = TreeClassifier(class_weight='auto', random_state=0) + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + # Make a multi-output problem with three copies of Iris + iris_multi = np.vstack((iris.target, iris.target, iris.target)).T + # Create user-defined weights that should balance over the outputs + clf3 = TreeClassifier(class_weight=[{0: 2., 1: 2., 2: 1.}, + {0: 2., 1: 1., 2: 2.}, + {0: 1., 1: 2., 2: 2.}], + random_state=0) + clf3.fit(iris.data, iris_multi) + assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_) + # Check against multi-output "auto" which should also have no effect + clf4 = TreeClassifier(class_weight='auto', random_state=0) + clf4.fit(iris.data, iris_multi) + assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_) + + # Inflate importance of class 1, check against user-defined weights + sample_weight = np.ones(iris.target.shape) + sample_weight[iris.target == 1] *= 100 + class_weight = {0: 1., 1: 100., 2: 1.} + clf1 = TreeClassifier(random_state=0) + clf1.fit(iris.data, iris.target, sample_weight) + clf2 = TreeClassifier(class_weight=class_weight, random_state=0) + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + # Check that sample_weight and class_weight are multiplicative + clf1 = TreeClassifier(random_state=0) + clf1.fit(iris.data, iris.target, sample_weight**2) + clf2 = TreeClassifier(class_weight=class_weight, random_state=0) + clf2.fit(iris.data, iris.target, sample_weight) + assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) + + +def test_class_weights(): + for name in CLF_TREES: + yield check_class_weights, name + + +def check_class_weight_errors(name): + """Test if class_weight raises errors and warnings when expected.""" + TreeClassifier = CLF_TREES[name] + _y = np.vstack((y, np.array(y) * 2)).T + + # Invalid preset string + clf = TreeClassifier(class_weight='the larch', random_state=0) + assert_raises(ValueError, clf.fit, X, y) + assert_raises(ValueError, clf.fit, X, _y) + + # Not a list or preset for multi-output + clf = TreeClassifier(class_weight=1, random_state=0) + assert_raises(ValueError, clf.fit, X, _y) + + # Incorrect length list for multi-output + clf = TreeClassifier(class_weight=[{-1: 0.5, 1: 1.}], random_state=0) + assert_raises(ValueError, clf.fit, X, _y) + + +def test_class_weight_errors(): + for name in CLF_TREES: + yield check_class_weight_errors, name + + def test_max_leaf_nodes(): """Test greedy trees with max_depth + 1 leafs. """ from sklearn.tree._tree import TREE_LEAF @@ -988,7 +1060,7 @@ def check_sparse_input(tree, dataset, max_depth=None): if tree in CLF_TREES: assert_array_almost_equal(s.predict_proba(X_sparse_test), - y_proba) + y_proba) assert_array_almost_equal(s.predict_log_proba(X_sparse_test), y_log_proba) @@ -1078,6 +1150,7 @@ def check_sparse_criterion(tree, dataset): "trees".format(tree)) assert_array_almost_equal(s.predict(X), d.predict(X)) + def test_sparse_criterion(): for tree, dataset in product(SPARSE_TREES, ["sparse-pos", "sparse-neg", "sparse-mix", @@ -1104,7 +1177,7 @@ def check_explicit_sparse_zeros(tree, max_depth=3, n_nonzero_i = random_state.binomial(n_samples, 0.5) indices_i = random_state.permutation(samples)[:n_nonzero_i] indices.append(indices_i) - data_i = random_state.binomial(3, 0.5, size=(n_nonzero_i, )) - 1 + data_i = random_state.binomial(3, 0.5, size=(n_nonzero_i, )) - 1 data.append(data_i) offset += n_nonzero_i indptr.append(offset) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 2e146ca9cd430..6dd0cf2999bcc 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -25,7 +25,7 @@ from ..base import BaseEstimator, ClassifierMixin, RegressorMixin from ..externals import six from ..feature_selection.from_model import _LearntSelectorMixin -from ..utils import check_array, check_random_state +from ..utils import check_array, check_random_state, compute_class_weight from ..utils.validation import NotFittedError, check_is_fitted @@ -81,7 +81,8 @@ def __init__(self, min_weight_fraction_leaf, max_features, max_leaf_nodes, - random_state): + random_state, + class_weight=None): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -91,6 +92,7 @@ def __init__(self, self.max_features = max_features self.random_state = random_state self.max_leaf_nodes = max_leaf_nodes + self.class_weight = class_weight self.n_features_ = None self.n_outputs_ = None @@ -146,6 +148,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): is_classification = isinstance(self, ClassifierMixin) y = np.atleast_1d(y) + expanded_class_weight = None if y.ndim == 1: # reshape is necessary to preserve the data contiguity against vs @@ -160,11 +163,45 @@ def fit(self, X, y, sample_weight=None, check_input=True): self.classes_ = [] self.n_classes_ = [] + if self.class_weight is not None: + y_original = np.copy(y) + for k in range(self.n_outputs_): classes_k, y[:, k] = np.unique(y[:, k], return_inverse=True) self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) + if self.class_weight is not None: + if isinstance(self.class_weight, six.string_types): + if self.class_weight != "auto": + raise ValueError('The only supported preset for ' + 'class_weight is "auto". Given "%s".' + % self.class_weight) + elif self.n_outputs_ > 1: + if not hasattr(self.class_weight, "__iter__"): + raise ValueError('For multi-output, class_weight ' + 'should be a list of dicts, or ' + '"auto".') + elif len(self.class_weight) != self.n_outputs_: + raise ValueError("For multi-output, number of " + "elements in class_weight should " + "match number of outputs.") + expanded_class_weight = [] + for k in range(self.n_outputs_): + if self.n_outputs_ == 1 or self.class_weight == 'auto': + class_weight_k = self.class_weight + else: + class_weight_k = self.class_weight[k] + weight_k = compute_class_weight(class_weight_k, + self.classes_[k], + y_original[:, k]) + weight_k = weight_k[np.searchsorted(self.classes_[k], + y_original[:, k])] + expanded_class_weight.append(weight_k) + expanded_class_weight = np.prod(expanded_class_weight, + axis=0, + dtype=np.float64) + else: self.classes_ = [None] * self.n_outputs_ self.n_classes_ = [1] * self.n_outputs_ @@ -240,6 +277,12 @@ def fit(self, X, y, sample_weight=None, check_input=True): "number of samples=%d" % (len(sample_weight), n_samples)) + if expanded_class_weight is not None: + if sample_weight is not None: + sample_weight = sample_weight * expanded_class_weight + else: + sample_weight = expanded_class_weight + # Set min_weight_leaf from min_weight_fraction_leaf if self.min_weight_fraction_leaf != 0. and sample_weight is not None: min_weight_leaf = (self.min_weight_fraction_leaf * @@ -428,6 +471,21 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): If None then unlimited number of leaf nodes. If not None then ``max_depth`` will be ignored. + class_weight : dict, list of dicts, "auto" or None, optional + (default=None) + Weights associated with classes in the form ``{class_label: weight}``. + If not given, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + The "auto" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data. + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; @@ -497,7 +555,8 @@ def __init__(self, min_weight_fraction_leaf=0., max_features=None, random_state=None, - max_leaf_nodes=None): + max_leaf_nodes=None, + class_weight=None): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -507,6 +566,7 @@ def __init__(self, min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, random_state=random_state) def predict_proba(self, X): @@ -751,7 +811,8 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", random_state=None, - max_leaf_nodes=None): + max_leaf_nodes=None, + class_weight=None): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -761,6 +822,7 @@ def __init__(self, min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, random_state=random_state) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 1c1d1970756df..22ec1168b0dcc 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -773,6 +773,8 @@ def check_class_weight_classifiers(name, Classifier): classifier = Classifier(class_weight=class_weight) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) + if hasattr(classifier, "min_weight_fraction_leaf"): + classifier.set_params(min_weight_fraction_leaf=0.01) set_random_state(classifier) classifier.fit(X_train, y_train)