Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a5c9dfe

Browse files
committed
Fix autotuner args_fn/grid_fn/predicate_fn to not count on lambda function argument names
1 parent dc354a8 commit a5c9dfe

File tree

6 files changed

+204
-153
lines changed

6 files changed

+204
-153
lines changed

samples/AttentionFMHA.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,12 @@ def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
255255
# --- Tune/Get the best configuration for the FMHA Kernel ---
256256
tuned_result = autotuner(
257257
torch.cuda.current_stream(),
258-
grid_fn=lambda TILE_M: (math.ceil(SeqLen_Q / TILE_M), Batch * Heads, 1),
258+
grid_fn=lambda named_args, cfg: (math.ceil(SeqLen_Q / cfg.TILE_M), Batch * Heads, 1),
259259
kernel=fmha_kernel,
260-
args_fn=lambda TILE_M, TILE_N: (
260+
args_fn=lambda cfg: (
261261
Q, K, V, Out,
262262
qk_scale, input_pos, D_k, Heads,
263-
TILE_M, TILE_N, query_group_size, causal, (SeqLen_KV % TILE_N) == 0
263+
cfg.TILE_M, cfg.TILE_N, query_group_size, causal, (SeqLen_KV % cfg.TILE_N) == 0
264264
),
265265
)
266266

samples/templates/AttentionFMHA.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
136136
# --- Tune/Get the best configuration for the FMHA Kernel ---
137137
tuned_result = autotuner(
138138
torch.cuda.current_stream(),
139-
grid_fn=lambda TILE_M: (math.ceil(SeqLen_Q / TILE_M), Batch * Heads, 1),
139+
grid_fn=lambda named_args, cfg: (math.ceil(SeqLen_Q / cfg.TILE_M), Batch * Heads, 1),
140140
kernel=fmha_kernel,
141-
args_fn=lambda TILE_M, TILE_N: (
141+
args_fn=lambda cfg: (
142142
Q, K, V, Out,
143143
qk_scale, input_pos, D_k, Heads,
144-
TILE_M, TILE_N, query_group_size, causal, (SeqLen_KV % TILE_N) == 0
144+
cfg.TILE_M, cfg.TILE_N, query_group_size, causal, (SeqLen_KV % cfg.TILE_N) == 0
145145
),
146146
)
147147

samples/utils/autotuner.py

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Callable, Sequence
1212
import cuda.tile as ct
1313
from cuda.tile._execution import TileDispatcher
14+
from cuda.tile._exception import TileCompilerTimeoutError, TileCompilerExecutionError
1415
from cuda.tile._cext import default_tile_context
1516
import random
1617
import torch
@@ -29,6 +30,11 @@ def __init__(self, *, num_ctas=None, occupancy=None, opt_level=3, **kwargs):
2930
self.occupancy = occupancy
3031
self.opt_level = opt_level
3132

33+
def __getattr__(self, name):
34+
if name in self.kwargs:
35+
return self.kwargs[name]
36+
raise AttributeError(f"Attribute {name} not found in {self.kwargs}")
37+
3238
def __str__(self):
3339
res = []
3440
for k, v in self.kwargs.items():
@@ -63,13 +69,14 @@ def __len__(self):
6369
def __getitem__(self, index):
6470
return self.configs[index]
6571

66-
def filter(self, named_args: dict[str, Any]) -> bool:
72+
def filter(self, named_args: dict[str, Any], cfg: Config) -> bool:
6773
if self.predicate_fn is None:
6874
return True
69-
predicate_sig = inspect.signature(self.predicate_fn)
70-
predicate_keys = set(predicate_sig.parameters.keys())
71-
kwargs = {k: named_args[k] for k in predicate_keys}
72-
return self.predicate_fn(**kwargs)
75+
result = self.predicate_fn(named_args, cfg)
76+
if not isinstance(result, bool):
77+
raise TypeError(f"Predicate function {self.predicate_fn.__name__} must return "
78+
f"a boolean value, but returned {type(result).__name__} instead.")
79+
return result
7380

7481

7582
def _shape_dtype_stride(arg: Any) -> tuple[tuple[int, ...], str, tuple[int, ...] | None]:
@@ -123,20 +130,6 @@ def _time_ms(run_once, *, get_args, stream, warmup=2, rep=10):
123130
return ms / max(1, rep)
124131

125132

126-
def _get_grid(grid_fn, named_args: dict[str, Any]) -> tuple[int, ...]:
127-
grid_sig = inspect.signature(grid_fn)
128-
grid_keys = set(grid_sig.parameters.keys())
129-
kwargs = {}
130-
for k in grid_keys:
131-
if k not in named_args:
132-
raise TypeError(
133-
f"Function parameter {k} in grid_fn is not in kernel parameters, "
134-
f"available parameters are {list(named_args.keys())}"
135-
)
136-
kwargs[k] = named_args[k]
137-
return grid_fn(**kwargs)
138-
139-
140133
@dataclass
141134
class TunedResult:
142135
# The tuned parameters
@@ -156,10 +149,13 @@ def __getattr__(self, name):
156149

157150

158151
def _make_trial_args(
159-
args_fn: Callable, kwargs: dict[str, Any], kernel, transforms: dict[str, Callable[[Any], Any]]
152+
args_fn: Callable[[Config], tuple[Any, ...]],
153+
cfg: Config,
154+
kernel: TileDispatcher,
155+
transforms: dict[str, Callable[[Any], Any]]
160156
) -> tuple[dict[str, Any], tuple[Any, ...]]:
161157
"""Make trial runtime arguments applying the transforms."""
162-
args = args_fn(**kwargs)
158+
args = args_fn(cfg)
163159

164160
trial_named_args = {}
165161
trial_args = []
@@ -186,16 +182,6 @@ def _normalize_search_space(space: SearchSpace | Sequence[Config]) -> SearchSpac
186182
)
187183

188184

189-
def _safe_args_fn(args_fn: Callable, kwargs: dict[str, Any]) -> tuple[Any, ...]:
190-
try:
191-
return args_fn(**kwargs)
192-
except TypeError:
193-
raise TypeError(
194-
f"Invalid parameters for args_fn, "
195-
f"should be the same as the search space config argument keys: {list(kwargs.keys())}"
196-
)
197-
198-
199185
@contextmanager
200186
def compiler_timeout(timeout_sec: int):
201187
old_timeout = default_tile_context.config.compiler_timeout_sec
@@ -219,7 +205,7 @@ def clear_cache(self, key=None):
219205

220206
def __call__(self,
221207
stream, grid_fn, kernel,
222-
args_fn: Callable,
208+
args_fn: Callable[[Config], tuple[Any, ...]],
223209
transforms: dict[str, Callable] = {},
224210
*,
225211
key_fn=_default_key,
@@ -228,24 +214,52 @@ def __call__(self,
228214
seed: int | None = None,
229215
force_retune: bool = False) -> TunedResult:
230216
"""
217+
Run the autotuned kernel and return its result.
218+
231219
It performs the following steps:
232-
1) picks or reuses the cached config and kernel,
233-
2) runs the kernel with the best config,
220+
1) picks a configuration from the search space or reuses the cached
221+
best configuration for the given key (unless ``force_retune=True``),
222+
2) launches the kernel with the best configuration,
234223
3) returns the tuned result.
235224
236225
Args:
237-
stream: The stream.
238-
grid_fn: The grid function.
239-
kernel: The kernel.
240-
args_fn: The function from the search space parameters to the runtime arguments.
241-
transforms: The transforms functions for runtime arguments if needed.
242-
key_fn: The key function.
243-
max_iter: The maximum number of valid condigurations to sample from the search space.
244-
compiler_time_limit_sec: The compilation time limit for each kernel.
245-
seed: The seed for the random number generator. Default is None.
246-
force_retune: Force retuning even if the config is found in the cache. Default is False.
226+
stream:
227+
CUDA stream to use for all kernel launches during tuning and
228+
for the final run.
229+
grid_fn:
230+
Callable that takes the named arguments and a single
231+
positional :class:`Config` object and returns a tuple of grid
232+
dimensions.
233+
kernel:
234+
The kernel to autotune.
235+
args_fn:
236+
Callable that takes a single positional :class:`Config` and
237+
returns a tuple of runtime arguments for ``kernel``.
238+
transforms:
239+
Optional transform or sequence of transforms applied to the
240+
runtime arguments before each kernel launch. Use this to
241+
perform lightweight pre-/post-processing without changing
242+
the search space.
243+
key_fn:
244+
Optional function that maps the named arguments to a hashable
245+
cache key. When omitted, a default key derived from argument
246+
shapes/dtypes is used. The key is used to look up and store
247+
the best config in the autotuner cache.
248+
max_iter:
249+
Maximum number of (valid) configurations to sample from the
250+
search space.
251+
compiler_time_limit_sec:
252+
The compilation time limit for each kernel.
253+
seed:
254+
Optional seed for the random number generator used when
255+
sampling configurations. If ``None``, the global random number
256+
generator state is used.
257+
force_retune:
258+
If ``True``, ignore any cached best config for this key and
259+
re-run the search. The new best config is then written back
260+
to the cache.
247261
"""
248-
key = key_fn(kernel, _safe_args_fn(args_fn, self._search_space.configs[0].kwargs))
262+
key = key_fn(kernel, args_fn(self._search_space.configs[0]))
249263
if not force_retune and key in self._cache:
250264
best_idx, best_grid, best_kernel = self._cache[key]
251265
logger.debug(f"Using cached config for key {key}: {self._search_space[best_idx]}")
@@ -259,13 +273,13 @@ def __call__(self,
259273
break
260274
cfg = self._search_space[cfg_idx]
261275
trial_named_args, trial_args = _make_trial_args(
262-
args_fn, cfg.kwargs, kernel, transforms
276+
args_fn, cfg, kernel, transforms
263277
)
264-
if not self._search_space.filter(trial_named_args):
278+
if not self._search_space.filter(trial_named_args, cfg):
265279
logger.debug(f"Config {cfg} filtered out by predicate function")
266280
continue
267281

268-
grid = _get_grid(grid_fn, trial_named_args)
282+
grid = grid_fn(trial_named_args, cfg)
269283
updated_kernel = ct.kernel(
270284
kernel._pyfunc,
271285
num_ctas=cfg.num_ctas,
@@ -280,11 +294,14 @@ def run_once(args):
280294
with compiler_timeout(compiler_time_limit_sec):
281295
time_ms = _time_ms(
282296
run_once,
283-
get_args=lambda: _make_trial_args(args_fn, cfg.kwargs, kernel, transforms)[1], # noqa
297+
get_args=lambda: _make_trial_args(args_fn, cfg, kernel, transforms)[1], # noqa
284298
stream=stream,
285299
)
286-
except Exception as e:
287-
logger.debug(f"{cfg} failed to run: {e}")
300+
except TileCompilerTimeoutError as e:
301+
logger.debug(f"{cfg} compilation timeout: {e}")
302+
continue
303+
except TileCompilerExecutionError as e:
304+
logger.debug(f"{cfg} compilation error: {e}")
288305
continue
289306

290307
if time_ms < best_time_ms:
@@ -303,7 +320,7 @@ def run_once(args):
303320
best_cfg = self._search_space[best_idx]
304321

305322
# Use the original runtime arguments to run the kernel with the best config
306-
best_packed_args = args_fn(**best_cfg.kwargs)
323+
best_packed_args = args_fn(best_cfg)
307324
ct.launch(stream, best_grid, best_kernel, best_packed_args)
308325

309326
# Return the tuned result

0 commit comments

Comments
 (0)