Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] Sparse One vs. Rest #3276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
a4fe5b0
Modified sparse OvR to handle sparse target data
hamsal Jun 13, 2014
8dbae82
Progress comment
hamsal Jun 13, 2014
4662b95
Introduce multiclass behvior into predict_ovr, minimal testing in tes…
hamsal Jun 13, 2014
0771ba0
Initialized sparse output in fit_ovr LabelBinarizer to be sparse if Y…
hamsal Jun 13, 2014
2a5ebfb
Enforced correct dtype in predict_ovr
hamsal Jun 13, 2014
1275ce5
Included first test for sparse ovr
hamsal Jun 18, 2014
c668fc7
Included first test for sparse ovr
hamsal Jun 18, 2014
94d7d04
Revised sparse ovr predict test
hamsal Jun 19, 2014
63aab29
Implemented construction of csc_matrix by column indicies in predict_ovr
hamsal Jun 19, 2014
7c51b1a
Revised formating and indentations
hamsal Jun 24, 2014
bba115a
Revised test_ovr_fit_predict_sparse to ensure identical results from …
hamsal Jun 24, 2014
5343138
Revised predict_ovr to loop over estimators in the multiclass case
hamsal Jun 24, 2014
b3860cd
Removed blank line
hamsal Jun 24, 2014
b492030
Defaulted label binarizer to set sparse_output=True when training ovr…
hamsal Jun 26, 2014
9a3c831
Revised predict_ovr to work with non integer labels
hamsal Jun 27, 2014
312e108
Fixed type, sp.issparse(y) -> True
hamsal Jun 27, 2014
ee4a715
Attempt to avoid a sparse effecieny warning by not converting a csc m…
hamsal Jun 27, 2014
b2d0f1e
Revised csc sparse matrix case in label_binarize in attempt to avoid …
hamsal Jun 27, 2014
8e2f9a2
Cast sparse array to csc in fit_ovr and wrote a get_col_ helper, remo…
hamsal Jun 28, 2014
4f82b66
Implemented tests for predict_proba and decison_function with a class…
hamsal Jun 30, 2014
0ae6dec
Restarting OrthogonalMatchingPursuitCV failure
hamsal Jun 30, 2014
6e5c3ae
Measured len of X in predict_ovr to allow for sparse data
hamsal Jul 1, 2014
6b9b53e
Revised column stacking in fit_ovr to be a generator expression
hamsal Jul 1, 2014
b35671a
Restartinig travis MD5 sums mismatch
hamsal Jul 2, 2014
c766dd6
Tested label binarizer with a sparse_output=True binary case
hamsal Jul 2, 2014
55cea43
len_X -> n_samples
hamsal Jul 5, 2014
986b43b
swithced to a dense label binarizer in the three class multiclass case
hamsal Jul 7, 2014
2454009
Corrected True-> False back from trials
hamsal Jul 7, 2014
86fa719
Corrected multiclass conditional
hamsal Jul 7, 2014
350ccc3
Documentation on working of _get_col
hamsal Jul 7, 2014
9a7635f
_get_col => getcol, formating revisions in fit_ovr
hamsal Jul 8, 2014
6805704
Removed special handling for multiclass case in fit_ovr
hamsal Jul 8, 2014
63c3b58
Supressed SparseEfficienyWarning by writing ignore_warning_class helper
hamsal Jul 10, 2014
1dd6e95
Revised ignore_warning_class to wrapp the function in a way to allow …
hamsal Jul 10, 2014
ebdae52
ignore_warning_class Documentation /travis rebuild
hamsal Jul 10, 2014
6a50b82
Implemeneted search of all warnings raised in assert_warns
hamsal Jul 11, 2014
99dfe1b
Fixed assert_warns call
hamsal Jul 11, 2014
ec0558c
Restart OrthogonalMatchingPursuitCV failure
hamsal Jul 11, 2014
01a4cb7
Cleaned unsused additions from multiclass.py
hamsal Jul 11, 2014
91d9354
Updated documentation in multiclass.py
hamsal Jul 14, 2014
104ee77
Rewrote expression for found in assert_warms idiomatically
hamsal Jul 14, 2014
b31e9b6
Edited mention of 2d array output to also include possiblity of spars…
hamsal Jul 14, 2014
37d7b19
Removed extra blank lines from assert_warns
hamsal Jul 14, 2014
a20cf27
A collection of small changes, Commented sparse_output = True desicio…
hamsal Jul 15, 2014
2c15f1b
Undo documentation edits to OvO fit and predict
hamsal Jul 15, 2014
09f4b25
Document the y_type_ attribute of LabelBinarizer
hamsal Jul 15, 2014
17407a8
Comment fit_ovr and predict_ovr as public functions
hamsal Jul 15, 2014
a3b909a
Removed make_mlb rename in test_ovr_fit_predict_sparse
hamsal Jul 15, 2014
911bff2
Use lb.classes_ in fit_ovr to maintain class dtype
hamsal Jul 15, 2014
80f57f2
Comment j_jobs > 1 in fit_orv
hamsal Jul 16, 2014
dbc67af
LabelBinarizer and ovr fail multioutput, test binary ovr
hamsal Jul 16, 2014
26d63c3
Fit binary target data on one line
hamsal Jul 17, 2014
855fdd1
Fix typo individual
hamsal Jul 17, 2014
e1dc470
Untab overindented line in predict docstring
hamsal Jul 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/modules/multiclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ by decomposing such problems into binary classification problems.
several joint classification tasks. This is a generalization
of the multi-label classification task, where the set of classification
problem is restricted to binary classification, and of the multi-class
classification task. *The output format is a 2d numpy array.*
classification task. *The output format is a 2d numpy array or sparse
matrix.*

