Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Discrete Naive Bayes classifiers crash unnecessarily on degenerate one-class data #18974

@dpoznik

Description

@dpoznik

Describe the bug

The Naive Bayes classifier allows for a degenerate case in which there is just one class. sklearn.naive_bayes.GaussianNB correctly handles this degenerate case, deterministically assigning the one class label to each sample at prediction time. In contrast, the discrete Naive Bayes classifiers do not properly handle data of this form. One can fit a model, but prediction can elicit a cryptic IndexError. This error arises due to the fact that the implementation implicitly but unnecessarily assumes that the number of classes will never be less than two.

I have submitted PR #18925 to fix this bug.

By way of contrast, the SVM classifier is inherently binary, and sklearn.svm.SVC appropriately raises an informative ValueError when one attempts to fit a model on data with just one class.

Steps/Code to Reproduce

import numpy as np
from sklearn import naive_bayes

rng = np.random.default_rng(0)
num_outcomes = 3
num_experiments = 12
num_samples = 10

X = rng.multinomial(
    n=num_experiments,
    pvals=np.ones(num_outcomes) / num_outcomes,
    size=num_samples)
y = np.ones(num_samples)

clf = naive_bayes.MultinomialNB(fit_prior=False)
clf.fit(X, y)
clf.predict(X)

Expected Results

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Actual Results

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-2-d1204a21d65a> in <module>
     15 clf = naive_bayes.MultinomialNB(fit_prior=False)
     16 clf.fit(X, y)
---> 17 clf.predict(X)

~/repos/scikit-learn/sklearn/naive_bayes.py in predict(self, X)
     74         X = self._check_X(X)
     75         jll = self._joint_log_likelihood(X)
---> 76         return self.classes_[np.argmax(jll, axis=1)]
     77 
     78     def predict_log_proba(self, X):

IndexError: index 1 is out of bounds for axis 0 with size 1

Versions

System:
    python: 3.8.6 (default, Nov 13 2020, 23:21:02)  [Clang 11.0.0 (clang-1100.0.33.16)]
executable: /Users/dpoznik/.pyenv/versions/3.8.6/bin/python
   machine: macOS-10.15.7-x86_64-i386-64bit

Python dependencies:
          pip: 20.3
   setuptools: 50.3.2
      sklearn: 0.24.dev0
        numpy: 1.19.4
        scipy: 1.5.4
       Cython: 0.29.21
       pandas: 1.1.0
   matplotlib: 3.3.3
       joblib: 0.17.0
threadpoolctl: 2.1.0

Built with OpenMP: False

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions