Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a79ad0a

Browse files
Issue #788 and # 789 gwf gwt exchanges and ims (#791)
Fixes #788 and #789 # Description now adds split models to the solution group (ims) that the source model belonged to. When writing a split simulation with flow and transport, gwfgwt exchanges are created. The models are connected based on type and domain. # Checklist - [X] Links to correct issue - [ ] Update changelog, if changes affect users - [X] PR title starts with ``Issue #nr``, e.g. ``Issue #737`` - [X] Unit tests were added - [ ] **If feature added**: Added/extended example --------- Co-authored-by: Joeri van Engelen <[email protected]>
1 parent 8f970eb commit a79ad0a

File tree

9 files changed

+184
-46
lines changed

9 files changed

+184
-46
lines changed

imod/mf6/ims.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,34 @@ def __init__(
428428
super().__init__(dict_dataset)
429429
self._validate_init_schemata(validate)
430430

431+
def remove_model_from_solution(self, modelname: str) -> None:
432+
models_in_solution = self.get_models_in_solution()
433+
if modelname not in models_in_solution:
434+
raise ValueError(
435+
f"attempted to remove model {modelname} from solution, but it was not found."
436+
)
437+
filtered_models = [m for m in models_in_solution if m != modelname]
438+
439+
if len(filtered_models) == 0:
440+
self.dataset = self.dataset.drop_vars("modelnames")
441+
else:
442+
self.dataset.update({"modelnames": ("model", filtered_models)})
443+
444+
def add_model_to_solution(self, modelname: str) -> None:
445+
models_in_solution = self.get_models_in_solution()
446+
if modelname in models_in_solution:
447+
raise ValueError(
448+
f"attempted to add model {modelname} to solution, but it was already in it."
449+
)
450+
models_in_solution.append(modelname)
451+
self.dataset.update({"modelnames": ("model", models_in_solution)})
452+
453+
def get_models_in_solution(self) -> list[str]:
454+
models_in_solution = []
455+
if "modelnames" in self.dataset.keys():
456+
models_in_solution = list(self.dataset["modelnames"].values)
457+
return models_in_solution
458+
431459

432460
def SolutionPresetSimple(
433461
modelnames, print_option="summary", csv_output=False, no_ptc=False

imod/mf6/multimodel/exchange_creator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Dict, List
2+
from typing import Dict
33

44
import numpy as np
55
import pandas as pd
@@ -85,7 +85,7 @@ def _compute_geometric_information(self) -> pd.DataFrame:
8585
@classmethod
8686
@abc.abstractmethod
8787
def _create_global_to_local_idx(
88-
cls, partition_info: List[PartitionInfo], global_cell_indices: GridDataArray
88+
cls, partition_info: list[PartitionInfo], global_cell_indices: GridDataArray
8989
) -> Dict[int, pd.DataFrame]:
9090
"""
9191
abstract method that creates for each partition a mapping from global cell indices to local cells in that
@@ -94,7 +94,7 @@ def _create_global_to_local_idx(
9494
raise NotImplementedError
9595

9696
def __init__(
97-
self, submodel_labels: GridDataArray, partition_info: List[PartitionInfo]
97+
self, submodel_labels: GridDataArray, partition_info: list[PartitionInfo]
9898
):
9999
self._submodel_labels = submodel_labels
100100

@@ -110,7 +110,9 @@ def __init__(
110110

111111
self._geometric_information = self._compute_geometric_information()
112112

113-
def create_exchanges(self, model_name: str, layers: GridDataArray) -> List[GWFGWF]:
113+
def create_gwfgwf_exchanges(
114+
self, model_name: str, layers: GridDataArray
115+
) -> list[GWFGWF]:
114116
"""
115117
Create GroundWaterFlow-GroundWaterFlow exchanges based on the submodel_labels array provided in the class
116118
constructor. The layer parameter is used to extrude the cell connection through all the layers. An exchange
@@ -193,7 +195,7 @@ def create_exchanges(self, model_name: str, layers: GridDataArray) -> List[GWFGW
193195
return exchanges
194196

195197
def _create_global_cellidx_to_local_cellid_mapping(
196-
self, partition_info: List[PartitionInfo]
198+
self, partition_info: list[PartitionInfo]
197199
) -> Dict[int, pd.DataFrame]:
198200
global_to_local_idx = self._create_global_to_local_idx(
199201
partition_info, self._global_cell_indices

imod/mf6/simulation.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import imod.mf6.exchangebase
2222
from imod.mf6.gwfgwf import GWFGWF
2323
from imod.mf6.gwfgwt import GWFGWT
24+
from imod.mf6.ims import Solution
2425
from imod.mf6.model import Modflow6Model
2526
from imod.mf6.model_gwf import GroundwaterFlowModel
2627
from imod.mf6.model_gwt import GroundwaterTransportModel
@@ -930,21 +931,21 @@ def split(self, submodel_labels: xr.DataArray) -> Modflow6Simulation:
930931
new_simulation[package_name] = package
931932

932933
for model_name, model in original_models.items():
934+
solution_name = self.get_solution_name(model_name)
935+
new_simulation[solution_name].remove_model_from_solution(model_name)
933936
for submodel_partition_info in partition_info:
934937
new_model_name = f"{model_name}_{submodel_partition_info.id}"
935938
new_simulation[new_model_name] = slice_model(
936939
submodel_partition_info, model
937940
)
941+
new_simulation[solution_name].add_model_to_solution(new_model_name)
938942

939943
exchanges = []
940944
for model_name, model in original_models.items():
941-
exchanges += exchange_creator.create_exchanges(
942-
model_name, model.domain.layer
943-
)
944-
945-
new_simulation["solver"]["modelnames"] = xr.DataArray(
946-
list(get_models(new_simulation).keys())
947-
)
945+
if isinstance(model, GroundwaterFlowModel):
946+
exchanges += exchange_creator.create_gwfgwf_exchanges(
947+
model_name, model.domain.layer
948+
)
948949

949950
new_simulation._add_modelsplit_exchanges(exchanges)
950951
new_simulation._set_exchange_options()
@@ -1003,15 +1004,16 @@ def _add_modelsplit_exchanges(self, exchanges_list: list[GWFGWF]) -> None:
10031004
def _set_exchange_options(self):
10041005
# collect some options that we will auto-set
10051006
for exchange in self["split_exchanges"]:
1006-
model_name_1 = exchange.dataset["model_name_1"].values[()]
1007-
model_1 = self[model_name_1]
1008-
exchange.set_options(
1009-
save_flows=model_1["oc"].is_budget_output,
1010-
dewatered=model_1["npf"].is_dewatered,
1011-
variablecv=model_1["npf"].is_variable_vertical_conductance,
1012-
xt3d=model_1["npf"].get_xt3d_option(),
1013-
newton=model_1.is_use_newton(),
1014-
)
1007+
if isinstance(exchange, GWFGWF):
1008+
model_name_1 = exchange.dataset["model_name_1"].values[()]
1009+
model_1 = self[model_name_1]
1010+
exchange.set_options(
1011+
save_flows=model_1["oc"].is_budget_output,
1012+
dewatered=model_1["npf"].is_dewatered,
1013+
variablecv=model_1["npf"].is_variable_vertical_conductance,
1014+
xt3d=model_1["npf"].get_xt3d_option(),
1015+
newton=model_1.is_use_newton(),
1016+
)
10151017

10161018
def _filter_inactive_cells_from_exchanges(self) -> None:
10171019
for ex in self["split_exchanges"]:
@@ -1041,6 +1043,13 @@ def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None:
10411043
active_exchange_domain = active_exchange_domain.dropna("index")
10421044
ex.dataset = ex.dataset.sel(index=active_exchange_domain["index"])
10431045

1046+
def get_solution_name(self, model_name: str) -> str:
1047+
for k, v in self.items():
1048+
if isinstance(v, Solution):
1049+
if model_name in v.dataset["modelnames"]:
1050+
return k
1051+
return None
1052+
10441053
def __repr__(self) -> str:
10451054
typename = type(self).__name__
10461055
INDENT = " "
@@ -1060,15 +1069,22 @@ def __repr__(self) -> str:
10601069
content = attrs + ["){}"]
10611070
return "\n".join(content)
10621071

1063-
def _generate_gwfgwt_exchanges(self):
1072+
def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]:
10641073
flow_models = self.get_models_of_type("gwf6")
10651074
transport_models = self.get_models_of_type("gwt6")
1066-
10671075
# exchange for flow and transport
10681076
exchanges = []
1069-
if len(flow_models) == 1 and len(transport_models) > 0:
1070-
flow_model_name = list(flow_models.keys())[0]
1071-
for transport_model_name in transport_models.keys():
1072-
exchanges.append(GWFGWT(flow_model_name, transport_model_name))
1077+
1078+
for flow_model_name in flow_models:
1079+
tpt_models_of_flow_model = []
1080+
flow_model = self[flow_model_name]
1081+
for tpt_model_name in transport_models:
1082+
tpt_model = self[tpt_model_name]
1083+
if tpt_model.domain.equals(flow_model.domain):
1084+
tpt_models_of_flow_model.append(tpt_model_name)
1085+
1086+
if len(tpt_models_of_flow_model) > 0:
1087+
for transport_model_name in tpt_models_of_flow_model:
1088+
exchanges.append(GWFGWT(flow_model_name, transport_model_name))
10731089

10741090
return exchanges

imod/tests/test_mf6/test_mf6_ims.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@
66
from imod.schemata import ValidationError
77

88

9+
def create_ims() -> imod.mf6.Solution:
10+
return imod.mf6.Solution(
11+
modelnames=["GWF_1"],
12+
print_option="summary",
13+
csv_output=False,
14+
no_ptc=True,
15+
outer_dvclose=1.0e-4,
16+
outer_maximum=500,
17+
under_relaxation=None,
18+
inner_dvclose=1.0e-4,
19+
inner_rclose=0.001,
20+
inner_maximum=100,
21+
linear_acceleration="cg",
22+
scaling_method=None,
23+
reordering_method=None,
24+
relaxation_factor=0.97,
25+
)
26+
27+
928
def test_render():
1029
ims = imod.mf6.Solution(
1130
modelnames=["GWF_1"],
@@ -65,3 +84,24 @@ def test_wrong_dtype():
6584
reordering_method=None,
6685
relaxation_factor=0.97,
6786
)
87+
88+
89+
def test_drop_and_add_model():
90+
ims = create_ims()
91+
ims.remove_model_from_solution("GWF_1")
92+
assert "modelnames" not in ims.dataset.keys()
93+
ims.add_model_to_solution("GWF_2")
94+
assert "GWF_2" in ims.dataset["modelnames"].values
95+
96+
97+
def test_remove_non_present_model():
98+
ims = create_ims()
99+
ims.remove_model_from_solution("GWF_1")
100+
with pytest.raises(ValueError):
101+
ims.remove_model_from_solution("GWF_1")
102+
103+
104+
def test_add_already_present_model():
105+
ims = create_ims()
106+
with pytest.raises(ValueError):
107+
ims.add_model_to_solution("GWF_1")

imod/tests/test_mf6/test_mf6_simulation.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import imod
1616
from imod.mf6.model import Modflow6Model
17+
from imod.mf6.model_gwf import GroundwaterFlowModel
1718
from imod.mf6.multimodel.modelsplitter import PartitionInfo
1819
from imod.mf6.simulation import get_models, get_packages
1920
from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo
@@ -298,18 +299,20 @@ def test_split_multiple_models(
298299

299300
simulation = setup_simulation
300301

301-
model_mock1 = MagicMock(spec_set=Modflow6Model)
302+
model_mock1 = MagicMock(spec_set=GroundwaterFlowModel)
302303
model_mock1._model_id = "test_model_id1"
303304

304-
model_mock2 = MagicMock(spec_set=Modflow6Model)
305+
model_mock2 = MagicMock(spec_set=GroundwaterFlowModel)
305306
model_mock2._model_id = "test_model_id2"
306307

307308
simulation["test_model1"] = model_mock1
308309
simulation["test_model2"] = model_mock2
309310

310-
simulation["solver"]["modelnames"] = ["test_model1", "test_model2"]
311+
simulation["solver"].dataset = xr.Dataset(
312+
{"modelnames": ["test_model1", "test_model2"]}
313+
)
311314

312-
slice_model_mock.return_value = MagicMock(spec_set=Modflow6Model)
315+
slice_model_mock.return_value = MagicMock(spec_set=GroundwaterFlowModel)
313316

314317
active = idomain.sel(layer=1)
315318
submodel_labels = xu.zeros_like(active).where(active.grid.face_y > 0.0, 1)
@@ -360,20 +363,22 @@ def test_split_multiple_models_creates_expected_number_of_exchanges(
360363

361364
simulation = setup_simulation
362365

363-
model_mock1 = MagicMock(spec_set=Modflow6Model)
366+
model_mock1 = MagicMock(spec_set=GroundwaterFlowModel)
364367
model_mock1._model_id = "test_model_id1"
365368
model_mock1.domain = idomain
366369

367-
model_mock2 = MagicMock(spec_set=Modflow6Model)
370+
model_mock2 = MagicMock(spec_set=GroundwaterFlowModel)
368371
model_mock2._model_id = "test_model_id2"
369372
model_mock2.domain = idomain
370373

371374
simulation["test_model1"] = model_mock1
372375
simulation["test_model2"] = model_mock2
373376

374-
simulation["solver"]["modelnames"] = ["test_model1", "test_model2"]
377+
simulation["solver"].dataset = xr.Dataset(
378+
{"modelnames": ["test_model1", "test_model2"]}
379+
)
375380

376-
slice_model_mock.return_value = MagicMock(spec_set=Modflow6Model)
381+
slice_model_mock.return_value = MagicMock(spec_set=GroundwaterFlowModel)
377382

378383
active = idomain.sel(layer=1)
379384
submodel_labels = xr.zeros_like(active).where(active.y > 50, 1)
@@ -390,9 +395,11 @@ def test_split_multiple_models_creates_expected_number_of_exchanges(
390395
submodel_labels, create_partition_info_mock()
391396
)
392397

393-
assert exchange_creator_mock.return_value.create_exchanges.call_count == 2
394-
call1 = exchange_creator_mock.return_value.create_exchanges.call_args_list[0][0]
395-
call2 = exchange_creator_mock.return_value.create_exchanges.call_args_list[1][0]
398+
# fmt: off
399+
assert exchange_creator_mock.return_value.create_gwfgwf_exchanges.call_count == 2 # noqa: E501
400+
call1 = exchange_creator_mock.return_value.create_gwfgwf_exchanges.call_args_list[0][0] # noqa: E501
401+
call2 = exchange_creator_mock.return_value.create_gwfgwf_exchanges.call_args_list[1][0] # noqa: E501
402+
# fmt: on
396403

397404
assert call1[0] == "test_model1"
398405
xr.testing.assert_equal(call1[1], idomain.layer)

imod/tests/test_mf6/test_multimodel/test_exchange_creator_structured.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_create_exchanges_validate_number_of_exchanges(
6868
layer = idomain.layer
6969

7070
# Act.
71-
exchanges = exchange_creator.create_exchanges(model_name, layer)
71+
exchanges = exchange_creator.create_gwfgwf_exchanges(model_name, layer)
7272

7373
# Assert.
7474
num_exchanges_x_direction = y_number_partitions * (x_number_partitions - 1)
@@ -296,7 +296,7 @@ def test_create_exchanges_validate_local_cell_ids(
296296
layer = idomain.layer
297297

298298
# Act.
299-
exchanges = exchange_creator.create_exchanges(model_name, layer)
299+
exchanges = exchange_creator.create_gwfgwf_exchanges(model_name, layer)
300300

301301
# Assert.
302302
assert len(exchanges) == len(expected_exchanges)
@@ -337,7 +337,7 @@ def test_exchange_geometric_information(
337337
layer = idomain.layer
338338

339339
# Act.
340-
exchanges = exchange_creator.create_exchanges(model_name, layer)
340+
exchanges = exchange_creator.create_gwfgwf_exchanges(model_name, layer)
341341

342342
# Assert.
343343
assert len(exchanges) == len(expected_exchanges)

imod/tests/test_mf6/test_multimodel/test_exchange_creator_unstructured.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_create_exchanges_unstructured_validate_number_of_exchanges(
4848
exchange_creator = ExchangeCreator_Unstructured(submodel_labels, partition_info)
4949

5050
# Act.
51-
exchanges = exchange_creator.create_exchanges("flow", idomain.layer)
51+
exchanges = exchange_creator.create_gwfgwf_exchanges("flow", idomain.layer)
5252

5353
# Assert.
5454
assert len(exchanges) == number_partitions - 1
@@ -75,7 +75,7 @@ def test_create_exchanges_unstructured_validate_exchange_locations(
7575
exchange_creator = ExchangeCreator_Unstructured(submodel_labels, partition_info)
7676

7777
# Act.
78-
exchanges = exchange_creator.create_exchanges("flow", idomain.layer)
78+
exchanges = exchange_creator.create_gwfgwf_exchanges("flow", idomain.layer)
7979

8080
# Assert.
8181
nlayer = 3
@@ -118,7 +118,7 @@ def test_create_exchanges_unstructured_validate_geometric_coefficients(
118118
exchange_creator = ExchangeCreator_Unstructured(submodel_labels, partition_info)
119119

120120
# Act.
121-
exchanges = exchange_creator.create_exchanges("flow", idomain.layer)
121+
exchanges = exchange_creator.create_gwfgwf_exchanges("flow", idomain.layer)
122122

123123
# Assert.
124124
assert np.allclose(exchanges[0].dataset["cl1"], expected_cl1)
@@ -159,7 +159,7 @@ def test_create_exchanges_unstructured_validate_auxiliary_coefficients(
159159
exchange_creator = ExchangeCreator_Unstructured(submodel_labels, partition_info)
160160

161161
# Act.
162-
_ = exchange_creator.create_exchanges("flow", idomain.layer)
162+
_ = exchange_creator.create_gwfgwf_exchanges("flow", idomain.layer)
163163
"""
164164
exchanges = exchange_creator.create_exchanges("flow", idomain.layer)
165165
"""

imod/tests/test_mf6/test_multimodel/test_mf6_modelsplitter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,3 @@ def test_slice_model_with_auxiliary_variables(tmp_path, flow_transport_simulatio
101101
assert "species_d" in list(split_simulation["flow_1"]["chd"].dataset.keys())
102102
assert "species_d" in list(split_simulation["flow_1"]["rch"].dataset.keys())
103103
assert "concentration" in list(split_simulation["flow_1"]["well"].dataset.keys())
104-
pass

0 commit comments

Comments
 (0)