Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 509b9ff

Browse files
RMSLE (root mean squared log error) (#20326)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent 1e2c899 commit 509b9ff

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

doc/whats_new/v1.0.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ Changelog
427427
quantile regression. :pr:`19415` by :user:`Xavier Dupré <sdpython>`
428428
and :user:`Oliver Grisel <ogrisel>`.
429429

430+
- |Feature| :func:`metrics.mean_squared_log_error` now supports
431+
`squared=False`.
432+
:pr:`20326` by :user:`Uttam kumar <helper-uttam>`.
433+
430434
- |Efficiency| Improved speed of :func:`metrics.confusion_matrix` when labels
431435
are integral.
432436
:pr:`9843` by :user:`Jon Crall <Erotemic>`.

sklearn/metrics/_regression.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# Konstantin Shmelkov <[email protected]>
2222
# Christian Lorentzen <[email protected]>
2323
# Ashutosh Hathidara <[email protected]>
24+
# Uttam kumar <[email protected]>
2425
# License: BSD 3 clause
2526

2627
import numpy as np
@@ -437,7 +438,7 @@ def mean_squared_error(
437438

438439

439440
def mean_squared_log_error(
440-
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
441+
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average", squared=True
441442
):
442443
"""Mean squared logarithmic error regression loss.
443444
@@ -466,6 +467,9 @@ def mean_squared_log_error(
466467
467468
'uniform_average' :
468469
Errors of all outputs are averaged with uniform weight.
470+
squared : bool, default=True
471+
If True returns MSLE (mean squared log error) value.
472+
If False returns RMSLE (root mean squared log error) value.
469473
470474
Returns
471475
-------
@@ -480,6 +484,8 @@ def mean_squared_log_error(
480484
>>> y_pred = [2.5, 5, 4, 8]
481485
>>> mean_squared_log_error(y_true, y_pred)
482486
0.039...
487+
>>> mean_squared_log_error(y_true, y_pred, squared=False)
488+
0.199...
483489
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
484490
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
485491
>>> mean_squared_log_error(y_true, y_pred)
@@ -505,6 +511,7 @@ def mean_squared_log_error(
505511
np.log1p(y_pred),
506512
sample_weight=sample_weight,
507513
multioutput=multioutput,
514+
squared=squared,
508515
)
509516

510517

0 commit comments

Comments
 (0)