Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] FEA multilabel confusion matrix #11179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ShangwuYao
Copy link
Contributor

@ShangwuYao ShangwuYao commented May 31, 2018

Reference Issues/PRs

Start adding metrics for #5516, continue on and close #10628

Fixes #3452

What does this implement/fix? Explain your changes.

  • implement multilabel_confusion_matrix and fix up edge cases that fail tests
  • benchmark multiclass implementation against incumbent P/R/F/S
  • benchmark multilabel implementation with benchmarks/bench_multilabel_metrics.py extended to consider non-micro averaging, sample_weight and perhaps other cases
  • optimize speed based on line-profiling
  • directly test multilabel_confusion_matrix
  • document under model_evaluation.rst
  • document how to calculate fall-out, miss-rate, sensitivity, specificity from multilabel_confusion_matrix
  • refactor jaccard similarity implementation once [MRG] average parameter for jaccard_similarity_score #10083 is merged

@ShangwuYao ShangwuYao changed the title Multilabel confusion append Continue on multilabel confusion matrix May 31, 2018
@jnothman
Copy link
Member

jnothman commented Jun 2, 2018

Nb: In #5516 I preferred adding scorers, but not metric functions, for things like fallout.

What are benchmark results looking like?

@jnothman
Copy link
Member

jnothman commented Jun 2, 2018

Also please prefix your pr title with [WIP] until tests and features are complete

@ShangwuYao ShangwuYao changed the title Continue on multilabel confusion matrix [WIP] Continue on multilabel confusion matrix Jun 2, 2018
@sklearn-lgtm
Copy link

This pull request introduces 1 alert when merging cdd619d into a31a906 - view on LGTM.com

new alerts:

  • 1 for Wrong name for an argument in a call

Comment posted by LGTM.com

@sklearn-lgtm
Copy link

This pull request introduces 1 alert when merging ec82be3 into a31a906 - view on LGTM.com

new alerts:

  • 1 for Wrong name for an argument in a call

Comment posted by LGTM.com

@ShangwuYao
Copy link
Contributor Author

ShangwuYao commented Jun 9, 2018

