|
11 | 11 | from types import FunctionType |
12 | 12 | from typing import Tuple, Any, Optional, Sequence, Callable |
13 | 13 |
|
14 | | -from cuda.tile._debug import CUDA_TILE_LOGS |
15 | 14 | from cuda.tile._exception import ( |
16 | 15 | TileTypeError, |
17 | 16 | TileInternalError, |
18 | 17 | ConstantNotFoundError, TileSyntaxError, Loc, TileError |
19 | 18 | ) |
| 19 | +from cuda.tile._cext import TileContext |
20 | 20 | from cuda.tile._ir import ir |
21 | 21 | from cuda.tile._ir.ir import Operation, Function, Block, Var, Argument, IRContext, TypedOperation |
22 | 22 | from cuda.tile._ir.op_impl import op_implementations |
@@ -359,12 +359,15 @@ def infer_types_in_func(context: TypingContext, |
359 | 359 | return dataclasses.replace(func, parameters=tuple(new_params)) |
360 | 360 |
|
361 | 361 |
|
362 | | -def infer_types_pass(func: Function, args: Tuple[Argument, ...], pyfunc: FunctionType) -> Function: |
| 362 | +def infer_types_pass(func: Function, |
| 363 | + args: Tuple[Argument, ...], |
| 364 | + pyfunc: FunctionType, |
| 365 | + tile_context: TileContext) -> Function: |
363 | 366 | context = TypingContext(func.root_block.ctx) |
364 | 367 | try: |
365 | 368 | return infer_types_in_func(context, func, args) |
366 | 369 | except Exception as e: |
367 | | - if 'CUTILEIR' in CUDA_TILE_LOGS: |
| 370 | + if 'CUTILEIR' in tile_context.config.log_keys: |
368 | 371 | highlight_loc = e.loc if hasattr(e, 'loc') else None |
369 | 372 | code = (f"====Partial CuTile IR for {func}==== \n\n" |
370 | 373 | f"{func.to_string(highlight_loc=highlight_loc)}\n\n") |
|
0 commit comments