diff --git a/doc/developers/performance.rst b/doc/developers/performance.rst index 36419894eafd6..c6fcc99b26102 100644 --- a/doc/developers/performance.rst +++ b/doc/developers/performance.rst @@ -341,9 +341,16 @@ memory alignment, direct blas calls... Using OpenMP ------------ -Since scikit-learn can be built without OpenMP, it's necessary to -protect each direct call to OpenMP. This can be done using the following -syntax:: +Since scikit-learn can be built without OpenMP, it's necessary to protect each +direct call to OpenMP. + +There are some helpers in +[sklearn/utils/_openmp_helpers.pyx](https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_openmp_helpers.pyx) +that you can reuse for the main useful functionalities and already have the +necessary protection to be built without OpenMP. + +If the helpers are not enough, you need to protect your OpenMP code using the +following syntax:: # importing OpenMP IF SKLEARN_OPENMP_PARALLELISM_ENABLED: diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index d03c7e5fa0b2a..c1bade148c988 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -1,5 +1,4 @@ cimport numpy as cnp -cimport openmp from cython cimport final from cython.operator cimport dereference as deref @@ -73,7 +72,8 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32( ) with nogil, parallel(num_threads=num_threads): - thread_num = openmp.omp_get_thread_num() + thread_num = _openmp_thread_num() + for i in prange(n, schedule='static'): # Upcasting the i-th row of X from float32 to float64 for j in range(d):