|
18 | 18 | import numpy as np
|
19 | 19 | import scipy.sparse as sp
|
20 | 20 |
|
21 |
| -from .. import _threadpool_controller |
22 | 21 | from ..base import (
|
23 | 22 | BaseEstimator,
|
24 | 23 | ClassNamePrefixFeaturesOutMixin,
|
|
32 | 31 | from ..utils._openmp_helpers import _openmp_effective_n_threads
|
33 | 32 | from ..utils._param_validation import Interval, StrOptions, validate_params
|
34 | 33 | from ..utils.extmath import row_norms, stable_cumsum
|
| 34 | +from ..utils.parallel import ( |
| 35 | + _get_threadpool_controller, |
| 36 | + _threadpool_controller_decorator, |
| 37 | +) |
35 | 38 | from ..utils.sparsefuncs import mean_variance_axis
|
36 | 39 | from ..utils.sparsefuncs_fast import assign_rows_csr
|
37 | 40 | from ..utils.validation import (
|
@@ -624,7 +627,7 @@ def _kmeans_single_elkan(
|
624 | 627 |
|
625 | 628 | # Threadpoolctl context to limit the number of threads in second level of
|
626 | 629 | # nested parallelism (i.e. BLAS) to avoid oversubscription.
|
627 |
| -@_threadpool_controller.wrap(limits=1, user_api="blas") |
| 630 | +@_threadpool_controller_decorator(limits=1, user_api="blas") |
628 | 631 | def _kmeans_single_lloyd(
|
629 | 632 | X,
|
630 | 633 | sample_weight,
|
@@ -827,7 +830,7 @@ def _labels_inertia(X, sample_weight, centers, n_threads=1, return_inertia=True)
|
827 | 830 |
|
828 | 831 |
|
829 | 832 | # Same as _labels_inertia but in a threadpool_limits context.
|
830 |
| -_labels_inertia_threadpool_limit = _threadpool_controller.wrap( |
| 833 | +_labels_inertia_threadpool_limit = _threadpool_controller_decorator( |
831 | 834 | limits=1, user_api="blas"
|
832 | 835 | )(_labels_inertia)
|
833 | 836 |
|
@@ -922,7 +925,7 @@ def _check_mkl_vcomp(self, X, n_samples):
|
922 | 925 |
|
923 | 926 | n_active_threads = int(np.ceil(n_samples / CHUNK_SIZE))
|
924 | 927 | if n_active_threads < self._n_threads:
|
925 |
| - modules = _threadpool_controller.info() |
| 928 | + modules = _get_threadpool_controller().info() |
926 | 929 | has_vcomp = "vcomp" in [module["prefix"] for module in modules]
|
927 | 930 | has_mkl = ("mkl", "intel") in [
|
928 | 931 | (module["internal_api"], module.get("threading_layer", None))
|
@@ -2144,7 +2147,7 @@ def fit(self, X, y=None, sample_weight=None):
|
2144 | 2147 |
|
2145 | 2148 | n_steps = (self.max_iter * n_samples) // self._batch_size
|
2146 | 2149 |
|
2147 |
| - with _threadpool_controller.limit(limits=1, user_api="blas"): |
| 2150 | + with _get_threadpool_controller().limit(limits=1, user_api="blas"): |
2148 | 2151 | # Perform the iterative optimization until convergence
|
2149 | 2152 | for i in range(n_steps):
|
2150 | 2153 | # Sample a minibatch from the full dataset
|
@@ -2270,7 +2273,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
|
2270 | 2273 | # Initialize number of samples seen since last reassignment
|
2271 | 2274 | self._n_since_last_reassign = 0
|
2272 | 2275 |
|
2273 |
| - with _threadpool_controller.limit(limits=1, user_api="blas"): |
| 2276 | + with _get_threadpool_controller().limit(limits=1, user_api="blas"): |
2274 | 2277 | _mini_batch_step(
|
2275 | 2278 | X,
|
2276 | 2279 | sample_weight=sample_weight,
|
|
0 commit comments