Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 66ec10f

Browse files
MAINT Clean up code in FastICA (#19796)
* FIX code cleanup in FastICA * keep syntax X_mean Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 5d8796b commit 66ec10f

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

sklearn/decomposition/_fastica.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,8 @@ def _fit(self, X, compute_sources=False):
427427
-------
428428
X_new : ndarray of shape (n_samples, n_components)
429429
"""
430-
431-
X = self._validate_data(X, copy=self.whiten, dtype=FLOAT_DTYPES,
432-
ensure_min_samples=2).T
430+
XT = self._validate_data(X, copy=self.whiten, dtype=FLOAT_DTYPES,
431+
ensure_min_samples=2).T
433432
fun_args = {} if self.fun_args is None else self.fun_args
434433
random_state = check_random_state(self.random_state)
435434

@@ -454,7 +453,7 @@ def g(x, fun_args):
454453
% self.fun
455454
)
456455

457-
n_samples, n_features = X.shape
456+
n_features, n_samples = XT.shape
458457

459458
n_components = self.n_components
460459
if not self.whiten and n_components is not None:
@@ -471,24 +470,24 @@ def g(x, fun_args):
471470
)
472471

473472
if self.whiten:
474-
# Centering the columns (ie the variables)
475-
X_mean = X.mean(axis=-1)
476-
X -= X_mean[:, np.newaxis]
473+
# Centering the features of X
474+
X_mean = XT.mean(axis=-1)
475+
XT -= X_mean[:, np.newaxis]
477476

478477
# Whitening and preprocessing by PCA
479-
u, d, _ = linalg.svd(X, full_matrices=False, check_finite=False)
478+
u, d, _ = linalg.svd(XT, full_matrices=False, check_finite=False)
480479

481480
del _
482481
K = (u / d).T[:n_components] # see (6.33) p.140
483482
del u, d
484-
X1 = np.dot(K, X)
483+
X1 = np.dot(K, XT)
485484
# see (13.6) p.267 Here X1 is white and data
486485
# in X has been projected onto a subspace by PCA
487-
X1 *= np.sqrt(n_features)
486+
X1 *= np.sqrt(n_samples)
488487
else:
489488
# X must be casted to floats to avoid typing issues with numpy
490489
# 2.0 and the line below
491-
X1 = as_float_array(X, copy=False) # copy has been taken care of
490+
X1 = as_float_array(XT, copy=False) # copy has been taken care of
492491

493492
w_init = self.w_init
494493
if w_init is None:
@@ -519,9 +518,9 @@ def g(x, fun_args):
519518

520519
if compute_sources:
521520
if self.whiten:
522-
S = np.linalg.multi_dot([W, K, X]).T
521+
S = np.linalg.multi_dot([W, K, XT]).T
523522
else:
524-
S = np.dot(W, X).T
523+
S = np.dot(W, XT).T
525524
else:
526525
S = None
527526

0 commit comments

Comments
 (0)