From 06891ebba9cd6a4bad682df771a6b5031b053532 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 14 Jul 2018 16:56:11 -0400 Subject: [PATCH 01/38] WIP - First draft on Yeo-Johnson transform --- sklearn/preprocessing/data.py | 64 +++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index e3c72d6884591..8fd6b6b202c7e 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -17,6 +17,7 @@ import numpy as np from scipy import sparse from scipy import stats +from scipy.optimize import brent from ..base import BaseEstimator, TransformerMixin from ..externals import six @@ -2491,10 +2492,15 @@ def fit(self, X, y=None): transformed = [] for col in X.T: - # the computation of lambda is influenced by NaNs and we need to - # get rid of them to compute them. - _, lmbda = stats.boxcox(col[~np.isnan(col)], lmbda=None) - col_trans = boxcox(col, lmbda) + if self.method == 'box-cox': + # the computation of lambda is influenced by NaNs and we need to + # get rid of them to compute them. + _, lmbda = stats.boxcox(col[~np.isnan(col)], lmbda=None) + col_trans = boxcox(col, lmbda) + else: # neo-johnson + lmbda = 1 #self._yeo_johnson_optimize(col) + col_trans = self._yeo_johnson_transform(col, lmbda) + self.lambdas_.append(lmbda) transformed.append(col_trans) @@ -2507,6 +2513,49 @@ def fit(self, X, y=None): return self + def _yeo_johnson_loglikelihood(self, x, lmbda): + psi = self._yeo_johnson_transform(x, lmbda) + n = x.shape[0] + + # Estimated mean and variance of the normal distribution + mu = psi.sum() / n + sig_sq = np.power(psi - mu, 2).sum() / n + + loglike = (-.5 / sig_sq) * np.power(psi - mu, 2).sum() + loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() + + return loglike + + def _yeo_johnson_transform(self, x, lmbda): + + out = np.zeros(shape=x.shape) + pos = (x >= 0) # binary mask + + # when x >= 0 + if lmbda < 1e-19: + out[pos] = np.log(x[pos] + 1) + else: #lmbda != 0 + out[pos] = (np.power(x[pos] + 1, lmbda) - 1) / lmbda + + # when x < 0 + if lmbda < 2 - 1e-19: + out[~pos] = -(np.power(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) + else: # lmbda == 2 + out[~pos] = -np.log(-x[~pos] + 1) + + return out + + def _yeo_johnson_optimize(self, x): + """Find and return optimal lambda parameter of the transform by maximum + likelihood optimization. + """ + + def objective(lmbda, x): + # solver is a minimizer and we need to maximise the log likelihood + return -self._yeo_johnson_loglikelihood(x, lmbda) + + return brent(objective, brack=(0, 2), args=(x,)) + def transform(self, X): """Apply the power transform to each feature using the fitted lambdas. @@ -2518,8 +2567,11 @@ def transform(self, X): check_is_fitted(self, 'lambdas_') X = self._check_input(X, check_positive=True, check_shape=True) + trans_fun = {'box-cox': boxcox, + 'yeo-johnson': self._yeo_johnson_transform + }[self.method] for i, lmbda in enumerate(self.lambdas_): - X[:, i] = boxcox(X[:, i], lmbda) + X[:, i] = trans_fun(X[:, i], lmbda) if self.standardize: X = self._scaler.transform(X) @@ -2590,7 +2642,7 @@ def _check_input(self, X, check_positive=False, check_shape=False, "than fitting data. Should have {n}, data has {m}" .format(n=len(self.lambdas_), m=X.shape[1])) - valid_methods = ('box-cox',) + valid_methods = ('box-cox', 'yeo-johnson') if check_method and self.method not in valid_methods: raise ValueError("'method' must be one of {}, " "got {} instead." From a88d168ec89ce00259bc5efe384ff238d29d8531 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 14 Jul 2018 19:01:10 -0400 Subject: [PATCH 02/38] Fixed lambda param optimization The issue was from an error in the log likelihood function --- sklearn/preprocessing/data.py | 46 +++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 8fd6b6b202c7e..aefb410390970 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -17,7 +17,7 @@ import numpy as np from scipy import sparse from scipy import stats -from scipy.optimize import brent +from scipy import optimize from ..base import BaseEstimator, TransformerMixin from ..externals import six @@ -2498,7 +2498,7 @@ def fit(self, X, y=None): _, lmbda = stats.boxcox(col[~np.isnan(col)], lmbda=None) col_trans = boxcox(col, lmbda) else: # neo-johnson - lmbda = 1 #self._yeo_johnson_optimize(col) + lmbda = self._yeo_johnson_optimize(col) col_trans = self._yeo_johnson_transform(col, lmbda) self.lambdas_.append(lmbda) @@ -2513,24 +2513,14 @@ def fit(self, X, y=None): return self - def _yeo_johnson_loglikelihood(self, x, lmbda): - psi = self._yeo_johnson_transform(x, lmbda) - n = x.shape[0] - - # Estimated mean and variance of the normal distribution - mu = psi.sum() / n - sig_sq = np.power(psi - mu, 2).sum() / n - - loglike = (-.5 / sig_sq) * np.power(psi - mu, 2).sum() - loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() - - return loglike - def _yeo_johnson_transform(self, x, lmbda): out = np.zeros(shape=x.shape) pos = (x >= 0) # binary mask + # Note: we're comparing lmbda to 1e-19 instead of strict equality to 0. + # See scipy/special/_boxcox.pxd for a rationale behind this + # when x >= 0 if lmbda < 1e-19: out[pos] = np.log(x[pos] + 1) @@ -2546,15 +2536,29 @@ def _yeo_johnson_transform(self, x, lmbda): return out def _yeo_johnson_optimize(self, x): - """Find and return optimal lambda parameter of the transform by maximum - likelihood optimization. + """Find and return optimal lambda parameter of the Yeo-Johnson + transform by MLE, for observed data x. + + Like for coxbox, MLE is done via the brent optimizer. """ - def objective(lmbda, x): - # solver is a minimizer and we need to maximise the log likelihood - return -self._yeo_johnson_loglikelihood(x, lmbda) + def _neg_log_likelihood(lmbda): + """Return the negative log likelihood of the observed data x as a + function of lambda.""" + psi = self._yeo_johnson_transform(x, lmbda) + n = x.shape[0] + + # Estimated mean and variance of the normal distribution + mu = psi.sum() / n + sig_sq = np.power(psi - mu, 2).sum() / n + + loglike = -n / 2 * np.log(sig_sq) + loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() + + return -loglike - return brent(objective, brack=(0, 2), args=(x,)) + # choosing backet -2, 2 like for boxcox + return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) def transform(self, X): """Apply the power transform to each feature using the fitted lambdas. From ee09d7fcf0c7426640dcbaef6e93c3f870a21baa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 10:38:14 -0400 Subject: [PATCH 03/38] Some first tests Need to write inverse_transform to continue --- sklearn/preprocessing/tests/test_data.py | 39 +++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 2ff9dfd776a03..30548b0cb8ba2 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2003,8 +2003,9 @@ def test_quantile_transform_valid_axis(): ". Got axis=2", quantile_transform, X.T, axis=2) -def test_power_transformer_notfitted(): - pt = PowerTransformer(method='box-cox') +@pytest.mark.parametrize("method", ['box-cox', 'yeo-johnson']) +def test_power_transformer_notfitted(method): + pt = PowerTransformer(method=method) X = np.abs(X_1col) assert_raises(NotFittedError, pt.transform, X) assert_raises(NotFittedError, pt.inverse_transform, X) @@ -2062,10 +2063,11 @@ def test_power_transformer_2d(): def test_power_transformer_strictly_positive_exception(): + # Exceptions should be raised for negative arrays and zero arrays when + # method is coxbox + pt = PowerTransformer(method='box-cox') pt.fit(np.abs(X_2d)) - - # Exceptions should be raised for negative arrays and zero arrays X_with_negatives = X_2d not_positive_message = 'strictly positive' @@ -2076,7 +2078,7 @@ def test_power_transformer_strictly_positive_exception(): pt.fit, X_with_negatives) assert_raise_message(ValueError, not_positive_message, - power_transform, X_with_negatives) + power_transform, X_with_negatives, 'box-cox') assert_raise_message(ValueError, not_positive_message, pt.transform, np.zeros(X_2d.shape)) @@ -2085,11 +2087,22 @@ def test_power_transformer_strictly_positive_exception(): pt.fit, np.zeros(X_2d.shape)) assert_raise_message(ValueError, not_positive_message, - power_transform, np.zeros(X_2d.shape)) + power_transform, np.zeros(X_2d.shape), 'box-cox') + # It should not raise any error for yeo-johnson + pt = PowerTransformer(method='yeo-johnson') + pt.fit(np.abs(X_2d)) + pt.transform(X_with_negatives) + pt.fit(X_with_negatives) + power_transform(X_with_negatives, method='yeo-johnson') + pt.transform(np.zeros(X_2d.shape)) + pt.fit(np.zeros(X_2d.shape)) + power_transform(np.zeros(X_2d.shape), method='yeo-johnson') -def test_power_transformer_shape_exception(): - pt = PowerTransformer(method='box-cox') + +@pytest.mark.parametrize("method", ['box-cox', 'yeo-johnson']) +def test_power_transformer_shape_exception(method): + pt = PowerTransformer(method=method) X = np.abs(X_2d) pt.fit(X) @@ -2122,3 +2135,13 @@ def test_power_transformer_lambda_zero(): pt.lambdas_ = np.array([0]) X_trans = pt.transform(X) assert_array_almost_equal(pt.inverse_transform(X_trans), X) + + +def test_power_transformer_lambda_one(): + # Make sure lambda = 1 corresponds to the identity for yeo-johnson + pt = PowerTransformer(method='yeo-johnson', standardize=False) + X = np.abs(X_2d)[:, 0:1] + + pt.lambdas_ = np.array([1]) + X_trans = pt.transform(X) + assert_array_almost_equal(X_trans, X) From aea084260d63579adc542a465bd25263d3998bcc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 10:41:45 -0400 Subject: [PATCH 04/38] Put helper method for yeo-johnson at the end --- sklearn/preprocessing/data.py | 94 +++++++++++++++++------------------ 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index aefb410390970..65b4c1ce845bb 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2513,53 +2513,6 @@ def fit(self, X, y=None): return self - def _yeo_johnson_transform(self, x, lmbda): - - out = np.zeros(shape=x.shape) - pos = (x >= 0) # binary mask - - # Note: we're comparing lmbda to 1e-19 instead of strict equality to 0. - # See scipy/special/_boxcox.pxd for a rationale behind this - - # when x >= 0 - if lmbda < 1e-19: - out[pos] = np.log(x[pos] + 1) - else: #lmbda != 0 - out[pos] = (np.power(x[pos] + 1, lmbda) - 1) / lmbda - - # when x < 0 - if lmbda < 2 - 1e-19: - out[~pos] = -(np.power(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) - else: # lmbda == 2 - out[~pos] = -np.log(-x[~pos] + 1) - - return out - - def _yeo_johnson_optimize(self, x): - """Find and return optimal lambda parameter of the Yeo-Johnson - transform by MLE, for observed data x. - - Like for coxbox, MLE is done via the brent optimizer. - """ - - def _neg_log_likelihood(lmbda): - """Return the negative log likelihood of the observed data x as a - function of lambda.""" - psi = self._yeo_johnson_transform(x, lmbda) - n = x.shape[0] - - # Estimated mean and variance of the normal distribution - mu = psi.sum() / n - sig_sq = np.power(psi - mu, 2).sum() / n - - loglike = -n / 2 * np.log(sig_sq) - loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() - - return -loglike - - # choosing backet -2, 2 like for boxcox - return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) - def transform(self, X): """Apply the power transform to each feature using the fitted lambdas. @@ -2613,6 +2566,53 @@ def inverse_transform(self, X): return X + def _yeo_johnson_transform(self, x, lmbda): + + out = np.zeros(shape=x.shape) + pos = (x >= 0) # binary mask + + # Note: we're comparing lmbda to 1e-19 instead of strict equality to 0. + # See scipy/special/_boxcox.pxd for a rationale behind this + + # when x >= 0 + if lmbda < 1e-19: + out[pos] = np.log(x[pos] + 1) + else: #lmbda != 0 + out[pos] = (np.power(x[pos] + 1, lmbda) - 1) / lmbda + + # when x < 0 + if lmbda < 2 - 1e-19: + out[~pos] = -(np.power(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) + else: # lmbda == 2 + out[~pos] = -np.log(-x[~pos] + 1) + + return out + + def _yeo_johnson_optimize(self, x): + """Find and return optimal lambda parameter of the Yeo-Johnson + transform by MLE, for observed data x. + + Like for coxbox, MLE is done via the brent optimizer. + """ + + def _neg_log_likelihood(lmbda): + """Return the negative log likelihood of the observed data x as a + function of lambda.""" + psi = self._yeo_johnson_transform(x, lmbda) + n = x.shape[0] + + # Estimated mean and variance of the normal distribution + mu = psi.sum() / n + sig_sq = np.power(psi - mu, 2).sum() / n + + loglike = -n / 2 * np.log(sig_sq) + loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() + + return -loglike + + # choosing backet -2, 2 like for boxcox + return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) + def _check_input(self, X, check_positive=False, check_shape=False, check_method=False): """Validate the input before fit and transform. From fba12eba86a147e15f53a77b506d424ebf4f57a1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 11:50:37 -0400 Subject: [PATCH 05/38] Added inverse transform + some tests --- sklearn/preprocessing/data.py | 41 +++++++++++++++++++----- sklearn/preprocessing/tests/test_data.py | 15 +++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 65b4c1ce845bb..6b408719a4bba 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2497,7 +2497,7 @@ def fit(self, X, y=None): # get rid of them to compute them. _, lmbda = stats.boxcox(col[~np.isnan(col)], lmbda=None) col_trans = boxcox(col, lmbda) - else: # neo-johnson + else: # yeo-johnson lmbda = self._yeo_johnson_optimize(col) col_trans = self._yeo_johnson_transform(col, lmbda) @@ -2556,16 +2556,41 @@ def inverse_transform(self, X): if self.standardize: X = self._scaler.inverse_transform(X) + inv_fun = {'box-cox': self._box_cox_inverse_tranform, + 'yeo-johnson': self._yeo_johnson_inverse_transform + }[self.method] for i, lmbda in enumerate(self.lambdas_): - x = X[:, i] - if lmbda == 0: - x_inv = np.exp(x) - else: - x_inv = (x * lmbda + 1) ** (1 / lmbda) - X[:, i] = x_inv + X[:, i] = inv_fun(X[:, i], lmbda) return X + def _box_cox_inverse_tranform(self, x, lmbda): + if lmbda == 0: + x_inv = np.exp(x) + else: + x_inv = (x * lmbda + 1) ** (1 / lmbda) + + return x_inv + + def _yeo_johnson_inverse_transform(self, x, lmbda): + x_inv = np.zeros(x.shape) + pos = x >= 0 + + # when x >= 0 + if lmbda < 1e-19: + x_inv[pos] = np.exp(x[pos]) - 1 + else: # lmbda != 0 + x_inv[pos] = np.power(x[pos] * lmbda + 1, 1 / lmbda) - 1 + + # when x < 0 + if lmbda < 2 - 1e-19: + x_inv[~pos] = 1 - np.power(-(2 - lmbda) * x[~pos] + 1, + 1 / (2 - lmbda)) + else: # lmbda == 2 + x_inv[~pos] = 1 - np.exp(-x[~pos]) + + return x_inv + def _yeo_johnson_transform(self, x, lmbda): out = np.zeros(shape=x.shape) @@ -2577,7 +2602,7 @@ def _yeo_johnson_transform(self, x, lmbda): # when x >= 0 if lmbda < 1e-19: out[pos] = np.log(x[pos] + 1) - else: #lmbda != 0 + else: # lmbda != 0 out[pos] = (np.power(x[pos] + 1, lmbda) - 1) / lmbda # when x < 0 diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 30548b0cb8ba2..576f79b90104a 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2011,6 +2011,21 @@ def test_power_transformer_notfitted(method): assert_raises(NotFittedError, pt.inverse_transform, X) +@pytest.mark.parametrize('standardize', [True, False]) +@pytest.mark.parametrize('X', [X_1col, X_2d]) +def test_power_transformer_inverse(standardize, X): + # Make sure we get the original input when applying transform and then + # inverse transform + pt = PowerTransformer(method='yeo-johnson', standardize=standardize) + X_trans = pt.fit_transform(X) + assert_almost_equal(X, pt.inverse_transform(X_trans)) + + X = np.abs(X) + pt = PowerTransformer(method='box-cox', standardize=standardize) + X_trans = pt.fit_transform(X) + assert_almost_equal(X, pt.inverse_transform(X_trans)) + + def test_power_transformer_1d(): X = np.abs(X_1col) From ed5a411610ccdf9387698da091acec27f4aea33a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 15:26:16 -0400 Subject: [PATCH 06/38] Added test for the optimization procedures --- sklearn/preprocessing/tests/test_data.py | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 576f79b90104a..ed94f01c3d810 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2160,3 +2160,28 @@ def test_power_transformer_lambda_one(): pt.lambdas_ = np.array([1]) X_trans = pt.transform(X) assert_array_almost_equal(X_trans, X) + + +@pytest.mark.parametrize("method, lmbda", [('box-cox', .5), + ('yeo-johnson', .1)]) +def test_optimization_power_transformer(method, lmbda): + """Test the optimization procedure + + - set a predefined value for lambda + - apply inverse_transform to a normal dist (we get X_inv) + - apply fit_transform to X_inv (we get X_inv_trans) + - check that X_inv_trans is roughly equal to X + """ + + rng = np.random.RandomState(0) + n_samples = 1000 + X = rng.normal(size=(n_samples, 1)) + + pt = PowerTransformer(method=method, standardize=False) + pt.lambdas_ = [lmbda] + X_inv = pt.inverse_transform(X) + pt.lambdas_ = [9999] # just to make sure + X_inv_trans = pt.fit_transform(X_inv) + + assert_almost_equal(0, np.linalg.norm(X - X_inv_trans) / n_samples, + decimal=2) From 8bab32e3e1ac7e43a812fd2a9d605b812c05fec3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 15:45:59 -0400 Subject: [PATCH 07/38] Created _box_cox_optimize method for better code symmetry --- sklearn/preprocessing/data.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 6b408719a4bba..b90b1fdd6fbc2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2491,17 +2491,18 @@ def fit(self, X, y=None): self.lambdas_ = [] transformed = [] - for col in X.T: - if self.method == 'box-cox': - # the computation of lambda is influenced by NaNs and we need to - # get rid of them to compute them. - _, lmbda = stats.boxcox(col[~np.isnan(col)], lmbda=None) - col_trans = boxcox(col, lmbda) - else: # yeo-johnson - lmbda = self._yeo_johnson_optimize(col) - col_trans = self._yeo_johnson_transform(col, lmbda) + opt_fun = {'box-cox': self._box_cox_optimize, + 'yeo-johnson': self._yeo_johnson_optimize + }[self.method] + trans_fun = {'box-cox': boxcox, + 'yeo-johnson': self._yeo_johnson_transform + }[self.method] + for col in X.T: + lmbda = opt_fun(col) self.lambdas_.append(lmbda) + + col_trans = trans_fun(col, lmbda) transformed.append(col_trans) self.lambdas_ = np.array(self.lambdas_) @@ -2613,6 +2614,18 @@ def _yeo_johnson_transform(self, x, lmbda): return out + def _box_cox_optimize(self, x): + """Find and return optimal lambda parameter of the Box-Cox transform by + MLE, for observed data x. + + We here use scipy builtins which uses the brent optimizer. + """ + # the computation of lambda is influenced by NaNs and we need to + # get rid of them to compute them. + _, lmbda = stats.boxcox(x[~np.isnan(x)], lmbda=None) + + return lmbda + def _yeo_johnson_optimize(self, x): """Find and return optimal lambda parameter of the Yeo-Johnson transform by MLE, for observed data x. From 0525bab15b91c4e8bf7bf7ad72ea29fe8ce48978 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 16:30:03 -0400 Subject: [PATCH 08/38] Opt for yeo-johnson not influenced by Nan Also added related test --- sklearn/preprocessing/data.py | 7 +++++-- sklearn/preprocessing/tests/test_data.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index b90b1fdd6fbc2..6578db0d81df6 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2620,8 +2620,8 @@ def _box_cox_optimize(self, x): We here use scipy builtins which uses the brent optimizer. """ - # the computation of lambda is influenced by NaNs and we need to - # get rid of them to compute them. + # the computation of lambda is influenced by NaNs so we need to + # get rid of them _, lmbda = stats.boxcox(x[~np.isnan(x)], lmbda=None) return lmbda @@ -2648,6 +2648,9 @@ def _neg_log_likelihood(lmbda): return -loglike + # the computation of lambda is influenced by NaNs so we need to + # get rid of them + x = x[~np.isnan(x)] # choosing backet -2, 2 like for boxcox return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index ed94f01c3d810..4c6e112345c5a 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2185,3 +2185,22 @@ def test_optimization_power_transformer(method, lmbda): assert_almost_equal(0, np.linalg.norm(X - X_inv_trans) / n_samples, decimal=2) + + +@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) +def test_power_transformer_nans(method): + # Make sure lambda estimation is not influenced by NaN values + # and that transform() supports NaN silently + + X = np.abs(X_1col) + pt = PowerTransformer(method=method) + pt.fit(X) + lmbda_no_nans = pt.lambdas_[0] + + # concat nans at the end and check lambda stays the same + X = np.concatenate([X, np.full_like(X, np.nan)]) + pt.fit(X) + lmbda_nans = pt.lambdas_[0] + + assert_equal(lmbda_no_nans, lmbda_nans) + pt.transform(X) From 8e187c4ecf7d660904cd39162c20ed09774fbad2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 16:51:50 -0400 Subject: [PATCH 09/38] Added doc --- sklearn/preprocessing/data.py | 43 +++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 6578db0d81df6..7c3164a73e37b 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2398,10 +2398,12 @@ class PowerTransformer(BaseEstimator, TransformerMixin): modeling issues related to heteroscedasticity (non-constant variance), or other situations where normality is desired. - Currently, PowerTransformer supports the Box-Cox transform. Box-Cox - requires input data to be strictly positive. The optimal parameter - for stabilizing variance and minimizing skewness is estimated through - maximum likelihood. + Currently, PowerTransformer supports the Box-Cox transform and the + Yeo-Johson transform. The optimal parameter for stabilizing variance and + minimizing skewness is estimated through maximum likelihood. + + Box-Cox requires input data to be strictly positive, while Yeo-Johnson + supports both positive or negative data. By default, zero-mean, unit-variance normalization is applied to the transformed data. @@ -2411,8 +2413,8 @@ class PowerTransformer(BaseEstimator, TransformerMixin): Parameters ---------- method : str, (default='box-cox') - The power transform method. Currently, 'box-cox' (Box-Cox transform) - is the only option available. + The power transform method. Available methods are 'box-cox' and + 'yeo-johnson'. standardize : boolean, default=True Set to True to apply zero-mean, unit-variance normalization to the @@ -2462,6 +2464,8 @@ class PowerTransformer(BaseEstimator, TransformerMixin): G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal of the Royal Statistical Society B, 26, 211-252 (1964). + I.K. Yeo and R.A. Johnson, "A new family of power transformations to + improve normality or symmetry." Biometrika, 87(4), pp.954-959. (2000) """ def __init__(self, method='box-cox', standardize=True, copy=True): self.method = method @@ -2469,11 +2473,10 @@ def __init__(self, method='box-cox', standardize=True, copy=True): self.copy = copy def fit(self, X, y=None): - """Estimate the optimal parameter for each feature. + """Estimate the optimal parameter lambda for each feature. - The optimal parameter for minimizing skewness is estimated - on each feature independently. If the method is Box-Cox, - the lambdas are estimated using maximum likelihood. + The optimal lambda parameter for minimizing skewness is estimated on + each feature independently using maximum likelihood. Parameters ---------- @@ -2546,6 +2549,17 @@ def inverse_transform(self, X): else: X = (X_trans * lambda + 1) ** (1 / lambda) + The inverse of the Yeo-Johnson transformation is given by:: + + if X >= 0 and lambda == 0: + X = exp(X_trans) - 1 + elif X >= 0 and lambda != 0: + X = (X_trans * lambda + 1) ** (1 / lambda) - 1 + elif X < 0 and lambda != 2: + X = 1 - (-(2 - lambda) * X_trans + 1) ** (1 / (2 - lambda)) + elif X < 0 and lambda == 2: + X = 1 - exp(-X_trans) + Parameters ---------- X : array-like, shape (n_samples, n_features) @@ -2566,6 +2580,9 @@ def inverse_transform(self, X): return X def _box_cox_inverse_tranform(self, x, lmbda): + """Return inverse-transformed input x following Box-Cox inverse + transform with parameter lambda. + """ if lmbda == 0: x_inv = np.exp(x) else: @@ -2574,6 +2591,9 @@ def _box_cox_inverse_tranform(self, x, lmbda): return x_inv def _yeo_johnson_inverse_transform(self, x, lmbda): + """Return inverse-transformed input x following Yeo-Johnson inverse + transform with parameter lambda. + """ x_inv = np.zeros(x.shape) pos = x >= 0 @@ -2593,6 +2613,9 @@ def _yeo_johnson_inverse_transform(self, x, lmbda): return x_inv def _yeo_johnson_transform(self, x, lmbda): + """Return transformed input x following Yeo-Johnson transform with + parameter lambda. + """ out = np.zeros(shape=x.shape) pos = (x >= 0) # binary mask From 4173df34ef9548c730902dfa8cef26c1fdf13280 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 16:52:08 -0400 Subject: [PATCH 10/38] Better test for nan in transform() --- sklearn/preprocessing/tests/test_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 4c6e112345c5a..65ebbd130f9df 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2203,4 +2203,6 @@ def test_power_transformer_nans(method): lmbda_nans = pt.lambdas_[0] assert_equal(lmbda_no_nans, lmbda_nans) - pt.transform(X) + + X_trans = pt.transform(X) + assert_array_equal(np.isnan(X_trans), np.isnan(X)) From 61e21836b9a0460a077cdd63dc6455c7f367614a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 18:28:21 -0400 Subject: [PATCH 11/38] Updated more docs and example --- doc/modules/preprocessing.rst | 39 ++++++++----- .../preprocessing/plot_power_transformer.py | 58 +++++++++++-------- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 0474a8a665016..603f7368bb75f 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -309,20 +309,32 @@ Power transforms are a family of parametric, monotonic transformations that aim to map data from any distribution to as close to a Gaussian distribution as possible in order to stabilize variance and minimize skewness. -:class:`PowerTransformer` currently provides one such power transformation, -the Box-Cox transform. The Box-Cox transform is given by: +:class:`PowerTransformer` currently provides two such power transformations, +the Box-Cox transform and the Yeo-Johnson transform. The Box-Cox transform is +given by: .. math:: - y_i^{(\lambda)} = + x_i^{(\lambda)} = \begin{cases} - \dfrac{y_i^\lambda - 1}{\lambda} & \text{if } \lambda \neq 0, \\[8pt] - \ln{(y_i)} & \text{if } \lambda = 0, + \dfrac{x_i^\lambda - 1}{\lambda} & \text{if } \lambda \neq 0, \\[8pt] + \ln{(x_i)} & \text{if } \lambda = 0, \end{cases} -Box-Cox can only be applied to strictly positive data. The transformation is -parameterized by :math:`\lambda`, which is determined through maximum likelihood -estimation. Here is an example of using Box-Cox to map samples drawn from a -lognormal distribution to a normal distribution:: +while the Yeo-Johnson is given by: + +.. math:: + x_i^{(\lambda)} = + \begin{cases} + [(x_i + 1)^\lambda - 1] / \lambda & \text{if } \lambda \neq 0, x_i \geq 0, \\[8pt] + \ln{(x_i) + 1} & \text{if } \lambda = 0, x_i \geq 0 \\[8pt] + -[(-x_i + 1)^{2 - \lambda} - 1] / (2 - \lambda) & \text{if } \lambda \neq 2, x_i < 0, \\[8pt] + - \ln (- x_i + 1) & \text{if } \lambda = 2, x_i < 0 + \end{cases} + +Box-Cox can only be applied to strictly positive data. In both methods, the +transformation is parameterized by :math:`\lambda`, which is determined through +maximum likelihood estimation. Here is an example of using Box-Cox to map +samples drawn from a lognormal distribution to a normal distribution:: >>> pt = preprocessing.PowerTransformer(method='box-cox', standardize=False) >>> X_lognormal = np.random.RandomState(616).lognormal(size=(3, 3)) @@ -339,10 +351,11 @@ While the above example sets the `standardize` option to `False`, :class:`PowerTransformer` will apply zero-mean, unit-variance normalization to the transformed output by default. -Below are examples of Box-Cox applied to various probability distributions. -Note that when applied to certain distributions, Box-Cox achieves very -Gaussian-like results, but with others, it is ineffective. This highlights -the importance of visualizing the data before and after transformation. +Below are examples of Box-Cox and Yeo-Johnson applied to various probability +distributions. Note that when applied to certain distributions, the power +transforms achieve very Gaussian-like results, but with others, they are +ineffective. This highlights the importance of visualizing the data before and +after transformation. .. figure:: ../auto_examples/preprocessing/images/sphx_glr_plot_power_transformer_001.png :target: ../auto_examples/preprocessing/plot_power_transformer.html diff --git a/examples/preprocessing/plot_power_transformer.py b/examples/preprocessing/plot_power_transformer.py index 52ce0d3121f73..21e90941355ff 100644 --- a/examples/preprocessing/plot_power_transformer.py +++ b/examples/preprocessing/plot_power_transformer.py @@ -1,19 +1,19 @@ """ -========================================================== -Using PowerTransformer to apply the Box-Cox transformation -========================================================== +====================== +Using PowerTransformer +====================== -This example demonstrates the use of the Box-Cox transform through -:class:`preprocessing.PowerTransformer` to map data from various distributions -to a normal distribution. +This example demonstrates the use of the Box-Cox and Yeo-Johnson transforms +through :class:`preprocessing.PowerTransformer` to map data from various +distributions to a normal distribution. -Box-Cox is useful as a transformation in modeling problems where -homoscedasticity and normality are desired. Below are examples of Box-Cox -applied to six different probability distributions: Lognormal, Chi-squared, -Weibull, Gaussian, Uniform, and Bimodal. +The power transform is useful as a transformation in modeling problems where +homoscedasticity and normality are desired. Below are examples of Box-Cox and +Yeo-Johnwon applied to six different probability distributions: Lognormal, +Chi-squared, Weibull, Gaussian, Uniform, and Bimodal. -Note that the transformation successfully maps the data to a normal -distribution when applied to certain datasets, but is ineffective with others. +Note that the transformations successfully map the data to a normal +distribution when applied to certain datasets, but are ineffective with others. This highlights the importance of visualizing the data before and after transformation. Also note that while the standardize option is set to False for the plot examples, by default, :class:`preprocessing.PowerTransformer` also @@ -21,6 +21,7 @@ """ # Author: Eric Chang +# Author: Nicolas Hug # License: BSD 3 clause import numpy as np @@ -36,7 +37,8 @@ BINS = 100 -pt = PowerTransformer(method='box-cox', standardize=False) +bc = PowerTransformer(method='box-cox', standardize=False) +yj = PowerTransformer(method='yeo-johnson', standardize=False) rng = np.random.RandomState(304) size = (N_SAMPLES, 1) @@ -78,10 +80,11 @@ colors = ['firebrick', 'darkorange', 'goldenrod', 'seagreen', 'royalblue', 'darkorchid'] -fig, axes = plt.subplots(nrows=4, ncols=3) +fig, axes = plt.subplots(nrows=6, ncols=3) axes = axes.flatten() -axes_idxs = [(0, 3), (1, 4), (2, 5), (6, 9), (7, 10), (8, 11)] -axes_list = [(axes[i], axes[j]) for i, j in axes_idxs] +axes_idxs = [(0, 3, 6), (1, 4, 7), (2, 5, 8), (9, 12, 15), (10, 13, 16), + (11, 14, 17)] +axes_list = [(axes[i], axes[j], axes[k]) for (i, j, k) in axes_idxs] for distribution, color, axes in zip(distributions, colors, axes_list): @@ -89,20 +92,27 @@ # scale all distributions to the range [0, 10] X = minmax_scale(X, feature_range=(1e-10, 10)) - # perform power transform - X_trans = pt.fit_transform(X) - lmbda = round(pt.lambdas_[0], 2) + # perform power transforms + X_trans_bc = bc.fit_transform(X) + lmbda_bc = round(bc.lambdas_[0], 2) + X_trans_yj = yj.fit_transform(X) + lmbda_yj = round(yj.lambdas_[0], 2) - ax_original, ax_trans = axes + ax_original, ax_bc, ax_yj = axes ax_original.hist(X, color=color, bins=BINS) ax_original.set_title(name, fontsize=FONT_SIZE) ax_original.tick_params(axis='both', which='major', labelsize=FONT_SIZE) - ax_trans.hist(X_trans, color=color, bins=BINS) - ax_trans.set_title('{} after Box-Cox, $\lambda$ = {}'.format(name, lmbda), - fontsize=FONT_SIZE) - ax_trans.tick_params(axis='both', which='major', labelsize=FONT_SIZE) + for ax, X_trans, meth_name, lmbda in zip((ax_bc, ax_yj), + (X_trans_bc, X_trans_yj), + ('Box-Cox', 'Yeo-Johnson'), + (lmbda_bc, lmbda_yj)): + ax.hist(X_trans, color=color, bins=BINS) + ax.set_title('{} after {}, $\lambda$ = {}'.format(name, meth_name, + lmbda), + fontsize=FONT_SIZE) + ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE) plt.tight_layout() From b1ac8d48beb63af2828025e7b923f0ee493bc9da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 18:30:30 -0400 Subject: [PATCH 12/38] updated test --- sklearn/preprocessing/tests/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 65ebbd130f9df..f18ea668fa358 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2180,7 +2180,7 @@ def test_optimization_power_transformer(method, lmbda): pt = PowerTransformer(method=method, standardize=False) pt.lambdas_ = [lmbda] X_inv = pt.inverse_transform(X) - pt.lambdas_ = [9999] # just to make sure + del pt.lambdas_ # just to make sure X_inv_trans = pt.fit_transform(X_inv) assert_almost_equal(0, np.linalg.norm(X - X_inv_trans) / n_samples, From 489bc70b32d01ecea6836fca12eb87aee299a2cc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 18:43:42 -0400 Subject: [PATCH 13/38] Modified tests according to reviews --- sklearn/preprocessing/tests/test_data.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f18ea668fa358..8a838747a60a9 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -62,6 +62,7 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import cross_val_predict from sklearn.svm import SVR +from sklearn.utils import shuffle from sklearn import datasets @@ -2161,9 +2162,12 @@ def test_power_transformer_lambda_one(): X_trans = pt.transform(X) assert_array_almost_equal(X_trans, X) - -@pytest.mark.parametrize("method, lmbda", [('box-cox', .5), - ('yeo-johnson', .1)]) +@pytest.mark.parametrize("method, lmbda",[('box-cox', .1), + ('box-cox', .5), + ('yeo-johnson', .1), + ('yeo-johnson', .5), + ('yeo-johnson', 1.), + ]) def test_optimization_power_transformer(method, lmbda): """Test the optimization procedure @@ -2175,16 +2179,19 @@ def test_optimization_power_transformer(method, lmbda): rng = np.random.RandomState(0) n_samples = 1000 - X = rng.normal(size=(n_samples, 1)) + X = rng.normal(loc=0, scale=1, size=(n_samples, 1)) pt = PowerTransformer(method=method, standardize=False) pt.lambdas_ = [lmbda] X_inv = pt.inverse_transform(X) - del pt.lambdas_ # just to make sure + + pt = PowerTransformer(method=method, standardize=False) X_inv_trans = pt.fit_transform(X_inv) assert_almost_equal(0, np.linalg.norm(X - X_inv_trans) / n_samples, decimal=2) + assert_almost_equal(0, X_inv_trans.mean(), decimal=1) + assert_almost_equal(1, X_inv_trans.std(), decimal=1) @pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) @@ -2199,10 +2206,12 @@ def test_power_transformer_nans(method): # concat nans at the end and check lambda stays the same X = np.concatenate([X, np.full_like(X, np.nan)]) + X = shuffle(X, random_state=0) + pt.fit(X) lmbda_nans = pt.lambdas_[0] - assert_equal(lmbda_no_nans, lmbda_nans) + assert_almost_equal(lmbda_no_nans, lmbda_nans, decimal=7) X_trans = pt.transform(X) assert_array_equal(np.isnan(X_trans), np.isnan(X)) From 6783e3ac94e4bf1c81bf901eec74f315eb38ef5e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 15 Jul 2018 18:45:22 -0400 Subject: [PATCH 14/38] Changed default method from cox-box to yeo-johnson --- sklearn/preprocessing/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 7c3164a73e37b..7b3189cea1334 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2412,7 +2412,7 @@ class PowerTransformer(BaseEstimator, TransformerMixin): Parameters ---------- - method : str, (default='box-cox') + method : str, (default='yeo-johnson') The power transform method. Available methods are 'box-cox' and 'yeo-johnson'. @@ -2467,7 +2467,7 @@ class PowerTransformer(BaseEstimator, TransformerMixin): I.K. Yeo and R.A. Johnson, "A new family of power transformations to improve normality or symmetry." Biometrika, 87(4), pp.954-959. (2000) """ - def __init__(self, method='box-cox', standardize=True, copy=True): + def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.method = method self.standardize = standardize self.copy = copy From dfd1eccd27c458b6c95168025904d8f188a764af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 11:25:02 -0400 Subject: [PATCH 15/38] Addressed most comments from @glemaitre, fixed flake8 --- sklearn/preprocessing/data.py | 45 ++++++++++++++---------- sklearn/preprocessing/tests/test_data.py | 13 +++---- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 7b3189cea1334..cdf2d2fa95da5 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2413,8 +2413,10 @@ class PowerTransformer(BaseEstimator, TransformerMixin): Parameters ---------- method : str, (default='yeo-johnson') - The power transform method. Available methods are 'box-cox' and - 'yeo-johnson'. + The power transform method. Available methods are: + + - 'box-cox' :ref:`(ref) ` + - 'yeo-johnson' :ref:`(ref) ` standardize : boolean, default=True Set to True to apply zero-mean, unit-variance normalization to the @@ -2461,11 +2463,16 @@ class PowerTransformer(BaseEstimator, TransformerMixin): References ---------- + + .. _box_cox_paper_ref: + G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal of the Royal Statistical Society B, 26, 211-252 (1964). + .. _yeo_johnson_paper_ref: + I.K. Yeo and R.A. Johnson, "A new family of power transformations to - improve normality or symmetry." Biometrika, 87(4), pp.954-959. (2000) + improve normality or symmetry." Biometrika, 87(4), pp.954-959, (2000). """ def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.method = method @@ -2494,18 +2501,18 @@ def fit(self, X, y=None): self.lambdas_ = [] transformed = [] - opt_fun = {'box-cox': self._box_cox_optimize, - 'yeo-johnson': self._yeo_johnson_optimize - }[self.method] - trans_fun = {'box-cox': boxcox, - 'yeo-johnson': self._yeo_johnson_transform - }[self.method] + optim_function = {'box-cox': self._box_cox_optimize, + 'yeo-johnson': self._yeo_johnson_optimize + }[self.method] + transform_function = {'box-cox': boxcox, + 'yeo-johnson': self._yeo_johnson_transform + }[self.method] for col in X.T: - lmbda = opt_fun(col) + lmbda = optim_function(col) self.lambdas_.append(lmbda) - col_trans = trans_fun(col, lmbda) + col_trans = transform_function(col, lmbda) transformed.append(col_trans) self.lambdas_ = np.array(self.lambdas_) @@ -2528,11 +2535,11 @@ def transform(self, X): check_is_fitted(self, 'lambdas_') X = self._check_input(X, check_positive=True, check_shape=True) - trans_fun = {'box-cox': boxcox, - 'yeo-johnson': self._yeo_johnson_transform - }[self.method] + transform_function = {'box-cox': boxcox, + 'yeo-johnson': self._yeo_johnson_transform + }[self.method] for i, lmbda in enumerate(self.lambdas_): - X[:, i] = trans_fun(X[:, i], lmbda) + X[:, i] = transform_function(X[:, i], lmbda) if self.standardize: X = self._scaler.transform(X) @@ -2660,13 +2667,13 @@ def _neg_log_likelihood(lmbda): """Return the negative log likelihood of the observed data x as a function of lambda.""" psi = self._yeo_johnson_transform(x, lmbda) - n = x.shape[0] + n_samples = x.shape[0] # Estimated mean and variance of the normal distribution - mu = psi.sum() / n - sig_sq = np.power(psi - mu, 2).sum() / n + mu = psi.sum() / n_samples + sig_sq = np.power(psi - mu, 2).sum() / n_samples - loglike = -n / 2 * np.log(sig_sq) + loglike = -n_samples / 2 * np.log(sig_sq) loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() return -loglike diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 8a838747a60a9..6db0e89a2d621 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2162,12 +2162,13 @@ def test_power_transformer_lambda_one(): X_trans = pt.transform(X) assert_array_almost_equal(X_trans, X) -@pytest.mark.parametrize("method, lmbda",[('box-cox', .1), - ('box-cox', .5), - ('yeo-johnson', .1), - ('yeo-johnson', .5), - ('yeo-johnson', 1.), - ]) + +@pytest.mark.parametrize("method, lmbda", [('box-cox', .1), + ('box-cox', .5), + ('yeo-johnson', .1), + ('yeo-johnson', .5), + ('yeo-johnson', 1.), + ]) def test_optimization_power_transformer(method, lmbda): """Test the optimization procedure From 78169f69ad67e522c30ac7e88747d5e138682763 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 11:27:15 -0400 Subject: [PATCH 16/38] Removed box-cox specific checks in estimator_checks The default is now Yeo-Johnson which can handle negative data --- sklearn/utils/estimator_checks.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 02d91ee80791b..743cabdcb4345 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -853,9 +853,6 @@ def check_transformer_general(name, transformer, readonly_memmap=False): random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) X -= X.min() - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X += 1 if readonly_memmap: X, y = create_memmap_backed_data([X, y]) @@ -981,9 +978,6 @@ def check_pipeline_consistency(name, estimator_orig): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X -= X.min() - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X += 1 X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -1045,9 +1039,6 @@ def check_estimators_dtypes(name, estimator_orig): methods = ["predict", "transform", "decision_function", "predict_proba"] for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]: - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X_train = np.abs(X_train) + 1 estimator = clone(estimator_orig) set_random_state(estimator, 1) estimator.fit(X_train, y) @@ -1162,9 +1153,6 @@ def check_estimators_pickle(name, estimator_orig): # some estimators can't do features less than 0 X -= X.min() - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X += 1 X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) @@ -1517,9 +1505,6 @@ def check_estimators_fit_returns_self(name, estimator_orig, X, y = make_blobs(random_state=0, n_samples=9, n_features=4) # some want non-negative input X -= X.min() - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X += 1 X = pairwise_estimator_convert_X(X, estimator_orig) estimator = clone(estimator_orig) @@ -1880,9 +1865,6 @@ def check_estimators_overwrite_params(name, estimator_orig): X, y = make_blobs(random_state=0, n_samples=9) # some want non-negative input X -= X.min() - if name == 'PowerTransformer': - # Box-Cox requires positive, non-zero data - X += 1 X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) From f48a17b4fd5b28c67334d6e8f20f778f5dafa8b2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 11:36:00 -0400 Subject: [PATCH 17/38] More explicit variable names for mean and variance --- sklearn/preprocessing/data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index cdf2d2fa95da5..9fb6910e14811 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2666,14 +2666,14 @@ def _yeo_johnson_optimize(self, x): def _neg_log_likelihood(lmbda): """Return the negative log likelihood of the observed data x as a function of lambda.""" - psi = self._yeo_johnson_transform(x, lmbda) + transformed = self._yeo_johnson_transform(x, lmbda) n_samples = x.shape[0] # Estimated mean and variance of the normal distribution - mu = psi.sum() / n_samples - sig_sq = np.power(psi - mu, 2).sum() / n_samples + est_mean = transformed.sum() / n_samples + est_var = np.power(transformed - est_mean, 2).sum() / n_samples - loglike = -n_samples / 2 * np.log(sig_sq) + loglike = -n_samples / 2 * np.log(est_var) loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() return -loglike From 67eaa9873265176aa9d00375351283366cdcfab0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 13:33:21 -0400 Subject: [PATCH 18/38] Addressed comments from glemaitre --- doc/modules/preprocessing.rst | 24 ++++++----- ...r_transformer.py => map_data_to_normal.py} | 43 +++++++++++-------- sklearn/preprocessing/data.py | 31 ++++++------- sklearn/preprocessing/tests/test_data.py | 40 +++++++---------- 4 files changed, 69 insertions(+), 69 deletions(-) rename examples/preprocessing/{plot_power_transformer.py => map_data_to_normal.py} (72%) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 603f7368bb75f..4f925ee197b0e 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -310,27 +310,29 @@ to map data from any distribution to as close to a Gaussian distribution as possible in order to stabilize variance and minimize skewness. :class:`PowerTransformer` currently provides two such power transformations, -the Box-Cox transform and the Yeo-Johnson transform. The Box-Cox transform is -given by: +the Yeo-Johnson transform and the Box-Cox transform. + +The Yeo-Johnson transform is given by: .. math:: x_i^{(\lambda)} = \begin{cases} - \dfrac{x_i^\lambda - 1}{\lambda} & \text{if } \lambda \neq 0, \\[8pt] - \ln{(x_i)} & \text{if } \lambda = 0, + [(x_i + 1)^\lambda - 1] / \lambda & \text{if } \lambda \neq 0, x_i \geq 0, \\[8pt] + \ln{(x_i) + 1} & \text{if } \lambda = 0, x_i \geq 0 \\[8pt] + -[(-x_i + 1)^{2 - \lambda} - 1] / (2 - \lambda) & \text{if } \lambda \neq 2, x_i < 0, \\[8pt] + - \ln (- x_i + 1) & \text{if } \lambda = 2, x_i < 0 \end{cases} -while the Yeo-Johnson is given by: +while the Box-Cox transform is given by: .. math:: x_i^{(\lambda)} = \begin{cases} - [(x_i + 1)^\lambda - 1] / \lambda & \text{if } \lambda \neq 0, x_i \geq 0, \\[8pt] - \ln{(x_i) + 1} & \text{if } \lambda = 0, x_i \geq 0 \\[8pt] - -[(-x_i + 1)^{2 - \lambda} - 1] / (2 - \lambda) & \text{if } \lambda \neq 2, x_i < 0, \\[8pt] - - \ln (- x_i + 1) & \text{if } \lambda = 2, x_i < 0 + \dfrac{x_i^\lambda - 1}{\lambda} & \text{if } \lambda \neq 0, \\[8pt] + \ln{(x_i)} & \text{if } \lambda = 0, \end{cases} + Box-Cox can only be applied to strictly positive data. In both methods, the transformation is parameterized by :math:`\lambda`, which is determined through maximum likelihood estimation. Here is an example of using Box-Cox to map @@ -357,8 +359,8 @@ transforms achieve very Gaussian-like results, but with others, they are ineffective. This highlights the importance of visualizing the data before and after transformation. -.. figure:: ../auto_examples/preprocessing/images/sphx_glr_plot_power_transformer_001.png - :target: ../auto_examples/preprocessing/plot_power_transformer.html +.. figure:: ../auto_examples/preprocessing/images/sphx_glr_map_data_to_normal_001.png + :target: ../auto_examples/preprocessing/map_data_to_normal.html :align: center :scale: 100 diff --git a/examples/preprocessing/plot_power_transformer.py b/examples/preprocessing/map_data_to_normal.py similarity index 72% rename from examples/preprocessing/plot_power_transformer.py rename to examples/preprocessing/map_data_to_normal.py index 21e90941355ff..2811d660f18c1 100644 --- a/examples/preprocessing/plot_power_transformer.py +++ b/examples/preprocessing/map_data_to_normal.py @@ -1,7 +1,7 @@ """ -====================== -Using PowerTransformer -====================== +================================= +Map data to a normal distribution +================================= This example demonstrates the use of the Box-Cox and Yeo-Johnson transforms through :class:`preprocessing.PowerTransformer` to map data from various @@ -18,6 +18,9 @@ transformation. Also note that while the standardize option is set to False for the plot examples, by default, :class:`preprocessing.PowerTransformer` also applies zero-mean, unit-variance standardization to the transformed outputs. + +For comparison, we also add the output from +:class:`preprocessing.QuantileTransformer`. """ # Author: Eric Chang @@ -28,6 +31,7 @@ import matplotlib.pyplot as plt from sklearn.preprocessing import PowerTransformer, minmax_scale +from sklearn.preprocessing import QuantileTransformer print(__doc__) @@ -37,9 +41,10 @@ BINS = 100 +rng = np.random.RandomState(304) bc = PowerTransformer(method='box-cox', standardize=False) yj = PowerTransformer(method='yeo-johnson', standardize=False) -rng = np.random.RandomState(304) +qt = QuantileTransformer(output_distribution='normal', random_state=rng) size = (N_SAMPLES, 1) @@ -80,11 +85,12 @@ colors = ['firebrick', 'darkorange', 'goldenrod', 'seagreen', 'royalblue', 'darkorchid'] -fig, axes = plt.subplots(nrows=6, ncols=3) +fig, axes = plt.subplots(nrows=8, ncols=3, figsize=plt.figaspect(3)) axes = axes.flatten() -axes_idxs = [(0, 3, 6), (1, 4, 7), (2, 5, 8), (9, 12, 15), (10, 13, 16), - (11, 14, 17)] -axes_list = [(axes[i], axes[j], axes[k]) for (i, j, k) in axes_idxs] +axes_idxs = [(0, 3, 6, 9), (1, 4, 7, 10), (2, 5, 8, 11), (12, 15, 18, 21), + (13, 16, 19, 22), (14, 17, 20, 23)] +axes_list = [(axes[i], axes[j], axes[k], axes[l]) + for (i, j, k, l) in axes_idxs] for distribution, color, axes in zip(distributions, colors, axes_list): @@ -92,26 +98,29 @@ # scale all distributions to the range [0, 10] X = minmax_scale(X, feature_range=(1e-10, 10)) - # perform power transforms + # perform power transforms and quantile transform X_trans_bc = bc.fit_transform(X) lmbda_bc = round(bc.lambdas_[0], 2) X_trans_yj = yj.fit_transform(X) lmbda_yj = round(yj.lambdas_[0], 2) + X_trans_qt = qt.fit_transform(X) - ax_original, ax_bc, ax_yj = axes + ax_original, ax_bc, ax_yj, ax_qt = axes ax_original.hist(X, color=color, bins=BINS) ax_original.set_title(name, fontsize=FONT_SIZE) ax_original.tick_params(axis='both', which='major', labelsize=FONT_SIZE) - for ax, X_trans, meth_name, lmbda in zip((ax_bc, ax_yj), - (X_trans_bc, X_trans_yj), - ('Box-Cox', 'Yeo-Johnson'), - (lmbda_bc, lmbda_yj)): + for ax, X_trans, meth_name, lmbda in zip( + (ax_bc, ax_yj, ax_qt), + (X_trans_bc, X_trans_yj, X_trans_qt), + ('Box-Cox', 'Yeo-Johnson', 'Quantile transform'), + (lmbda_bc, lmbda_yj, None)): ax.hist(X_trans, color=color, bins=BINS) - ax.set_title('{} after {}, $\lambda$ = {}'.format(name, meth_name, - lmbda), - fontsize=FONT_SIZE) + title = '{} after {}'.format(name, meth_name) + if lmbda is not None: + title += ', $\lambda$ = {}'.format(lmbda) + ax.set_title(title, fontsize=FONT_SIZE) ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 9fb6910e14811..9864e932f6fd2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2415,8 +2415,8 @@ class PowerTransformer(BaseEstimator, TransformerMixin): method : str, (default='yeo-johnson') The power transform method. Available methods are: - - 'box-cox' :ref:`(ref) ` - - 'yeo-johnson' :ref:`(ref) ` + - 'yeo-johnson' [1]_, works with postive and negative values + - 'box-cox' [2]_, only works with strictly positive values standardize : boolean, default=True Set to True to apply zero-mean, unit-variance normalization to the @@ -2464,15 +2464,12 @@ class PowerTransformer(BaseEstimator, TransformerMixin): References ---------- - .. _box_cox_paper_ref: + .. [1] I.K. Yeo and R.A. Johnson, "A new family of power transformations to + improve normality or symmetry." Biometrika, 87(4), pp.954-959, + (2000). - G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal of the - Royal Statistical Society B, 26, 211-252 (1964). - - .. _yeo_johnson_paper_ref: - - I.K. Yeo and R.A. Johnson, "A new family of power transformations to - improve normality or symmetry." Biometrika, 87(4), pp.954-959, (2000). + .. [2] G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal + of the Royal Statistical Society B, 26, 211-252 (1964). """ def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.method = method @@ -2601,7 +2598,7 @@ def _yeo_johnson_inverse_transform(self, x, lmbda): """Return inverse-transformed input x following Yeo-Johnson inverse transform with parameter lambda. """ - x_inv = np.zeros(x.shape) + x_inv = np.zeros(x.shape, dtype=x.dtype) pos = x >= 0 # when x >= 0 @@ -2624,8 +2621,8 @@ def _yeo_johnson_transform(self, x, lmbda): parameter lambda. """ - out = np.zeros(shape=x.shape) - pos = (x >= 0) # binary mask + out = np.zeros(shape=x.shape, dtype=x.dtype) + pos = x >= 0 # binary mask # Note: we're comparing lmbda to 1e-19 instead of strict equality to 0. # See scipy/special/_boxcox.pxd for a rationale behind this @@ -2660,18 +2657,18 @@ def _yeo_johnson_optimize(self, x): """Find and return optimal lambda parameter of the Yeo-Johnson transform by MLE, for observed data x. - Like for coxbox, MLE is done via the brent optimizer. + Like for Box-Cox, MLE is done via the brent optimizer. """ def _neg_log_likelihood(lmbda): """Return the negative log likelihood of the observed data x as a function of lambda.""" - transformed = self._yeo_johnson_transform(x, lmbda) + x_trans = self._yeo_johnson_transform(x, lmbda) n_samples = x.shape[0] # Estimated mean and variance of the normal distribution - est_mean = transformed.sum() / n_samples - est_var = np.power(transformed - est_mean, 2).sum() / n_samples + est_mean = x_trans.sum() / n_samples + est_var = np.power(x_trans - est_mean, 2).sum() / n_samples loglike = -n_samples / 2 * np.log(est_var) loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum() diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 6db0e89a2d621..9c3cfb3c3707e 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2012,17 +2012,14 @@ def test_power_transformer_notfitted(method): assert_raises(NotFittedError, pt.inverse_transform, X) +@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) @pytest.mark.parametrize('standardize', [True, False]) @pytest.mark.parametrize('X', [X_1col, X_2d]) -def test_power_transformer_inverse(standardize, X): +def test_power_transformer_inverse(method, standardize, X): # Make sure we get the original input when applying transform and then # inverse transform - pt = PowerTransformer(method='yeo-johnson', standardize=standardize) - X_trans = pt.fit_transform(X) - assert_almost_equal(X, pt.inverse_transform(X_trans)) - - X = np.abs(X) - pt = PowerTransformer(method='box-cox', standardize=standardize) + X = np.abs(X) if method == 'box-cox' else X + pt = PowerTransformer(method=method, standardize=standardize) X_trans = pt.fit_transform(X) assert_almost_equal(X, pt.inverse_transform(X_trans)) @@ -2078,7 +2075,7 @@ def test_power_transformer_2d(): assert isinstance(pt.lambdas_, np.ndarray) -def test_power_transformer_strictly_positive_exception(): +def test_power_transformer_boxcox_strictly_positive_exception(): # Exceptions should be raised for negative arrays and zero arrays when # method is coxbox @@ -2105,15 +2102,12 @@ def test_power_transformer_strictly_positive_exception(): assert_raise_message(ValueError, not_positive_message, power_transform, np.zeros(X_2d.shape), 'box-cox') - # It should not raise any error for yeo-johnson - pt = PowerTransformer(method='yeo-johnson') - pt.fit(np.abs(X_2d)) - pt.transform(X_with_negatives) - pt.fit(X_with_negatives) - power_transform(X_with_negatives, method='yeo-johnson') - pt.transform(np.zeros(X_2d.shape)) - pt.fit(np.zeros(X_2d.shape)) - power_transform(np.zeros(X_2d.shape), method='yeo-johnson') + +@pytest.mark.parametrize('X', [X_2d, np.abs(X_2d), -np.abs(X_2d), + np.zeros(X_2d.shape)]) +def test_power_transformer_yeojohnson_any_input(X): + # Yeo-Johnson method should support any kind of input + power_transform(X, method='yeo-johnson') @pytest.mark.parametrize("method", ['box-cox', 'yeo-johnson']) @@ -2170,13 +2164,11 @@ def test_power_transformer_lambda_one(): ('yeo-johnson', 1.), ]) def test_optimization_power_transformer(method, lmbda): - """Test the optimization procedure - - - set a predefined value for lambda - - apply inverse_transform to a normal dist (we get X_inv) - - apply fit_transform to X_inv (we get X_inv_trans) - - check that X_inv_trans is roughly equal to X - """ + # Test the optimization procedure: + # - set a predefined value for lambda + # - apply inverse_transform to a normal dist (we get X_inv) + # - apply fit_transform to X_inv (we get X_inv_trans) + # - check that X_inv_trans is roughly equal to X rng = np.random.RandomState(0) n_samples = 1000 From 7a2bce7c50d6eb062c8b98a8747cf15d531c72a0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 13:37:11 -0400 Subject: [PATCH 19/38] Changed number of bins in plots to auto --- examples/preprocessing/map_data_to_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/preprocessing/map_data_to_normal.py b/examples/preprocessing/map_data_to_normal.py index 2811d660f18c1..bfac9265a68c3 100644 --- a/examples/preprocessing/map_data_to_normal.py +++ b/examples/preprocessing/map_data_to_normal.py @@ -38,7 +38,7 @@ N_SAMPLES = 3000 FONT_SIZE = 6 -BINS = 100 +BINS = 'auto' rng = np.random.RandomState(304) From 2b56a9d4a09779cbef16b4e598fccdd1ace35665 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 15:22:34 -0400 Subject: [PATCH 20/38] Fixed Nan issues (ignored warnings) Also removed other checks about positive data now that Yeo-Johnson is the default --- sklearn/preprocessing/data.py | 15 +++++++++------ sklearn/preprocessing/tests/test_common.py | 22 ++++++++++------------ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 9864e932f6fd2..acaca0e819dca 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2506,11 +2506,12 @@ def fit(self, X, y=None): }[self.method] for col in X.T: - lmbda = optim_function(col) - self.lambdas_.append(lmbda) + with np.errstate(invalid='ignore'): # hide NaN warnings + lmbda = optim_function(col) + self.lambdas_.append(lmbda) - col_trans = transform_function(col, lmbda) - transformed.append(col_trans) + col_trans = transform_function(col, lmbda) + transformed.append(col_trans) self.lambdas_ = np.array(self.lambdas_) transformed = np.array(transformed) @@ -2536,7 +2537,8 @@ def transform(self, X): 'yeo-johnson': self._yeo_johnson_transform }[self.method] for i, lmbda in enumerate(self.lambdas_): - X[:, i] = transform_function(X[:, i], lmbda) + with np.errstate(invalid='ignore'): # hide NaN warnings + X[:, i] = transform_function(X[:, i], lmbda) if self.standardize: X = self._scaler.transform(X) @@ -2579,7 +2581,8 @@ def inverse_transform(self, X): 'yeo-johnson': self._yeo_johnson_inverse_transform }[self.method] for i, lmbda in enumerate(self.lambdas_): - X[:, i] = inv_fun(X[:, i], lmbda) + with np.errstate(invalid='ignore'): # hide NaN warnings + X[:, i] = inv_fun(X[:, i], lmbda) return X diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index cbb77e4884040..dc37c888b98d4 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -36,25 +36,23 @@ def _get_valid_samples_by_column(X, col): @pytest.mark.parametrize( - "est, func, support_sparse, strictly_positive", - [(MaxAbsScaler(), maxabs_scale, True, False), - (MinMaxScaler(), minmax_scale, False, False), - (StandardScaler(), scale, False, False), - (StandardScaler(with_mean=False), scale, True, False), - (PowerTransformer(), power_transform, False, True), - (QuantileTransformer(n_quantiles=10), quantile_transform, True, False), - (RobustScaler(), robust_scale, False, False), - (RobustScaler(with_centering=False), robust_scale, True, False)] + "est, func, support_sparse", + [(MaxAbsScaler(), maxabs_scale, True), + (MinMaxScaler(), minmax_scale, False), + (StandardScaler(), scale, False), + (StandardScaler(with_mean=False), scale, True), + (PowerTransformer(), power_transform, False), + (QuantileTransformer(n_quantiles=10), quantile_transform, True), + (RobustScaler(), robust_scale, False), + (RobustScaler(with_centering=False), robust_scale, True)] ) -def test_missing_value_handling(est, func, support_sparse, strictly_positive): +def test_missing_value_handling(est, func, support_sparse): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) X = iris.data.copy() n_missing = 50 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan - if strictly_positive: - X += np.nanmin(X) + 0.1 X_train, X_test = train_test_split(X, random_state=1) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() From e928d26b62b4c86afe38ff4f1cce98a702f8895e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 15:36:14 -0400 Subject: [PATCH 21/38] Fixed docstring example issue --- sklearn/preprocessing/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index acaca0e819dca..140dbb93719f3 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2437,13 +2437,13 @@ class PowerTransformer(BaseEstimator, TransformerMixin): >>> pt = PowerTransformer() >>> data = [[1, 2], [3, 2], [4, 5]] >>> print(pt.fit(data)) - PowerTransformer(copy=True, method='box-cox', standardize=True) + PowerTransformer(copy=True, method='yeo-johnson', standardize=True) >>> print(pt.lambdas_) # doctest: +ELLIPSIS - [ 1.051... -2.345...] + [1.38668178e+00 5.93926346e-09] >>> print(pt.transform(data)) # doctest: +ELLIPSIS - [[-1.332... -0.707...] - [ 0.256... -0.707...] - [ 1.076... 1.414...]] + [[-1.31616039 -0.70710678] + [ 0.20998268 -0.70710678] + [ 1.1061777 1.41421356]] See also -------- From 5273212be1a7f1459b8b7661059ca926948ddb83 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 16:44:45 -0400 Subject: [PATCH 22/38] Updated whatsnew --- doc/whats_new/v0.20.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 5b9216926b834..973e4979a460c 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -137,12 +137,14 @@ Preprocessing DataFrames. :issue:`9012` by `Andreas Müller`_ and `Joris Van den Bossche`_, and :issue:`11315` by :user:`Thomas Fan `. -- Added :class:`preprocessing.PowerTransformer`, which implements the Box-Cox - power transformation, allowing users to map data from any distribution to a - Gaussian distribution. This is useful as a variance-stabilizing transformation - in situations where normality and homoscedasticity are desirable. +- Added :class:`preprocessing.PowerTransformer`, which implements the + Yeo-Johnson and Box-Cox power transformations, allowing users to map data + from any distribution to a Gaussian distribution. This is useful as a + variance-stabilizing transformation in situations where normality and + homoscedasticity are desirable. :issue:`10210` by :user:`Eric Chang ` and - :user:`Maniteja Nandana `. + :user:`Maniteja Nandana `, and :issue:`10261` by :user:`Nicolas + Hug Date: Mon, 16 Jul 2018 18:06:11 -0400 Subject: [PATCH 23/38] Addressed comments from glemaitre --- doc/whats_new/v0.20.rst | 4 ++-- examples/preprocessing/map_data_to_normal.py | 2 +- sklearn/preprocessing/data.py | 17 ++++++++++----- sklearn/preprocessing/tests/test_common.py | 23 +++++++++++--------- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 973e4979a460c..abdccc78205d6 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -143,8 +143,8 @@ Preprocessing variance-stabilizing transformation in situations where normality and homoscedasticity are desirable. :issue:`10210` by :user:`Eric Chang ` and - :user:`Maniteja Nandana `, and :issue:`10261` by :user:`Nicolas - Hug `, and :issue:`11520` by :user:`Nicolas + Hug `. - Added the :class:`compose.TransformedTargetRegressor` which transforms the target y before fitting a regression model. The predictions are mapped diff --git a/examples/preprocessing/map_data_to_normal.py b/examples/preprocessing/map_data_to_normal.py index bfac9265a68c3..e2f69a3f9d52d 100644 --- a/examples/preprocessing/map_data_to_normal.py +++ b/examples/preprocessing/map_data_to_normal.py @@ -24,7 +24,7 @@ """ # Author: Eric Chang -# Author: Nicolas Hug + Nicolas Hug # License: BSD 3 clause import numpy as np diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 140dbb93719f3..7ff7945178582 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2438,9 +2438,9 @@ class PowerTransformer(BaseEstimator, TransformerMixin): >>> data = [[1, 2], [3, 2], [4, 5]] >>> print(pt.fit(data)) PowerTransformer(copy=True, method='yeo-johnson', standardize=True) - >>> print(pt.lambdas_) # doctest: +ELLIPSIS + >>> print(pt.lambdas_) [1.38668178e+00 5.93926346e-09] - >>> print(pt.transform(data)) # doctest: +ELLIPSIS + >>> print(pt.transform(data)) [[-1.31616039 -0.70710678] [ 0.20998268 -0.70710678] [ 1.1061777 1.41421356]] @@ -2600,6 +2600,11 @@ def _box_cox_inverse_tranform(self, x, lmbda): def _yeo_johnson_inverse_transform(self, x, lmbda): """Return inverse-transformed input x following Yeo-Johnson inverse transform with parameter lambda. + + Note + ---- + We're comparing lmbda to 1e-19 instead of strict equality to 0. See + scipy/special/_boxcox.pxd for a rationale behind this """ x_inv = np.zeros(x.shape, dtype=x.dtype) pos = x >= 0 @@ -2622,14 +2627,16 @@ def _yeo_johnson_inverse_transform(self, x, lmbda): def _yeo_johnson_transform(self, x, lmbda): """Return transformed input x following Yeo-Johnson transform with parameter lambda. + + Note + ---- + We're comparing lmbda to 1e-19 instead of strict equality to 0. See + scipy/special/_boxcox.pxd for a rationale behind this """ out = np.zeros(shape=x.shape, dtype=x.dtype) pos = x >= 0 # binary mask - # Note: we're comparing lmbda to 1e-19 instead of strict equality to 0. - # See scipy/special/_boxcox.pxd for a rationale behind this - # when x >= 0 if lmbda < 1e-19: out[pos] = np.log(x[pos] + 1) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index dc37c888b98d4..ac904d99e8af3 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -36,23 +36,26 @@ def _get_valid_samples_by_column(X, col): @pytest.mark.parametrize( - "est, func, support_sparse", - [(MaxAbsScaler(), maxabs_scale, True), - (MinMaxScaler(), minmax_scale, False), - (StandardScaler(), scale, False), - (StandardScaler(with_mean=False), scale, True), - (PowerTransformer(), power_transform, False), - (QuantileTransformer(n_quantiles=10), quantile_transform, True), - (RobustScaler(), robust_scale, False), - (RobustScaler(with_centering=False), robust_scale, True)] + "est, func, support_sparse, strictly_positive", + [(MaxAbsScaler(), maxabs_scale, True, False), + (MinMaxScaler(), minmax_scale, False, False), + (StandardScaler(), scale, False, False), + (StandardScaler(with_mean=False), scale, True, False), + (PowerTransformer('yeo-johnson'), power_transform, False, False), + (PowerTransformer('box-cox'), power_transform, False, True), + (QuantileTransformer(n_quantiles=10), quantile_transform, True, False), + (RobustScaler(), robust_scale, False, False), + (RobustScaler(with_centering=False), robust_scale, True, False)] ) -def test_missing_value_handling(est, func, support_sparse): +def test_missing_value_handling(est, func, support_sparse, strictly_positive): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) X = iris.data.copy() n_missing = 50 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan + if strictly_positive: + X += np.nanmin(X) + 0.1 X_train, X_test = train_test_split(X, random_state=1) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() From 0c543e34161552e3b1d39b197fe956f2c8804dc5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 18:12:15 -0400 Subject: [PATCH 24/38] Fixed comment issue --- examples/preprocessing/map_data_to_normal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/map_data_to_normal.py b/examples/preprocessing/map_data_to_normal.py index e2f69a3f9d52d..e068ad5bd4bde 100644 --- a/examples/preprocessing/map_data_to_normal.py +++ b/examples/preprocessing/map_data_to_normal.py @@ -24,7 +24,7 @@ """ # Author: Eric Chang - Nicolas Hug +# Nicolas Hug # License: BSD 3 clause import numpy as np @@ -42,8 +42,8 @@ rng = np.random.RandomState(304) -bc = PowerTransformer(method='box-cox', standardize=False) -yj = PowerTransformer(method='yeo-johnson', standardize=False) +bc = PowerTransformer(method='box-cox', standardize=True) +yj = PowerTransformer(method='yeo-johnson', standardize=True) qt = QuantileTransformer(output_distribution='normal', random_state=rng) size = (N_SAMPLES, 1) From a0d86a0c1937d0b3fae6ef3321ad86e6356cce25 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 18:53:51 -0400 Subject: [PATCH 25/38] Updated example --- examples/preprocessing/map_data_to_normal.py | 37 ++++++++++++-------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/examples/preprocessing/map_data_to_normal.py b/examples/preprocessing/map_data_to_normal.py index e068ad5bd4bde..02210e367c781 100644 --- a/examples/preprocessing/map_data_to_normal.py +++ b/examples/preprocessing/map_data_to_normal.py @@ -15,12 +15,20 @@ Note that the transformations successfully map the data to a normal distribution when applied to certain datasets, but are ineffective with others. This highlights the importance of visualizing the data before and after -transformation. Also note that while the standardize option is set to False for -the plot examples, by default, :class:`preprocessing.PowerTransformer` also -applies zero-mean, unit-variance standardization to the transformed outputs. +transformation. + +Also note that even though Box-Cox seems to perform better than Yeo-Johnson for +lognormal and chi-squared distributions, keep in mind that Box-Cox does not +support inputs with negative values. For comparison, we also add the output from -:class:`preprocessing.QuantileTransformer`. +:class:`preprocessing.QuantileTransformer`. It can force any arbitrary +distribution into a gaussian, provided that there are enough training samples +(thousands). Because it is a non-parametric method, it is harder to interpret +than the parametric ones (Box-Cox and Yeo-Johnson). + +On "small" datasets (less than a few hundred points), the quantile transformer +is prone to overfitting. The use of the power transform is then recommended. """ # Author: Eric Chang @@ -30,20 +38,21 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn.preprocessing import PowerTransformer, minmax_scale +from sklearn.preprocessing import PowerTransformer from sklearn.preprocessing import QuantileTransformer +from sklearn.model_selection import train_test_split print(__doc__) -N_SAMPLES = 3000 +N_SAMPLES = 1000 FONT_SIZE = 6 BINS = 'auto' rng = np.random.RandomState(304) -bc = PowerTransformer(method='box-cox', standardize=True) -yj = PowerTransformer(method='yeo-johnson', standardize=True) +bc = PowerTransformer(method='box-cox') +yj = PowerTransformer(method='yeo-johnson') qt = QuantileTransformer(output_distribution='normal', random_state=rng) size = (N_SAMPLES, 1) @@ -95,19 +104,18 @@ for distribution, color, axes in zip(distributions, colors, axes_list): name, X = distribution - # scale all distributions to the range [0, 10] - X = minmax_scale(X, feature_range=(1e-10, 10)) + X_train, X_test = train_test_split(X, test_size=.5) # perform power transforms and quantile transform - X_trans_bc = bc.fit_transform(X) + X_trans_bc = bc.fit(X_train).transform(X_test) lmbda_bc = round(bc.lambdas_[0], 2) - X_trans_yj = yj.fit_transform(X) + X_trans_yj = yj.fit(X_train).transform(X_test) lmbda_yj = round(yj.lambdas_[0], 2) - X_trans_qt = qt.fit_transform(X) + X_trans_qt = qt.fit(X_train).transform(X_test) ax_original, ax_bc, ax_yj, ax_qt = axes - ax_original.hist(X, color=color, bins=BINS) + ax_original.hist(X_train, color=color, bins=BINS) ax_original.set_title(name, fontsize=FONT_SIZE) ax_original.tick_params(axis='both', which='major', labelsize=FONT_SIZE) @@ -122,6 +130,7 @@ title += ', $\lambda$ = {}'.format(lmbda) ax.set_title(title, fontsize=FONT_SIZE) ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE) + ax.set_xlim([-3.5, 3.5]) plt.tight_layout() From 7b18937c711b4a3a7cf60e5a3c858e68b021472f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 18:56:13 -0400 Subject: [PATCH 26/38] fixed minor typos --- doc/glossary.rst | 2 +- sklearn/preprocessing/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/glossary.rst b/doc/glossary.rst index cea07ed1a5cfc..049139370036e 100644 --- a/doc/glossary.rst +++ b/doc/glossary.rst @@ -294,7 +294,7 @@ General Concepts convergence of the training loss, to avoid over-fitting. This is generally done by monitoring the generalization score on a validation set. When available, it is activated through the parameter - ``early_stopping`` or by setting a postive :term:`n_iter_no_change`. + ``early_stopping`` or by setting a positive :term:`n_iter_no_change`. estimator instance We sometimes use this terminology to distinguish an :term:`estimator` diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 7ff7945178582..097860d5800af 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2415,7 +2415,7 @@ class PowerTransformer(BaseEstimator, TransformerMixin): method : str, (default='yeo-johnson') The power transform method. Available methods are: - - 'yeo-johnson' [1]_, works with postive and negative values + - 'yeo-johnson' [1]_, works with positive and negative values - 'box-cox' [2]_, only works with strictly positive values standardize : boolean, default=True From 800a2c2319c3c7a2ad687ac8a7855a4954950d78 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 19:10:38 -0400 Subject: [PATCH 27/38] Updated comment in whatsnew following ogrisel comments --- doc/whats_new/v0.20.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 6f66db20b9d4a..11ffc64674c3e 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -138,10 +138,11 @@ Preprocessing and :issue:`11315` by :user:`Thomas Fan `. - Added :class:`preprocessing.PowerTransformer`, which implements the - Yeo-Johnson and Box-Cox power transformations, allowing users to map data - from any distribution to a Gaussian distribution. This is useful as a - variance-stabilizing transformation in situations where normality and - homoscedasticity are desirable. + Yeo-Johnson and Box-Cox power transformations. Power transformations try to + find a set of feature-wise parametric transformations to approximately map + data to a Gaussian distribution centered at zero and with unit variance. + This is useful as a variance-stabilizing transformation in situations where + normality and homoscedasticity are desirable. :issue:`10210` by :user:`Eric Chang ` and :user:`Maniteja Nandana `, and :issue:`11520` by :user:`Nicolas Hug `. From 53afa9f10fa36a0d0c7081a4bf2f2dd1a55c6c1d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 16 Jul 2018 19:15:51 -0400 Subject: [PATCH 28/38] Renamed plot example --- doc/modules/preprocessing.rst | 4 ++-- .../{map_data_to_normal.py => plot_map_data_to_normal.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename examples/preprocessing/{map_data_to_normal.py => plot_map_data_to_normal.py} (100%) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index f44d5ad748141..1d479892de826 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -359,8 +359,8 @@ transforms achieve very Gaussian-like results, but with others, they are ineffective. This highlights the importance of visualizing the data before and after transformation. -.. figure:: ../auto_examples/preprocessing/images/sphx_glr_map_data_to_normal_001.png - :target: ../auto_examples/preprocessing/map_data_to_normal.html +.. figure:: ../auto_examples/preprocessing/images/sphx_glr_plot_map_data_to_normal_001.png + :target: ../auto_examples/preprocessing/plot_map_data_to_normal.html :align: center :scale: 100 diff --git a/examples/preprocessing/map_data_to_normal.py b/examples/preprocessing/plot_map_data_to_normal.py similarity index 100% rename from examples/preprocessing/map_data_to_normal.py rename to examples/preprocessing/plot_map_data_to_normal.py From a0d97eea65e88dc79bf21aa649a5808197a47743 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 11:28:17 -0400 Subject: [PATCH 29/38] Should fix test for Python 2.7 --- sklearn/preprocessing/tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 9c3cfb3c3707e..f991dab3fe533 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2171,7 +2171,7 @@ def test_optimization_power_transformer(method, lmbda): # - check that X_inv_trans is roughly equal to X rng = np.random.RandomState(0) - n_samples = 1000 + n_samples = 20000 X = rng.normal(loc=0, scale=1, size=(n_samples, 1)) pt = PowerTransformer(method=method, standardize=False) @@ -2204,7 +2204,7 @@ def test_power_transformer_nans(method): pt.fit(X) lmbda_nans = pt.lambdas_[0] - assert_almost_equal(lmbda_no_nans, lmbda_nans, decimal=7) + assert_almost_equal(lmbda_no_nans, lmbda_nans, decimal=5) X_trans = pt.transform(X) assert_array_equal(np.isnan(X_trans), np.isnan(X)) From 23c3ddd72d5ce0a82136cf45b78debeafbdb64ba Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 11:28:42 -0400 Subject: [PATCH 30/38] Should fix example plot --- examples/preprocessing/plot_map_data_to_normal.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/plot_map_data_to_normal.py b/examples/preprocessing/plot_map_data_to_normal.py index 02210e367c781..b4515d77fcc3d 100644 --- a/examples/preprocessing/plot_map_data_to_normal.py +++ b/examples/preprocessing/plot_map_data_to_normal.py @@ -47,7 +47,7 @@ N_SAMPLES = 1000 FONT_SIZE = 6 -BINS = 'auto' +BINS = 30#'auto' rng = np.random.RandomState(304) @@ -94,7 +94,7 @@ colors = ['firebrick', 'darkorange', 'goldenrod', 'seagreen', 'royalblue', 'darkorchid'] -fig, axes = plt.subplots(nrows=8, ncols=3, figsize=plt.figaspect(3)) +fig, axes = plt.subplots(nrows=8, ncols=3, figsize=plt.figaspect(2)) axes = axes.flatten() axes_idxs = [(0, 3, 6, 9), (1, 4, 7, 10), (2, 5, 8, 11), (12, 15, 18, 21), (13, 16, 19, 22), (14, 17, 20, 23)] @@ -125,13 +125,14 @@ ('Box-Cox', 'Yeo-Johnson', 'Quantile transform'), (lmbda_bc, lmbda_yj, None)): ax.hist(X_trans, color=color, bins=BINS) - title = '{} after {}'.format(name, meth_name) + title = 'After {}'.format(meth_name) if lmbda is not None: - title += ', $\lambda$ = {}'.format(lmbda) + title += '\n$\lambda$ = {}'.format(lmbda) ax.set_title(title, fontsize=FONT_SIZE) ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE) ax.set_xlim([-3.5, 3.5]) plt.tight_layout() -plt.show() +plt.savefig('lol.png') +#plt.show() From 0c3b268c3a97cb4271974c8db013193a87605302 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 12:18:50 -0400 Subject: [PATCH 31/38] Addressed comments from TomDLT --- sklearn/preprocessing/data.py | 34 ++++++++++++++---------- sklearn/preprocessing/tests/test_data.py | 2 +- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 097860d5800af..2bbb4b4b61772 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2495,31 +2495,31 @@ def fit(self, X, y=None): """ X = self._check_input(X, check_positive=True, check_method=True) - self.lambdas_ = [] - transformed = [] - optim_function = {'box-cox': self._box_cox_optimize, 'yeo-johnson': self._yeo_johnson_optimize }[self.method] - transform_function = {'box-cox': boxcox, - 'yeo-johnson': self._yeo_johnson_transform - }[self.method] - + self.lambdas_ = [] for col in X.T: with np.errstate(invalid='ignore'): # hide NaN warnings lmbda = optim_function(col) self.lambdas_.append(lmbda) - - col_trans = transform_function(col, lmbda) - transformed.append(col_trans) - self.lambdas_ = np.array(self.lambdas_) - transformed = np.array(transformed) if self.standardize: + transform_function = {'box-cox': boxcox, + 'yeo-johnson': self._yeo_johnson_transform + }[self.method] + transformed = [] + for col, lmbda in zip(X.T, self.lambdas_): + with np.errstate(invalid='ignore'): # hide NaN warnings + col_trans = transform_function(col, lmbda) + transformed.append(col_trans) + transformed = np.array(transformed) + self._scaler = StandardScaler() self._scaler.fit(X=transformed.T) + return self def transform(self, X): @@ -2529,6 +2529,11 @@ def transform(self, X): ---------- X : array-like, shape (n_samples, n_features) The data to be transformed using a power transformation. + + Returns + ------- + X_trans : array-like, shape (n_samples, n_features) + The transformed data. """ check_is_fitted(self, 'lambdas_') X = self._check_input(X, check_positive=True, check_shape=True) @@ -2688,7 +2693,7 @@ def _neg_log_likelihood(lmbda): # the computation of lambda is influenced by NaNs so we need to # get rid of them x = x[~np.isnan(x)] - # choosing backet -2, 2 like for boxcox + # choosing bracket -2, 2 like for boxcox return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) def _check_input(self, X, check_positive=False, check_shape=False, @@ -2700,7 +2705,8 @@ def _check_input(self, X, check_positive=False, check_shape=False, X : array-like, shape (n_samples, n_features) check_positive : bool - If True, check that all data is positive and non-zero. + If True, check that all data is positive and non-zero (only if + self.method is box-cox). check_shape : bool If True, check that n_features matches the length of self.lambdas_ diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f991dab3fe533..295c80733454c 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2077,7 +2077,7 @@ def test_power_transformer_2d(): def test_power_transformer_boxcox_strictly_positive_exception(): # Exceptions should be raised for negative arrays and zero arrays when - # method is coxbox + # method is boxcox pt = PowerTransformer(method='box-cox') pt.fit(np.abs(X_2d)) From 7be037647f8d991777feae6e71f3bd588a8dc951 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 17 Jul 2018 19:31:36 +0200 Subject: [PATCH 32/38] OPTIM implement fit_transform --- sklearn/preprocessing/data.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 2bbb4b4b61772..de521eda50b32 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2477,6 +2477,13 @@ def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.copy = copy def fit(self, X, y=None): + self._fit(X, y=y, force_compute_transform=False) + return self + + def fit_transform(self, X, y=None): + return self._fit(X, y, force_compute_transform=True) + + def _fit(self, X, y=None, force_compute_transform=False): """Estimate the optimal parameter lambda for each feature. The optimal lambda parameter for minimizing skewness is estimated on @@ -2505,7 +2512,7 @@ def fit(self, X, y=None): self.lambdas_.append(lmbda) self.lambdas_ = np.array(self.lambdas_) - if self.standardize: + if self.standardize or force_compute_transform: transform_function = {'box-cox': boxcox, 'yeo-johnson': self._yeo_johnson_transform }[self.method] @@ -2514,13 +2521,16 @@ def fit(self, X, y=None): with np.errstate(invalid='ignore'): # hide NaN warnings col_trans = transform_function(col, lmbda) transformed.append(col_trans) - transformed = np.array(transformed) + transformed = np.array(transformed).T + if self.standardize: self._scaler = StandardScaler() - self._scaler.fit(X=transformed.T) - + if force_compute_transform: + transformed = self._scaler.fit_transform(transformed) + else: + self._scaler.fit(X=transformed) - return self + return transformed def transform(self, X): """Apply the power transform to each feature using the fitted lambdas. From 420476c26cbc10fb506a91a508c6596adc4fc0eb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 14:40:44 -0400 Subject: [PATCH 33/38] Added test fit_transform() == fit().transform() --- sklearn/preprocessing/data.py | 21 +++++++++++++-------- sklearn/preprocessing/tests/test_data.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index de521eda50b32..e4c1d0cf44e87 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2477,13 +2477,6 @@ def __init__(self, method='yeo-johnson', standardize=True, copy=True): self.copy = copy def fit(self, X, y=None): - self._fit(X, y=y, force_compute_transform=False) - return self - - def fit_transform(self, X, y=None): - return self._fit(X, y, force_compute_transform=True) - - def _fit(self, X, y=None, force_compute_transform=False): """Estimate the optimal parameter lambda for each feature. The optimal lambda parameter for minimizing skewness is estimated on @@ -2500,6 +2493,13 @@ def _fit(self, X, y=None, force_compute_transform=False): ------- self : object """ + self._fit(X, y=y, force_compute_transform=False) + return self + + def fit_transform(self, X, y=None): + return self._fit(X, y, force_compute_transform=True) + + def _fit(self, X, y=None, force_compute_transform=False): X = self._check_input(X, check_positive=True, check_method=True) optim_function = {'box-cox': self._box_cox_optimize, @@ -2512,11 +2512,11 @@ def _fit(self, X, y=None, force_compute_transform=False): self.lambdas_.append(lmbda) self.lambdas_ = np.array(self.lambdas_) + transformed = [] if self.standardize or force_compute_transform: transform_function = {'box-cox': boxcox, 'yeo-johnson': self._yeo_johnson_transform }[self.method] - transformed = [] for col, lmbda in zip(X.T, self.lambdas_): with np.errstate(invalid='ignore'): # hide NaN warnings col_trans = transform_function(col, lmbda) @@ -2585,6 +2585,11 @@ def inverse_transform(self, X): ---------- X : array-like, shape (n_samples, n_features) The transformed data. + + Returns + ------- + X : array-like, shape (n_samples, n_features) + The original data """ check_is_fitted(self, 'lambdas_') X = self._check_input(X, check_shape=True) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 295c80733454c..0ec3ac60bced8 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2208,3 +2208,15 @@ def test_power_transformer_nans(method): X_trans = pt.transform(X) assert_array_equal(np.isnan(X_trans), np.isnan(X)) + + +@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) +@pytest.mark.parametrize('standardize', [True, False]) +def test_power_transformer_fit_transform(method, standardize): + # check that fit_transform() and fit().transform() return the same values + X = X_1col + if method == 'box-cox': + X = np.abs(X) + + pt = PowerTransformer(method, standardize) + assert_array_almost_equal(pt.fit(X).transform(X), pt.fit_transform(X)) From 1287f948fe4edbb67c012c94f7a7ae54214218e9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 15:53:57 -0400 Subject: [PATCH 34/38] Added tests for the copy parameter --- sklearn/preprocessing/data.py | 28 ++++++------ sklearn/preprocessing/tests/test_data.py | 58 ++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index e4c1d0cf44e87..68f359cb35088 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2493,15 +2493,18 @@ def fit(self, X, y=None): ------- self : object """ - self._fit(X, y=y, force_compute_transform=False) + self._fit(X, y=y, force_transform=False) return self def fit_transform(self, X, y=None): - return self._fit(X, y, force_compute_transform=True) + return self._fit(X, y, force_transform=True) - def _fit(self, X, y=None, force_compute_transform=False): + def _fit(self, X, y=None, force_transform=False): X = self._check_input(X, check_positive=True, check_method=True) + if not self.copy and not force_transform: # if call from fit() + X = X.copy() # force copy so that fit does not change X inplace + optim_function = {'box-cox': self._box_cox_optimize, 'yeo-johnson': self._yeo_johnson_optimize }[self.method] @@ -2512,25 +2515,22 @@ def _fit(self, X, y=None, force_compute_transform=False): self.lambdas_.append(lmbda) self.lambdas_ = np.array(self.lambdas_) - transformed = [] - if self.standardize or force_compute_transform: + if self.standardize or force_transform: transform_function = {'box-cox': boxcox, 'yeo-johnson': self._yeo_johnson_transform }[self.method] - for col, lmbda in zip(X.T, self.lambdas_): + for i, lmbda in enumerate(self.lambdas_): with np.errstate(invalid='ignore'): # hide NaN warnings - col_trans = transform_function(col, lmbda) - transformed.append(col_trans) - transformed = np.array(transformed).T + X[:, i] = transform_function(X[:, i], lmbda) if self.standardize: - self._scaler = StandardScaler() - if force_compute_transform: - transformed = self._scaler.fit_transform(transformed) + self._scaler = StandardScaler(copy=self.copy) + if force_transform: + X = self._scaler.fit_transform(X) else: - self._scaler.fit(X=transformed) + self._scaler.fit(X) - return transformed + return X def transform(self, X): """Apply the power transform to each feature using the fitted lambdas. diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 0ec3ac60bced8..f5ea7a9dd8edc 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2220,3 +2220,61 @@ def test_power_transformer_fit_transform(method, standardize): pt = PowerTransformer(method, standardize) assert_array_almost_equal(pt.fit(X).transform(X), pt.fit_transform(X)) + + +@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) +@pytest.mark.parametrize('standardize', [True, False]) +def test_power_transformer_copy_True(method, standardize): + # Check that neither fit, transform, fit_transform nor inverse_transform + # modify X inplace when copy=True + X = X_1col + if method == 'box-cox': + X = np.abs(X) + + X_original = X.copy() + assert X is not X_original # sanity checks + assert_array_almost_equal(X, X_original) + + pt = PowerTransformer(method, standardize, copy=True) + + pt.fit(X) + assert_array_almost_equal(X, X_original) + X_trans = pt.transform(X) + assert X_trans is not X + + X_trans = pt.fit_transform(X) + assert_array_almost_equal(X, X_original) + assert X_trans is not X + + X_inv_trans = pt.inverse_transform(X_trans) + assert X_trans is not X_inv_trans + + +@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson']) +@pytest.mark.parametrize('standardize', [True, False]) +def test_power_transformer_copy_False(method, standardize): + # check that when copy=False fit doesn't change X inplace but transform, + # fit_transform and inverse_transform do. + X = X_1col + if method == 'box-cox': + X = np.abs(X) + + X_original = X.copy() + assert X is not X_original # sanity checks + assert_array_almost_equal(X, X_original) + + pt = PowerTransformer(method, standardize, copy=False) + + pt.fit(X) + assert_array_almost_equal(X, X_original) # fit didn't change X + + X_trans = pt.transform(X) + assert X_trans is X + + if method == 'box-cox': + X = np.abs(X) + X_trans = pt.fit_transform(X) + assert X_trans is X + + X_inv_trans = pt.inverse_transform(X_trans) + assert X_trans is X_inv_trans From 0ce4b3633aab62e309616eb5c9baa9ee277677e5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 15:54:09 -0400 Subject: [PATCH 35/38] Fixed flake8 issues in example plot --- examples/preprocessing/plot_map_data_to_normal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/plot_map_data_to_normal.py b/examples/preprocessing/plot_map_data_to_normal.py index b4515d77fcc3d..b8b7625f3c02b 100644 --- a/examples/preprocessing/plot_map_data_to_normal.py +++ b/examples/preprocessing/plot_map_data_to_normal.py @@ -47,7 +47,7 @@ N_SAMPLES = 1000 FONT_SIZE = 6 -BINS = 30#'auto' +BINS = 30 rng = np.random.RandomState(304) @@ -134,5 +134,4 @@ plt.tight_layout() -plt.savefig('lol.png') -#plt.show() +plt.show() From 597a85dcced70489e20dff381ea11e99494aea1a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 16:32:10 -0400 Subject: [PATCH 36/38] set copy to False for the scaler --- sklearn/preprocessing/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 68f359cb35088..0ae8ebb80ee92 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2524,7 +2524,7 @@ def _fit(self, X, y=None, force_transform=False): X[:, i] = transform_function(X[:, i], lmbda) if self.standardize: - self._scaler = StandardScaler(copy=self.copy) + self._scaler = StandardScaler(copy=False) if force_transform: X = self._scaler.fit_transform(X) else: From 8022cc3a89390348664a2284ce6b5d9127d2d76f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 17 Jul 2018 18:06:12 -0400 Subject: [PATCH 37/38] Addressed comments from glemaitre --- sklearn/preprocessing/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 0ae8ebb80ee92..1273050d70cad 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2621,8 +2621,8 @@ def _yeo_johnson_inverse_transform(self, x, lmbda): """Return inverse-transformed input x following Yeo-Johnson inverse transform with parameter lambda. - Note - ---- + Notes + ----- We're comparing lmbda to 1e-19 instead of strict equality to 0. See scipy/special/_boxcox.pxd for a rationale behind this """ @@ -2648,8 +2648,8 @@ def _yeo_johnson_transform(self, x, lmbda): """Return transformed input x following Yeo-Johnson transform with parameter lambda. - Note - ---- + Notes + ----- We're comparing lmbda to 1e-19 instead of strict equality to 0. See scipy/special/_boxcox.pxd for a rationale behind this """ @@ -2721,7 +2721,7 @@ def _check_input(self, X, check_positive=False, check_shape=False, check_positive : bool If True, check that all data is positive and non-zero (only if - self.method is box-cox). + ``self.method=='box-cox'``). check_shape : bool If True, check that n_features matches the length of self.lambdas_ From c0a01dfc67f54afb494995f098146f85c453ec5f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jul 2018 13:23:42 -0400 Subject: [PATCH 38/38] Updated plot_all_scaling.py example --- examples/preprocessing/plot_all_scaling.py | 32 ++++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/preprocessing/plot_all_scaling.py b/examples/preprocessing/plot_all_scaling.py index 92cd635e2a06d..07fd3662da448 100755 --- a/examples/preprocessing/plot_all_scaling.py +++ b/examples/preprocessing/plot_all_scaling.py @@ -87,6 +87,8 @@ MaxAbsScaler().fit_transform(X)), ('Data after robust scaling', RobustScaler(quantile_range=(25, 75)).fit_transform(X)), + ('Data after power transformation (Yeo-Johnson)', + PowerTransformer(method='yeo-johnson').fit_transform(X)), ('Data after power transformation (Box-Cox)', PowerTransformer(method='box-cox').fit_transform(X)), ('Data after quantile transformation (gaussian pdf)', @@ -294,21 +296,21 @@ def make_plot(item_idx): make_plot(4) ############################################################################## -# PowerTransformer (Box-Cox) -# -------------------------- +# PowerTransformer +# ---------------- # -# ``PowerTransformer`` applies a power transformation to each -# feature to make the data more Gaussian-like. Currently, -# ``PowerTransformer`` implements the Box-Cox transform. The Box-Cox transform -# finds the optimal scaling factor to stabilize variance and mimimize skewness -# through maximum likelihood estimation. By default, ``PowerTransformer`` also -# applies zero-mean, unit variance normalization to the transformed output. -# Note that Box-Cox can only be applied to positive, non-zero data. Income and -# number of households happen to be strictly positive, but if negative values -# are present, a constant can be added to each feature to shift it into the -# positive range - this is known as the two-parameter Box-Cox transform. +# ``PowerTransformer`` applies a power transformation to each feature to make +# the data more Gaussian-like. Currently, ``PowerTransformer`` implements the +# Yeo-Johnson and Box-Cox transforms. The power transform finds the optimal +# scaling factor to stabilize variance and mimimize skewness through maximum +# likelihood estimation. By default, ``PowerTransformer`` also applies +# zero-mean, unit variance normalization to the transformed output. Note that +# Box-Cox can only be applied to strictly positive data. Income and number of +# households happen to be strictly positive, but if negative values are present +# the Yeo-Johnson transformed is to be preferred. make_plot(5) +make_plot(6) ############################################################################## # QuantileTransformer (Gaussian output) @@ -319,7 +321,7 @@ def make_plot(item_idx): # Note that this non-parametetric transformer introduces saturation artifacts # for extreme values. -make_plot(6) +make_plot(7) ################################################################### # QuantileTransformer (uniform output) @@ -337,7 +339,7 @@ def make_plot(item_idx): # any outlier by setting them to the a priori defined range boundaries (0 and # 1). -make_plot(7) +make_plot(8) ############################################################################## # Normalizer @@ -350,6 +352,6 @@ def make_plot(item_idx): # transformed data only lie in the positive quadrant. This would not be the # case if some original features had a mix of positive and negative values. -make_plot(8) +make_plot(9) plt.show()