1313
1414import numpy as np
1515from scipy import linalg
16-
1716from ..base import BaseEstimator , TransformerMixin , _ClassNamePrefixFeaturesOutMixin
1817from ..exceptions import ConvergenceWarning
1918
@@ -162,10 +161,12 @@ def fastica(
162161 max_iter = 200 ,
163162 tol = 1e-04 ,
164163 w_init = None ,
164+ whiten_solver = "svd" ,
165165 random_state = None ,
166166 return_X_mean = False ,
167167 compute_sources = True ,
168168 return_n_iter = False ,
169+ sign_flip = False ,
169170):
170171 """Perform Fast Independent Component Analysis.
171172
@@ -228,6 +229,18 @@ def my_g(x):
228229 Initial un-mixing array. If `w_init=None`, then an array of values
229230 drawn from a normal distribution is used.
230231
232+ whiten_solver : {"eigh", "svd"}, default="svd"
233+ The solver to use for whitening.
234+
235+ - "svd" is more stable numerically if the problem is degenerate, and
236+ often faster when `n_samples <= n_features`.
237+
238+ - "eigh" is generally more memory efficient when
239+ `n_samples >= n_features`, and can be faster when
240+ `n_samples >= 50 * n_features`.
241+
242+ .. versionadded:: 1.2
243+
231244 random_state : int, RandomState instance or None, default=None
232245 Used to initialize ``w_init`` when not specified, with a
233246 normal distribution. Pass an int, for reproducible results
@@ -244,6 +257,21 @@ def my_g(x):
244257 return_n_iter : bool, default=False
245258 Whether or not to return the number of iterations.
246259
260+ sign_flip : bool, default=False
261+ Used to determine whether to enable sign flipping during whitening for
262+ consistency in output between solvers.
263+
264+ - If `sign_flip=False` then the output of different choices for
265+ `whiten_solver` may not be equal. Both outputs will still be correct,
266+ but may differ numerically.
267+
268+ - If `sign_flip=True` then the output of both solvers will be
269+ reconciled during fit so that their outputs match. This may produce
270+ a different output for each solver when compared to
271+ `sign_flip=False`.
272+
273+ .. versionadded:: 1.2
274+
247275 Returns
248276 -------
249277 K : ndarray of shape (n_components, n_features) or None
@@ -300,7 +328,9 @@ def my_g(x):
300328 max_iter = max_iter ,
301329 tol = tol ,
302330 w_init = w_init ,
331+ whiten_solver = whiten_solver ,
303332 random_state = random_state ,
333+ sign_flip = sign_flip ,
304334 )
305335 S = est ._fit (X , compute_sources = compute_sources )
306336
@@ -378,12 +408,39 @@ def my_g(x):
378408 Initial un-mixing array. If `w_init=None`, then an array of values
379409 drawn from a normal distribution is used.
380410
411+ whiten_solver : {"eigh", "svd"}, default="svd"
412+ The solver to use for whitening.
413+
414+ - "svd" is more stable numerically if the problem is degenerate, and
415+ often faster when `n_samples <= n_features`.
416+
417+ - "eigh" is generally more memory efficient when
418+ `n_samples >= n_features`, and can be faster when
419+ `n_samples >= 50 * n_features`.
420+
421+ .. versionadded:: 1.2
422+
381423 random_state : int, RandomState instance or None, default=None
382424 Used to initialize ``w_init`` when not specified, with a
383425 normal distribution. Pass an int, for reproducible results
384426 across multiple function calls.
385427 See :term:`Glossary <random_state>`.
386428
429+ sign_flip : bool, default=False
430+ Used to determine whether to enable sign flipping during whitening for
431+ consistency in output between solvers.
432+
433+ - If `sign_flip=False` then the output of different choices for
434+ `whiten_solver` may not be equal. Both outputs will still be correct,
435+ but may differ numerically.
436+
437+ - If `sign_flip=True` then the output of both solvers will be
438+ reconciled during fit so that their outputs match. This may produce
439+ a different output for each solver when compared to
440+ `sign_flip=False`.
441+
442+ .. versionadded:: 1.2
443+
387444 Attributes
388445 ----------
389446 components_ : ndarray of shape (n_components, n_features)
@@ -457,7 +514,9 @@ def __init__(
457514 max_iter = 200 ,
458515 tol = 1e-4 ,
459516 w_init = None ,
517+ whiten_solver = "svd" ,
460518 random_state = None ,
519+ sign_flip = False ,
461520 ):
462521 super ().__init__ ()
463522 self .n_components = n_components
@@ -468,7 +527,9 @@ def __init__(
468527 self .max_iter = max_iter
469528 self .tol = tol
470529 self .w_init = w_init
530+ self .whiten_solver = whiten_solver
471531 self .random_state = random_state
532+ self .sign_flip = sign_flip
472533
473534 def _fit (self , X , compute_sources = False ):
474535 """Fit the model.
@@ -557,9 +618,33 @@ def g(x, fun_args):
557618 XT -= X_mean [:, np .newaxis ]
558619
559620 # Whitening and preprocessing by PCA
560- u , d , _ = linalg .svd (XT , full_matrices = False , check_finite = False )
621+ if self .whiten_solver == "eigh" :
622+ # Faster when num_samples >> n_features
623+ d , u = linalg .eigh (XT .dot (X ))
624+ sort_indices = np .argsort (d )[::- 1 ]
625+ eps = np .finfo (d .dtype ).eps
626+ degenerate_idx = d < eps
627+ if np .any (degenerate_idx ):
628+ warnings .warn (
629+ "There are some small singular values, using "
630+ "whiten_solver = 'svd' might lead to more "
631+ "accurate results."
632+ )
633+ d [degenerate_idx ] = eps # For numerical issues
634+ np .sqrt (d , out = d )
635+ d , u = d [sort_indices ], u [:, sort_indices ]
636+ elif self .whiten_solver == "svd" :
637+ u , d = linalg .svd (XT , full_matrices = False , check_finite = False )[:2 ]
638+ else :
639+ raise ValueError (
640+ "`whiten_solver` must be 'eigh' or 'svd' but got"
641+ f" { self .whiten_solver } instead"
642+ )
643+
644+ # Give consistent eigenvectors for both svd solvers
645+ if self .sign_flip :
646+ u *= np .sign (u [0 ])
561647
562- del _
563648 K = (u / d ).T [:n_components ] # see (6.33) p.140
564649 del u , d
565650 X1 = np .dot (K , XT )
0 commit comments