diff --git a/doc/whats_new/upcoming_changes/sklearn.naive_bayes/31556.fix.rst b/doc/whats_new/upcoming_changes/sklearn.naive_bayes/31556.fix.rst new file mode 100644 index 0000000000000..0f5b969bd9e6f --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.naive_bayes/31556.fix.rst @@ -0,0 +1,3 @@ +- :class:`naive_bayes.CategoricalNB` now correctly declares that it accepts + categorical features in the tags returned by its `__sklearn_tags__` method. + By :user:`Olivier Grisel ` diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index e5b03abbb903a..31a1b87af2916 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1433,6 +1433,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): def __sklearn_tags__(self): tags = super().__sklearn_tags__() + tags.input_tags.categorical = True tags.input_tags.sparse = False tags.input_tags.positive_only = True return tags diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 99cfe030a940f..f5638e7384e86 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -968,3 +968,12 @@ def test_predict_joint_proba(Estimator, global_random_seed): log_prob_x = logsumexp(jll, axis=1) log_prob_x_y = jll - np.atleast_2d(log_prob_x).T assert_allclose(est.predict_log_proba(X2), log_prob_x_y) + + +@pytest.mark.parametrize("Estimator", ALL_NAIVE_BAYES_CLASSES) +def test_categorical_input_tag(Estimator): + tags = Estimator().__sklearn_tags__() + if Estimator is CategoricalNB: + assert tags.input_tags.categorical + else: + assert not tags.input_tags.categorical diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py index 221236f8bc998..8d88ad23eb5e9 100644 --- a/sklearn/utils/_test_common/instance_generator.py +++ b/sklearn/utils/_test_common/instance_generator.py @@ -144,7 +144,6 @@ MultiOutputRegressor, RegressorChain, ) -from sklearn.naive_bayes import CategoricalNB from sklearn.neighbors import ( KernelDensity, KNeighborsClassifier, @@ -898,15 +897,6 @@ def _yield_instances_for_check(check, estimator_orig): "sample_weight is not equivalent to removing/repeating samples." ), }, - CategoricalNB: { - # TODO: fix sample_weight handling of this estimator, see meta-issue #16298 - "check_sample_weight_equivalence_on_dense_data": ( - "sample_weight is not equivalent to removing/repeating samples." - ), - "check_sample_weight_equivalence_on_sparse_data": ( - "sample_weight is not equivalent to removing/repeating samples." - ), - }, ColumnTransformer: { "check_estimators_empty_data_messages": "FIXME", "check_estimators_nan_inf": "FIXME", diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 156448698a780..ccff3cb44cad5 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3997,7 +3997,10 @@ def check_positive_only_tag_during_fit(name, estimator_orig): y = _enforce_estimator_tags_y(estimator, y) set_random_state(estimator, 0) X = _enforce_estimator_tags_X(estimator, X) - X -= X.mean() + # Make sure that the dtype of X stays unchanged: for instance estimator + # that expect categorical inputs typically expected integer-based encoded + # categories. + X -= X.mean().astype(X.dtype) if tags.input_tags.positive_only: with raises(ValueError, match="Negative values in data"):