|
6 | 6 | import os
|
7 | 7 | from typing import Optional
|
8 | 8 |
|
9 |
| -from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS |
| 9 | +from cuda.pathfinder._headers import supported_nvidia_headers |
| 10 | +from cuda.pathfinder._headers.supported_nvidia_headers import IS_WINDOWS |
| 11 | +from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path |
10 | 12 | from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
|
11 | 13 |
|
12 | 14 |
|
13 |
| -@functools.cache |
14 |
| -def find_nvidia_header_directory(libname: str) -> Optional[str]: |
15 |
| - if libname != "nvshmem": |
16 |
| - raise RuntimeError(f"UNKNOWN {libname=}") |
| 15 | +def _abs_norm(path: Optional[str]) -> Optional[str]: |
| 16 | + if path: |
| 17 | + return os.path.normpath(os.path.abspath(path)) |
| 18 | + return None |
| 19 | + |
| 20 | + |
| 21 | +def _joined_isfile(dirpath: str, basename: str) -> bool: |
| 22 | + return os.path.isfile(os.path.join(dirpath, basename)) |
17 | 23 |
|
18 |
| - if libname == "nvshmem" and IS_WINDOWS: |
| 24 | + |
| 25 | +def _find_nvshmem_header_directory() -> Optional[str]: |
| 26 | + if IS_WINDOWS: |
19 | 27 | # nvshmem has no Windows support.
|
20 | 28 | return None
|
21 | 29 |
|
22 | 30 | # Installed from a wheel
|
23 | 31 | nvidia_sub_dirs = ("nvidia", "nvshmem", "include")
|
24 | 32 | hdr_dir: str # help mypy
|
25 | 33 | for hdr_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
|
26 |
| - nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") |
27 |
| - if os.path.isfile(nvshmem_h_path): |
| 34 | + if _joined_isfile(hdr_dir, "nvshmem.h"): |
28 | 35 | return hdr_dir
|
29 | 36 |
|
30 | 37 | conda_prefix = os.environ.get("CONDA_PREFIX")
|
31 | 38 | if conda_prefix and os.path.isdir(conda_prefix):
|
32 | 39 | hdr_dir = os.path.join(conda_prefix, "include")
|
33 |
| - nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") |
34 |
| - if os.path.isfile(nvshmem_h_path): |
| 40 | + if _joined_isfile(hdr_dir, "nvshmem.h"): |
35 | 41 | return hdr_dir
|
36 | 42 |
|
37 | 43 | for hdr_dir in sorted(glob.glob("/usr/include/nvshmem_*"), reverse=True):
|
38 |
| - nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") |
39 |
| - if os.path.isfile(nvshmem_h_path): |
| 44 | + if _joined_isfile(hdr_dir, "nvshmem.h"): |
40 | 45 | return hdr_dir
|
41 | 46 |
|
42 | 47 | return None
|
| 48 | + |
| 49 | + |
| 50 | +def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> Optional[str]: |
| 51 | + parts = [anchor_point] |
| 52 | + if libname == "nvvm": |
| 53 | + parts.append(libname) |
| 54 | + parts.append("include") |
| 55 | + idir = os.path.join(*parts) |
| 56 | + if libname == "cccl": |
| 57 | + cdir = os.path.join(idir, "cccl") # CTK 13 |
| 58 | + if _joined_isfile(cdir, h_basename): |
| 59 | + return cdir |
| 60 | + if _joined_isfile(idir, h_basename): |
| 61 | + return idir |
| 62 | + return None |
| 63 | + |
| 64 | + |
| 65 | +def _find_based_on_conda_layout(libname: str, h_basename: str, conda_prefix: str) -> Optional[str]: |
| 66 | + if IS_WINDOWS: |
| 67 | + anchor_point = os.path.join(conda_prefix, "Library") |
| 68 | + if not os.path.isdir(anchor_point): |
| 69 | + return None |
| 70 | + else: |
| 71 | + targets_include_path = glob.glob(os.path.join(conda_prefix, "targets", "*", "include")) |
| 72 | + if not targets_include_path: |
| 73 | + return None |
| 74 | + if len(targets_include_path) != 1: |
| 75 | + # Conda does not support multiple architectures. |
| 76 | + # QUESTION(PR#956): Do we want to issue a warning? |
| 77 | + return None |
| 78 | + anchor_point = os.path.dirname(targets_include_path[0]) |
| 79 | + return _find_based_on_ctk_layout(libname, h_basename, anchor_point) |
| 80 | + |
| 81 | + |
| 82 | +def _find_ctk_header_directory(libname: str) -> Optional[str]: |
| 83 | + h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_CTK[libname] |
| 84 | + candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK[libname] |
| 85 | + |
| 86 | + # Installed from a wheel |
| 87 | + for cdir in candidate_dirs: |
| 88 | + hdr_dir: str # help mypy |
| 89 | + for hdr_dir in find_sub_dirs_all_sitepackages(tuple(cdir.split("/"))): |
| 90 | + if _joined_isfile(hdr_dir, h_basename): |
| 91 | + return hdr_dir |
| 92 | + |
| 93 | + conda_prefix = os.getenv("CONDA_PREFIX") |
| 94 | + if conda_prefix: # noqa: SIM102 |
| 95 | + if result := _find_based_on_conda_layout(libname, h_basename, conda_prefix): |
| 96 | + return result |
| 97 | + |
| 98 | + cuda_home = get_cuda_home_or_path() |
| 99 | + if cuda_home: # noqa: SIM102 |
| 100 | + if result := _find_based_on_ctk_layout(libname, h_basename, cuda_home): |
| 101 | + return result |
| 102 | + |
| 103 | + return None |
| 104 | + |
| 105 | + |
| 106 | +@functools.cache |
| 107 | +def find_nvidia_header_directory(libname: str) -> Optional[str]: |
| 108 | + """Locate the header directory for a supported NVIDIA library. |
| 109 | +
|
| 110 | + Args: |
| 111 | + libname (str): The short name of the library whose headers are needed |
| 112 | + (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``). |
| 113 | +
|
| 114 | + Returns: |
| 115 | + str or None: Absolute path to the discovered header directory, or ``None`` |
| 116 | + if the headers cannot be found. |
| 117 | +
|
| 118 | + Raises: |
| 119 | + RuntimeError: If ``libname`` is not in the supported set. |
| 120 | +
|
| 121 | + Search order: |
| 122 | + 1. **NVIDIA Python wheels** |
| 123 | +
|
| 124 | + - Scan installed distributions (``site-packages``) for header layouts |
| 125 | + shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``). |
| 126 | +
|
| 127 | + 2. **Conda environments** |
| 128 | +
|
| 129 | + - Check Conda-style installation prefixes, which use platform-specific |
| 130 | + include directory layouts. |
| 131 | +
|
| 132 | + 3. **CUDA Toolkit environment variables** |
| 133 | +
|
| 134 | + - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order). |
| 135 | +
|
| 136 | + Notes: |
| 137 | + - The ``SUPPORTED_HEADERS_CTK`` dictionary maps each supported CUDA Toolkit |
| 138 | + (CTK) library to the name of its canonical header (e.g., ``"cublas" → |
| 139 | + "cublas.h"``). This is used to verify that the located directory is valid. |
| 140 | +
|
| 141 | + - The only supported non-CTK library at present is ``nvshmem``. |
| 142 | + """ |
| 143 | + |
| 144 | + if libname == "nvshmem": |
| 145 | + return _abs_norm(_find_nvshmem_header_directory()) |
| 146 | + |
| 147 | + if libname in supported_nvidia_headers.SUPPORTED_HEADERS_CTK: |
| 148 | + return _abs_norm(_find_ctk_header_directory(libname)) |
| 149 | + |
| 150 | + raise RuntimeError(f"UNKNOWN {libname=}") |
0 commit comments