Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ Fixed
:meth:`imod.mf6.LayeredWell.from_imod5_cap_data` now regrids the iMOD5 CAP
data to the MODFLOW6 target discretization.
- Fixed confusing warning about inconsistent IPF columns when loading GEN files.
- Fix bug where iMOD Python would error on writing a model where package
settings were specified as dask array, which could happen when loading a model
lazily with :meth:`imod.mf6.Modflow6Simulation.from_file` and not
computing the data before writing.

Changed
~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions imod/common/utilities/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from imod.common.utilities.clip import clip_by_grid
from imod.common.utilities.dataclass_type import DataclassType, EmptyRegridMethod
from imod.common.utilities.dtype import is_integer
from imod.common.utilities.value_filters import is_valid
from imod.common.utilities.value_filters import enforce_scalar, is_valid
from imod.typing import Imod5DataDict
from imod.typing.grid import (
GridDataArray,
Expand Down Expand Up @@ -48,7 +48,7 @@ def handle_extra_coords(coordname: str, target_grid: GridDataArray, variable_dat
if hasattr(variable_data, "coords"):
if coordname in target_grid.coords:
return variable_data.assign_coords(
{coordname: target_grid.coords[coordname].values[()]}
{coordname: enforce_scalar(target_grid.coords[coordname])}
)
elif coordname in variable_data.coords:
return variable_data.drop_vars(coordname)
Expand All @@ -73,7 +73,7 @@ def _regrid_array(

# skip regridding for scalar arrays with no valid values (such as "None")
scalar_da: bool = is_scalar(da)
if scalar_da and not is_valid(da.values[()]):
if scalar_da and not is_valid(enforce_scalar(da)):
return None

# the dataarray might be a scalar. If it is, then it does not need regridding.
Expand Down
10 changes: 5 additions & 5 deletions imod/common/utilities/value_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
from xarray.core.utils import is_scalar

from imod.typing import GridDataset
from imod.typing import GridDataArray, GridDataset


def is_valid(value: Any) -> bool:
Expand All @@ -29,16 +29,16 @@ def is_valid(value: Any) -> bool:


def is_empty_dataarray(da: Any) -> bool:
return isinstance(da, xr.DataArray) and da.isnull().all().item()
return isinstance(da, xr.DataArray) and enforce_scalar(da.isnull().all())


def get_scalar_variables(ds: GridDataset) -> list[str]:
"""Returns scalar variables in a dataset."""
return [var for var, arr in ds.variables.items() if is_scalar(arr)]


def enforce_scalar(a: np.ndarray) -> np.ndarray:
def enforce_scalar(a: GridDataArray) -> Any:
"""Enforce scalar value from array."""
if a.size == 1:
return a.item()
return ValueError(f"Array has size {a.size}, expected size 1.")
return a.compute().item()
raise ValueError(f"Array has size {a.size}, expected size 1.")
5 changes: 2 additions & 3 deletions imod/formats/array_io/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import xarray as xr

from imod.common.utilities.value_filters import enforce_scalar
from imod.util import nested_dict, spatial, time


Expand Down Expand Up @@ -114,8 +113,8 @@ def _scalar_z_coord(coords, tops, bots):
top = np.unique(tops)
bot = np.unique(bots)
if top.size == bot.size == 1:
top = enforce_scalar(top)
bot = enforce_scalar(bot)
top = top.item()
bot = bot.item()
dz = top - bot
z = top - 0.5 * dz
coords["dz"] = float(dz) # cast from array
Expand Down
11 changes: 6 additions & 5 deletions imod/mf6/adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from imod.common.interfaces.iregridpackage import IRegridPackage
from imod.common.utilities.dataclass_type import DataclassType
from imod.common.utilities.value_filters import enforce_scalar
from imod.mf6.package import Package
from imod.schemata import AllValueSchema, DimsSchema, DTypeSchema
from imod.typing import GridDataArray
Expand All @@ -43,11 +44,11 @@ def __init__(self, ats_percel: Optional[float] = None, validate: bool = True):

def _render(self, directory, pkgname, globaltimes, binary):
render_dict = {}
render_dict["scheme"] = self.dataset["scheme"].item()
if "ats_percel" in self.dataset and self._valid(
self.dataset["ats_percel"].item()
):
render_dict["ats_percel"] = self.dataset["ats_percel"].item()
render_dict["scheme"] = enforce_scalar(self.dataset["scheme"])
if "ats_percel" in self.dataset:
ats_percel = enforce_scalar(self.dataset["ats_percel"])
if self._valid(ats_percel):
render_dict["ats_percel"] = ats_percel
return self._template.render(render_dict)

def mask(self, _) -> Package:
Expand Down
5 changes: 3 additions & 2 deletions imod/mf6/exchangebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import xarray as xr

from imod.common.utilities.value_filters import enforce_scalar
from imod.mf6.package import Package

_pkg_id_to_type = {"gwfgwf": "GWF6-GWF6", "gwfgwt": "GWF6-GWT6", "gwtgwt": "GWT6-GWT6"}
Expand All @@ -21,13 +22,13 @@ class ExchangeBase(Package):
def model_name1(self) -> str:
if "model_name_1" not in self.dataset:
raise ValueError("model_name_1 not present in dataset")
return self.dataset["model_name_1"].item()
return enforce_scalar(self.dataset["model_name_1"])

@property
def model_name2(self) -> str:
if "model_name_2" not in self.dataset:
raise ValueError("model_name_2 not present in dataset")
return self.dataset["model_name_2"].item()
return enforce_scalar(self.dataset["model_name_2"])

def package_name(self) -> str:
return f"{self.model_name1}_{self.model_name2}"
Expand Down
3 changes: 2 additions & 1 deletion imod/mf6/hfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_extract_zbounds_from_vertical_polygons,
_prepare_index_names,
)
from imod.common.utilities.value_filters import enforce_scalar
from imod.logging import LogLevel, init_log_decorator, logger
from imod.mf6.boundary_condition import BoundaryCondition
from imod.mf6.dis import StructuredDiscretization
Expand Down Expand Up @@ -567,7 +568,7 @@ def from_file(cls, path: str | Path, **kwargs) -> Self:
Refer to the xarray documentation for the possible keyword arguments.
"""
instance = super().from_file(path, **kwargs)
geometry = json.loads(instance.dataset["geometry"].values.item())
geometry = json.loads(enforce_scalar(instance.dataset["geometry"]))
instance.line_data = gpd.GeoDataFrame.from_features(geometry)

return instance
Expand Down
54 changes: 28 additions & 26 deletions imod/mf6/lak.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, Optional

import jinja2
import numpy as np
import pandas as pd
import xarray as xr

from imod import mf6
from imod.common.utilities.value_filters import enforce_scalar
from imod.logging import init_log_decorator
from imod.mf6.boundary_condition import BoundaryCondition
from imod.mf6.package import Package
Expand Down Expand Up @@ -296,7 +297,7 @@ def create_outlet_data(outlets, name_to_number):

# Convert names to numbers
for var in ("lakein", "lakeout"):
name = outlet.dataset[var].item()
name = enforce_scalar(outlet.dataset[var])
if (
var == "lakeout" and name == ""
): # the outlet lakeout can be outside of the model
Expand All @@ -317,7 +318,7 @@ def create_outlet_data(outlets, name_to_number):
if "time" in outlet.dataset[var].dims:
value = 0.0
else:
value = outlet.dataset[var].item()
value = enforce_scalar(outlet.dataset[var])
if value is None:
value = np.nan
else:
Expand Down Expand Up @@ -797,29 +798,31 @@ def __init__(

@staticmethod
def from_lakes_and_outlets(
lakes,
outlets=None,
print_input=False,
print_stage=False,
print_flows=False,
save_flows=False,
stagefile=None,
budgetfile=None,
budgetcsvfile=None,
package_convergence_filename=None,
time_conversion=None,
length_conversion=None,
lakes: list[LakeData],
outlets: Optional[list[OutletBase]] = None,
print_input: bool = False,
print_stage: bool = False,
print_flows: bool = False,
save_flows: bool = False,
stagefile: Optional[str] = None,
budgetfile: Optional[str] = None,
budgetcsvfile: Optional[str] = None,
package_convergence_filename: Optional[str] = None,
time_conversion: Optional[float] = None,
length_conversion: Optional[float] = None,
):
package_content = {}
package_content: dict[str, Any] = {}
name_to_number = {
lake["boundname"].item(): i + 1 for i, lake in enumerate(lakes)
enforce_scalar(lake["boundname"]): i + 1 for i, lake in enumerate(lakes)
}

# Package data
lake_numbers = list(name_to_number.values())
n_connection = [lake["connection_type"].count().values[()] for lake in lakes]
n_connection = [
enforce_scalar(lake["connection_type"].count()) for lake in lakes
]
package_content["lake_starting_stage"] = xr.DataArray(
data=[lake["starting_stage"].item() for lake in lakes],
data=[enforce_scalar(lake["starting_stage"]) for lake in lakes],
dims=[LAKE_DIM],
)
package_content["lake_number"] = xr.DataArray(
Expand Down Expand Up @@ -924,7 +927,7 @@ def _render(self, directory, pkgname, globaltimes, binary):
"time_conversion",
"length_conversion",
):
value = self[var].item()
value = enforce_scalar(self[var])
if self._valid(value):
d[var] = value

Expand All @@ -944,7 +947,7 @@ def _render(self, directory, pkgname, globaltimes, binary):
self.dataset["lake_starting_stage"],
):
nconn = (self.dataset["connection_lake_number"] == number).sum()
row = tuple(a.item() for a in (number, stage, nconn, name))
row = tuple(enforce_scalar(a) for a in (number, stage, nconn, name))
packagedata.append(row)
d["packagedata"] = packagedata

Expand Down Expand Up @@ -1055,7 +1058,7 @@ def __init__(self, period_number):
for tssname in self._period_data:
if len(period_data[tssname].dims) > 0:
for index in period_data.coords["index"].values:
value = period_data[tssname].sel(index=index).item()
value = enforce_scalar(period_data[tssname].sel(index=index))
isvalid = False
if isinstance(value, str):
isvalid = value != ""
Expand Down Expand Up @@ -1161,15 +1164,14 @@ def _write_laketable_files(self, directory, lake_number_to_filename):

# count number of rows
stage_col = table.sel({"column": "stage"})
d["nrow"] = (
stage_col.where(pd.api.types.is_numeric_dtype).count().values[()]
)
stage_counts = stage_col.where(pd.api.types.is_numeric_dtype).count()
d["nrow"] = enforce_scalar(stage_counts)

# check if the barea column is present for this table (and not filled with nan's)
has_barea_column = "barea" in table.coords["column"]
if has_barea_column:
barea_column = table.sel({"column": "barea"})
has_barea_column = barea_column.notnull().any().item()
has_barea_column = enforce_scalar(barea_column.notnull().any())

columns = ["stage", "sarea", "volume"]
if has_barea_column:
Expand Down
17 changes: 8 additions & 9 deletions imod/mf6/npf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, cast
from typing import Optional

import numpy as np
import xarray as xr
Expand All @@ -10,6 +10,7 @@
from imod.common.utilities.regrid import (
_regrid_package_data,
)
from imod.common.utilities.value_filters import enforce_scalar
from imod.logging import init_log_decorator
from imod.mf6.package import Package
from imod.mf6.regrid.regrid_schemes import (
Expand All @@ -32,17 +33,15 @@


def _dataarray_to_bool(griddataarray: GridDataArray) -> bool:
if griddataarray is None or griddataarray.values is None:
return False
scalar_value: Optional[bool] = enforce_scalar(griddataarray)

if griddataarray.values.size != 1:
raise ValueError("DataArray is not a single value")
if scalar_value is None:
return False

if griddataarray.values.dtype != bool:
raise ValueError("DataArray is not a boolean")
if not isinstance(scalar_value, bool):
raise ValueError("DataArray does not contain a boolean value")

bool_value = cast(bool, griddataarray.values.item())
return bool_value
return scalar_value


class NodePropertyFlow(Package, IRegridPackage):
Expand Down
6 changes: 3 additions & 3 deletions imod/mf6/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from imod.common.interfaces.iregridpackage import IRegridPackage
from imod.common.utilities.value_filters import is_empty_dataarray
from imod.common.utilities.value_filters import enforce_scalar, is_empty_dataarray
from imod.logging import init_log_decorator
from imod.mf6.package import Package
from imod.mf6.write_context import WriteContext
Expand Down Expand Up @@ -172,11 +172,11 @@ def _render(self, directory, pkgname, globaltimes, binary):
package_times = self.dataset[datavar].coords["time"].values
starts = np.searchsorted(globaltimes, package_times) + 1
for i, s in enumerate(starts):
setting = self.dataset[datavar].isel(time=i).values[()]
setting = enforce_scalar(self.dataset[datavar].isel(time=i))
periods[s][key] = self._get_ocsetting(setting)

else:
setting = self.dataset[datavar].item()
setting = enforce_scalar(self.dataset[datavar])
periods[1][key] = self._get_ocsetting(setting)

d["periods"] = periods
Expand Down
3 changes: 2 additions & 1 deletion imod/mf6/timedis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import cftime
import numpy as np

from imod.common.utilities.value_filters import enforce_scalar
from imod.logging import init_log_decorator
from imod.mf6.package import Package
from imod.mf6.write_context import WriteContext
Expand Down Expand Up @@ -79,7 +80,7 @@ def _render(self, directory, pkgname, globaltimes, binary):
"start_date_time": start_date_time,
}
if "ats_filename" in self.dataset:
d["ats_filename"] = self.dataset["ats_filename"].item()
d["ats_filename"] = enforce_scalar(self.dataset["ats_filename"])
timestep_duration = self.dataset["timestep_duration"]
n_timesteps = self.dataset["n_timesteps"]
timestep_multiplier = self.dataset["timestep_multiplier"]
Expand Down
5 changes: 3 additions & 2 deletions imod/mf6/wel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from imod.common.utilities.grid import broadcast_to_full_domain
from imod.common.utilities.layer import create_layered_top
from imod.common.utilities.schemata import validation_pkg_error_message
from imod.common.utilities.value_filters import enforce_scalar
from imod.logging import init_log_decorator, logger
from imod.logging.logging_decorators import standard_log_decorator
from imod.logging.loglevel import LogLevel
Expand Down Expand Up @@ -1107,8 +1108,8 @@ def _assign_wells_to_layers(

index_names = wells_df.index.names

minimum_k = self.dataset["minimum_k"].item()
minimum_thickness = self.dataset["minimum_thickness"].item()
minimum_k = enforce_scalar(self.dataset["minimum_k"])
minimum_thickness = enforce_scalar(self.dataset["minimum_thickness"])

# Unset multi-index, because assign_wells cannot deal with
# multi-indices which is returned by self.dataset.to_dataframe() in
Expand Down
Loading
Loading