1111from typing import Callable , Sequence
1212import cuda .tile as ct
1313from cuda .tile ._execution import TileDispatcher
14+ from cuda .tile ._exception import TileCompilerTimeoutError , TileCompilerExecutionError
1415from cuda .tile ._cext import default_tile_context
1516import random
1617import torch
@@ -29,6 +30,11 @@ def __init__(self, *, num_ctas=None, occupancy=None, opt_level=3, **kwargs):
2930 self .occupancy = occupancy
3031 self .opt_level = opt_level
3132
33+ def __getattr__ (self , name ):
34+ if name in self .kwargs :
35+ return self .kwargs [name ]
36+ raise AttributeError (f"Attribute { name } not found in { self .kwargs } " )
37+
3238 def __str__ (self ):
3339 res = []
3440 for k , v in self .kwargs .items ():
@@ -63,13 +69,14 @@ def __len__(self):
6369 def __getitem__ (self , index ):
6470 return self .configs [index ]
6571
66- def filter (self , named_args : dict [str , Any ]) -> bool :
72+ def filter (self , named_args : dict [str , Any ], cfg : Config ) -> bool :
6773 if self .predicate_fn is None :
6874 return True
69- predicate_sig = inspect .signature (self .predicate_fn )
70- predicate_keys = set (predicate_sig .parameters .keys ())
71- kwargs = {k : named_args [k ] for k in predicate_keys }
72- return self .predicate_fn (** kwargs )
75+ result = self .predicate_fn (named_args , cfg )
76+ if not isinstance (result , bool ):
77+ raise TypeError (f"Predicate function { self .predicate_fn .__name__ } must return "
78+ f"a boolean value, but returned { type (result ).__name__ } instead." )
79+ return result
7380
7481
7582def _shape_dtype_stride (arg : Any ) -> tuple [tuple [int , ...], str , tuple [int , ...] | None ]:
@@ -123,20 +130,6 @@ def _time_ms(run_once, *, get_args, stream, warmup=2, rep=10):
123130 return ms / max (1 , rep )
124131
125132
126- def _get_grid (grid_fn , named_args : dict [str , Any ]) -> tuple [int , ...]:
127- grid_sig = inspect .signature (grid_fn )
128- grid_keys = set (grid_sig .parameters .keys ())
129- kwargs = {}
130- for k in grid_keys :
131- if k not in named_args :
132- raise TypeError (
133- f"Function parameter { k } in grid_fn is not in kernel parameters, "
134- f"available parameters are { list (named_args .keys ())} "
135- )
136- kwargs [k ] = named_args [k ]
137- return grid_fn (** kwargs )
138-
139-
140133@dataclass
141134class TunedResult :
142135 # The tuned parameters
@@ -156,10 +149,13 @@ def __getattr__(self, name):
156149
157150
158151def _make_trial_args (
159- args_fn : Callable , kwargs : dict [str , Any ], kernel , transforms : dict [str , Callable [[Any ], Any ]]
152+ args_fn : Callable [[Config ], tuple [Any , ...]],
153+ cfg : Config ,
154+ kernel : TileDispatcher ,
155+ transforms : dict [str , Callable [[Any ], Any ]]
160156) -> tuple [dict [str , Any ], tuple [Any , ...]]:
161157 """Make trial runtime arguments applying the transforms."""
162- args = args_fn (** kwargs )
158+ args = args_fn (cfg )
163159
164160 trial_named_args = {}
165161 trial_args = []
@@ -186,16 +182,6 @@ def _normalize_search_space(space: SearchSpace | Sequence[Config]) -> SearchSpac
186182 )
187183
188184
189- def _safe_args_fn (args_fn : Callable , kwargs : dict [str , Any ]) -> tuple [Any , ...]:
190- try :
191- return args_fn (** kwargs )
192- except TypeError :
193- raise TypeError (
194- f"Invalid parameters for args_fn, "
195- f"should be the same as the search space config argument keys: { list (kwargs .keys ())} "
196- )
197-
198-
199185@contextmanager
200186def compiler_timeout (timeout_sec : int ):
201187 old_timeout = default_tile_context .config .compiler_timeout_sec
@@ -219,7 +205,7 @@ def clear_cache(self, key=None):
219205
220206 def __call__ (self ,
221207 stream , grid_fn , kernel ,
222- args_fn : Callable ,
208+ args_fn : Callable [[ Config ], tuple [ Any , ...]] ,
223209 transforms : dict [str , Callable ] = {},
224210 * ,
225211 key_fn = _default_key ,
@@ -228,24 +214,52 @@ def __call__(self,
228214 seed : int | None = None ,
229215 force_retune : bool = False ) -> TunedResult :
230216 """
217+ Run the autotuned kernel and return its result.
218+
231219 It performs the following steps:
232- 1) picks or reuses the cached config and kernel,
233- 2) runs the kernel with the best config,
220+ 1) picks a configuration from the search space or reuses the cached
221+ best configuration for the given key (unless ``force_retune=True``),
222+ 2) launches the kernel with the best configuration,
234223 3) returns the tuned result.
235224
236225 Args:
237- stream: The stream.
238- grid_fn: The grid function.
239- kernel: The kernel.
240- args_fn: The function from the search space parameters to the runtime arguments.
241- transforms: The transforms functions for runtime arguments if needed.
242- key_fn: The key function.
243- max_iter: The maximum number of valid condigurations to sample from the search space.
244- compiler_time_limit_sec: The compilation time limit for each kernel.
245- seed: The seed for the random number generator. Default is None.
246- force_retune: Force retuning even if the config is found in the cache. Default is False.
226+ stream:
227+ CUDA stream to use for all kernel launches during tuning and
228+ for the final run.
229+ grid_fn:
230+ Callable that takes the named arguments and a single
231+ positional :class:`Config` object and returns a tuple of grid
232+ dimensions.
233+ kernel:
234+ The kernel to autotune.
235+ args_fn:
236+ Callable that takes a single positional :class:`Config` and
237+ returns a tuple of runtime arguments for ``kernel``.
238+ transforms:
239+ Optional transform or sequence of transforms applied to the
240+ runtime arguments before each kernel launch. Use this to
241+ perform lightweight pre-/post-processing without changing
242+ the search space.
243+ key_fn:
244+ Optional function that maps the named arguments to a hashable
245+ cache key. When omitted, a default key derived from argument
246+ shapes/dtypes is used. The key is used to look up and store
247+ the best config in the autotuner cache.
248+ max_iter:
249+ Maximum number of (valid) configurations to sample from the
250+ search space.
251+ compiler_time_limit_sec:
252+ The compilation time limit for each kernel.
253+ seed:
254+ Optional seed for the random number generator used when
255+ sampling configurations. If ``None``, the global random number
256+ generator state is used.
257+ force_retune:
258+ If ``True``, ignore any cached best config for this key and
259+ re-run the search. The new best config is then written back
260+ to the cache.
247261 """
248- key = key_fn (kernel , _safe_args_fn ( args_fn , self ._search_space .configs [0 ]. kwargs ))
262+ key = key_fn (kernel , args_fn ( self ._search_space .configs [0 ]))
249263 if not force_retune and key in self ._cache :
250264 best_idx , best_grid , best_kernel = self ._cache [key ]
251265 logger .debug (f"Using cached config for key { key } : { self ._search_space [best_idx ]} " )
@@ -259,13 +273,13 @@ def __call__(self,
259273 break
260274 cfg = self ._search_space [cfg_idx ]
261275 trial_named_args , trial_args = _make_trial_args (
262- args_fn , cfg . kwargs , kernel , transforms
276+ args_fn , cfg , kernel , transforms
263277 )
264- if not self ._search_space .filter (trial_named_args ):
278+ if not self ._search_space .filter (trial_named_args , cfg ):
265279 logger .debug (f"Config { cfg } filtered out by predicate function" )
266280 continue
267281
268- grid = _get_grid ( grid_fn , trial_named_args )
282+ grid = grid_fn ( trial_named_args , cfg )
269283 updated_kernel = ct .kernel (
270284 kernel ._pyfunc ,
271285 num_ctas = cfg .num_ctas ,
@@ -280,11 +294,14 @@ def run_once(args):
280294 with compiler_timeout (compiler_time_limit_sec ):
281295 time_ms = _time_ms (
282296 run_once ,
283- get_args = lambda : _make_trial_args (args_fn , cfg . kwargs , kernel , transforms )[1 ], # noqa
297+ get_args = lambda : _make_trial_args (args_fn , cfg , kernel , transforms )[1 ], # noqa
284298 stream = stream ,
285299 )
286- except Exception as e :
287- logger .debug (f"{ cfg } failed to run: { e } " )
300+ except TileCompilerTimeoutError as e :
301+ logger .debug (f"{ cfg } compilation timeout: { e } " )
302+ continue
303+ except TileCompilerExecutionError as e :
304+ logger .debug (f"{ cfg } compilation error: { e } " )
288305 continue
289306
290307 if time_ms < best_time_ms :
@@ -303,7 +320,7 @@ def run_once(args):
303320 best_cfg = self ._search_space [best_idx ]
304321
305322 # Use the original runtime arguments to run the kernel with the best config
306- best_packed_args = args_fn (** best_cfg . kwargs )
323+ best_packed_args = args_fn (best_cfg )
307324 ct .launch (stream , best_grid , best_kernel , best_packed_args )
308325
309326 # Return the tuned result
0 commit comments