diff --git a/external/dace b/external/dace index d5fbadb6..d186d86d 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit d5fbadb626389e425fac5ed93d2a880811eca41f +Subproject commit d186d86dea15f7852545dcde0c4f5b9e6d4f072b diff --git a/external/gt4py b/external/gt4py index eef3c0ee..e256ec5f 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit eef3c0ee9de9c4eb8f57650b64abf7863c05fc83 +Subproject commit e256ec5f2ae79e6240f7b5a4a29c9647877c12f4 diff --git a/ndsl/__init__.py b/ndsl/__init__.py index d468d58f..72ea237f 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -10,6 +10,7 @@ from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath from .quantity import Quantity +from .dsl.optimization_config import OptimizationConfig from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig @@ -90,6 +91,7 @@ "MetaEnumStr", "State", "LocalState", + "OptimizationConfig", "NDSLRuntime", "Local", "DiagManagerMonitor", diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 65d72018..abb70ec8 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -786,7 +786,7 @@ def __init__( "Communicator needs to be instantiated with communication subsystem" f" derived from `comm_abc.Comm`, got {type(comm)}." ) - if comm.Get_size() != partitioner.total_ranks: + if comm.Get_size() < partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " f"comm object with only {comm.Get_size()} ranks, are we running " diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 2807cf6a..605b86d7 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -52,6 +52,8 @@ class BackendLoopOrder(Enum): "orch:dace:cpu:KJI": "dace:cpu_KJI", "st:dace:gpu:KJI": "dace:gpu", "orch:dace:gpu:KJI": "dace:gpu", + "st:dace:gpu:IJK": "dace:gpu_IJK", + "orch:dace:gpu:IJK": "dace:gpu_IJK", } """Internal: match the NDSL backend names with the GT4Py names""" diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 87d608dd..d4313815 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -7,46 +7,48 @@ def identify_code_path( partitioner: Partitioner, single_code_path: bool, ) -> FV3CodePath: - """Determine which code path your rank will hit. + """ + Determine which code path your rank will hit. - If single_code_path is True, single_code_path is True, - only one code path exists (case of doubly periodic grid). + If single_code_path is True, only one code path exists, + e.g. in case of a doubly periodic grid. If single_code_path is False, we are in the case of the - cube-sphere and we will look at our position on the tile.""" + cube-sphere and we will look at our position on the tile. + """ # Doubly-periodic or single tile grid - if single_code_path: + if single_code_path or partitioner.layout == (1, 1): return FV3CodePath.All # Cube-sphere - if partitioner.layout == (1, 1): - return FV3CodePath.All - elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: + if partitioner.layout[0] <= 1 or partitioner.layout[1] <= 1: raise NotImplementedError( - f"Build for layout {partitioner.layout} is not handled" + f"Build for layout {partitioner.layout} is not handled." ) - else: - if partitioner.tile.on_tile_bottom(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.BottomLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.BottomRight - else: - return FV3CodePath.Bottom - if partitioner.tile.on_tile_top(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.TopLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.TopRight - else: - return FV3CodePath.Top - else: - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.Left - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.Right - else: - return FV3CodePath.Center + + # Bottom row + if partitioner.tile.on_tile_bottom(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.BottomLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.BottomRight + return FV3CodePath.Bottom + + # Top row + if partitioner.tile.on_tile_top(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.TopLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.TopRight + return FV3CodePath.Top + + # Left & right column with corners already handled + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.Left + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.Right + + return FV3CodePath.Center def get_cache_fullpath(code_path: FV3CodePath) -> str: diff --git a/ndsl/dsl/caches/codepath.py b/ndsl/dsl/caches/codepath.py index 61591ccf..3d90a9e2 100644 --- a/ndsl/dsl/caches/codepath.py +++ b/ndsl/dsl/caches/codepath.py @@ -3,10 +3,12 @@ class FV3CodePath(enum.Enum): """Enum listing all possible code paths on a cube sphere. + For any layout the cube sphere has up to 9 different code paths depending on the positioning of the rank on the tile and which of the edge/corner cases it has to handle, as well as the possibility for all boundary computations in the 1x1 layout case. + Since the framework inlines code to optimize, we _cannot_ pre-suppose which code being kept and/or ejected. This enum serves as the ground truth to map rank to the proper generated code. diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 013f4083..b722df77 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -10,14 +10,20 @@ from gt4py.cartesian.utils.compiler import cxx_compiler_defaults, gpu_configuration from ndsl import LocalComm +from ndsl.comm import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner from ndsl.config import Backend from ndsl.dsl import NDSL_COMPILER_SILENCE, NDSL_GLOBAL_PRECISION from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.optional_imports import cupy as cp -from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector +from ndsl.performance.collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) if TYPE_CHECKING: @@ -166,8 +172,8 @@ def __init__( Args: communicator: used for setting the distributed caches backend: string for the backend - tile_nx: x/y domain size for a single time - tile_nz: z domain size for a single time + tile_nx: x/y domain size for a single tile + tile_nz: z domain size for a single tile orchestration: orchestration mode from DaCeOrchestration time: trigger performance collection, available to user with `performance_collector` @@ -181,16 +187,12 @@ def __init__( # ToDo: DaceConfig becomes a bit more than a read-only config # with this. Should be refactored into a DaceExecutor carrying a config self.loaded_dace_executables: DaceExecutables = {} - self.performance_collector = ( - PerformanceCollector( - "InternalOrchestrationTimer", - comm=( - LocalComm(0, 6, {}) if communicator is None else communicator.comm - ), + if not time: + self.performance_collector: AbstractPerformanceCollector = ( + NullPerformanceCollector() ) - if time - else NullPerformanceCollector() - ) + else: + self.set_timer(communicator.comm if communicator else None) # Temporary. This is a bit too out of the ordinary for the common user. # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` @@ -265,21 +267,29 @@ def __init__( march_option = "-mcpu=native" if is_arm_neoverse else "-march=native" # Removed --fast-math gpu_config = gpu_configuration(GT4PY_COMPILE_OPT_LEVEL) + gpu_cflags = " ".join(gpu_config.gpu_compile_flags).strip() dace.config.Config.set( "compiler", "cuda", "args", - value=f"-std=c++14 {warnings_policy} -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_config.gpu_compile_flags}", + value=f"-std=c++14 {warnings_policy} -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}", ) - cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 - dace.config.Config.set("compiler", "cuda", "cuda_arch", value=f"{cuda_sm}") - # Block size/thread count is defaulted to an average value for recent - # hardware (Pascal and upward). The problem of setting an optimized - # block/thread is both hardware and problem dependant. Fine tuners - # available in DaCe should be relied on for further tuning of this value. + # Target compilation for hardware micro-code capacities + gpu_defaults = get_gpu_hardware_defaults() dace.config.Config.set( - "compiler", "cuda", "default_block_size", value="64,8,1" + "compiler", + "cuda", + "cuda_arch", + value=f"{gpu_defaults.compute_capability}", + ) + + # Default block size for kernels launch + dace.config.Config.set( + "compiler", + "cuda", + "default_block_size", + value=str(gpu_defaults.block_size)[1:-1], ) # Potentially buggy - deactivate dace.config.Config.set( @@ -346,6 +356,9 @@ def __init__( value="c", ) + # Debug lineinfo is incorrect anyway for the stencils + dace.config.Config.set("compiler", "lineinfo", value="none") + # Attempt to kill the dace.conf to avoid confusion dace_conf_to_kill = dace.config.Config.cfg_filename() if dace_conf_to_kill is not None: @@ -413,4 +426,20 @@ def from_dict(cls, data: dict) -> Self: config.rank_size = data["rank_size"] config.layout = data["layout"] config.tile_resolution = data["tile_resolution"] - return config + # TODO + # Computed properties like `self.code_path` and `self.do_compile` + # aren't updated. + # We also don't `set_distributed_caches()` based on that updated + # information. + raise NotImplementedError( + "Implementation of `DaceConfig.from_dict()` is incomplete." + ) + + def set_timer(self, comm: Comm | None) -> None: + """Set timer on configuration externally""" + # TODO: this absolutely should not be a on a Configuration object + # and even less setup outside. Madness, we have lost our ways... + self.performance_collector = PerformanceCollector( + "InternalOrchestrationTimer", + comm=(LocalComm(0, 6, {}) if comm is None else comm), + ) diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py new file mode 100644 index 00000000..bbd367dc --- /dev/null +++ b/ndsl/dsl/dace/hardware_config.py @@ -0,0 +1,126 @@ +import dataclasses +import sys +from pathlib import Path +from typing import Literal + +from ndsl import ndsl_log +from ndsl.optional_imports import cupy as cp + + +GPUVendor = Literal["Nvidia"] | Literal["AMD"] | Literal["Intel"] | Literal["Unknown"] + +# Taken straight out of https://pcisig.com/membership/member-companies +_VENDOR_PCI_SIGNATURES: dict[int, GPUVendor] = { + 0x10DE: "Nvidia", + 0x1002: "AMD", + 0x8086: "Intel", + 0x0: "Unknown", +} + +# Cached copy of the hardware default +_GPU_HARDWARE_DEFAULTS = None + + +def _get_vendor() -> GPUVendor: + """Retrieve vendor using the current device PCI id to query the PCI vendor + from the kernel logs. + + ⚠️ Only works on Linux - kicks back to "Unknown" in other cases. + """ + if not sys.platform.startswith("linux"): + ndsl_log.info("GPU hardware detection only possible on Linux system.") + return "Unknown" + + pci_device_id = cp.cuda.runtime.deviceGetPCIBusId(0) + dev_path = Path("/sys", "bus", "pci", "devices", f"{pci_device_id}") + if not dev_path.exists(): + ndsl_log.info(f"GPU detection: PCI device not found at {dev_path}.") + return "Unknown" + + with open(dev_path / "vendor", "r") as f: + vendor_str = f.read().strip().replace("0x", "") + vendor_id = int(vendor_str, 16) + + if vendor_id not in _VENDOR_PCI_SIGNATURES: + ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}.") + return "Unknown" + + return _VENDOR_PCI_SIGNATURES[vendor_id] + + +@dataclasses.dataclass +class GPUHardwareDefaults: + """Compute defaults for common GPUs""" + + vendor: GPUVendor + block_size: list[int] = dataclasses.field(default_factory=list) + compute_capability: int = -1 # Nvidia specific + + +def get_gpu_hardware_defaults() -> GPUHardwareDefaults: + """Retrieve default values for GPU computation configuration.""" + global _GPU_HARDWARE_DEFAULTS + if _GPU_HARDWARE_DEFAULTS is not None: + return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] + + if cp is None or not cp.cuda.is_available(): + ndsl_log.warning("No cupy - defaulting for GPU hardware") + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor="Unknown", + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + return _GPU_HARDWARE_DEFAULTS + + # Who goes there + vendor = _get_vendor() + if vendor == "Nvidia": + compute_capability = int(cp.cuda.Device(0).compute_capability) + # Default block size based on compute capability + if compute_capability > 80: + # Covers: + # - Blackwell (100+) + # - Hopper (90-100) + # - Ampere (80-90) + block_sizes = [128, 1, 1] + elif compute_capability > 60: + # Covers: + # - Volta (70-80) + # - Pascal (60-70) + block_sizes = [64, 8, 1] + else: + # For older hardware - we default to the safe warp-size since + # the dawn of GPGPU on Nvidia hardware + block_sizes = [32, 1, 1] + + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=block_sizes, + compute_capability=compute_capability, + ) + elif vendor == "AMD": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[64, 1, 1], # Default RDNA architecture is Wave64 + ) + elif vendor == "Intel": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[32, 1, 1], # Intel can run 8, 16 or 32 - but SIMD betters in 32 + ) + else: + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + + ndsl_log.info(f"GPU vendor detected: {_GPU_HARDWARE_DEFAULTS.vendor}") + + return _GPU_HARDWARE_DEFAULTS diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 4da02544..03068abf 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -4,28 +4,30 @@ import os from collections.abc import Callable, Sequence from pathlib import Path +from pprint import pformat from typing import Any -from dace import SDFG, CompiledSDFG +from dace import SDFG, CompiledSDFG, DeviceType from dace import compiletime as DaceCompiletime from dace import dtypes from dace import method as dace_method from dace import nodes from dace import program as dace_program from dace.dtypes import DeviceType as DaceDeviceType +from dace.dtypes import ScheduleType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.transformation.auto.auto_optimize import make_transients_persistent -from dace.transformation.dataflow import MapExpansion +from dace.transformation.dataflow import MapCollapse, MapExpansion +from dace.transformation.dataflow.add_threadblock_map import AddThreadBlockMap from dace.transformation.helpers import get_parent_map from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements -from ndsl import ndsl_log +from ndsl import Backend, OptimizationConfig, ndsl_log from ndsl.comm.mpi import MPI -from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, @@ -33,19 +35,15 @@ DaCeOrchestration, ) from ndsl.dsl.dace.dace_executable import DaceExecutable +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.dsl.dace.labeler import set_label from ndsl.dsl.dace.sdfg_debug_passes import ( negative_delp_checker, negative_qtracers_checker, sdfg_nan_checker, ) -from ndsl.dsl.dace.stree import CPUPipeline -from ndsl.dsl.dace.stree.optimizations import ( - AxisIterator, - CartesianAxisMerge, - CartesianRefineTransients, - CleanUpScheduleTree, -) +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.pipeline import StreePipeline from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -55,10 +53,7 @@ from ndsl.quantity import Quantity, State -_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = ( - os.environ.get("NDSL_STREE_OPT", "False") == "True" -) -"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES: list[tn.ScheduleNodeVisitor] | None = None def dace_inhibitor(func: Callable) -> Callable: @@ -149,8 +144,36 @@ def _tree_as_sdfg(stree: tn.ScheduleTreeRoot) -> SDFG: return stree.as_sdfg(skip={"ScalarToSymbolPromotion", "ControlFlowRaising"}) +def _optimization_pipeline( + config: OptimizationConfig, + device_type: DeviceType, + backend: Backend, + *, + passes: list[tn.ScheduleNodeVisitor] | None = None, + cache_directory: Path | None = None, +) -> StreePipeline: + if device_type == device_type.CPU: + return CPUPipeline( + config, backend, passes=passes, cache_directory=cache_directory + ) + + if device_type == DeviceType.GPU: + return GPUPipeline( + config, backend, passes=passes, cache_directory=cache_directory + ) + + raise ValueError( + f"Unknown device type `{device_type}`, expected {DeviceType.CPU} or {DeviceType.GPU}." + ) + + def _build_sdfg( - dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any + dace_program: DaceProgram, + sdfg: SDFG, + config: DaceConfig, + optimization_config: OptimizationConfig, + args: Any, + kwargs: Any, ) -> None: """Build the .so out of the SDFG on the top tile ranks only.""" is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile @@ -158,6 +181,7 @@ def _build_sdfg( backend_name = config.get_backend() if is_compiling: + ndsl_log.debug(f"Compiling config:\n{pformat(optimization_config, indent=2)}") # Fully specialize all known symbols and then propagate these changes in the simplify # pass that follows. This is not only a smart idea in general, but also simplifies (haha) # the schedule tree (optimization) roundtrip. @@ -170,27 +194,43 @@ def _build_sdfg( repl_dict[sym] = val my_sdfg.replace_dict(repl_dict) - if config.verbose_orchestration: - sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"), - compress=True, - ) + if config.verbose_orchestration: + ndsl_log.debug("saving 00-combined_from_stencils.sdfgz") + sdfg.save( + os.path.abspath( + f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz" + ), + compress=True, + ) with DaCeProgress(config, "Simplify (1)"): _simplify(sdfg) if config.verbose_orchestration: + ndsl_log.debug("saving 01-simplify.sdfgz") sdfg.save( os.path.abspath(f"{sdfg.build_folder}/01-simplify_1.sdfgz"), compress=True, ) - if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if optimization_config.stree.enabled: # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): # Break all loops into uni-dimensional loops to simplify optimizations - sdfg.apply_transformations_repeated(MapExpansion, validate=True) + sdfg.apply_transformations_repeated( + MapExpansion, + options={ + "inner_schedule": ( + ScheduleType.GPU_Device + if device_type is DeviceType.GPU + else ScheduleType.Default + ) + }, + validate=True, + print_report=True, + ) stree = sdfg.as_schedule_tree() if config.verbose_orchestration: + ndsl_log.debug("saving 02-pre_opt.stree.txt") with open( os.path.abspath(f"{sdfg.build_folder}/02-pre_opt.stree.txt"), "w+", @@ -198,45 +238,16 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - passes = [] - if backend_name.loop_order == BackendLoopOrder.IJK: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._K), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KJI: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._I), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KIJ: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianRefineTransients(backend_name), - ] - ) - else: - raise NotImplementedError( - f"Loop order {backend_name.loop_order} has no schedule tree pipeline" - ) - CPUPipeline(passes=passes, cache_directory=Path(sdfg.build_folder)).run( - stree, verbose=config.verbose_schedule_tree_optimizations + pipeline = _optimization_pipeline( + optimization_config, + device_type, + backend_name, + cache_directory=Path(sdfg.build_folder), + passes=_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES, ) + pipeline.run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: + ndsl_log.debug("saving 03-post_opt.stree.txt") with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), "w+", @@ -246,48 +257,90 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: go back to SDFG"): sdfg = _tree_as_sdfg(stree) if config.verbose_orchestration: + ndsl_log.debug("saving 04-from_stree.sdfgz") sdfg.save( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, ) - # Make the transients array persistents - if config.is_gpu_backend(): - # TODO - # The following should happen on the stree level - _to_gpu(sdfg) + # We want all maps properly collapse to make sure the codegen will see nD parallel + # axis as a single kernelizable map + with DaCeProgress(config, "Collapse maps"): + # allow `MapCollapse` to collapse maps with different schedules + sdfg.apply_transformations_repeated(MapCollapse, permissive=True) + + with DaCeProgress(config, "Make transient persistents"): + # Make the transients array persistents + if config.is_gpu_backend(): + # TODO + # The following should happen on the stree level + _to_gpu(sdfg) + make_transients_persistent(sdfg=sdfg, device=device_type) - sdfg.apply_gpu_transformations() + # Upload args to device + _upload_to_device(list(args) + list(kwargs.values())) + else: + # TODO + # The following should happen on the stree level + for _sd, _aname, arr in sdfg.arrays_recursive(): + if arr.shape == (1,): + arr.storage = DaceStorageType.Register + make_transients_persistent(sdfg=sdfg, device=device_type) - make_transients_persistent(sdfg=sdfg, device=device_type) + if config.is_gpu_backend(): + with DaCeProgress(config, "Apply GPU transformations"): + # Set block size on GPU maps and collect callback + # tasklets to exclude next + gpu_defaults = get_gpu_hardware_defaults() + exclude_taskslets_list = [] + + for me, _state in sdfg.all_nodes_recursive(): + if ( + isinstance(me, nodes.MapEntry) + and me.map.schedule == ScheduleType.GPU_Device + ): + if me.map.gpu_block_size is None: + me.map.gpu_block_size = gpu_defaults.block_size + + if isinstance(me, nodes.Tasklet) and "callback_" in me.label: + exclude_taskslets_list.append(me.label) + + sdfg.apply_transformations_repeated(AddThreadBlockMap) + + if optimization_config.gpu.common_gpu_xforms: + with DaCeProgress(config, "Apply common GPU xforms"): + # Apply common GPU transforms (includes a simplify) + # while making sure tasklet remain on the host + from dace.transformation.interstate import GPUTransformSDFG + + sdfg.apply_transformations( + GPUTransformSDFG, + options={ + "exclude_tasklets": ",".join(exclude_taskslets_list), + "host_data": ["__pystate"], + }, + ) + else: + with DaCeProgress(config, "GPU simplify"): + _simplify(sdfg) - # Upload args to device - _upload_to_device(list(args) + list(kwargs.values())) + if config.verbose_orchestration: + ndsl_log.debug("saving 05-apply_gpu_xforms.sdfgz") + sdfg.save( + os.path.abspath( + f"{sdfg.build_folder}/05-apply_gpu_xforms.sdfgz" + ), + compress=True, + ) else: - # TODO - # The following should happen on the stree level - for _sd, _aname, arr in sdfg.arrays_recursive(): - if arr.shape == (1,): - arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=device_type) - - # Build non-constants & non-transients from the sdfg_kwargs - sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) - for k in dace_program.constant_args: - if k in sdfg_kwargs: - del sdfg_kwargs[k] - sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None} - for k, tup in dace_program.resolver.closure_arrays.items(): - if k in sdfg_kwargs and tup[1].transient: - del sdfg_kwargs[k] - - with DaCeProgress(config, "Simplify (2)"): - _simplify(sdfg) - if config.verbose_orchestration: - sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), - compress=True, - ) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) + if config.verbose_orchestration: + ndsl_log.debug("saving 05-simplify_2.sdfgz") + sdfg.save( + os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), + compress=True, + ) # Move all memory that can be into a pool to lower memory pressure for GPU # We skip this memory optimization for CPU because we don't have a memory # pool available yet (DaCe v1) @@ -316,7 +369,12 @@ def _build_sdfg( # Compile with DaCeProgress(config, "Codegen & compile"): - sdfg.compile() + compiled_sdfg = sdfg.compile() + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiled_sdfg, + arguments={}, + arguments_hash=0, + ) # Printing analysis of the compiled SDFG with DaCeProgress(config, "Build finished. Running memory static analysis"): @@ -355,22 +413,30 @@ def _build_sdfg( ) MPI.COMM_WORLD.Barrier() - with DaCeProgress(config, "Loading"): - sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True) - if sdfg_path is None: - raise ValueError("Couldn't load SDFG post build") - compiledSDFG, _ = dace_program.load_precompiled_sdfg( - sdfg_path, *args, **kwargs - ) - config.loaded_dace_executables[dace_program] = DaceExecutable( - compiled_sdfg=compiledSDFG, - arguments={}, - arguments_hash=0, - ) + if not is_compiling: + with DaCeProgress(config, "Loading"): + sdfg_path = get_sdfg_path( + dace_program.name, config, override_run_only=True + ) + if sdfg_path is None: + raise ValueError("Couldn't load SDFG post build") + compiledSDFG, _ = dace_program.load_precompiled_sdfg( + sdfg_path, *args, **kwargs + ) + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiledSDFG, + arguments={}, + arguments_hash=0, + ) def _call_sdfg( - dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any + dace_program: DaceProgram, + sdfg: SDFG, + config: DaceConfig, + optimization_config: OptimizationConfig, + args: Any, + kwargs: Any, ) -> list | None: """Dispatch to either SDFG execution and/or build.""" @@ -382,7 +448,7 @@ def _call_sdfg( and dace_program not in config.loaded_dace_executables # already cached ): ndsl_log.info("Building DaCe orchestration") - _build_sdfg(dace_program, sdfg, config, args, kwargs) + _build_sdfg(dace_program, sdfg, config, optimization_config, args, kwargs) if mode not in [DaCeOrchestration.BuildAndRun, DaCeOrchestration.Run]: raise ValueError(f"Unexpected DaceOrchestration mode `{mode}`.") @@ -489,9 +555,15 @@ class _LazyComputepathFunction(SDFGConvertible): that will be compiled but not regenerated. """ - def __init__(self, func: Callable, config: DaceConfig) -> None: + def __init__( + self, + func: Callable, + config: DaceConfig, + optimization_config: OptimizationConfig, + ) -> None: self.func = func self.config = config + self.optimization_config = optimization_config self.daceprog: DaceProgram = dace_program(self.func) self._sdfg = None @@ -507,6 +579,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] self.daceprog, sdfg, self.config, + self.optimization_config, args, kwargs, ) @@ -571,12 +644,13 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] **kwargs, ) # Label the code (this is the topmost code) - if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=True) return _call_sdfg( self.daceprog, sdfg, self.lazy_method.config, + self.lazy_method.optimization_config, args, kwargs, ) @@ -584,7 +658,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] sdfg = _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs) # Label the code - if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False) return sdfg @@ -599,9 +673,15 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t constant_args, given_args, parent_closure ) - def __init__(self, func: Callable, config: DaceConfig): + def __init__( + self, + func: Callable, + config: DaceConfig, + optimization_config: OptimizationConfig, + ) -> None: self.func = func self.config = config + self.optimization_config = optimization_config def __get__(self, obj: object, objtype: Any = None) -> SDFGEnabledCallable: """Return SDFGEnabledCallable wrapping original obj.method from cache. @@ -620,6 +700,7 @@ def orchestrate( config: DaceConfig, method_to_orchestrate: str = "__call__", dace_compiletime_args: Sequence[str] | None = None, + optimization_config: OptimizationConfig | None = None, ) -> None: """ Orchestrate a method of an object with DaCe. @@ -650,6 +731,11 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] + if optimization_config is None: + opt_config = OptimizationConfig() + else: + opt_config = optimization_config + func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant @@ -672,7 +758,7 @@ def orchestrate( # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config).__get__(obj) + wrapped = _LazyComputepathMethod(func, config, opt_config).__get__(obj) if method_to_orchestrate == "__call__": # Grab the function from the type of the child class @@ -724,6 +810,7 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t def orchestrate_function( config: DaceConfig, dace_compiletime_args: Sequence[str] | None = None, + optimization_config: OptimizationConfig | None = None, ) -> Callable[..., Any] | _LazyComputepathFunction: """ Decorator orchestrating a method of an object with DaCe. @@ -738,11 +825,16 @@ def orchestrate_function( if dace_compiletime_args is None: dace_compiletime_args = [] + if optimization_config is None: + opt_config = OptimizationConfig() + else: + opt_config = optimization_config + def _decorator(func: Callable[..., Any]): # type: ignore[no-untyped-def] def _wrapper(*args, **kwargs): # type: ignore[no-untyped-def] for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime - return _LazyComputepathFunction(func, config) + return _LazyComputepathFunction(func, config, opt_config) return _wrapper(func) if config.is_dace_orchestrated() else func diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 73497f93..b1f69aa1 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,11 +1,26 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge +from .axis_merge import CartesianAxisMerge +from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree +from .kernelize_maps import KernelizeMaps +from .offgrid_conditionals import ( + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) from .refine_transients import CartesianRefineTransients +from .remove_loops import InlineVertical2DWrite +from .statistics import TreeOptimizationStatistics __all__ = [ - "AxisIterator", "CartesianAxisMerge", - "CartesianRefineTransients", + "CartesianMerge", "CleanUpScheduleTree", + "KernelizeMaps", + "ExtractOffgridConditionals", + "InlineOffgridConditionals", + "MergeConditionals", + "CartesianRefineTransients", + "InlineVertical2DWrite", + "TreeOptimizationStatistics", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index c042badf..196d3d0e 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import copy import dace @@ -7,29 +5,18 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( +from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, - no_data_dependencies_on_cartesian_axis, -) -from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, + get_next_node, + is_axis_for, + is_axis_map, + last_node, list_index, + no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) - - -# Buggy passes that should work -PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics - - -def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a map over the given axis.""" - map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) - - -def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: - return node.loop.loop_variable.startswith(axis.as_str()) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol def _both_same_single_axis_maps( @@ -39,18 +26,17 @@ def _both_same_single_axis_maps( ( len(first.node.map.params) == 1 and len(second.node.map.params) == 1 ) # Single axis - and _is_axis_map(first, axis) # Correct axis in first map - and _is_axis_map(second, axis) # Correct axis in second map + and is_axis_map(first, axis) # Correct axis in first map + and is_axis_map(second, axis) # Correct axis in second map ) def _can_merge_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: - if _both_same_single_axis_maps(first, second, axis): - if no_data_dependencies_on_cartesian_axis(first, second, axis): - return True - return False + return _both_same_single_axis_maps( + first, second, axis + ) and no_data_dependencies_on_cartesian_axis(first, second, axis) class InsertOvercomputationGuard(tn.ScheduleNodeTransformer): @@ -82,89 +68,45 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: all_children_are_maps = all( [isinstance(child, tn.MapScope) for child in node.children] ) - if not all_children_are_maps: - if self._merged_range != self._original_range: - if_scope = tn.IfScope( - condition=self._execution_condition(), - children=node.children, - parent=node, - ) - # Re-parent to IF - for child in node.children: - child.parent = if_scope - node.children = [if_scope] + if all_children_are_maps: + node.children = self.visit(node.children) return node - node.children = self.visit(node.children) + if self._merged_range != self._original_range: + if_scope = tn.IfScope( + condition=self._execution_condition(), + children=node.children, + parent=node, + ) + # Re-parent to IF + for child in node.children: + child.parent = if_scope + node.children = [if_scope] return node -def _get_next_node( - nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode -) -> tn.ScheduleTreeNode: - return nodes[list_index(nodes, node) + 1] - - -def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: - return list_index(nodes, node) >= len(nodes) - 1 - - -class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): - def __init__(self, axis: AxisIterator) -> None: - self._axis = axis - - def visit_MapScope( - self, - map_scope: tn.MapScope, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if axis_replacements is None: - axis_replacements = {} - - for index, param in enumerate(map_scope.node.params): - if param in axis_replacements: - map_scope.node.params[index] = axis_replacements[param] - - # visit children - for child in map_scope.children: - self.visit(child, axis_replacements=axis_replacements) - - def visit_TaskletNode( - self, - node: tn.TaskletNode, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if not axis_replacements: - # Noop if there are no replacements to do. - return - - for memlets in node.in_memlets.values(): - memlets.replace(axis_replacements) - for memlets in node.out_memlets.values(): - memlets.replace(axis_replacements) - - class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. Can do: - merge a given axis with the next maps at the same recursion level - - can overcompute (eager) to allow for more merging at the cost of an if + - can overcompute to allow for more merging at the cost of an if It expects: - All Maps and ForLoop are on a single axis - but doesn't check for it. Args: axis: AxisIterator to be merged - eager: overcompute with a conditional guard + overcompute: merge at the cost of an if statement. """ - def __init__(self, axis: AxisIterator, *, eager: bool = True) -> None: + def __init__(self, axis: AxisIterator, *, overcompute: bool) -> None: self.axis = axis - self.eager = eager + self.overcompute = overcompute def __str__(self) -> str: - return f"CartesianAxisMerge_{self.axis.name}_{'eager' if self.eager else ''}" + suffix = "_overcompute" if self.overcompute else "" + return f"CartesianAxisMerge_{self.axis.name}{suffix}" def _merge_node( self, node: tn.ScheduleTreeNode, nodes: list[tn.ScheduleTreeNode] @@ -179,9 +121,6 @@ def _merge_node( if isinstance(node, tn.MapScope): return self._map_overcompute_merge(node, nodes) - if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope): - return self._push_ifelse_down(node, nodes) - if isinstance(node, tn.ForScope): return self._for_merge(node) @@ -197,7 +136,7 @@ def _merge_node( def _for_merge(self, the_for_scope: tn.ForScope) -> int: merged = 0 - if _is_axis_for(the_for_scope, self.axis): + if is_axis_for(the_for_scope, AxisIterator._K): # TODO: if the for scope is on a cartesian axis it can be # merged with other for scope going in the same direction pass @@ -206,7 +145,7 @@ def _for_merge(self, the_for_scope: tn.ForScope) -> int: if ( len(the_for_scope.children) == 1 and isinstance(the_for_scope.children[0], tn.MapScope) - and _is_axis_map(the_for_scope.children[0], self.axis) + and is_axis_map(the_for_scope.children[0], self.axis) ): swap_node_position_in_tree(the_for_scope, the_for_scope.children[0]) merged += 1 @@ -248,92 +187,19 @@ def _push_tasklet_down( return merged - def _push_ifelse_down( - self, the_if: tn.IfScope, nodes: list[tn.ScheduleTreeNode] - ) -> int: - merged = 0 - - # Recurse down if/else/elif - if_index = list_index(nodes, the_if) - if len(the_if.children) != 0: - merged += self._merge_node(the_if.children[0], the_if.children) - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - merged += self._merge_node(else_node, else_node.children) - else: - break - - # Look at swapping if/else/elif first map w/ control flow - - # Gather all first maps - if they do not exists, get out - all_maps = [] - if isinstance(the_if.children[0], tn.MapScope): - all_maps.append(the_if.children[0]) - else: - return merged - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - if isinstance(else_node.children[0], tn.MapScope): - all_maps.append(else_node.children[0]) - else: - return merged - - else: - break - - # Check for mergeability - if len(all_maps) > 1: - the_map = all_maps[0] - for _map in all_maps[1:]: - if not _can_merge_axis_maps(the_map, _map, self.axis): - return merged - - # We are good to go - swap it all - inner_if_map = the_if.children[0] - - # Swap IF & maps - if_index = list_index(nodes, the_if) - swap_node_position_in_tree(the_if, inner_if_map) - - # Swap ELIF/ELSE & maps - for else_index in range(if_index + 1, len(nodes)): - if else_index < len(nodes) and ( - isinstance(nodes[else_index], tn.ElseScope) - or isinstance(nodes[else_index], tn.ElifScope) - ): - swap_node_position_in_tree( - nodes[else_index], nodes[else_index].children[0] - ) - else: - break - - # Merge the Maps - assert isinstance(nodes[if_index], tn.MapScope) - merged += self._map_overcompute_merge(nodes[if_index], nodes) - - return merged - def _map_overcompute_merge( self, the_map: tn.MapScope, nodes: list[tn.ScheduleTreeNode] ) -> int: # End of nodes OR # Not the right axis # --> recurse - if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis): + if last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) return merged - next_node = _get_next_node(nodes, the_map) + next_node = get_next_node(nodes, the_map) # Next node is not a MapScope - no merge if not isinstance(next_node, tn.MapScope): @@ -345,7 +211,6 @@ def _map_overcompute_merge( # Over compute to merge: # - force-merge by expanding the ranges - # - then, guard children to only run in their respective range first_range = the_map.node.map.range second_range = next_node.node.map.range merged_range = dace.subsets.Range( @@ -358,8 +223,15 @@ def _map_overcompute_merge( ] ) - # push IfScope down if children are just maps - axis_as_str = the_map.node.params[0] + # only overcompute if configured - otherwise no merge + if not self.overcompute and ( + first_range != merged_range or second_range != merged_range + ): + return 0 + + # - then, guard children to only run in their respective range + axis_as_str = the_map.node.map.params[0] + assert isinstance(axis_as_str, str) first_map = InsertOvercomputationGuard( axis_as_str, merged_range=merged_range, original_range=first_range ).visit(the_map) @@ -368,7 +240,9 @@ def _map_overcompute_merge( merged_range=merged_range, original_range=second_range, ).visit(next_node) - merged_children: list[tn.MapScope] = [ + assert isinstance(first_map, tn.MapScope) + assert isinstance(second_map, tn.MapScope) + merged_children: list[tn.ScheduleTreeNode] = [ *first_map.children, *second_map.children, ] @@ -384,11 +258,13 @@ def _map_overcompute_merge( # K-maps use unique iterators (i.e. every k-map iterates over `k__[0-9]*`). # After merge, we need to replace the axis symbols of the second map's children # with the axis symbol of the first map. - if next_node.node.map.params[0] != the_map.node.map.params[0]: - replacements = {next_node.node.map.params[0]: the_map.node.map.params[0]} - ReplaceAxisSymbol(self.axis).visit( - first_map, axis_replacements=replacements - ) + if second_map.node.map.params[0] != first_map.node.map.params[0]: + replacements = { + dace.symbol(second_map.node.map.params[0]): dace.symbol( + first_map.node.map.params[0] + ) + } + ReplaceAxisSymbol(replacements).visit(first_map) # delete now-merged second_map del nodes[list_index(nodes, next_node)] diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py new file mode 100644 index 00000000..16d72380 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -0,0 +1,57 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl.config import Backend, BackendLoopOrder +from ndsl.dsl.dace.stree.optimizations.axis_merge import CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.offgrid_conditionals import ( + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) + + +class CartesianMerge(tn.ScheduleNodeTransformer): + """Merge Cartesian computation blocks. + + Args: + backend: The loop order influences the merge order. + overcompute: Whether to merge at the cost of an if statement. Defaults to True. + """ + + def __init__(self, backend: Backend, *, overcompute: bool = True) -> None: + super().__init__() + self._backend = backend + self._overcompute = overcompute + + def __str__(self) -> str: + return "CartesianMerge" + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + for axis in self._backend_order(): + InlineOffgridConditionals(axis).visit(node) + MergeConditionals().visit(node) + + for axis in self._backend_order(): + CartesianAxisMerge(axis, overcompute=self._overcompute).visit(node) + + ExtractOffgridConditionals().visit(node) + MergeConditionals().visit(node) + + def _backend_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + if self._backend.loop_order == BackendLoopOrder.IJK: + return (AxisIterator._I, AxisIterator._J, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.IKJ: + return (AxisIterator._I, AxisIterator._K, AxisIterator._J) + + if self._backend.loop_order == BackendLoopOrder.JIK: + return (AxisIterator._J, AxisIterator._I, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.JKI: + return (AxisIterator._J, AxisIterator._K, AxisIterator._I) + + if self._backend.loop_order == BackendLoopOrder.KIJ: + return (AxisIterator._K, AxisIterator._I, AxisIterator._J) + + assert self._backend.loop_order == BackendLoopOrder.KJI + return (AxisIterator._K, AxisIterator._J, AxisIterator._I) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 5e9ab522..93798f42 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log @@ -9,6 +7,7 @@ class CleanUpScheduleTree(tn.ScheduleNodeTransformer): """Remove `StateBoundary` nodes from children of ScheduleTreeScopes.""" def __init__(self) -> None: + super().__init__() self._removed_state_boundaries = 0 def __str__(self) -> str: @@ -27,40 +26,37 @@ def _remove_state_boundaries_from_children( def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + + self.generic_visit(node) return node - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: self._removed_state_boundaries = 0 self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes") + return node diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py new file mode 100644 index 00000000..2e342912 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -0,0 +1,25 @@ +from .memlet import AxisIterator, no_data_dependencies_on_cartesian_axis # isort: skip +from .loops import is_axis_for, is_axis_map, is_cartesian_axis +from .topology import ( + detect_cycle, + get_next_node, + last_node, + list_index, + reparent_scope_node, + swap_node_position_in_tree, +) + + +__all__ = [ + "AxisIterator", + "no_data_dependencies_on_cartesian_axis", + "is_axis_map", + "is_cartesian_axis", + "is_axis_for", + "get_next_node", + "last_node", + "swap_node_position_in_tree", + "detect_cycle", + "list_index", + "reparent_scope_node", +] diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py new file mode 100644 index 00000000..1f057954 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -0,0 +1,29 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as tn + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator + + +def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a Map over the given axis.""" + if len(node.node.map.params) != 1: + return False + + param = node.node.map.params[0] + assert isinstance(param, str) + return axis.is_equal(param) + + +def is_cartesian_axis(node: tn.MapScope | tn.ForScope) -> bool: + """Returns true if the given node is a map over any cartesian axis.""" + for axis in AxisIterator: + if (isinstance(node, tn.MapScope) and is_axis_map(node, axis)) or ( + isinstance(node, tn.ForScope) and is_axis_for(node, axis) + ): + return True + + return False + + +def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + """Returns true if node is a For over the given axis.""" + return axis.is_equal(node.loop.loop_variable) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py similarity index 61% rename from ndsl/dsl/dace/stree/optimizations/memlet_helpers.py rename to ndsl/dsl/dace/stree/optimizations/common/memlet.py index 0626133e..266be19d 100644 --- a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -2,6 +2,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet +from dace.symbolic import symbol from ndsl import ndsl_log @@ -17,77 +18,79 @@ def as_str(self) -> str: def as_cartesian_index(self) -> int: return self.value[1] + def is_equal(self, other: str) -> bool: + if self == AxisIterator._K: + return other.startswith(self.as_str()) + + return other == self.as_str() + + +def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: + """Return a normalize indexation symbol for cartesian indexation.""" + rename_maps = {} + for symb in index.free_symbols: + if symb.name.startswith(axis.as_str()): + rename_maps[symb] = symbol(axis.as_str()) + return index.subs(rename_maps) + def no_data_dependencies_on_cartesian_axis( first: stree.MapScope, second: stree.MapScope, axis: AxisIterator, ) -> bool: - """Check for read after write. Allow when indexation on the axis - is not offset.""" + """Check for read after write and write after write with different offsets.""" write_collector = MemletCollector(collect_reads=False) write_collector.visit(first) + other_writes = MemletCollector(collect_reads=False) + other_writes.visit(second) read_collector = MemletCollector(collect_writes=False) read_collector.visit(second) + for write in write_collector.out_memlets: # TODO: this can be optimized to allow non-overlapping intervals and such in the future - if write.subset.dims() <= axis.as_cartesian_index(): + axis_index = axis.as_cartesian_index() + + if write.subset.dims() <= axis_index: # Dimension does not exist continue - previous_axis_index = write.subset[axis.as_cartesian_index()][0] + previous_axis_index = normalize_cartesian_indexation( + write.subset[axis_index][0], axis + ) + + # Write-after-write with an offset case + for other_write in other_writes.out_memlets: + if write.data == other_write.data: + if previous_axis_index != normalize_cartesian_indexation( + other_write.subset[axis_index][0], axis + ): + ndsl_log.debug( + f"[{axis.name} Merge] Found write after write conflict " + f"for {write.data} " + f"w/ different offset to {axis.name} (" + f"first write at {previous_axis_index}, " + f"second write at {other_write.subset[axis_index][0]})" + ) + return False + + # Read-after-write with an offset case for read in read_collector.in_memlets: if write.data == read.data: - if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + if previous_axis_index != normalize_cartesian_indexation( + read.subset[axis_index][0], axis + ): ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" - f"write at {write.subset[axis.as_cartesian_index()][0]}, " - f"read at {read.subset[axis.as_cartesian_index()][0]})" + f"write at {write.subset[axis_index][0]}, " + f"read at {read.subset[axis_index][0]})" ) return False - return True - -def no_data_dependencies( - first: stree.MapScope, - second: stree.MapScope, - restrict_check_to_k: bool = False, -) -> bool: - write_collector = MemletCollector(collect_reads=False) - write_collector.visit(first) - read_collector = MemletCollector(collect_writes=False) - read_collector.visit(second) - for write in write_collector.out_memlets: - # Make sure we don't have read after write conditions. - # TODO: this can be optimized to allow non-overlapping intervals and such in the future - if restrict_check_to_k: - if write.subset.dims() < 3: - # Case of 2D write - no K dependency - continue - - previous_k_index = write.subset[2][0] - for read in read_collector.in_memlets: - if write.data == read.data: - if previous_k_index != read.subset[2][0]: - print( - "[K Merge] Found read after write conflict " - f"for {write.data} " - "w/ different offset to K (" - f"write at {write.subset[2][0]}, " - f"read at {read.subset[2][0]})" - ) - return False - - else: - if write.data in [read.data for read in read_collector.in_memlets]: - print( - f"[All dims merge] Found potential read after write conflict for {write.data}" - ) - return False return True diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py similarity index 65% rename from ndsl/dsl/dace/stree/optimizations/tree_common_op.py rename to ndsl/dsl/dace/stree/optimizations/common/topology.py index 1253ba81..fa06f3db 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -3,12 +3,30 @@ import dace.sdfg.analysis.schedule_tree.treenodes as tn +def reparent_scope_node( + original_parent: tn.ScheduleTreeScope, + new_parent: tn.ScheduleTreeScope, + *, + prepend: bool = True, +) -> None: + """Re-parent children between two scope nodes""" + + for child in original_parent.children: + child.parent = new_parent + + if prepend: + new_parent.children = [*original_parent.children, *new_parent.children] + else: + new_parent.children = [*new_parent.children, *original_parent.children] + + def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope ) -> None: """Top node becomes child, child becomes top node.""" # Ensue parent/children relationship is valid tn.validate_children_and_parents_align(top_node) + assert top_node.parent is not None # Take refs before swap top_children = top_node.parent.children @@ -51,3 +69,15 @@ def list_index( """Check if node is in list with "is" operator.""" # compare with "is" to get memory comparison. ".index()" uses value comparison return next(index for index, element in enumerate(collection) if element is node) + + +def get_next_node( + nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode +) -> tn.ScheduleTreeNode: + """Get next node in the children from given node""" + return nodes[list_index(nodes, node) + 1] + + +def last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: + """Test for last node of list""" + return list_index(nodes, node) >= len(nodes) - 1 diff --git a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py new file mode 100644 index 00000000..11135ef6 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py @@ -0,0 +1,82 @@ +from copy import deepcopy + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl import Backend +from ndsl.config import BackendLoopOrder, BackendTargetDevice +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_map, + is_cartesian_axis, +) + + +class _KernelizeMap(tn.ScheduleNodeTransformer): + def __init__(self, axis: AxisIterator) -> None: + super().__init__() + self._axis = axis + + def __str__(self) -> str: + return f"KernelizeMap_{self._axis}" + + def _count_cartesian_children(self, node: tn.ScheduleTreeScope) -> int: + cartesian_children = 0 + for child in node.children: + if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis( + child + ): + cartesian_children += 1 + return cartesian_children + + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope | list[tn.MapScope]: + # if this is a map on a cartesian axis + # and the children contain two or more cartesian axes + if is_axis_map(node, self._axis) and self._count_cartesian_children(node) > 1: + kernelized_maps: list[tn.MapScope] = [] + current_children: list[tn.ScheduleTreeNode] = [] + + for child in node.children: + current_children.append(child) + if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis( + child + ): + kernelized_maps.append( + tn.MapScope( + node=deepcopy(node.node), + children=[child for child in current_children], + parent=node.parent, + state=node.state, + ) + ) + current_children = [] + return kernelized_maps + + return self.generic_visit(node) + + +class KernelizeMaps(tn.ScheduleNodeVisitor): + def __init__(self, backend: Backend) -> None: + super().__init__() + self._backend = backend + + if self._backend.device != BackendTargetDevice.GPU: + raise ValueError( + "The transformation `KernelizeMaps` is only intended to run on GPUs." + ) + + def __str__(self) -> str: + return "KernelizeMaps" + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + for axis in self._axis_order(): + _KernelizeMap(axis).visit(node) + + def _axis_order(self) -> list[AxisIterator]: + if self._backend.loop_order == BackendLoopOrder.IJK: + return [AxisIterator._J, AxisIterator._I] + if self._backend.loop_order == BackendLoopOrder.KJI: + return [AxisIterator._J, AxisIterator._K] + + raise NotImplementedError( + f"KernelizeMaps is not configured for loop order {self._backend.loop_order}." + ) diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py new file mode 100644 index 00000000..de4c21a2 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -0,0 +1,129 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + get_next_node, + is_axis_map, + last_node, + list_index, +) + + +class InlineOffgridConditionals(tn.ScheduleNodeVisitor): + """ + Push offgrid conditional inside their cartesian block, duplicating the + conditional if needed. + + Turning: + ``` + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... + ``` + into + ``` + map i, j, k: + if a_flag == 0: + ... + map i, j, k: + if a_flag == 0: + ... + ``` + """ + + _axis: AxisIterator + + def __init__(self, axis: AxisIterator) -> None: + super().__init__() + self._axis = axis + + def __str__(self) -> str: + return f"InlineOffgridConditionals_{self._axis}" + + def visit_IfScope(self, node: tn.IfScope) -> None: + assert node.parent is not None # just to keep pyright happy + + # For now, skip in case there's an `elif` or `else` following. + if not last_node(node.parent.children, node): + next_node = get_next_node(node.parent.children, node) + if isinstance(next_node, (tn.ElifScope, tn.ElseScope)): + ndsl_log.debug( + "Can't handle conditionals with `elif` and `else` blocks yet :(" + ) + return + + if not all( + [ + isinstance(child, tn.MapScope) and is_axis_map(child, self._axis) + for child in node.children + ] + ): + return + + # If all children are maps over the correct axis, move the if inside. + new_nodes: list[tn.MapScope] = [] + + for child in node.children: + assert isinstance( + child, tn.MapScope + ) # otherwise the condition above is wrong + + if_scope = tn.IfScope( + condition=node.condition, children=child.children, parent=child + ) + + for map_child in child.children: + map_child.parent = if_scope # re-parent to new if_scope + + child.children = [if_scope] + child.parent = node.parent # re-parent to parent of old if_scope + new_nodes.append(child) + + insert_at = list_index(node.parent.children, node) + node.parent.children[insert_at:insert_at] = new_nodes + node.parent.children.remove(node) + + +class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): + """ + Push offgrid conditional outside of their cartesian block. + + This is the inverse transform of InlineOffgridConditionals. + """ + + def __str__(self) -> str: + return "ExtractOffgridConditionals" + + +class MergeConditionals(tn.ScheduleNodeTransformer): + """ + Merge consecutive and equal conditionals. + + Turning: + ``` + if a_flag == 0: + map i, j, k: + ... + if a_flag == 0: + map i, j, k: + ... + ``` + into + ``` + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... + ``` + + Outside of user code, combination of ExtractOffgridConditionals, + InlineOffgridConditionals and CartesianMapMerge can lead to this + pattern. + """ + + def __str__(self) -> str: + return "MergeConditionals" diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index e838d913..39d213fc 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,11 +1,11 @@ import warnings import dace.data -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log from ndsl.config import Backend, BackendFramework -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator def _change_index_of_tuple( @@ -34,15 +34,11 @@ def _reduce_cartesian_axis_size_to_1( are atomic""" # Dev Note: Better dataflow analysis would look at exactly - # what's goin on here! + # what's going on here! # Assume 3D cartesian! if len(transient_data.shape) < 3: - warnings.warn( - f"Potential non-3D array: {transient_data}, skipping.", - UserWarning, - stacklevel=2, - ) + ndsl_log.debug(f"Potential non-3D array: {transient_data}, skipping.") return False read_write_range: dace.subsets.Range = dace.subsets.union( @@ -78,7 +74,7 @@ def _reduce_cartesian_axis_size_to_1( return True -class CollectTransientRangeAccess(stree.ScheduleNodeVisitor): +class CollectTransientRangeAccess(tn.ScheduleNodeVisitor): """Unionize all transient arrays access into a single Range.""" def __init__(self) -> None: @@ -100,13 +96,14 @@ def __str__(self) -> str: def _find_first_map_or_loop( self, - node: stree.TaskletNode, + node: tn.TaskletNode, axis: AxisIterator, ) -> dace.nodes.MapEntry | None: parent = node.parent while parent is not None: - if isinstance(parent, stree.MapScope): - for p in parent.node.params: + if isinstance(parent, tn.MapScope): + for p in parent.node.map.params: + assert isinstance(p, str) if p.startswith(axis.as_str()): return parent.node @@ -115,8 +112,8 @@ def _find_first_map_or_loop( def _record_access( self, - node: stree.TaskletNode, - memlets: stree.MemletSet, + node: tn.TaskletNode, + memlets: tn.MemletSet, recording_set: dict[str, dace.subsets.Range | None], ) -> None: for memlet in memlets: @@ -149,11 +146,11 @@ def _record_access( AxisIterator._K.as_cartesian_index() ].add(map_entry) - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: self._record_access(node, node.input_memlets(), self.transients_range_writes) self._record_access(node, node.output_memlets(), self.transients_range_reads) - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers for name, data in self.containers.items(): if data.transient and isinstance(data, dace.data.Array): @@ -165,7 +162,7 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.visit(child) -class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): +class RebuildMemletsFromContainers(tn.ScheduleNodeVisitor): """Rebuild memlets from containers to ensure they are scope to the right size.""" def __init__(self, refined_arrays: set[str]) -> None: @@ -174,7 +171,7 @@ def __init__(self, refined_arrays: set[str]) -> None: def __str__(self) -> str: return "RefineTransientAxis" - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: for memlet in [*node.output_memlets(), *node.input_memlets()]: if memlet.data not in self._refined_arrays: continue @@ -191,13 +188,13 @@ def visit_TaskletNode(self, node: stree.TaskletNode) -> None: if array.shape[index] == 1: memlet.subset.ranges[index] = (0, 0, 1) - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers for child in node.children: self.visit(child) -class CartesianRefineTransients(stree.ScheduleNodeTransformer): +class CartesianRefineTransients(tn.ScheduleNodeTransformer): """Refine (reduce dimensionality) of transients based on their true use in the cartesian dimensions. @@ -210,7 +207,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): cartesian axis) it will reduce that axis to 1 if all access are atomic (exactly _one_ element of the array is ever worked on in a single loop) - It will refuse to merge if the transient is used in multiple loops of for - a given axis - irrigardless of it's access pattern (e.g. even if it could be + a given axis - regardless of it's access pattern (e.g. even if it could be refine because it's always written first.) It should but cannot do/will bug if: @@ -258,7 +255,7 @@ def __init__(self, backend: Backend) -> None: def __str__(self) -> str: return "CartesianRefineTransients" - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: collect_map = CollectTransientRangeAccess() collect_map.visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py new file mode 100644 index 00000000..89716404 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -0,0 +1,87 @@ +import ast +from typing import Any + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + list_index, +) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol + + +class InlineVertical2DWrite(tn.ScheduleNodeVisitor): + """Inline K index value for 2D write vertical while removing for loop. + + Transforming: + ``` + for __k = 0; __k < 1; __k = __k + 1: + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, __k]) + ``` + + Into + ``` + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, 0]) + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._for_scopes_removed = 0 + + def __str__(self) -> str: + return "InlineVertical2DWrite" + + def visit_ForScope(self, the_for: tn.ForScope) -> None: + if not is_axis_for(the_for, AxisIterator._K): + return + + assert the_for.parent is not None # just to keep pyright happy + + # Retrieve init/bound value by executing the code and replace usage of it + # If the code cannot be executed (no-literal variable part of the op, etc.) + # we will _not_ inline + try: + exec_locals: dict[str, Any] = {} + exec_globals: dict[str, Any] = {} + exec( + ast.unparse(the_for.loop.init_statement.code[0]), + exec_globals, + exec_locals, + ) + init_value = exec_locals[the_for.loop.loop_variable] + bound_value = eval( + ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) + ) + except Exception as _: + return + if abs(bound_value - init_value) != 1: + return + + ReplaceAxisSymbol( + {dace.symbol(the_for.loop.loop_variable): str(init_value)} + ).visit(the_for) + + # Insert children of the ForScope to parent + insert_at = list_index(the_for.parent.children, the_for) + for child in the_for.children: + child.parent = the_for.parent + the_for.parent.children[insert_at:insert_at] = the_for.children + + # Remove ForScope + the_for.parent.children.remove(the_for) + self._for_scopes_removed += 1 + assert len(the_for.children) > 0 + + def visit_ScheduleTreeRoot(self, the_root: tn.ScheduleTreeRoot) -> None: + self._for_scopes_removed = 0 + + for child in the_root.children: + self.visit(child) + + ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined") diff --git a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py new file mode 100644 index 00000000..c04c2fc5 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -0,0 +1,33 @@ +import itertools + +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.symbolic import symbol + + +class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): + def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None: + self._axis_replacements = axis_replacements + + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: + for memlet in itertools.chain( + node.in_memlets.values(), node.out_memlets.values() + ): + memlet.replace(self._axis_replacements) + + if node.node.label.startswith("masklet"): + for old, new in self._axis_replacements.items(): + node.node.code.as_string = node.node.code.as_string.replace( + str(old), str(new) + ) + + def visit_IfScope(self, node: tn.IfScope) -> None: + for old, new in self._axis_replacements.items(): + node.condition.as_string = node.condition.as_string.replace( + str(old), str(new) + ) + + for child in node.children: + self.visit(child) + + def __str__(self) -> str: + return "ReplaceAxisSymbol" diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index 2583ec2d..9f7e4be4 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -1,15 +1,16 @@ -import dace.sdfg.analysis.schedule_tree.treenodes as stree import dace.subsets as sbs +from dace.sdfg.analysis.schedule_tree import treenodes as tn -class SpecializeCartesianMaps(stree.ScheduleNodeVisitor): +class SpecializeCartesianMaps(tn.ScheduleNodeVisitor): def __init__(self, mappings: dict[str, int]) -> None: super().__init__() self._mappings = mappings - def visit_MapScope(self, node: stree.MapScope) -> None: + def visit_MapScope(self, node: tn.MapScope) -> None: dims = [] for p in node.node.map.params: + assert isinstance(p, str) if p == "__i": dims.append((0, self._mappings["__I"], 1)) if p == "__j": @@ -19,3 +20,6 @@ def visit_MapScope(self, node: stree.MapScope) -> None: node.node.map.range = sbs.Range(dims) self.visit(node.children) + + def __str__(self) -> str: + return "SpecializeCartesianMaps" diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py new file mode 100644 index 00000000..6fc927f9 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -0,0 +1,101 @@ +import dataclasses + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, +) + + +class CountCartesianLoops(tn.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._maps = [0, 0, 0] + self._fors = [0, 0, 0] + + def visit_MapScope(self, node: tn.MapScope) -> None: + for axis in AxisIterator: + if is_axis_map(node, axis): + self._maps[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + def visit_ForScope(self, node: tn.ForScope) -> None: + for axis in AxisIterator: + if is_axis_for(node, axis): + self._fors[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + +class CountTransient(tn.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._counts = [0, 0, 0, 0, 0] + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + for data in node.containers.values(): + non_atomic_dims_count = sum(1 for x in data.shape if x != 1) + if isinstance(data, dace.data.Array) and data.transient: + if non_atomic_dims_count == 0: + self._counts[0] += 1 + elif non_atomic_dims_count == 1: + self._counts[1] += 1 + elif non_atomic_dims_count == 2: + self._counts[2] += 1 + elif non_atomic_dims_count == 3: + self._counts[3] += 1 + else: + self._counts[4] += 1 + + +class TreeOptimizationStatistics: + """Capture basic statistics on the schedule tree optimization actions""" + + @dataclasses.dataclass + class Record: + """Private record of a state of a tree""" + + cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + transients: list[int] = dataclasses.field( + default_factory=lambda: [0, 0, 0, 0, 0] + ) + + def __init__(self) -> None: + self._original_record = TreeOptimizationStatistics.Record() + self._optimized_record = TreeOptimizationStatistics.Record() + + def _record( + self, + record: Record, + tree_root: tn.ScheduleTreeRoot, + ) -> None: + """Record the state of a tree""" + c = CountCartesianLoops() + c.visit(tree_root) + record.cartesian_fors = c._fors + record.cartesian_maps = c._maps + + c = CountTransient() + c.visit(tree_root) + record.transients = c._counts + + def original(self, tree_root: tn.ScheduleTreeRoot) -> None: + """Record the original state of the tree, before optimization""" + self._record(self._original_record, tree_root) + + def optimized(self, tree_root: tn.ScheduleTreeRoot) -> None: + """Record the state of the tree after optimization""" + self._record(self._optimized_record, tree_root) + + def report(self) -> str: + """Craft a concize string reporting on the statistics""" + msg = "Tree optimization:\n" + msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" + msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" + msg += f" Transients [Scalarized Array, 1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + return msg diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index f9bc452f..0b7d4713 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,16 +1,23 @@ from pathlib import Path -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl import ndsl_log_on_rank_0 -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl import Backend, OptimizationConfig, ndsl_log_on_rank_0 +from ndsl.dsl.dace.stree.optimizations import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, + KernelizeMaps, + TreeOptimizationStatistics, +) class StreePipeline: def __init__( self, + config: OptimizationConfig, *, - passes: list[stree.ScheduleNodeTransformer], + passes: list[tn.ScheduleNodeVisitor], cache_directory: Path | None = None, ) -> None: if cache_directory is None: @@ -18,6 +25,7 @@ def __init__( self.cache_directory = cache_directory self.passes = passes + self.config = config def __hash__(self) -> int: return hash(repr(self)) @@ -27,10 +35,14 @@ def __repr__(self) -> str: def run( self, - stree: stree.ScheduleTreeRoot, + stree: tn.ScheduleTreeRoot, verbose: bool = False, - ) -> stree.ScheduleTreeRoot: + ) -> tn.ScheduleTreeRoot: + tree_stats = TreeOptimizationStatistics() + tree_stats.original(stree) + for i, p in enumerate(self.passes): + path: Path | None = None if verbose: path = self.cache_directory / f"pass{i}_{p}.txt" ndsl_log_on_rank_0.info(f"[Stree OPT] {p} (saving {path} after)") @@ -38,23 +50,40 @@ def run( p.visit(stree) if verbose: + assert path is not None with open(path, "w+") as f: f.write(stree.as_string()) + tree_stats.optimized(stree) + + if verbose: + ndsl_log_on_rank_0.info(tree_stats.report()) + return stree class CPUPipeline(StreePipeline): def __init__( self, + config: OptimizationConfig, + backend: Backend, *, - passes: list[stree.ScheduleNodeTransformer] | None = None, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + ppl_passes = [ + CleanUpScheduleTree(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), + CartesianMerge(backend, overcompute=config.stree.merger.overcompute), + CartesianRefineTransients(backend), + ] + else: + ppl_passes = passes super().__init__( - passes=( - passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] - ), + config=config, + passes=ppl_passes, cache_directory=cache_directory, ) @@ -62,10 +91,27 @@ def __init__( class GPUPipeline(StreePipeline): def __init__( self, - passes: list[stree.ScheduleNodeTransformer] | None = None, + config: OptimizationConfig, + backend: Backend, + *, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + ppl_passes = [ + CleanUpScheduleTree(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), + CartesianMerge(backend, overcompute=config.stree.merger.overcompute), + KernelizeMaps(backend), + # 🐞 Transient refine can't be used + # because of bugs transients showing in code generation + # CartesianRefineTransients(backend), + ] + else: + ppl_passes = passes super().__init__( - passes=passes if passes is not None else [], + config=config, + passes=ppl_passes, cache_directory=cache_directory, ) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index a994c61a..294f5711 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -5,6 +5,7 @@ from collections.abc import Callable from typing import Any, Sequence +from ndsl import OptimizationConfig from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float @@ -21,10 +22,22 @@ class NDSLRuntime: The __call__ function will automatically be orchestrated.""" - def __init__(self, stencil_factory: StencilFactory) -> None: + def __init__( + self, + stencil_factory: StencilFactory, + optimization_config: OptimizationConfig | None = None, + ) -> None: self._stencil_factory = stencil_factory # Use this flag to detect that the init wasn't done properly self._base_class_was_properly_super_init = True + if optimization_config is None: + # TODO + # - Decide where to put defaults. + # - For now, they are in the OptimizationConfig object itself. + # - We could have specialized defaults here for NDSLRuntime code. + self._optimization_config = OptimizationConfig() + else: + self._optimization_config = optimization_config def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: # WARNING: no code outside the `init_decorator` this is cls @@ -75,6 +88,7 @@ def check_for_quantity(object_: object) -> None: orchestrate( obj=self, config=self._stencil_factory.config.dace_config, + optimization_config=self._optimization_config, ) def __getattribute__(self, name: str) -> Any: diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py new file mode 100644 index 00000000..049bc8d0 --- /dev/null +++ b/ndsl/dsl/optimization_config.py @@ -0,0 +1,30 @@ +import os +from dataclasses import dataclass, field + + +@dataclass +class OptimizationConfig: + @dataclass + class Tree: + """Optimization using the Schedule Tree IR""" + + @dataclass + class Merger: + overcompute: bool = ( + os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True").lower() == "true" + ) + """When merging allow map of different size to merge by inserting an if guard""" + + enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" + """Enable Schedule Tree transformations""" + merger: Merger = field(default_factory=Merger) + + @dataclass + class GPU: + """Optimization dedicated for GPU""" + + common_gpu_xforms: bool = False + """DaCe common xforms bundled in `apply_gpu_transformations`""" + + stree: Tree = field(default_factory=Tree) + gpu: GPU = field(default_factory=GPU) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index d5d37264..7b448944 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -881,6 +881,8 @@ def _origin_from_dims(self, dims: Iterable[str]) -> list[int]: return_origin.append(self.origin[1]) elif dim in K_DIMS: return_origin.append(self.origin[2]) + else: + raise ValueError(f"Unknown dimension '{dim}'.") return return_origin def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: @@ -888,16 +890,18 @@ def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: for dimension in dimensions: if dimension == I_DIM: result.append(self.domain[0]) - if dimension == I_INTERFACE_DIM: + elif dimension == I_INTERFACE_DIM: result.append(self.domain[0] + 1) - if dimension == J_DIM: + elif dimension == J_DIM: result.append(self.domain[1]) - if dimension == J_INTERFACE_DIM: + elif dimension == J_INTERFACE_DIM: result.append(self.domain[1] + 1) - if dimension == K_DIM: + elif dimension == K_DIM: result.append(self.domain[2]) - if dimension == K_INTERFACE_DIM: + elif dimension == K_INTERFACE_DIM: result.append(self.domain[2] + 1) + else: + raise ValueError(f"Unknown dimension '{dimension}'.") return result def get_shape( diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index 4a257080..c923afee 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -17,11 +17,18 @@ def __init__( n_halo: int, data_dimensions: dict[str, int], backend: Backend, + *, + pad_non_interface_dimensions: bool = False, ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) fortran_style_memory = backend.is_fortran_aligned() - self._pad_non_interface_dimensions = not fortran_style_memory + + # TODO: pad_non_interface_dimensions should not be kept. In general + # this should _always_ be False and non-interface dimensions never padded by default + self._pad_non_interface_dimensions = ( + not fortran_style_memory or pad_non_interface_dimensions + ) @classmethod def from_tile_params( @@ -36,6 +43,7 @@ def from_tile_params( data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, + pad_non_interface_dimensions: bool = False, ) -> Self: """Create a SubtileGridSizer from parameters about the full tile. @@ -76,7 +84,15 @@ def from_tile_params( "SubtileGridSizer::from_tile_params: Compute domain extent must be greater than halo size" ) - return cls(nx, ny, nz, n_halo, data_dimensions, backend) + return cls( + nx, + ny, + nz, + n_halo, + data_dimensions, + backend, + pad_non_interface_dimensions=pad_non_interface_dimensions, + ) @classmethod def from_namelist( diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index f69480a7..37aee3eb 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -31,6 +31,7 @@ def __init__( # Initialize memory to obviously wrong value - Local should _not_ be expected # to be zero'ed. data[:] = 123456789 + self._on_gpu = backend.is_gpu_backend() super().__init__( data, @@ -45,5 +46,5 @@ def __init__( def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" data = dace.data.create_datadescriptor(self._data) - data.transient = True + data.transient = True if not self._on_gpu else False return data diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 0624a8c0..5d310674 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -287,7 +287,7 @@ def field(self) -> np.ndarray | cupy.ndarray: def data(self) -> np.ndarray | cupy.ndarray: """The underlying array of data""" warnings.warn( - "Quantity.data accessor is now deprecated. Use a slicing operation directly on" + "Quantity.data accessor is now deprecated. Use a slicing operation directly on " "the quantity, e.g. `my_quantity[:]` instead of `my_quantity.data[:]`", category=UserWarning, stacklevel=2, diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 6e5b17af..652e132a 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -105,6 +105,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: default=False, help="Do not generate logging report or NetCDF in .translate-errors", ) + parser.addoption( + "--pad_non_interface_dimensions", + action="store_true", + default=False, + help="Pad the non interface dimensions in all backends. Default to False.", + ) def pytest_configure(config: pytest.Config) -> None: @@ -255,6 +261,9 @@ def _sequential_savepoint_cases( topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") + pad_non_interface_dimensions = metafunc.config.getoption( + "pad_non_interface_dimensions" + ) return _savepoint_cases( savepoint_names, @@ -268,6 +277,7 @@ def _sequential_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, + pad_non_interface_dimensions=pad_non_interface_dimensions, ) @@ -283,6 +293,7 @@ def _savepoint_cases( topology_mode: str, sort_report: str, no_report: bool, + pad_non_interface_dimensions: bool, ) -> list[SavepointCase]: grid_params = grid_params_from_f90nml(namelist) return_list = [] @@ -305,6 +316,7 @@ def _savepoint_cases( rank=rank, layout=grid_params["layout"], backend=backend, + pad_non_interface_dimensions=pad_non_interface_dimensions, ).python_grid() if grid_mode == "compute": _compute_grid_data( @@ -377,6 +389,9 @@ def _parallel_savepoint_cases( savepoint_names = _parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = _get_savepoint_restriction(metafunc) + pad_non_interface_dimensions = metafunc.config.getoption( + "pad_non_interface_dimensions" + ) return _savepoint_cases( savepoint_names, @@ -390,6 +405,7 @@ def _parallel_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, + pad_non_interface_dimensions=pad_non_interface_dimensions, ) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 3af290e4..db24fd13 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -60,6 +60,7 @@ def _make( layout: tuple[int, int], rank: int, backend: Backend, + pad_non_interface_dimensions: bool = False, ) -> "Grid": shape_params = { "npx": npx, @@ -81,7 +82,15 @@ def _make( "js": N_HALO_DEFAULT, "je": ny + N_HALO_DEFAULT - 1, } - return cls(indices, shape_params, rank, layout, backend, local_indices=True) + return cls( + indices, + shape_params, + rank, + layout, + backend, + local_indices=True, + pad_non_interface_dimensions=pad_non_interface_dimensions, + ) @classmethod def from_namelist(cls, namelist: Namelist, rank: int, backend: Backend) -> "Grid": @@ -112,6 +121,7 @@ def __init__( backend: Backend, data_fields: dict | None = None, local_indices: bool = False, + pad_non_interface_dimensions: bool = False, ) -> None: if data_fields is None: data_fields = {} @@ -162,6 +172,7 @@ def __init__( self._grid_data: GridData | None = None self._driver_grid_data: DriverGridData | None = None self._damping_coefficients: DampingCoefficients | None = None + self._pad_non_interface_dimensions = pad_non_interface_dimensions @property def sizer(self) -> GridSizer: @@ -180,6 +191,7 @@ def sizer(self) -> GridSizer: }, layout=self.layout, backend=self.backend, + pad_non_interface_dimensions=self._pad_non_interface_dimensions, ) return self._sizer diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 2bc0d4fc..66db8063 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -466,7 +466,7 @@ def _report_results( os.makedirs(detail_dir, exist_ok=True) # Summary - header = f"{savepoint_name} w/ f{backend.as_humanly_readable()}" + header = f"{savepoint_name} w/ {backend.as_humanly_readable()}" lines = [] for varname, metric in results.items(): lines.append(f"{varname}: {metric.one_line_report()}") diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 29afc577..4011a1c0 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -68,10 +68,7 @@ def __init__( self.ordered_input_vars = None self.ignore_near_zero_errors: dict[str, Any] = {} self.skip_test = skip_test - if self.stencil_factory.backend.is_fortran_aligned(): - self.maxshape = self.grid.domain_shape_full() - else: - self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1)) + self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1)) def extra_data_load(self, data_loader: DataLoader): pass @@ -322,7 +319,15 @@ def new_from_serialized_data(cls, serializer, rank, layout, backend: Backend): grid_data[field] = read_serialized_data(serializer, grid_savepoint, field) return cls(grid_data, rank, layout, backend=backend) - def __init__(self, inputs, rank, layout, *, backend: Backend): + def __init__( + self, + inputs, + rank, + layout, + *, + backend: Backend, + pad_non_interface_dimensions: bool = False, + ): self.backend = backend self.indices = {} self.shape_params = {} @@ -338,6 +343,7 @@ def __init__(self, inputs, rank, layout, *, backend: Backend): del inputs[index] self.data = inputs + self._pad_non_interface_dimensions = pad_non_interface_dimensions def _make_composite_var_storage(self, varname, data3d, shape, count): for s in range(count): @@ -444,7 +450,12 @@ def make_grid_storage(self, pygrid): def python_grid(self): pygrid = Grid( - self.indices, self.shape_params, self.rank, self.layout, self.backend + self.indices, + self.shape_params, + self.rank, + self.layout, + self.backend, + pad_non_interface_dimensions=self._pad_non_interface_dimensions, ) self.make_grid_storage(pygrid) pygrid.add_data(self.data) diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 3acfd723..e7fc93eb 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -339,7 +339,7 @@ def one_line_report(self) -> str: return f"❌ Numerical failures: {failed_indices}/{all_indices} failed - metric: {metric_thresholds}" def report(self, file_path: str | None = None) -> list[str]: - failed_indices = np.logical_not(self.success).nonzero() + failed_indices = np.atleast_1d(np.logical_not(self.success)).nonzero() # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) if self.changing_column_map is not None: diff --git a/tests/dsl/dace/stree/__init__.py b/tests/dsl/dace/stree/__init__.py index 2fa38d13..b43c1d92 100644 --- a/tests/dsl/dace/stree/__init__.py +++ b/tests/dsl/dace/stree/__init__.py @@ -1,7 +1,7 @@ -from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge +from .sdfg_stree_tools import StreePipeline, get_SDFG_and_purge __all__ = [ - "StreeOptimization", + "StreePipeline", "get_SDFG_and_purge", ] diff --git a/tests/dsl/dace/stree/common/__init__.py b/tests/dsl/dace/stree/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dsl/dace/stree/common/test_loops.py b/tests/dsl/dace/stree/common/test_loops.py new file mode 100644 index 00000000..a2c12d76 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_loops.py @@ -0,0 +1,96 @@ +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.state import LoopRegion + +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, + is_cartesian_axis, +) + + +def test_is_axis_map_multiple_params() -> None: + node = tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map_ij", ["__i", "__j"], [(0, 3, 1), (0, 4, 1)]) + ), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_axis_map_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_not_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_K() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._K) + + +def test_is_axis_map_wrong_iterator() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_cartesian_axis() -> None: + map_i = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_i) + + map_j = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", ["__j"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_j) + + map_k = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_k) + + for_k = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert is_cartesian_axis(for_k) + + map_non_cartesian = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])), + children=[], + ) + assert not is_cartesian_axis(map_non_cartesian) + + +def test_is_axis_for_k() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert is_axis_for(node, AxisIterator._K) + + +def test_is_axis_for_wrong_iterator() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert not is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i"), children=[]) + assert is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_not_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i0"), children=[]) + assert not is_axis_for(node, AxisIterator._I) diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py new file mode 100644 index 00000000..44fe15e1 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_memlet.py @@ -0,0 +1,32 @@ +from dace.symbolic import symbol + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common.memlet import ( + normalize_cartesian_indexation, +) + + +def test_normalize_cartesian_index(): + # Case of __k_id(node) - original case + original_symbol = symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + + # Case of offset + original_symbol = 1 + symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of no-op (with offset) + original_symbol = 1 + symbol("__k") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of index named with _k - so not a cartesian axis + original_symbol = 1 + symbol("_kindex") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("_kindex") + 1 diff --git a/tests/dsl/dace/stree/optimizations/__init__.py b/tests/dsl/dace/stree/optimizations/__init__.py index e69de29b..e0e56d60 100644 --- a/tests/dsl/dace/stree/optimizations/__init__.py +++ b/tests/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +from ndsl import QuantityFactory, StencilFactory + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py new file mode 100644 index 00000000..3343b3eb --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -0,0 +1,182 @@ +import pytest +from dace import nodes +from dace.sdfg.state import LoopRegion + +from ndsl import Backend, NDSLRuntime, OptimizationConfig, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval +from ndsl.dsl.stencil import StencilFactory +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree import get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories + + +def stencil_kernelize(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(PARALLEL), interval(...): + value = in_field * 2 + tmp = value + + with computation(FORWARD), interval(0, -1): + tmp = 0.5 * (tmp + tmp[0, 0, 1]) + + with computation(PARALLEL), interval(...): + out_field = tmp + + +def stencil_only_serial_noop( + in_field: FloatField, out_field: FloatField +) -> None: # type:ignore + with computation(FORWARD), interval(...): + tmp = in_field + + with computation(BACKWARD), interval(...): + out_field = tmp + + +def stencil_only_parallel_noop( + in_field: FloatField, out_field: FloatField +) -> None: # type:ignore + with computation(PARALLEL), interval(0, 2): + out_field = in_field + + with computation(PARALLEL), interval(-2, None): + out_field = in_field + 1 + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + optimization_config = OptimizationConfig(OptimizationConfig.Tree(enabled=True)) + super().__init__(stencil_factory, optimization_config) + + methods_to_orchestrate = [ + "kernelize_k", + "only_serial_noop", + "only_parallel_noop", + ] + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + optimization_config=optimization_config, + ) + + self._stencil_kernelize_k = stencil_factory.from_dims_halo( + func=stencil_kernelize, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + self._stencil_only_serial_noop = stencil_factory.from_dims_halo( + func=stencil_only_serial_noop, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + self._stencil_only_parallel_noop = stencil_factory.from_dims_halo( + func=stencil_only_parallel_noop, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + + def kernelize_k(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_kernelize_k(in_field, out_field) + + def only_serial_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_only_serial_noop(in_field, out_field) + + def only_parallel_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_only_parallel_noop(in_field, out_field) + + +class TestKernelizeMaps: + @pytest.fixture( + params=[ + "orch:dace:cpu:IJK", + pytest.param("orch:dace:gpu:IJK", marks=pytest.mark.gpu), + ] + ) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 4, 5) + return get_factories_single_tile( + nx=domain[0], + ny=domain[1], + nz=domain[2], + nhalo=0, + backend=Backend(request.param), + ) + + def test_kernelize_k_gpu(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), "") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), "") + + code.kernelize_k(in_field, out_field) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + if stencil_factory.backend.is_gpu_backend(): + # check for kernelization + all_maps = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, nodes.MapEntry) + ] + + ij_maps = 0 + ijk_maps = 0 + for map_entry in all_maps: + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + elif len(map_entry.map.params) == 3: + params = map_entry.map.params + k_param = params[2] + if ( + params[0:2] == ["__i", "__j"] + and isinstance(k_param, str) + and k_param.startswith("__k") + ): + ijk_maps += 1 + + # expect two IJK-maps and one IJ-map + assert ij_maps == 1 + assert ijk_maps == 2 + assert len(all_maps) == 3 + + all_loop_regions = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, LoopRegion) + ] + # expect one k-loop is preserved + assert len(all_loop_regions) == 1 + assert all_loop_regions[0].loop_variable.startswith("__k") + else: + # check that we keep IJ loops merged + all_maps = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, nodes.MapEntry) + ] + + ij_maps = 0 + k_maps = 0 + for map_entry in all_maps: + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + elif len(map_entry.map.params) == 1: + param = map_entry.map.params[0] + if isinstance(param, str) and param.startswith("__k"): + k_maps += 1 + + # expect one IJ-map and two K-maps + assert ij_maps == 1 + assert k_maps == 2 + assert len(all_maps) == 3 + + all_loop_regions = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, LoopRegion) + ] + # expect one k-loop is preserved + assert len(all_loop_regions) == 1 + assert all_loop_regions[0].loop_variable.startswith("__k") diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index d57e758a..d20e3c4c 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -1,18 +1,18 @@ -from typing import TypeAlias - import dace import pytest from dace import nodes from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.state import LoopRegion -from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl import OptimizationConfig, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.dace.stree.pipeline import CartesianMerge, CleanUpScheduleTree from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories def stencil(in_field: FloatField, out_field: FloatField) -> None: @@ -54,6 +54,7 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) orchestratable_methods = [ "trivial_merge", "missing_merge_of_forscope_and_map", @@ -66,6 +67,7 @@ def __init__( obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self.stencil = stencil_factory.from_dims_halo( @@ -130,9 +132,6 @@ def push_non_cartesian_for( self.stencil(in_field, out_field) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStreeMergeMapsIJK: @pytest.fixture def factories(self) -> Factories: @@ -150,8 +149,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.trivial_merge(in_qty, out_qty) + code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ @@ -160,7 +158,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -170,8 +168,7 @@ def test_missing_merge_of_forscope_and_map( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.missing_merge_of_forscope_and_map(in_qty, out_qty) + code.missing_merge_of_forscope_and_map(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -179,7 +176,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks + assert len(all_maps) == 3 # 1 IJ + 2 Ks all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -194,8 +191,7 @@ def test_overcompute_merge( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.overcompute_merge(in_qty, out_qty) + code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -203,7 +199,40 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged and collapsed + + def test_no_overcompute_merge( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + no_overcompute = [ + CleanUpScheduleTree(), + CartesianMerge(stencil_factory.backend, overcompute=False), + ] + + with StreePipeline(passes=no_overcompute): + code.overcompute_merge(in_qty, out_qty) + + sdfg = get_SDFG_and_purge(stencil_factory).sdfg + + all_maps = [ + me for me, _ in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) + ] + k_maps = 0 + ij_maps = 0 + for map_entry in all_maps: + if len(map_entry.map.params) == 1 and map_entry.map.params[0].startswith( + "__k" + ): + k_maps += 1 + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + + assert ij_maps == 1 + assert k_maps == 2 def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -212,9 +241,8 @@ def test_block_merge_when_dependencies_are_found( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Forbid merging when data dependencies are detected - code.block_merge_when_dependencies_are_found(in_qty, out_qty) + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -222,7 +250,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) + assert len(all_maps) == 3 # 1 IJ + 2 Ks (un-merged) def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -231,10 +259,9 @@ def test_push_non_cartesian_for( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Push non-cartesian ForScope inwards, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -242,7 +269,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged & collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() @@ -268,8 +295,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.trivial_merge(in_qty, out_qty) + code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ @@ -278,7 +304,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all maps merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -288,9 +314,8 @@ def test_missing_merge_of_forscope_and_map( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # K iterative loop - blocks all merges - code.missing_merge_of_forscope_and_map(in_qty, out_qty) + # K iterative loop - blocks all merges + code.missing_merge_of_forscope_and_map(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -298,7 +323,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 8 # 2 KJI (all maps) + 1 for scope + assert len(all_maps) == 3 # 2 KJI (all maps) + 1 JI all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -313,9 +338,8 @@ def test_overcompute_merge( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Overcompute merge in K - we merge and introduce an If guard - code.overcompute_merge(in_qty, out_qty) + # Overcompute merge in K - we merge and introduce an If guard + code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -323,7 +347,7 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged & collapsed def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -332,9 +356,8 @@ def test_block_merge_when_dependencies_are_found( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Forbid merging when data dependencies are detected - code.block_merge_when_dependencies_are_found(in_qty, out_qty) + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -342,7 +365,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 6 # 2 * KJI + assert len(all_maps) == 2 # 2 * KJI def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -351,10 +374,9 @@ def test_push_non_cartesian_for( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Push non-cartesian ForScope inwards, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -362,7 +384,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged and collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py new file mode 100644 index 00000000..ce1446b2 --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -0,0 +1,153 @@ +import pytest +from dace import nodes + +from ndsl import ( + Backend, + NDSLRuntime, + OptimizationConfig, + StencilFactory, + orchestrate, + stencils, +) +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree import get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + super().__init__(stencil_factory, config) + + methods_to_orchestrate = [ + "happy_case", + "happy_case_2", + "blocked_by_else", + "blocked_by_other_nodes", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + optimization_config=config, + ) + + self._copy_stencil = stencil_factory.from_dims_halo( + func=stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] + ) + + def happy_case(self, in_field: FloatField, out_field: FloatField) -> None: + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def happy_case_2(self, in_field: FloatField, out_field: FloatField) -> None: + if not in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def blocked_by_else(self, in_field: FloatField, out_field: FloatField) -> None: + self._copy_stencil(in_field, out_field) + + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + else: + self._copy_stencil(out_field, in_field) + + def blocked_by_other_nodes( + self, in_field: FloatField, out_field: FloatField + ) -> None: + if in_field[0, 0, 0] > 0: + in_field[:] = 42.0 + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + +class TestStreeInlineOffgridConditionals: + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) + ) + + def test_happy_case(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + code.happy_case(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 1 # all merged and collapsed + + def test_happy_case_2(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + code.happy_case_2(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 1 # all merged and collapsed + + def test_blocked_by_else(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + code.blocked_by_else(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 3 # 3 * IJK/KJI + + def test_blocked_by_other_nodes(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + code.blocked_by_other_nodes(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + + # ⚠️ Dev note: + # This should be just `assert len(all_maps) == 2`, but currently, the K-loops + # can't merge because the K-iterators are different. To be fixed (and simplified + # here) with a subsequent commit. + assert len(all_maps) == 3 diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index 677790bc..89662b4e 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -1,10 +1,9 @@ -from ndsl import StencilFactory, orchestrate +from ndsl import OptimizationConfig, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization def double_map(in_field: FloatField, out_field: FloatField): @@ -17,7 +16,12 @@ def double_map(in_field: FloatField, out_field: FloatField): class TriviallyMergeableCode: def __init__(self, stencil_factory: StencilFactory): - orchestrate(obj=self, config=stencil_factory.config.dace_config) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + optimization_config=config, + ) self.stencil = stencil_factory.from_dims_halo( func=double_map, compute_dims=[I_DIM, J_DIM, K_DIM], @@ -37,7 +41,6 @@ def test_stree_roundtrip_no_opt(): in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code(in_qty, out_qty) + code(in_qty, out_qty) assert (out_qty.field[:] == 4).all() diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py new file mode 100644 index 00000000..331fb44b --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -0,0 +1,299 @@ +import pytest +from dace import nodes +from dace.sdfg.state import LoopRegion + +from ndsl import OptimizationConfig, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend, BackendLoopOrder +from ndsl.constants import I_DIM, J_DIM, K_DIM, Float +from ndsl.dsl.dace.stree.optimizations import InlineVertical2DWrite +from ndsl.dsl.dace.stree.pipeline import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, +) +from ndsl.dsl.gt4py import FORWARD, computation, interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.stencils import copy +from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories + + +def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + + +def stencil_multiple_2D_write( + in_field: FloatField, out_fieldIJ: FloatFieldIJ, out_fieldIJ_2: FloatFieldIJ +) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + out_fieldIJ_2 = in_field + 1.0 + + +def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(-1, None): + out_fieldIJ = in_field + + +def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: + with computation(FORWARD), interval(...): + out_field = in_field + + +class OrchestratedCode: + def __init__(self, stencil_factory: StencilFactory) -> None: + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + methods_to_orchestrate = [ + "write_at_0", + "write_at_top", + "do_not_inline", + "combined_stencils", + "multiple_statements", + ] + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + optimization_config=config, + ) + + self.stencil_simple_2D_write = stencil_factory.from_dims_halo( + func=stencil_simple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_2D_write_at_K = stencil_factory.from_dims_halo( + func=stencil_2D_write_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_do_not_inline = stencil_factory.from_dims_halo( + func=stencil_forward_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_copy = stencil_factory.from_dims_halo( + func=copy, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_multiple_2D_write = stencil_factory.from_dims_halo( + func=stencil_multiple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + + def write_at_0( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_simple_2D_write(in_field, out_field) + + def write_at_top( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_2D_write_at_K(in_field, out_field) + + def do_not_inline( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil_do_not_inline(in_field, out_field) + + def combined_stencils( + self, field: FloatField, field2: FloatField, fieldIJ: FloatFieldIJ + ) -> None: + self.stencil_copy(field, field2) + self.stencil_simple_2D_write(field2, fieldIJ) + + def multiple_statements( + self, in_field: FloatField, out_field: FloatFieldIJ, out_field2: FloatFieldIJ + ) -> None: + self.stencil_copy(in_field, in_field) + self.stencil_multiple_2D_write(in_field, out_field, out_field2) + + +class TestStree2DWriteInline: + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) + def factories(self, request: pytest.FixtureRequest) -> Factories: + + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) + ) + + def test_common_2D_write(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] + + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, 0] = Float(32.0) + + with StreePipeline(passes=pipeline): + code.write_at_0(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 1 # IJ/JI collapsed + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() + + def test_2D_write_K_top(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] + + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, -1] = Float(32.0) + + with StreePipeline(passes=pipeline): + code.write_at_top(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 1 # IJ/JI collapsed + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() + + def test_do_not_inline(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] + + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreePipeline(passes=pipeline): + code.do_not_inline(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 1 # IJ/JI collapsed + assert len(all_loop_region) == 1 + assert (out_qty.field[:] == Float(1)).all() + + def test_combined_stencils(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + + with StreePipeline(passes=pipeline): + code.combined_stencils(field, field_2, field_IJ) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 2 # IJ + K + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 2 # KJI + JI + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(1)).all() + + def test_multiple_statements(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") + + field.field[:, :, 0] = Float(42.0) + with StreePipeline(passes=pipeline): + code.multiple_statements(field, field_IJ, field_IJ_2) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 2 # IJ + K + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 2 # KJI + JI + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(42.0)).all() + assert (field_IJ_2.field[:] == Float(43.0)).all() diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index 9795957a..3190f39d 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -1,10 +1,17 @@ -from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import ( + NDSLRuntime, + OptimizationConfig, + Quantity, + QuantityFactory, + StencilFactory, + orchestrate, +) from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval from ndsl.dsl.typing import Float, FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge DATADIM_SIZE = 8 @@ -39,7 +46,8 @@ class TransientRefineableCode(NDSLRuntime): def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory ) -> None: - super().__init__(stencil_factory) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + super().__init__(stencil_factory, optimization_config=config) orchestratable_methods = [ "refine_to_scalar", "refine_to_K_buffer", @@ -51,6 +59,7 @@ def __init__( obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self.stencil = stencil_factory.from_dims_halo( func=stencil, @@ -105,40 +114,39 @@ def test_stree_roundtrip_transient_is_refined() -> None: code = TransientRefineableCode(stencil_factory, quantity_factory) - with StreeOptimization(): - # Refine to scalar - code.refine_to_scalar(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1) - - # Refine cartesian axis to buffers - # IJ merges - K is a buffer - code.refine_to_K_buffer(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == ( - 1, - 1, - domain[2] + 1, # Quantity are domain size + 1 - ) - - # I merges - JK buffer - code.refine_to_JK_buffer(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == ( - 1, - domain[1] + 1, # Quantity are domain size + 1 - domain[2] + 1, - ) - - # Refine to remaining data dimensions - code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 + # Refine to scalar + code.refine_to_scalar(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) + + # Refine cartesian axis to buffers + # IJ merges - K is a buffer + code.refine_to_K_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + 1, + domain[2] + 1, # Quantity are domain size + 1 + ) + + # I merges - JK buffer + code.refine_to_JK_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + domain[1] + 1, # Quantity are domain size + 1 + domain[2] + 1, + ) + + # Refine to remaining data dimensions + code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 diff --git a/tests/dsl/dace/stree/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py index 6c664205..aeb149a5 100644 --- a/tests/dsl/dace/stree/sdfg_stree_tools.py +++ b/tests/dsl/dace/stree/sdfg_stree_tools.py @@ -1,6 +1,7 @@ from types import TracebackType import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn import ndsl.dsl.dace.orchestration as orch from ndsl import StencilFactory @@ -20,9 +21,13 @@ def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: return sdfg -class StreeOptimization: +class StreePipeline: + def __init__(self, *, passes: list[tn.ScheduleNodeVisitor] | None = None) -> None: + self.passes = passes + def __enter__(self) -> None: - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + self.original_passes = orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.passes def __exit__( self, @@ -30,4 +35,4 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.original_passes diff --git a/tests/dsl/orchestration/test_boundaries_k.py b/tests/dsl/orchestration/test_boundaries_k.py new file mode 100644 index 00000000..80ea4a84 --- /dev/null +++ b/tests/dsl/orchestration/test_boundaries_k.py @@ -0,0 +1,196 @@ +import numpy as np +import pytest + +from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation, interval +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree.optimizations import Factories + + +def accumulate_down(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = in_field + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + in_field + + +def accumulate_down_from_interface_field(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_on_interface(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-2, -1): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -2): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_up(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + out_field = in_field + + # accumulate "upwards" + with interval(1, None): + out_field = out_field[0, 0, -1] + in_field + + +def accumulate_up_interface(in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + interface_field = in_field + + # accumulate "upwards" + with interval(1, None): + interface_field = interface_field[0, 0, -1] + in_field[0, 0, -1] + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "accumulate_down", + "accumulate_down_from_interface_field", + "accumulate_on_interface", + "accumulate_up", + "accumulate_up_interface", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + method_to_orchestrate=method, + config=stencil_factory.config.dace_config, + ) + + self._accumulate_down = stencil_factory.from_dims_halo( + func=accumulate_down, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_down_from_interface_field = stencil_factory.from_dims_halo( + func=accumulate_down_from_interface_field, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + + self._accumulate_on_interface = stencil_factory.from_dims_halo( + func=accumulate_on_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + self._accumulate_up = stencil_factory.from_dims_halo( + func=accumulate_up, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_up_interface = stencil_factory.from_dims_halo( + func=accumulate_up_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + def accumulate_down(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down(in_field, out_field) + + def accumulate_down_from_interface_field(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down_from_interface_field(interface_field, out_field) + + def accumulate_on_interface(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_on_interface(interface_field, out_field) + + def accumulate_up(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_up(in_field, out_field) + + def accumulate_up_interface(self, in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + self._accumulate_up_interface(in_field, interface_field) + + +class TestBoundariesK: + @pytest.fixture( + params=[ + "orch:dace:cpu:IJK", + "orch:dace:cpu:KJI", + "st:dace:cpu:IJK", + "st:dace:cpu:KJI", + ] + ) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 4, 5) + return get_factories_single_tile( + nx=domain[0], + ny=domain[1], + nz=domain[2], + nhalo=0, + backend=Backend(request.param), + ) + + def test_accumulate_down(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [5, 4, 3, 2, 1]) + + def test_accumulate_interface_field(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down_from_interface_field(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_interface_domain(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_on_interface(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_up(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_up(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [1, 2, 3, 4, 5]) + + def test_accumulate_up_interface(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + interface_field = quantity_factory.zeros( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + + code.accumulate_up_interface(in_field, interface_field) + assert np.array_equal(interface_field.field[0, 0, :], [1, 2, 3, 4, 5, 6]) diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index 67e4f226..83694274 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -2,20 +2,19 @@ import pytest -from ndsl import NDSLRuntime, QuantityFactory, StencilFactory +from ndsl import ( + NDSLRuntime, + OptimizationConfig, + QuantityFactory, + StencilFactory, + stencils, +) from ndsl.boilerplate import ( get_factories_single_tile, get_factories_single_tile_orchestrated, ) from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM -from ndsl.dsl.gt4py import PARALLEL, computation, interval -from ndsl.dsl.typing import FloatField - - -def the_copy_stencil(from_: FloatField, to: FloatField) -> None: - with computation(PARALLEL), interval(...): - to = from_ class Code(NDSLRuntime): @@ -24,7 +23,7 @@ def __init__( ) -> None: super().__init__(stencil_factory) self.copy = stencil_factory.from_dims_halo( - the_copy_stencil, compute_dims=[I_DIM, J_DIM, K_DIM] + stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] ) self.local = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM]) @@ -105,3 +104,36 @@ def test_runtime_fail_when_not_super_init() -> None: RuntimeError, match="inherit from NDSLRuntime but didn't call super()" ): bad_code = BadCode_NoSuperInit() + + +def test_runtime_with_performance_config() -> None: + class CustomPerformanceConfig(NDSLRuntime): + def __init__( + self, + stencil_factory: StencilFactory, + optimization_config: OptimizationConfig, + ) -> None: + super().__init__(stencil_factory, optimization_config) + self.copy = stencil_factory.from_dims_halo( + stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] + ) + + def __call__(self, src, dst) -> None: # type: ignore[no-untyped-def] + self.copy(src, dst) + + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + nx=5, ny=5, nz=3, nhalo=0, backend=Backend.cpu() + ) + + # setup code + config = OptimizationConfig() + code = CustomPerformanceConfig(stencil_factory, config) + + # setup inputs/outputs + src = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a") + dst = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a") + + # call code with inputs/outputs + code(src, dst) + + assert (src.field[:] == dst.field[:]).all()