Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d4d4af8

Browse files
MNT Move entropy to private function (scikit-learn#31294)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 4872503 commit d4d4af8

File tree

6 files changed

+51
-22
lines changed

6 files changed

+51
-22
lines changed

doc/modules/array_api.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ base estimator also does:
132132
Metrics
133133
-------
134134

135-
- :func:`sklearn.metrics.cluster.entropy`
136135
- :func:`sklearn.metrics.accuracy_score`
137136
- :func:`sklearn.metrics.d2_tweedie_score`
138137
- :func:`sklearn.metrics.explained_variance_score`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`metrics.cluster.entropy` is deprecated and will be removed in v1.10.
2+
By :user:`Lucy Liu <lucyleeow>`

sklearn/metrics/cluster/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
adjusted_rand_score,
1515
completeness_score,
1616
contingency_matrix,
17+
# TODO(1.10): Remove
1718
entropy,
1819
expected_mutual_information,
1920
fowlkes_mallows_score,
@@ -40,6 +41,7 @@
4041
"consensus_score",
4142
"contingency_matrix",
4243
"davies_bouldin_score",
44+
# TODO(1.10): Remove
4345
"entropy",
4446
"expected_mutual_information",
4547
"fowlkes_mallows_score",

sklearn/metrics/cluster/_supervised.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
from scipy import sparse as sp
1616

17+
from ...utils import deprecated
1718
from ...utils._array_api import _max_precision_float_dtype, get_namespace_and_device
1819
from ...utils._param_validation import Hidden, Interval, StrOptions, validate_params
1920
from ...utils.multiclass import type_of_target
@@ -530,8 +531,8 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, *, beta=1.0):
530531
if len(labels_true) == 0:
531532
return 1.0, 1.0, 1.0
532533

533-
entropy_C = entropy(labels_true)
534-
entropy_K = entropy(labels_pred)
534+
entropy_C = _entropy(labels_true)
535+
entropy_K = _entropy(labels_pred)
535536

