@@ -846,33 +846,36 @@ def test_check_alpha():
846846 https://github.com/scikit-learn/scikit-learn/issues/10772
847847 """
848848 _ALPHA_MIN = 1e-10
849+ X = np .array ([[2 , 1 ], [1 , 1 ]])
850+ y = np .array ([0 , 1 ])
849851 b = BernoulliNB (alpha = 0 , force_alpha = True )
850- assert b ._check_alpha () == 0
852+ b .fit (X , y )
853+ assert b ._alpha == 0
851854
852855 alphas = np .array ([0.0 , 1.0 ])
853856
854857 b = BernoulliNB (alpha = alphas , force_alpha = True )
855- # We manually set `n_features_in_` not to have `_check_alpha` err
856- b .n_features_in_ = alphas .shape [0 ]
857- assert_array_equal (b ._check_alpha (), alphas )
858+ b .fit (X , y )
859+ assert_array_equal (b ._alpha , alphas )
858860
859861 msg = (
860862 "alpha too small will result in numeric errors, setting alpha = %.1e"
861863 % _ALPHA_MIN
862864 )
863865 b = BernoulliNB (alpha = 0 , force_alpha = False )
864866 with pytest .warns (UserWarning , match = msg ):
865- assert b ._check_alpha () == _ALPHA_MIN
867+ b .fit (X , y )
868+ assert b ._alpha == _ALPHA_MIN
866869
867870 b = BernoulliNB (alpha = 0 )
868871 with pytest .warns (UserWarning , match = msg ):
869- assert b ._check_alpha () == _ALPHA_MIN
872+ b .fit (X , y )
873+ assert b ._alpha == _ALPHA_MIN
870874
871875 b = BernoulliNB (alpha = alphas , force_alpha = False )
872- # We manually set `n_features_in_` not to have `_check_alpha` err
873- b .n_features_in_ = alphas .shape [0 ]
874876 with pytest .warns (UserWarning , match = msg ):
875- assert_array_equal (b ._check_alpha (), np .array ([_ALPHA_MIN , 1.0 ]))
877+ b .fit (X , y )
878+ assert_array_equal (b ._alpha , np .array ([_ALPHA_MIN , 1.0 ]))
876879
877880
878881def test_alpha_vector ():
0 commit comments