Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3d73b6a

Browse files
Issue #1507 Add validation model options (#1534)
Fixes #1507 # Description - Adds validation of model options - Add TypeSchema to validate model options, which are not DataArray or UgridDataArray - Clean up unused kwargs logic in ``Modflow6Model.__init__`` - Small refactor: move some logic to concatenate schemata to separate function, force to list to allow concatenating (hidden bug that showed up in mypy after . - Small refactor: Add TypeAlias SchemaDict # Checklist - [x] Links to correct issue - [x] Update changelog, if changes affect users - [x] PR title starts with ``Issue #nr``, e.g. ``Issue #737`` - [x] Unit tests were added - [ ] **If feature added**: Added/extended example
1 parent ab2af5e commit 3d73b6a

File tree

10 files changed

+233
-68
lines changed

10 files changed

+233
-68
lines changed

docs/api/changelog.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ Changed
5555
often than not perfectly rectangular in shape.
5656
- :func:`imod.prepare.create_partition_labels` now returns a griddata with the
5757
name ``"label"`` instead of ``"idomain"``.
58-
58+
- Upon providing the wrong type to one of the options of
59+
:class:`imod.mf6.GroundwaterFlowModel`,
60+
:class:`imod.mf6.GroundwaterTransportModel`, this will throw a
61+
``ValidationError`` upon initialization and writing.
5962

6063
[1.0.0rc3] - 2025-04-17
6164
-----------------------

imod/common/utilities/schemata.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Tuple
1+
from collections import defaultdict
2+
from collections.abc import Mapping
3+
from copy import deepcopy
4+
from typing import Any, Optional, Protocol
25

3-
from imod.schemata import BaseSchema
6+
from imod.common.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
7+
from imod.schemata import BaseSchema, SchemataDict, ValidationError
48

59

