2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
+ import weakref
5
6
from collections import namedtuple
6
7
from typing import Optional , Union
7
8
from warnings import warn
@@ -60,12 +61,12 @@ class KernelAttributes:
60
61
def __new__ (self , * args , ** kwargs ):
61
62
raise RuntimeError ("KernelAttributes cannot be instantiated directly. Please use Kernel APIs." )
62
63
63
- slots = ("_handle " , "_cache" , "_backend_version" , "_loader" )
64
+ slots = ("_kernel " , "_cache" , "_backend_version" , "_loader" )
64
65
65
66
@classmethod
66
- def _init (cls , handle ):
67
+ def _init (cls , kernel ):
67
68
self = super ().__new__ (cls )
68
- self ._handle = handle
69
+ self ._kernel = weakref . ref ( kernel )
69
70
self ._cache = {}
70
71
71
72
self ._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
@@ -74,20 +75,23 @@ def _init(cls, handle):
74
75
75
76
def _get_cached_attribute (self , device_id : int , attribute : driver .CUfunction_attribute ) -> int :
76
77
"""Helper function to get a cached attribute or fetch and cache it if not present."""
77
- if device_id in self ._cache and attribute in self ._cache [device_id ]:
78
- return self ._cache [device_id ][attribute ]
78
+ cache_key = device_id , attribute
79
+ result = self ._cache .get (cache_key , cache_key )
80
+ if result is not cache_key :
81
+ return result
82
+ kernel = self ._kernel ()
83
+ if kernel is None :
84
+ raise RuntimeError ("Cannot access kernel attributes for expired Kernel object" )
79
85
if self ._backend_version == "new" :
80
- result = handle_return (self ._loader ["attribute" ](attribute , self ._handle , device_id ))
86
+ result = handle_return (self ._loader ["attribute" ](attribute , kernel ._handle , device_id ))
81
87
else : # "old" backend
82
88
warn (
83
89
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. " ,
84
90
RuntimeWarning ,
85
91
stacklevel = 2 ,
86
92
)
87
- result = handle_return (self ._loader ["attribute" ](attribute , self ._handle ))
88
- if device_id not in self ._cache :
89
- self ._cache [device_id ] = {}
90
- self ._cache [device_id ][attribute ] = result
93
+ result = handle_return (self ._loader ["attribute" ](attribute , kernel ._handle ))
94
+ self ._cache [cache_key ] = result
91
95
return result
92
96
93
97
def max_threads_per_block (self , device_id : int = None ) -> int :
@@ -365,7 +369,7 @@ class Kernel:
365
369
366
370
"""
367
371
368
- __slots__ = ("_handle" , "_module" , "_attributes" , "_occupancy" )
372
+ __slots__ = ("_handle" , "_module" , "_attributes" , "_occupancy" , "__weakref__" )
369
373
370
374
def __new__ (self , * args , ** kwargs ):
371
375
raise RuntimeError ("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs." )
@@ -385,7 +389,7 @@ def _from_obj(cls, obj, mod):
385
389
def attributes (self ) -> KernelAttributes :
386
390
"""Get the read-only attributes of this kernel."""
387
391
if self ._attributes is None :
388
- self ._attributes = KernelAttributes ._init (self . _handle )
392
+ self ._attributes = KernelAttributes ._init (self )
389
393
return self ._attributes
390
394
391
395
def _get_arguments_info (self , param_info = False ) -> tuple [int , list [ParamInfo ]]:
0 commit comments