Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import namedtuple
from typing import Optional, Union
from warnings import warn
Expand Down Expand Up @@ -60,12 +61,12 @@ class KernelAttributes:
def __new__(self, *args, **kwargs):
raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.")

slots = ("_handle", "_cache", "_backend_version", "_loader")
slots = ("_kernel", "_cache", "_backend_version", "_loader")

@classmethod
def _init(cls, handle):
def _init(cls, kernel):
self = super().__new__(cls)
self._handle = handle
self._kernel = weakref.ref(kernel)
self._cache = {}

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
Expand All @@ -74,20 +75,23 @@ def _init(cls, handle):

def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int:
"""Helper function to get a cached attribute or fetch and cache it if not present."""
if device_id in self._cache and attribute in self._cache[device_id]:
return self._cache[device_id][attribute]
cache_key = device_id, attribute
result = self._cache.get(cache_key, cache_key)
if result is not cache_key:
return result
kernel = self._kernel()
if kernel is None:
raise RuntimeError("Cannot access kernel attributes for expired Kernel object")
if self._backend_version == "new":
result = handle_return(self._loader["attribute"](attribute, self._handle, device_id))
result = handle_return(self._loader["attribute"](attribute, kernel._handle, device_id))
else: # "old" backend
warn(
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ",
RuntimeWarning,
stacklevel=2,
)
result = handle_return(self._loader["attribute"](attribute, self._handle))
if device_id not in self._cache:
self._cache[device_id] = {}
self._cache[device_id][attribute] = result
result = handle_return(self._loader["attribute"](attribute, kernel._handle))
self._cache[cache_key] = result
return result

def max_threads_per_block(self, device_id: int = None) -> int:
Expand Down Expand Up @@ -365,7 +369,7 @@ class Kernel:

"""

__slots__ = ("_handle", "_module", "_attributes", "_occupancy")
__slots__ = ("_handle", "_module", "_attributes", "_occupancy", "__weakref__")

def __new__(self, *args, **kwargs):
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
Expand All @@ -385,7 +389,7 @@ def _from_obj(cls, obj, mod):
def attributes(self) -> KernelAttributes:
"""Get the read-only attributes of this kernel."""
if self._attributes is None:
self._attributes = KernelAttributes._init(self._handle)
self._attributes = KernelAttributes._init(self)
return self._attributes

def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:
Expand Down
Loading