Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ repos:
args: ["--config-file", "pyproject.toml"]
additional_dependencies: [types-PyYAML]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
name: mypy-tests
args: ["--config-file", "pyproject.toml"]
additional_dependencies: [types-PyYAML]
files: tests

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
Expand Down
6 changes: 3 additions & 3 deletions tests/checkpointer/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from ndsl.checkpointer import SnapshotCheckpointer


def test_snapshot_checkpointer_no_data():
def test_snapshot_checkpointer_no_data() -> None:
checkpointer = SnapshotCheckpointer(rank=0)
xr.testing.assert_identical(checkpointer.dataset, xr.Dataset())


def test_snapshot_checkpointer_one_snapshot():
def test_snapshot_checkpointer_one_snapshot() -> None:
checkpointer = SnapshotCheckpointer(rank=0)
val1 = np.random.randn(2, 3, 4)
checkpointer("savepoint_name", val1=val1)
Expand All @@ -27,7 +27,7 @@ def test_snapshot_checkpointer_one_snapshot():
)


def test_snapshot_checkpointer_multiple_snapshots():
def test_snapshot_checkpointer_multiple_snapshots() -> None:
checkpointer = SnapshotCheckpointer(rank=0)
val1 = np.random.randn(2, 2, 3, 4)
val2 = np.random.randn(1, 3, 2, 4)
Expand Down
16 changes: 9 additions & 7 deletions tests/checkpointer/test_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@
)


def test_thresholds_no_trials():
def test_thresholds_no_trials() -> None:
checkpointer = ThresholdCalibrationCheckpointer()
with pytest.raises(InsufficientTrialsError):
checkpointer.thresholds


def test_thresholds_one_empty_trial():
def test_thresholds_one_empty_trial() -> None:
checkpointer = ThresholdCalibrationCheckpointer()
with checkpointer.trial():
pass
with pytest.raises(InsufficientTrialsError):
checkpointer.thresholds


def test_thresholds_two_empty_trials():
def test_thresholds_two_empty_trials() -> None:
checkpointer = ThresholdCalibrationCheckpointer()
for _ in range(2):
with checkpointer.trial():
pass
assert checkpointer.thresholds.savepoints == {}


def test_thresholds_one_data_trial():
def test_thresholds_one_data_trial() -> None:
checkpointer = ThresholdCalibrationCheckpointer()
with checkpointer.trial():
data = np.asarray([0.0, 0.0, 0.0])
Expand All @@ -50,7 +50,9 @@ def test_thresholds_one_data_trial():
pytest.param(1.0, [-5.0, 5.0, 10.0, 0.0], 3.0, 15.0, id="more_values"),
],
)
def test_thresholds_sufficient_trials(factor, values, rel_threshold, abs_threshold):
def test_thresholds_sufficient_trials(
factor: float, values: list[float], rel_threshold: float, abs_threshold: float
) -> None:
checkpointer = ThresholdCalibrationCheckpointer(factor=factor)
for val in values:
with checkpointer.trial():
Expand All @@ -63,7 +65,7 @@ def test_thresholds_sufficient_trials(factor, values, rel_threshold, abs_thresho
}


def test_thresholds_more_variables():
def test_thresholds_more_variables() -> None:
checkpointer = ThresholdCalibrationCheckpointer(factor=1.0)
with checkpointer.trial():
data1 = np.asarray([0.0, 0.0, 0.0])
Expand All @@ -83,7 +85,7 @@ def test_thresholds_more_variables():
}


