diff --git a/scikits/learn/base.py b/scikits/learn/base.py index 7deca23023012..77d6fbf71d5f5 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -10,7 +10,7 @@ import numpy as np -from .metrics import explained_variance_score +from .metrics import r2_score ################################################################################ def clone(estimator, safe=True): @@ -236,7 +236,7 @@ class RegressorMixin(object): """ def score(self, X, y): - """ Returns the explained variance of the prediction + """ Returns the coefficient of determination of the prediction Parameters ---------- @@ -249,7 +249,7 @@ def score(self, X, y): ------- z : float """ - return explained_variance_score(y, self.predict(X)) + return r2_score(y, self.predict(X)) ################################################################################ diff --git a/scikits/learn/grid_search.py b/scikits/learn/grid_search.py index e2dfe27639912..9a24fd008fb72 100644 --- a/scikits/learn/grid_search.py +++ b/scikits/learn/grid_search.py @@ -176,7 +176,7 @@ class GridSearchCV(BaseEstimator): >>> svr = SVR() >>> clf = GridSearchCV(svr, parameters, n_jobs=1) >>> clf.fit(X, y).predict([[-0.8, -1]]) - array([ 1.14]) + array([ 1.13101459]) """ def __init__(self, estimator, param_grid, loss_func=None, score_func=None, diff --git a/scikits/learn/linear_model/base.py b/scikits/learn/linear_model/base.py index d67e8a5b28ea8..e0b6c62f74020 100644 --- a/scikits/learn/linear_model/base.py +++ b/scikits/learn/linear_model/base.py @@ -12,7 +12,7 @@ import numpy as np from ..base import BaseEstimator, RegressorMixin -from ..metrics import explained_variance_score +from ..metrics import r2_score ### ### TODO: intercept for all models @@ -41,9 +41,9 @@ def predict(self, X): X = np.asanyarray(X) return np.dot(X, self.coef_) + self.intercept_ - def _explained_variance(self, X, y): + def _r2_score(self, X, y): """Compute explained variance a.k.a. r^2""" - return explained_variance_score(y, self.predict(X)) + return r2_score(y, self.predict(X)) @staticmethod def _center_data(X, y, fit_intercept): diff --git a/scikits/learn/linear_model/bayes.py b/scikits/learn/linear_model/bayes.py index d040f52e7dda4..81e8c69ebcd9f 100644 --- a/scikits/learn/linear_model/bayes.py +++ b/scikits/learn/linear_model/bayes.py @@ -207,7 +207,7 @@ def fit(self, X, y, **params): self._set_intercept(Xmean, ymean) # Store explained variance for __str__ - self.explained_variance_ = self._explained_variance(X, y) + self.r2_score_ = self._r2_score(X, y) return self @@ -420,5 +420,5 @@ def fit(self, X, y, **params): self._set_intercept(Xmean, ymean) # Store explained variance for __str__ - self.explained_variance_ = self._explained_variance(X, y) + self.r2_score_ = self._r2_score(X, y) return self diff --git a/scikits/learn/linear_model/coordinate_descent.py b/scikits/learn/linear_model/coordinate_descent.py index 68ba8184549c6..7b65f96e5b461 100644 --- a/scikits/learn/linear_model/coordinate_descent.py +++ b/scikits/learn/linear_model/coordinate_descent.py @@ -102,7 +102,7 @@ def fit(self, X, y, maxit=1000, tol=1e-4, coef_init=None, **params): ' to increase the number of interations') # Store explained variance for __str__ - self.explained_variance_ = self._explained_variance(X, y) + self.r2_score_ = self._r2_score(X, y) # return self for chaining fit and predict calls return self @@ -354,7 +354,7 @@ def fit(self, X, y, cv=None, **fit_params): self.coef_ = model.coef_ self.intercept_ = model.intercept_ - self.explained_variance_ = model.explained_variance_ + self.r2_score_ = model.r2_score_ self.alpha = model.alpha self.alphas = np.asarray(alphas) return self diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index f79b4e783c09d..c2e646146201d 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -499,7 +499,7 @@ def precision_recall_curve(y_true, probas_pred): def explained_variance_score(y_true, y_pred): """Explained variance regression score function - Best possible score is 1.0, lower values are worst. + Best possible score is 1.0, lower values are worse. Note: the explained variance is not a symmetric function. @@ -512,6 +512,25 @@ def explained_variance_score(y_true, y_pred): y_pred : array-like """ return 1 - np.var(y_true - y_pred) / np.var(y_true) + + +def r2_score(y_true, y_pred): + """R^2 (coefficient of determination) regression score function + + Best possible score is 1.0, lower values are worse. + + Note: not a symmetric function. + + return the R^2 score + + Parameters + ---------- + y_true : array-like + + y_pred : array-like + """ + return 1 - ((y_true - y_pred)**2).sum() / ((y_true - y_true.mean())**2).sum() + ############################################################################### diff --git a/scikits/learn/tests/test_metrics.py b/scikits/learn/tests/test_metrics.py index 93fac32958ae1..f3c8896651d5c 100644 --- a/scikits/learn/tests/test_metrics.py +++ b/scikits/learn/tests/test_metrics.py @@ -2,7 +2,7 @@ import numpy as np import nose -from numpy.testing import assert_ +from nose.tools import assert_true from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_equal from numpy.testing import assert_equal, assert_almost_equal @@ -13,6 +13,7 @@ from ..metrics import classification_report from ..metrics import confusion_matrix from ..metrics import explained_variance_score +from ..metrics import r2_score from ..metrics import f1_score from ..metrics import mean_square_error from ..metrics import precision_recall_curve @@ -222,6 +223,9 @@ def test_losses(): assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2) assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2) + assert_almost_equal(r2_score(y_true, y_pred), -0.04, 2) + assert_almost_equal(r2_score(y_true, y_true), 1.00, 2) + def test_symmetry(): """Test the symmetry of score and loss functions""" @@ -233,8 +237,10 @@ def test_symmetry(): assert_almost_equal(mean_square_error(y_true, y_pred), mean_square_error(y_pred, y_true)) # not symmetric - assert_(explained_variance_score(y_true, y_pred) != \ + assert_true(explained_variance_score(y_true, y_pred) != \ explained_variance_score(y_pred, y_true)) + assert_true(r2_score(y_true, y_pred) != \ + r2_score(y_pred, y_true)) # FIXME: precision and recall aren't symmetric either