Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX binary/multiclass jaccard_similarity_score and extend to handle averaging #13092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 97 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
64e30d6
multiclass jaccard similarity not equal to accurary_score
gxyd Nov 5, 2017
a495cfc
add space and fix input
gxyd Nov 8, 2017
fcba7f0
score being a n_class size array and weight already taken care of
gxyd Nov 10, 2017
d49ccab
add space to fix printing of doctest
gxyd Nov 24, 2017
615ac9a
add support for 'average' of type 'macro', 'micro', 'weighted'
gxyd Nov 24, 2017
78b2a84
add tests and make documentation changes
gxyd Nov 25, 2017
41f7e2b
use 'average' for 'multilabel' classification
gxyd Nov 29, 2017
a7d0111
introduce average='binary', average='samples'
gxyd Nov 29, 2017
057815a
show errors and warning before anything
gxyd Dec 1, 2017
f1bd76f
write separate functions
gxyd Dec 1, 2017
581d540
completely okay API and improved doctest
gxyd Dec 2, 2017
aefe921
fix lgtm error and better control flow
gxyd Dec 2, 2017
83df958
add normalize in API
gxyd Dec 2, 2017
041c668
raise ValueError for not-providing 'avergae' in multiclass
gxyd Dec 12, 2017
39b92b1
fixed errors with multiclass for different average values
gxyd Dec 12, 2017
a0712b5
fix tests, use assert_raise_message instead
gxyd Dec 12, 2017
113072a
add common_test for jaccard_similarity_score
gxyd Dec 15, 2017
c52d577
use `average='none-samples'` instead of 'normalize=False'
gxyd Dec 17, 2017
2e2d762
average='micro' in multiclass case is equivalent to accuracy_score
gxyd Dec 17, 2017
5504a00
fixes to multilabel case
gxyd Dec 17, 2017
b30ba53
add error message for `average='samples'` for non-multilable case
gxyd Dec 17, 2017
8d0ca20
add none-samples in common test
gxyd Dec 20, 2017
ce89b5f
add support for `labels` in multilabel classification
gxyd Dec 28, 2017
192bb2d
fix multilablel classification
gxyd Dec 30, 2017
149af2a
fix for multiclass
gxyd Dec 31, 2017
40fca72
corrected 'macro', 'weighted' for multiclass only 'micro' remains
gxyd Jan 1, 2018
4b50447
fix completely logic of average='micro', now only 'binary' remains
gxyd Jan 2, 2018
fd099e5
remove 'warn' from API, after discussion on PR with jnothman
gxyd Jan 2, 2018
8c9c614
fix average='binary'
gxyd Jan 2, 2018
a7d3b40
fix doctest, now test_common and lgtm remain to be fixed
gxyd Jan 2, 2018
8a7e673
this fixes lgtm?
gxyd Jan 2, 2018
6e75c5a
fix average='micro' for multiclass jaccard
gxyd Jan 7, 2018
c800598
add smart tests
gxyd Jan 7, 2018
c3fa41d
first fix for test_common
gxyd Jan 8, 2018
27ffebf
fixes LGTM errors?
gxyd Jan 8, 2018
e017ccf
remove warning from tests/test_classification.py
gxyd Jan 9, 2018
9ee4c11
simplify code for multilabel jaccard
gxyd Jan 10, 2018
d1311c7
address Joel's comments
gxyd Jan 11, 2018
d3f76d5
fix doc and add average=None test for multiclass
gxyd Jan 11, 2018
319b5d3
fix none-samples jaccard_similarity score to return array of scores
gxyd Jan 12, 2018
07a05e6
take of care of zero weights for average='weighted'
gxyd Jan 13, 2018
3a312a3
fix test_common
gxyd Jan 13, 2018
9251d29
don't bother testing 'sample_weight' and fix test
gxyd Jan 13, 2018
225c0f2
remove average='none-samples' as a possibility
gxyd Jan 16, 2018
d7fe5ca
fix average='weighted'
gxyd Jan 16, 2018
414ae8b
use np.in1d instead of np.isin (unavailable in version < 1.13.0)
gxyd Jan 16, 2018
3ac79bd
address Joel's comments
gxyd Jan 17, 2018
3673407
fix lgtm error
gxyd Jan 17, 2018
d3d7ca9
fix lgtm
gxyd Jan 17, 2018
04768c5
Fix flake8 errors
gxyd Jan 18, 2018
54fe344
code coverage
gxyd Jan 18, 2018
ee54853
fix flake8
gxyd Jan 18, 2018
551804d
improve doc
gxyd Jan 18, 2018
37737c2
Merge branch 'master' into jaccard-sim
gxyd Jan 22, 2018
0d45a44
add what's new entry and address Joel's comments
gxyd Jan 22, 2018
785bb36
improve doc's entry
gxyd Jan 22, 2018
7c1314a
use normalize='true-if-samples' for internal use
gxyd Jan 23, 2018
90e0c5c
address Joel's comments all, but one
gxyd Jan 24, 2018
8ff62bc
add jaccard similarity score to scorers
gxyd Jan 29, 2018
f5d03d0
use make_scorer with average='binary'
gxyd Jan 30, 2018
c3279ff
fix import and pep8
gxyd Jan 30, 2018
a673683
fix doc
gxyd Jan 30, 2018
0b507aa
use 'jaccard' instead of 'jaccard_similarity'
gxyd Jan 30, 2018
5bb690d
collect common validation code between prfs and jaccard
gxyd Jan 30, 2018
f1e1b69
update docstring and name
gxyd Jan 30, 2018
9606d52
fix pep8
gxyd Jan 30, 2018
e1d7e28
a little more refactoring
gxyd Jan 31, 2018
a2a09da
change answer for zeroed multiclass and binary averaging
gxyd Jan 31, 2018
c873dce
fix for edge cases
gxyd Feb 6, 2018
4978dfd
update (refactoring) function name and add doc example
gxyd Feb 7, 2018
c73605f
Merge branch 'master' into HEAD
jnothman Nov 6, 2018
b536ac6
Merge branch 'master' into jaccard
jnothman Nov 6, 2018
4fe8a1f
Fix merge error
jnothman Nov 6, 2018
095a02e
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ja…
jnothman Nov 11, 2018
2c9b356
WIP
jnothman Nov 13, 2018
0e9e12d
Make tests pass
jnothman Nov 18, 2018
95dfada
Credit in what's new
jnothman Nov 18, 2018
99fdd5c
Clean merge error in what's new
jnothman Nov 18, 2018
80520e9
Remove debug print
jnothman Nov 18, 2018
4ba98bc
PEP8
jnothman Nov 18, 2018
5b5f04c
new array printing format
jnothman Nov 18, 2018
dfe58f4
new array printing format #2
jnothman Nov 18, 2018
7422982
Revert changes to v0.20.rst
jnothman Nov 18, 2018
1f495c1
Merge branch 'master' into jaccard-sim
jnothman Nov 20, 2018
afa7759
Remove changes due to bad merge
jnothman Nov 24, 2018
55b1e83
Avoid assert_equal
jnothman Nov 24, 2018
6b71c18
cosmit
jnothman Nov 24, 2018
03d89de
Clean up validation
jnothman Nov 24, 2018
46c1274
reuse warning code
jnothman Nov 24, 2018
a779926
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ja…
jnothman Nov 25, 2018
e082e62
WIP
jnothman Jan 2, 2019
1e9373e
Merge branch 'master' into jaccard-sim
jnothman Feb 4, 2019
28dcca4
Clean what's new
jnothman Feb 4, 2019
7fd7201
Merge branch 'master' into jaccard-sim
jnothman Feb 4, 2019
27cf502
FIX coax tests to pass
jnothman Feb 5, 2019
7943540
Merge branch 'jaccard-sim' of github.com:jnothman/scikit-learn into j…
jnothman Feb 5, 2019
47776c0
Address Adrin's comments
jnothman Feb 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Scoring Function
'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support
'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1'
'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1'
'jaccard' etc. :func:`metrics.jaccard_similarity_score` suffixes apply as with 'f1'
'roc_auc' :func:`metrics.roc_auc_score`

