Closed
Description
Description
While reviewing #8478, we encountered a BayesRidge
returning some NaN arrarys for the prediction and for the std of the posterior. Apparently, this behavior is triggered when the provided y
is constant. It is in fact due to a division by zero during fit
when initializing alpha such as 1. / np.var(y)
.
I am not really familiar with the code base, but I would assume that the predictions should not be NaN.
ping @agramfort @TomDLT @jnothman
NB: the division by zero warning was not printing during the testing in the original PR which is strange. Is pytest not catching the printing when test is failing? @lesteve
Steps/Code to Reproduce
import numpy as np
from sklearn.linear_model import BayesianRidge
X = np.random.random((100, 100))
y = np.ones(100)
estimator = BayesianRidge()
y_pred, std = estimator.fit(X, y).predict(X, return_std=True)
print(y_pred)
print(std)
Expected Results
Prediction should be:
[ 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
I would expect maybe an array of zero for the std.
Actual Results
/home/lemaitre/Documents/code/toolbox/scikit-learn/sklearn/linear_model/bayes.py:165: RuntimeWarning: divide by zero encountered in double_scalars
alpha_ = 1. / np.var(y)
/home/lemaitre/Documents/code/toolbox/scikit-learn/sklearn/linear_model/bayes.py:212: RuntimeWarning: invalid value encountered in true_divide
(lambda_ + alpha_ * eigen_vals_)))
[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan]
[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan]
Versions
Linux-4.8.17-040817-generic-x86_64-with-debian-stretch-sid
Python 3.6.1 |Continuum Analytics, Inc.| (default, May 11 2017, 13:09:58)
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
NumPy 1.13.1
SciPy 0.19.1
Scikit-Learn 0.20.dev0