-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Describe the bug
The Naive Bayes classifier allows for a degenerate case in which there is just one class. sklearn.naive_bayes.GaussianNB
correctly handles this degenerate case, deterministically assigning the one class label to each sample at prediction time. In contrast, the discrete Naive Bayes classifiers do not properly handle data of this form. One can fit a model, but prediction can elicit a cryptic IndexError
. This error arises due to the fact that the implementation implicitly but unnecessarily assumes that the number of classes will never be less than two.
I have submitted PR #18925 to fix this bug.
By way of contrast, the SVM classifier is inherently binary, and sklearn.svm.SVC
appropriately raises an informative ValueError
when one attempts to fit a model on data with just one class.
Steps/Code to Reproduce
import numpy as np
from sklearn import naive_bayes
rng = np.random.default_rng(0)
num_outcomes = 3
num_experiments = 12
num_samples = 10
X = rng.multinomial(
n=num_experiments,
pvals=np.ones(num_outcomes) / num_outcomes,
size=num_samples)
y = np.ones(num_samples)
clf = naive_bayes.MultinomialNB(fit_prior=False)
clf.fit(X, y)
clf.predict(X)
Expected Results
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Actual Results
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-2-d1204a21d65a> in <module>
15 clf = naive_bayes.MultinomialNB(fit_prior=False)
16 clf.fit(X, y)
---> 17 clf.predict(X)
~/repos/scikit-learn/sklearn/naive_bayes.py in predict(self, X)
74 X = self._check_X(X)
75 jll = self._joint_log_likelihood(X)
---> 76 return self.classes_[np.argmax(jll, axis=1)]
77
78 def predict_log_proba(self, X):
IndexError: index 1 is out of bounds for axis 0 with size 1
Versions
System:
python: 3.8.6 (default, Nov 13 2020, 23:21:02) [Clang 11.0.0 (clang-1100.0.33.16)]
executable: /Users/dpoznik/.pyenv/versions/3.8.6/bin/python
machine: macOS-10.15.7-x86_64-i386-64bit
Python dependencies:
pip: 20.3
setuptools: 50.3.2
sklearn: 0.24.dev0
numpy: 1.19.4
scipy: 1.5.4
Cython: 0.29.21
pandas: 1.1.0
matplotlib: 3.3.3
joblib: 0.17.0
threadpoolctl: 2.1.0
Built with OpenMP: False