Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6daacba

Browse files
authored
Updates KernelAttributes to avoid possible dangling handles. (#957)
* Updates KernelAttributes to avoid possible dangling handles. * Simplifies the caching logic in KernelAttributes.
1 parent 978154c commit 6daacba

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import weakref
56
from collections import namedtuple
67
from typing import Optional, Union
78
from warnings import warn
@@ -60,12 +61,12 @@ class KernelAttributes:
6061
def __new__(self, *args, **kwargs):
6162
raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.")
6263

63-
slots = ("_handle", "_cache", "_backend_version", "_loader")
64+
slots = ("_kernel", "_cache", "_backend_version", "_loader")
6465

6566
@classmethod
66-
def _init(cls, handle):
67+
def _init(cls, kernel):
6768
self = super().__new__(cls)
68-
self._handle = handle
69+
self._kernel = weakref.ref(kernel)
6970
self._cache = {}
7071

7172
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
@@ -74,20 +75,23 @@ def _init(cls, handle):
7475

7576
def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int:
7677
"""Helper function to get a cached attribute or fetch and cache it if not present."""
77-
if device_id in self._cache and attribute in self._cache[device_id]:
78-
return self._cache[device_id][attribute]
78+
cache_key = device_id, attribute
79+
result = self._cache.get(cache_key, cache_key)
80+
if result is not cache_key:
81+
return result
82+
kernel = self._kernel()
83+
if kernel is None:
84+
raise RuntimeError("Cannot access kernel attributes for expired Kernel object")
7985
if self._backend_version == "new":
80-
result = handle_return(self._loader["attribute"](attribute, self._handle, device_id))
86+
result = handle_return(self._loader["attribute"](attribute, kernel._handle, device_id))
8187
else: # "old" backend
8288
warn(
8389
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ",
8490
RuntimeWarning,
8591
stacklevel=2,
8692
)
87-
result = handle_return(self._loader["attribute"](attribute, self._handle))
88-
if device_id not in self._cache:
89-
self._cache[device_id] = {}
90-
self._cache[device_id][attribute] = result
93+
result = handle_return(self._loader["attribute"](attribute, kernel._handle))
94+
self._cache[cache_key] = result
9195
return result
9296

9397
def max_threads_per_block(self, device_id: int = None) -> int:
@@ -365,7 +369,7 @@ class Kernel:
365369
366370
"""
367371

368-
__slots__ = ("_handle", "_module", "_attributes", "_occupancy")
372+
__slots__ = ("_handle", "_module", "_attributes", "_occupancy", "__weakref__")
369373

370374
def __new__(self, *args, **kwargs):
371375
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
@@ -385,7 +389,7 @@ def _from_obj(cls, obj, mod):
385389
def attributes(self) -> KernelAttributes:
386390
"""Get the read-only attributes of this kernel."""
387391
if self._attributes is None:
388-
self._attributes = KernelAttributes._init(self._handle)
392+
self._attributes = KernelAttributes._init(self)
389393
return self._attributes
390394

391395
def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:

0 commit comments

Comments
 (0)