**Clustering**
Expand Down Expand Up @@ -698,24 +699,31 @@ with a ground truth label set :math:`y_i` and predicted label set

J(y_i, \hat{y}_i) = \frac{|y_i \cap \hat{y}_i|}{|y_i \cup \hat{y}_i|}.

In binary and multiclass classification, the Jaccard similarity coefficient
score is equal to the classification accuracy.
:func:`jaccard_similarity_score` works like :func:`precision_recall_fscore_support`
as a naively set-wise measure applying only to binary and multilabel targets.

::
In the multilabel case with binary label indicators: ::

>>> import numpy as np
>>> from sklearn.metrics import jaccard_similarity_score
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
>>> y_true = np.array([[0, 1], [1, 1]])
>>> y_pred = np.ones((2, 2))
>>> jaccard_similarity_score(y_true, y_pred)
0.5
0.75
>>> jaccard_similarity_score(y_true, y_pred, normalize=False)
2
1.5

In the multilabel case with binary label indicators: ::
Multiclass problems are binarized and treated like the corresponding
multilabel problem: ::

>>> jaccard_similarity_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
0.75
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
>>> jaccard_similarity_score(y_true, y_pred, average='macro')
0.5
>>> jaccard_similarity_score(y_true, y_pred, average='micro')
0.33...
>>> jaccard_similarity_score(y_true, y_pred, average=None)
array([1., 0., 0., 1.])

