Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 14a360f

Browse files
committed
Added private _alpha attribute for testing ease
1 parent 63f9fd8 commit 14a360f

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

sklearn/naive_bayes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,8 +706,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
706706
# be called by the user explicitly just once after several consecutive
707707
# calls to partial_fit and prior any call to predict[_[log_]proba]
708708
# to avoid computing the smooth log probas at each call to partial fit
709-
alpha = self._check_alpha()
710-
self._update_feature_log_prob(alpha)
709+
self._alpha = self._check_alpha()
710+
self._update_feature_log_prob(self._alpha)
711711
self._update_class_log_prior(class_prior=class_prior)
712712
return self
713713

@@ -760,8 +760,8 @@ def fit(self, X, y, sample_weight=None):
760760
n_classes = Y.shape[1]
761761
self._init_counters(n_classes, n_features)
762762
self._count(X, Y)
763-
alpha = self._check_alpha()
764-
self._update_feature_log_prob(alpha)
763+
self._alpha = self._check_alpha()
764+
self._update_feature_log_prob(self._alpha)
765765
self._update_class_log_prior(class_prior=class_prior)
766766
return self
767767

sklearn/tests/test_naive_bayes.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -846,33 +846,36 @@ def test_check_alpha():
846846
https://github.com/scikit-learn/scikit-learn/issues/10772
847847
"""
848848
_ALPHA_MIN = 1e-10
849+
X = np.array([[2, 1], [1, 1]])
850+
y = np.array([0, 1])
849851
b = BernoulliNB(alpha=0, force_alpha=True)
850-
assert b._check_alpha() == 0
852+
b.fit(X, y)
853+
assert b._alpha == 0
851854

852855
alphas = np.array([0.0, 1.0])
853856

854857
b = BernoulliNB(alpha=alphas, force_alpha=True)
855-
# We manually set `n_features_in_` not to have `_check_alpha` err
856-
b.n_features_in_ = alphas.shape[0]
857-
assert_array_equal(b._check_alpha(), alphas)
858+
b.fit(X, y)
859+
assert_array_equal(b._alpha, alphas)
858860

859861
msg = (
860862
"alpha too small will result in numeric errors, setting alpha = %.1e"
861863
% _ALPHA_MIN
862864
)
863865
b = BernoulliNB(alpha=0, force_alpha=False)
864866
with pytest.warns(UserWarning, match=msg):
865-
assert b._check_alpha() == _ALPHA_MIN
867+
b.fit(X, y)
868+
assert b._alpha == _ALPHA_MIN
866869

867870
b = BernoulliNB(alpha=0)
868871
with pytest.warns(UserWarning, match=msg):
869-
assert b._check_alpha() == _ALPHA_MIN
872+
b.fit(X, y)
873+
assert b._alpha == _ALPHA_MIN
870874

871875
b = BernoulliNB(alpha=alphas, force_alpha=False)
872-
# We manually set `n_features_in_` not to have `_check_alpha` err
873-
b.n_features_in_ = alphas.shape[0]
874876
with pytest.warns(UserWarning, match=msg):
875-
assert_array_equal(b._check_alpha(), np.array([_ALPHA_MIN, 1.0]))
877+
b.fit(X, y)
878+
assert_array_equal(b._alpha, np.array([_ALPHA_MIN, 1.0]))
876879

877880

878881
def test_alpha_vector():

0 commit comments

Comments
 (0)