diff --git a/setup.py b/setup.py index a6b6e33afd7b0..3e2887cbea89e 100755 --- a/setup.py +++ b/setup.py @@ -200,7 +200,7 @@ def build_extensions(self): print(f"Using old NumPy C API (version 1.7) for extension {ext.name}") if sklearn._OPENMP_SUPPORTED: - openmp_flag = get_openmp_flag(self.compiler) + openmp_flag = get_openmp_flag() for e in self.extensions: e.extra_compile_args += openmp_flag diff --git a/sklearn/_build_utils/openmp_helpers.py b/sklearn/_build_utils/openmp_helpers.py index b89d8e97f95c6..ed9bf0ea3eea0 100644 --- a/sklearn/_build_utils/openmp_helpers.py +++ b/sklearn/_build_utils/openmp_helpers.py @@ -12,12 +12,7 @@ from .pre_build_helpers import compile_test_program -def get_openmp_flag(compiler): - if hasattr(compiler, "compiler"): - compiler = compiler.compiler[0] - else: - compiler = compiler.__class__.__name__ - +def get_openmp_flag(): if sys.platform == "win32": return ["/openmp"] elif sys.platform == "darwin" and "openmp" in os.getenv("CPPFLAGS", ""): @@ -66,7 +61,7 @@ def check_openmp_support(): if flag.startswith(("-L", "-Wl,-rpath", "-l", "-Wl,--sysroot=/")) ] - extra_postargs = get_openmp_flag + extra_postargs = get_openmp_flag() openmp_exception = None try: diff --git a/sklearn/_build_utils/pre_build_helpers.py b/sklearn/_build_utils/pre_build_helpers.py index 9068390f2afad..2c0e5ef3ada47 100644 --- a/sklearn/_build_utils/pre_build_helpers.py +++ b/sklearn/_build_utils/pre_build_helpers.py @@ -10,18 +10,11 @@ from setuptools.command.build_ext import customize_compiler, new_compiler -def compile_test_program(code, extra_preargs=[], extra_postargs=[]): +def compile_test_program(code, extra_preargs=None, extra_postargs=None): """Check that some C code can be compiled and run""" ccompiler = new_compiler() customize_compiler(ccompiler) - # extra_(pre/post)args can be a callable to make it possible to get its - # value from the compiler - if callable(extra_preargs): - extra_preargs = extra_preargs(ccompiler) - if callable(extra_postargs): - extra_postargs = extra_postargs(ccompiler) - start_dir = os.path.abspath(".") with tempfile.TemporaryDirectory() as tmp_dir: