diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index c5e01ad56c32d..a379ac17b3e5f 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -70,6 +70,7 @@ Scoring Function 'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support 'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1' 'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1' +'jaccard' etc. :func:`metrics.jaccard_similarity_score` suffixes apply as with 'f1' 'roc_auc' :func:`metrics.roc_auc_score` **Clustering** @@ -698,24 +699,31 @@ with a ground truth label set :math:`y_i` and predicted label set J(y_i, \hat{y}_i) = \frac{|y_i \cap \hat{y}_i|}{|y_i \cup \hat{y}_i|}. -In binary and multiclass classification, the Jaccard similarity coefficient -score is equal to the classification accuracy. +:func:`jaccard_similarity_score` works like :func:`precision_recall_fscore_support` +as a naively set-wise measure applying only to binary and multilabel targets. -:: +In the multilabel case with binary label indicators: :: >>> import numpy as np >>> from sklearn.metrics import jaccard_similarity_score - >>> y_pred = [0, 2, 1, 3] - >>> y_true = [0, 1, 2, 3] + >>> y_true = np.array([[0, 1], [1, 1]]) + >>> y_pred = np.ones((2, 2)) >>> jaccard_similarity_score(y_true, y_pred) - 0.5 + 0.75 >>> jaccard_similarity_score(y_true, y_pred, normalize=False) - 2 + 1.5 -In the multilabel case with binary label indicators: :: +Multiclass problems are binarized and treated like the corresponding +multilabel problem: :: - >>> jaccard_similarity_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2))) - 0.75 + >>> y_pred = [0, 2, 1, 3] + >>> y_true = [0, 1, 2, 3] + >>> jaccard_similarity_score(y_true, y_pred, average='macro') + 0.5 + >>> jaccard_similarity_score(y_true, y_pred, average='micro') + 0.33... + >>> jaccard_similarity_score(y_true, y_pred, average=None) + array([1., 0., 0., 1.]) .. _precision_recall_f_measure_metrics: diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index fdd0230fc840b..3e91b2bfa1609 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -160,6 +160,13 @@ Support for Python 3.4 and below has been officially dropped. metrics such as recall, specificity, fall out and miss rate. :issue:`11179` by :user:`Shangwu Yao ` and `Joel Nothman`_. +- |Feature| |Fix| :func:`metrics.jaccard_similarity_score` now accepts + ``average`` argument like :func:`metrics.precision_recall_fscore_support` as + a naively set-wise measure applying only to binary, multilabel targets. It + now binarizes multiclass input and treats them like the corresponding + multilabel problem. + :issue:`10083` by :user:`Gaurav Dhingra ` and `Joel Nothman`_. + - |Enhancement| Use label `accuracy` instead of `micro-average` on :func:`metrics.classification_report` to avoid confusion. `micro-average` is only shown for multi-label or multi-class with a subset of classes because @@ -167,15 +174,15 @@ Support for Python 3.4 and below has been officially dropped. :issue:`12334` by :user:`Emmanuel Arias `, `Joel Nothman`_ and `Andreas Müller`_ +- |Fix| The metric :func:`metrics.r2_score` is degenerate with a single sample + and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`. + :issue:`12855` by :user:`Pawel Sendyk .` + - |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated in version 0.21 and will be removed in version 0.23. :issue:`10580` by :user:`Reshama Shaikh ` and `Sandra Mitrovic `. -- |Fix| The metric :func:`metrics.r2_score` is degenerate with a single sample - and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`. - :issue:`12855` by :user:`Pawel Sendyk .` - - |Efficiency| The pairwise manhattan distances with sparse input now uses the BLAS shipped with scipy instead of the bundled BLAS. :issue:`12732` by :user:`Jérémie du Boisberranger ` diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index 7d1653c426c68..5ae254b6028f7 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -577,7 +577,8 @@ class labels [2]_. return 1 - k -def jaccard_similarity_score(y_true, y_pred, normalize=True, +def jaccard_similarity_score(y_true, y_pred, labels=None, pos_label=1, + average='samples', normalize='true-if-samples', sample_weight=None): """Jaccard similarity coefficient score @@ -596,23 +597,57 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True, y_pred : 1d array-like, or label indicator array / sparse matrix Predicted labels, as returned by a classifier. + labels : list, optional + The set of labels to include when ``average != 'binary'``, and their + order if ``average is None``. Labels present in the data can be + excluded, for example to calculate a multiclass average ignoring a + majority negative class, while labels not present in the data will + result in 0 components in a macro average. For multilabel targets, + labels are column indices. By default, all labels in ``y_true`` and + ``y_pred`` are used in sorted order. + + pos_label : str or int, 1 by default + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. + + average : string, ['samples' (default), 'binary', 'micro', 'macro', None, \ + 'weighted'] + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification). + normalize : bool, optional (default=True) If ``False``, return the sum of the Jaccard similarity coefficient over the sample set. Otherwise, return the average of Jaccard - similarity coefficient. + similarity coefficient. ``normalize`` is only applicable when + ``average='samples'``. The default value 'true-if-samples' behaves like + True, but does not raise an error with other values of `average`. sample_weight : array-like of shape = [n_samples], optional Sample weights. Returns ------- - score : float - If ``normalize == True``, return the average Jaccard similarity - coefficient, else it returns the sum of the Jaccard similarity - coefficient over the sample set. - - The best performance is 1 with ``normalize == True`` and the number - of samples with ``normalize == False``. + score : float (if average is not None) or array of floats, shape =\ + [n_unique_labels] See also -------- @@ -620,48 +655,78 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True, Notes ----- - In binary and multiclass classification, this function is equivalent - to the ``accuracy_score``. It differs in the multilabel classification - problem. + :func:`jaccard_similarity_score` may be a poor metric if there are no + positives for some samples or classes. References ---------- .. [1] `Wikipedia entry for the Jaccard index `_ - Examples -------- + >>> import numpy as np >>> from sklearn.metrics import jaccard_similarity_score - >>> y_pred = [0, 2, 1, 3] - >>> y_true = [0, 1, 2, 3] - >>> jaccard_similarity_score(y_true, y_pred) + + In the multilabel case: + + >>> y_true = np.array([[1, 0, 1], [0, 0, 1], [1, 1, 1]]) + >>> y_pred = np.array([[0, 1, 1], [1, 1, 1], [0, 0, 1]]) + >>> jaccard_similarity_score(y_true, y_pred, average='samples') + ... # doctest: +ELLIPSIS + 0.33... + >>> jaccard_similarity_score(y_true, y_pred, average='micro') + ... # doctest: +ELLIPSIS + 0.33... + >>> jaccard_similarity_score(y_true, y_pred, average='weighted') 0.5 - >>> jaccard_similarity_score(y_true, y_pred, normalize=False) - 2 + >>> jaccard_similarity_score(y_true, y_pred, average=None) + array([0., 0., 1.]) - In the multilabel case with binary label indicators: + In the multiclass case: - >>> import numpy as np - >>> jaccard_similarity_score(np.array([[0, 1], [1, 1]]),\ - np.ones((2, 2))) - 0.75 + >>> jaccard_similarity_score(np.array([0, 1, 2, 3]), + ... np.array([0, 2, 2, 3]), average='macro') + 0.625 """ + if average != 'samples' and normalize != 'true-if-samples': + raise ValueError("'normalize' is only meaningful with " + "`average='samples'`, got `average='%s'`." + % average) + labels = _check_set_wise_labels(y_true, y_pred, average, labels, + pos_label) + if labels is _ALL_ZERO: + warnings.warn('Jaccard is ill-defined and being set to 0.0 with no ' + 'true or predicted samples', UndefinedMetricWarning) + return 0. + samplewise = average == 'samples' + MCM = multilabel_confusion_matrix(y_true, y_pred, + sample_weight=sample_weight, + labels=labels, samplewise=samplewise) + numerator = MCM[:, 1, 1] + denominator = MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0] - # Compute accuracy for each possible representation - y_type, y_true, y_pred = _check_targets(y_true, y_pred) - check_consistent_length(y_true, y_pred, sample_weight) - if y_type.startswith('multilabel'): - with np.errstate(divide='ignore', invalid='ignore'): - # oddly, we may get an "invalid" rather than a "divide" error here - pred_or_true = count_nonzero(y_true + y_pred, axis=1) - pred_and_true = count_nonzero(y_true.multiply(y_pred), axis=1) - score = pred_and_true / pred_or_true - score[pred_or_true == 0.0] = 1.0 + if average == 'micro': + numerator = np.array([numerator.sum()]) + denominator = np.array([denominator.sum()]) + + jaccard = _prf_divide(numerator, denominator, 'jaccard', + 'true or predicted', average, ('jaccard',)) + if average is None: + return jaccard + if not normalize: + return np.sum(jaccard * (1 if sample_weight is None + else sample_weight)) + if average == 'weighted': + weights = MCM[:, 1, 0] + MCM[:, 1, 1] + if not np.any(weights): + # numerator is 0, and warning should have already been issued + weights = None + elif average == 'samples' and sample_weight is not None: + weights = sample_weight else: - score = y_true == y_pred - - return _weighted_sum(score, sample_weight, normalize) + weights = None + return np.average(jaccard, weights=weights) def matthews_corrcoef(y_true, y_pred, sample_weight=None): @@ -1056,8 +1121,10 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for): The metric, modifier and average arguments are used only for determining an appropriate warning. """ - result = numerator / denominator mask = denominator == 0.0 + denominator = denominator.copy() + denominator[mask] = 1 + result = numerator / denominator if not np.any(mask): return result @@ -1091,6 +1158,41 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for): return result +_ALL_ZERO = object() # sentinel for special, degenerate case + + +def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): + """Validation associated with set-wise metrics + + Returns identified labels or _ALL_ZERO sentinel + """ + average_options = (None, 'micro', 'macro', 'weighted', 'samples') + if average not in average_options and average != 'binary': + raise ValueError('average has to be one of ' + + str(average_options)) + + y_type, y_true, y_pred = _check_targets(y_true, y_pred) + present_labels = unique_labels(y_true, y_pred) + if average == 'binary': + if y_type == 'binary': + if pos_label not in present_labels: + if len(present_labels) < 2: + return _ALL_ZERO + else: + raise ValueError("pos_label=%r is not a valid label: " + "%r" % (pos_label, present_labels)) + labels = [pos_label] + else: + raise ValueError("Target is %s but average='binary'. Please " + "choose another average setting." % y_type) + elif pos_label not in (None, 1): + warnings.warn("Note that pos_label (set to %r) is ignored when " + "average != 'binary' (got %r). You may use " + "labels=[pos_label] to specify a single positive class." + % (pos_label, average), UserWarning) + return labels + + def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, pos_label=1, average=None, warn_for=('precision', 'recall', @@ -1234,35 +1336,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, array([2, 2, 2])) """ - average_options = (None, 'micro', 'macro', 'weighted', 'samples') - if average not in average_options and average != 'binary': - raise ValueError('average has to be one of ' + - str(average_options)) if beta <= 0: raise ValueError("beta should be >0 in the F-beta score") - - y_type, y_true, y_pred = _check_targets(y_true, y_pred) - check_consistent_length(y_true, y_pred, sample_weight) - present_labels = unique_labels(y_true, y_pred) - - if average == 'binary': - if y_type == 'binary': - if pos_label not in present_labels: - if len(present_labels) < 2: - # Only negative labels - return (0., 0., 0., 0) - else: - raise ValueError("pos_label=%r is not a valid label: %r" % - (pos_label, present_labels)) - labels = [pos_label] - else: - raise ValueError("Target is %s but average='binary'. Please " - "choose another average setting." % y_type) - elif pos_label not in (None, 1): - warnings.warn("Note that pos_label (set to %r) is ignored when " - "average != 'binary' (got %r). You may use " - "labels=[pos_label] to specify a single positive class." - % (pos_label, average), UserWarning) + labels = _check_set_wise_labels(y_true, y_pred, average, labels, + pos_label) + if labels is _ALL_ZERO: + return (0., 0., 0., 0) # Calculate tp_sum, pred_sum, true_sum ### samplewise = average == 'samples' diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index f93736ed097a3..2f840ccb5a362 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -28,7 +28,7 @@ f1_score, roc_auc_score, average_precision_score, precision_score, recall_score, log_loss, balanced_accuracy_score, explained_variance_score, - brier_score_loss) + brier_score_loss, jaccard_similarity_score) from .cluster import adjusted_rand_score from .cluster import homogeneity_score @@ -482,6 +482,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, accuracy_scorer = make_scorer(accuracy_score) f1_scorer = make_scorer(f1_score) balanced_accuracy_scorer = make_scorer(balanced_accuracy_score) +jaccard_similarity_scorer = make_scorer(jaccard_similarity_score) # Score functions that need decision values roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True, @@ -534,8 +535,9 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, for name, metric in [('precision', precision_score), - ('recall', recall_score), ('f1', f1_score)]: - SCORERS[name] = make_scorer(metric) + ('recall', recall_score), ('f1', f1_score), + ('jaccard', jaccard_similarity_score)]: + SCORERS[name] = make_scorer(metric, average='binary') for average in ['macro', 'micro', 'samples', 'weighted']: qualified_name = '{0}_{1}'.format(name, average) SCORERS[qualified_name] = make_scorer(metric, pos_label=None, diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index a1ddc4654c462..c28236e8bf7f2 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -11,7 +11,7 @@ from sklearn import svm from sklearn.datasets import make_multilabel_classification -from sklearn.preprocessing import label_binarize +from sklearn.preprocessing import label_binarize, LabelBinarizer from sklearn.utils.validation import check_random_state from sklearn.utils.testing import assert_raises, clean_warning_registry from sklearn.utils.testing import assert_raise_message @@ -1136,7 +1136,40 @@ def test_multilabel_hamming_loss(): hamming_loss, y1, y2, labels=[0, 1]) -def test_multilabel_jaccard_similarity_score(): +def test_jaccard_similarity_score_validation(): + y_true = np.array([0, 1, 0, 1, 1]) + y_pred = np.array([0, 1, 0, 1, 1]) + assert_raise_message(ValueError, "pos_label=2 is not a valid label: " + "array([0, 1])", jaccard_similarity_score, y_true, + y_pred, average='binary', pos_label=2) + + y_true = np.array([[0, 1, 1], [1, 0, 0]]) + y_pred = np.array([[1, 1, 1], [1, 0, 1]]) + msg1 = ("Target is multilabel-indicator but average='binary'. " + "Please choose another average setting.") + assert_raise_message(ValueError, msg1, jaccard_similarity_score, y_true, + y_pred, average='binary', pos_label=-1) + + y_true = np.array([0, 1, 1, 0, 2]) + y_pred = np.array([1, 1, 1, 1, 0]) + msg2 = ("Target is multiclass but average='binary'. Please choose " + "another average setting.") + assert_raise_message(ValueError, msg2, jaccard_similarity_score, y_true, + y_pred, average='binary') + msg3 = ("Samplewise metrics are not available outside of multilabel " + "classification.") + assert_raise_message(ValueError, msg3, jaccard_similarity_score, y_true, + y_pred, average='samples') + + assert_warns_message(UserWarning, + "Note that pos_label (set to 3) is ignored when " + "average != 'binary' (got 'micro'). You may use " + "labels=[pos_label] to specify a single positive " + "class.", jaccard_similarity_score, y_true, y_pred, + average='micro', pos_label=3) + + +def test_multilabel_jaccard_similarity_score(recwarn): # Dense label indicator matrix format y1 = np.array([[0, 1, 1], [1, 0, 1]]) y2 = np.array([[0, 0, 1], [1, 0, 1]]) @@ -1144,13 +1177,135 @@ def test_multilabel_jaccard_similarity_score(): # size(y1 \inter y2) = [1, 2] # size(y1 \union y2) = [2, 2] - assert_equal(jaccard_similarity_score(y1, y2), 0.75) - assert_equal(jaccard_similarity_score(y1, y1), 1) - assert_equal(jaccard_similarity_score(y2, y2), 1) - assert_equal(jaccard_similarity_score(y2, np.logical_not(y2)), 0) - assert_equal(jaccard_similarity_score(y1, np.logical_not(y1)), 0) - assert_equal(jaccard_similarity_score(y1, np.zeros(y1.shape)), 0) - assert_equal(jaccard_similarity_score(y2, np.zeros(y1.shape)), 0) + assert jaccard_similarity_score(y1, y2) == 0.75 + assert jaccard_similarity_score(y1, y1) == 1 + assert jaccard_similarity_score(y2, y2) == 1 + assert jaccard_similarity_score(y2, np.logical_not(y2)) == 0 + assert jaccard_similarity_score(y1, np.logical_not(y1)) == 0 + assert jaccard_similarity_score(y1, np.zeros(y1.shape)) == 0 + assert jaccard_similarity_score(y2, np.zeros(y1.shape)) == 0 + + y_true = np.array([[0, 1, 1], [1, 0, 0]]) + y_pred = np.array([[1, 1, 1], [1, 0, 1]]) + # average='macro' + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='macro'), 2. / 3) + # average='micro' + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='micro'), 3. / 5) + # average='samples' (default) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred), 7. / 12) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='samples', + labels=[0, 2]), 1. / 2) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='samples', + labels=[1, 2]), 1. / 2) + # average='samples', normalize=False + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='samples', + normalize=False), + 7. / 6) + # average=None + assert_array_equal(jaccard_similarity_score(y_true, y_pred, average=None), + np.array([1. / 2, 1., 1. / 2])) + + y_true = np.array([[0, 1, 1], [1, 0, 1]]) + y_pred = np.array([[1, 1, 1], [1, 0, 1]]) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='macro'), 5. / 6) + # average='weighted' + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='weighted'), 7. / 8) + # normalize error + msg1 = ("'normalize' is only meaningful with `average='samples'`, got " + "`average='macro'`.") + assert_raise_message(ValueError, msg1, jaccard_similarity_score, y_true, + y_pred, average='macro', normalize=False) + assert_raise_message(ValueError, msg1, jaccard_similarity_score, y_true, + y_pred, average='macro', normalize=True) + msg2 = 'Got 4 > 2' + assert_raise_message(ValueError, msg2, jaccard_similarity_score, y_true, + y_pred, labels=[4]) + msg3 = 'Got -1 < 0' + assert_raise_message(ValueError, msg3, jaccard_similarity_score, y_true, + y_pred, labels=[-1]) + + msg = ('Jaccard is ill-defined and being set to 0.0 in labels ' + 'with no true or predicted samples.') + assert assert_warns_message(UndefinedMetricWarning, msg, + jaccard_similarity_score, + np.array([[0, 1]]), + np.array([[0, 1]]), + average='macro') == 0.5 + + msg = ('Jaccard is ill-defined and being set to 0.0 in samples ' + 'with no true or predicted labels.') + assert assert_warns_message(UndefinedMetricWarning, msg, + jaccard_similarity_score, + np.array([[0, 0], [1, 1]]), + np.array([[0, 0], [1, 1]]), + average='samples') == 0.5 + + assert not list(recwarn) + + +def test_multiclass_jaccard_similarity_score(recwarn): + y_true = ['ant', 'ant', 'cat', 'cat', 'ant', 'cat', 'bird', 'bird'] + y_pred = ['cat', 'ant', 'cat', 'cat', 'ant', 'bird', 'bird', 'cat'] + labels = ['ant', 'bird', 'cat'] + lb = LabelBinarizer() + lb.fit(labels) + y_true_bin = lb.transform(y_true) + y_pred_bin = lb.transform(y_pred) + multi_jaccard_similarity_score = partial(jaccard_similarity_score, y_true, + y_pred) + bin_jaccard_similarity_score = partial(jaccard_similarity_score, + y_true_bin, y_pred_bin) + multi_labels_list = [['ant', 'bird'], ['ant', 'cat'], ['cat', 'bird'], + ['ant'], ['bird'], ['cat'], None] + bin_labels_list = [[0, 1], [0, 2], [2, 1], [0], [1], [2], None] + + # other than average='samples'/'none-samples', test everything else here + for average in ('macro', 'weighted', 'micro', None): + for m_label, b_label in zip(multi_labels_list, bin_labels_list): + assert_almost_equal(multi_jaccard_similarity_score(average=average, + labels=m_label), + bin_jaccard_similarity_score(average=average, + labels=b_label)) + + y_true = np.array([[0, 0], [0, 0], [0, 0]]) + y_pred = np.array([[0, 0], [0, 0], [0, 0]]) + with ignore_warnings(): + assert (jaccard_similarity_score(y_true, y_pred, average='weighted') + == 0) + + assert not list(recwarn) + + +def test_average_binary_jaccard_similarity_score(recwarn): + # tp=0, fp=0, fn=1, tn=0 + assert jaccard_similarity_score([1], [0], average='binary') == 0. + # tp=0, fp=0, fn=0, tn=1 + msg = ('Jaccard is ill-defined and being set to 0.0 with ' + 'no true or predicted samples') + assert assert_warns_message(UndefinedMetricWarning, + msg, + jaccard_similarity_score, + [0, 0], [0, 0], + average='binary') == 0. + # tp=1, fp=0, fn=0, tn=0 (pos_label=0) + assert jaccard_similarity_score([0], [0], pos_label=0, + average='binary') == 1. + y_true = np.array([1, 0, 1, 1, 0]) + y_pred = np.array([1, 0, 1, 1, 1]) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='binary'), 3. / 4) + assert_almost_equal(jaccard_similarity_score(y_true, y_pred, + average='binary', + pos_label=0), 1. / 2) + + assert not list(recwarn) @ignore_warnings diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 68e2c42600ebc..eb6b410d93084 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -142,26 +142,37 @@ "weighted_f2_score": partial(fbeta_score, average="weighted", beta=2), "weighted_precision_score": partial(precision_score, average="weighted"), "weighted_recall_score": partial(recall_score, average="weighted"), + "weighted_jaccard_similarity_score": + partial(jaccard_similarity_score, average="weighted"), "micro_f0.5_score": partial(fbeta_score, average="micro", beta=0.5), "micro_f1_score": partial(f1_score, average="micro"), "micro_f2_score": partial(fbeta_score, average="micro", beta=2), "micro_precision_score": partial(precision_score, average="micro"), "micro_recall_score": partial(recall_score, average="micro"), + "micro_jaccard_similarity_score": + partial(jaccard_similarity_score, average="micro"), "macro_f0.5_score": partial(fbeta_score, average="macro", beta=0.5), "macro_f1_score": partial(f1_score, average="macro"), "macro_f2_score": partial(fbeta_score, average="macro", beta=2), "macro_precision_score": partial(precision_score, average="macro"), "macro_recall_score": partial(recall_score, average="macro"), + "macro_jaccard_similarity_score": + partial(jaccard_similarity_score, average="macro"), "samples_f0.5_score": partial(fbeta_score, average="samples", beta=0.5), "samples_f1_score": partial(f1_score, average="samples"), "samples_f2_score": partial(fbeta_score, average="samples", beta=2), "samples_precision_score": partial(precision_score, average="samples"), "samples_recall_score": partial(recall_score, average="samples"), + "samples_jaccard_similarity_score": + partial(jaccard_similarity_score, average="macro"), "cohen_kappa_score": cohen_kappa_score, + + "binary_jaccard_similarity_score": + partial(jaccard_similarity_score, average="binary") } @@ -248,6 +259,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_precision_score", "samples_recall_score", "coverage_error", + "jaccard_similarity_score", + "unnormalized_jaccard_similarity_score", "unnormalized_multilabel_confusion_matrix_sample", "label_ranking_loss", "label_ranking_average_precision_score", @@ -268,6 +281,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "micro_average_precision_score", "samples_average_precision_score", + "jaccard_similarity_score", + "unnormalized_jaccard_similarity_score", + "binary_jaccard_similarity_score", + # with default average='binary', multiclass is prohibited "precision_score", "recall_score", @@ -286,7 +303,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # Metrics with an "average" argument METRICS_WITH_AVERAGING = { - "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score" + "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", + "binary_jaccard_similarity_score" } # Threshold-based metrics with an "average" argument @@ -331,15 +349,22 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "hamming_loss", "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", + "jaccard_similarity_score", + "binary_jaccard_similarity_score", "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", + "weighted_jaccard_similarity_score", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", + "micro_jaccard_similarity_score", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", + "macro_jaccard_similarity_score", + + "unnormalized_jaccard_similarity_score", "unnormalized_multilabel_confusion_matrix", "unnormalized_multilabel_confusion_matrix_sample", @@ -377,16 +402,21 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", + "weighted_jaccard_similarity_score", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", + "macro_jaccard_similarity_score", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", + "micro_jaccard_similarity_score", + "unnormalized_multilabel_confusion_matrix", "samples_f0.5_score", "samples_f1_score", "samples_f2_score", "samples_precision_score", "samples_recall_score", + "samples_jaccard_similarity_score", } # Regression metrics with "multioutput-continuous" format support @@ -403,6 +433,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "jaccard_similarity_score", "unnormalized_jaccard_similarity_score", "zero_one_loss", "unnormalized_zero_one_loss", + "micro_jaccard_similarity_score", "macro_jaccard_similarity_score", + "binary_jaccard_similarity_score", + "samples_jaccard_similarity_score", + "f1_score", "micro_f1_score", "macro_f1_score", "weighted_recall_score", # P = R = F = accuracy in multiclass case @@ -430,7 +464,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "precision_score", "recall_score", "f2_score", "f0.5_score", "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", - "weighted_precision_score", "unnormalized_multilabel_confusion_matrix", + "weighted_precision_score", "weighted_jaccard_similarity_score", + "unnormalized_multilabel_confusion_matrix", "macro_f0.5_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "log_loss", "hinge_loss" @@ -451,6 +486,9 @@ def test_symmetry(): y_true = random_state.randint(0, 2, size=(20, )) y_pred = random_state.randint(0, 2, size=(20, )) + y_true_bin = random_state.randint(0, 2, size=(20, 25)) + y_pred_bin = random_state.randint(0, 2, size=(20, 25)) + # We shouldn't forget any metrics assert_equal(SYMMETRIC_METRICS.union( NOT_SYMMETRIC_METRICS, set(THRESHOLDED_METRICS), @@ -464,8 +502,17 @@ def test_symmetry(): # Symmetric metric for name in SYMMETRIC_METRICS: metric = ALL_METRICS[name] - assert_allclose(metric(y_true, y_pred), metric(y_pred, y_true), - err_msg="%s is not symmetric" % name) + if name in METRIC_UNDEFINED_BINARY: + if name in MULTILABELS_METRICS: + assert_allclose(metric(y_true_bin, y_pred_bin), + metric(y_pred_bin, y_true_bin), + err_msg="%s is not symmetric" % name) + else: + assert False, "This case is currently unhandled" + else: + assert_allclose(metric(y_true, y_pred), + metric(y_pred, y_true), + err_msg="%s is not symmetric" % name) # Not symmetric metrics for name in NOT_SYMMETRIC_METRICS: @@ -834,6 +881,8 @@ def test_raise_value_error_multilabel_sequences(name): @pytest.mark.parametrize('name', METRICS_WITH_NORMALIZE_OPTION) def test_normalize_option_binary_classification(name): + if name in METRIC_UNDEFINED_BINARY: + return # Test in the binary case n_samples = 20 random_state = check_random_state(0) @@ -851,6 +900,8 @@ def test_normalize_option_binary_classification(name): @pytest.mark.parametrize('name', METRICS_WITH_NORMALIZE_OPTION) def test_normalize_option_multiclass_classification(name): + if name in METRIC_UNDEFINED_MULTICLASS: + return # Test in the multiclass case random_state = check_random_state(0) y_true = random_state.randint(0, 4, size=(20, )) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 2ab4b6b72e3a7..8f5fef58b6ac7 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -3,9 +3,9 @@ import shutil import os import numbers +from functools import partial import numpy as np - import pytest from sklearn.utils.testing import assert_almost_equal @@ -18,7 +18,8 @@ from sklearn.base import BaseEstimator from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, - log_loss, precision_score, recall_score) + log_loss, precision_score, recall_score, + jaccard_similarity_score) from sklearn.metrics import cluster as cluster_module from sklearn.metrics.scorer import (check_scoring, _PredictScorer, _passthrough_scorer) @@ -52,7 +53,9 @@ 'roc_auc', 'average_precision', 'precision', 'precision_weighted', 'precision_macro', 'precision_micro', 'recall', 'recall_weighted', 'recall_macro', 'recall_micro', - 'neg_log_loss', 'log_loss', 'brier_score_loss'] + 'neg_log_loss', 'log_loss', 'brier_score_loss', + 'jaccard', 'jaccard_weighted', 'jaccard_macro', + 'jaccard_micro'] # All supervised cluster scorers (They behave like classification metric) CLUSTER_SCORERS = ["adjusted_rand_score", @@ -64,7 +67,8 @@ "normalized_mutual_info_score", "fowlkes_mallows_score"] -MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples'] +MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples', + 'jaccard_samples'] def _make_estimators(X_train, y_train, y_ml_train): @@ -285,7 +289,9 @@ def test_classification_scores(): clf.fit(X_train, y_train) for prefix, metric in [('f1', f1_score), ('precision', precision_score), - ('recall', recall_score)]: + ('recall', recall_score), + ('jaccard', partial(jaccard_similarity_score, + average='binary'))]: score1 = get_scorer('%s_weighted' % prefix)(clf, X_test, y_test) score2 = metric(y_test, clf.predict(X_test), pos_label=None,