610
def filter_schemata_dict(
7-
schemata_dict: dict[str, list[BaseSchema] | Tuple[BaseSchema, ...]],
11+
schemata_dict: SchemataDict,
812
schema_types: tuple[type[BaseSchema], ...],
913
) -> dict[str, list[BaseSchema]]:
1014
"""
@@ -37,3 +41,96 @@ def filter_schemata_dict(
3741
if schema_match:
3842
d[key] = schema_match
3943
return d
44+
45+
46+
def concatenate_schemata_dicts(
47+
schemata1: SchemataDict, schemata2: SchemataDict
48+
) -> SchemataDict:
49+
"""
50+
Concatenate two schemata dictionaries. If a key is present in both
51+
dictionaries, the values are concatenated into a list. If a key is only
52+
present in one dictionary, it is added to the new dictionary as is.
53+
"""
54+
schemata = deepcopy(schemata1)
55+
for key, value in schemata2.items():
56+
if key not in schemata.keys():
57+
schemata[key] = value
58+
else:
59+
# Force to list to be able to concatenate
60+
schemata[key] = list(schemata[key]) + list(value)
61+
return schemata
62+
63+
64+
def validate_schemata_dict(
65+
schemata: SchemataDict, data: Mapping, **kwargs: Any
66+
) -> dict[str, list[ValidationError]]:
67+
"""
68+
Validate a data mapping against a schemata dictionary. Returns a dictionary
69+
of errors for each variable in the schemata dictionary. The errors are
70+
stored in a list for each variable.
71+
"""
72+
errors = defaultdict(list)
73+
for variable, var_schemata in schemata.items():
74+
for schema in var_schemata:
75+
if variable in data.keys():
76+
try:
77+
schema.validate(data[variable], **kwargs)
78+
except ValidationError as e:
79+
errors[variable].append(e)
80+
return errors
81+
82+
83+
def validation_pkg_error_message(pkg_errors: dict[str, list[ValidationError]]) -> str:
84+
messages = []
85+
for var, var_errors in pkg_errors.items():
86+
messages.append(f"- {var}")
87+
messages.extend(f" - {error}" for error in var_errors)
88+
return "\n" + "\n".join(messages)
89+
90+
91+
def pkg_errors_to_status_info(
92+
pkg_name: str,
93+
pkg_errors: dict[str, list[ValidationError]],
94+
footer_text: Optional[str],
95+
) -> StatusInfoBase:
96+
pkg_status_info = NestedStatusInfo(f"{pkg_name} package")
97+
for var_name, var_errors in pkg_errors.items():
98+
var_status_info = StatusInfo(var_name)
99+
for var_error in var_errors:
100+
var_status_info.add_error(str(var_error))
101+
pkg_status_info.add(var_status_info)
102+
pkg_status_info.set_footer_text(footer_text)
103+
return pkg_status_info
104+
105+
106+
class ValidateFuncProtocol(Protocol):
107+
"""
108+
Protocol for a method that validates a schemata dictionary, showing the
109+
call signature the method is expected to have.
110+
"""
111+
112+
def __call__(
113+
self, schemata: SchemataDict, **kwargs: Any
114+
) -> dict[str, list[ValidationError]]: ...
115+
116+
117+
def validate_with_error_message(
118+
validate_func: ValidateFuncProtocol,
119+
validate: bool,
120+
schemata: SchemataDict,
121+
**kwargs: Any,
122+
) -> None:
123+
"""
124+
Validate a validation function and create a validation error message if
125+
necessary. The validate_func is provided as an argument to allow providing
126+
overloaded methods. The validate_func should call validate_schemata_dict
127+
with a datatype of the object.
128+
"""
129+
130+
if not validate:
131+
return
132+
errors = validate_func(schemata, **kwargs)
133+
if len(errors) > 0:
134+
message = validation_pkg_error_message(errors)
135+
raise ValidationError(message)
136+
return

imod/mf6/model.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import collections
55
import inspect
66
import pathlib
7-
from copy import deepcopy
87
from pathlib import Path
98
from typing import Any, List, Optional, Tuple, Union
109

@@ -22,6 +21,12 @@
2221
from imod.common.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
2322
from imod.common.utilities.mask import _mask_all_packages
2423
from imod.common.utilities.regrid import _regrid_like
24+
from imod.common.utilities.schemata import (
25+
concatenate_schemata_dicts,
26+
pkg_errors_to_status_info,
27+
validate_schemata_dict,
28+
validate_with_error_message,
29+
)
2530
from imod.logging import LogLevel, logger, standard_log_decorator
2631
from imod.mf6.drn import Drainage
2732
from imod.mf6.ghb import GeneralHeadBoundary
@@ -30,11 +35,10 @@
3035
from imod.mf6.package import Package
3136
from imod.mf6.riv import River
3237
from imod.mf6.utilities.mf6hfb import merge_hfb_packages
33-
from imod.mf6.validation import pkg_errors_to_status_info
3438
from imod.mf6.validation_context import ValidationContext
3539
from imod.mf6.wel import GridAgnosticWell
3640
from imod.mf6.write_context import WriteContext
37-
from imod.schemata import ValidationError
41+
from imod.schemata import SchemataDict, ValidationError
3842
from imod.typing import GridDataArray
3943
from imod.util.regrid import RegridderWeightsCache
4044

@@ -51,6 +55,7 @@ def pkg_has_cleanup(pkg: Package):
5155

5256
class Modflow6Model(collections.UserDict, IModel, abc.ABC):
5357
_mandatory_packages: tuple[str, ...] = ()
58+
_init_schemata: SchemataDict = {}
5459
_model_id: Optional[str] = None
5560
_template: Template
5661

@@ -60,13 +65,27 @@ def _initialize_template(name: str) -> Template:
6065
env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
6166
return env.get_template(name)
6267

63-
def __init__(self, **kwargs):
68+
def __init__(self):
6469
collections.UserDict.__init__(self)
65-
for k, v in kwargs.items():
66-
self[k] = v
67-
6870
self._options = {}
6971

72+
@standard_log_decorator()
73+
def validate_options(
74+
self, schemata: dict, **kwargs
75+
) -> dict[str, list[ValidationError]]:
76+
return validate_schemata_dict(schemata, self._options, **kwargs)
77+
78+
def validate_init_schemata_options(self, validate: bool) -> None:
79+
"""
80+
Run the "cheap" schema validations.
81+
82+
The expensive validations are run during writing. Some are only
83+
available then: e.g. idomain to determine active part of domain.
84+
"""
85+
validate_with_error_message(
86+
self.validate_options, validate, self._init_schemata
87+
)
88+
7089
def __setitem__(self, key, value):
7190
if len(key) > 16:
7291
raise KeyError(
@@ -227,6 +246,12 @@ def validate(self, model_name: str = "") -> StatusInfoBase:
227246
bottom = dis["bottom"]
228247

229248
model_status_info = NestedStatusInfo(f"{model_name} model")
249+
# Check model options
250+
option_errors = self.validate_options(self._init_schemata)
251+
model_status_info.add(
252+
pkg_errors_to_status_info("model options", option_errors, None)
253+
)
254+
# Validate packages
230255
for pkg_name, pkg in self.items():
231256
# Check for all schemata when writing. Types and dimensions
232257
# may have been changed after initialization...
@@ -235,12 +260,9 @@ def validate(self, model_name: str = "") -> StatusInfoBase:
235260
continue # some packages can be skipped
236261

237262
# Concatenate write and init schemata.
238-
schemata = deepcopy(pkg._init_schemata)
239-
for key, value in pkg._write_schemata.items():
240-
if key not in schemata.keys():
241-
schemata[key] = value
242-
else:
243-
schemata[key] += value
263+
schemata = concatenate_schemata_dicts(
264+
pkg._init_schemata, pkg._write_schemata
265+
)
244266

245267
pkg_errors = pkg._validate(
246268
schemata=schemata,

imod/mf6/model_gwf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SimulationAllocationOptions,
4040
SimulationDistributingOptions,
4141
)
42+
from imod.schemata import TypeSchema
4243
from imod.typing import GridDataArray, StressPeriodTimesType
4344
from imod.typing.grid import zeros_like
4445
from imod.util.regrid import RegridderWeightsCache
@@ -82,6 +83,15 @@ class GroundwaterFlowModel(Modflow6Model):
8283
_model_id = "gwf6"
8384
_template = Modflow6Model._initialize_template("gwf-nam.j2")
8485

86+
_init_schemata = {
87+
"listing_file": [TypeSchema(str)],
88+
"print_input": [TypeSchema(bool)],
89+
"print_flows": [TypeSchema(bool)],
90+
"save_flows": [TypeSchema(bool)],
91+
"newton": [TypeSchema(bool)],
92+
"under_relaxation": [TypeSchema(bool)],
93+
}
94+
8595
@init_log_decorator()
8696
def __init__(
8797
self,
@@ -91,6 +101,7 @@ def __init__(
91101
save_flows: bool = False,
92102
newton: bool = False,
93103
under_relaxation: bool = False,
104+
validate: bool = True,
94105
):
95106
super().__init__()
96107
self._options = {
@@ -101,6 +112,7 @@ def __init__(
101112
"newton": newton,
102113
"under_relaxation": under_relaxation,
103114
}
115+
self.validate_init_schemata_options(validate)
104116

105117
def clip_box(
106118
self,

imod/mf6/model_gwt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from imod.logging import init_log_decorator
66
from imod.mf6.model import Modflow6Model
7+
from imod.schemata import TypeSchema
78

89

910
class GroundwaterTransportModel(Modflow6Model):
@@ -36,13 +37,21 @@ class GroundwaterTransportModel(Modflow6Model):
3637
_model_id = "gwt6"
3738
_template = Modflow6Model._initialize_template("gwt-nam.j2")
3839

40+
_init_schemata = {
41+
"listing_file": [TypeSchema(str)],
42+
"print_input": [TypeSchema(bool)],
43+
"print_flows": [TypeSchema(bool)],
44+
"save_flows": [TypeSchema(bool)],
45+
}
46+
3947
@init_log_decorator()
4048
def __init__(
4149
self,
4250
listing_file: Optional[str] = None,
4351
print_input: bool = False,
4452
print_flows: bool = False,
4553
save_flows: bool = False,
54+
validate: bool = True,
4655
):
4756
super().__init__()
4857
self._options = {
@@ -51,3 +60,4 @@ def __init__(
5160
"print_flows": print_flows,
5261
"save_flows": save_flows,
5362
}
63+
self.validate_init_schemata_options(validate)

imod/mf6/package.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import abc
44
import pathlib
5-
from collections import defaultdict
65
from copy import deepcopy
76
from pathlib import Path
87
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
@@ -24,7 +23,11 @@
2423
_regrid_like,
2524
)
2625
from imod.common.utilities.regrid_method_type import EmptyRegridMethod, RegridMethodType
27-
from imod.common.utilities.schemata import filter_schemata_dict
26+
from imod.common.utilities.schemata import (
27+
filter_schemata_dict,
28+
validate_schemata_dict,
29+
validate_with_error_message,
30+
)
2831
from imod.common.utilities.value_filters import is_valid
2932
from imod.logging import standard_log_decorator
3033
from imod.mf6.auxiliary_variables import (
@@ -37,12 +40,11 @@
3740
TRANSPORT_PACKAGES,
3841
PackageBase,
3942
)
40-
from imod.mf6.validation import validation_pkg_error_message
4143
from imod.mf6.write_context import WriteContext
4244
from imod.schemata import (
4345
AllNoDataSchema,
4446
EmptyIndexesSchema,
45-
SchemaType,
47+
SchemataDict,
4648
ValidationError,
4749
)
4850
from imod.typing import GridDataArray
@@ -63,8 +65,8 @@ class Package(PackageBase, IPackage, abc.ABC):
6365
"""
6466