.. _precision_recall_f_measure_metrics:

Expand Down
15 changes: 11 additions & 4 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,29 @@ Support for Python 3.4 and below has been officially dropped.
metrics such as recall, specificity, fall out and miss rate.
:issue:`11179` by :user:`Shangwu Yao <ShangwuYao>` and `Joel Nothman`_.

- |Feature| |Fix| :func:`metrics.jaccard_similarity_score` now accepts
``average`` argument like :func:`metrics.precision_recall_fscore_support` as
a naively set-wise measure applying only to binary, multilabel targets. It
now binarizes multiclass input and treats them like the corresponding
multilabel problem.
:issue:`10083` by :user:`Gaurav Dhingra <gxyd>` and `Joel Nothman`_.

- |Enhancement| Use label `accuracy` instead of `micro-average` on
:func:`metrics.classification_report` to avoid confusion. `micro-average` is
only shown for multi-label or multi-class with a subset of classes because
it is otherwise identical to accuracy.
:issue:`12334` by :user:`Emmanuel Arias <[email protected]>`,
`Joel Nothman`_ and `Andreas Müller`_

- |Fix| The metric :func:`metrics.r2_score` is degenerate with a single sample
and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`.
:issue:`12855` by :user:`Pawel Sendyk <psendyk>.`

- |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated
in version 0.21 and will be removed in version 0.23.
:issue:`10580` by :user:`Reshama Shaikh <reshamas>` and `Sandra
Mitrovic <SandraMNE>`.

- |Fix| The metric :func:`metrics.r2_score` is degenerate with a single sample
and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`.
:issue:`12855` by :user:`Pawel Sendyk <psendyk>.`

- |Efficiency| The pairwise manhattan distances with sparse input now uses the
BLAS shipped with scipy instead of the bundled BLAS. :issue:`12732` by
:user:`Jérémie du Boisberranger <jeremiedbb>`
Expand Down
207 changes: 143 additions & 64 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ class labels [2]_.
return 1 - k


