Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] Implements Multiclass hinge loss #3607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

SaurabhJha
Copy link
Contributor

Implements multiclass hinge loss. Fixes #3451

@SaurabhJha
Copy link
Contributor Author

I apologise for the failing test. Can anyone please help me with this error? I will include the complete traceback

Traceback (most recent call last):
File "/Library/Python/2.7/site-packages/nose/case.py", line 197, in runTest
self.test(*self.arg)
File "/Users/saurabhjha/scikit-learn/sklearn/metrics/tests/test_classification.py", line 1072, in test_hinge_loss_multilabel
hinge_loss_by_hand_part3,
File "/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/numpy/core/fromnumeric.py", line 2374, in mean
return mean(axis, dtype, out)
ValueError: axis(=3) out of bounds

@Winterflower
Copy link

Are the variables hinge_loss_by_hand_part1, hinge_loss_by_hand_part2 and hinge_loss_by_hand_part3 references to primitives or numpy arrays? Based on a quick look at the test, hinge_loss_by_hand_part* refer to floats and thus the numpy method is giving you an error.

If I understand correctly, numpy.mean (or np.mean) accepts an array as an argument. (See the docs at http://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html).
Could you try to fixing it to

hinge_loss_by_hand = np.mean(np.array([hinge_loss_by_hand_part1, 
hinge_loss_by_hand_part2,  hinge_loss_by_hand_part3]))

and see if the test passes?

margins_array = []
if label_vector is None:
raise ValueError("label_vector\
required in multilabel classification")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being dumb, but what exactly is label_vector . Isn't it simply np.unique(y)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we need a multi_class option that calculates the ovr or crammer-singer hinge loss as @mblondel has suggested.

@MechCoder
Copy link
Member

+1 for @Winterflower 's suggestion. But I'm not sure the code even goes there since you have ravelled pred_decision

@SaurabhJha
Copy link
Contributor Author

Can anyone please review the new commit. I have tried to address comments

@MechCoder
Copy link
Member

@SaurabhJha I'm trying to have a look now.

pred_decision = column_or_1d(pred_decision)
pred_decision = np.array(pred_decision)
lb = LabelEncoder()
encoded_labels = lb.fit(y_true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need not store the value of lb.fit(y_true) since it returns lb . You can simply do

lb = LabelEncoder()
lb.fit(y_true)

and use lb.classes_ below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even better

le = LabelEncoder()
y_bin = le.fit_transform(y)

so that it can be used later on.

@MechCoder
Copy link
Member

@SaurabhJha Please update your code according to this vectorized implementation. After that we can look at further clean ups. #3607 (comment)

@SaurabhJha
Copy link
Contributor Author

Do you think we need the optional parameter "label_vector"? I tried but cannot think any other way to do it.

@MechCoder
Copy link
Member

No, I do not think so.

I have done it without the label_vector, you can just use the transformed classes of the LabelEncoder and probably raise an error

hinge_loss = np.ones(y_true.shape[0])
classes_ = np.unique(y_bin)
if len(classes_) != pred_decision.shape[1]:
    raise ValueError("...")
for i in classes_:
    class_slice = classes_ != i
    pos_index = y_bin == i
    hinge_loss[pos_index] -= pred_decision[pos_index, i]
    hinge_loss[pos_index] += np.max(pred_decision[pos_index][:, class_slice], axis=1)
return np.mean(hinge_loss)

Does this seem ok to you?

@SaurabhJha
Copy link
Contributor Author

I think I got the issue. For a four-label problem, a typical pred_decision row looks like this--
[ 1.27272363, 0.0342046 , -0.68379465, -1.40169096]

So, you need all the labels to match each of them to their respective labels. So let's say we have the labels like this
[0, 1, 2, 3]
and we want to predict hinge_loss for 2, then pred_decision[0], pred_decision[1] and pred_decision[3] will be negative labels.
I don't think we can get the complete picture of multi label problem unless we pass the true labels.

In fact, your first test will fail

if len(classes_) != pred_decision.shape[1]:
For the particular test case I wrote, it will throw 3 != 4 error.

What do you think? Please correct me if I am wrong

@MechCoder
Copy link
Member

if len(classes_) != pred_decision.shape[1]:

classes_ = [0, 1, 2, 3] and pred_decision = [ 1.27272363, 0.0342046 , -0.68379465, -1.40169096]. len(classes_)=4. So I do not understand how it will throw a 3 != 4 error

The true label that you pass is y_true. I still do not understand why you need a label vector, the label vector is nothing but np.unique(y_true)

@MechCoder
Copy link
Member

For ex, look at the fit method of any estimator in scikit-learn. It is not required to pass what label y has. This is computed internally from y itself.

@MechCoder
Copy link
Member

Just in case I am wrong @arjoly can you please verify that this comment seems to be correct, #3607 (comment) , or can you think of a better way?

@jnothman
Copy link
Member

jnothman commented Sep 2, 2014

labels are provided in multiclass metrics:

  • if the problem can be decomposed as binary metrics (e.g. F1) and it's useful to return the per-class metric, then an ordering needs to be provided (or returned) for the output to be interpretable
  • to ensure that if the particular subsample does not contain any instances of some class, the metric can be adjusted for that missing class so that metrics calculated on other samples are in the same scale

@MechCoder
Copy link
Member

@jnothman Thanks for the clarification!

@MechCoder
Copy link
Member

I was thinking that we were not going to do per-class hinge_loss and just the mean across all samples and hence my comment.

pred_decision = column_or_1d(pred_decision)
pred_decision = np.array(pred_decision)
le = LabelEncoder()
le.fit(y_true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The label encoding needs to account for labels not included in y_true, if provided in label_vector (which should be called labels)

@jnothman
Copy link
Member

jnothman commented Sep 3, 2014

@SaurabhJha, I don't think this code is correct.

Could please you describe the algorithm in words (pseudocode) for clarification? Then we can help you get the code to match its description.

@MechCoder
Copy link
Member

@jnothman The idea is that for each sample, we do 1 - the pred decision corresponding to the true class corresponding to the sample + max(pred decision corresponding to the other classes), then take the mean across all samples.

I had come up with this version.

hinge_loss = np.ones(y_true.shape[0])
classes_ = np.unique(y_bin)
for i in classes_:
    class_slice = classes_ != i
    true_index = y_bin == i
    hinge_loss[true_index] -= pred_decision[true_index, i]
    hinge_loss[true_index] += np.max(pred_decision[true_index][:, class_slice], axis=1)
return np.mean(hinge_loss)

raise TypeError("pred_decision should be an array of floats.")
y_true_unique = np.unique(y_true)
if np.size(y_true_unique) > 2:
if (labels is None and len(pred_decision.shape) > 1 and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at present this will fail if pred_decision is a list. (I think), You need to do check_array(pred_decision, ensure_2d=False) somewhere, and replace len(pred_decision.shape) by pred_decision.ndim

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also please write a test for this. you can replace one of the pred_decision that is a numpy array in one of your tests with a list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I am wrong. So your concern the calling of shape on pred_decision if it's not numpy array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

On Tue, Oct 28, 2014 at 5:08 PM, Saurabh Jha [email protected]
wrote:

In sklearn/metrics/classification.py:

  • are encoded as +1 and -1 respectively

  • lbin = LabelBinarizer(neg_label=-1)

- y_true = lbin.fit_transform(y_true)[:, 0]

  • if len(lbin.classes_) > 2 or (pred_decision.ndim == 2
  •                              and pred_decision.shape[1] != 1):
    
  •    raise ValueError("Multi-class hinge loss not supported")
    

- pred_decision = np.ravel(pred_decision)

  • try:
  •    margin = y_true \* pred_decision
    
  • except TypeError:
  •    raise TypeError("pred_decision should be an array of floats.")
    
  • y_true_unique = np.unique(y_true)
  • if np.size(y_true_unique) > 2:
  •    if (labels is None and len(pred_decision.shape) > 1 and
    

Please correct me if I am wrong. So your concern the calling of shape on
pred_decision if it's not numpy array


Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/3607/files#r19481610.

Godspeed,
Manoj Kumar,
Intern, Telecom ParisTech
Mech Undergrad
http://manojbits.wordpress.com

@MechCoder
Copy link
Member

that's it from my side.

@SaurabhJha
Copy link
Contributor Author

I wonder if I should squash the upcoming commit also into previous commit. I think I should

@MechCoder
Copy link
Member

you could just do

git commit -a --amend
git push -f origin "name_of_branch"

@MechCoder
Copy link
Member

okay thanks. Any other final reviews @jnothman @arjoly ? I shall merge in 72 hours, i.e Monday around this time if nobody complains.

hinge_loss, y_true, pred_decision)


def test_hinge_loss_multiclass_reptition_of_labels_with_missing_labels():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reptition -> repetition

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, don't mention repetition. Repetition should be what's normal in multiclass evaluation, and is included in all tested.

@jnothman
Copy link
Member

In case @MechCoder threatens to merge again, I'm happy with this PR (except for the test name issue), but we should have a better documentation of how it fits into the invariance tests, and should probably explicitly exclude it from the current _WITH_LABELS test that is inapplicable.

@MechCoder
Copy link
Member

In case @MechCoder threatens to merge again,

That was a really "unethical" way of getting your attention, but hey it worked ;)

@SaurabhJha SaurabhJha force-pushed the 3451 branch 2 times, most recently from 76fae17 to eaaf80f Compare October 31, 2014 16:50
# Currently, invariance of string and integer labels cannot be tested
# in common invarinace tests because invariance tests for multiclass
# decision functions is not implemented yet.
y_true = ['blue', 'green', 'red',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MechCoder Please see if this message is what was intended?

@MechCoder
Copy link
Member

@jnothman I cherrypicked Saurabh's commit, made a few minor cosmetics and I want to push it directly from cmd-line. Sorry but how do I do it? I do not want to screw anything up

@MechCoder
Copy link
Member

I did this.

git checkout -b multiclass_hinge
git cherry-pick saurabh 3451
# made changes
git commit -a --amend

do I just do git rebase -i upstream/master?

@jnothman
Copy link
Member

jnothman commented Nov 3, 2014

Okay, my git squashing/merging workflow in brief. In .git/config I have:

[remote "upstream"]
    url = [email protected]:scikit-learn/scikit-learn.git
    fetch = +refs/heads/master:refs/remotes/upstream/master
    fetch = +refs/pull/*/head:refs/remotes/upstream/pr/*

Then I do something like:

$ git checkout master
$ git pull upstream master
$ git fetch upstream  # fetches all PRs
$ git checkout upstream/pr/3607
$ git squash master  # uses git rebase -i; from https://github.com/jnothman/git-squash
$ git checkout -  # back to master
$ git merge -  # cherry-picks in squashed commit
$ git push upstream master

(And in practice this looks like

$ gchm
$ glum
$ gfu
$ gch upstream/pr/3607
$ git squash master
$ gch -
$ gm -
$ gpum

with some bits obviously still needing an alias)

@MechCoder
Copy link
Member

I think I made a mistake before, I removed your latest commit from master (2f275de) . I think I somehow replaced origin with upstream somewhere while typing. Extremely sorry :/

@jnothman
Copy link
Member

jnothman commented Nov 3, 2014

Surely you can't have done that without a force push. Never ever force push
to scikit-learn/master...

On 3 November 2014 23:53, Manoj Kumar [email protected] wrote:

I think I made a mistake before, I removed your latest commit from
master (2f275de
2f275de)
. I think I somehow replaced origin with upstream somewhere while typing.
Extremely sorry :/


Reply to this email directly or view it on GitHub
#3607 (comment)
.

@jnothman
Copy link
Member

jnothman commented Nov 3, 2014

Anyway, I've restored the head to what it was recently.

@MechCoder
Copy link
Member

Sorry. I've learnt from my mistakes. Now trying out this comment (#3607 (comment))

@MechCoder
Copy link
Member

@SaurabhJha There were some minor typos in the docs and I commited as 8eee4bc after removing it from the metrics_with_labels since it is the pred_decision function that is the argument instead of y_true as in other metrics (after adding a TODO note). Thanks for your contrib!

@MechCoder MechCoder closed this Nov 3, 2014
@MechCoder
Copy link
Member

@jnothman Thanks for your tips! I think I did it correctly. On a side note, this (https://twitter.com/heathercmiller/status/526770571728531456) ;)

@SaurabhJha
Copy link
Contributor Author

Thank you @jnothman @MechCoder for your review. Good to see my first contribution :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add multiclass support to hinge_loss
6 participants