The set of labels can be different for each output variable.
For instance a sample could be assigned "pear" for an output variable that
Expand Down
111 changes: 94 additions & 17 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@
#
# License: BSD 3 clause

import array
import numpy as np
import warnings
import scipy.sparse as sp

from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
from .base import MetaEstimatorMixin
from .preprocessing import LabelBinarizer
from .metrics.pairwise import euclidean_distances
from .utils import check_random_state
from .utils.multiclass import type_of_target
from .utils.multiclass import unique_labels
from .utils.validation import _num_samples
from .externals.joblib import Parallel
from .externals.joblib import delayed

Expand Down Expand Up @@ -81,24 +86,96 @@ def _check_estimator(estimator):


def fit_ovr(estimator, X, y, n_jobs=1):
"""Fit a one-vs-the-rest strategy."""
_check_estimator(estimator)
"""Fit a list of estimators using a one-vs-the-rest strategy.

lb = LabelBinarizer()
Y = lb.fit_transform(y)
Parameters
----------
estimator : estimator object
An estimator object implementing `fit` and one of `decision_function`
or `predict_proba`.

estimators = Parallel(n_jobs=n_jobs)(
delayed(_fit_binary)(estimator, X, Y[:, i], classes=["not %s" % i, i])
for i in range(Y.shape[1]))
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Data.

y : {array-like, sparse matrix}, shape = [n_samples] or
[n_samples, n_classes] Multi-class targets. An indicator matrix
turns on multilabel classification.

Returns
-------
self
"""
_check_estimator(estimator)
# A sparse LabelBinarizer, with sparse_output=True, has been shown to
# outpreform or match a dense label binarizer in all cases and has also
# resulted in less or equal memory consumption in the fit_ovr function
# overall.
lb = LabelBinarizer(sparse_output=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the amount of thought and benchmarking efforts that went into making this decision, I think it's worth to at least explain the empirical results in a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have included a comment summarizing the benefits

Y = lb.fit_transform(y)
Y = Y.tocsc()
columns = (col.toarray().ravel() for col in Y.T)
# In cases where individual estimators are very fast to train setting
# n_jobs > 1 in can results in slower performance due to the overhead
# of spawning threads.
estimators = Parallel(n_jobs=n_jobs)(delayed(_fit_binary)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment (or a sentence in the docs) stating that n_jobs > 1 can be slower than n_jobs == 1 when the individual binary classifiers are very fast to fit (as in the case in @arjoly's benchmark.)

You can add a comment in the source referencing this joblib issue.

(estimator,
X,
column,
classes=["not %s" % i,
lb.classes_[i]])
for i, column in enumerate(columns))
return estimators, lb


def predict_ovr(estimators, label_binarizer, X):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function assumes that estimators comes from our fit_ovr; specifically, that all estimators in the list are of the same type (at least w.r.t. the threshold.) Since it is a public function, there's no telling what users might do; like learn SVMs for the first 5 classes and random forests for the other 5.

EDIT: ways to solve this:

  1. document
  2. document + raise exception if estimators is heterogeneous
  3. document + have a list of thresholds

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or 4) deprecate and make them private...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I say (1) or perhaps (2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented suggestion number 2 and included a test for it.

"""Make predictions using the one-vs-the-rest strategy."""
Y = np.array([_predict_binary(e, X) for e in estimators])
"""Predict multi-class targets using the one vs rest strategy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this blank line.

Parameters
----------
estimators : list of `n_classes` estimators, Estimators used for
predictions. The list must be homogeneous with respect to the type of
estimators. fit_ovr supplies this list as part of its output.

label_binarizer : LabelBinarizer object, Object used to transform
multiclass labels to binary labels and vice-versa. fit_ovr supplies
this object as part of its output.

X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Data.

Returns
-------
y : {array-like, sparse matrix}, shape = [n_samples] or
[n_samples, n_classes]. Predicted multi-class targets.
"""
e_types = set([type(e) for e in estimators if not
isinstance(e, _ConstantPredictor)])
if len(e_types) > 1:
raise ValueError("List of estimators must contain estimators of the"
" same type but contains types {0}".format(e_types))
e = estimators[0]
thresh = 0 if hasattr(e, "decision_function") and is_classifier(e) else .5
return label_binarizer.inverse_transform(Y.T, threshold=thresh)

if label_binarizer.y_type_ == "multiclass":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can y_type_ be binary, continuous or something else? Does the else: branch stay correct in all these cases?

EDIT: I think the cases to consider are "binary" and "multiclass-multioutput". Could you please check that these cases are covered by tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the need for binary data with One vs. Rest since it is not multiclass and would be the same thing as fitting a regular estimator. I am also not sure "multiclass-multioutput" would work since it does not fit with the scheme of One vs. Rest and there is no way to binarize it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not saying there is need for it. I'm asking what happens if a user
tries it.
On Jul 15, 2014 9:58 PM, "hamsal" [email protected] wrote:

In sklearn/multiclass.py:

 e = estimators[0]
 thresh = 0 if hasattr(e, "decision_function") and is_classifier(e) else .5
  • return label_binarizer.inverse_transform(Y.T, threshold=thresh)
  • if label_binarizer.y_type_ == "multiclass":

I don't see the need for binary data with One vs. Rest since it is not
multiclass and would be the same thing as fitting a regular estimator. I am
also not sure "multiclass-multioutput" would work since it does not fit
with the scheme of One vs. Rest and there is no way to binarize it.


Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/3276/files#r14959192.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have raised a ValueError in LabelBinarizer when mulioutput data is used because this is where the issue lies, I wrote tests for OvR and the LabelBinarizer to assert these errors and finally I included a binary test case for OvR which demonstrates correctness.

maxima = np.empty(X.shape[0], dtype=float)
maxima.fill(-np.inf)
argmaxima = np.zeros(X.shape[0], dtype=int)
for i, e in enumerate(estimators):
pred = _predict_binary(e, X)
np.maximum(maxima, pred, out=maxima)
argmaxima[maxima == pred] = i
return label_binarizer.classes_[np.array(argmaxima.T)]
else:
n_samples = _num_samples(X)
indices = array.array('i')
indptr = array.array('i', [0])
for e in estimators:
indices.extend(np.where(_predict_binary(e, X) > thresh)[0])
indptr.append(len(indices))
data = np.ones(len(indices), dtype=int)
indicator = sp.csc_matrix((data, indices, indptr),
shape=(n_samples, len(estimators)))
return label_binarizer.inverse_transform(indicator)


def predict_proba_ovr(estimators, X, is_multilabel):
Expand Down Expand Up @@ -190,9 +267,9 @@ def fit(self, X, y):
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Data.

y : array-like, shape = [n_samples] or [n_samples, n_classes]
Multi-class targets. An indicator matrix turns on multilabel
classification.
y : {array-like, sparse matrix}, shape = [n_samples] or
[n_samples, n_classes] Multi-class targets. An indicator matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y : {array-like, sparse matrix}, shape = [n_samples] or [n_samples, n_classes]

Could it be one line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extends over the line limit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it render well in the doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on building the doc, I am getting errors importing ImportError: No module named sklearn.externals.six after running make html. I am building scikit learn python setup.py build_ext --inplace then adding the directory to my PYTHONPATH. I am still trying to figure out the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with make doc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule is in the scikit-learn folder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I have gotten it to start building. I apprently have not run python setup.py install in the scikit-learn folder before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks bad in the documentation [n_samples, n_classes] is in the body. Maybe it is better to move the entire statement shape = [n_samples] or [n_samples, n_classes] into the body?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape = [n_samples] or [n_samples, n_classes] into the body?

Usually it is put in the header.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Althoug it is not entirely precise another solution could be to shorten shape = [n_samples] or [n_samples, n_classes] to shape = [n_samples, n_classes]

turns on multilabel classification.

Returns
-------
Expand All @@ -216,8 +293,8 @@ def predict(self, X):

Returns
-------
y : array-like, shape = [n_samples]
Predicted multi-class targets.
y : {array-like, sparse matrix}, shape = [n_samples] or
[n_samples, n_classes]. Predicted multi-class targets.
"""
self._check_is_fitted()

Expand All @@ -242,7 +319,7 @@ def predict_proba(self, X):

Returns
-------
T : array-like, shape = [n_samples, n_classes]
T : {array-like, sparse matrix}, shape = [n_samples, n_classes]
Returns the probability of the sample for each class in the model,
where classes are ordered as they are in `self.classes_`.
"""
Expand Down Expand Up @@ -271,7 +348,7 @@ def decision_function(self, X):
@property
def multilabel_(self):
"""Whether this is a multilabel classifier"""
return self.label_binarizer_.multilabel_
return self.label_binarizer_.y_type_.startswith('multilabel')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The y_type_ attribute of LabelBinarizer is not documented, could you please add it to the Attributes section?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the entry in the Attributes section


def score(self, X, y):
if self.multilabel_:
Expand Down
23 changes: 22 additions & 1 deletion sklearn/preprocessing/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
`classes_` : array of shape [n_class]
Holds the label for each class.

`y_type_` : str,
Represents the type of the target data as evaluated by
utils.multiclass.type_of_target. Possible type are 'continuous',
'continuous-multioutput', 'binary', 'multiclass',
'mutliclass-multioutput', 'multilabel-sequences',
'multilabel-indicator', and 'unknown'.

`multilabel_` : boolean
True if the transformer was fitted on a multilabel rather than a
multiclass set of labels. The multilabel_ attribute is deprecated
Expand Down Expand Up @@ -301,6 +308,10 @@ def fit(self, y):
self : returns an instance of self.
"""
self.y_type_ = type_of_target(y)
if 'multioutput' in self.y_type_:
raise ValueError("Multioutput target data is not supported with "
"label binarization")

self.sparse_input_ = sp.issparse(y)
self.classes_ = unique_labels(y)
return self
Expand Down Expand Up @@ -462,6 +473,9 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
pos_label = -neg_label

y_type = type_of_target(y)
if 'multioutput' in y_type:
raise ValueError("Multioutput target data is not supported with label "
"binarization")

n_samples = y.shape[0] if sp.issparse(y) else len(y)
n_classes = len(classes)
Expand Down Expand Up @@ -517,14 +531,19 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,

if pos_switch:
Y[Y == pos_label] = 0
else:
Y.data = astype(Y.data, int, copy=False)

# preserve label ordering
if np.any(classes != sorted_class):
indices = np.argsort(classes)
Y = Y[:, indices]

if y_type == "binary":
Y = Y[:, -1].reshape((-1, 1))
if sparse_output:
Y = Y.getcol(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this interesting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found from setting lb = LabelBinarizer(sparse_output=True) this case was previously unsupported and would throw an ambiguous error, so I thought it best to support it instead of creating a tailored error message.

else:
Y = Y[:, -1].reshape((-1, 1))

return Y

Expand Down Expand Up @@ -600,6 +619,8 @@ def _inverse_binarize_thresholding(y, output_type, classes, threshold):

# Inverse transform data
if output_type == "binary":
if sp.issparse(y):
y = y.toarray()
if y.ndim == 2 and y.shape[1] == 2:
return classes[y[:, 1]]
else:
Expand Down
14 changes: 14 additions & 0 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def test_label_binarizer_errors():
y=np.array([[1, 2, 3], [2, 1, 3]]), output_type="binary",
classes=[1, 2, 3], threshold=0)

# Fail on multioutput data
assert_raises(ValueError, LabelBinarizer().fit, np.array([[1, 3], [2, 1]]))
assert_raises(ValueError, label_binarize, np.array([[1, 3], [2, 1]]),
[1, 2, 3])


def test_label_encoder():
"""Test LabelEncoder's transform and inverse_transform methods"""
Expand Down Expand Up @@ -467,6 +472,15 @@ def test_label_binarize_binary():

yield check_binarized_results, y, classes, pos_label, neg_label, expected

# Binary case where sparse_output = True will not result in a ValueError
y = [0, 1, 0]
classes = [0, 1]
pos_label = 3
neg_label = 0
expected = np.array([[3, 0], [0, 3], [3, 0]])[:, 1].reshape((-1, 1))

yield check_binarized_results, y, classes, pos_label, neg_label, expected


def test_label_binarize_multiclass():
y = [0, 1, 2]
Expand Down
Loading