def jaccard_similarity_score(y_true, y_pred, normalize=True,
def jaccard_similarity_score(y_true, y_pred, labels=None, pos_label=1,
average='samples', normalize='true-if-samples',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note that this is not backward compatible with users calling it with positional arguments [sigh]! But I'm not sure what we should do in these cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we deprecate the current function and make jaccard_score that would solve it :)

sample_weight=None):
"""Jaccard similarity coefficient score

Expand All @@ -596,72 +597,136 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
y_pred : 1d array-like, or label indicator array / sparse matrix
Predicted labels, as returned by a classifier.

labels : list, optional
The set of labels to include when ``average != 'binary'``, and their
order if ``average is None``. Labels present in the data can be
excluded, for example to calculate a multiclass average ignoring a
majority negative class, while labels not present in the data will
result in 0 components in a macro average. For multilabel targets,
labels are column indices. By default, all labels in ``y_true`` and
``y_pred`` are used in sorted order.

pos_label : str or int, 1 by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> (default=1)?

The class to report if ``average='binary'`` and the data is binary.
If the data are multiclass or multilabel, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.

average : string, ['samples' (default), 'binary', 'micro', 'macro', None, \
'weighted']
If ``None``, the scores for each class are returned. Otherwise, this
determines the type of averaging performed on the data:

``'binary'``:
Only report results for the class specified by ``pos_label``.
This is applicable only if targets (``y_{true,pred}``) are binary.
``'micro'``:
Calculate metrics globally by counting the total true positives,
false negatives and false positives.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
``'weighted'``:
Calculate metrics for each label, and find their average, weighted
by support (the number of true instances for each label). This
alters 'macro' to account for label imbalance.
``'samples'``:
Calculate metrics for each instance, and find their average (only
meaningful for multilabel classification).

normalize : bool, optional (default=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is true-if-samples and not True

If ``False``, return the sum of the Jaccard similarity coefficient
over the sample set. Otherwise, return the average of Jaccard
similarity coefficient.
similarity coefficient. ``normalize`` is only applicable when
``average='samples'``. The default value 'true-if-samples' behaves like
True, but does not raise an error with other values of `average`.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

Returns
-------
score : float
If ``normalize == True``, return the average Jaccard similarity
coefficient, else it returns the sum of the Jaccard similarity
coefficient over the sample set.

The best performance is 1 with ``normalize == True`` and the number
of samples with ``normalize == False``.
score : float (if average is not None) or array of floats, shape =\
[n_unique_labels]

See also
--------
accuracy_score, hamming_loss, zero_one_loss

Notes
-----
In binary and multiclass classification, this function is equivalent
to the ``accuracy_score``. It differs in the multilabel classification
problem.
:func:`jaccard_similarity_score` may be a poor metric if there are no
positives for some samples or classes.

References
----------
.. [1] `Wikipedia entry for the Jaccard index
<https://en.wikipedia.org/wiki/Jaccard_index>`_


Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import jaccard_similarity_score
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
>>> jaccard_similarity_score(y_true, y_pred)

In the multilabel case:

>>> y_true = np.array([[1, 0, 1], [0, 0, 1], [1, 1, 1]])
>>> y_pred = np.array([[0, 1, 1], [1, 1, 1], [0, 0, 1]])
>>> jaccard_similarity_score(y_true, y_pred, average='samples')
... # doctest: +ELLIPSIS
0.33...
>>> jaccard_similarity_score(y_true, y_pred, average='micro')
... # doctest: +ELLIPSIS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is redundant, it's already set above (and it generates the odd empty ... line in the output).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these flags are per-statement, so I don't see how "it's already set above"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scope of those flags are at least per-block, example:

>>> cov.covariance_ # doctest: +ELLIPSIS
array([[0.7569..., 0.2818...],
[0.2818..., 0.3928...]])
>>> cov.location_
array([0.0622..., 0.0193...])

0.33...
>>> jaccard_similarity_score(y_true, y_pred, average='weighted')
0.5
>>> jaccard_similarity_score(y_true, y_pred, normalize=False)
2
>>> jaccard_similarity_score(y_true, y_pred, average=None)
array([0., 0., 1.])

In the multilabel case with binary label indicators:
In the multiclass case:

>>> import numpy as np
>>> jaccard_similarity_score(np.array([[0, 1], [1, 1]]),\
np.ones((2, 2)))
0.75
>>> jaccard_similarity_score(np.array([0, 1, 2, 3]),
... np.array([0, 2, 2, 3]), average='macro')
0.625
"""
if average != 'samples' and normalize != 'true-if-samples':
raise ValueError("'normalize' is only meaningful with "
"`average='samples'`, got `average='%s'`."
% average)
labels = _check_set_wise_labels(y_true, y_pred, average, labels,
pos_label)
if labels is _ALL_ZERO:
warnings.warn('Jaccard is ill-defined and being set to 0.0 with no '
'true or predicted samples', UndefinedMetricWarning)
return 0.
samplewise = average == 'samples'
MCM = multilabel_confusion_matrix(y_true, y_pred,
sample_weight=sample_weight,
labels=labels, samplewise=samplewise)
numerator = MCM[:, 1, 1]
denominator = MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0]

# Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
if y_type.startswith('multilabel'):
with np.errstate(divide='ignore', invalid='ignore'):
# oddly, we may get an "invalid" rather than a "divide" error here
pred_or_true = count_nonzero(y_true + y_pred, axis=1)
pred_and_true = count_nonzero(y_true.multiply(y_pred), axis=1)
score = pred_and_true / pred_or_true
score[pred_or_true == 0.0] = 1.0
if average == 'micro':
numerator = np.array([numerator.sum()])
denominator = np.array([denominator.sum()])