def test_thresholds_two_calls():
def test_thresholds_two_calls() -> None:
checkpointer = ThresholdCalibrationCheckpointer(factor=1.0)
with checkpointer.trial():
data1 = np.asarray([0.0, 0.0, 0.0])
Expand Down
24 changes: 17 additions & 7 deletions tests/checkpointer/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def get_dataset(
n_savepoints: int, n_vars: int, n_ranks: int, nx: int, ny: int, nz: int
):
) -> xr.Dataset:
data_vars = {}
for i in range(n_vars):
data_vars["data{}".format(i)] = xr.DataArray(
Expand All @@ -27,7 +27,7 @@ def get_dataset(
return xr.Dataset(data_vars=data_vars)


def test_validation_validates_onevar_onecall():
def test_validation_validates_onevar_onecall() -> None:
temp_dir = tempfile.TemporaryDirectory()
nx_compute = 12
nz = 20
Expand Down Expand Up @@ -61,7 +61,9 @@ def test_validation_validates_onevar_onecall():
pytest.param(1.0, 0.99, id="absolute_failure"),
],
)
def test_validation_asserts_onevar_onecall(relative_threshold, absolute_threshold):
def test_validation_asserts_onevar_onecall(
relative_threshold: float, absolute_threshold: float
) -> None:
temp_dir = tempfile.TemporaryDirectory()
nx_compute = 12
nz = 20
Expand Down Expand Up @@ -105,7 +107,9 @@ def test_validation_asserts_onevar_onecall(relative_threshold, absolute_threshol
pytest.param(1.0, 0.99, id="absolute_threshold"),
],
)
def test_validation_passes_onevar_two_calls(relative_threshold, absolute_threshold):
def test_validation_passes_onevar_two_calls(
relative_threshold: float, absolute_threshold: float
) -> None:
temp_dir = tempfile.TemporaryDirectory()
nx_compute = 12
nz = 20
Expand Down Expand Up @@ -155,7 +159,9 @@ def test_validation_passes_onevar_two_calls(relative_threshold, absolute_thresho
pytest.param(1.0, 0.99, id="absolute_failure"),
],
)
def test_validation_asserts_onevar_two_calls(relative_threshold, absolute_threshold):
def test_validation_asserts_onevar_two_calls(
relative_threshold: float, absolute_threshold: float
) -> None:
temp_dir = tempfile.TemporaryDirectory()
nx_compute = 12
nz = 20
Expand Down Expand Up @@ -207,7 +213,9 @@ def test_validation_asserts_onevar_two_calls(relative_threshold, absolute_thresh
pytest.param(1.0, 0.99, id="absolute_failure"),
],
)
def test_validation_asserts_twovar_onecall(relative_threshold, absolute_threshold):
def test_validation_asserts_twovar_onecall(
relative_threshold: float, absolute_threshold: float
) -> None:
temp_dir = tempfile.TemporaryDirectory()
nx_compute = 12
nz = 20
Expand Down Expand Up @@ -272,6 +280,8 @@ def test_validation_asserts_twovar_onecall(relative_threshold, absolute_threshol
),
],
)
def test_clip_pace_array_to_target(array, target_shape, target_array):
def test_clip_pace_array_to_target(
array: np.ndarray, target_shape: tuple, target_array: np.ndarray
) -> None:
clipped = _clip_pace_array_to_target(array, target_shape=target_shape)
np.testing.assert_array_equal(clipped, target_array)
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/config/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ndsl import Backend


def test_backend_building():
def test_backend_building() -> None:
Backend("st:python:cpu:IJK")
Backend("st:numpy:cpu:IJK")
Backend("st:gt:cpu:IJK")
Expand All @@ -23,7 +23,7 @@ def test_backend_building():
Backend(unknown_backend)


def test_backend_operators():
def test_backend_operators() -> None:
backend_A = Backend("st:numpy:cpu:IJK")
backend_B = Backend("st:numpy:cpu:IJK")

Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import numpy as np
import pytest

Expand All @@ -6,15 +8,15 @@


@pytest.fixture(params=["numpy", pytest.param("cupy", marks=pytest.mark.gpu)])
def backend(request):
def backend(request: pytest.FixtureRequest) -> Literal["numpy"] | Literal["cupy"]:
if request.param == "cupy" and cupy is None:
raise ModuleNotFoundError("cupy must be installed to run gpu tests")

return request.param


@pytest.fixture
def ndsl_backend(backend: str):
def ndsl_backend(backend: str) -> Backend:
if backend == "numpy":
return Backend("st:numpy:cpu:IJK")

Expand All @@ -25,7 +27,7 @@ def ndsl_backend(backend: str):


@pytest.fixture
def numpy(backend: str):
def numpy(backend: str): # type: ignore[no-untyped-def]
if backend == "numpy":
return np

Expand Down
6 changes: 3 additions & 3 deletions tests/dsl/dace/stree/optimizations/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.dsl.dace.stree import StreeOptimization


def double_map(in_field: FloatField, out_field: FloatField):
def double_map(in_field: FloatField, out_field: FloatField) -> None:
with computation(PARALLEL), interval(...):
out_field = in_field

Expand All @@ -23,11 +23,11 @@ def __init__(self, stencil_factory: StencilFactory):
compute_dims=[I_DIM, J_DIM, K_DIM],
)

def __call__(self, in_field: FloatField, out_field: FloatField):
def __call__(self, in_field: FloatField, out_field: FloatField) -> None:
self.stencil(in_field, out_field)


def test_stree_roundtrip_no_opt():
def test_stree_roundtrip_no_opt() -> None:
domain = (3, 3, 4)
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend=Backend.cpu()
Expand Down
4 changes: 2 additions & 2 deletions tests/dsl/dace/stree/optimizations/test_transient_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def stencil_with_JK_offset(in_field: FloatField, out_field: FloatField) -> None:
out_field = in_field[J + 1, K + 1] + 3


def stencil_with_ddim(in_field: DDIM_TYPE, out_field: DDIM_TYPE) -> None:
def stencil_with_ddim(in_field: DDIM_TYPE, out_field: DDIM_TYPE) -> None: # type: ignore[valid-type]
with computation(PARALLEL), interval(...):
n = 0
while n < DATADIM_SIZE:
out_field[0, 0, 0][n] = in_field[0, 0, 0][n] + 4
out_field[0, 0, 0][n] = in_field[0, 0, 0][n] + 4 # type: ignore[index]
n = n + 1


