2121# Konstantin Shmelkov <[email protected] > 2222# Christian Lorentzen <[email protected] > 2323# Ashutosh Hathidara <[email protected] > 24+ 2425# License: BSD 3 clause
2526
2627import numpy as np
@@ -437,7 +438,7 @@ def mean_squared_error(
437438
438439
439440def mean_squared_log_error (
440- y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average"
441+ y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average" , squared = True
441442):
442443 """Mean squared logarithmic error regression loss.
443444
@@ -466,6 +467,9 @@ def mean_squared_log_error(
466467
467468 'uniform_average' :
468469 Errors of all outputs are averaged with uniform weight.
470+ squared : bool, default=True
471+ If True returns MSLE (mean squared log error) value.
472+ If False returns RMSLE (root mean squared log error) value.
469473
470474 Returns
471475 -------
@@ -480,6 +484,8 @@ def mean_squared_log_error(
480484 >>> y_pred = [2.5, 5, 4, 8]
481485 >>> mean_squared_log_error(y_true, y_pred)
482486 0.039...
487+ >>> mean_squared_log_error(y_true, y_pred, squared=False)
488+ 0.199...
483489 >>> y_true = [[0.5, 1], [1, 2], [7, 6]]
484490 >>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
485491 >>> mean_squared_log_error(y_true, y_pred)
@@ -505,6 +511,7 @@ def mean_squared_log_error(
505511 np .log1p (y_pred ),
506512 sample_weight = sample_weight ,
507513 multioutput = multioutput ,
514+ squared = squared ,
508515 )
509516
510517
0 commit comments