diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 63bb6ff26..c659a8d78 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import weakref from collections import namedtuple from typing import Optional, Union from warnings import warn @@ -60,12 +61,12 @@ class KernelAttributes: def __new__(self, *args, **kwargs): raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.") - slots = ("_handle", "_cache", "_backend_version", "_loader") + slots = ("_kernel", "_cache", "_backend_version", "_loader") @classmethod - def _init(cls, handle): + def _init(cls, kernel): self = super().__new__(cls) - self._handle = handle + self._kernel = weakref.ref(kernel) self._cache = {} self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old" @@ -74,20 +75,23 @@ def _init(cls, handle): def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int: """Helper function to get a cached attribute or fetch and cache it if not present.""" - if device_id in self._cache and attribute in self._cache[device_id]: - return self._cache[device_id][attribute] + cache_key = device_id, attribute + result = self._cache.get(cache_key, cache_key) + if result is not cache_key: + return result + kernel = self._kernel() + if kernel is None: + raise RuntimeError("Cannot access kernel attributes for expired Kernel object") if self._backend_version == "new": - result = handle_return(self._loader["attribute"](attribute, self._handle, device_id)) + result = handle_return(self._loader["attribute"](attribute, kernel._handle, device_id)) else: # "old" backend warn( "Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ", RuntimeWarning, stacklevel=2, ) - result = handle_return(self._loader["attribute"](attribute, self._handle)) - if device_id not in self._cache: - self._cache[device_id] = {} - self._cache[device_id][attribute] = result + result = handle_return(self._loader["attribute"](attribute, kernel._handle)) + self._cache[cache_key] = result return result def max_threads_per_block(self, device_id: int = None) -> int: @@ -365,7 +369,7 @@ class Kernel: """ - __slots__ = ("_handle", "_module", "_attributes", "_occupancy") + __slots__ = ("_handle", "_module", "_attributes", "_occupancy", "__weakref__") def __new__(self, *args, **kwargs): raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.") @@ -385,7 +389,7 @@ def _from_obj(cls, obj, mod): def attributes(self) -> KernelAttributes: """Get the read-only attributes of this kernel.""" if self._attributes is None: - self._attributes = KernelAttributes._init(self._handle) + self._attributes = KernelAttributes._init(self) return self._attributes def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]: