Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import namedtuple
from typing import Optional, Union
from warnings import warn
Expand Down Expand Up @@ -60,12 +61,12 @@ class KernelAttributes:
def __new__(self, *args, **kwargs):
raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.")

slots = ("_handle", "_cache", "_backend_version", "_loader")
slots = ("_kernel", "_cache", "_backend_version", "_loader")

@classmethod
def _init(cls, handle):
def _init(cls, kernel):
self = super().__new__(cls)
self._handle = handle
self._kernel = weakref.ref(kernel)
self._cache = {}

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
Expand All @@ -74,20 +75,21 @@ def _init(cls, handle):

def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int:
"""Helper function to get a cached attribute or fetch and cache it if not present."""
if device_id in self._cache and attribute in self._cache[device_id]:
return self._cache[device_id][attribute]
if (device_id, attribute) in self._cache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about

        cache_key = (device_id, attribute)  # so this tuple doesn't have to be rebuilt
        result = self._cache.get(cache_key, cache_key)  # the tuple doubles as sentinel; there is only one cache lookup
        if result is not cache_key:
            return result

then all the way below

        self.cache[cache_key] = result
        return result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, thanks!

return self._cache[device_id, attribute]
kernel = self._kernel()
if kernel is None:
raise RuntimeError("Cannot access kernel attributes for expired Kernel object")
if self._backend_version == "new":
result = handle_return(self._loader["attribute"](attribute, self._handle, device_id))
result = handle_return(self._loader["attribute"](attribute, kernel._handle, device_id))
else: # "old" backend
warn(
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ",
RuntimeWarning,
stacklevel=2,
)
result = handle_return(self._loader["attribute"](attribute, self._handle))
if device_id not in self._cache:
self._cache[device_id] = {}
self._cache[device_id][attribute] = result
result = handle_return(self._loader["attribute"](attribute, kernel._handle))
self._cache[device_id, attribute] = result
return result

def max_threads_per_block(self, device_id: int = None) -> int:
Expand Down Expand Up @@ -365,7 +367,7 @@ class Kernel:

"""

__slots__ = ("_handle", "_module", "_attributes", "_occupancy")
__slots__ = ("_handle", "_module", "_attributes", "_occupancy", "__weakref__")

def __new__(self, *args, **kwargs):
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
Expand All @@ -385,7 +387,7 @@ def _from_obj(cls, obj, mod):
def attributes(self) -> KernelAttributes:
"""Get the read-only attributes of this kernel."""
if self._attributes is None:
self._attributes = KernelAttributes._init(self._handle)
self._attributes = KernelAttributes._init(self)
return self._attributes

def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:
Expand Down