The benchmarking result of precision_recall_fscore_support_with_multilabel_confusion_matrix on multiclass case is much slower than the original one, because of the use of confusion_matrix (10 times slower in some cases) (as shown in #10628 ).

I have optimized the speed for multilabel-indicator case (replace bincount with point-wise multiplication, remove unnecessary expensive _check_targets), for this case, the precision_recall_fscore_support_with_multilabel_confusion_matrix has very close performance with the original implementation (slightly faster).

The precision_recall_fscore_support_with_multilabel_confusion_matrix is only for debugging and optimization purpose and will be removed. And this implementation of precision_recall_fscore_support_with_multilabel_confusion_matrix could pass the tests for precision_recall_fscore_support.

@jnothman
Copy link
Member

jnothman commented Jun 9, 2018 via email

@ShangwuYao
Copy link
Contributor Author

Ok, I will ping you when I finished.
I hope to contribute to sth critical, if you find anything suitable, I would love to give it a try.

@jnothman
Copy link
Member

jnothman commented Jun 10, 2018 via email

@ShangwuYao
Copy link
Contributor Author

That issue looks interesting, I am looking into it.
I think the work on multilabel_confusion_matrix is finished, could you review it when you have the time? @jnothman Thanks!

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm mistaken, you're no longer using this in precision_tecall_fscore_supporr due to poor benchmarks in the multilabel case. I consider its use in that function to be a key goal.

Is there a faster way to count true positives, false positives and false negatives for each class without using confusion_matrix?

Also, could you please change your benchmark plots to show comparable curves with the same colour but different markers on them depending on whether they are using mlcm or not?

Thanks a lot!

Btw, that issue I pointed you too is a long-term wish, not critical for the next release.

labels=None, samplewise=False):
"""Returns a confusion matrix for each output of a multilabel problem

Multiclass tasks will be treated as if binarised under a one-vs-rest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps say (i.e. where y is 1d)

raise ValueError("Samplewise confusion is not useful outside of "
"multilabel classification.")
present_labels = unique_labels(y_true, y_pred)
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should try do this with LabelEncoder and bincount. Or with confusion_matrix without its validation.

Copy link
Contributor Author

@ShangwuYao ShangwuYao Jun 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using LabelEncoder and bincount as you said, but then it is pretty much the same as the original implementation. So there is no speedup in doing that.

Although, interestingly, in the multiclass case and binary case of multilabel_confusion_matrix, if I replace the call to confusion_matrix with the LabelEncoder and bincount implementation, it will make multilabel_confusion_matrix faster than confusion_matrix.

In [7]: y_true = np.random.randint(0, 2, (300,))

In [8]: y_pred = np.random.randint(0, 2, (300,))

In [9]: %timeit multilabel_confusion_matrix(y_true, y_pred)
308 µs ± 3.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [10]: %timeit confusion_matrix(y_true, y_pred)
488 µs ± 5.72 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

So there should be room for improvement here, I will let you know when I get deeper.

@eamanu eamanu mentioned this pull request Oct 6, 2018
@jnothman
Copy link
Member

jnothman commented Oct 7, 2018

Do you have some references for users?

I invented the name. The function is mostly to help us implement and facilitate the implementation of arbitrary metrics based on set-wise/binary metrics: Precision, Recall, F1, Fβ, Jaccard, Specificity, etc. The point is that it calculates sufficient statistics for these metrics, just as confusion_matrix calculates sufficient statistics for multiclass metrics (ignoring sample_weight perhaps), and contingency_matrix calculates sufficient statistics for clustering metrics.

Because its needs are mostly internal, and it would mostly be used through the metric aggregates, it is unlikely to be referenced in the literature as such.

It makes sense, however, that utiml similarly implements something like this, because from it they can derive all the metrics they need for multilabel evaluation. Their "absolute matrix" is not identical to our confusion_matrix. It is merely the sum over axis 0 of the output of multilabel_confusion_matrix.

The binary case (at least with 1d input) needs to be handled with a 2x2x2 matrix, in order to support those metrics above. One clarifying alternative would be to make a separate function for multiclass input (which needs to be binarised) from that for multilabel input.

I'd be okay calling this binarized_confusion_matrix, in correspondence with LabelBinarizer (but without any affinity in meaning to Binarizer!).

@qinhanmin2014
Copy link
Member

(1) Regarding the name: I'm fine with both nultilabel_confusion_matrix and binarized_confusion_matrix so I'll follow your final decision. Personally, I prefer binarized_confusion_matrix, because if I put binarized_confusion_matrix and confusion_matrix together, I can figure out some of the relationships and differences between them.
(2) Regarding the R pack, apologies I didn't read the doc carefully. I don't think we need to care about their Absolute Matrix and Proportinal Matrix.
(3) Regarding the binary case, I'm fine with current implementation.

So the remaining things @jnothman
(1) your final decision about the name
(2) 4 minor review comments above

@qinhanmin2014
Copy link
Member

And @jnothman another annoying thing :) Do you think it's acceptable?

multilabel_confusion_matrix([0, 0], [0, 0])
array([[[0, 0],
           [0, 2]]], dtype=int64)

@jnothman
Copy link
Member

Do you think it's acceptable?

I think that is consistent, and don't have a problem with it.

>>> multilabel_confusion_matrix([0, 0], [0, 0])
array([[[0, 0],
        [0, 2]]])
>>> multilabel_confusion_matrix([0, 0], [0, 0], labels=[0, 1])
array([[[0, 0],
        [0, 2]],

       [[2, 0],
        [0, 0]]])

@jnothman
Copy link
Member

I'm happy to rename to binarized_confusion_matrix, as long as you reckon that makes sense for the case that the input is already multiple labels.

@qinhanmin2014
Copy link
Member

I think that is consistent, and don't have a problem with it.

I see, yes it's reasonable but a bit tricky.

I'm happy to rename to binarized_confusion_matrix

I think both names are fine and will follow your decision. (binarized_confusion_matrix seems more straightforward but multilabel_confusion_matrix seems to be consistent with R utiml).

Still want to know your opinion about the 4 reviews above (#11179 (review)), especially the second and the fourth one :)

@jnothman
Copy link
Member

@TomDLT what do you think of the name binarized_confusion_matrix vs multilabel_confusion_matrix vs other??

@jnothman
Copy link
Member

Any further opinions on the name binarized_confusion_matrix vs multilabel_confusion_matrix vs other for a function which returns a 2x2 confusion matrix for each class in multilabel or multiclass data?

@qinhanmin2014
Copy link
Member

FYI test is failing (apologies I don't have time to investigate now)

=================================== FAILURES ===================================
___________________ test_multilabel_confusion_matrix_errors ____________________
    def test_multilabel_confusion_matrix_errors():
        y_true = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
        y_pred = np.array([[1, 0, 0], [0, 1, 1], [0, 0, 1]])
    
        # Bad sample_weight
        assert_raise_message(ValueError, "inconsistent numbers of samples",
                             multilabel_confusion_matrix,
                             y_true, y_pred, sample_weight=[1, 2])
        assert_raise_message(ValueError, "could not be broadcast",
                             multilabel_confusion_matrix,
                             y_true, y_pred,
                             sample_weight=[[1, 2, 3],
                                            [2, 3, 4],
>                                           [3, 4, 5]])
y_pred     = array([[1, 0, 0],
       [0, 1, 1],
       [0, 0, 1]])
y_true     = array([[1, 0, 1],
       [0, 1, 0],
       [1, 1, 0]])
/home/travis/build/scikit-learn/scikit-learn/sklearn/metrics/tests/test_classification.py:484: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
exceptions = <type 'exceptions.ValueError'>, message = 'could not be broadcast'
function = <function multilabel_confusion_matrix at 0x7f3042ff0c80>
args = (array([[1, 0, 1],
       [0, 1, 0],
       [1, 1, 0]]), array([[1, 0, 0],
       [0, 1, 1],
       [0, 0, 1]]))
kwargs = {'sample_weight': [[1, 2, 3], [2, 3, 4], [3, 4, 5]]}
e = ValueError('a.shape[axis] != len(repeats)',)
error_message = 'a.shape[axis] != len(repeats)'
    def assert_raise_message(exceptions, message, function, *args, **kwargs):
        """Helper function to test the message raised in an exception.
    
        Given an exception, a callable to raise the exception, and
        a message string, tests that the correct exception is raised and
        that the message is a substring of the error thrown. Used to test
        that the specific message thrown during an exception is correct.
    
        Parameters
        ----------
        exceptions : exception or tuple of exception
            An Exception object.
    
        message : str
            The error message or a substring of the error message.
    
        function : callable
            Callable object to raise error.
    
        *args : the positional arguments to `function`.
    
        **kwargs : the keyword arguments to `function`.
        """
        try:
            function(*args, **kwargs)
        except exceptions as e:
            error_message = str(e)
            if message not in error_message:
                raise AssertionError("Error message does not include the expected"
                                     " string: %r. Observed error message: %r" %
>                                    (message, error_message))
E               AssertionError: Error message does not include the expected string: 'could not be broadcast'. Observed error message: 'a.shape[axis] != len(repeats)'
args       = (array([[1, 0, 1],
       [0, 1, 0],
       [1, 1, 0]]), array([[1, 0, 0],
       [0, 1, 1],
       [0, 0, 1]]))
e          = ValueError('a.shape[axis] != len(repeats)',)
error_message = 'a.shape[axis] != len(repeats)'
exceptions = <type 'exceptions.ValueError'>
function   = <function multilabel_confusion_matrix at 0x7f3042ff0c80>
kwargs     = {'sample_weight': [[1, 2, 3], [2, 3, 4], [3, 4, 5]]}
message    = 'could not be broadcast'

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman Do we need to wait for @TomDLT 's opinion about the name? If not, I guess we can merge.

@TomDLT
Copy link
Member

TomDLT commented Oct 30, 2018

  • multilabel_confusion_matrix makes it clear on which problem you can use the function.
  • binarized_confusion_matrix makes it clear what is actually computed.

Both names are fine. I would be slightly in favor of multilabel_confusion_matrix, since this may help users find the function. The question of what is actually computed can be found in the docstring: Multiclass data will be treated as if binarized under a one-vs-rest transformation..

@qinhanmin2014
Copy link
Member

I think we can merge. Thanks all for the great work!

@qinhanmin2014 qinhanmin2014 merged commit 6555631 into scikit-learn:master Oct 30, 2018
@jnothman
Copy link
Member

And thanks @ShangwuYao for making this happen!

thoo pushed a commit to thoo/scikit-learn that referenced this pull request Nov 14, 2018
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
@rth rth mentioned this pull request Aug 3, 2020
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add multi-label support to the confusion matrix metric
6 participants