Expand Down
24 changes: 12 additions & 12 deletions tests/dsl/orchestration/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ndsl.dsl.typing import FloatField


def _stencil(out: Field[float]):
def _stencil(out: Field[float]) -> None:
with computation(PARALLEL), interval(...):
out = out + 1

Expand All @@ -30,34 +30,34 @@ class AState(State):


class OrchestratedProgram:
def __init__(self, stencil_factory: StencilFactory):
def __init__(self, stencil_factory: StencilFactory) -> None:
orchestrate(obj=self, config=stencil_factory.config.dace_config)
self.stencil = stencil_factory.from_dims_halo(_stencil, [I_DIM, J_DIM, K_DIM])

def __call__(self, out_qty):
def __call__(self, out_qty) -> None:
self.stencil(out_qty)


class DSLTypeProgram(NDSLRuntime):
def __init__(self, stencil_factory: StencilFactory):
def __init__(self, stencil_factory: StencilFactory) -> None:
super().__init__(stencil_factory)
self.stencil = stencil_factory.from_dims_halo(_stencil, [I_DIM, J_DIM, K_DIM])

def __call__(self, a_quantity: Quantity, a_state: AState):
def __call__(self, a_quantity: Quantity, a_state: AState) -> None:
self.stencil(a_quantity)
self.stencil(a_state.the_quantity)


class GTTypeProgram(NDSLRuntime):
def __init__(self, stencil_factory: StencilFactory):
def __init__(self, stencil_factory: StencilFactory) -> None:
super().__init__(stencil_factory)
self.stencil = stencil_factory.from_dims_halo(_stencil, [I_DIM, J_DIM, K_DIM])

def __call__(self, a_quantity: FloatField):
def __call__(self, a_quantity: FloatField) -> None:
self.stencil(a_quantity)


def test_memory_reallocation_blind_type():
def test_memory_reallocation_blind_type() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0
)
Expand All @@ -77,7 +77,7 @@ def test_memory_reallocation_blind_type():


@pytest.mark.xfail(reason="See https://github.com/NOAA-GFDL/NDSL/issues/436")
def test_memory_reallocation_dsl_typehint():
def test_memory_reallocation_dsl_typehint() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0
)
Expand All @@ -95,7 +95,7 @@ def test_memory_reallocation_dsl_typehint():
assert (state_B.the_quantity.field[0, 0, :] == 2).all()


def test_memory_reallocation_gt4py_typehint():
def test_memory_reallocation_gt4py_typehint() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0
)
Expand All @@ -111,7 +111,7 @@ def test_memory_reallocation_gt4py_typehint():
assert (qty_E.field[0, 0, :] == 2).all()


def test_default_types_are_compiletime():
def test_default_types_are_compiletime() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0
)
Expand All @@ -121,7 +121,7 @@ def test_default_types_are_compiletime():
code(qty_A, state_A)


def test_dace_call_argument_caching():
def test_dace_call_argument_caching() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
5, 5, 2, 0, backend=Backend.cpu()
)
Expand Down
12 changes: 6 additions & 6 deletions tests/dsl/orchestration/test_reset_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
from ndsl.dsl.typing import FloatField, Int, IntFieldIJ


def reset_mask(dp1: FloatField, pe1: FloatField, mask: IntFieldIJ):
def reset_mask(dp1: FloatField, pe1: FloatField, mask: IntFieldIJ) -> None:
with computation(PARALLEL), interval(...):
dp1 = pe1[0, 0, 1] - pe1
with computation(FORWARD), interval(0, 1):
mask = 0


def conditional_copy(dp1: FloatField, pe1: FloatField, mask: IntFieldIJ):
def conditional_copy(dp1: FloatField, pe1: FloatField, mask: IntFieldIJ) -> None:
with computation(PARALLEL), interval(0, 1):
if mask == 1:
dp1 = pe1


class OrchestratedProgramm(NDSLRuntime):
class OrchestratedProgram(NDSLRuntime):
def __init__(
self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory
):
) -> None:
super().__init__(stencil_factory)

self._reset_mask = self._stencil_factory.from_dims_halo(
Expand All @@ -42,14 +42,14 @@ def mask_has_been_reset(self) -> bool:
return (self._mask.field[:] == 0).all()


def test_conditional_copy_with_mask():
def test_conditional_copy_with_mask() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
nx=4, ny=5, nz=6, nhalo=1
)

dp1 = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="")
pe1 = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="")
code = OrchestratedProgramm(stencil_factory, quantity_factory)
code = OrchestratedProgram(stencil_factory, quantity_factory)

code(dp1, pe1)

Expand Down
Loading