Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions orca_python/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
ccr,
gm,
gmsec,
greater_is_better,
mae,
mmae,
ms,
Expand All @@ -16,9 +15,14 @@
tkendall,
wkappa,
)
from .utils import (
compute_metric,
get_metric_names,
greater_is_better,
load_metric_as_scorer,
)

__all__ = [
"greater_is_better",
"ccr",
"amae",
"gm",
Expand All @@ -32,4 +36,8 @@
"spearman",
"rps",
"accuracy_off1",
"get_metric_names",
"greater_is_better",
"load_metric_as_scorer",
"compute_metric",
]
42 changes: 0 additions & 42 deletions orca_python/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,6 @@
from sklearn.metrics import confusion_matrix, recall_score


def greater_is_better(metric_name):
"""Determine if greater values indicate better classification performance.

Needed when declaring a new scorer through make_scorer from sklearn.

Parameters
----------
metric_name : str
Name of the metric.

Returns
-------
greater_is_better : bool
True if greater values indicate better classification performance, False otherwise.

Examples
--------
>>> from orca_python.metrics.metrics import greater_is_better
>>> greater_is_better("ccr")
True
>>> greater_is_better("mze")
False
>>> greater_is_better("mae")
False

"""
greater_is_better_metrics = [
"ccr",
"ms",
"gm",
"gmsec",
"tkendall",
"wkappa",
"spearman",
"accuracy_off1",
]
if metric_name in greater_is_better_metrics:
return True
else:
return False


def ccr(y_true, y_pred):
"""Calculate the Correctly Classified Ratio.

Expand Down
18 changes: 0 additions & 18 deletions orca_python/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ccr,
gm,
gmsec,
greater_is_better,
mae,
mmae,
ms,
Expand All @@ -22,23 +21,6 @@
)


def test_greater_is_better():
"""Test the greater_is_better function."""
assert greater_is_better("accuracy_off1")
assert greater_is_better("ccr")
assert greater_is_better("gm")
assert greater_is_better("gmsec")
assert not greater_is_better("mae")
assert not greater_is_better("mmae")
assert not greater_is_better("amae")
assert greater_is_better("ms")
assert not greater_is_better("mze")
assert not greater_is_better("rps")
assert greater_is_better("tkendall")
assert greater_is_better("wkappa")
assert greater_is_better("spearman")


def test_accuracy_off1():
"""Test the Accuracy that allows errors in adjacent classes."""
y_true = np.array([0, 1, 2, 3, 4, 5])
Expand Down
171 changes: 171 additions & 0 deletions orca_python/metrics/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Tests for the metrics module utilities."""

import numpy.testing as npt
import pytest

from orca_python.metrics import (
accuracy_off1,
amae,
ccr,
gm,
gmsec,
mae,
mmae,
ms,
mze,
rps,
spearman,
tkendall,
wkappa,
)
from orca_python.metrics.utils import (
_METRICS,
compute_metric,
get_metric_names,
greater_is_better,
load_metric_as_scorer,
)


def test_get_metric_names():
"""Test that get_metric_names returns all available metric names."""
all_metrics = get_metric_names()
expected_names = list(_METRICS.keys())

assert type(all_metrics) is list
assert all_metrics[:3] == ["accuracy_off1", "amae", "ccr"]
assert "rps" in all_metrics
npt.assert_array_equal(sorted(all_metrics), sorted(expected_names))


@pytest.mark.parametrize(
"metric_name, gib",
[
("accuracy_off1", True),
("amae", False),
("ccr", True),
("gm", True),
("gmsec", True),
("mae", False),
("mmae", False),
("ms", True),
("mze", False),
("rps", False),
("spearman", True),
("tkendall", True),
("wkappa", True),
],
)
def test_greater_is_better(metric_name, gib):
"""Test that greater_is_better returns the correct boolean for each metric."""
assert greater_is_better(metric_name) == gib


def test_greater_is_better_invalid_name():
"""Test that greater_is_better raises an error for an invalid metric name."""
error_msg = "Unrecognized metric name: 'roc_auc'."

with pytest.raises(KeyError, match=error_msg):
greater_is_better("roc_auc")


@pytest.mark.parametrize(
"metric_name, metric",
[
("rps", rps),
("ccr", ccr),
("accuracy_off1", accuracy_off1),
("gm", gm),
("gmsec", gmsec),
("mae", mae),
("mmae", mmae),
("amae", amae),
("ms", ms),
("mze", mze),
("tkendall", tkendall),
("wkappa", wkappa),
("spearman", spearman),
],
)
def test_load_metric_as_scorer(metric_name, metric):
"""Test that load_metric_as_scorer correctly loads the expected metric."""
metric_func = load_metric_as_scorer(metric_name)

assert metric_func._score_func == metric
assert metric_func._sign == (1 if greater_is_better(metric_name) else -1)


@pytest.mark.parametrize(
"metric_name, metric",
[
("ccr", ccr),
("accuracy_off1", accuracy_off1),
("gm", gm),
("gmsec", gmsec),
("mae", mae),
("mmae", mmae),
("amae", amae),
("ms", ms),
("mze", mze),
("tkendall", tkendall),
("wkappa", wkappa),
("spearman", spearman),
],
)
def test_correct_metric_output(metric_name, metric):
"""Test that the loaded metric function produces the same output as the
original metric."""
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
metric_func = load_metric_as_scorer(metric_name)
metric_true = metric(y_true, y_pred)
metric_pred = metric_func._score_func(y_true, y_pred)

npt.assert_almost_equal(metric_pred, metric_true, decimal=6)


def test_load_metric_invalid_name():
"""Test that loading an invalid metric raises the correct exception."""
error_msg = "metric_name must be a string."
with pytest.raises(TypeError, match=error_msg):
load_metric_as_scorer(123)

error_msg = "Unrecognized metric name: 'roc_auc'."
with pytest.raises(KeyError, match=error_msg):
load_metric_as_scorer("roc_auc")


@pytest.mark.parametrize(
"metric_name",
[
"ccr",
"accuracy_off1",
"gm",
"gmsec",
"mae",
"mmae",
"amae",
"ms",
"mze",
"tkendall",
"wkappa",
"spearman",
],
)
def test_compute_metric(metric_name) -> None:
"""Test that compute_metric returns the correct metric value."""
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
metric_value = compute_metric(metric_name, y_true, y_pred)
metric_func = load_metric_as_scorer(metric_name)
metric_true = metric_func._score_func(y_true, y_pred)

npt.assert_almost_equal(metric_value, metric_true, decimal=6)


def test_compute_metric_invalid_name():
"""Test that compute_metric raises an error for an invalid metric name."""
error_msg = "Unrecognized metric name: 'roc_auc'."

with pytest.raises(KeyError, match=error_msg):
compute_metric("roc_auc", [1, 2, 3], [1, 2, 3])
Loading