diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index e780c190..7bde1bf0 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -23,14 +23,13 @@ jobs: run: pip3 install .[test] - name: prepare input eta files - run: | - python tests/grid/generate_eta_files.py + run: python tests/grid/generate_eta_files.py - name: Run serial-cpu tests - run: coverage run --rcfile=setup.cfg -m pytest -x tests + run: coverage run --rcfile=setup.cfg -m pytest tests - name: Run parallel-cpu tests - run: mpiexec -np 6 --oversubscribe coverage run --rcfile=setup.cfg -m mpi4py -m pytest -x tests/mpi + run: mpiexec -np 6 --oversubscribe coverage run --rcfile=setup.cfg -m mpi4py -m pytest tests/mpi - name: Output code coverage run: | diff --git a/README.md b/README.md index 3ea17add..6a387192 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ NDSL submodules `gt4py` and `dace` to point to vetted versions, use `git clone - NDSL is __NOT__ available on `pypi`. Installation of the package has to be local, via `pip install ./NDSL` (`-e` supported). The packages has a few options: - `ndsl[test]`: installs the test packages (based on `pytest`) -- `ndsl[demos]`: installs extra requirements to run [NDSL exmpales](./examples/NDSL/) +- `ndsl[demos]`: installs extra requirements to run [NDSL examples](./examples/NDSL/) - `ndsl[docs]`: installs extra requirements to build the docs - `ndsl[develop]`: installs tools for development, docs, and tests. @@ -34,8 +34,6 @@ For CPU backends: - 3.11.x >= Python < 3.12.x - Compilers: - GNU 11.2+ -- Libraries: - - Boost headers 1.76+ (no lib installed, just headers) For GPU backends (the above plus): diff --git a/docs/index.md b/docs/index.md index 24053f27..650a2846 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,8 +25,6 @@ For CPU backends: - 3.11.x >= Python < 3.12.x - Compilers: - GNU 11.2+ -- Libraries: - - Boost headers 1.76+ (no lib installed, just headers) For GPU backends (the above plus): @@ -88,18 +86,6 @@ git submodule update --init --recursive ``` - Pace requires GCC > 9.2, MPI, and Python 3.8 on your system, and CUDA is required to run with a GPU backend. -You will also need the headers of the boost libraries in your `$PATH` (boost itself does not need to be installed). -If installed outside the standard header locations, gt4py requires that `$BOOST_ROOT` be set: - -```bash -cd BOOST/ROOT -wget https://boostorg.jfrog.io/artifactory/main/release/1.79.0/source/boost_1_79_0.tar.gz -tar -xzf boost_1_79_0.tar.gz -mkdir -p boost_1_79_0/include -mv boost_1_79_0/boost boost_1_79_0/include/ -export BOOST_ROOT=BOOST/ROOT/boost_1_79_0 -``` - - We recommend creating a python `venv` or conda environment specifically for Pace. ```bash diff --git a/examples/NDSL/01_gt4py_basics.ipynb b/examples/NDSL/01_gt4py_basics.ipynb index ce66cfa2..fa3797d7 100644 --- a/examples/NDSL/01_gt4py_basics.ipynb +++ b/examples/NDSL/01_gt4py_basics.ipynb @@ -40,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil\n", + "from ndsl.dsl.gt4py import PARALLEL, computation, interval, stencil\n", "from ndsl.dsl.typing import FloatField\n", "from ndsl.quantity import Quantity\n", "import numpy as np" @@ -109,14 +109,14 @@ "\n", "We see that this stencil does not contain any explicit loops. As mentioned above in the notebook, GT4Py has a particular computation policy that implicitly executes in parallel within an `IJ` plane and is user defined in the `K` interval. This execution policy in the `K` interval is dictated by the `computation` and `interval` keywords. \n", "\n", - "- `with computation(PARALLEL)` means that there's no order preference to executing the `K` interval. This also means that the `K` interval can be computed in parallel to potentially gain performace if computational resources are available.\n", + "- `with computation(PARALLEL)` means that there's no order preference to executing the `K` interval. This also means that the `K` interval can be computed in parallel to potentially gain performance if computational resources are available.\n", "\n", "- `interval(...)` means that the entire `K` interval is executed. Instead of `(...)`, more specific intervals can be specified using a tuple of two integers. For example... \n", "\n", " - `interval(0,2)` : The interval `K` = 0 to 1 is executed.\n", " - `interval(0,-1)` : The interval `K` = 0 to N-2 (where N is the size of `K`) is executed.\n", "\n", - "The decorator `@stencil(backend=backend)` (Note: `stencil` comes from the package `gt4py.cartesian.gtscript`) converts `copy_stencil` to use the specified `backend` to \"compile\" the stencil. `stencil` can also be a function call to create a stencil object." + "The decorator `@stencil(backend=backend)` (Note: `stencil` comes from the package `ndsl.dsl.gt4py`) converts `copy_stencil` to use the specified `backend` to \"compile\" the stencil. `stencil` can also be a function call to create a stencil object." ] }, { @@ -269,7 +269,7 @@ "metadata": {}, "outputs": [], "source": [ - "from gt4py.cartesian.gtscript import FORWARD, BACKWARD\n", + "from ndsl.dsl.gt4py import FORWARD, BACKWARD\n", "\n", "nx = 5\n", "ny = 5\n", @@ -470,7 +470,7 @@ "\n", "GT4Py also has the capability to create functions in order to better organize code. The main difference between a GT4Py function call and a GT4Py stencil is that a function does not (and cannot) contain the keywords `computation` and `interval`. However, array index referencing within a GT4py function is the same as in a GT4Py stencil.\n", "\n", - "GT4Py functions can be created by using the decorator `function` (Note: `function` originates from the package `gt4py.cartesian.gtscript`)." + "GT4Py functions can be created by using the decorator `function` (Note: `function` originates from the package `ndsl.dsl.gt4py`)." ] }, { @@ -479,7 +479,7 @@ "metadata": {}, "outputs": [], "source": [ - "from gt4py.cartesian.gtscript import function\n", + "from ndsl.dsl.gt4py import function\n", "\n", "@function\n", "def plus_one(field: FloatField):\n", diff --git a/examples/NDSL/02_NDSL_basics.ipynb b/examples/NDSL/02_NDSL_basics.ipynb index eac17cad..1c572ce1 100644 --- a/examples/NDSL/02_NDSL_basics.ipynb +++ b/examples/NDSL/02_NDSL_basics.ipynb @@ -26,14 +26,14 @@ "outputs": [], "source": [ "from ndsl import StencilFactory\n", - "from ndsl.boilerplate import get_factories_single_tile_numpy\n", + "from ndsl.boilerplate import get_factories_single_tile\n", "\n", "nx = 6\n", "ny = 6\n", "nz = 1\n", "nhalo = 1\n", "\n", - "stencil_factory, _ = get_factories_single_tile_numpy(nx, ny, nz, nhalo)" + "stencil_factory, _ = get_factories_single_tile(nx, ny, nz, nhalo)" ] }, { @@ -59,8 +59,8 @@ "metadata": {}, "outputs": [], "source": [ + "from ndsl.dsl.gt4py import PARALLEL, computation, interval\n", "from ndsl.dsl.typing import FloatField\n", - "from gt4py.cartesian.gtscript import PARALLEL, computation, interval\n", "\n", "def copy_field_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", @@ -150,7 +150,7 @@ "print(\"Plotting qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", "print(\"Plotting qty_out at K = 0\")\n", - "qty_out.plot_k_level(0)\n" + "qty_out.plot_k_level(0)" ] }, { @@ -189,7 +189,7 @@ "\n", "The next example will create a stencil that takes a `Quantity` as an input, shift the input by 1 in the `-J` direction, and write it to an output `Quantity`. This stencil is defined in `copy_field_offset_stencil`.\n", "\n", - "Note that in `copy_field_offset_stencil`, the shift in the J dimension is performed by referencing the `J` object from `gt4py.cartesian.gtscript` for simplicity. This reference will apply the shift in J to the entire input domain. Another way to perform the shift without referencing the `J` object is to write `[0,-1,0]` (assuming that the variable being modified is 3-dimensional) instead of `[J-1]`.\n", + "Note that in `copy_field_offset_stencil`, the shift in the `J` dimension is performed by referencing the `J` object from `ndsl.dsl.gt4py` for simplicity. This reference will apply the shift in `J` to the entire input domain. Another way to perform the shift without referencing the `J` object is to write `[0,-1,0]` (assuming that the variable being modified is 3-dimensional) instead of `[J-1]`.\n", "\n", "With the stencil in place, a class `CopyFieldOffset` is defined using the `StencilFactory` object and `copy_field_offset_stencil`. The class is instantiated and demonstrated to shift `qty_in` by 1 in the J-dimension and write to `qty_out`." ] @@ -200,7 +200,7 @@ "metadata": {}, "outputs": [], "source": [ - "from gt4py.cartesian.gtscript import J\n", + "from ndsl.dsl.gt4py import J\n", "\n", "def copy_field_offset_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", @@ -230,7 +230,7 @@ " gt4py_backend=backend\n", " )\n", "\n", - "print(\"Initialize qty_out to zeros\")\n" + "print(\"Initialize qty_out to zeros\")" ] }, { @@ -251,7 +251,7 @@ "source": [ "### **Limits to offset : Cannot set offset outside of usable domain**\n", "\n", - "Note that when the copy offset by -1 in the j-direction is performed, the 'halo' region at J = 8 is copied over due to the `J` shift. This means that there are limits to the shift amount since choosing a large shift amount may result in accessing a data region that does not exist. The following example shows this by trying to perform a shift by -2 in the j-direction." + "Note that when the copy offset by `-1` in the `j`-direction is performed, the 'halo' region at `J = 8` is copied over due to the `J` shift. This means that there are limits to the shift amount since choosing a large shift amount may result in accessing a data region that does not exist. The following example shows this by trying to perform a shift by `-2` in the `j`-direction." ] }, { @@ -282,7 +282,7 @@ " \n", "copy_field_offset = CopyFieldOffset(stencil_factory)\n", "\n", - "copy_field_offset(qty_in, qty_out)\n" + "copy_field_offset(qty_in, qty_out)" ] }, { @@ -300,8 +300,6 @@ "metadata": {}, "outputs": [], "source": [ - "from gt4py.cartesian.gtscript import J\n", - "\n", "def copy_field_offset_output_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", " field_out[0,1,0] = field_in\n", @@ -322,14 +320,13 @@ " ):\n", " self._copy_field_offset_output(field_in, field_out)\n", " \n", - "copy_field_offset_output = CopyFieldOffsetOutput(stencil_factory)\n", - " " + "copy_field_offset_output = CopyFieldOffsetOutput(stencil_factory)" ] } ], "metadata": { "kernelspec": { - "display_name": "gt4py_jupyter", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/examples/NDSL/03_orchestration_basics.ipynb b/examples/NDSL/03_orchestration_basics.ipynb index 01a77dd8..48f0d702 100644 --- a/examples/NDSL/03_orchestration_basics.ipynb +++ b/examples/NDSL/03_orchestration_basics.ipynb @@ -24,7 +24,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "from gt4py.cartesian.gtscript import (\n", + "from ndsl.dsl.gt4py import (\n", " PARALLEL,\n", " computation,\n", " interval,\n", diff --git a/external/gt4py b/external/gt4py index 1ba0a972..45324c88 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 1ba0a97282037a6756f5da23d207a362383e5743 +Subproject commit 45324c88e57b5e8dfc974efa70fa2f2e5e10677f diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 5ec303de..f6bb6117 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,3 +1,4 @@ +from . import dsl # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm @@ -28,6 +29,7 @@ from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport from .quantity import Quantity +from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator from .utils import MetaEnumStr diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index aa806b21..1447c5fb 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -1,10 +1,10 @@ import collections import numpy as np +import xarray as xr from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp -from ndsl.optional_imports import xarray as xr def make_dims(savepoint_dim, label, data_list): @@ -39,12 +39,7 @@ def dataset(self) -> "xr.Dataset": data_vars[f"{variable_name}"] = make_dims( savepoint_dim, variable_name, self._arrays[variable_name] ) - if xr is None: - raise ModuleNotFoundError( - "xarray must be installed to use Snapshots.dataset" - ) - else: - return xr.Dataset(data_vars=data_vars) + return xr.Dataset(data_vars=data_vars) class SnapshotCheckpointer(Checkpointer): @@ -54,10 +49,6 @@ class SnapshotCheckpointer(Checkpointer): """ def __init__(self, rank: int): - if xr is None: - raise ModuleNotFoundError( - "xarray must be installed to use SnapshotCheckpointer" - ) self._rank = rank self._snapshots = _Snapshots() diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index ded73b39..fbf0e956 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -9,12 +9,6 @@ from ndsl.quantity import Quantity -try: - import cupy as cp -except ImportError: - cp = None - - SavepointName = str VariableName = str ArrayLike = Union[Quantity, np.ndarray] diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 8af11317..00da3d13 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -4,6 +4,7 @@ from typing import MutableMapping, Tuple import numpy as np +import xarray as xr from ndsl.checkpointer.base import Checkpointer from ndsl.checkpointer.thresholds import ( @@ -12,7 +13,6 @@ SavepointThresholds, cast_to_ndarray, ) -from ndsl.optional_imports import xarray as xr def _clip_pace_array_to_target( @@ -109,8 +109,6 @@ def __call__(self, savepoint_name: str, **kwargs: ArrayLike) -> None: Raises: AssertionError: if the thresholds on any variable are not met """ - if xr is None: - raise ModuleNotFoundError("xarray is not installed") nc_file = os.path.join(self._savepoint_data_path, savepoint_name + ".nc") ds = xr.open_dataset(nc_file) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 014142da..c523affb 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -10,17 +10,12 @@ from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from ndsl.optional_imports import cupy from ndsl.performance.timer import NullTimer, Timer from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule -try: - import cupy -except ImportError: - cupy = None - - def to_numpy(array, dtype=None) -> np.ndarray: """ Input array can be a numpy array or a cupy array. Returns numpy array. diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 3c466950..0b4a5540 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -1,10 +1,7 @@ -try: - import mpi4py - from mpi4py import MPI -except ImportError: - MPI = None from typing import Dict, List, Optional, TypeVar, cast +from mpi4py import MPI + from ndsl.comm.comm_abc import Comm, ReductionOperator, Request @@ -12,22 +9,22 @@ class MPIComm(Comm): - _op_mapping: Dict[ReductionOperator, mpi4py.MPI.Op] = { - ReductionOperator.OP_NULL: mpi4py.MPI.OP_NULL, - ReductionOperator.MAX: mpi4py.MPI.MAX, - ReductionOperator.MIN: mpi4py.MPI.MIN, - ReductionOperator.SUM: mpi4py.MPI.SUM, - ReductionOperator.PROD: mpi4py.MPI.PROD, - ReductionOperator.LAND: mpi4py.MPI.LAND, - ReductionOperator.BAND: mpi4py.MPI.BAND, - ReductionOperator.LOR: mpi4py.MPI.LOR, - ReductionOperator.BOR: mpi4py.MPI.BOR, - ReductionOperator.LXOR: mpi4py.MPI.LXOR, - ReductionOperator.BXOR: mpi4py.MPI.BXOR, - ReductionOperator.MAXLOC: mpi4py.MPI.MAXLOC, - ReductionOperator.MINLOC: mpi4py.MPI.MINLOC, - ReductionOperator.REPLACE: mpi4py.MPI.REPLACE, - ReductionOperator.NO_OP: mpi4py.MPI.NO_OP, + _op_mapping: Dict[ReductionOperator, MPI.Op] = { + ReductionOperator.OP_NULL: MPI.OP_NULL, + ReductionOperator.MAX: MPI.MAX, + ReductionOperator.MIN: MPI.MIN, + ReductionOperator.SUM: MPI.SUM, + ReductionOperator.PROD: MPI.PROD, + ReductionOperator.LAND: MPI.LAND, + ReductionOperator.BAND: MPI.BAND, + ReductionOperator.LOR: MPI.LOR, + ReductionOperator.BOR: MPI.BOR, + ReductionOperator.LXOR: MPI.LXOR, + ReductionOperator.BXOR: MPI.BXOR, + ReductionOperator.MAXLOC: MPI.MAXLOC, + ReductionOperator.MINLOC: MPI.MINLOC, + ReductionOperator.REPLACE: MPI.REPLACE, + ReductionOperator.NO_OP: MPI.NO_OP, } def __init__(self): @@ -84,4 +81,4 @@ def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op]) def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T: - return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op]) + return self._comm.Allreduce(MPI.IN_PLACE, recvobj, self._op_mapping[op]) diff --git a/ndsl/constants.py b/ndsl/constants.py index eb89bd17..7092f361 100644 --- a/ndsl/constants.py +++ b/ndsl/constants.py @@ -158,8 +158,10 @@ class ConstantVersions(Enum): raise RuntimeError("Constant selector failed, bad code.") SECONDS_PER_DAY = Float(86400.0) -SBC = 5.670400e-8 +SBC = Float(5.670400e-8) """Stefan-Boltzmann constant (W/m^2/K^4)""" +RHO_H2O = Float(1000.0) +"""Density of water in kg/m^3""" CV_AIR = CP_AIR - RDGAS """Heat capacity of dry air at constant volume""" RDG = -RDGAS / GRAV @@ -200,11 +202,13 @@ class ConstantVersions(Enum): """Saturation vapor pressure at H2O 3pt (Pa)""" T_WFR = TICE - Float(40.0) """homogeneous freezing temperature""" -TICE0 = TICE - Float(0.01) +TICE0 = Float(2.7315e2) +""" Temp at 0C""" T_MIN = Float(178.0) """Minimum temperature to freeze-dry all water vapor""" T_SAT_MIN = TICE - Float(160.0) +"""Minimum temperature used in saturation calculations""" LAT2 = np.power((HLV + HLF), 2, dtype=Float) """Used in bigg mechanism""" -TTP = 2.7316e2 +TTP = Float(2.7316e2) """Temperature of H2O triple point""" diff --git a/ndsl/debug/__init__.py b/ndsl/debug/__init__.py new file mode 100644 index 00000000..4f255c54 --- /dev/null +++ b/ndsl/debug/__init__.py @@ -0,0 +1,4 @@ +from .config import ndsl_debugger + + +__all__ = ["ndsl_debugger"] diff --git a/ndsl/debug/config.py b/ndsl/debug/config.py new file mode 100644 index 00000000..94ee7119 --- /dev/null +++ b/ndsl/debug/config.py @@ -0,0 +1,51 @@ +""" +This module provides configuration for the global debugger `ndsl_debugger` + +When loading, the configuration will be searched in the global environment variable +`NDSL_DEBUG_CONFIG` + +Configuration is a yaml file of the shape +```yaml +stencils_or_class: + - copy_corners_x_nord + - copy_corners_y_nord + - DGridShallowWaterLagrangianDynamics.__call__ +track_parameter_by_name: + - fy +``` + +Global variable: + ndsl_debugger: Debugger accessible throughout the middleware, default to `None` + if there is no configuration +""" + +import os + +import yaml + +from ndsl.comm.mpi import MPIComm +from ndsl.debug.debugger import Debugger +from ndsl.logging import ndsl_log + + +ndsl_debugger = None + + +def _set_debugger(): + config = os.getenv("NDSL_DEBUG_CONFIG", "") + if not os.path.exists(config): + if config != "": + ndsl_log.warning( + f"NDSL_DEBUG_CONFIG set but path {config} does not exists." + ) + else: + return + with open(config) as file: + config_dict = yaml.load(file.read(), Loader=yaml.SafeLoader) + global ndsl_debugger + ndsl_debugger = Debugger(rank=MPIComm().Get_rank(), **config_dict) + ndsl_log.info("[NDSL Debugger] On") + ndsl_log.debug(f"[NDSL Debugger] Config:\n{config_dict}") + + +_set_debugger() diff --git a/ndsl/debug/debugger.py b/ndsl/debug/debugger.py new file mode 100644 index 00000000..7e1f60fe --- /dev/null +++ b/ndsl/debug/debugger.py @@ -0,0 +1,109 @@ +import dataclasses +import numbers +import os +import pathlib + +import pandas as pd +import xarray as xr + +from ndsl.logging import ndsl_log +from ndsl.quantity import Quantity + + +@dataclasses.dataclass +class Debugger: + """Debugger relying on `ndsl.debug.config` for setup capable + of doing automatic data save on external configuration.""" + + # Configuration + stencils_or_class: list[str] = dataclasses.field(default_factory=list) + track_parameter_by_name: list[str] = dataclasses.field(default_factory=list) + save_compute_domain_only: bool = False + dir_name: str = "./" + + # Runtime data + rank: int = -1 + calls_count: dict[str, int] = dataclasses.field(default_factory=dict) + track_parameter_count: dict[str, int] = dataclasses.field(default_factory=dict) + + def _to_xarray(self, data, name) -> xr.DataArray: + if isinstance(data, Quantity): + if self.save_compute_domain_only: + mem = data.field + shp = data.field.shape + else: + mem = data.data + shp = data.shape + elif hasattr(data, "shape"): + mem = data + shp = data.shape + elif ( + pd.api.types.is_numeric_dtype(data) + or pd.api.types.is_string_dtype(data) + or isinstance(data, numbers.Number) + ): + return xr.DataArray(data) + else: + ndsl_log.error(f"[Debugger] Cannot save data of type {type(data)}") + return xr.DataArray([0]) + return xr.DataArray(mem, dims=[f"dim_{i}_{s}" for i, s in enumerate(shp)]) + + def track_data(self, data_as_dict, source_as_name, is_in) -> None: + for name, data in data_as_dict.items(): + if name not in self.track_parameter_by_name: + continue + + if name not in self.track_parameter_count: + self.track_parameter_count[name] = 0 + count = self.track_parameter_count[name] + + path = pathlib.Path(f"{self.dir_name}/debug/tracks/{name}/R{self.rank}/") + os.makedirs(path, exist_ok=True) + path = pathlib.Path( + f"{path}/{count}_{name}_{source_as_name}-{'In' if is_in else 'Out'}.nc4" + ) + try: + self._to_xarray(data, name).to_netcdf(path) + except ValueError as e: + from ndsl import ndsl_log + + ndsl_log.error(f"[Debugger] Failure to save {data}: {e}") + + self.track_parameter_count[name] += 1 + + def save_as_dataset(self, data_as_dict, savename, is_in) -> None: + """Save dictionnary of data to NetCDF + + Note: Unknown types in the dictionnary won't be saved. + """ + if savename not in self.stencils_or_class: + return + + data_arrays = {} + for name, data in data_as_dict.items(): + if dataclasses.is_dataclass(data): + for field in dataclasses.fields(data): + data_arrays[f"{name}.{field.name}"] = self._to_xarray( + getattr(data, field.name), field.name + ) + else: + data_arrays[name] = self._to_xarray(data, name) + + call_count = ( + self.calls_count[savename] if savename in self.calls_count.keys() else 0 + ) + path = pathlib.Path(f"{self.dir_name}/debug/savepoints/R{self.rank}/") + os.makedirs(path, exist_ok=True) + path = pathlib.Path( + f"{path}/{savename}-Call{call_count}-{'In' if is_in else 'Out'}.nc4" + ) + try: + xr.Dataset(data_arrays).to_netcdf(path) + except ValueError as e: + ndsl_log.error(f"[DebugInfo] Failure to save {savename}: {e}") + + def increment_call_count(self, savename: str): + """Increment the call count for this savename""" + if savename not in self.calls_count.keys(): + self.calls_count[savename] = 0 + self.calls_count[savename] += 1 diff --git a/ndsl/debug/tooling.py b/ndsl/debug/tooling.py new file mode 100644 index 00000000..4ec89bbc --- /dev/null +++ b/ndsl/debug/tooling.py @@ -0,0 +1,44 @@ +import inspect +from functools import wraps +from typing import Any, Callable + +from ndsl.debug.config import ndsl_debugger + + +def instrument(func) -> Callable: + @wraps(func) + def wrapper(self, *args: Any, **kwargs: Any): + if ndsl_debugger is None: + return func(self, *args, **kwargs) + savename = func.__qualname__ + params = inspect.signature(func).parameters + data_as_dict = {} + + # Positional + positional_count = 0 + for name, param in params.items(): + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + if positional_count == 0: # self + positional_count += 1 + continue + if positional_count < len(args) + 1: + data_as_dict[name] = args[positional_count - 1] + positional_count += 1 + # Keyword arguments + for name, value in kwargs.items(): + if name in params: + data_as_dict[name] = value + if ndsl_debugger is not None: + ndsl_debugger.save_as_dataset(data_as_dict, func.__qualname__, is_in=True) + ndsl_debugger.track_data(data_as_dict, func.__qualname__, is_in=True) + r = func(self, *args, **kwargs) + if ndsl_debugger is not None: + ndsl_debugger.save_as_dataset(data_as_dict, func.__qualname__, is_in=False) + ndsl_debugger.track_data(data_as_dict, func.__qualname__, is_in=False) + ndsl_debugger.increment_call_count(savename) + return r + + return wrapper diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index b7034e07..e3fe0cc8 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -1,8 +1,25 @@ -import gt4py.cartesian.config +# Literal precision for both GT4Py & NDSL +import os +import sys from ndsl.comm.mpi import MPI +gt4py_config_module = "gt4py.cartesian.config" +if gt4py_config_module in sys.modules: + raise RuntimeError( + "`GT4Py` config imported before `ndsl` imported." + " Please import `ndsl.dsl` or any `ndsl` module " + " before any `gt4py` imports." + ) +NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64")) +os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION) + + +# Set cache names for default gt backends workflow +import gt4py.cartesian.config # noqa: E402 + + if MPI is not None: import os diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index c09b69cc..5c30367f 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -32,12 +32,7 @@ report_memory_static_analysis, ) from ndsl.logging import ndsl_log - - -try: - import cupy as cp -except ImportError: - cp = None +from ndsl.optional_imports import cupy as cp def dace_inhibitor(func: Callable) -> Callable: @@ -111,7 +106,6 @@ def _simplify( validate=validate, validate_all=validate_all, verbose=verbose, - skip=["ConstantPropagation"], ).apply_pass(sdfg, {}) diff --git a/ndsl/dsl/gt4py/__init__.py b/ndsl/dsl/gt4py/__init__.py new file mode 100644 index 00000000..7c051fb0 --- /dev/null +++ b/ndsl/dsl/gt4py/__init__.py @@ -0,0 +1,53 @@ +# Import gt4py functions, all future references to gt4py should come through here +from gt4py.cartesian.gtscript import ( + BACKWARD, + FORWARD, + IJ, + IJK, + IK, + JK, + PARALLEL, + Field, + GlobalTable, + I, + J, + K, + Sequence, + abs, + acos, + acosh, + asin, + asinh, + atan, + atanh, + cbrt, + ceil, + compile_assert, + computation, + cos, + cosh, + exp, + externals, + floor, + function, + gamma, + horizontal, + interval, + isfinite, + isinf, + isnan, + log, + log10, + max, + min, + mod, + region, + sin, + sinh, + sqrt, + stencil, + tan, + tanh, + trunc, + types, +) diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 6c0254e9..5f3619e7 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,19 +1,16 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import gt4py import numpy as np +from gt4py import storage as gt_storage +from gt4py.cartesian import backend as gt_backend from ndsl.constants import N_HALO_DEFAULT from ndsl.dsl.typing import DTypes, Field, Float from ndsl.logging import ndsl_log +from ndsl.optional_imports import cupy as cp -try: - import cupy as cp -except ImportError: - cp = None - # If True, automatically transfers memory between CPU and GPU (see gt4py.storage) managed_memory = True @@ -188,7 +185,7 @@ def make_storage_data( backend=backend, ) - storage = gt4py.storage.from_array( + storage = gt_storage.from_array( data, dtype, backend=backend, @@ -340,7 +337,7 @@ def make_storage_from_shape( mask = (False, False, True) # Assume 1D is a k-field else: mask = (n_dims * (True,)) + ((3 - n_dims) * (False,)) - storage = gt4py.storage.zeros( + storage = gt_storage.zeros( shape, dtype, backend=backend, @@ -448,7 +445,7 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None): def is_gpu_backend(backend: str) -> bool: - return gt4py.cartesian.backend.from_name(backend).storage_info["device"] == "gpu" + return gt_backend.from_name(backend).storage_info["device"] == "gpu" def zeros(shape, dtype=Float, *, backend: str): diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index daf78091..d829bbd0 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -17,31 +17,29 @@ ) import dace -import gt4py import numpy as np +from gt4py.cartesian import config as gt_config +from gt4py.cartesian import definitions as gt_definitions from gt4py.cartesian import gtscript from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline +from gt4py.cartesian.stencil_object import StencilObject from ndsl.comm.comm_abc import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS +from ndsl.debug import ndsl_debugger from ndsl.dsl.dace.orchestration import SDFGConvertible from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from ndsl.dsl.typing import Float, Index3D, cast_to_index3d from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.logging import ndsl_log from ndsl.quantity import Quantity +from ndsl.quantity.field_bundle import FieldBundleType, MarkupFieldBundleType from ndsl.testing.comparison import LegacyMetric -try: - import cupy as cp -except ImportError: - cp = np - - def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id): report_head = f"comparing against numpy for func {function_name}, gt_id {gt_id}:" report_segments = [] @@ -294,10 +292,11 @@ def __init__( externals = {} self.externals = externals self._func_name = func.__name__ + self._func_qualname = func.__qualname__ stencil_kwargs = self.stencil_config.stencil_kwargs( skip_passes=skip_passes, func=func ) - self.stencil_object = None + self.stencil_object: StencilObject | None = None self._argument_names = tuple(inspect.getfullargspec(func).args) @@ -305,8 +304,8 @@ def __init__( dace.Config.set( "default_build_folder", value="{gt_root}/{gt_cache}/dacecache".format( - gt_root=gt4py.cartesian.config.cache_settings["root_path"], - gt_cache=gt4py.cartesian.config.cache_settings["dir_name"], + gt_root=gt_config.cache_settings["root_path"], + gt_cache=gt_config.cache_settings["dir_name"], ), ) @@ -335,13 +334,21 @@ def __init__( ): block_waiting_for_compilation(MPI.COMM_WORLD, compilation_config) + # Field Bundle might have dropped a placeholder type that we now + # have to resolve to the proper type. + for name, types in func.__annotations__.items(): + if isinstance(types, MarkupFieldBundleType): + func.__annotations__[name] = FieldBundleType.T( + types.name, do_markup=False + ) + self.stencil_object = gtscript.stencil( definition=func, externals=externals, dtypes={float: Float}, **stencil_kwargs, build_info=(build_info := {}), - ) + ) # type: ignore if ( compilation_config.use_minimal_caching @@ -375,13 +382,17 @@ def nothing_function(*args, **kwargs): setattr(self, "__call__", nothing_function) def __call__(self, *args, **kwargs) -> None: + # Verbose stencil execution if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") + + # Marshal arguments args_list = list(args) _convert_quantities_to_storage(args_list, kwargs) args = tuple(args_list) - args_as_kwargs = dict(zip(self._argument_names, args)) + + # Ranks comparison tool if self.comm is not None: differences = compare_ranks(self.comm, {**args_as_kwargs, **kwargs}) if len(differences) > 0: @@ -389,6 +400,14 @@ def __call__(self, *args, **kwargs) -> None: f"rank {self.comm.Get_rank()} has differences {differences} " f"before calling {self._func_name}" ) + + # Debugger actions if turned on + if ndsl_debugger: + all_args = args_as_kwargs | kwargs + ndsl_debugger.save_as_dataset(all_args, self._func_qualname, is_in=True) + ndsl_debugger.track_data(all_args, self._func_qualname, is_in=True) + + # Execute stencil if self.stencil_config.compilation_config.validate_args: if __debug__ and "origin" in kwargs: raise TypeError("origin cannot be passed to FrozenStencil call") @@ -401,7 +420,7 @@ def __call__(self, *args, **kwargs) -> None: domain=self.domain, validate_args=True, exec_info=self._timing_collector.exec_info, - ) + ) # type: ignore else: self.stencil_object.run( **args_as_kwargs, @@ -409,6 +428,15 @@ def __call__(self, *args, **kwargs) -> None: **self._stencil_run_kwargs, exec_info=self._timing_collector.exec_info, ) + + # Debugger actions if turned on + if ndsl_debugger: + all_args = args_as_kwargs | kwargs + ndsl_debugger.save_as_dataset(all_args, self._func_qualname, is_in=False) + ndsl_debugger.track_data(all_args, self._func_qualname, is_in=False) + ndsl_debugger.increment_call_count(self._func_qualname) + + # Ranks comparison tool if self.comm is not None: differences = compare_ranks(self.comm, {**args_as_kwargs, **kwargs}) if len(differences) > 0: @@ -466,7 +494,7 @@ def _get_written_fields(cls, field_info) -> List[str]: if field_info[field_name] and bool( field_info[field_name].access - & gt4py.cartesian.definitions.AccessKind.WRITE # type: ignore + & gt_definitions.AccessKind.WRITE # type: ignore ) ] return write_fields diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 1cae1063..5f60f401 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -1,8 +1,8 @@ import os from typing import Tuple, TypeAlias, Union, cast -import gt4py.cartesian.gtscript as gtscript import numpy as np +from gt4py.cartesian import gtscript # A Field @@ -22,11 +22,6 @@ DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64] -# Depreciated version of get_precision, but retained for a PACE dependency -def floating_point_precision() -> int: - return int(os.getenv("PACE_FLOAT_PRECISION", "64")) - - def get_precision() -> int: return int(os.getenv("PACE_FLOAT_PRECISION", "64")) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index c32ceb3f..7b28c2ff 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -296,7 +296,8 @@ def __init__( self._dy_agrid = None self._dx_center = None self._dy_center = None - self._area = None + self._area: Optional[Quantity] = None + self._area64: Optional[Quantity] = None self._area_c = None if eta_file is not None or ak is not None or bk is not None: ( @@ -1490,9 +1491,18 @@ def area(self) -> Quantity: the area of each a-grid cell """ if self._area is None: - self._area = self._compute_area() + self._area, self._area64 = self._compute_area() return self._area + @property + def area64(self) -> Quantity: + """ + the area of each a-grid cell, at 64-bit precision + """ + if self._area64 is None: + self._area, self._area64 = self._compute_area() + return self._area64 + @property def area_c(self) -> Quantity: """ @@ -2089,7 +2099,7 @@ def _compute_dxdy_center_cartesian(self): return dx_center, dy_center - def _compute_area_cube_sphere(self): + def _compute_area_cube_sphere(self) -> tuple[Quantity, Quantity]: area_64 = self.quantity_factory.zeros( [X_DIM, Y_DIM], "m^2", @@ -2106,9 +2116,9 @@ def _compute_area_cube_sphere(self): ) self._comm.halo_update(area_64, n_points=self._halo) - return quantity_cast_to_model_float(self.quantity_factory, area_64) + return quantity_cast_to_model_float(self.quantity_factory, area_64), area_64 - def _compute_area_cartesian(self): + def _compute_area_cartesian(self) -> tuple[Quantity, Quantity]: area_64 = self.quantity_factory.zeros( [X_DIM, Y_DIM], "m^2", @@ -2116,7 +2126,7 @@ def _compute_area_cartesian(self): allow_mismatch_float_precision=True, ) area_64.data[:, :] = self._dx_const * self._dy_const - return quantity_cast_to_model_float(self.quantity_factory, area_64) + return quantity_cast_to_model_float(self.quantity_factory, area_64), area_64 def _compute_area_c_cube_sphere(self): area_cgrid_64 = self.quantity_factory.zeros( diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 1a82d053..2fbc34a3 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -7,7 +7,7 @@ # TODO: if we can remove translate tests in favor of checkpointer tests, # we can remove this "disallowed" import (ndsl.util does not depend on ndsl.dsl) try: - from ndsl.dsl.gt4py_utils import split_cartesian_into_storages + from ndsl.dsl.gt4py_utils import is_gpu_backend, split_cartesian_into_storages except ImportError: split_cartesian_into_storages = None import ndsl.constants as constants @@ -93,7 +93,7 @@ def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "HorizontalGridData lon_agrid=metric_terms.lon_agrid, lat_agrid=metric_terms.lat_agrid, area=metric_terms.area, - area_64=metric_terms.area, + area_64=metric_terms.area64, rarea=metric_terms.rarea, rarea_c=metric_terms.rarea_c, dx=metric_terms.dx, @@ -233,7 +233,10 @@ def ptop(self) -> Float: """ if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - return Float(self.ak.view[0]) + if is_gpu_backend(self.ak.gt4py_backend): + return Float(self.ak.view[0].get()) + else: + return Float(self.ak.view[0]) @dataclasses.dataclass(frozen=True) diff --git a/ndsl/halo/data_transformer.py b/ndsl/halo/data_transformer.py index 9f9ab2f6..f3133974 100644 --- a/ndsl/halo/data_transformer.py +++ b/ndsl/halo/data_transformer.py @@ -225,9 +225,12 @@ def finalize(self): self.synchronize() # Push the buffers back in the cache - Buffer.push_to_cache(self._pack_buffer) + if self._pack_buffer is not None: + Buffer.push_to_cache(self._pack_buffer) self._pack_buffer = None - Buffer.push_to_cache(self._unpack_buffer) + + if self._unpack_buffer is not None: + Buffer.push_to_cache(self._unpack_buffer) self._unpack_buffer = None @staticmethod diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 869086f6..85ee17dd 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -1,11 +1,11 @@ from typing import Callable, Optional, Sequence import numpy as np +from gt4py import storage as gt_storage from ndsl.constants import SPATIAL_DIMS from ndsl.dsl.typing import Float from ndsl.initialization.sizer import GridSizer -from ndsl.optional_imports import gt4py from ndsl.quantity import Quantity, QuantityHaloSpec @@ -20,13 +20,13 @@ def __init__(self, backend: str): self.backend = backend def empty(self, *args, **kwargs) -> np.ndarray: - return gt4py.storage.empty(*args, backend=self.backend, **kwargs) + return gt_storage.empty(*args, backend=self.backend, **kwargs) def ones(self, *args, **kwargs) -> np.ndarray: - return gt4py.storage.ones(*args, backend=self.backend, **kwargs) + return gt_storage.ones(*args, backend=self.backend, **kwargs) def zeros(self, *args, **kwargs) -> np.ndarray: - return gt4py.storage.zeros(*args, backend=self.backend, **kwargs) + return gt_storage.zeros(*args, backend=self.backend, **kwargs) class QuantityFactory: @@ -112,6 +112,28 @@ def from_array( base.data[:] = base.np.asarray(data) return base + def from_compute_array( + self, + data: np.ndarray, + dims: Sequence[str], + units: str, + allow_mismatch_float_precision: bool = False, + ): + """ + Create a Quantity from a numpy array. + + That numpy array must correspond to the correct shape and extent + of the compute domain for the given dims. + """ + base = self.zeros( + dims=dims, + units=units, + dtype=data.dtype, + allow_mismatch_float_precision=allow_mismatch_float_precision, + ) + base.view[:] = base.np.asarray(data) + return base + def _allocate( self, allocator: Callable, diff --git a/ndsl/io.py b/ndsl/io.py index b07248bb..0aa2b7d4 100644 --- a/ndsl/io.py +++ b/ndsl/io.py @@ -1,9 +1,9 @@ from typing import TextIO import cftime +import xarray as xr import ndsl.filesystem as filesystem -from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity @@ -22,7 +22,7 @@ def to_xarray_dataset(state) -> xr.Dataset: data_vars = { - name: value.data_array for name, value in state.items() if name != "time" + name: value.data_as_xarray for name, value in state.items() if name != "time" } if "time" in state: data_vars["time"] = state["time"] @@ -47,7 +47,7 @@ def _extract_time(value: xr.DataArray) -> cftime.datetime: """Extract time value from read-in state.""" if value.ndim > 0: raise ValueError( - "State must be representative of a single scalar time. " f"Got {value}." + f"State must be representative of a single scalar time. Got {value}." ) time = value.item() if not isinstance(time, cftime.datetime): @@ -69,7 +69,8 @@ def read_state(filename: str) -> dict: """ out_dict = {} with filesystem.open(filename, "rb") as f: - ds = xr.open_dataset(f, use_cftime=True) + time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) + ds = xr.open_dataset(f, decode_times=time_coder) for name, value in ds.data_vars.items(): if name == "time": out_dict[name] = _extract_time(value) diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 204d7b94..945483d6 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -5,13 +5,13 @@ import fsspec import numpy as np +import xarray as xr from ndsl.comm.communicator import Communicator from ndsl.dsl.typing import Float, get_precision from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy -from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 214171be..20d5db7a 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -2,14 +2,13 @@ from typing import List, Tuple, Union import cftime +import xarray as xr import ndsl.constants as constants from ndsl.comm.partitioner import Partitioner, subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy -from ndsl.optional_imports import cupy -from ndsl.optional_imports import xarray as xr -from ndsl.optional_imports import zarr +from ndsl.optional_imports import cupy, zarr from ndsl.utils import list_by_dims diff --git a/ndsl/namelist.py b/ndsl/namelist.py index 8df5c207..e040c2cc 100644 --- a/ndsl/namelist.py +++ b/ndsl/namelist.py @@ -20,41 +20,68 @@ class NamelistDefaults: u_max = 350.0 do_f3d = False inline_q = False - do_skeb = False # save dissipation estimate + do_skeb = False + """Save dissipation estimate""" use_logp = False moist_phys = True check_negative = False - # gfdl_cloud_mucrophys.F90 - tau_r2g = 900.0 # rain freezing during fast_sat - tau_smlt = 900.0 # snow melting - tau_g2r = 600.0 # graupel melting to rain - tau_imlt = 600.0 # cloud ice melting - tau_i2s = 1000.0 # cloud ice to snow auto - conversion - tau_l2r = 900.0 # cloud water to rain auto - conversion - tau_g2v = 1200.0 # graupel sublimation - tau_v2g = 21600.0 # graupel deposition -- make it a slow process - sat_adj0 = 0.90 # adjustment factor (0: no, 1: full) during fast_sat_adj - ql_gen = 1.0e-3 # max new cloud water during remapping step if fast_sat_adj = .t. - ql_mlt = 2.0e-3 # max value of cloud water allowed from melted cloud ice - qs_mlt = 1.0e-6 # max cloud water due to snow melt - ql0_max = 2.0e-3 # max cloud water value (auto converted to rain) - t_sub = 184.0 # min temp for sublimation of cloud ice - qi_gen = 1.82e-6 # max cloud ice generation during remapping step - qi_lim = 1.0 # cloud ice limiter to prevent large ice build up - qi0_max = 1.0e-4 # max cloud ice value (by other sources) - rad_snow = True # consider snow in cloud fraction calculation - rad_rain = True # consider rain in cloud fraction calculation - rad_graupel = True # consider graupel in cloud fraction calculation - tintqs = False # use temperature in the saturation mixing in PDF - dw_ocean = 0.10 # base value for ocean - dw_land = 0.15 # base value for subgrid deviation / variability over land + # gfdl_cloud_microphys.F90 + tau_r2g = 900.0 + """rain freezing during fast_sat""" + tau_smlt = 900.0 + """snow melting""" + tau_g2r = 600.0 + """graupel melting to rain""" + tau_imlt = 600.0 + """cloud ice melting""" + tau_i2s = 1000.0 + """cloud ice to snow auto - conversion""" + tau_l2r = 900.0 + """cloud water to rain auto - conversion""" + tau_g2v = 1200.0 + """graupel sublimation""" + tau_v2g = 21600.0 + """graupel deposition -- make it a slow process""" + sat_adj0 = 0.90 + """adjustment factor (0: no, 1: full) during fast_sat_adj""" + ql_gen = 1.0e-3 + """max new cloud water during remapping step if fast_sat_adj = .t.""" + ql_mlt = 2.0e-3 + """max value of cloud water allowed from melted cloud ice""" + qs_mlt = 1.0e-6 + """max cloud water due to snow melt""" + ql0_max = 2.0e-3 + """max cloud water value (auto converted to rain)""" + t_sub = 184.0 + """min temp for sublimation of cloud ice""" + qi_gen = 1.82e-6 + """max cloud ice generation during remapping step""" + qi_lim = 1.0 + """cloud ice limiter to prevent large ice build up""" + qi0_max = 1.0e-4 + """max cloud ice value (by other sources)""" + rad_snow = True + """consider snow in cloud fraction calculation""" + rad_rain = True + """consider rain in cloud fraction calculation""" + rad_graupel = True + """consider graupel in cloud fraction calculation""" + tintqs = False + """use temperature in the saturation mixing in PDF""" + dw_ocean = 0.10 + """base value for ocean""" + dw_land = 0.15 + """base value for subgrid deviation / variability over land""" # cloud scheme 0 - ? # 1: old fvgfs gfdl) mp implementation # 2: binary cloud scheme (0 / 1) icloud_f = 0 - cld_min = 0.05 # !< minimum cloud fraction - tau_l2v = 300.0 # cloud water to water vapor (evaporation) - tau_v2l = 90.0 # water vapor to cloud water (condensation) + cld_min = 0.05 + """minimum cloud fraction""" + tau_l2v = 300.0 + """cloud water to water vapor (evaporation)""" + tau_v2l = 90.0 + """water vapor to cloud water (condensation)""" c2l_ord = 4 regional = False m_split = 0 @@ -63,70 +90,148 @@ class NamelistDefaults: use_old_omega = True use_logp = False rf_fast = False - p_ref = 1e5 # Surface pressure used to construct a horizontally-uniform reference + p_ref = 1e5 + """Surface pressure used to construct a horizontally-uniform reference""" adiabatic = False nf_omega = 1 fv_sg_adj = -1 n_sponge = 1 fast_sat_adj = True - qc_crt = 5.0e-8 # Minimum condensate mixing ratio to allow partial cloudiness - c_cracw = 0.8 # Rain accretion efficiency - c_paut = ( - 0.5 # Autoconversion cloud water to rain (use 0.5 to reduce autoconversion) - ) - c_pgacs = 0.01 # Snow to graupel "accretion" eff. (was 0.1 in zetac) - c_psaci = 0.05 # Accretion: cloud ice to snow (was 0.1 in zetac) - ccn_l = 300.0 # CCN over land (cm^-3) - ccn_o = 100.0 # CCN over ocean (cm^-3) - const_vg = False # Fall velocity tuning constant of graupel - const_vi = False # Fall velocity tuning constant of ice - const_vr = False # Fall velocity tuning constant of rain water - const_vs = False # Fall velocity tuning constant of snow - vi_fac = 1.0 # if const_vi: 1/3 - vs_fac = 1.0 # if const_vs: 1. - vg_fac = 1.0 # if const_vg: 2. - vr_fac = 1.0 # if const_vr: 4. - de_ice = False # To prevent excessive build-up of cloud ice from external sources - do_qa = True # Do inline cloud fraction - do_sedi_heat = False # Transport of heat in sedimentation - do_sedi_w = True # Transport of vertical motion in sedimentation - fix_negative = True # Fix negative water species - irain_f = 0 # Cloud water to rain auto conversion scheme - mono_prof = False # Perform terminal fall with mono ppm scheme - mp_time = 225.0 # Maximum microphysics timestep (sec) - prog_ccn = False # Do prognostic ccn (yi ming's method) - qi0_crt = 8e-05 # Cloud ice to snow autoconversion threshold - qs0_crt = 0.003 # Snow to graupel density threshold (0.6e-3 in purdue lin scheme) - rh_inc = 0.2 # RH increment for complete evaporation of cloud water and cloud ice - rh_inr = 0.3 # RH increment for minimum evaporation of rain - rthresh = 1e-05 # Critical cloud drop radius (micrometers) - sedi_transport = True # Transport of momentum in sedimentation - use_ppm = False # Use ppm fall scheme - vg_max = 16.0 # Maximum fall speed for graupel - vi_max = 1.0 # Maximum fall speed for ice - vr_max = 16.0 # Maximum fall speed for rain - vs_max = 2.0 # Maximum fall speed for snow - z_slope_ice = True # Use linear mono slope for autoconversions - z_slope_liq = True # Use linear mono slope for autoconversions - tice = 273.16 # set tice = 165. to turn off ice - phase phys (kessler emulator) - alin = 842.0 # "a" in lin1983 - clin = 4.8 # "c" in lin 1983, 4.8 -- > 6. (to enhance ql -- > qs) - isatmedmf = 0 # which version of satmedmfvdif to use - dspheat = False # flag for tke dissipative heating - xkzm_h = 1.0 # background vertical diffusion for heat q over ocean - xkzm_m = 1.0 # background vertical diffusion for momentum over ocean - xkzm_hl = 1.0 # background vertical diffusion for heat q over land - xkzm_ml = 1.0 # background vertical diffusion for momentum over land - xkzm_hi = 1.0 # background vertical diffusion for heat q over ice - xkzm_mi = 1.0 # background vertical diffusion for momentum over ice - xkzm_s = 1.0 # sigma threshold for background mom. diffusion - xkzm_lim = 0.01 # background vertical diffusion limit - xkzminv = 0.15 # diffusivity in inversion layers - xkgdx = 25.0e3 # background vertical diffusion threshold - rlmn = 30.0 # lower-limiter on asymtotic mixing length in satmedmfdiff - rlmx = 300.0 # upper-limiter on asymtotic mixing length in satmedmfdiff - do_dk_hb19 = False # flag for using hb19 background diff formula in satmedmfdiff - cap_k0_land = False # flag for applying limiter on background diff in inversion layer over land in satmedmfdiff + qc_crt = 5.0e-8 + """Minimum condensate mixing ratio to allow partial cloudiness""" + c_cracw = 0.8 + """Rain accretion efficiency""" + c_paut = 0.5 + """Autoconversion cloud water to rain (use 0.5 to reduce autoconversion)""" + c_pgacs = 0.01 + """Snow to graupel "accretion" eff. (was 0.1 in zetac)""" + c_psaci = 0.05 + """Accretion: cloud ice to snow (was 0.1 in zetac)""" + ccn_l = 300.0 + """CCN over land (cm^-3)""" + ccn_o = 100.0 + """CCN over ocean (cm^-3)""" + const_vg = False + """Fall velocity tuning constant of graupel""" + const_vi = False + """Fall velocity tuning constant of ice""" + const_vr = False + """Fall velocity tuning constant of rain water""" + const_vs = False + """Fall velocity tuning constant of snow""" + vi_fac = 1.0 + """if const_vi: 1/3""" + vs_fac = 1.0 + """if const_vs: 1.""" + vg_fac = 1.0 + """if const_vg: 2.""" + vr_fac = 1.0 + """if const_vr: 4.""" + de_ice = False + """To prevent excessive build-up of cloud ice from external sources""" + do_qa = True + """Do inline cloud fraction""" + do_sedi_heat = False + """Transport of heat in sedimentation""" + do_sedi_w = True + """Transport of vertical motion in sedimentation""" + fix_negative = True + """Fix negative water species""" + irain_f = 0 + """Cloud water to rain auto conversion scheme""" + mono_prof = False + """Perform terminal fall with mono ppm scheme""" + mp_time = 225.0 + """Maximum microphysics timestep (sec)""" + prog_ccn = False + """Do prognostic ccn (yi ming's method)""" + qi0_crt = 8e-05 + """Cloud ice to snow autoconversion threshold""" + qs0_crt = 0.003 + """Snow to graupel density threshold (0.6e-3 in purdue lin scheme)""" + rh_inc = 0.2 + """RH increment for complete evaporation of cloud water and cloud ice""" + rh_inr = 0.3 + """RH increment for minimum evaporation of rain""" + rthresh = 1e-05 + """Critical cloud drop radius (micrometers)""" + sedi_transport = True + """Transport of momentum in sedimentation""" + use_ppm = False + """Use ppm fall scheme""" + vg_max = 16.0 + """Maximum fall speed for graupel""" + vi_max = 1.0 + """Maximum fall speed for ice""" + vr_max = 16.0 + """Maximum fall speed for rain""" + vs_max = 2.0 + """Maximum fall speed for snow""" + z_slope_ice = True + """Use linear mono slope for autoconversions""" + z_slope_liq = True + """Use linear mono slope for autoconversions""" + tice = 273.16 + """set tice = 165. to turn off ice - phase phys (kessler emulator)""" + alin = 842.0 + """value for 'a' in lin1983""" + clin = 4.8 + """"c" in lin 1983, 4.8 -- > 6. (to enhance ql -- > qs)""" + mom4ice = False + lsm = 1 + redrag = False + isatmedmf = 0 + """which version of satmedmfvdif to use""" + dspheat = False + """flag for tke dissipative heating""" + xkzm_h = 1.0 + """background vertical diffusion for heat q over ocean""" + xkzm_m = 1.0 + """background vertical diffusion for momentum over ocean""" + xkzm_hl = 1.0 + """background vertical diffusion for heat q over land""" + xkzm_ml = 1.0 + """background vertical diffusion for momentum over land""" + xkzm_hi = 1.0 + """background vertical diffusion for heat q over ice""" + xkzm_mi = 1.0 + """background vertical diffusion for momentum over ice""" + xkzm_ho = 1.0 + """background vertical diffusion for heat q over ocean""" + xkzm_mo = 1.0 + """background vertical diffusion for momentum over ocean""" + xkzm_s = 1.0 + """sigma threshold for background mom. diffusion""" + xkzm_lim = 0.01 + """background vertical diffusion limit""" + xkzminv = 0.15 + """diffusivity in inversion layers""" + xkgdx = 25.0e3 + """background vertical diffusion threshold""" + rlmn = 30.0 + """lower-limiter on asymtotic mixing length in satmedmfdiff""" + rlmx = 300.0 + """upper-limiter on asymtotic mixing length in satmedmfdiff""" + do_dk_hb19 = False + """flag for using hb19 background diff formula in satmedmfdiff""" + cap_k0_land = True + """flag for applying limter on background diff in inversion layer over land in satmedmfdiff""" + ncld = 1 + """choice of cloud scheme""" + c0s_shal = 0.002 + """c_e for shallow convection (Han and Pan, 2011, eq(6))""" + c1_shal = 5.0e-4 + """conversion parameter of detrainment from liquid water into convetive precipitaiton""" + clam_shal = 0.3 + """conversion parameter of detrainment from liquid water into grid-scale cloud water""" + pgcon_shal = 0.55 + """control the reduction in momentum transport""" + asolfac_shal = 0.89 + """aerosol-aware parameter based on Lim & Hong (2012): asolfac= cx / c0s(=.002), cx = min([-0.7 ln(Nccn) + 24]*1.e-4, c0s), Nccn: CCN number concentration in cm^(-3), Until a realistic Nccn is provided, typical Nccns are assumed, as Nccn=100 for sea and Nccn=7000 for land""" + lsoil = 4 + """Number of soil levels in land surface model""" + sw_dynamics = False + """flag for turning on shallow water conditions in dyn core""" @classmethod def as_dict(cls): @@ -309,6 +414,9 @@ class Namelist: tice: float = NamelistDefaults.tice alin: float = NamelistDefaults.alin clin: float = NamelistDefaults.clin + mom4ice: bool = NamelistDefaults.mom4ice + lsm: int = NamelistDefaults.lsm + redrag: bool = NamelistDefaults.redrag isatmedmf: int = NamelistDefaults.isatmedmf dspheat: bool = NamelistDefaults.dspheat xkzm_h: float = NamelistDefaults.xkzm_h @@ -317,6 +425,8 @@ class Namelist: xkzm_ml: float = NamelistDefaults.xkzm_ml xkzm_hi: float = NamelistDefaults.xkzm_hi xkzm_mi: float = NamelistDefaults.xkzm_mi + xkzm_ho: float = NamelistDefaults.xkzm_ho + xkzm_mo: float = NamelistDefaults.xkzm_mo xkzm_s: float = NamelistDefaults.xkzm_s xkzm_lim: float = NamelistDefaults.xkzm_lim xkzminv: float = NamelistDefaults.xkzminv @@ -325,8 +435,12 @@ class Namelist: rlmx: float = NamelistDefaults.rlmx do_dk_hb19: bool = NamelistDefaults.do_dk_hb19 cap_k0_land: bool = NamelistDefaults.cap_k0_land - # c0s_shal: Any - # c1_shal: Any + c0s_shal: float = NamelistDefaults.c0s_shal + c1_shal: float = NamelistDefaults.c1_shal + clam_shal: float = NamelistDefaults.clam_shal + pgcon_shal: float = NamelistDefaults.pgcon_shal + asolfac_shal: float = NamelistDefaults.asolfac_shal + ncld: int = NamelistDefaults.ncld # cal_pre: Any # cdmbgwd: Any # cnvcld: Any @@ -353,7 +467,6 @@ class Namelist: # ivegsrc: Any # ldiag3d: Any # lwhtr: Any - # ncld: int # nst_anl: Any # pdfcld: Any # pre_rad: Any @@ -416,69 +529,68 @@ class Namelist: u_max: float = NamelistDefaults.u_max do_f3d: bool = NamelistDefaults.do_f3d inline_q: bool = NamelistDefaults.inline_q - do_skeb: bool = NamelistDefaults.do_skeb # save dissipation estimate + do_skeb: bool = NamelistDefaults.do_skeb + """save dissipation estimate""" use_logp: bool = NamelistDefaults.use_logp moist_phys: bool = NamelistDefaults.moist_phys check_negative: bool = NamelistDefaults.check_negative # gfdl_cloud_microphys.F90 - tau_r2g: float = NamelistDefaults.tau_r2g # rain freezing during fast_sat - tau_smlt: float = NamelistDefaults.tau_smlt # snow melting - tau_g2r: float = NamelistDefaults.tau_g2r # graupel melting to rain - tau_imlt: float = NamelistDefaults.tau_imlt # cloud ice melting - tau_i2s: float = NamelistDefaults.tau_i2s # cloud ice to snow auto - conversion - tau_l2r: float = NamelistDefaults.tau_l2r # cloud water to rain auto - conversion - tau_g2v: float = NamelistDefaults.tau_g2v # graupel sublimation - tau_v2g: float = ( - NamelistDefaults.tau_v2g - ) # graupel deposition -- make it a slow process - sat_adj0: float = ( - NamelistDefaults.sat_adj0 - ) # adjustment factor (0: no 1: full) during fast_sat_adj - ql_gen: float = ( - 1.0e-3 # max new cloud water during remapping step if fast_sat_adj = .t. - ) - ql_mlt: float = ( - NamelistDefaults.ql_mlt - ) # max value of cloud water allowed from melted cloud ice - qs_mlt: float = NamelistDefaults.qs_mlt # max cloud water due to snow melt - ql0_max: float = ( - NamelistDefaults.ql0_max - ) # max cloud water value (auto converted to rain) - t_sub: float = NamelistDefaults.t_sub # min temp for sublimation of cloud ice - qi_gen: float = ( - NamelistDefaults.qi_gen - ) # max cloud ice generation during remapping step - qi_lim: float = ( - NamelistDefaults.qi_lim - ) # cloud ice limiter to prevent large ice build up - qi0_max: float = NamelistDefaults.qi0_max # max cloud ice value (by other sources) - rad_snow: bool = ( - NamelistDefaults.rad_snow - ) # consider snow in cloud fraction calculation - rad_rain: bool = ( - NamelistDefaults.rad_rain - ) # consider rain in cloud fraction calculation - rad_graupel: bool = ( - NamelistDefaults.rad_graupel - ) # consider graupel in cloud fraction calculation - tintqs: bool = ( - NamelistDefaults.tintqs - ) # use temperature in the saturation mixing in PDF - dw_ocean: float = NamelistDefaults.dw_ocean # base value for ocean - dw_land: float = ( - NamelistDefaults.dw_land - ) # base value for subgrid deviation / variability over land + tau_r2g: float = NamelistDefaults.tau_r2g + """rain freezing during fast_sat""" + tau_smlt: float = NamelistDefaults.tau_smlt + """snow melting""" + tau_g2r: float = NamelistDefaults.tau_g2r + """graupel melting to rain""" + tau_imlt: float = NamelistDefaults.tau_imlt + """cloud ice melting""" + tau_i2s: float = NamelistDefaults.tau_i2s + """cloud ice to snow auto - conversion""" + tau_l2r: float = NamelistDefaults.tau_l2r + """cloud water to rain auto - conversion""" + tau_g2v: float = NamelistDefaults.tau_g2v + """graupel sublimation""" + tau_v2g: float = NamelistDefaults.tau_v2g + """graupel deposition -- make it a slow process""" + sat_adj0: float = NamelistDefaults.sat_adj0 + """adjustment factor (0: no 1: full) during fast_sat_adj""" + ql_gen: float = 1.0e-3 + """max new cloud water during remapping step if fast_sat_adj = .t.""" + ql_mlt: float = NamelistDefaults.ql_mlt + """max value of cloud water allowed from melted cloud ice""" + qs_mlt: float = NamelistDefaults.qs_mlt + """max cloud water due to snow melt""" + ql0_max: float = NamelistDefaults.ql0_max + """max cloud water value (auto converted to rain)""" + t_sub: float = NamelistDefaults.t_sub + """min temp for sublimation of cloud ice""" + qi_gen: float = NamelistDefaults.qi_gen + """max cloud ice generation during remapping step""" + qi_lim: float = NamelistDefaults.qi_lim + """cloud ice limiter to prevent large ice build up""" + qi0_max: float = NamelistDefaults.qi0_max + """max cloud ice value (by other sources)""" + rad_snow: bool = NamelistDefaults.rad_snow + """consider snow in cloud fraction calculation""" + rad_rain: bool = NamelistDefaults.rad_rain + """consider rain in cloud fraction calculation""" + rad_graupel: bool = NamelistDefaults.rad_graupel + """consider graupel in cloud fraction calculation""" + tintqs: bool = NamelistDefaults.tintqs + """use temperature in the saturation mixing in PDF""" + dw_ocean: float = NamelistDefaults.dw_ocean + """base value for ocean""" + dw_land: float = NamelistDefaults.dw_land + """base value for subgrid deviation / variability over land""" # cloud scheme 0 - ? # 1: old fvgfs gfdl) mp implementation # 2: binary cloud scheme (0 / 1) icloud_f: int = NamelistDefaults.icloud_f - cld_min: float = NamelistDefaults.cld_min # !< minimum cloud fraction - tau_l2v: float = ( - NamelistDefaults.tau_l2v - ) # cloud water to water vapor (evaporation) - tau_v2l: float = ( - NamelistDefaults.tau_v2l - ) # water vapor to cloud water (condensation) + cld_min: float = NamelistDefaults.cld_min + """minimum cloud fraction""" + tau_l2v: float = NamelistDefaults.tau_l2v + """cloud water to water vapor (evaporation)""" + tau_v2l: float = NamelistDefaults.tau_v2l + """water vapor to cloud water (condensation)""" c2l_ord: int = NamelistDefaults.c2l_ord regional: bool = NamelistDefaults.regional m_split: int = NamelistDefaults.m_split @@ -490,7 +602,9 @@ class Namelist: nf_omega: int = NamelistDefaults.nf_omega fv_sg_adj: int = NamelistDefaults.fv_sg_adj n_sponge: int = NamelistDefaults.n_sponge + lsoil: int = NamelistDefaults.lsoil daily_mean: bool = False + sw_dynamics: bool = NamelistDefaults.sw_dynamics """Flag to replace cosz with daily mean value in physics""" @classmethod diff --git a/ndsl/optional_imports.py b/ndsl/optional_imports.py index 990a954d..d1079fb7 100644 --- a/ndsl/optional_imports.py +++ b/ndsl/optional_imports.py @@ -14,11 +14,6 @@ def __call__(self, *args, **kwargs): except ModuleNotFoundError as err: zarr = RaiseWhenAccessed(err) -try: - import xarray -except ModuleNotFoundError as err: - xarray = None - try: import cupy except ImportError: @@ -30,14 +25,3 @@ def __call__(self, *args, **kwargs): cupy.cuda.runtime.deviceSynchronize() except cupy.cuda.runtime.CUDARuntimeError: cupy = None - - -try: - import gt4py -except ImportError: - gt4py = None - -try: - import dace -except ImportError: - dace = None diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index fce0cf63..43751528 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,9 +1,11 @@ -from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata -from ndsl.quantity.quantity import Quantity +from .metadata import QuantityHaloSpec, QuantityMetadata +from .quantity import Quantity __all__ = [ "Quantity", "QuantityMetadata", "QuantityHaloSpec", + "FieldBundle", + "FieldBundleType", ] diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py new file mode 100644 index 00000000..df637fe2 --- /dev/null +++ b/ndsl/quantity/field_bundle.py @@ -0,0 +1,189 @@ +import copy +from dataclasses import dataclass +from typing import Any + +from gt4py.cartesian import gtscript + +from ndsl.dsl.typing import Float +from ndsl.initialization.allocator import QuantityFactory +from ndsl.quantity.quantity import Quantity + + +# ToDo: This is 4th dimensions restricted. We need a concept +# of data dimensions index here to be able to extend to N dimensions +_DataDimensionIndex = int +_FieldBundleIndexer = dict[str, _DataDimensionIndex] + + +class FieldBundle: + """Field Bundle wraps a nD array (3D + n Data Dimensions) into a complex + indexing scheme that allows a dual name-based and index-based + access to the underlying memory. It is paired with the `FieldBundleType` + which provides a way to type hint parameters for stencils in the `gtscript`. + + WARNING: The present implementation only allows for 4D array. + """ + + _quantity: Quantity + _indexer: _FieldBundleIndexer = {} + + def __init__( + self, + bundle_name: str, + quantity: Quantity, + mapping: _FieldBundleIndexer = {}, + register_type: bool = False, + ): + """ + Initialize a bundle from a nD quantity. + + Dev note: current implementation limits to 4D inputs. + + Args: + bundle_name: name of the bundle, accessible via `name`. + quantity: data inputs as a nD array. + mapping: sparse dict of [name, index] to be able to call tracers by name. + register_type: boolean to register the type as part of initialization. + """ + if len(quantity.shape) != 4: + raise NotImplementedError("FieldBundle implementation restricted to 4D") + + self.name = bundle_name + self._quantity = quantity + self._indexer = mapping + if register_type: + # ToDo: extend the dims below to work with more than 4 dims + assert len(quantity.shape) == 4 + FieldBundleType.register(bundle_name, quantity.shape[3:]) + + def map(self, index: _DataDimensionIndex, name: str): + """Map a single `index` to ` name`""" + self._indexer[name] = index + + @property + def quantity(self) -> Quantity: + return self._quantity + + @property + def shape(self) -> tuple[int, ...]: + return self._quantity.shape + + def groupby(self, name: str) -> Quantity: + """Not implemented""" + raise NotImplementedError + + def __getattr__(self, name: str) -> Quantity: + """Allow to reference sub-array using `field.a_name`""" + if name not in self._indexer.keys(): + # This replicates as close possible the default behavior of getattr + # without breaking orchestration + return None # type: ignore + # ToDo: extend the dims below to work with more than 4 dims + assert len(self._quantity.data.shape) == 4 + return Quantity( + data=self._quantity.data[:, :, :, self.index(name)], + dims=self._quantity.dims[:-1], + units=self._quantity.units, + origin=self._quantity.origin[:-1], + extent=self._quantity.extent[:-1], + ) + + def index(self, name: str) -> int: + """Get index from name.""" + return self._indexer[name] + + @property + def __array_interface__(self): + """Memory interface for CPU.""" + return self._quantity.__array_interface__ + + @property + def __cuda_array_interface__(self): + """Memory interface for GPU memory as defined by cupy.""" + return self._quantity.__cuda_array_interface__ + + def __descriptor__(self) -> Any: + """Data descriptor for DaCe.""" + return self._quantity.__descriptor__() + + @staticmethod + def extend_3D_quantity_factory( + quantity_factory: QuantityFactory, + extra_dims: dict[str, int], + ) -> QuantityFactory: + """Create a nD quantity factory from a cartesian 3D factory. + + Args: + quantity_factory: Cartesian 3D factory. + extra_dims: dict of [name, size] of the data dimensions to add. + """ + new_factory = copy.copy(quantity_factory) + new_factory.set_extra_dim_lengths( + **{ + **extra_dims, + } + ) + return new_factory + + +@dataclass +class MarkupFieldBundleType: + """Markup a field bundle to delay specialization. + + Properties: + name: name of the future type to look into the registrar. + """ + + name: str + + +class FieldBundleType: + """Field Bundle Types to help with static sizing of Data Dimensions. + + Methods: + register: Register a type by sizing it's data dimensions + T: access any registered types for type hinting. + """ + + _field_type_registrar: dict[str, gtscript._FieldDescriptor] = {} + + @classmethod + def register( + cls, name: str, data_dims: tuple[int], dtype=Float + ) -> gtscript._FieldDescriptor: + """Register a name type by name by giving the size of it's data dimensions. + + The same type cannot be registered twice and will error out. + + Args: + name: Type name, to be re-used with `T`. + data_dims: tuple of int giving size of each data dimensions. + dtype: Inner data type, defaults to Float. + """ + if name in cls._field_type_registrar.keys(): + raise RuntimeError(f"Registering {name} a second time!") + cls._field_type_registrar[name] = gtscript.Field[ + gtscript.IJK, (dtype, (data_dims)) + ] + return cls._field_type_registrar[name] + + @classmethod + def T( + cls, name: str, do_markup: bool = True + ) -> gtscript._FieldDescriptor | MarkupFieldBundleType: + """ + Get registered type. + + Dev note: The markup feature is to allow early parsing (at file import) + to go ahead - while we will resolve the full type when calling the stencil. + + Args: + name: name of the type as registered via `register` + do_markup: if name not registered, markup for a future specialization + at stencil call time + """ + if name not in cls._field_type_registrar: + if do_markup: + return MarkupFieldBundleType(name) + raise RuntimeError(f"FieldBundle type {name} as not been registered!") + return cls._field_type_registrar[name] diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 4f80fff1..6d83cd89 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,14 +1,17 @@ import warnings from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast +import dace import matplotlib.pyplot as plt import numpy as np +import xarray as xr +from gt4py import storage as gt_storage +from gt4py.cartesian import backend as gt_backend from mpi4py import MPI import ndsl.constants as constants from ndsl.dsl.typing import Float, is_float -from ndsl.optional_imports import cupy, dace, gt4py -from ndsl.optional_imports import xarray as xr +from ndsl.optional_imports import cupy from ndsl.quantity.bounds import BoundedArrayView from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule @@ -77,7 +80,7 @@ def __init__( ) if gt4py_backend is not None: - gt4py_backend_cls = gt4py.cartesian.backend.from_name(gt4py_backend) + gt4py_backend_cls = gt_backend.from_name(gt4py_backend) assert gt4py_backend_cls is not None is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] @@ -153,9 +156,16 @@ def from_data_array( gt4py_backend=gt4py_backend, ) - def to_netcdf(self, path: str, name="var", rank: int = -1) -> None: + def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: - self.data_array.to_dataset(name=name).to_netcdf(f"{path}__r{rank}.nc4") + if all_data: + self.data_as_xarray.to_dataset(name=name).to_netcdf( + f"{path}__r{rank}.nc4" + ) + else: + self.field_as_xarray.to_dataset(name=name).to_netcdf( + f"{path}__r{rank}.nc4" + ) def halo_spec(self, n_halo: int) -> QuantityHaloSpec: return QuantityHaloSpec( @@ -192,7 +202,7 @@ def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: Tuple): """Allocates an ndarray with optimal memory layout, and copies the data over.""" - storage = gt4py.storage.from_array( + storage = gt_storage.from_array( data, data.dtype, backend=gt4py_backend, @@ -239,6 +249,10 @@ def view(self) -> BoundedArrayView: """a view into the computational domain of the underlying data""" return self._compute_domain_view + @property + def field(self) -> np.ndarray | cupy.ndarray: + return self._compute_domain_view[:] + @property def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" @@ -260,16 +274,14 @@ def extent(self) -> Tuple[int, ...]: return self.metadata.extent @property - def data_array(self, full_data=False) -> xr.DataArray: - """Returns an Xarray.DataArray of the view (domain) + def field_as_xarray(self) -> xr.DataArray: + """Returns an Xarray.DataArray of the field (domain)""" + return xr.DataArray(self.field, dims=self.dims, attrs=self.attrs) - Args: - full_data: Return the entire data (halo included) instead of the view - """ - if full_data: - return xr.DataArray(self.data[:], dims=self.dims, attrs=self.attrs) - else: - return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) + @property + def data_as_xarray(self) -> xr.DataArray: + """Returns an Xarray.DataArray of the underlying array""" + return xr.DataArray(self.data, dims=self.dims, attrs=self.attrs) @property def np(self) -> NumpyModule: @@ -293,13 +305,7 @@ def __descriptor__(self) -> Any: If the internal data given doesn't follow the protocol it will most likely fail. """ - if dace: - return dace.data.create_datadescriptor(self.data) - else: - raise ImportError( - "Attempt to use DaCe orchestrated backend but " - "DaCe module is not available." - ) + return dace.data.create_datadescriptor(self.data) def transpose( self, diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index 01f9bdb8..7983f8a9 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -2,12 +2,13 @@ import os from typing import BinaryIO, Generator, Iterable +import xarray as xr + import ndsl.constants as constants import ndsl.filesystem as filesystem import ndsl.io as io from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index -from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity from ndsl.restart._properties import RESTART_PROPERTIES, RestartProperties diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py index b46123a3..18f44afb 100644 --- a/ndsl/stencils/basic_operations.py +++ b/ndsl/stencils/basic_operations.py @@ -1,7 +1,5 @@ -import gt4py.cartesian.gtscript as gtscript -from gt4py.cartesian.gtscript import PARALLEL, computation, interval - -from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ +from ndsl.dsl.gt4py import FORWARD, PARALLEL, computation, function, interval +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, IntField, IntFieldIJ def copy_defn(q_in: FloatField, q_out: FloatField): @@ -18,11 +16,8 @@ def copy_defn(q_in: FloatField, q_out: FloatField): def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): """ - Multiplies every element of q_out - by every element of the adjustment - field over the interval, replacing - the elements of q_out by the result - of the multiplication. + Multiplies every element of q_out by every element of the adjustment field over the + interval, replacing the elements of q_out by the result of the multiplication. Args: adjustment: adjustment field @@ -34,8 +29,7 @@ def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): def set_value_defn(q_out: FloatField, value: Float): """ - Sets every element of q_out to the - value specified by value argument. + Sets every element of q_out to the value specified by value argument. Args: q_out: output field @@ -47,11 +41,8 @@ def set_value_defn(q_out: FloatField, value: Float): def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): """ - Divides every element of q_out - by every element of the adjustment - field over the interval, replacing - the elements of q_out by the result - of the multiplication. + Divides every element of q_out by every element of the adjustment field over the + interval, replacing the elements of q_out by the result of the multiplication. Args: adjustment: adjustment field @@ -61,36 +52,61 @@ def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): q_out = q_out / adjustment -@gtscript.function +def select_k( + in_field: FloatField, + out_field: FloatFieldIJ, + k_mask: IntField, + k_select: IntFieldIJ, +): + """ + Saves a specific k-index of a 3D field to a new 2D array. The k-value can be + different for each i,j point. + + Args: + in_field: A 3D array to select from + out_field: A 2D field to save values in + k_mask: a field that lists each k-index + k_select: the k-value to extract from in_field + """ + # TODO: refactor this using THIS_K instead of a mask + with computation(FORWARD), interval(...): + if k_mask == k_select: + out_field = in_field + + +def average_in( + q_out: FloatField, + adjustment: FloatField, +): + """ + Averages every element of q_out with every element of the adjustment field, + overwriting q_out. + + Args: + adjustment: adjustment field + q_out: output field + """ + with computation(PARALLEL), interval(...): + q_out = (q_out + adjustment) * 0.5 + + +@function def sign(a, b): """ - Defines asignb as the absolute value - of a, and checks if b is positive - or negative, assigning the analogus - sign value to asignb. asignb is returned + Defines a_sign_b as the absolute value of a, and checks if b is positive or + negative, assigning the analogous sign value to a_sign_b. a_sign_b is returned. Args: a: A number b: A number """ - asignb = abs(a) - if b > 0: - asignb = asignb - else: - asignb = -asignb - return asignb + a_sign_b = abs(a) + return a_sign_b if b > 0 else -a_sign_b -@gtscript.function +@function def dim(a, b): """ - Performs a check on the difference - between the values in arguments - a and b. The variable diff is set - to the difference between a and b - when the difference is positive, - otherwise it is set to zero. The - function returns the diff variable. + Calculates a - b, camped to 0, i.e. max(a - b, 0). """ - diff = a - b if a - b > 0 else 0 - return diff + return max(a - b, 0) diff --git a/ndsl/stencils/testing/README.md b/ndsl/stencils/testing/README.md index 098dd489..d08aa7e4 100644 --- a/ndsl/stencils/testing/README.md +++ b/ndsl/stencils/testing/README.md @@ -83,4 +83,4 @@ Upon failure, the test will drop a `netCDF` file in a `./.translate-errors` dire ## Environment variables -- `PACE_TEST_N_THRESHOLD_SAMPLES`: Upon failure the system will try to perturb the output in an attempt to check for numerical instability. This means re-running the test for N samples. Default is `10`, `0` or less turns this feature off. +- `NDSL_TEST_N_THRESHOLD_SAMPLES`: Upon failure the system will try to perturb the output in an attempt to check for numerical instability. This means re-running the test for N samples. Default is `0`, which turns this feature off. diff --git a/ndsl/stencils/testing/serialbox_to_netcdf.py b/ndsl/stencils/testing/serialbox_to_netcdf.py index f514ae0c..e0dac98f 100644 --- a/ndsl/stencils/testing/serialbox_to_netcdf.py +++ b/ndsl/stencils/testing/serialbox_to_netcdf.py @@ -12,7 +12,7 @@ import serialbox except ModuleNotFoundError: raise ModuleNotFoundError( - "Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env" + "Serialbox couldn't be imported, make sure it's in your PYTHONPATH or your env" ) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 5b9bc773..18271a4a 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -6,10 +6,10 @@ import numpy as np import pytest -import ndsl.dsl.gt4py_utils as gt_utils from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.dsl import gt4py_utils as gt_utils from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig from ndsl.quantity import Quantity @@ -25,6 +25,7 @@ OUTDIR = "./.translate-outputs" GPU_MAX_ERR = 1e-10 GPU_NEAR_ZERO = 1e-15 +N_THRESHOLD_SAMPLES = int(os.getenv("NDSL_TEST_N_THRESHOLD_SAMPLES", 0)) def platform(): @@ -89,7 +90,7 @@ def process_override(threshold_overrides, testobj, test_name, backend): testobj.skip_test = bool(match["skip_test"]) elif len(matches) > 1: raise Exception( - "misconfigured threshold overrides file, more than 1 specification for " + "Misconfigured threshold overrides file, more than 1 specification for " + test_name + " with backend=" + backend @@ -98,9 +99,6 @@ def process_override(threshold_overrides, testobj, test_name, backend): ) -N_THRESHOLD_SAMPLES = int(os.getenv("PACE_TEST_N_THRESHOLD_SAMPLES", 10)) - - def get_thresholds(testobj, input_data): _get_thresholds(testobj.compute, input_data) @@ -158,7 +156,7 @@ def test_sequential_savepoint( ): if case.testobj is None: pytest.xfail( - f"no translate object available for savepoint {case.savepoint_name}" + f"No translate object available for savepoint {case.savepoint_name}." ) stencil_config = StencilConfig( compilation_config=CompilationConfig(backend=backend), @@ -178,7 +176,7 @@ def test_sequential_savepoint( if case.testobj.skip_test: return if not case.exists: - pytest.skip(f"Data at rank {case.grid.rank} does not exists") + pytest.skip(f"Data at rank {case.grid.rank} does not exist.") input_data = dataset_to_dict(case.ds_in) input_names = ( case.testobj.serialnames(case.testobj.in_vars["data_vars"]) @@ -188,7 +186,7 @@ def test_sequential_savepoint( input_data = {name: input_data[name] for name in input_names} except KeyError as e: raise KeyError( - f"Variable {e} was described in the translate test but cannot be found in the NetCDF" + f"Variable {e} was described in the translate test but cannot be found in the NetCDF." ) original_input_data = copy.deepcopy(input_data) # give the user a chance to load data from other savepoints to allow @@ -208,7 +206,7 @@ def test_sequential_savepoint( try: ref_data = all_ref_data[varname] except KeyError: - raise KeyError(f"Output {varname} couldn't be found in output data") + raise KeyError(f"Output {varname} couldn't be found in output data.") if hasattr(case.testobj, "subset_output"): ref_data = case.testobj.subset_output(varname, ref_data) with subtests.test(varname=varname): diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 8132d290..0acee958 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -6,16 +6,12 @@ import ndsl.dsl.gt4py_utils as utils from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Field, Float, Int # noqa: F401 +from ndsl.optional_imports import cupy as cp from ndsl.quantity import Quantity from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.savepoint import DataLoader -try: - import cupy as cp -except ImportError: - cp = None - logger = logging.getLogger(__name__) @@ -85,14 +81,14 @@ def setup(self, inputs) -> None: def compute_func(self, **inputs) -> Optional[dict[str, Any]]: """Compute function to transform the dictionary of `inputs`. - Must return a dictionnary of updated variables""" + Must return a dictionary of updated variables""" raise NotImplementedError("Implement a child class compute method") def compute(self, inputs) -> dict[str, Any]: - """Transform inputs from NetCDF to gt4py.storagers, run compute_func then slice + """Transform inputs from NetCDF to gt4py.storages, run compute_func then slice the outputs based on specifications. - Return: Dictonnary of storages reshaped for comparison + Return: Dictionary of storages reshaped for comparison """ self.setup(inputs) return self.slice_output(self.compute_from_storage(inputs)) @@ -201,7 +197,7 @@ def collect_start_indices(self, datashape, varinfo): def make_storage_data_input_vars( self, inputs, storage_vars=None, dict_4d=True ) -> None: - """From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionnary to update inputs to + """From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionary to update inputs to their configured shape. Return: None @@ -214,6 +210,9 @@ def make_storage_data_input_vars( inputs_out[p] = inputs_in[p] for d, info in storage_vars.items(): serialname = info["serialname"] if "serialname" in info else d + index_variable = ( + info["index_variable"] if "index_variable" in info else False + ) self.update_info(info, inputs_in) if "kaxis" in info: inputs_in[serialname] = np.moveaxis( @@ -231,6 +230,8 @@ def make_storage_data_input_vars( dummy_axes = info.get("dummy_axes", None) axis = info.get("axis", 2) + if index_variable: + inputs_in[serialname] -= 1 inputs_out[d] = self.make_storage_data( np.squeeze(inputs_in[serialname]), istart=istart, @@ -258,9 +259,16 @@ def slice_output(self, inputs, out_data=None) -> dict[str, Any]: info = self.out_vars[var] self.update_info(info, inputs) serialname = info["serialname"] if "serialname" in info else var + index_variable = ( + info["index_variable"] if "index_variable" in info else False + ) ds = self.grid.default_domain_dict() ds.update(info) data_result = as_numpy(out_data[var]) + if index_variable: + if isinstance(data_result, dict): + raise TypeError(f"Variable {serialname} is a 4D dict, not an index") + data_result += 1 if isinstance(data_result, dict): names_4d = info.get("names_4d", utils.tracer_variables) var4d = np.zeros( diff --git a/ndsl/stencils/tridiag.py b/ndsl/stencils/tridiag.py new file mode 100644 index 00000000..e1fe50c0 --- /dev/null +++ b/ndsl/stencils/tridiag.py @@ -0,0 +1,89 @@ +from gt4py.cartesian.gtscript import BACKWARD, FORWARD, computation, interval + +from ndsl.dsl.typing import BoolFieldIJ, FloatField + + +def tridiag_solve( + a: FloatField, + b: FloatField, + c: FloatField, + d: FloatField, + x: FloatField, + delta: FloatField, +): + """ + This stencil solves a square, k x k tridiagonal matrix system + with coefficients a, b, and c, and vectors p and d using the Thomas algorithm: + ! ### ### ### ### ### ###! + ! #b(0), c(0), 0 , 0 , 0 , . . . , 0 # # x(0) # # d(0) #! + ! #a(1), b(1), c(1), 0 , 0 , . . . , 0 # # x(1) # # d(1) #! + ! # 0 , a(2), b(2), c(2), 0 , . . . , 0 # # x(2) # # d(2) #! + ! # 0 , 0 , a(3), b(3), c(3), . . . , 0 # # x(3) # # d(3) #! + ! # 0 , 0 , 0 , a(4), b(4), . . . , 0 # # x(4) # # d(4) #! + ! # . . # # . # = # . #! + ! # . . # # . # # . #! + ! # . . # # . # # . #! + ! # 0 , . . . , 0 , a(k-2), b(k-2), c(k-2), 0 # #x(k-3)# #d(k-3)#! + ! # 0 , . . . , 0 , 0 , a(k-1), b(k-1), c(k-1)# #x(k-2)# #d(k-2)#! + ! # 0 , . . . , 0 , 0 , 0 , a(k) , b(k) # #x(k-1)# #d(k-1)#! + ! ### ### ### ### ### ###! + + Args: + a (in): lower-diagonal matrix coefficients + b (in): diagonal matrix coefficients + c (in): upper-diagonal matrix coefficients + d (in): Result vector + x (out): The vector to solve for + delta (out): d post-pivot + """ + with computation(FORWARD): # Forward sweep + with interval(0, 1): + x = c / b + delta = d / b + with interval(1, None): + x = c / (b - a * x[0, 0, -1]) + delta = (d - a * delta[0, 0, -1]) / (b - a * x[0, 0, -1]) + with computation(BACKWARD): # Reverse sweep + with interval(-1, None): + x = delta + with interval(0, -1): + x = delta - x * x[0, 0, 1] + + +def masked_tridiag_solve( + a: FloatField, + b: FloatField, + c: FloatField, + d: FloatField, + x: FloatField, + delta: FloatField, + mask: BoolFieldIJ, +): + """ + Same as tridiag_solve but restricted to a subset of horizontal points + + Args: + a (in): lower-diagonal matrix coefficients + b (in): diagonal matrix coefficients + c (in): upper-diagonal matrix coefficients + d (in): Result vector + mask (in): Columns to execute the stencil on + x (out): The vector to solve for + delta (out): d post-pivot + """ + with computation(FORWARD): # Forward sweep + with interval(0, 1): + if mask: + x = c / b + delta = d / b + with interval(1, None): + if mask: + x = c / (b - a * x[0, 0, -1]) + delta = (d - a * delta[0, 0, -1]) / (b - a * x[0, 0, -1]) + with computation(BACKWARD): # Reverse sweep + with interval(-1, None): + if mask: + x = delta + with interval(0, -1): + if mask: + x = delta - x * x[0, 0, 1] diff --git a/ndsl/testing/README.md b/ndsl/testing/README.md index 9154956c..549d3467 100644 --- a/ndsl/testing/README.md +++ b/ndsl/testing/README.md @@ -2,107 +2,114 @@ ## Summary -NDSL exposes a "Translate" test system which allows the automatic numerical regression test against a pre-defined sets of NetCDFs. +NDSL exposes a "Translate test" system which allows the automatic numerical regression test against a pre-defined sets of NetCDFs. -To write a translate test, derive from `TranslateFortranData2Py`. The system works by matching name of Translate class and data, e.g.: +To write a translate test, derive from `TranslateFortranData2Py`. The system works by matching the name of the Translate test class with input/output data, e.g.: -- if `TranslateNAME` is the name of the translate class -- then the name of the data should be `NAME-In.nc` for the inputs and `NAME-Out.nc` for outputs that'll be check. +- if `TranslateNAME` is the name of the translate test class +- then the name of the data is expected be `NAME-In.nc` for inputs and `NAME-Out.nc` for outputs that'll be checked. -The test runs via the `pytest` harness and can be triggered with the `pytest` commands. +The tests run via the `pytest` harness and can be triggered with the `pytest` commands. Options ares: -- --insert-assert-print: Print statements that would be substituted for insert_assert(), instead of writing to files -- --insert-assert-fail: Fail tests which include one or more insert_assert() calls -- --backend=BACKEND: Backend to execute the test with, can only be one. -- --which_modules=WHICH_MODULES: Whitelist of modules to run. Only the part after Translate, e.g. in TranslateXYZ it'd be XYZ -- --skip_modules=SKIP_MODULES: Blacklist of modules to not run. Only the part after Translate, e.g. in TranslateXYZ it'd be XYZ -- --which_rank=WHICH_RANK: Restrict test to a single rank -- --data_path=DATA_PATH: Path of Netcdf input and outputs. Naming pattern needs to be XYZ-In and XYZ-Out for a test class named TranslateXYZ -- --threshold_overrides_file=THRESHOLD_OVERRIDES_FILE: Path to a yaml overriding the default error threshold for a custom value. -- --print_failures: Print the failures detail. Default to True. -- --failure_stride=FAILURE_STRIDE: How many indices of failures to print from worst to best. Default to 1. -- --grid=GRID: Grid loading mode. "file" looks for "Grid-Info.nc", "compute" does the same but recomputes MetricTerms, "default" creates a simple grid with no metrics terms. Default to "file". -- --topology=TOPOLOGY Topology of the grid. "cubed-sphere" means a 6-faced grid, "doubly-periodic" means a 1 tile grid. Default to "cubed-sphere". -- --multimodal_metric: Use the multi-modal float metric. Default to False. - -More options of `pytest` are available when doing `pytest --help`. +- `--insert-assert-print`: Print statements that would be substituted for `insert_assert()`, instead of writing to files. +- `--insert-assert-fail`: Fail tests which include one or more `insert_assert()` calls. +- `--backend=`: Backend to execute the test with. Can only be one. +- `--which_modules=`: List of modules to run. Only the part after _Translate_, e.g. for `TranslateXYZ` the name would be `XYZ`. +- `--skip_modules=`: List of modules to skip. Only the part after _Translate_, e.g. for `TranslateXYZ` the name would be `XYZ`. +- `--which_rank=`: Restrict test to a single rank. +- `--data_path=`: Path of NetCDF inputs and outputs. The expected naming pattern is `XYZ-In.nc` and `XYZ-Out.nc` for a test class named `TranslateXYZ`. +- `--threshold_overrides_file=`: Path to a yaml file overriding the default error thresholds with (granular) custom values. +- `--print_failures=`: Print failures in detail. Defaults to `True`. +- `--failure_stride=`: How many indices of failures to print from worst to best. Defaults to 1. +- `--grid=<"file"|"compute"|"default">`: Grid loading mode. `"file"` looks for `"Grid-Info.nc"`, `"compute"` does the same but recomputes MetricTerms, `"default"` creates a simple grid with no metrics terms. Defaults to `"file"`. +- `--topology=<"cubed-sphere"|"doubly-periodic">`: Topology of the grid. `"cubed-sphere"` means a 6-faced grid, `"doubly-periodic"` means a 1 tile grid. Defaults to `"cubed-sphere"`. +- `--multimodal_metric=`: Use the multi-modal float metric. Defaults to `False`. + +To list all options of `pytest`, try `pytest --help`. ## Metrics -There is three state of a test in `pytest`: FAIL, PASS and XFAIL (expected fail). To clear the PASS status, the output data contained in `NAME-Out.nc` is compared to the computed data via the `TranslateNAME` test. Because this system was developed to port Fortran numerics to many targets (mostly C, but also Python, and CPU/GPU), we can't rely on bit-to-bit comparison and have been developing a couple of metrics. +There are three exit states for a test in `pytest`: `FAIL`, `PASS`, and `XFAIL` (expected fail). To clear the `PASS` status, the output data contained in `NAME-Out.nc` is compared to the computed data via the `TranslateNAME` test. Because this system was developed to port Fortran numerics to other target languages (mostly C, but also Python, and CPU/GPU), we can't rely on bit-to-bit comparison and have been developing a couple of metrics. ### Legacy metric -The legacy metric was used throughout the development of the dynamical core and microphysics scheme at 64-bit precision. It tries to solve differences over big and small amplitude values with a single formula that goes as follows: $`\|computed-reference|/reference`$ where `reference` has been purged of 0. -NaN values are considered no-pass. -To pass the metric has to be lower than `1e-14`, any value lower than `1e-18` will be considered pass by default. The pass threshold can be overridden (see below). +The legacy metric was used throughout the development of the dynamical core and microphysics scheme at 64-bit precision. It tries to solve differences over big and small amplitude values with a single formula that goes as follows: + +$`\|computed-reference\| / reference`$ + +where `reference` has been purged of 0. `NaN` values are considered no-pass. + +To pass, the metric has to be lower than `1e-14`, any value lower than `1e-18` will be considered pass by default. These thresholds can be overridden (see below). ### Multi-modal metric -Moving to mixed precision code, the legacy metric didn't give enough flexibility to account for 32-bit precision errors that could accumulate. Another metric was built with the intent of breaking the one-fit-all concept and giving back flexibility. The metric is a combination of three differences: +Moving to mixed precision code, the legacy metric didn't give enough flexibility to account for 32-bit precision errors that could accumulate. The multi-modal metric was built with the intent of breaking the "one-threshold-fits-all" concept and giving back flexibility. The metric is a combination of three differences: -- _Absolute Difference_ ($`|computed-reference|`. +For _near zero_ override, `ignore_near_zero_errors` is specified to allow some fields to pass with higher relative error if the absolute error is very small. Additionally, it is also possible to define a global near zero value for all remaining fields not specified in `ignore_near_zero_errors`. This is done by specifying `all_other_near_zero`. Override yaml file should have one of the following formats: -### One near zero value for all variables +#### One near zero value for all variables -```Stencil_name: +```yaml +Stencil_name: - backend: max_error: near_zero: ignore_near_zero_errors: - - - - - - ... + - + - + - ... ``` -### Variable specific near zero value +#### Variable specific near zero values -```Stencil_name: +```yaml +Stencil_name: - backend: max_error: ignore_near_zero_errors: - : - : - ... + : + : + ... ``` -### [optional] Global near zero value for remaining fields +#### [optional] Global near zero value for remaining fields -```Stencil_name: +```yaml +Stencil_name: - backend: max_error: ignore_near_zero_errors: - : - : - all_other_near_zero: - ... + : + : + all_other_near_zero: ``` where fields other than `var1` and `var2` will use `global_value`. -### Multimodal overrides +#### Multimodal overrides -```Stencil_name: +```yaml +Stencil_name: - backend: multimodal: absolute_eps: diff --git a/ndsl/viz/__init__.py b/ndsl/viz/__init__.py new file mode 100644 index 00000000..38f61a6f --- /dev/null +++ b/ndsl/viz/__init__.py @@ -0,0 +1,4 @@ +from .cube_sphere import plot_cube_sphere + + +__all__ = ["plot_cube_sphere"] diff --git a/ndsl/viz/cube_sphere.py b/ndsl/viz/cube_sphere.py new file mode 100644 index 00000000..09018d1f --- /dev/null +++ b/ndsl/viz/cube_sphere.py @@ -0,0 +1,36 @@ +import numpy as np +from cartopy import crs as ccrs +from matplotlib import pyplot as plt + +from ndsl import Quantity, ndsl_log +from ndsl.comm.communicator import Communicator +from ndsl.grid import GridData +from ndsl.viz.fv3 import pcolormesh_cube + + +def plot_cube_sphere( + quantity: Quantity, + k_level: int, + comm: Communicator, + grid_data: GridData, + save_to_path: str, +): + if len(quantity.shape) < 2 or len(quantity.shape) > 3: + ndsl_log.error( + f"[Plot Cube] Can't plot quantity with shape == {quantity.shape}" + ) + return + + data = comm.gather(quantity) + lat = comm.gather(grid_data.lat) + lon = comm.gather(grid_data.lon) + + if comm.rank == 0: + fig, ax = plt.subplots(1, 1, subplot_kw={"projection": ccrs.Robinson()}) + pcolormesh_cube( + lat.view[:] * 180.0 / np.pi, + lon.view[:] * 180.0 / np.pi, + data.view[:] if len(data.shape) == 3 else data.view[:, :, :, k_level], + ax=ax, + ) + fig.savefig(save_to_path) diff --git a/ndsl/viz/fv3/README.md b/ndsl/viz/fv3/README.md new file mode 100644 index 00000000..f6623c02 --- /dev/null +++ b/ndsl/viz/fv3/README.md @@ -0,0 +1,14 @@ +# Acknowledgment + +This code was lifted from and developped by AI2 under the MIT license (see below). + +## MIT License + +The MIT License (MIT) +Copyright (c) 2019, The Allen Institute for Artificial Intelligence + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/ndsl/viz/fv3/__init__.py b/ndsl/viz/fv3/__init__.py new file mode 100644 index 00000000..8f1ff012 --- /dev/null +++ b/ndsl/viz/fv3/__init__.py @@ -0,0 +1,41 @@ +from ._constants import ( + COORD_X_CENTER, + COORD_X_OUTER, + COORD_Y_CENTER, + COORD_Y_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, + VAR_LON_CENTER, + VAR_LON_OUTER, +) +from ._plot_cube import pcolormesh_cube, plot_cube +from ._plot_diagnostics import plot_diurnal_cycle, plot_time_series +from ._plot_helpers import infer_cmap_params +from ._styles import use_colorblind_friendly_style, wong_palette +from ._timestep_histograms import ( + plot_daily_and_hourly_hist, + plot_daily_hist, + plot_hourly_hist, +) + + +__all__ = [ + "plot_daily_and_hourly_hist", + "plot_daily_hist", + "plot_hourly_hist", + "plot_cube", + "pcolormesh_cube", + "plot_diurnal_cycle", + "plot_time_series", + "infer_cmap_params", + "use_colorblind_friendly_style", + "wong_palette", + "COORD_X_CENTER", + "COORD_Y_CENTER", + "COORD_X_OUTER", + "COORD_Y_OUTER", + "VAR_LON_CENTER", + "VAR_LAT_CENTER", + "VAR_LON_OUTER", + "VAR_LAT_OUTER", +] diff --git a/ndsl/viz/fv3/_constants.py b/ndsl/viz/fv3/_constants.py new file mode 100644 index 00000000..4eb699a3 --- /dev/null +++ b/ndsl/viz/fv3/_constants.py @@ -0,0 +1,9 @@ +COORD_X_CENTER = "x" +COORD_X_OUTER = "x_interface" +COORD_Y_CENTER = "y" +COORD_Y_OUTER = "y_interface" +VAR_LON_CENTER = "lon" +VAR_LAT_CENTER = "lat" +VAR_LON_OUTER = "lonb" +VAR_LAT_OUTER = "latb" +INIT_TIME_DIM = "initialization_time" diff --git a/ndsl/viz/fv3/_masking.py b/ndsl/viz/fv3/_masking.py new file mode 100644 index 00000000..80280e15 --- /dev/null +++ b/ndsl/viz/fv3/_masking.py @@ -0,0 +1,117 @@ +import numpy as np + + +def _mask_antimeridian_quads(lonb: np.ndarray, central_longitude: float): + """Computes mask of cubed-sphere tile grid quadrilaterals bisected by a + projection system's antimeridian, in order to avoid cartopy plotting + artifacts + + Args: + lonb (np.ndarray): + Array of grid edge longitudes, of dimensions (npy + 1, npx + 1, + tile) + central_longitude (float): + Central longitude from which the antimeridian is computed + + Returns: + mask (np.ndarray): + Boolean array of grid centers, False = excluded, of dimensions + (npy, npx, tile) + + + Example: + masked_array = np.where( + mask_antimeridian_quads(lonb, central_longitude), + array, + np.nan + ) + """ + antimeridian = (central_longitude + 180.0) % 360.0 + mask = np.full([lonb.shape[0] - 1, lonb.shape[1] - 1, lonb.shape[2]], True) + for tile in range(6): + tile_lonb = lonb[:, :, tile] + tile_mask = mask[:, :, tile] + for ix in range(tile_lonb.shape[0] - 1): + for iy in range(tile_lonb.shape[1] - 1): + vertex_indices = ([ix, ix + 1, ix, ix + 1], [iy, iy, iy + 1, iy + 1]) + vertices = tile_lonb[vertex_indices] + if ( + sum(_periodic_equal_or_less_than(vertices, antimeridian)) != 4 + and sum(_periodic_greater_than(vertices, antimeridian)) != 4 + and sum((_periodic_difference(vertices, antimeridian) < 90.0)) == 4 + ): + tile_mask[ix, iy] = False + mask[:, :, tile] = tile_mask + + return mask + + +def _periodic_equal_or_less_than(x1, x2, period=360.0): + """Compute whether x1 is less than or equal to x2, where + the difference between the two is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Less_than_or_equal (Bool): + Whether x1 is less than or equal to x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + np.where(x1 - x2 <= 0, True, False), + np.where( + x1 - x2 >= 0, + np.where(x1 - (x2 + period) <= 0, True, False), + np.where((x1 + period) - x2 <= 0, True, False), + ), + ) + + +def _periodic_greater_than(x1, x2, period=360.0): + """Compute whether x1 is greater than x2, where + the difference between the two is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Greater_than (Bool): + Whether x1 is greater than x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + np.where(x1 - x2 > 0, True, False), + np.where( + x1 - x2 >= 0, + np.where(x1 - (x2 + period) > 0, True, False), + np.where((x1 + period) - x2 > 0, True, False), + ), + ) + + +def _periodic_difference(x1, x2, period=360.0): + """Compute difference between x1 and x2, where + the difference is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Difference (float): + Difference between x1 and x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + x1 - x2, + np.where(x1 - x2 >= 0, x1 - (x2 + period), (x1 + period) - x2), + ) diff --git a/ndsl/viz/fv3/_plot_cube.py b/ndsl/viz/fv3/_plot_cube.py new file mode 100644 index 00000000..8942d494 --- /dev/null +++ b/ndsl/viz/fv3/_plot_cube.py @@ -0,0 +1,621 @@ +from __future__ import annotations + +import os +import warnings +from functools import partial + +import cartopy +import numpy as np +import xarray as xr +from cartopy import crs as ccrs +from matplotlib import pyplot as plt + +from ._constants import ( + COORD_X_CENTER, + COORD_X_OUTER, + COORD_Y_CENTER, + COORD_Y_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, + VAR_LON_CENTER, + VAR_LON_OUTER, +) +from ._masking import _mask_antimeridian_quads +from ._plot_helpers import ( + _align_grid_var_dims, + _align_plot_var_dims, + _get_var_label, + infer_cmap_params, +) +from .grid_metadata import GridMetadata, GridMetadataFV3, GridMetadataScream + + +if os.getenv("CARTOPY_EXTERNAL_DOWNLOADER") != "natural_earth": + # workaround to host our own global-scale coastline shapefile instead + # of unreliable cartopy source + cartopy.config["downloaders"][("shapefiles", "natural_earth")].url_template = ( + "https://raw.githubusercontent.com/ai2cm/" + "vcm-ml-example-data/main/fv3net/fv3viz/coastline_shapefiles/" + "{resolution}_{category}/ne_{resolution}_{name}.zip" + ) + +WRAPPER_GRID_METADATA = GridMetadataFV3( + COORD_X_CENTER, + COORD_Y_CENTER, + COORD_X_OUTER, + COORD_Y_OUTER, + "tile", + VAR_LON_CENTER, + VAR_LON_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, +) + + +def plot_cube( + ds: xr.Dataset, + var_name: str, + grid_metadata: GridMetadata = WRAPPER_GRID_METADATA, + plotting_function: str = "pcolormesh", + ax: plt.axes = None, + row: str = None, + col: str = None, + col_wrap: int = None, + projection: ccrs.Projection = None, + colorbar: bool = True, + cmap_percentiles_lim: bool = True, + cbar_label: str = None, + coastlines: bool = True, + coastlines_kwargs: dict = None, + **kwargs, +): + """Plots an xr.DataArray containing tiled cubed sphere gridded data + onto a global map projection, with optional faceting of additional dims + + Args: + ds: + Dataset containing variable to plotted, along with the grid + variables defining cell center latitudes and longitudes and the + cell bounds latitudes and longitudes, which must share common + dimension names + var_name: + name of the data variable in `ds` to be plotted + grid_metadata: + a vcm.cubedsphere.GridMetadata data structure that + defines the names of plot and grid variable dimensions and the names + of the grid variables themselves; defaults to those used by the + fv3gfs Python wrapper (i.e., 'x', 'y', 'x_interface', 'y_interface' and + 'lat', 'lon', 'latb', 'lonb') + plotting_function: + Name of matplotlib 2-d plotting function. Available + options are "pcolormesh", "contour", and "contourf". Defaults to + "pcolormesh". + ax: + Axes onto which the map should be plotted; must be created with + a cartopy projection argument. If not supplied, axes are generated + with a projection. If ax is suppled, faceting is disabled. + row: + Name of diemnsion to be faceted along subplot rows. Must not be a + tile, lat, or lon dimension. Defaults to no row facets. + col: + Name of diemnsion to be faceted along subplot columns. Must not be + a tile, lat, or lon dimension. Defaults to no column facets. + col_wrap: + If only one of `col`, `row` is specified, number of columns to plot + before wrapping onto next row. Defaults to None, i.e. no limit. + projection: + Cartopy projection object to be used in creating axes. Ignored + if cartopy geo-axes are supplied. Defaults to Robinson projection. + colorbar: + Flag for whether to plot a colorbar. Defaults to True. + cmap_percentiles_lim: + If False, use the absolute min/max to set color limits. + If True, use 2/98 percentile values. + cbar_label: + If provided, use this as the color bar label. + coastlines: + Whether to plot coastlines on map. Default True. + coastlines_kwargs: + Dict of arguments to be passed to cartopy axes's + `coastline` function if `coastlines` flag is set to True. + **kwargs: Additional keyword arguments to be passed to the plotting function. + + Returns: + figure (plt.Figure): + matplotlib figure object onto which axes grid is created + axes (np.ndarray): + Array of `plt.axes` objects assocated with map subplots if faceting; + otherwise array containing single axes object. + handles (list): + List or nested list of matplotlib object handles associated with + map subplots if faceting; otherwise list of single object handle. + cbar (plt.colorbar): + object handle associated with figure, if `colorbar` + arg is True, else None. + facet_grid (xarray.plot.facetgrid): + xarray plotting facetgrid for multi-axes case. In single-axes case, + retunrs None. + + Example: + # plot diag winds at two times + fig, axes, hs, cbar, facet_grid = plot_cube( + diag_ds.isel(time = slice(2, 4)), + 'VGRD850', + plotting_function = "contourf", + col = "time", + coastlines = True, + colorbar = True, + vmin = -20, + vmax = 20 + ) + """ + + mappable_ds = _mappable_var(ds, var_name, grid_metadata) + array = mappable_ds[var_name].values + + kwargs["vmin"], kwargs["vmax"], kwargs["cmap"] = infer_cmap_params( + array, + vmin=kwargs.get("vmin"), + vmax=kwargs.get("vmax"), + cmap=kwargs.get("cmap"), + robust=cmap_percentiles_lim, + ) + if isinstance(grid_metadata, GridMetadataFV3): + _plot_func_short = partial( + _plot_cube_axes, + lat=mappable_ds.lat.values, + lon=mappable_ds.lon.values, + latb=mappable_ds.latb.values, + lonb=mappable_ds.lonb.values, + plotting_function=plotting_function, + **kwargs, + ) + elif isinstance(grid_metadata, GridMetadataScream): + _plot_func_short = partial( + _plot_scream_axes, + lat=mappable_ds.lat.values, + lon=mappable_ds.lon.values, + plotting_function=plotting_function, + **kwargs, + ) + else: + assert ValueError( + f"grid_metadata needs to be either GridMetadataFV3 or GridMetadataScream, \ + but got {type(grid_metadata)}" + ) + + projection = ccrs.Robinson() if not projection else projection + + if ax is None and (row or col): + # facets + facet_grid = xr.plot.FacetGrid( + data=mappable_ds, + row=row, + col=col, + col_wrap=col_wrap, + subplot_kws={"projection": projection}, + ) + facet_grid = facet_grid.map(_plot_func_short, var_name) + fig = facet_grid.fig + axes = facet_grid.axes + handles = facet_grid._mappables + else: + # single axes + if ax is None: + fig, ax = plt.subplots(1, 1, subplot_kw={"projection": projection}) + else: + fig = ax.figure + handle = _plot_func_short(array, ax=ax) + axes = np.array(ax) + handles = [handle] + facet_grid = None + + if coastlines: + coastlines_kwargs = dict() if not coastlines_kwargs else coastlines_kwargs + [ax.coastlines(**coastlines_kwargs) for ax in axes.flatten()] + + if colorbar: + if row or col: + fig.subplots_adjust( + bottom=0.1, top=0.9, left=0.1, right=0.8, wspace=0.02, hspace=0.02 + ) + cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8]) + else: + fig.subplots_adjust(wspace=0.25) + cb_ax = ax.inset_axes([1.05, 0, 0.02, 1]) + cbar = plt.colorbar(handles[0], cax=cb_ax, extend="both") + cbar.set_label(cbar_label or _get_var_label(ds[var_name].attrs, var_name)) + else: + cbar = None + + return fig, axes, handles, cbar, facet_grid + + +def _mappable_var( + ds: xr.Dataset, + var_name: str, + grid_metadata: GridMetadata = WRAPPER_GRID_METADATA, +): + """Converts a dataset into a format for plotting across cubed-sphere tiles by + checking and ordering its grid variable and plotting variable dimensions + + Args: + ds: + Dataset containing the variable to be plotted, along with grid variables. + var_name: + Name of variable to be plotted. + grid_metadata: + vcm.cubedsphere.GridMetadata object describing dim + names and grid variable names + Returns: + ds (xr.Dataset): Dataset containing variable to be plotted as well as grid + variables, all of whose dimensions are ordered for plotting. + """ + mappable_ds = xr.Dataset() + for var, dims in grid_metadata.coord_vars.items(): + mappable_ds[var] = _align_grid_var_dims(ds[var], required_dims=dims) + if isinstance(grid_metadata, GridMetadataFV3): + var_da = _align_plot_var_dims(ds[var_name], grid_metadata.y, grid_metadata.x) + return mappable_ds.merge(var_da) + elif isinstance(grid_metadata, GridMetadataScream): + return mappable_ds.merge(ds[var_name]) + + +def pcolormesh_cube( + lat: np.ndarray, lon: np.ndarray, array: np.ndarray, ax: plt.axes = None, **kwargs +): + """Plots tiled cubed sphere. This function applies nan to gridcells which cross + the antimeridian, and then iteratively plots rectangles of array which avoid nan + gridcells. This is done to avoid artifacts when plotting gridlines with the + `edgecolor` argument. In comparison to :py:func:`plot_cube`, this function takes + np.ndarrays of the lat and lon cell corners and the variable to be plotted + at cell centers, and makes only one plot on an optionally specified axes object. + + Args: + lat: + Array of latitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + lon: + Array of longitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + array: + Array of variables values at cell centers, of dimensions (tile, ny, nx) + ax: + Matplotlib geoaxes object onto which plotting function will be + called. Default None uses current axes. + **kwargs: + Keyword arguments to be passed to plotting function. + + Returns: + p_handle (obj): + matplotlib object handle associated with a segment of the map subplot + """ + all_handles = _pcolormesh_cube_all_handles(lat, lon, array, ax=ax, **kwargs) + return all_handles[-1] + + +def _pcolormesh_cube_all_handles( + lat: np.ndarray, lon: np.ndarray, array: np.ndarray, ax: plt.axes = None, **kwargs +): + if lat.shape != lon.shape: + raise ValueError("lat and lon should have the same shape") + if ax is None: + ax = plt.gca() + central_longitude = ax.projection.proj4_params["lon_0"] + array = np.where( + _mask_antimeridian_quads(lon.T, central_longitude), array.T, np.nan + ).T + # oddly a PlateCarree transform seems to be needed here even for non-PlateCarree + # projections?? very puzzling, but it seems to be the case. + kwargs["transform"] = kwargs.get("transform", ccrs.PlateCarree()) + kwargs["vmin"] = kwargs.get("vmin", np.nanmin(array)) + kwargs["vmax"] = kwargs.get("vmax", np.nanmax(array)) + + def plot(x, y, array): + return ax.pcolormesh(x, y, array, **kwargs) + + handles = _apply_to_non_non_nan_segments( + plot, lat, center_longitudes(lon, central_longitude), array + ) + return handles + + +class UpdateablePColormesh: + def __init__(self, lat, lon, array: np.ndarray, ax: plt.axes = None, **kwargs): + self.handles = _pcolormesh_cube_all_handles(lat, lon, array, ax=ax, **kwargs) + plt.colorbar(self.handles[-1], ax=ax) + self.lat = lat + self.lon = lon + self.ax = ax + + def update(self, array): + central_longitude = self.ax.projection.proj4_params["lon_0"] + array = np.where( + _mask_antimeridian_quads(self.lon.T, central_longitude), array.T, np.nan + ).T + + iter_handles = iter(self.handles) + + def update_handle(x, y, array): + handle = next(iter_handles) + handle.set_array(array.ravel()) + + _apply_to_non_non_nan_segments(update_handle, self.lat, self.lon, array) + + +def _apply_to_non_non_nan_segments(func, lat, lon, array): + """ + Applies func to disjoint rectangular segments of array covering all non-nan values. + + Args: + func: + Function to be applied to non-nan segments of array. + lat: + Array of latitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + lon: + Array of longitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + array: + Array of variables values at cell centers, of dimensions (tile, ny, nx) + + Returns: + list of return values of func + """ + all_handles = [] + for tile in range(array.shape[0]): + x = lon[tile, :, :] + y = lat[tile, :, :] + for x_plot, y_plot, array_plot in _segment_plot_inputs(x, y, array[tile, :, :]): + all_handles.append(func(x_plot, y_plot, array_plot)) + return all_handles + + +def _segment_plot_inputs(x, y, masked_array): + """Takes in two arrays at corners of grid cells and an array at grid cell centers + which may contain NaNs. Yields 3-tuples of rectangular segments of + these arrays which cover all non-nan points without duplicates, and don't contain + NaNs. + """ + is_nan = np.isnan(masked_array) + if np.sum(is_nan) == 0: # contiguous section, just plot it + if np.product(masked_array.shape) > 0: + yield (x, y, masked_array) + else: + x_nans = np.sum(is_nan, axis=1) / is_nan.shape[1] + y_nans = np.sum(is_nan, axis=0) / is_nan.shape[0] + if x_nans.max() >= y_nans.max(): # most nan-y line is in first dimension + i_split = x_nans.argmax() + if x_nans[i_split] == 1.0: # split cleanly along line + yield from _segment_plot_inputs( + x[: i_split + 1, :], + y[: i_split + 1, :], + masked_array[:i_split, :], + ) + yield from _segment_plot_inputs( + x[i_split + 1 :, :], + y[i_split + 1 :, :], + masked_array[i_split + 1 :, :], + ) + else: + # split to create segments of complete nans + # which subsequent recursive calls will split on and remove + i_start = 0 + i_end = 1 + while i_end < is_nan.shape[1]: + while ( + i_end < is_nan.shape[1] + and is_nan[i_split, i_start] == is_nan[i_split, i_end] + ): + i_end += 1 + # we have a largest-possible contiguous segment of nans/not nans + yield from _segment_plot_inputs( + x[:, i_start : i_end + 1], + y[:, i_start : i_end + 1], + masked_array[:, i_start:i_end], + ) + i_start = i_end # start the next segment + else: + # put most nan-y line in first dimension + # so the first part of this if block catches it + yield from _segment_plot_inputs( + x.T, + y.T, + masked_array.T, + ) + + +def center_longitudes(lon_array, central_longitude): + return np.where( + lon_array < (central_longitude + 180.0) % 360.0, + lon_array, + lon_array - 360.0, + ) + + +def _validate_cube_shape(lat_shape, lon_shape, latb_shape, lonb_shape, array_shape): + if (lon_shape[-1] != 6) or (lat_shape[-1] != 6) or (array_shape[-1] != 6): + raise ValueError( + """Last axis of each array must have six elements for + cubed-sphere tiles.""" + ) + + if ( + (lon_shape[0] != lat_shape[0]) + or (lat_shape[0] != array_shape[0]) + or (lon_shape[1] != lat_shape[1]) + or (lat_shape[1] != array_shape[1]) + ): + raise ValueError( + """Horizontal axis lengths of lat and lon must be equal to + those of array.""" + ) + + if (len(lonb_shape) != 3) or (len(latb_shape) != 3) or (len(array_shape) != 3): + raise ValueError("Lonb, latb, and data_var each must be 3-dimensional.") + + if (lonb_shape[-1] != 6) or (latb_shape[-1] != 6) or (array_shape[-1] != 6): + raise ValueError( + "Tile axis of each array must have six elements for cubed-sphere tiles." + ) + + if ( + (lonb_shape[0] != latb_shape[0]) + or (latb_shape[0] != (array_shape[0] + 1)) + or (lonb_shape[1] != latb_shape[1]) + or (latb_shape[1] != (array_shape[1] + 1)) + ): + raise ValueError( + """Horizontal axis lengths of latb and lonb + must be one greater than those of array.""" + ) + + if (len(lon_shape) != 3) or (len(lat_shape) != 3) or (len(array_shape) != 3): + raise ValueError("Lon, lat, and data_var each must be 3-dimensional.") + + +def _plot_cube_axes( + array: np.ndarray, + lat: np.ndarray, + lon: np.ndarray, + latb: np.ndarray, + lonb: np.ndarray, + plotting_function: str, + ax: plt.axes = None, + **kwargs, +): + """Plots tiled cubed sphere for a given subplot axis, + using np.ndarrays for all data + + Args: + array: + Array of variables values at cell centers, of dimensions (npy, npx, + tile) + lat: + Array of latitudes of cell centers, of dimensions (npy, npx, tile) + lon: + Array of longitudes of cell centers, of dimensions (npy, npx, tile) + latb: + Array of latitudes of cell edges, of dimensions (npy + 1, npx + 1, + tile) + lonb: + Array of longitudes of cell edges, of dimensions (npy + 1, npx + 1, + tile) + plotting_function: + Name of matplotlib 2-d plotting function. Available options + are "pcolormesh", "contour", and "contourf". + ax: + Matplotlib geoaxes object onto which plotting function will be + called. Default None uses current axes. + **kwargs: + Keyword arguments to be passed to plotting function. + + Returns: + p_handle (obj): + matplotlib object handle associated with map subplot + """ + _validate_cube_shape(lon.shape, lat.shape, lonb.shape, latb.shape, array.shape) + + if ax is None: + ax = plt.gca() + + if plotting_function in ["pcolormesh", "contour", "contourf"]: + _plotting_function = getattr(ax, plotting_function) + else: + raise ValueError( + """Plotting functions only include pcolormesh, contour, + and contourf.""" + ) + + if "vmin" not in kwargs: + kwargs["vmin"] = np.nanmin(array) + + if "vmax" not in kwargs: + kwargs["vmax"] = np.nanmax(array) + + if np.isnan(kwargs["vmin"]): + kwargs["vmin"] = -0.1 + if np.isnan(kwargs["vmax"]): + kwargs["vmax"] = 0.1 + + if plotting_function != "pcolormesh": + if "levels" not in kwargs: + kwargs["n_levels"] = 11 if "n_levels" not in kwargs else kwargs["n_levels"] + kwargs["levels"] = np.linspace( + kwargs["vmin"], kwargs["vmax"], kwargs["n_levels"] + ) + + central_longitude = ax.projection.proj4_params["lon_0"] + + masked_array = np.where( + _mask_antimeridian_quads(lonb, central_longitude), array, np.nan + ) + + for tile in range(6): + if plotting_function == "pcolormesh": + x = lonb[:, :, tile] + y = latb[:, :, tile] + else: + # contouring + x = center_longitudes(lon[:, :, tile], central_longitude) + y = lat[:, :, tile] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + p_handle = _plotting_function( + x, y, masked_array[:, :, tile], transform=ccrs.PlateCarree(), **kwargs + ) + + ax.set_global() + + return p_handle + + +def _plot_scream_axes( + array: np.ndarray, + lat: np.ndarray, + lon: np.ndarray, + plotting_function: str, + ax: plt.axes = None, + **kwargs, +): + if ax is None: + ax = plt.gca() + if plotting_function in ["pcolormesh", "contour", "contourf"]: + mapping = { + "pcolormesh": "tripcolor", + "contour": "tricontour", + "contourf": "tricontourf", + } + _plotting_function = getattr(ax, mapping[plotting_function]) + else: + raise ValueError( + """Plotting functions only include pcolormesh, contour, + and contourf.""" + ) + if "vmin" not in kwargs: + kwargs["vmin"] = np.nanmin(array) + + if "vmax" not in kwargs: + kwargs["vmax"] = np.nanmax(array) + + if np.isnan(kwargs["vmin"]): + kwargs["vmin"] = -0.1 + if np.isnan(kwargs["vmax"]): + kwargs["vmax"] = 0.1 + + if plotting_function != "pcolormesh": + if "levels" not in kwargs: + kwargs["n_levels"] = 11 if "n_levels" not in kwargs else kwargs["n_levels"] + kwargs["levels"] = np.linspace( + kwargs["vmin"], kwargs["vmax"], kwargs["n_levels"] + ) + lon = np.where(lon > 180, lon - 360, lon) + p_handle = _plotting_function( + lon.flatten(), + lat.flatten(), + array.flatten(), + transform=ccrs.PlateCarree(), + **kwargs, + ) + ax.set_global() + return p_handle diff --git a/ndsl/viz/fv3/_plot_diagnostics.py b/ndsl/viz/fv3/_plot_diagnostics.py new file mode 100644 index 00000000..9c102759 --- /dev/null +++ b/ndsl/viz/fv3/_plot_diagnostics.py @@ -0,0 +1,125 @@ +""" +Some helper functions for creating diagnostic plots. + +These are specifically for usage in fv3net. + +Uses the general purpose plotting functions in +fv3viz such as plot_cube. + + +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from scipy.stats import binned_statistic + +from ._constants import COORD_X_CENTER, COORD_Y_CENTER, INIT_TIME_DIM + + +STACK_DIMS = ["tile", INIT_TIME_DIM, COORD_X_CENTER, COORD_Y_CENTER] + + +def _mask_nan_lines(x, y): + nan_mask = np.isfinite(y) + return np.array(x)[nan_mask], np.array(y)[nan_mask] + + +def plot_diurnal_cycle( + merged_ds, var, stack_dims=STACK_DIMS, num_time_bins=24, title=None, ylabel=None +): + """ + + Args: + merged_ds (xr.dataset): + can either provide a merged dataset with a "dataset" dim + that will be used to plot separate lines for each variable, or a + single dataset with no "dataset" dim + var (str): + name of variable to plot + num_time_bins (int): + number of bins per day + title(str): + optional plot title + + Returns: + matplotlib figure + """ + plt.clf() + fig = plt.figure() + if "dataset" not in merged_ds.dims: + merged_ds = xr.concat([merged_ds], "dataset") + for label in merged_ds["dataset"].values: + # TODO this function mixes computation, plotting, and implicitly + # I/O via deferred dask calculations. + # and should be extensively refactored. + ds = merged_ds.sel(dataset=label) + if len([dim for dim in ds.dims if dim in stack_dims]) > 1: + ds = ds.stack(sample=stack_dims).dropna("sample") + local_time = ds["local_time"].values.flatten() + data_var = ds[var].values.flatten() + bin_means, bin_edges, _ = binned_statistic( + local_time, data_var, bins=num_time_bins + ) + bin_centers = [ + 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(num_time_bins) + ] + bin_centers, bin_means = _mask_nan_lines(bin_centers, bin_means) + plt.plot(bin_centers, bin_means, label=label) + plt.xlabel("local_time [hr]") + plt.ylabel(ylabel or var) + plt.legend(loc="lower left") + if title: + plt.title(title) + return fig + + +# function below here are from the previous design and probably outdated +# leaving for now as it might be adapted to work with new design + + +def plot_time_series( + ds, + vars_to_plot, + output_dir, + plot_filename="time_series.png", + time_var=INIT_TIME_DIM, + xlabel=None, + ylabel=None, + title=None, +): + """Plot one or more variables as a time series. + + Args: + ds (xr.dataset): + dataset containing time series variables to plot + vars_to_plot(list[str]): + data variables to plot + output_dir (str): + output directory to save figure into + plot_filename (str): + filename to save figure to + time_var (str): + name of time dimension + xlabel (str): + x axis label + ylabel (str): + y axis label + title (str): + plot title + Returns: + matplotlib figure + """ + plt.clf() + for var in vars_to_plot: + time = ds[time_var].values + plt.plot(time, ds[var].values, label=var) + if xlabel: + plt.xlabel(xlabel) + if ylabel: + plt.ylabel(ylabel) + plt.legend() + if title: + plt.title(title) + plt.savefig(os.path.join(output_dir, plot_filename)) diff --git a/ndsl/viz/fv3/_plot_helpers.py b/ndsl/viz/fv3/_plot_helpers.py new file mode 100644 index 00000000..75da6983 --- /dev/null +++ b/ndsl/viz/fv3/_plot_helpers.py @@ -0,0 +1,172 @@ +import textwrap +from typing import Optional, Tuple + +import numpy as np + + +def _align_grid_var_dims(da, required_dims): + missing_dims = set(required_dims).difference(da.dims) + if len(missing_dims) > 0: + raise ValueError( + f"Grid variable {da.name} missing dims {missing_dims}. " + "Incompatible grid metadata may have been passed." + ) + redundant_dims = set(da.dims).difference(required_dims) + if len(redundant_dims) == 0: + da_out = da.transpose(*required_dims) + else: + redundant_dims_index = {dim: 0 for dim in redundant_dims} + da_out = ( + da.isel(redundant_dims_index) + .drop_vars(redundant_dims, errors="ignore") + .transpose(*required_dims) + ) + return da_out + + +def _align_plot_var_dims(da, coord_y_center, coord_x_center): + first_dims = [coord_y_center, coord_x_center, "tile"] + missing_dims = set(first_dims).difference(set(da.dims)) + if len(missing_dims) > 0: + raise ValueError( + f"Data array to be plotted {da.name} missing dims {missing_dims}. " + "Incompatible grid metadata may have been passed." + ) + rest = set(da.dims).difference(set(first_dims)) + xpose_dims = first_dims + list(rest) + return da.transpose(*xpose_dims) + + +def _min_max_from_percentiles(x, min_percentile=2, max_percentile=98): + """Use +/- small percentile to determine bounds for colorbar. Avoids the case + where an outlier in the data causes the color scale to be washed out. + + Args: + x: array of data values + min_percentile: lower percentile to use instead of absolute min + max_percentile: upper percentile to use instead of absolute max + + Returns: + Tuple of values at min_percentile, max_percentile + """ + x = np.array(x).flatten() + x = x[~np.isnan(x)] + if len(x) == 0: + # all values of x are equal to np.nan + xmin, xmax = np.nan, np.nan + else: + xmin, xmax = np.percentile(x, [min_percentile, max_percentile]) + return xmin, xmax + + +def _infer_color_limits( + xmin: float, xmax: float, vmin: float = None, vmax: float = None, cmap: str = None +): + """ "auto-magical" handling of color limits and colormap if not supplied by + user + + Args: + xmin (float): + Smallest value in data to be plotted + xmax (float): + Largest value in data to be plotted + vmin (float, optional): + Colormap minimum value. Default None. + vmax (float, optional): + Colormap minimum value. Default None. + cmap (str, optional): + Name of colormap. Default None. + + Returns: + vmin (float) + Inferred colormap minimum value if not supplied, or user value if + supplied. + vmax (float) + Inferred colormap maximum value if not supplied, or user value if + supplied. + cmap (str) + Inferred colormap if not supplied, or user value if supplied. + + Example: + # choose limits and cmap for data spanning 0 + >>>> _infer_color_limits(-10, 20) + (-20, 20, 'RdBu_r') + """ + if vmin is None and vmax is None: + if xmin < 0 and xmax > 0: + cmap = "RdBu_r" if not cmap else cmap + vabs_max = np.max([np.abs(xmin), np.abs(xmax)]) + vmin, vmax = (-vabs_max, vabs_max) + else: + vmin, vmax = xmin, xmax + cmap = "viridis" if not cmap else cmap + elif vmin is None: + if xmin < 0 and vmax > 0: + vmin = -vmax + cmap = "RdBu_r" if not cmap else cmap + else: + vmin = xmin + cmap = "viridis" if not cmap else cmap + elif vmax is None: + if xmax > 0 and vmin < 0: + vmax = -vmin + cmap = "RdBu_r" if not cmap else cmap + else: + vmax = xmax + cmap = "viridis" if not cmap else cmap + elif not cmap: + cmap = "RdBu_r" if vmin == -vmax else "viridis" + + return vmin, vmax, cmap + + +def _get_var_label(attrs: dict, var_name: str, max_line_length: int = 30): + """Get the label for the variable on the colorbar + + Args: + attrs (dict): + Variable aattribute dict + var_name (str): + Short name of variable + max_line_length (int, optional): + Max number of characters on each line of returned label. + Defaults to 30. + + Returns: + var_label (str) + long_name [units], var_name [units] or var_name depending on attrs + """ + if "long_name" in attrs: + var_label = attrs["long_name"] + else: + var_label = var_name + if "units" in attrs: + var_label += f" [{attrs['units']}]" + return "\n".join(textwrap.wrap(var_label, max_line_length)) + + +def infer_cmap_params( + data: np.ndarray, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + cmap: Optional[str] = None, + robust: bool = False, +) -> Tuple[float, float, str]: + """Determine useful colorbar limits and cmap for given data. + + Args: + data: The data to be plotted. + vmin: Optional minimum for colorbar. + vmax: Optional maximum for colorbar. + cmap: Optional colormap to use. + robust: If true, use 2nd and 98th percentiles for colorbar limits. + + Returns: + Tuple of (vmin, vmax, cmap). + """ + if robust: + xmin, xmax = _min_max_from_percentiles(data) + else: + xmin, xmax = np.nanmin(data), np.nanmax(data) + vmin, vmax, cmap = _infer_color_limits(xmin, xmax, vmin, vmax, cmap) + return vmin, vmax, cmap diff --git a/ndsl/viz/fv3/_styles.py b/ndsl/viz/fv3/_styles.py new file mode 100644 index 00000000..2eb221ea --- /dev/null +++ b/ndsl/viz/fv3/_styles.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt +from cycler import cycler + + +# adapted from https://davidmathlogic.com/colorblind +wong_palette = [ + "#56B4E9", + "#E69F00", + "#009E73", + "#0072B2", + "#D55E00", + "#CC79A7", + "#F0E442", # put yellow last, remove black +] + + +def use_colorblind_friendly_style(): + plt.rcParams["axes.prop_cycle"] = cycler("color", wong_palette) diff --git a/ndsl/viz/fv3/_timestep_histograms.py b/ndsl/viz/fv3/_timestep_histograms.py new file mode 100644 index 00000000..38985478 --- /dev/null +++ b/ndsl/viz/fv3/_timestep_histograms.py @@ -0,0 +1,36 @@ +import datetime +from typing import Sequence, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.axes import Axes + + +def plot_daily_and_hourly_hist( + time_list: Sequence[Union[datetime.datetime, np.datetime64]], +) -> plt.figure: + """Given a sequence of datetimes (anything that can be handled by pandas) create + and return 2-subplot figure with histograms of daily and hourly counts.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 3)) + plot_daily_hist(axes[0], time_list) + plot_hourly_hist(axes[1], time_list) + fig.suptitle(f"total count: {len(time_list)}") + plt.tight_layout() + return fig + + +def plot_daily_hist(ax: Axes, time_list: Sequence[datetime.datetime]): + """Given list of datetimes, plot histogram of count per calendar day on ax""" + ser = pd.Series(time_list) + groupby_list = [ser.dt.year, ser.dt.month, ser.dt.day] + ser.groupby(groupby_list).count().plot(ax=ax, kind="bar", title="Daily count") + ax.set_ylabel("Count") + + +def plot_hourly_hist(ax: Axes, time_list: Sequence[datetime.datetime]): + """Given list of datetimes, plot histogram of count per UTC hour on ax""" + ser = pd.Series(time_list) + ser.groupby(ser.dt.hour).count().plot(ax=ax, kind="bar", title="Hourly count") + ax.set_ylabel("Count") + ax.set_xlabel("UTC hour") diff --git a/ndsl/viz/fv3/grid_metadata.py b/ndsl/viz/fv3/grid_metadata.py new file mode 100644 index 00000000..2171360e --- /dev/null +++ b/ndsl/viz/fv3/grid_metadata.py @@ -0,0 +1,47 @@ +import abc +import dataclasses + + +class GridMetadata(abc.ABC): + @property + @abc.abstractmethod + def coord_vars(self) -> dict: + ... + + +@dataclasses.dataclass +class GridMetadataFV3(GridMetadata): + x: str = "x" + y: str = "y" + x_interface: str = "x_interface" + y_interface: str = "y_interface" + tile: str = "tile" + lon: str = "lon" + lonb: str = "lonb" + lat: str = "lat" + latb: str = "latb" + + @property + def coord_vars(self): + coord_vars = { + self.lonb: [self.y_interface, self.x_interface, self.tile], + self.latb: [self.y_interface, self.x_interface, self.tile], + self.lon: [self.y, self.x, self.tile], + self.lat: [self.y, self.x, self.tile], + } + return coord_vars + + +@dataclasses.dataclass +class GridMetadataScream(GridMetadata): + ncol: str = "ncol" + lon: str = "lon" + lat: str = "lat" + + @property + def coord_vars(self): + coord_vars = { + self.lon: [self.ncol], + self.lat: [self.ncol], + } + return coord_vars diff --git a/setup.py b/setup.py index e9f1c2e6..453c4e07 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ def local_pkg(name: str, relative_path: str) -> str: "dask", # for xarray "numpy==1.26.4", "matplotlib", # for plotting in boilerplate + "cartopy", # for plotting in ndsl.viz ] @@ -59,7 +60,7 @@ def local_pkg(name: str, relative_path: str) -> str: packages=find_namespace_packages(include=["ndsl", "ndsl.*"]), include_package_data=True, url="https://github.com/NOAA-GFDL/NDSL", - version="2025.03.00", + version="2025.05.00", zip_safe=False, entry_points={ "console_scripts": [ diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index 89d368ec..f9806528 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,20 +1,14 @@ import numpy as np -import pytest +import xarray as xr from ndsl.checkpointer import SnapshotCheckpointer -from ndsl.optional_imports import xarray as xr -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") - - -@requires_xarray def test_snapshot_checkpointer_no_data(): checkpointer = SnapshotCheckpointer(rank=0) xr.testing.assert_identical(checkpointer.dataset, xr.Dataset()) -@requires_xarray def test_snapshot_checkpointer_one_snapshot(): checkpointer = SnapshotCheckpointer(rank=0) val1 = np.random.randn(2, 3, 4) @@ -33,7 +27,6 @@ def test_snapshot_checkpointer_one_snapshot(): ) -@requires_xarray def test_snapshot_checkpointer_multiple_snapshots(): checkpointer = SnapshotCheckpointer(rank=0) val1 = np.random.randn(2, 2, 3, 4) diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index 0c08d52b..5e1b90a3 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -3,13 +3,10 @@ import numpy as np import pytest +import xarray as xr from ndsl.checkpointer import SavepointThresholds, Threshold, ValidationCheckpointer from ndsl.checkpointer.validation import _clip_pace_array_to_target -from ndsl.optional_imports import xarray as xr - - -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") def get_dataset( @@ -30,7 +27,6 @@ def get_dataset( return xr.Dataset(data_vars=data_vars) -@requires_xarray def test_validation_validates_onevar_onecall(): temp_dir = tempfile.TemporaryDirectory() nx_compute = 12 @@ -65,7 +61,6 @@ def test_validation_validates_onevar_onecall(): pytest.param(1.0, 0.99, id="absolute_failure"), ], ) -@requires_xarray def test_validation_asserts_onevar_onecall(relative_threshold, absolute_threshold): temp_dir = tempfile.TemporaryDirectory() nx_compute = 12 @@ -110,7 +105,6 @@ def test_validation_asserts_onevar_onecall(relative_threshold, absolute_threshol pytest.param(1.0, 0.99, id="absolute_threshold"), ], ) -@requires_xarray def test_validation_passes_onevar_two_calls(relative_threshold, absolute_threshold): temp_dir = tempfile.TemporaryDirectory() nx_compute = 12 @@ -161,7 +155,6 @@ def test_validation_passes_onevar_two_calls(relative_threshold, absolute_thresho pytest.param(1.0, 0.99, id="absolute_failure"), ], ) -@requires_xarray def test_validation_asserts_onevar_two_calls(relative_threshold, absolute_threshold): temp_dir = tempfile.TemporaryDirectory() nx_compute = 12 @@ -214,7 +207,6 @@ def test_validation_asserts_onevar_two_calls(relative_threshold, absolute_thresh pytest.param(1.0, 0.99, id="absolute_failure"), ], ) -@requires_xarray def test_validation_asserts_twovar_onecall(relative_threshold, absolute_threshold): temp_dir = tempfile.TemporaryDirectory() nx_compute = 12 diff --git a/tests/conftest.py b/tests/conftest.py index b2ebcd30..761e773f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,12 @@ try: - import gt4py -except ModuleNotFoundError: - gt4py = None -try: - import cupy -except ModuleNotFoundError: - cupy = None + import ndsl.dsl # noqa: F401 +except ModuleNotFoundError as error: + error.msg = f"NDSL cannot be loaded because {error.msg}" + raise error + +from ndsl.optional_imports import cupy @pytest.fixture(params=["numpy", "cupy"]) @@ -17,10 +16,7 @@ def backend(request): if cupy is None and request.param.endswith("cupy"): if request.config.getoption("--gpu-only"): raise ModuleNotFoundError("cupy must be installed to run gpu tests") - else: - pytest.skip("cupy is not available for GPU backend") - elif gt4py is None and request.param.startswith("gt4py"): - pytest.skip("gt4py backend is not available") + pytest.skip("cupy is not available for GPU backend") elif request.config.getoption("--gpu-only") and not request.param.endswith("cupy"): pytest.skip("running gpu tests only") else: diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 768238e2..d550c123 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -1,5 +1,8 @@ +import os +import shutil + import pytest -from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval +from gt4py.cartesian import config as gt_config from gt4py.storage import empty, ones from ndsl import ( @@ -12,6 +15,7 @@ ) from ndsl.comm.mpi import MPI from ndsl.dsl.dace.orchestration import orchestrate +from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval def _make_storage( @@ -83,11 +87,6 @@ def __call__(self): MPI is not None, reason="relocatibility checked with a one-rank setup" ) def test_relocatability_orchestration(backend): - import os - import shutil - - from gt4py.cartesian import config as gt_config - original_root_directory = gt_config.cache_settings["root_path"] working_dir = str(os.getcwd()) @@ -142,14 +141,8 @@ def test_relocatability_orchestration(backend): MPI is not None, reason="relocatibility checked with a one-rank setup" ) def test_relocatability(backend: str): - import os - import shutil - - import gt4py - from gt4py.cartesian import config as gt_config - # Restore original dir name - gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get( + gt_config.cache_settings["dir_name"] = os.environ.get( "GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}" ) diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index e0173b7b..22b840cb 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -5,7 +5,6 @@ HorizontalExecutionMerging, ) from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline -from gt4py.cartesian.gtscript import PARALLEL, computation, interval from ndsl import ( CompilationConfig, @@ -15,6 +14,7 @@ StencilFactory, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 180b7ba2..5348f346 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,7 +1,7 @@ -from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.storage import empty, ones from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory +from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval def _make_storage( diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 2af1218d..65bf1cf2 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -1,6 +1,5 @@ import numpy as np import pytest -from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region from ndsl import ( CompilationConfig, @@ -11,6 +10,7 @@ StencilFactory, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, horizontal, interval, region from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import CompareToNumpyStencil, get_stencils_with_varied_bounds from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 986883dc..94ea894c 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -4,7 +4,6 @@ import gt4py.cartesian.gtscript import numpy as np import pytest -from gt4py.cartesian.gtscript import PARALLEL, computation, interval from ndsl import ( CompilationConfig, @@ -14,6 +13,7 @@ Quantity, StencilConfig, ) +from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage from ndsl.dsl.typing import Float, FloatField diff --git a/tests/mpi/mpi_comm.py b/tests/mpi/mpi_comm.py index a799f761..0052f777 100644 --- a/tests/mpi/mpi_comm.py +++ b/tests/mpi/mpi_comm.py @@ -1,8 +1,6 @@ -try: - from mpi4py import MPI -except ImportError: - MPI = None +from mpi4py import MPI + -if MPI is not None and MPI.COMM_WORLD.Get_size() == 1: +if MPI.COMM_WORLD.Get_size() == 1: # not run as a parallel test, disable MPI tests MPI = None diff --git a/tests/quantity/test_fieldbundle.py b/tests/quantity/test_fieldbundle.py new file mode 100644 index 00000000..f30d81c1 --- /dev/null +++ b/tests/quantity/test_fieldbundle.py @@ -0,0 +1,77 @@ +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField +from ndsl.quantity.field_bundle import FieldBundle, FieldBundleType + + +def assign_4d_field_stcl(field_4d: FieldBundleType.T("Tracers")): # type: ignore # noqa + with computation(PARALLEL), interval(...): + field_4d[0, 0, 0][1] = 63.63 + field_4d[0, 0, 0][3] = 63.63 + + +def assign_3d_field_stcl(field_3d: FloatField): + with computation(PARALLEL), interval(...): + field_3d = 121.121 + + +def test_field_bundle(): + # Grid & Factories + NX = 2 + NY = 2 + NZ = 2 + N4th = 5 + stencil_factory, quantity_factory = get_factories_single_tile(NX, NY, NZ, 1) + + # Type register + FieldBundleType.register("Tracers", (N4th,)) + + # Make stencils + assign_4d_field = stencil_factory.from_dims_halo( + func=assign_4d_field_stcl, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + assign_3d_field = stencil_factory.from_dims_halo( + func=assign_3d_field_stcl, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + # "Input" data + new_quantity_factory = FieldBundle.extend_3D_quantity_factory( + quantity_factory, {"tracers": N4th} + ) + data = new_quantity_factory.ones([X_DIM, Y_DIM, Z_DIM, "tracers"], units="kg/g") + + # Build Bundle + tracers = FieldBundle( + bundle_name="tracers", + quantity=data, + mapping={"vapor": 0, "cloud": 2}, + ) + + # Test + tracers.quantity.field[:, :, :, :] = 48.4 + tracers.quantity.field[:, :, :, 2] = 21.21 + + assign_4d_field(tracers.quantity) + + assert (tracers.quantity.field[:, :, :, 0] == 48.4).all() + assert (tracers.quantity.field[:, :, :, 1] == 63.63).all() + assert (tracers.quantity.field[:, :, :, 2] == 21.21).all() + assert (tracers.quantity.field[:, :, :, 3] == 63.63).all() + assert (tracers.quantity.field[:, :, :, 4] == 48.4).all() + + tracers.vapor.field[:] = 1000.1000 + + assert (tracers.quantity.field[:, :, :, 0] == 1000.1000).all() + assert (tracers.quantity.field[:, :, :, 1] == 63.63).all() + + assign_3d_field(tracers.cloud) + + assert (tracers.cloud.field[:] == 121.121).all() + assert (tracers.quantity.field[:, :, :, 2] == tracers.cloud.field[:]).all() + + +if __name__ == "__main__": + test_field_bundle() diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 61e92025..2b4954d1 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -5,14 +5,6 @@ from ndsl.quantity.bounds import _shift_slice -try: - import xarray as xr -except ModuleNotFoundError: - xr = None - -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") - - @pytest.fixture(params=["empty", "one", "five"]) def extent_1d(request, backend, n_halo): if request.param == "empty": @@ -260,13 +252,12 @@ def test_shift_slice(slice_in, shift, extent, slice_out): ), ], ) -@requires_xarray def test_to_data_array(quantity): - assert quantity.data_array.attrs == quantity.attrs - assert quantity.data_array.dims == quantity.dims - assert quantity.data_array.shape == quantity.extent - np.testing.assert_array_equal(quantity.data_array.values, quantity.view[:]) + assert quantity.field_as_xarray.attrs == quantity.attrs + assert quantity.field_as_xarray.dims == quantity.dims + assert quantity.field_as_xarray.shape == quantity.extent + np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: assert ( - quantity.data_array.data.ctypes.data == quantity.data.ctypes.data + quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data ), "data memory address is not equal" diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 2cdb8d49..bc39f61f 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -2,16 +2,7 @@ import pytest from ndsl import Quantity - - -try: - import gt4py -except ImportError: - gt4py = None -try: - import cupy as cp -except ImportError: - cp = None +from ndsl.optional_imports import cupy as cp @pytest.fixture @@ -72,8 +63,7 @@ def test_numpy(quantity, backend): assert quantity.np is np -@pytest.mark.skipif(gt4py is None, reason="requires gt4py") -def test_modifying_numpy_data_modifies_view(): +def test_modifying_numpy_data_modifies_view_and_field(): shape = (6, 6) data = np.zeros(shape, dtype=float) quantity = Quantity( @@ -91,11 +81,38 @@ def test_modifying_numpy_data_modifies_view(): assert quantity.view[0, 0] == 1 assert quantity.view[2, 2] == 5 assert quantity.view[4, 4] == 3 + assert quantity.field[0, 0] == 1 + assert quantity.field[2, 2] == 5 + assert quantity.field[4, 4] == 3 assert quantity.data[0, 0] == 1 assert quantity.data[2, 2] == 5 assert quantity.data[4, 4] == 3 +def test_data_and_field_access_right_full_array_and_compute_domain(): + """Test halo read/write align with data (full array) and field (compute domain)""" + shape = (6, 6) + data = np.zeros(shape, dtype=float) + quantity = Quantity( + data, + origin=(1, 1), + extent=(5, 5), + dims=["dim1", "dim2"], + units="units", + gt4py_backend="numpy", + ) + assert np.all(quantity.data == 0) + # Write compute domain - test data is written with the offset + quantity.field[:] = 11.11 + assert np.all(quantity.field == 11.11) + assert np.all(quantity.data[1:-1, 1:-1] == 11.11) + assert np.all(quantity.data[0:1, 0:1] == 0) + # Write halo and test field has been left untouched + quantity.data[0:1, 0:1] = 33 + assert np.all(quantity.data[0:1, 0:1] == 33) + assert np.all(quantity.field == 11.11) + + @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_data_exists(quantity, backend): if "numpy" in backend: @@ -104,6 +121,14 @@ def test_data_exists(quantity, backend): assert isinstance(quantity.data, cp.ndarray) +@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) +def test_field_exists(quantity, backend): + if "numpy" in backend: + assert isinstance(quantity.field, np.ndarray) + else: + assert isinstance(quantity.field, cp.ndarray) + + @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_accessing_data_does_not_break_view( data, origin, extent, dims, units, gt4py_backend @@ -118,6 +143,7 @@ def test_accessing_data_does_not_break_view( ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] + assert quantity.data[origin] == quantity.field[tuple(0 for _ in origin)] # run using cupy backend even though unused, to mark this as a "gpu" test diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index c0211fb3..574163b5 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -1,8 +1,8 @@ import numpy as np -from gt4py.cartesian.gtscript import PARALLEL, computation, interval from ndsl import QuantityFactory, StencilFactory from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index aee22533..75be7acf 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -23,12 +23,6 @@ from ndsl.performance import Timer -try: - import gt4py -except ImportError: - gt4py = None - - @pytest.fixture(params=[(1, 1), (3, 3)]) def layout(request): return request.param diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 17a58785..40595669 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -16,15 +16,10 @@ TilePartitioner, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.optional_imports import cupy as cp from ndsl.performance import Timer -try: - import cupy as cp -except ModuleNotFoundError: - cp = None - - @pytest.fixture(params=[(1, 1), (3, 3)]) def layout(request, fast): if fast and request.param == (1, 1): diff --git a/tests/test_legacy_restart.py b/tests/test_legacy_restart.py index 2034c04c..65514747 100644 --- a/tests/test_legacy_restart.py +++ b/tests/test_legacy_restart.py @@ -2,14 +2,9 @@ import tempfile import cftime - - -try: - import xarray as xr -except ModuleNotFoundError: - xr = None import numpy as np import pytest +import xarray as xr import ndsl.io as io from ndsl import ( @@ -28,8 +23,6 @@ ) -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") - TEST_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) DATA_DIRECTORY = os.path.join(TEST_DIRECTORY, "data") @@ -39,7 +32,6 @@ def layout(request): return request.param -@requires_xarray def get_c12_restart_state_list(layout, only_names, tracer_properties): total_ranks = 6 * layout[0] * layout[1] shared_buffer = {} @@ -65,7 +57,6 @@ def get_c12_restart_state_list(layout, only_names, tracer_properties): @pytest.mark.parametrize("layout", [(1, 1), (3, 3)]) @pytest.mark.cpu_only -@requires_xarray def test_open_c12_restart(layout): tracer_properties = {} only_names = None @@ -133,7 +124,6 @@ def test_open_c12_restart(layout): }, ], ) -@requires_xarray @pytest.mark.cpu_only def test_open_c12_restart_tracer_properties(layout, tracer_properties): only_names = None @@ -150,7 +140,6 @@ def test_open_c12_restart_tracer_properties(layout, tracer_properties): @pytest.mark.parametrize("layout", [(1, 1), (3, 3)]) @pytest.mark.cpu_only -@requires_xarray def test_open_c12_restart_empty_to_state_without_crashing(layout): total_ranks = 6 * layout[0] * layout[1] ny = 12 / layout[0] @@ -193,7 +182,6 @@ def test_open_c12_restart_empty_to_state_without_crashing(layout): @pytest.mark.parametrize("layout", [(1, 1), (3, 3)]) @pytest.mark.cpu_only -@requires_xarray def test_open_c12_restart_to_allocated_state_without_crashing(layout): total_ranks = 6 * layout[0] * layout[1] ny = 12 / layout[0] @@ -288,7 +276,6 @@ def result_dims(data_array, new_dims): @pytest.mark.cpu_only -@requires_xarray def test_apply_dims(data_array, new_dims, result_dims): result = _apply_dims(data_array, new_dims) np.testing.assert_array_equal(result.values, data_array.values) @@ -381,7 +368,6 @@ def test_get_rank_suffix_invalid_total_ranks(invalid_total_ranks): @pytest.mark.cpu_only -@requires_xarray def test_read_state_incorrectly_encoded_time(): with tempfile.NamedTemporaryFile(mode="w", suffix=".nc") as file: state_ds = xr.DataArray(0.0, name="time").to_dataset() @@ -391,7 +377,6 @@ def test_read_state_incorrectly_encoded_time(): @pytest.mark.cpu_only -@requires_xarray def test_read_state_non_scalar_time(): with tempfile.NamedTemporaryFile(mode="w", suffix=".nc") as file: state_ds = xr.DataArray([0.0, 1.0], dims=["T"], name="time").to_dataset() @@ -405,7 +390,6 @@ def test_read_state_non_scalar_time(): [["time", "air_temperature"], ["air_temperature"]], ids=lambda x: f"{x}", ) -@requires_xarray def test_open_c12_restart_only_names(layout, only_names): tracer_properties = {} c12_restart_state_list = get_c12_restart_state_list( diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 7a21dd78..19614486 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -5,6 +5,7 @@ import cftime import numpy as np import pytest +import xarray as xr from ndsl import ( CubedSphereCommunicator, @@ -14,11 +15,8 @@ Quantity, TilePartitioner, ) -from ndsl.optional_imports import xarray as xr -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") - logger = logging.getLogger(__name__) @@ -37,7 +35,6 @@ pytest.param((5, 4, 4), 0, 1, ("z", "y", "x_interface"), id="cell_edge"), ], ) -@requires_xarray def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 6d56dd6f..918d1ea4 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -15,12 +15,6 @@ ) -try: - import gt4py -except ImportError: - gt4py = None - - @pytest.fixture(params=[(1, 1), (3, 3)]) def layout(request): return request.param diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index e40d5210..a67b6599 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -1,16 +1,11 @@ -import tempfile - - -try: - import zarr -except ModuleNotFoundError: - zarr = None import copy import logging +import tempfile from datetime import datetime, timedelta import cftime import pytest +import xarray as xr from ndsl import CubedSpherePartitioner, DummyComm, Quantity, TilePartitioner from ndsl.constants import ( @@ -23,11 +18,12 @@ Z_DIM, ) from ndsl.monitor.zarr_monitor import ZarrMonitor, array_chunks, get_calendar -from ndsl.optional_imports import xarray as xr +from ndsl.optional_imports import RaiseWhenAccessed, zarr -requires_zarr = pytest.mark.skipif(zarr is None, reason="zarr is not installed") -requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") +requires_zarr = pytest.mark.skipif( + isinstance(zarr, RaiseWhenAccessed), reason="zarr is not installed" +) logger = logging.getLogger("test_zarr_monitor") @@ -144,7 +140,6 @@ def state_list(base_state, n_times, start_time, time_step, numpy): @requires_zarr -@requires_xarray def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: monitor = ZarrMonitor(tempdir, cube_partitioner) @@ -155,14 +150,12 @@ def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): @requires_zarr -@requires_xarray def validate_xarray_can_open(dirname): # just checking there are no crashes, validate_group checks data xr.open_zarr(dirname) @requires_zarr -@requires_xarray def validate_store(states, filename, numpy, start_time): nt = len(states) calendar = get_calendar(start_time) @@ -225,7 +218,6 @@ def validate_array_values(name, array): ], ) @requires_zarr -@requires_xarray def test_monitor_file_store_multi_rank_state( layout, nt, tmpdir_factory, shape, ny_rank_add, nx_rank_add, dims, numpy ): @@ -327,7 +319,6 @@ def test_monitor_file_store_multi_rank_state( ], ) @requires_zarr -@requires_xarray def test_array_chunks(layout, tile_array_shape, array_dims, target): result = array_chunks(layout, tile_array_shape, array_dims) assert result == target @@ -344,7 +335,6 @@ def _assert_no_nulls(dataset: "xr.Dataset"): @pytest.mark.parametrize("mask_and_scale", [True, False]) @requires_zarr -@requires_xarray def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale): store = {} @@ -359,7 +349,6 @@ def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale @requires_zarr -@requires_xarray def test_values_preserved(cube_partitioner, numpy): dims = ("y", "x") units = "m" @@ -395,7 +384,6 @@ def state_list_with_inconsistent_calendars(base_state, numpy): @requires_zarr -@requires_xarray def test_monitor_file_store_inconsistent_calendars( state_list_with_inconsistent_calendars, cube_partitioner, numpy ): @@ -444,7 +432,6 @@ def zarr_monitor_single_rank(zarr_store, cube_partitioner): @requires_zarr -@requires_xarray def test_transposed_diags_write_across_ranks(diag, cube_partitioner, zarr_store): layout = (1, 1) total_ranks = 6 * layout[0] * layout[1] @@ -470,7 +457,6 @@ def test_transposed_diags_write_across_ranks(diag, cube_partitioner, zarr_store) @requires_zarr -@requires_xarray def test_transposed_diags_write_across_timesteps(diag, zarr_monitor_single_rank): # verify that we can store transposed diags across time time_1 = cftime.DatetimeJulian(2010, 6, 20, 6, 0, 0) @@ -486,7 +472,6 @@ def test_transposed_diags_write_across_timesteps(diag, zarr_monitor_single_rank) @requires_zarr -@requires_xarray def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): time_1 = cftime.DatetimeJulian(2010, 6, 20, 6, 0, 0) time_2 = cftime.DatetimeJulian(2010, 6, 20, 6, 15, 0) @@ -504,7 +489,6 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): @requires_zarr -@requires_xarray def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_rank): time_1 = cftime.DatetimeJulian(2010, 6, 20, 6, 0, 0) time_2 = cftime.DatetimeJulian(2010, 6, 20, 6, 15, 0)