6567
_pkg_id = ""
66-
_init_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
67-
_write_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
68+
_init_schemata: SchemataDict = {}
69+
_write_schemata: SchemataDict = {}
6870
_keyword_map: dict[str, str] = {}
6971
_regrid_method: RegridMethodType = EmptyRegridMethod()
7072
_template: jinja2.Template
@@ -326,17 +328,7 @@ def _write(
326328

327329
@standard_log_decorator()
328330
def _validate(self, schemata: dict, **kwargs) -> dict[str, list[ValidationError]]:
329-
errors = defaultdict(list)
330-
for variable, var_schemata in schemata.items():
331-
for schema in var_schemata:
332-
if (
333-
variable in self.dataset.keys()
334-
): # concentration only added to dataset if specified
335-
try:
336-
schema.validate(self.dataset[variable], **kwargs)
337-
except ValidationError as e:
338-
errors[variable].append(e)
339-
return errors
331+
return validate_schemata_dict(schemata, self.dataset, **kwargs)
340332

341333
def is_empty(self) -> bool:
342334
"""
@@ -355,20 +347,16 @@ def is_empty(self) -> bool:
355347
allnodata_errors = self._validate(allnodata_schemata)
356348
return len(allnodata_errors) > 0
357349

358-
def _validate_init_schemata(self, validate: bool):
350+
def _validate_init_schemata(self, validate: bool, **kwargs) -> None:
359351
"""
360352
Run the "cheap" schema validations.
361353
362354
The expensive validations are run during writing. Some are only
363355
available then: e.g. idomain to determine active part of domain.
364356
"""
365-
if not validate:
366-
return
367-
errors = self._validate(self._init_schemata)
368-
if len(errors) > 0:
369-
message = validation_pkg_error_message(errors)
370-
raise ValidationError(message)
371-
return
357+
validate_with_error_message(
358+
self._validate, validate, self._init_schemata, **kwargs
359+
)
372360

373361
def copy(self) -> Any:
374362
# All state should be contained in the dataset.

0 commit comments

Comments
 (0)