diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index e4bff3c124dc5..eaec51b0e5a6a 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -427,6 +427,10 @@ Changelog quantile regression. :pr:`19415` by :user:`Xavier Dupré ` and :user:`Oliver Grisel `. +- |Feature| :func:`metrics.mean_squared_log_error` now supports + `squared=False`. + :pr:`20326` by :user:`Uttam kumar `. + - |Efficiency| Improved speed of :func:`metrics.confusion_matrix` when labels are integral. :pr:`9843` by :user:`Jon Crall `. diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index a2d7fd0d41bcb..f20308b6c5660 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -21,6 +21,7 @@ # Konstantin Shmelkov # Christian Lorentzen # Ashutosh Hathidara +# Uttam kumar # License: BSD 3 clause import numpy as np @@ -437,7 +438,7 @@ def mean_squared_error( def mean_squared_log_error( - y_true, y_pred, *, sample_weight=None, multioutput="uniform_average" + y_true, y_pred, *, sample_weight=None, multioutput="uniform_average", squared=True ): """Mean squared logarithmic error regression loss. @@ -466,6 +467,9 @@ def mean_squared_log_error( 'uniform_average' : Errors of all outputs are averaged with uniform weight. + squared : bool, default=True + If True returns MSLE (mean squared log error) value. + If False returns RMSLE (root mean squared log error) value. Returns ------- @@ -480,6 +484,8 @@ def mean_squared_log_error( >>> y_pred = [2.5, 5, 4, 8] >>> mean_squared_log_error(y_true, y_pred) 0.039... + >>> mean_squared_log_error(y_true, y_pred, squared=False) + 0.199... >>> y_true = [[0.5, 1], [1, 2], [7, 6]] >>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]] >>> mean_squared_log_error(y_true, y_pred) @@ -505,6 +511,7 @@ def mean_squared_log_error( np.log1p(y_pred), sample_weight=sample_weight, multioutput=multioutput, + squared=squared, )