diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 8806719eabdd1..62251f9b96188 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -71,6 +71,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None): dtype : str or list, default="numeric" the dtype argument passed to check_array. + xp : module, default=None + Precomputed array namespace module. When passed, typically from a caller + that has already performed inspection of its own inputs, skips array + namespace inspection. + Returns ------- type_true : one of {'continuous', continuous-multioutput'} @@ -398,7 +403,7 @@ def mean_absolute_percentage_error( dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp) y_type, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput + y_true, y_pred, multioutput, dtype=dtype, xp=xp ) check_consistent_length(y_true, y_pred, sample_weight) epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype) @@ -1253,7 +1258,7 @@ def max_error(y_true, y_pred): np.int64(1) """ xp, _ = get_namespace(y_true, y_pred) - y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None) + y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None, xp=xp) if y_type == "continuous-multioutput": raise ValueError("Multioutput not supported in max_error") return xp.max(xp.abs(y_true - y_pred)) @@ -1352,7 +1357,7 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0): """ xp, _ = get_namespace(y_true, y_pred) y_type, y_true, y_pred, _ = _check_reg_targets( - y_true, y_pred, None, dtype=[xp.float64, xp.float32] + y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp ) if y_type == "continuous-multioutput": raise ValueError("Multioutput not supported in mean_tweedie_deviance")