jaccard = _prf_divide(numerator, denominator, 'jaccard',
'true or predicted', average, ('jaccard',))
if average is None:
return jaccard
if not normalize:
return np.sum(jaccard * (1 if sample_weight is None
else sample_weight))
if average == 'weighted':
weights = MCM[:, 1, 0] + MCM[:, 1, 1]
if not np.any(weights):
# numerator is 0, and warning should have already been issued
weights = None
elif average == 'samples' and sample_weight is not None:
weights = sample_weight
else:
score = y_true == y_pred

return _weighted_sum(score, sample_weight, normalize)
weights = None
return np.average(jaccard, weights=weights)


def matthews_corrcoef(y_true, y_pred, sample_weight=None):
Expand Down Expand Up @@ -1056,8 +1121,10 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for):
The metric, modifier and average arguments are used only for determining
an appropriate warning.
"""
result = numerator / denominator
mask = denominator == 0.0
denominator = denominator.copy()
denominator[mask] = 1
result = numerator / denominator
if not np.any(mask):
return result

Expand Down Expand Up @@ -1091,6 +1158,41 @@ def _prf_divide(numerator, denominator, metric, modifier, average, warn_for):
return result


_ALL_ZERO = object() # sentinel for special, degenerate case


def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
"""Validation associated with set-wise metrics

Returns identified labels or _ALL_ZERO sentinel
"""
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
if average not in average_options and average != 'binary':
raise ValueError('average has to be one of ' +
str(average_options))

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
present_labels = unique_labels(y_true, y_pred)
if average == 'binary':
if y_type == 'binary':
if pos_label not in present_labels:
if len(present_labels) < 2:
return _ALL_ZERO
else:
raise ValueError("pos_label=%r is not a valid label: "
"%r" % (pos_label, present_labels))
labels = [pos_label]
else:
raise ValueError("Target is %s but average='binary'. Please "
"choose another average setting." % y_type)
elif pos_label not in (None, 1):
warnings.warn("Note that pos_label (set to %r) is ignored when "
"average != 'binary' (got %r). You may use "
"labels=[pos_label] to specify a single positive class."
% (pos_label, average), UserWarning)
return labels


def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
pos_label=1, average=None,
warn_for=('precision', 'recall',
Expand Down Expand Up @@ -1234,35 +1336,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
array([2, 2, 2]))

"""
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
if average not in average_options and average != 'binary':
raise ValueError('average has to be one of ' +
str(average_options))
if beta <= 0:
raise ValueError("beta should be >0 in the F-beta score")

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
present_labels = unique_labels(y_true, y_pred)

if average == 'binary':
if y_type == 'binary':
if pos_label not in present_labels:
if len(present_labels) < 2:
# Only negative labels
return (0., 0., 0., 0)
else:
raise ValueError("pos_label=%r is not a valid label: %r" %
(pos_label, present_labels))
labels = [pos_label]
else:
raise ValueError("Target is %s but average='binary'. Please "
"choose another average setting." % y_type)
elif pos_label not in (None, 1):
warnings.warn("Note that pos_label (set to %r) is ignored when "
"average != 'binary' (got %r). You may use "
"labels=[pos_label] to specify a single positive class."
% (pos_label, average), UserWarning)
labels = _check_set_wise_labels(y_true, y_pred, average, labels,
pos_label)
if labels is _ALL_ZERO:
return (0., 0., 0., 0)

# Calculate tp_sum, pred_sum, true_sum ###
samplewise = average == 'samples'
Expand Down
8 changes: 5 additions & 3 deletions sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss,
balanced_accuracy_score, explained_variance_score,
brier_score_loss)
brier_score_loss, jaccard_similarity_score)

from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
Expand Down Expand Up @@ -482,6 +482,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
accuracy_scorer = make_scorer(accuracy_score)
f1_scorer = make_scorer(f1_score)
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
jaccard_similarity_scorer = make_scorer(jaccard_similarity_score)

# Score functions that need decision values
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
Expand Down Expand Up @@ -534,8 +535,9 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,


for name, metric in [('precision', precision_score),
('recall', recall_score), ('f1', f1_score)]:
SCORERS[name] = make_scorer(metric)
('recall', recall_score), ('f1', f1_score),
('jaccard', jaccard_similarity_score)]:
SCORERS[name] = make_scorer(metric, average='binary')
for average in ['macro', 'micro', 'samples', 'weighted']:
qualified_name = '{0}_{1}'.format(name, average)
SCORERS[qualified_name] = make_scorer(metric, pos_label=None,
Expand Down
Loading