@@ -1733,67 +1733,87 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
1733
1733
1734
1734
1735
1735
def check_array_api_metric (
1736
- metric , array_namespace , device , dtype , y_true_np , y_pred_np , sample_weight = None
1736
+ metric , array_namespace , device , dtype_name , y_true_np , y_pred_np , sample_weight
1737
1737
):
1738
1738
xp = _array_api_for_tests (array_namespace , device )
1739
+
1739
1740
y_true_xp = xp .asarray (y_true_np , device = device )
1740
1741
y_pred_xp = xp .asarray (y_pred_np , device = device )
1741
1742
1742
1743
metric_np = metric (y_true_np , y_pred_np , sample_weight = sample_weight )
1743
1744
1745
+ if sample_weight is not None :
1746
+ sample_weight = xp .asarray (sample_weight , device = device )
1747
+
1744
1748
with config_context (array_api_dispatch = True ):
1745
- if sample_weight is not None :
1746
- sample_weight = xp .asarray (sample_weight , device = device )
1747
1749
metric_xp = metric (y_true_xp , y_pred_xp , sample_weight = sample_weight )
1748
1750
1749
1751
assert_allclose (
1750
1752
metric_xp ,
1751
1753
metric_np ,
1752
- atol = _atol_for_type (dtype ),
1754
+ atol = _atol_for_type (dtype_name ),
1753
1755
)
1754
1756
1755
1757
1756
1758
def check_array_api_binary_classification_metric (
1757
- metric , array_namespace , device , dtype
1759
+ metric , array_namespace , device , dtype_name
1758
1760
):
1759
1761
y_true_np = np .array ([0 , 0 , 1 , 1 ])
1760
1762
y_pred_np = np .array ([0 , 1 , 0 , 1 ])
1763
+
1761
1764
check_array_api_metric (
1762
- metric , array_namespace , device , dtype , y_true_np = y_true_np , y_pred_np = y_pred_np
1765
+ metric ,
1766
+ array_namespace ,
1767
+ device ,
1768
+ dtype_name ,
1769
+ y_true_np = y_true_np ,
1770
+ y_pred_np = y_pred_np ,
1771
+ sample_weight = None ,
1772
+ )
1773
+
1774
+ sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ], dtype = dtype_name )
1775
+
1776
+ check_array_api_metric (
1777
+ metric ,
1778
+ array_namespace ,
1779
+ device ,
1780
+ dtype_name ,
1781
+ y_true_np = y_true_np ,
1782
+ y_pred_np = y_pred_np ,
1783
+ sample_weight = sample_weight ,
1763
1784
)
1764
- if "sample_weight" in signature (metric ).parameters :
1765
- check_array_api_metric (
1766
- metric ,
1767
- array_namespace ,
1768
- device ,
1769
- dtype ,
1770
- y_true_np = y_true_np ,
1771
- y_pred_np = y_pred_np ,
1772
- sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ]),
1773
- )
1774
1785
1775
1786
1776
1787
def check_array_api_multiclass_classification_metric (
1777
- metric , array_namespace , device , dtype
1788
+ metric , array_namespace , device , dtype_name
1778
1789
):
1779
1790
y_true_np = np .array ([0 , 1 , 2 , 3 ])
1780
1791
y_pred_np = np .array ([0 , 1 , 0 , 2 ])
1792
+
1781
1793
check_array_api_metric (
1782
- metric , array_namespace , device , dtype , y_true_np = y_true_np , y_pred_np = y_pred_np
1794
+ metric ,
1795
+ array_namespace ,
1796
+ device ,
1797
+ dtype_name ,
1798
+ y_true_np = y_true_np ,
1799
+ y_pred_np = y_pred_np ,
1800
+ sample_weight = None ,
1801
+ )
1802
+
1803
+ sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ], dtype = dtype_name )
1804
+
1805
+ check_array_api_metric (
1806
+ metric ,
1807
+ array_namespace ,
1808
+ device ,
1809
+ dtype_name ,
1810
+ y_true_np = y_true_np ,
1811
+ y_pred_np = y_pred_np ,
1812
+ sample_weight = sample_weight ,
1783
1813
)
1784
- if "sample_weight" in signature (metric ).parameters :
1785
- check_array_api_metric (
1786
- metric ,
1787
- array_namespace ,
1788
- device ,
1789
- dtype ,
1790
- y_true_np = y_true_np ,
1791
- y_pred_np = y_pred_np ,
1792
- sample_weight = np .array ([0.0 , 0.1 , 2.0 , 1.0 ]),
1793
- )
1794
1814
1795
1815
1796
- metric_checkers = {
1816
+ array_api_metric_checkers = {
1797
1817
accuracy_score : [
1798
1818
check_array_api_binary_classification_metric ,
1799
1819
check_array_api_multiclass_classification_metric ,
@@ -1805,15 +1825,15 @@ def check_array_api_multiclass_classification_metric(
1805
1825
}
1806
1826
1807
1827
1808
- def yield_metric_checker_combinations (metric_checkers = metric_checkers ):
1828
+ def yield_metric_checker_combinations (metric_checkers = array_api_metric_checkers ):
1809
1829
for metric , checkers in metric_checkers .items ():
1810
1830
for checker in checkers :
1811
1831
yield metric , checker
1812
1832
1813
1833
1814
1834
@pytest .mark .parametrize (
1815
- "array_namespace, device, dtype " , yield_namespace_device_dtype_combinations ()
1835
+ "array_namespace, device, dtype_name " , yield_namespace_device_dtype_combinations ()
1816
1836
)
1817
1837
@pytest .mark .parametrize ("metric, check_func" , yield_metric_checker_combinations ())
1818
- def test_array_api_compliance (metric , array_namespace , device , dtype , check_func ):
1819
- check_func (metric , array_namespace , device , dtype )
1838
+ def test_array_api_compliance (metric , array_namespace , device , dtype_name , check_func ):
1839
+ check_func (metric , array_namespace , device , dtype_name )
0 commit comments