diff --git a/docs/api/changelog.rst b/docs/api/changelog.rst index a1bfc7b5e..dacf8e9c9 100644 --- a/docs/api/changelog.rst +++ b/docs/api/changelog.rst @@ -70,6 +70,10 @@ Fixed :meth:`imod.mf6.LayeredWell.from_imod5_cap_data` now regrids the iMOD5 CAP data to the MODFLOW6 target discretization. - Fixed confusing warning about inconsistent IPF columns when loading GEN files. +- Fix bug where iMOD Python would error on writing a model where package + settings were specified as dask array, which could happen when loading a model + lazily with :meth:`imod.mf6.Modflow6Simulation.from_file` and not + computing the data before writing. Changed ~~~~~~~ diff --git a/imod/common/utilities/regrid.py b/imod/common/utilities/regrid.py index bffb13cb4..a3d5995b3 100644 --- a/imod/common/utilities/regrid.py +++ b/imod/common/utilities/regrid.py @@ -19,7 +19,7 @@ from imod.common.utilities.clip import clip_by_grid from imod.common.utilities.dataclass_type import DataclassType, EmptyRegridMethod from imod.common.utilities.dtype import is_integer -from imod.common.utilities.value_filters import is_valid +from imod.common.utilities.value_filters import enforce_scalar, is_valid from imod.typing import Imod5DataDict from imod.typing.grid import ( GridDataArray, @@ -48,7 +48,7 @@ def handle_extra_coords(coordname: str, target_grid: GridDataArray, variable_dat if hasattr(variable_data, "coords"): if coordname in target_grid.coords: return variable_data.assign_coords( - {coordname: target_grid.coords[coordname].values[()]} + {coordname: enforce_scalar(target_grid.coords[coordname])} ) elif coordname in variable_data.coords: return variable_data.drop_vars(coordname) @@ -73,7 +73,7 @@ def _regrid_array( # skip regridding for scalar arrays with no valid values (such as "None") scalar_da: bool = is_scalar(da) - if scalar_da and not is_valid(da.values[()]): + if scalar_da and not is_valid(enforce_scalar(da)): return None # the dataarray might be a scalar. If it is, then it does not need regridding. diff --git a/imod/common/utilities/value_filters.py b/imod/common/utilities/value_filters.py index 6f94cacfb..4fa5aad97 100644 --- a/imod/common/utilities/value_filters.py +++ b/imod/common/utilities/value_filters.py @@ -4,7 +4,7 @@ import xarray as xr from xarray.core.utils import is_scalar -from imod.typing import GridDataset +from imod.typing import GridDataArray, GridDataset def is_valid(value: Any) -> bool: @@ -29,7 +29,7 @@ def is_valid(value: Any) -> bool: def is_empty_dataarray(da: Any) -> bool: - return isinstance(da, xr.DataArray) and da.isnull().all().item() + return isinstance(da, xr.DataArray) and enforce_scalar(da.isnull().all()) def get_scalar_variables(ds: GridDataset) -> list[str]: @@ -37,8 +37,8 @@ def get_scalar_variables(ds: GridDataset) -> list[str]: return [var for var, arr in ds.variables.items() if is_scalar(arr)] -def enforce_scalar(a: np.ndarray) -> np.ndarray: +def enforce_scalar(a: GridDataArray) -> Any: """Enforce scalar value from array.""" if a.size == 1: - return a.item() - return ValueError(f"Array has size {a.size}, expected size 1.") + return a.compute().item() + raise ValueError(f"Array has size {a.size}, expected size 1.") diff --git a/imod/formats/array_io/reading.py b/imod/formats/array_io/reading.py index 1f8ee8431..412513214 100644 --- a/imod/formats/array_io/reading.py +++ b/imod/formats/array_io/reading.py @@ -7,7 +7,6 @@ import numpy as np import xarray as xr -from imod.common.utilities.value_filters import enforce_scalar from imod.util import nested_dict, spatial, time @@ -114,8 +113,8 @@ def _scalar_z_coord(coords, tops, bots): top = np.unique(tops) bot = np.unique(bots) if top.size == bot.size == 1: - top = enforce_scalar(top) - bot = enforce_scalar(bot) + top = top.item() + bot = bot.item() dz = top - bot z = top - 0.5 * dz coords["dz"] = float(dz) # cast from array diff --git a/imod/mf6/adv.py b/imod/mf6/adv.py index 366dc1f48..e2dd0438e 100644 --- a/imod/mf6/adv.py +++ b/imod/mf6/adv.py @@ -17,6 +17,7 @@ from imod.common.interfaces.iregridpackage import IRegridPackage from imod.common.utilities.dataclass_type import DataclassType +from imod.common.utilities.value_filters import enforce_scalar from imod.mf6.package import Package from imod.schemata import AllValueSchema, DimsSchema, DTypeSchema from imod.typing import GridDataArray @@ -43,11 +44,11 @@ def __init__(self, ats_percel: Optional[float] = None, validate: bool = True): def _render(self, directory, pkgname, globaltimes, binary): render_dict = {} - render_dict["scheme"] = self.dataset["scheme"].item() - if "ats_percel" in self.dataset and self._valid( - self.dataset["ats_percel"].item() - ): - render_dict["ats_percel"] = self.dataset["ats_percel"].item() + render_dict["scheme"] = enforce_scalar(self.dataset["scheme"]) + if "ats_percel" in self.dataset: + ats_percel = enforce_scalar(self.dataset["ats_percel"]) + if self._valid(ats_percel): + render_dict["ats_percel"] = ats_percel return self._template.render(render_dict) def mask(self, _) -> Package: diff --git a/imod/mf6/exchangebase.py b/imod/mf6/exchangebase.py index cf078f028..3782c0da2 100644 --- a/imod/mf6/exchangebase.py +++ b/imod/mf6/exchangebase.py @@ -4,6 +4,7 @@ import numpy as np import xarray as xr +from imod.common.utilities.value_filters import enforce_scalar from imod.mf6.package import Package _pkg_id_to_type = {"gwfgwf": "GWF6-GWF6", "gwfgwt": "GWF6-GWT6", "gwtgwt": "GWT6-GWT6"} @@ -21,13 +22,13 @@ class ExchangeBase(Package): def model_name1(self) -> str: if "model_name_1" not in self.dataset: raise ValueError("model_name_1 not present in dataset") - return self.dataset["model_name_1"].item() + return enforce_scalar(self.dataset["model_name_1"]) @property def model_name2(self) -> str: if "model_name_2" not in self.dataset: raise ValueError("model_name_2 not present in dataset") - return self.dataset["model_name_2"].item() + return enforce_scalar(self.dataset["model_name_2"]) def package_name(self) -> str: return f"{self.model_name1}_{self.model_name2}" diff --git a/imod/mf6/hfb.py b/imod/mf6/hfb.py index 3397ac885..b459ab13b 100644 --- a/imod/mf6/hfb.py +++ b/imod/mf6/hfb.py @@ -27,6 +27,7 @@ _extract_zbounds_from_vertical_polygons, _prepare_index_names, ) +from imod.common.utilities.value_filters import enforce_scalar from imod.logging import LogLevel, init_log_decorator, logger from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.dis import StructuredDiscretization @@ -567,7 +568,7 @@ def from_file(cls, path: str | Path, **kwargs) -> Self: Refer to the xarray documentation for the possible keyword arguments. """ instance = super().from_file(path, **kwargs) - geometry = json.loads(instance.dataset["geometry"].values.item()) + geometry = json.loads(enforce_scalar(instance.dataset["geometry"])) instance.line_data = gpd.GeoDataFrame.from_features(geometry) return instance diff --git a/imod/mf6/lak.py b/imod/mf6/lak.py index bdc2b231b..65cc5e382 100644 --- a/imod/mf6/lak.py +++ b/imod/mf6/lak.py @@ -9,7 +9,7 @@ import textwrap from collections import defaultdict from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional import jinja2 import numpy as np @@ -17,6 +17,7 @@ import xarray as xr from imod import mf6 +from imod.common.utilities.value_filters import enforce_scalar from imod.logging import init_log_decorator from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.package import Package @@ -296,7 +297,7 @@ def create_outlet_data(outlets, name_to_number): # Convert names to numbers for var in ("lakein", "lakeout"): - name = outlet.dataset[var].item() + name = enforce_scalar(outlet.dataset[var]) if ( var == "lakeout" and name == "" ): # the outlet lakeout can be outside of the model @@ -317,7 +318,7 @@ def create_outlet_data(outlets, name_to_number): if "time" in outlet.dataset[var].dims: value = 0.0 else: - value = outlet.dataset[var].item() + value = enforce_scalar(outlet.dataset[var]) if value is None: value = np.nan else: @@ -797,29 +798,31 @@ def __init__( @staticmethod def from_lakes_and_outlets( - lakes, - outlets=None, - print_input=False, - print_stage=False, - print_flows=False, - save_flows=False, - stagefile=None, - budgetfile=None, - budgetcsvfile=None, - package_convergence_filename=None, - time_conversion=None, - length_conversion=None, + lakes: list[LakeData], + outlets: Optional[list[OutletBase]] = None, + print_input: bool = False, + print_stage: bool = False, + print_flows: bool = False, + save_flows: bool = False, + stagefile: Optional[str] = None, + budgetfile: Optional[str] = None, + budgetcsvfile: Optional[str] = None, + package_convergence_filename: Optional[str] = None, + time_conversion: Optional[float] = None, + length_conversion: Optional[float] = None, ): - package_content = {} + package_content: dict[str, Any] = {} name_to_number = { - lake["boundname"].item(): i + 1 for i, lake in enumerate(lakes) + enforce_scalar(lake["boundname"]): i + 1 for i, lake in enumerate(lakes) } # Package data lake_numbers = list(name_to_number.values()) - n_connection = [lake["connection_type"].count().values[()] for lake in lakes] + n_connection = [ + enforce_scalar(lake["connection_type"].count()) for lake in lakes + ] package_content["lake_starting_stage"] = xr.DataArray( - data=[lake["starting_stage"].item() for lake in lakes], + data=[enforce_scalar(lake["starting_stage"]) for lake in lakes], dims=[LAKE_DIM], ) package_content["lake_number"] = xr.DataArray( @@ -924,7 +927,7 @@ def _render(self, directory, pkgname, globaltimes, binary): "time_conversion", "length_conversion", ): - value = self[var].item() + value = enforce_scalar(self[var]) if self._valid(value): d[var] = value @@ -944,7 +947,7 @@ def _render(self, directory, pkgname, globaltimes, binary): self.dataset["lake_starting_stage"], ): nconn = (self.dataset["connection_lake_number"] == number).sum() - row = tuple(a.item() for a in (number, stage, nconn, name)) + row = tuple(enforce_scalar(a) for a in (number, stage, nconn, name)) packagedata.append(row) d["packagedata"] = packagedata @@ -1055,7 +1058,7 @@ def __init__(self, period_number): for tssname in self._period_data: if len(period_data[tssname].dims) > 0: for index in period_data.coords["index"].values: - value = period_data[tssname].sel(index=index).item() + value = enforce_scalar(period_data[tssname].sel(index=index)) isvalid = False if isinstance(value, str): isvalid = value != "" @@ -1161,15 +1164,14 @@ def _write_laketable_files(self, directory, lake_number_to_filename): # count number of rows stage_col = table.sel({"column": "stage"}) - d["nrow"] = ( - stage_col.where(pd.api.types.is_numeric_dtype).count().values[()] - ) + stage_counts = stage_col.where(pd.api.types.is_numeric_dtype).count() + d["nrow"] = enforce_scalar(stage_counts) # check if the barea column is present for this table (and not filled with nan's) has_barea_column = "barea" in table.coords["column"] if has_barea_column: barea_column = table.sel({"column": "barea"}) - has_barea_column = barea_column.notnull().any().item() + has_barea_column = enforce_scalar(barea_column.notnull().any()) columns = ["stage", "sarea", "volume"] if has_barea_column: diff --git a/imod/mf6/npf.py b/imod/mf6/npf.py index 4831dbd38..bfc422b85 100644 --- a/imod/mf6/npf.py +++ b/imod/mf6/npf.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import Optional import numpy as np import xarray as xr @@ -10,6 +10,7 @@ from imod.common.utilities.regrid import ( _regrid_package_data, ) +from imod.common.utilities.value_filters import enforce_scalar from imod.logging import init_log_decorator from imod.mf6.package import Package from imod.mf6.regrid.regrid_schemes import ( @@ -32,17 +33,15 @@ def _dataarray_to_bool(griddataarray: GridDataArray) -> bool: - if griddataarray is None or griddataarray.values is None: - return False + scalar_value: Optional[bool] = enforce_scalar(griddataarray) - if griddataarray.values.size != 1: - raise ValueError("DataArray is not a single value") + if scalar_value is None: + return False - if griddataarray.values.dtype != bool: - raise ValueError("DataArray is not a boolean") + if not isinstance(scalar_value, bool): + raise ValueError("DataArray does not contain a boolean value") - bool_value = cast(bool, griddataarray.values.item()) - return bool_value + return scalar_value class NodePropertyFlow(Package, IRegridPackage): diff --git a/imod/mf6/oc.py b/imod/mf6/oc.py index 6921a3db8..677d361a4 100644 --- a/imod/mf6/oc.py +++ b/imod/mf6/oc.py @@ -6,7 +6,7 @@ import numpy as np from imod.common.interfaces.iregridpackage import IRegridPackage -from imod.common.utilities.value_filters import is_empty_dataarray +from imod.common.utilities.value_filters import enforce_scalar, is_empty_dataarray from imod.logging import init_log_decorator from imod.mf6.package import Package from imod.mf6.write_context import WriteContext @@ -172,11 +172,11 @@ def _render(self, directory, pkgname, globaltimes, binary): package_times = self.dataset[datavar].coords["time"].values starts = np.searchsorted(globaltimes, package_times) + 1 for i, s in enumerate(starts): - setting = self.dataset[datavar].isel(time=i).values[()] + setting = enforce_scalar(self.dataset[datavar].isel(time=i)) periods[s][key] = self._get_ocsetting(setting) else: - setting = self.dataset[datavar].item() + setting = enforce_scalar(self.dataset[datavar]) periods[1][key] = self._get_ocsetting(setting) d["periods"] = periods diff --git a/imod/mf6/timedis.py b/imod/mf6/timedis.py index f84154f00..d55c9a75d 100644 --- a/imod/mf6/timedis.py +++ b/imod/mf6/timedis.py @@ -1,6 +1,7 @@ import cftime import numpy as np +from imod.common.utilities.value_filters import enforce_scalar from imod.logging import init_log_decorator from imod.mf6.package import Package from imod.mf6.write_context import WriteContext @@ -79,7 +80,7 @@ def _render(self, directory, pkgname, globaltimes, binary): "start_date_time": start_date_time, } if "ats_filename" in self.dataset: - d["ats_filename"] = self.dataset["ats_filename"].item() + d["ats_filename"] = enforce_scalar(self.dataset["ats_filename"]) timestep_duration = self.dataset["timestep_duration"] n_timesteps = self.dataset["n_timesteps"] timestep_multiplier = self.dataset["timestep_multiplier"] diff --git a/imod/mf6/wel.py b/imod/mf6/wel.py index 83eca04bd..e39af906a 100644 --- a/imod/mf6/wel.py +++ b/imod/mf6/wel.py @@ -19,6 +19,7 @@ from imod.common.utilities.grid import broadcast_to_full_domain from imod.common.utilities.layer import create_layered_top from imod.common.utilities.schemata import validation_pkg_error_message +from imod.common.utilities.value_filters import enforce_scalar from imod.logging import init_log_decorator, logger from imod.logging.logging_decorators import standard_log_decorator from imod.logging.loglevel import LogLevel @@ -1107,8 +1108,8 @@ def _assign_wells_to_layers( index_names = wells_df.index.names - minimum_k = self.dataset["minimum_k"].item() - minimum_thickness = self.dataset["minimum_thickness"].item() + minimum_k = enforce_scalar(self.dataset["minimum_k"]) + minimum_thickness = enforce_scalar(self.dataset["minimum_thickness"]) # Unset multi-index, because assign_wells cannot deal with # multi-indices which is returned by self.dataset.to_dataframe() in diff --git a/imod/msw/coupler_mapping.py b/imod/msw/coupler_mapping.py index 793a1a1cf..1cad8080d 100644 --- a/imod/msw/coupler_mapping.py +++ b/imod/msw/coupler_mapping.py @@ -5,6 +5,7 @@ import xarray as xr from imod.common.interfaces.ipackagebase import IPackageBase +from imod.common.utilities.value_filters import enforce_scalar from imod.mf6.dis import StructuredDiscretization from imod.mf6.mf6_wel_adapter import Mf6Wel from imod.msw.fixed_format import VariableMetaData @@ -49,7 +50,7 @@ def _create_mod_id_rch( n_subunit = svat["subunit"].size # Sum active cells and convert to int for MyPy. This is the number of # modflow cells in the top layer. - n_mod_top = int(idomain_top_active.sum().item()) + n_mod_top = int(enforce_scalar(idomain_top_active.sum())) # idomain does not have a subunit dimension, so tile for n_subunits mod_id_1d: IntArray = np.tile(np.arange(1, n_mod_top + 1), (n_subunit, 1)) @@ -128,7 +129,7 @@ def _create_well_id( well_column = well_cellid.sel(dim_cellid="column").data - 1 # Sum active cells and convert to int for MyPy. This is the number of # modflow cells in the top layer. - n_mod = int(idomain_active.sum().item()) + n_mod = int(enforce_scalar(idomain_active.sum())) mod_id = xr.full_like(idomain_active, 0, dtype=np.int64) mod_id.data[idomain_active.data] = np.arange(1, n_mod + 1) diff --git a/imod/msw/model.py b/imod/msw/model.py index 1c604c352..42ecf9c0a 100644 --- a/imod/msw/model.py +++ b/imod/msw/model.py @@ -14,7 +14,6 @@ from imod.common.utilities.clip import clip_by_grid from imod.common.utilities.partitioninfo import create_partition_info from imod.common.utilities.regrid import regrid_imod5_cap_data -from imod.common.utilities.value_filters import enforce_scalar from imod.common.utilities.version import prepend_content_with_version_info from imod.mf6.dis import StructuredDiscretization from imod.mf6.mf6_wel_adapter import Mf6Wel @@ -228,8 +227,8 @@ def _get_starttime(self): year, time_since_start_year = to_metaswap_timeformat([starttime]) - year = int(enforce_scalar(year.values)) - time_since_start_year = float(enforce_scalar(time_since_start_year.values)) + year = int(year.item()) + time_since_start_year = float(time_since_start_year.item()) return year, time_since_start_year diff --git a/imod/schemata.py b/imod/schemata.py index ab7e338a0..60b7ffc0e 100644 --- a/imod/schemata.py +++ b/imod/schemata.py @@ -53,6 +53,7 @@ import xugrid as xu from numpy.typing import DTypeLike # noqa: F401 +from imod.common.utilities.value_filters import enforce_scalar from imod.typing import GridDataArray, ScalarAsDataArray from imod.typing.grid import notnull from imod.util.imports import MissingOptionalModule @@ -183,7 +184,7 @@ def validate(self, obj: ScalarAsDataArray, **kwargs) -> None: return # MODFLOW 6 is not case sensitive for string options. - value = obj.item() + value = enforce_scalar(obj) if isinstance(value, str): value = value.lower() diff --git a/imod/tests/test_common/test_utilities/test_value_filters.py b/imod/tests/test_common/test_utilities/test_value_filters.py new file mode 100644 index 000000000..d8c1bbb5b --- /dev/null +++ b/imod/tests/test_common/test_utilities/test_value_filters.py @@ -0,0 +1,70 @@ +import dask +import numpy as np +import pytest +import pytest_cases +import xarray as xr + +from imod.common.utilities.value_filters import enforce_scalar, is_empty_dataarray + + +class ScalarCases: + def case_int(self): + return 42 + + def case_float(self): + return 42.0 + + def case_bool(self): + return False + + def case_none(self): + return None + + def case_string(self): + return "test" + + +@pytest_cases.parametrize_with_cases("input_value", cases=ScalarCases) +def test_enforce_scalar(input_value): + da = xr.DataArray([input_value]) + assert enforce_scalar(da) == input_value + + da = xr.DataArray([input_value, input_value]) + with pytest.raises(ValueError): + enforce_scalar(da) + + +@pytest_cases.parametrize_with_cases("input_value", cases=ScalarCases) +def test_enforce_scalar_from_dask(input_value): + data = dask.array.from_array([input_value], chunks=1) + da = xr.DataArray(data) + assert enforce_scalar(da) == input_value + + data = dask.array.from_array([input_value, input_value], chunks=2) + da = xr.DataArray(data) + with pytest.raises(ValueError): + enforce_scalar(da) + + +def test_is_empty_dataarray(): + da = xr.DataArray([1, 2, 3]) + assert not is_empty_dataarray(da) + + da = xr.DataArray([None, None, None]) + assert is_empty_dataarray(da) + + da = xr.DataArray([np.nan, np.nan]) + assert is_empty_dataarray(da) + + da = xr.DataArray([1, None, 3]) + assert not is_empty_dataarray(da) + + da = xr.DataArray([1, np.nan, 3]) + assert not is_empty_dataarray(da) + + da = "not a DataArray" + assert not is_empty_dataarray(da) + + data = dask.array.from_array([np.nan, np.nan], chunks=2) + da = xr.DataArray(data) + assert is_empty_dataarray(da)