536537
contingency = contingency_matrix(labels_true, labels_pred, sparse=True)
537538
MI = mutual_info_score(None, None, contingency=contingency)
@@ -1042,7 +1043,7 @@ def adjusted_mutual_info_score(
10421043
# Calculate the expected value for the mutual information
10431044
emi = expected_mutual_information(contingency, n_samples)
10441045
# Calculate entropy for each labeling
1045-
h_true, h_pred = entropy(labels_true), entropy(labels_pred)
1046+
h_true, h_pred = _entropy(labels_true), _entropy(labels_pred)
10461047
normalizer = _generalized_average(h_true, h_pred, average_method)
10471048
denominator = normalizer - emi
10481049
# Avoid 0.0 / 0.0 when expectation equals maximum, i.e. a perfect match.
@@ -1168,7 +1169,7 @@ def normalized_mutual_info_score(
11681169
return 0.0
11691170

11701171
# Calculate entropy for each labeling
1171-
h_true, h_pred = entropy(labels_true), entropy(labels_pred)
1172+
h_true, h_pred = _entropy(labels_true), _entropy(labels_pred)
11721173

11731174
normalizer = _generalized_average(h_true, h_pred, average_method)
11741175
return float(mi / normalizer)
@@ -1272,13 +1273,7 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse="deprecated"):
12721273
return float(np.sqrt(tk / pk) * np.sqrt(tk / qk)) if tk != 0.0 else 0.0
12731274

12741275

1275-
@validate_params(
1276-
{
1277-
"labels": ["array-like"],
1278-
},
1279-
prefer_skip_nested_validation=True,
1280-
)
1281-
def entropy(labels):
1276+
def _entropy(labels):
12821277
"""Calculate the entropy for a labeling.
12831278
12841279
Parameters
@@ -1312,3 +1307,25 @@ def entropy(labels):
13121307
# Always convert the result as a Python scalar (on CPU) instead of a device
13131308
# specific scalar array.
13141309
return float(-xp.sum((pi / pi_sum) * (xp.log(pi) - log(pi_sum))))
1310+
1311+
1312+
# TODO(1.10): Remove
1313+
@deprecated("`entropy` is deprecated in 1.8 and will be removed in 1.10.")
1314+
def entropy(labels):
1315+
"""Calculate the entropy for a labeling.
1316+
1317+
Parameters
1318+
----------
1319+
labels : array-like of shape (n_samples,), dtype=int
1320+
The labels.
1321+
1322+
Returns
1323+
-------
1324+
entropy : float
1325+
The entropy for a labeling.
1326+
1327+
Notes
1328+
-----
1329+
The logarithm used is the natural logarithm (base-e).
1330+
"""
1331+
return _entropy(labels)

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
adjusted_rand_score,
1111
completeness_score,
1212
contingency_matrix,
13-
entropy,
1413
expected_mutual_information,
1514
fowlkes_mallows_score,
1615
homogeneity_completeness_v_measure,
@@ -21,7 +20,12 @@
2120
rand_score,
2221
v_measure_score,
2322
)
24-
from sklearn.metrics.cluster._supervised import _generalized_average, check_clusterings
23+
from sklearn.metrics.cluster._supervised import (
24+
_entropy,
25+
_generalized_average,
26+
check_clusterings,
27+
entropy,
28+
)
2529
from sklearn.utils import assert_all_finite
2630
from sklearn.utils._array_api import (
2731
_get_namespace_device_dtype_ids,
@@ -267,10 +271,16 @@ def test_int_overflow_mutual_info_fowlkes_mallows_score():
267271
assert_all_finite(fowlkes_mallows_score(x, y))
268272

269273

274+
# TODO(1.10): Remove
275+
def test_public_entropy_deprecation():
276+
with pytest.warns(FutureWarning, match="Function entropy is deprecated"):
277+
entropy([0, 0, 42.0])
278+
279+
270280
def test_entropy():
271-
assert_almost_equal(entropy([0, 0, 42.0]), 0.6365141, 5)
272-
assert_almost_equal(entropy([]), 1)
273-
assert entropy([1, 1, 1, 1]) == 0
281+
assert_almost_equal(_entropy([0, 0, 42.0]), 0.6365141, 5)
282+
assert_almost_equal(_entropy([]), 1)
283+
assert _entropy([1, 1, 1, 1]) == 0
274284

275285

276286
@pytest.mark.parametrize(
@@ -284,9 +294,9 @@ def test_entropy_array_api(array_namespace, device, dtype_name):
284294
empty_int32_labels = xp.asarray([], dtype=xp.int32, device=device)
285295
int_labels = xp.asarray([1, 1, 1, 1], device=device)
286296
with config_context(array_api_dispatch=True):
287-
assert entropy(float_labels) == pytest.approx(0.6365141, abs=1e-5)
288-
assert entropy(empty_int32_labels) == 1
289-
assert entropy(int_labels) == 0
297+
assert _entropy(float_labels) == pytest.approx(0.6365141, abs=1e-5)
298+
assert _entropy(empty_int32_labels) == 1
299+
assert _entropy(int_labels) == 0
290300

291301

292302
def test_contingency_matrix():
@@ -339,7 +349,7 @@ def test_v_measure_and_mutual_information(seed=36):
339349
v_measure_score(labels_a, labels_b),
340350
2.0
341351
* mutual_info_score(labels_a, labels_b)
342-
/ (entropy(labels_a) + entropy(labels_b)),
352+
/ (_entropy(labels_a) + _entropy(labels_b)),
343353
0,
344354
)
345355
avg = "arithmetic"

sklearn/tests/test_public_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def _check_function_param_validation(
223223
"sklearn.metrics.classification_report",
224224
"sklearn.metrics.cluster.adjusted_mutual_info_score",
225225
"sklearn.metrics.cluster.contingency_matrix",
226-
"sklearn.metrics.cluster.entropy",
227226
"sklearn.metrics.cluster.fowlkes_mallows_score",
228227
"sklearn.metrics.cluster.homogeneity_completeness_v_measure",
229228
"sklearn.metrics.cluster.normalized_mutual_info_score",

0 commit comments

Comments
 (0)