-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+1] Implements Multiclass hinge loss #3607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,12 @@ Enhancements | |
to `Rohit Sivaprasad`_), as well as evaluation metrics (by | ||
`Joel Nothman`_). | ||
|
||
- Add ``sample_weight`` parameter to `metrics.jaccard_similarity_score`. | ||
By `Jatin Shah`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent this back to where it was, please. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new thing indentation is more in line with the whole file. Can you please have a look at the source please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new indentation is correct. (The indentation of the whole file is a bit awkward, IMO) |
||
|
||
- Add support for multiclass in `metrics.hinge_loss`. Added ``labels=None`` | ||
as optional paramter. By `Saurabh Jha`. | ||
|
||
- Add ``multi_class="multinomial"`` option in | ||
:class:`linear_model.LogisticRegression` to implement a Logistic | ||
Regression solver that minimizes the cross-entropy or multinomial loss | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
# Joel Nothman <[email protected]> | ||
# Noel Dawe <[email protected]> | ||
# Jatin Shah <[email protected]> | ||
# Saurabh Jha <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
from __future__ import division | ||
|
@@ -1376,14 +1377,20 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None): | |
return _weighted_sum(loss, sample_weight, normalize) | ||
|
||
|
||
def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None): | ||
def hinge_loss(y_true, pred_decision, labels=None): | ||
"""Average hinge loss (non-regularized) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it a convention to not document optional parameters? For example, pos_label and neg_label are not documented here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no such convention, whereas There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes! I did not notice that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I remove them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. git blame points me to @arjoly . Maybe he can confirm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I will leave this as it is for now. On Mon, Sep 15, 2014 at 9:04 PM, Manoj Kumar [email protected]
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems |
||
|
||
Assuming labels in y_true are encoded with +1 and -1, when a prediction | ||
mistake is made, ``margin = y_true * pred_decision`` is always negative | ||
(since the signs disagree), implying ``1 - margin`` is always greater than | ||
1. The cumulated hinge loss is therefore an upper bound of the number of | ||
mistakes made by the classifier. | ||
In binary class case, assuming labels in y_true are encoded with +1 and -1, | ||
when a prediction mistake is made, ``margin = y_true * pred_decision`` is | ||
always negative (since the signs disagree), implying ``1 - margin`` is | ||
always greater than 1. The cumulated hinge loss is therefore an upper | ||
bound of the number of mistakes made by the classifier. | ||
|
||
In multiclass case, the function expects that either all the labels are | ||
included in y_true or an optional labels argument is provided which | ||
contains all the labels. The multilabel margin is calculated according | ||
to Crammer-Singer's method. As in the binary case, the cumulated hinge loss | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some description like this belongs in |
||
is an upper bound of the number of mistakes made by the classifier. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -1394,6 +1401,9 @@ def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None): | |
pred_decision : array, shape = [n_samples] or [n_samples, n_classes] | ||
Predicted decisions, as output by decision_function (floats). | ||
|
||
labels : array, optional, default None | ||
Contains all the labels for the problem. Used in multiclass hinge loss. | ||
|
||
Returns | ||
------- | ||
loss : float | ||
|
@@ -1403,6 +1413,16 @@ def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None): | |
.. [1] `Wikipedia entry on the Hinge loss | ||
<http://en.wikipedia.org/wiki/Hinge_loss>`_ | ||
|
||
.. [2] Koby Crammer, Yoram Singer. On the Algorithmic | ||
Implementation of Multiclass Kernel-based Vector | ||
Machines. Journal of Machine Learning Research 2, | ||
(2001), 265-292 | ||
|
||
.. [3] 'L1 AND L2 Regularization for Multiclass Hinge Loss Models | ||
by Robert C. Moore, John DeNero. | ||
<http://www.ttic.edu/sigml/symposium2011/papers/ | ||
Moore+DeNero_Regularization.pdf>' | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn import svm | ||
|
@@ -1420,27 +1440,56 @@ def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None): | |
>>> hinge_loss([-1, 1, 1], pred_decision) # doctest: +ELLIPSIS | ||
0.30... | ||
|
||
In the multiclass case: | ||
>>> X = np.array([[0], [1], [2], [3]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: space after |
||
>>> Y = np.array([0, 1, 2, 3]) | ||
>>> labels = np.array([0, 1, 2, 3]) | ||
>>> est = svm.LinearSVC() | ||
>>> est.fit(X, Y) | ||
LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True, | ||
intercept_scaling=1, loss='l2', max_iter=1000, multi_class='ovr', | ||
penalty='l2', random_state=None, tol=0.0001, verbose=0) | ||
>>> pred_decision = est.decision_function([[-1], [2], [3]]) | ||
>>> y_true = [0, 2, 3] | ||
>>> hinge_loss(y_true, pred_decision, labels) #doctest: +ELLIPSIS | ||
0.56... | ||
""" | ||
# TODO: multi-class hinge-loss | ||
check_consistent_length(y_true, pred_decision) | ||
pred_decision = check_array(pred_decision, ensure_2d=False) | ||
y_true = column_or_1d(y_true) | ||
pred_decision = column_or_1d(pred_decision) | ||
|
||
# the rest of the code assumes that positive and negative labels | ||
# are encoded as +1 and -1 respectively | ||
lbin = LabelBinarizer(neg_label=-1) | ||
y_true = lbin.fit_transform(y_true)[:, 0] | ||
|
||
if len(lbin.classes_) > 2 or (pred_decision.ndim == 2 | ||
and pred_decision.shape[1] != 1): | ||
raise ValueError("Multi-class hinge loss not supported") | ||
pred_decision = np.ravel(pred_decision) | ||
|
||
try: | ||
margin = y_true * pred_decision | ||
except TypeError: | ||
raise TypeError("pred_decision should be an array of floats.") | ||
y_true_unique = np.unique(y_true) | ||
if y_true_unique.size > 2: | ||
if (labels is None and pred_decision.ndim > 1 and | ||
(np.size(y_true_unique) != pred_decision.shape[1])): | ||
raise ValueError("Please include all labels in y_true " | ||
"or pass labels as third argument") | ||
if labels is None: | ||
labels = y_true_unique | ||
le = LabelEncoder() | ||
le.fit(labels) | ||
y_true = le.transform(y_true) | ||
mask = np.ones_like(pred_decision, dtype=bool) | ||
mask[np.arange(y_true.shape[0]), y_true] = False | ||
margin = pred_decision[~mask] | ||
margin -= np.max(pred_decision[mask].reshape(y_true.shape[0], -1), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a comment here to explain the |
||
axis=1) | ||
|
||
else: | ||
# Handles binary class case | ||
# this code assumes that positive and negative labels | ||
# are encoded as +1 and -1 respectively | ||
pred_decision = column_or_1d(pred_decision) | ||
pred_decision = np.ravel(pred_decision) | ||
|
||
lbin = LabelBinarizer(neg_label=-1) | ||
y_true = lbin.fit_transform(y_true)[:, 0] | ||
|
||
try: | ||
margin = y_true * pred_decision | ||
except TypeError: | ||
raise TypeError("pred_decision should be an array of floats.") | ||
|
||
losses = 1 - margin | ||
# The hinge doesn't penalize good enough predictions. | ||
# The hinge_loss doesn't penalize good enough predictions. | ||
losses[losses <= 0] = 0 | ||
return np.mean(losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
due to -> by?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this is how you do it in scientific references. Does it matter much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok