diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index e4347b639..9ae4ef64b 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -304,7 +304,9 @@ def _patch_pytorch_tile_facade() -> None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel - selected_backend_is_pytorch = getattr(backend, "__backend_name__", None) == "pytorch" + selected_backend_is_pytorch = ( + getattr(backend, "__backend_name__", None) == "pytorch" + ) try: import numpy as _np # pylint: disable=import-outside-toplevel diff --git a/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py b/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py index bdffc0a86..31095fd8a 100644 --- a/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py +++ b/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py @@ -1,8 +1,8 @@ """Compatibility package for shared NumPy RNG helpers.""" +import sys as _sys from importlib import util as _importlib_util from pathlib import Path as _Path -import sys as _sys import numpy as _np diff --git a/src/pyrecest/_backend/capabilities/__init__.py b/src/pyrecest/_backend/capabilities/__init__.py index c0ab40e5e..ca5279a64 100644 --- a/src/pyrecest/_backend/capabilities/__init__.py +++ b/src/pyrecest/_backend/capabilities/__init__.py @@ -74,7 +74,9 @@ def _patch_pytorch_dot_outer_device_contract() -> None: helper_names = ("dot", "outer") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": diff --git a/src/pyrecest/_backend/jax/fft.py b/src/pyrecest/_backend/jax/fft.py index fb71a6dfd..fcd26b65e 100644 --- a/src/pyrecest/_backend/jax/fft.py +++ b/src/pyrecest/_backend/jax/fft.py @@ -30,11 +30,15 @@ def _normalize_real_fft_axis(axis): def rfft(a, n=None, axis=-1, norm=None): - return _fft.rfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm) + return _fft.rfft( + _jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm + ) def irfft(a, n=None, axis=-1, norm=None): - return _fft.irfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm) + return _fft.irfft( + _jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm + ) def fftn(a, s=None, axes=None, norm=None): diff --git a/src/pyrecest/_backend/pytorch/_dtype.py b/src/pyrecest/_backend/pytorch/_dtype.py index fdd0fdea1..c3b0b54ec 100644 --- a/src/pyrecest/_backend/pytorch/_dtype.py +++ b/src/pyrecest/_backend/pytorch/_dtype.py @@ -15,8 +15,8 @@ from pyrecest._backend._dtype_utils import ( get_default_dtype as _shared_get_default_dtype, ) +from torch import bool as torch_bool from torch import ( - bool as torch_bool, complex64, complex128, float32, diff --git a/src/pyrecest/_backend_submodules.py b/src/pyrecest/_backend_submodules.py index 55409b6da..ea7509735 100644 --- a/src/pyrecest/_backend_submodules.py +++ b/src/pyrecest/_backend_submodules.py @@ -573,7 +573,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): "(dimension must be 2 or 3)" ) - leading_shape = numpy_module.broadcast_shapes(tuple(a.shape[:-1]), tuple(b.shape[:-1])) + leading_shape = numpy_module.broadcast_shapes( + tuple(a.shape[:-1]), tuple(b.shape[:-1]) + ) if tuple(a.shape[:-1]) != leading_shape: a = torch_module.broadcast_to(a, leading_shape + (a_dim,)) if tuple(b.shape[:-1]) != leading_shape: diff --git a/src/pyrecest/_pytorch_dot_contract.py b/src/pyrecest/_pytorch_dot_contract.py index bdebd130c..3ea6337de 100644 --- a/src/pyrecest/_pytorch_dot_contract.py +++ b/src/pyrecest/_pytorch_dot_contract.py @@ -7,8 +7,8 @@ def patch_pytorch_dot_numpy_contract() -> None: """Make raw and public PyTorch ``dot`` follow NumPy's axis contract.""" try: - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - PyTorch may be unavailable return diff --git a/src/pyrecest/backend_support/__init__.py b/src/pyrecest/backend_support/__init__.py index e11f3ec12..8318498fc 100644 --- a/src/pyrecest/backend_support/__init__.py +++ b/src/pyrecest/backend_support/__init__.py @@ -48,8 +48,8 @@ def _patch_raw_pytorch_assignment_scalar_tensor_indices() -> None: """Make raw PyTorch assignment helpers accept scalar integer tensor indices.""" try: - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch as _torch # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable return @@ -226,7 +226,9 @@ def _patch_pytorch_copy_numpy_contract() -> None: try: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_copy = raw_pytorch.copy @@ -258,7 +260,9 @@ def _patch_pytorch_clip_numpy_contract() -> None: try: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_clip = raw_pytorch.clip @@ -351,7 +355,9 @@ def _patch_pytorch_broadcast_to_numpy_contract() -> None: import numpy as np # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_broadcast_to = raw_pytorch.broadcast_to diff --git a/src/pyrecest/backend_support/_jax_random_empty_contract.py b/src/pyrecest/backend_support/_jax_random_empty_contract.py index b584a490f..00f6d88de 100644 --- a/src/pyrecest/backend_support/_jax_random_empty_contract.py +++ b/src/pyrecest/backend_support/_jax_random_empty_contract.py @@ -34,11 +34,17 @@ def _parse_randint_arguments(low, high, size, kwargs): return None return legacy_minval, legacy_maxval, size, kwargs if ( - raw_jax_random._looks_like_shape_sequence(low) # pylint: disable=protected-access + raw_jax_random._looks_like_shape_sequence( + low + ) # pylint: disable=protected-access and high is not None and size is not None - and raw_jax_random._looks_like_scalar_randint_bound(high) # pylint: disable=protected-access - and raw_jax_random._looks_like_scalar_randint_bound(size) # pylint: disable=protected-access + and raw_jax_random._looks_like_scalar_randint_bound( + high + ) # pylint: disable=protected-access + and raw_jax_random._looks_like_scalar_randint_bound( + size + ) # pylint: disable=protected-access ): return high, size, low, kwargs if high is None: @@ -55,17 +61,27 @@ def _empty_invalid_bound_result(low, high, size, args, kwargs): return None low, high, size, kwargs = parsed - low = raw_jax_random._validate_randint_bound(low, "low") # pylint: disable=protected-access - high = raw_jax_random._validate_randint_bound(high, "high") # pylint: disable=protected-access + low = raw_jax_random._validate_randint_bound( + low, "low" + ) # pylint: disable=protected-access + high = raw_jax_random._validate_randint_bound( + high, "high" + ) # pylint: disable=protected-access try: low, high = jnp.broadcast_arrays(low, high) except ValueError as exc: raise ValueError("low and high could not be broadcast together") from exc - shape = raw_jax_random._bounded_sampler_shape(size, low, high) # pylint: disable=protected-access - if not raw_jax_random._shape_has_no_samples(shape): # pylint: disable=protected-access + shape = raw_jax_random._bounded_sampler_shape( + size, low, high + ) # pylint: disable=protected-access + if not raw_jax_random._shape_has_no_samples( + shape + ): # pylint: disable=protected-access return None - state, has_state, remaining_kwargs = raw_jax_random._get_state(**kwargs) # pylint: disable=protected-access + state, has_state, remaining_kwargs = raw_jax_random._get_state( + **kwargs + ) # pylint: disable=protected-access state, result = raw_jax_random._randint( # pylint: disable=protected-access state, shape, diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py index 1be1347e3..29dee2236 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py @@ -16,7 +16,9 @@ def patch_pytorch_dtype_promotion_contract() -> None: from pyrecest._backend.pytorch._common import ( # pylint: disable=import-outside-toplevel _normalize_dtype, ) - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return _patch_pytorch_repeat_numpy_contract(raw_pytorch, torch) @@ -160,7 +162,9 @@ def _normalize_axis(axis, ndim): if axis < 0: axis += ndim if axis < 0 or axis >= ndim: - raise IndexError(f"axis {axis} is out of bounds for array of dimension {ndim}") + raise IndexError( + f"axis {axis} is out of bounds for array of dimension {ndim}" + ) return axis def _boundary(value, reference, axis): @@ -220,7 +224,10 @@ def _patch_pytorch_transpose_numpy_axes_contract(raw_pytorch, np) -> None: original_transpose = raw_pytorch.transpose if getattr(original_transpose, "_pyrecest_numpy_axes_contract", False): - if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch": + if ( + backend is not None + and getattr(backend, "__backend_name__", None) == "pytorch" + ): backend.transpose = original_transpose return @@ -247,8 +254,7 @@ def _normalize_creation_shape(shape, torch, np): normalized_shape = (_operator_index(shape_array.item()),) else: normalized_shape = tuple( - _operator_index(one_dimension) - for one_dimension in shape_array.tolist() + _operator_index(one_dimension) for one_dimension in shape_array.tolist() ) if any(one_dimension < 0 for one_dimension in normalized_shape): raise ValueError("negative dimensions are not allowed") @@ -425,7 +431,9 @@ def _patch_pytorch_randint_empty_size_contract(raw_pytorch_random, torch) -> Non active_pytorch_backend = ( backend is not None and getattr(backend, "__backend_name__", None) == "pytorch" ) - backend_random = getattr(backend, "random", None) if active_pytorch_backend else None + backend_random = ( + getattr(backend, "random", None) if active_pytorch_backend else None + ) original_randint = raw_pytorch_random.randint if getattr(original_randint, "_pyrecest_empty_size_contract", False): if active_pytorch_backend and backend_random is not None: @@ -503,7 +511,9 @@ def _normalize_pad_pairs(pad_width, ndim, np): try: pad_pairs = np.broadcast_to(pad_width_array, (ndim, 2)) except ValueError as exc: - raise ValueError(f"pad_width must be broadcastable to shape ({ndim}, 2)") from exc + raise ValueError( + f"pad_width must be broadcastable to shape ({ndim}, 2)" + ) from exc if np.any(pad_pairs < 0): raise ValueError("index can't contain negative values") return tuple( @@ -525,7 +535,9 @@ def _normalize_constant_value_pairs(constant_values, ndim, np): def _filled_pad_block(shape, value, reference, torch): """Return a constant-filled block compatible with ``reference``.""" - scalar_value = torch.as_tensor(value, dtype=reference.dtype, device=reference.device) + scalar_value = torch.as_tensor( + value, dtype=reference.dtype, device=reference.device + ) if scalar_value.ndim != 0: raise ValueError("constant_values entries must be scalar") block = torch.empty(tuple(shape), dtype=reference.dtype, device=reference.device) diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index a672197d1..81ff2e927 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -8,13 +8,17 @@ def _load_base_contract_module(): - module_path = Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py" + module_path = ( + Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py" + ) spec = importlib.util.spec_from_file_location( "_pyrecest_torch_dtype_promotion_contract_base", module_path, ) if spec is None or spec.loader is None: - raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}") + raise ImportError( + f"Cannot load PyTorch dtype contract module from {module_path}" + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module @@ -31,7 +35,9 @@ def patch_pytorch_dtype_promotion_contract() -> None: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) @@ -69,11 +75,17 @@ def assignment(x, values, indices, axis=0): return assignment -def _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) -> None: +def _patch_pytorch_assignment_numpy_index_contract( + raw_pytorch, backend, torch, np +) -> None: """Make PyTorch assignment helpers accept NumPy integer and boolean indices.""" helper_names = ("assignment", "assignment_by_sum") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_numpy_index_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_numpy_index_contract", + False, + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -227,7 +239,9 @@ def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None: "power": False, } if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helpers ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -250,7 +264,9 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None """Keep equality-style helpers on an existing non-CPU tensor device.""" helper_names = ("equal", "less_equal", "array" + "_equal") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -345,7 +361,9 @@ def _wrap_argsort_arraylike_helper(original_argsort, raw_pytorch, torch): if getattr(original_argsort, "_pyrecest_arraylike_contract", False): return original_argsort - def argsort(input, axis=-1, descending=False, stable=False, *, dim=None): # pylint: disable=redefined-builtin + def argsort( + input, axis=-1, descending=False, stable=False, *, dim=None + ): # pylint: disable=redefined-builtin if dim is not None: if axis != -1 and axis != dim: raise TypeError("argsort() got both 'axis' and 'dim'") @@ -377,7 +395,11 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non ) all_helper_names = (*helper_names, "argsort") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_arraylike_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_arraylike_contract", + False, + ) for helper_name in all_helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -395,7 +417,9 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non if getattr(backend, "__backend_name__", None) == "pytorch": setattr(backend, helper_name, wrapped_helper) - wrapped_argsort = _wrap_argsort_arraylike_helper(raw_pytorch.argsort, raw_pytorch, torch) + wrapped_argsort = _wrap_argsort_arraylike_helper( + raw_pytorch.argsort, raw_pytorch, torch + ) raw_pytorch.argsort = wrapped_argsort if getattr(backend, "__backend_name__", None) == "pytorch": backend.argsort = wrapped_argsort diff --git a/src/pyrecest/calibration/time_offset.py b/src/pyrecest/calibration/time_offset.py index a4b0ab84f..db2ed24ac 100644 --- a/src/pyrecest/calibration/time_offset.py +++ b/src/pyrecest/calibration/time_offset.py @@ -25,8 +25,14 @@ class TimeOffsetFitResult: metadata: Mapping[str, Any] = field(default_factory=dict) def summary(self) -> dict[str, Any]: - out: dict[str, Any] = {"metric": self.metric, "best_offset_s": self.best_offset_s, "evaluated_offsets": int(len(self.offsets_s))} - best_index = _best_metric_index(self.offsets_s, self.metric_values, self.counts, self.best_offset_s) + out: dict[str, Any] = { + "metric": self.metric, + "best_offset_s": self.best_offset_s, + "evaluated_offsets": int(len(self.offsets_s)), + } + best_index = _best_metric_index( + self.offsets_s, self.metric_values, self.counts, self.best_offset_s + ) if best_index is not None: out["best_metric_value"] = float(self.metric_values[best_index]) out["best_count"] = int(self.counts[best_index]) @@ -34,7 +40,12 @@ def summary(self) -> dict[str, Any]: return out -def _best_metric_index(offsets_s: np.ndarray, metric_values: np.ndarray, counts: np.ndarray, best_offset_s: float | None) -> int | None: +def _best_metric_index( + offsets_s: np.ndarray, + metric_values: np.ndarray, + counts: np.ndarray, + best_offset_s: float | None, +) -> int | None: if best_offset_s is None: return None offsets = np.asarray(offsets_s, dtype=float).reshape(-1) @@ -151,11 +162,19 @@ def apply_time_offset(times_s: np.ndarray, offset_s: float | None) -> np.ndarray def _validate_max_time_delta(max_time_delta_s: float | None) -> float | None: - return None if max_time_delta_s is None else _as_nonnegative_time_delta(max_time_delta_s, "max_time_delta_s") - - -def _finite_reference_rows(reference_times_s: np.ndarray, reference_values: np.ndarray | None = None) -> np.ndarray: - reference_times = _as_real_numeric_array(reference_times_s, "reference_times_s").reshape(-1) + return ( + None + if max_time_delta_s is None + else _as_nonnegative_time_delta(max_time_delta_s, "max_time_delta_s") + ) + + +def _finite_reference_rows( + reference_times_s: np.ndarray, reference_values: np.ndarray | None = None +) -> np.ndarray: + reference_times = _as_real_numeric_array( + reference_times_s, "reference_times_s" + ).reshape(-1) finite = np.isfinite(reference_times) if reference_values is not None: values = _as_real_numeric_array(reference_values, "reference_values") @@ -165,8 +184,12 @@ def _finite_reference_rows(reference_times_s: np.ndarray, reference_values: np.n return finite -def nearest_time_indices(reference_times_s: np.ndarray, query_times_s: np.ndarray) -> np.ndarray: - reference = _as_real_numeric_array(reference_times_s, "reference_times_s").reshape(-1) +def nearest_time_indices( + reference_times_s: np.ndarray, query_times_s: np.ndarray +) -> np.ndarray: + reference = _as_real_numeric_array(reference_times_s, "reference_times_s").reshape( + -1 + ) query = _as_real_numeric_array(query_times_s, "query_times_s").reshape(-1) finite_reference = _finite_reference_rows(reference) if not finite_reference.any(): @@ -183,14 +206,24 @@ def nearest_time_indices(reference_times_s: np.ndarray, query_times_s: np.ndarra insertion = np.searchsorted(sorted_reference, finite_query_values) right = np.clip(insertion, 0, sorted_reference.size - 1) left = np.clip(insertion - 1, 0, sorted_reference.size - 1) - use_right = np.abs(sorted_reference[right] - finite_query_values) < np.abs(sorted_reference[left] - finite_query_values) + use_right = np.abs(sorted_reference[right] - finite_query_values) < np.abs( + sorted_reference[left] - finite_query_values + ) nearest[finite_query] = original_indices[order[np.where(use_right, right, left)]] return nearest -def interpolate_reference_values(reference_times_s: np.ndarray, reference_values: np.ndarray, query_times_s: np.ndarray, *, max_time_delta_s: float | None = None) -> tuple[np.ndarray, np.ndarray]: +def interpolate_reference_values( + reference_times_s: np.ndarray, + reference_values: np.ndarray, + query_times_s: np.ndarray, + *, + max_time_delta_s: float | None = None, +) -> tuple[np.ndarray, np.ndarray]: max_time_delta = _validate_max_time_delta(max_time_delta_s) - reference_times = _as_real_numeric_array(reference_times_s, "reference_times_s").reshape(-1) + reference_times = _as_real_numeric_array( + reference_times_s, "reference_times_s" + ).reshape(-1) reference_values = _as_real_numeric_array(reference_values, "reference_values") query_times = _as_real_numeric_array(query_times_s, "query_times_s").reshape(-1) if reference_values.ndim not in (1, 2): @@ -201,14 +234,25 @@ def interpolate_reference_values(reference_times_s: np.ndarray, reference_values raise ValueError("reference_times_s length must match reference_values rows") finite_reference = _finite_reference_rows(reference_times, reference_values) if np.count_nonzero(finite_reference) < 2: - raise ValueError("at least two finite reference rows are required for interpolation") + raise ValueError( + "at least two finite reference rows are required for interpolation" + ) reference_times = reference_times[finite_reference] reference_values = reference_values[finite_reference] order = np.argsort(reference_times) reference_times = reference_times[order] reference_values = reference_values[order] - interpolated = np.column_stack([np.interp(query_times, reference_times, reference_values[:, dim]) for dim in range(reference_values.shape[1])]) - valid = np.isfinite(query_times) & (query_times >= reference_times[0]) & (query_times <= reference_times[-1]) + interpolated = np.column_stack( + [ + np.interp(query_times, reference_times, reference_values[:, dim]) + for dim in range(reference_values.shape[1]) + ] + ) + valid = ( + np.isfinite(query_times) + & (query_times >= reference_times[0]) + & (query_times <= reference_times[-1]) + ) if max_time_delta is not None: nearest = nearest_time_indices(reference_times, query_times) valid &= np.abs(reference_times[nearest] - query_times) <= max_time_delta @@ -216,39 +260,109 @@ def interpolate_reference_values(reference_times_s: np.ndarray, reference_values return interpolated, valid -def time_offset_error_summary(measurement_times_s: np.ndarray, measurement_values: np.ndarray, reference_times_s: np.ndarray, reference_values: np.ndarray, offset_s: float, *, max_time_delta_s: float | None = None) -> dict[str, float]: - measurement_values = _as_real_numeric_array(measurement_values, "measurement_values") +def time_offset_error_summary( + measurement_times_s: np.ndarray, + measurement_values: np.ndarray, + reference_times_s: np.ndarray, + reference_values: np.ndarray, + offset_s: float, + *, + max_time_delta_s: float | None = None, +) -> dict[str, float]: + measurement_values = _as_real_numeric_array( + measurement_values, "measurement_values" + ) if measurement_values.ndim == 1: measurement_values = measurement_values.reshape(-1, 1) elif measurement_values.ndim != 2: raise ValueError("measurement_values must be one- or two-dimensional") query_times = apply_time_offset(measurement_times_s, offset_s) if query_times.size != measurement_values.shape[0]: - raise ValueError("measurement_times_s length must match measurement_values rows") - reference_at_query, valid = interpolate_reference_values(reference_times_s, reference_values, query_times, max_time_delta_s=max_time_delta_s) + raise ValueError( + "measurement_times_s length must match measurement_values rows" + ) + reference_at_query, valid = interpolate_reference_values( + reference_times_s, + reference_values, + query_times, + max_time_delta_s=max_time_delta_s, + ) if measurement_values.shape[1] != reference_at_query.shape[1]: - raise ValueError("measurement_values and reference_values must have the same value dimension") + raise ValueError( + "measurement_values and reference_values must have the same value dimension" + ) valid &= np.isfinite(measurement_values).all(axis=1) - errors = np.linalg.norm(measurement_values[valid] - reference_at_query[valid], axis=1) + errors = np.linalg.norm( + measurement_values[valid] - reference_at_query[valid], axis=1 + ) return _error_stats(float(offset_s), errors, total_count=len(measurement_values)) -def time_offset_sweep(measurement_times_s: np.ndarray, measurement_values: np.ndarray, reference_times_s: np.ndarray, reference_values: np.ndarray, offsets_s: Iterable[float], *, max_time_delta_s: float | None = None) -> list[dict[str, float]]: - return [time_offset_error_summary(measurement_times_s, measurement_values, reference_times_s, reference_values, offset, max_time_delta_s=max_time_delta_s) for offset in offsets_s] - - -def fit_time_offset(measurement_times_s: np.ndarray, measurement_values: np.ndarray, reference_times_s: np.ndarray, reference_values: np.ndarray, offsets_s: Iterable[float], *, metric: str = "rmse", max_time_delta_s: float | None = None, metadata: Mapping[str, Any] | None = None) -> TimeOffsetFitResult: +def time_offset_sweep( + measurement_times_s: np.ndarray, + measurement_values: np.ndarray, + reference_times_s: np.ndarray, + reference_values: np.ndarray, + offsets_s: Iterable[float], + *, + max_time_delta_s: float | None = None, +) -> list[dict[str, float]]: + return [ + time_offset_error_summary( + measurement_times_s, + measurement_values, + reference_times_s, + reference_values, + offset, + max_time_delta_s=max_time_delta_s, + ) + for offset in offsets_s + ] + + +def fit_time_offset( + measurement_times_s: np.ndarray, + measurement_values: np.ndarray, + reference_times_s: np.ndarray, + reference_values: np.ndarray, + offsets_s: Iterable[float], + *, + metric: str = "rmse", + max_time_delta_s: float | None = None, + metadata: Mapping[str, Any] | None = None, +) -> TimeOffsetFitResult: metric = _validate_error_metric(metric) - summaries = time_offset_sweep(measurement_times_s, measurement_values, reference_times_s, reference_values, offsets_s, max_time_delta_s=max_time_delta_s) + summaries = time_offset_sweep( + measurement_times_s, + measurement_values, + reference_times_s, + reference_values, + offsets_s, + max_time_delta_s=max_time_delta_s, + ) offsets = np.array([row["time_offset_s"] for row in summaries], dtype=float) values = np.array([row[metric] for row in summaries], dtype=float) counts = np.array([row.get("count", 0.0) for row in summaries], dtype=float) finite = np.isfinite(values) & (counts > 0) - best = None if not finite.any() else float(offsets[np.where(finite)[0][int(np.nanargmin(values[finite]))]]) - return TimeOffsetFitResult(best_offset_s=best, metric=metric, offsets_s=offsets, metric_values=values, counts=counts.astype(int), summaries=summaries, metadata={} if metadata is None else dict(metadata)) - - -def aggregate_time_offset_sweeps(sweeps: Iterable[Iterable[Mapping[str, float]]], *, metric: str = "rmse") -> list[dict[str, float]]: + best = ( + None + if not finite.any() + else float(offsets[np.where(finite)[0][int(np.nanargmin(values[finite]))]]) + ) + return TimeOffsetFitResult( + best_offset_s=best, + metric=metric, + offsets_s=offsets, + metric_values=values, + counts=counts.astype(int), + summaries=summaries, + metadata={} if metadata is None else dict(metadata), + ) + + +def aggregate_time_offset_sweeps( + sweeps: Iterable[Iterable[Mapping[str, float]]], *, metric: str = "rmse" +) -> list[dict[str, float]]: metric = _validate_error_metric(metric) by_offset: dict[float, list[Mapping[str, float]]] = {} for sweep in sweeps: @@ -257,12 +371,32 @@ def aggregate_time_offset_sweeps(sweeps: Iterable[Iterable[Mapping[str, float]]] by_offset.setdefault(offset, []).append(row) rows: list[dict[str, float]] = [] for offset, parts in sorted(by_offset.items()): - counts = np.array([_as_nonnegative_summary_count(part.get("count", 0.0), "count") for part in parts], dtype=float) + counts = np.array( + [ + _as_nonnegative_summary_count(part.get("count", 0.0), "count") + for part in parts + ], + dtype=float, + ) row = {"time_offset_s": float(offset), "count": float(np.sum(counts))} for key in dict.fromkeys(("mean", "std", "rmse", "p95", "max", metric)): - values = np.array([_as_summary_scalar(part.get(key, np.nan), str(key), allow_nan=True) for part in parts], dtype=float) + values = np.array( + [ + _as_summary_scalar(part.get(key, np.nan), str(key), allow_nan=True) + for part in parts + ], + dtype=float, + ) if key == "std": - means = np.array([_as_summary_scalar(part.get("mean", np.nan), "mean", allow_nan=True) for part in parts], dtype=float) + means = np.array( + [ + _as_summary_scalar( + part.get("mean", np.nan), "mean", allow_nan=True + ) + for part in parts + ], + dtype=float, + ) row[key] = _aggregate_std_metric(values, means, counts) else: row[key] = _aggregate_summary_metric(key, values, counts) @@ -270,7 +404,9 @@ def aggregate_time_offset_sweeps(sweeps: Iterable[Iterable[Mapping[str, float]]] return rows -def _aggregate_summary_metric(key: str, values: np.ndarray, counts: np.ndarray) -> float: +def _aggregate_summary_metric( + key: str, values: np.ndarray, counts: np.ndarray +) -> float: valid = np.isfinite(values) & (counts > 0.0) if not valid.any(): return float("nan") @@ -281,22 +417,58 @@ def _aggregate_summary_metric(key: str, values: np.ndarray, counts: np.ndarray) return float(np.average(values[valid], weights=counts[valid])) -def _aggregate_std_metric(stds: np.ndarray, means: np.ndarray, counts: np.ndarray) -> float: +def _aggregate_std_metric( + stds: np.ndarray, means: np.ndarray, counts: np.ndarray +) -> float: valid = np.isfinite(stds) & np.isfinite(means) & (counts > 0.0) if not valid.any(): return float("nan") weights = counts[valid] pooled_mean = float(np.average(means[valid], weights=weights)) - second_moment = float(np.average(stds[valid] ** 2 + means[valid] ** 2, weights=weights)) + second_moment = float( + np.average(stds[valid] ** 2 + means[valid] ** 2, weights=weights) + ) return float(np.sqrt(max(0.0, second_moment - pooled_mean**2))) -def _error_stats(offset_s: float, errors: np.ndarray, *, total_count: int) -> dict[str, float]: +def _error_stats( + offset_s: float, errors: np.ndarray, *, total_count: int +) -> dict[str, float]: errors = np.asarray(errors, dtype=float).reshape(-1) errors = errors[np.isfinite(errors)] if errors.size == 0: - return {"time_offset_s": float(offset_s), "count": 0.0, "coverage": 0.0 if total_count else float("nan"), "mean": float("nan"), "std": float("nan"), "rmse": float("nan"), "p95": float("nan"), "max": float("nan")} - return {"time_offset_s": float(offset_s), "count": float(errors.size), "coverage": float(errors.size / total_count) if total_count > 0 else float("nan"), "mean": float(np.mean(errors)), "std": float(np.std(errors)), "rmse": float(np.sqrt(np.mean(errors**2))), "p95": float(np.percentile(errors, 95)), "max": float(np.max(errors))} - - -__all__ = ["TimeOffsetFitResult", "aggregate_time_offset_sweeps", "apply_time_offset", "fit_time_offset", "interpolate_reference_values", "make_offset_grid", "nearest_time_indices", "time_offset_error_summary", "time_offset_sweep"] + return { + "time_offset_s": float(offset_s), + "count": 0.0, + "coverage": 0.0 if total_count else float("nan"), + "mean": float("nan"), + "std": float("nan"), + "rmse": float("nan"), + "p95": float("nan"), + "max": float("nan"), + } + return { + "time_offset_s": float(offset_s), + "count": float(errors.size), + "coverage": ( + float(errors.size / total_count) if total_count > 0 else float("nan") + ), + "mean": float(np.mean(errors)), + "std": float(np.std(errors)), + "rmse": float(np.sqrt(np.mean(errors**2))), + "p95": float(np.percentile(errors, 95)), + "max": float(np.max(errors)), + } + + +__all__ = [ + "TimeOffsetFitResult", + "aggregate_time_offset_sweeps", + "apply_time_offset", + "fit_time_offset", + "interpolate_reference_values", + "make_offset_grid", + "nearest_time_indices", + "time_offset_error_summary", + "time_offset_sweep", +] diff --git a/src/pyrecest/distributions/abstract_dirac_distribution.py b/src/pyrecest/distributions/abstract_dirac_distribution.py index ac838347a..dd0572d95 100644 --- a/src/pyrecest/distributions/abstract_dirac_distribution.py +++ b/src/pyrecest/distributions/abstract_dirac_distribution.py @@ -13,7 +13,9 @@ arange, argmax, asarray, - copy as backend_copy, +) +from pyrecest.backend import copy as backend_copy +from pyrecest.backend import ( int32, int64, isclose, diff --git a/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py b/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py index 13eebb508..e7aa7dc6a 100644 --- a/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py +++ b/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py @@ -11,7 +11,9 @@ linalg, ones, pi, - random as backend_random, +) +from pyrecest.backend import random as backend_random +from pyrecest.backend import ( sin, sqrt, stack, diff --git a/src/pyrecest/evaluation/get_extract_mean.py b/src/pyrecest/evaluation/get_extract_mean.py index f387b2d93..78414f45b 100644 --- a/src/pyrecest/evaluation/get_extract_mean.py +++ b/src/pyrecest/evaluation/get_extract_mean.py @@ -88,7 +88,9 @@ def _extract_mtt_mean(filter_state): def get_extract_mean(manifold_name, mtt_scenario=False): normalized_name = _normalize_registry_name(manifold_name) - is_mtt_scenario = _coerce_mtt_scenario_flag(mtt_scenario) or "mtt" in normalized_name + is_mtt_scenario = ( + _coerce_mtt_scenario_flag(mtt_scenario) or "mtt" in normalized_name + ) registered_factory = _EXTRACT_MEAN_FACTORIES.get(normalized_name) if registered_factory is not None: return registered_factory(manifold_name, is_mtt_scenario) diff --git a/src/pyrecest/filters/adaptive_process_noise.py b/src/pyrecest/filters/adaptive_process_noise.py index 25bed3bce..6383ff5d5 100644 --- a/src/pyrecest/filters/adaptive_process_noise.py +++ b/src/pyrecest/filters/adaptive_process_noise.py @@ -151,7 +151,9 @@ def observe( accepted = _normalize_bool_flag(accepted, "accepted") if not accepted: return self.ratios_by_source.get(source, 1.0) - measurement_dim = _normalize_positive_integer(measurement_dim, "measurement_dim") + measurement_dim = _normalize_positive_integer( + measurement_dim, "measurement_dim" + ) nis = _normalize_nonnegative_finite_scalar(nis, "nis") ratio = nis / float(measurement_dim) previous = self.ratios_by_source.get(source, 1.0) diff --git a/src/pyrecest/filters/wrapped_normal_filter.py b/src/pyrecest/filters/wrapped_normal_filter.py index e2de0e247..df25d8d77 100644 --- a/src/pyrecest/filters/wrapped_normal_filter.py +++ b/src/pyrecest/filters/wrapped_normal_filter.py @@ -140,7 +140,9 @@ def update_nonlinear_progressive( ) if current_lambda <= 0: - raise ValueError("Progressive update with given threshold impossible") + raise ValueError( + "Progressive update with given threshold impossible" + ) current_lambda = maximum(current_lambda, MINIMUM_LAMBDA) current_lambda = minimum(current_lambda, lambda_) diff --git a/src/pyrecest/models/additive_noise.py b/src/pyrecest/models/additive_noise.py index f5effad75..1bb028967 100644 --- a/src/pyrecest/models/additive_noise.py +++ b/src/pyrecest/models/additive_noise.py @@ -411,7 +411,9 @@ def has_jacobian(self): def measurement_residual(self, measurement, state, **kwargs): """Return ``measurement - h(state)``.""" - return _array_difference(measurement, self.measurement_function(state, **kwargs)) + return _array_difference( + measurement, self.measurement_function(state, **kwargs) + ) def sample_measurement(self, state, n: int = 1, **kwargs): """Draw ``n`` samples from ``p(measurement | state)``.""" diff --git a/src/pyrecest/models/validation.py b/src/pyrecest/models/validation.py index cd977d367..2502ba0a8 100644 --- a/src/pyrecest/models/validation.py +++ b/src/pyrecest/models/validation.py @@ -384,7 +384,9 @@ def infer_state_dim_from_distribution( if hasattr(distribution, "d"): try: - support = _maybe_call(getattr(distribution, "d"), allow_methods=allow_methods) + support = _maybe_call( + getattr(distribution, "d"), allow_methods=allow_methods + ) support = validate_matrix(support, name="distribution.d") return _shape_tuple(support)[1] except (TypeError, ValueError): diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 8bc49080d..f4fe020a5 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -33,7 +33,11 @@ def _patch_pytorch_raw_comparison_arraylike_contract() -> None: active_pytorch_backend = getattr(backend, "__backend_name__", None) == "pytorch" helper_names = ("greater", "less", "logical_or") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_arraylike_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_arraylike_contract", + False, + ) for helper_name in helper_names ): if active_pytorch_backend: @@ -177,7 +181,9 @@ def squeeze(a, axis=None): if not axes: return a - normalized_axes = tuple(one_axis + a.ndim if one_axis < 0 else one_axis for one_axis in axes) + normalized_axes = tuple( + one_axis + a.ndim if one_axis < 0 else one_axis for one_axis in axes + ) for one_axis, normalized_axis in zip(axes, normalized_axes): if normalized_axis < 0 or normalized_axis >= a.ndim: raise ValueError( @@ -188,7 +194,9 @@ def squeeze(a, axis=None): if any(a.shape[one_axis] != 1 for one_axis in normalized_axes): return a - squeeze_axis = normalized_axes[0] if len(normalized_axes) == 1 else normalized_axes + squeeze_axis = ( + normalized_axes[0] if len(normalized_axes) == 1 else normalized_axes + ) return original_squeeze(a, axis=squeeze_axis) squeeze.__name__ = getattr(original_squeeze, "__name__", "squeeze") diff --git a/src/pyrecest/utils/__init__.py b/src/pyrecest/utils/__init__.py index 5e7e6c432..507f17f3c 100644 --- a/src/pyrecest/utils/__init__.py +++ b/src/pyrecest/utils/__init__.py @@ -163,7 +163,9 @@ def _fit_standardization(self, features): "__name__", "_fit_standardization", ) - _fit_standardization.__doc__ = getattr(original_fit_standardization, "__doc__", None) + _fit_standardization.__doc__ = getattr( + original_fit_standardization, "__doc__", None + ) _fit_standardization._pyrecest_backend_std_contract = True LogisticPairwiseAssociationModel._fit_standardization = _fit_standardization @@ -193,7 +195,9 @@ def _prepare_prediction_features(features, expected_feature_dimension): if not _backend.all(_backend.isfinite(flattened)): raise ValueError("features must be finite") return flattened, () - return original_prepare_prediction_features(features, expected_feature_dimension) + return original_prepare_prediction_features( + features, expected_feature_dimension + ) _prepare_prediction_features.__name__ = getattr( original_prepare_prediction_features, @@ -323,7 +327,9 @@ def _prepare_prediction_features(features, expected_feature_dimension): ) _multisession_assignment_module.tracks_to_session_labels = tracks_to_session_labels -_multisession_assignment_module._validate_scalar_cost = _validate_multisession_scalar_cost +_multisession_assignment_module._validate_scalar_cost = ( + _validate_multisession_scalar_cost +) __all__ = [ "MultiSessionAssignmentResult", diff --git a/src/pyrecest/utils/candidate_pruning.py b/src/pyrecest/utils/candidate_pruning.py index f602698c0..e0ab810b8 100644 --- a/src/pyrecest/utils/candidate_pruning.py +++ b/src/pyrecest/utils/candidate_pruning.py @@ -377,7 +377,9 @@ def _as_probability_matrix( if probabilities.shape != shape: raise ValueError("probability_matrix must match cost_matrix shape") if np.any(np.isinf(probabilities)): - raise ValueError("probability_matrix may only contain finite probabilities or NaN") + raise ValueError( + "probability_matrix may only contain finite probabilities or NaN" + ) finite = np.isfinite(probabilities) if np.any(finite & ((probabilities < 0.0) | (probabilities > 1.0))): raise ValueError("finite probability_matrix entries must lie in [0, 1]") diff --git a/src/pyrecest/utils/cost_matrix_adjustments.py b/src/pyrecest/utils/cost_matrix_adjustments.py index 8ae3a5185..69664718c 100644 --- a/src/pyrecest/utils/cost_matrix_adjustments.py +++ b/src/pyrecest/utils/cost_matrix_adjustments.py @@ -81,7 +81,11 @@ def apply( def apply_cost_matrix_adjustment( cost_matrix: Any, - adjustment: CostMatrixAdjustment | Callable[[np.ndarray], Any] | CallableCostMatrixAdjustment, + adjustment: ( + CostMatrixAdjustment + | Callable[[np.ndarray], Any] + | CallableCostMatrixAdjustment + ), *, metadata: Mapping[str, Any] | None = None, ) -> CostMatrixAdjustmentResult: @@ -105,7 +109,9 @@ def apply_cost_matrix_adjustment( def compose_cost_matrix_adjustments( cost_matrix: Any, adjustments: Sequence[ - CostMatrixAdjustment | Callable[[np.ndarray], Any] | CallableCostMatrixAdjustment + CostMatrixAdjustment + | Callable[[np.ndarray], Any] + | CallableCostMatrixAdjustment ], *, metadata: Mapping[str, Any] | None = None, @@ -146,7 +152,9 @@ def additive_cost_matrix_adjustment( penalty = _as_cost_matrix(penalty_matrix) stored_diagnostics = dict(diagnostics or {}) - def _add(matrix: np.ndarray, _metadata: Mapping[str, Any]) -> CostMatrixAdjustmentResult: + def _add( + matrix: np.ndarray, _metadata: Mapping[str, Any] + ) -> CostMatrixAdjustmentResult: if matrix.shape != penalty.shape: raise ValueError( f"penalty_matrix shape {penalty.shape} does not match cost_matrix shape {matrix.shape}" diff --git a/tests/backend/test_numpy_random_size_validation.py b/tests/backend/test_numpy_random_size_validation.py index 1c47b496c..35ad26129 100644 --- a/tests/backend/test_numpy_random_size_validation.py +++ b/tests/backend/test_numpy_random_size_validation.py @@ -1,9 +1,7 @@ import numpy as np import pytest - from pyrecest._backend.numpy import random - _INVALID_SIZE_ARGUMENTS = ( True, np.bool_(True), diff --git a/tests/backend/test_numpy_random_uniform_validation.py b/tests/backend/test_numpy_random_uniform_validation.py index d9a1f88a8..a83df27ea 100644 --- a/tests/backend/test_numpy_random_uniform_validation.py +++ b/tests/backend/test_numpy_random_uniform_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest._backend.numpy import random diff --git a/tests/backend/test_numpy_uniform_ragged_bounds.py b/tests/backend/test_numpy_uniform_ragged_bounds.py index 2fe03a215..ed70556cd 100644 --- a/tests/backend/test_numpy_uniform_ragged_bounds.py +++ b/tests/backend/test_numpy_uniform_ragged_bounds.py @@ -1,5 +1,4 @@ import pytest - from pyrecest._backend import numpy as backend diff --git a/tests/backend/test_pytorch_comparison_device_contract.py b/tests/backend/test_pytorch_comparison_device_contract.py index 906371969..205912146 100644 --- a/tests/backend/test_pytorch_comparison_device_contract.py +++ b/tests/backend/test_pytorch_comparison_device_contract.py @@ -1,10 +1,9 @@ import pytest - torch = pytest.importorskip("torch") -import pyrecest.backend_tools # noqa: E402,F401 import pyrecest._backend.pytorch as pytorch_backend # noqa: E402 +import pyrecest.backend_tools # noqa: E402,F401 def _non_cpu_device(): @@ -72,7 +71,9 @@ def test_raw_pytorch_isclose_prefers_existing_non_cpu_device_for_right_operand() def test_raw_pytorch_allclose_accepts_arraylike_against_cuda_operand(): if not torch.cuda.is_available(): - pytest.skip("allclose returns a host bool and cannot be exercised on meta tensors") + pytest.skip( + "allclose returns a host bool and cannot be exercised on meta tensors" + ) right = torch.tensor([1.0, float("nan")], device="cuda") diff --git a/tests/backend_support/conftest.py b/tests/backend_support/conftest.py index 76b59c91e..a3c77d2e7 100644 --- a/tests/backend_support/conftest.py +++ b/tests/backend_support/conftest.py @@ -1,6 +1,5 @@ from pathlib import Path - _placeholder_regression = Path(__file__).with_name( "test_pytorch_dot_outer_device_contract.py" ) diff --git a/tests/backend_support/test_common_dot_contract.py b/tests/backend_support/test_common_dot_contract.py index 72d6d5c4d..89e52db47 100644 --- a/tests/backend_support/test_common_dot_contract.py +++ b/tests/backend_support/test_common_dot_contract.py @@ -1,7 +1,6 @@ import numpy as np import numpy.testing as npt import pytest - from pyrecest._backend import _common as common diff --git a/tests/backend_support/test_jax_one_hot_contract.py b/tests/backend_support/test_jax_one_hot_contract.py index 44b47bec1..b96d4cba1 100644 --- a/tests/backend_support/test_jax_one_hot_contract.py +++ b/tests/backend_support/test_jax_one_hot_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_jax_take_out_contract.py b/tests/backend_support/test_jax_take_out_contract.py index 36fc8ef8a..9022629b4 100644 --- a/tests/backend_support/test_jax_take_out_contract.py +++ b/tests/backend_support/test_jax_take_out_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_broadcast_to_contract.py b/tests/backend_support/test_pytorch_broadcast_to_contract.py index 924564f0b..43aed0f45 100644 --- a/tests/backend_support/test_pytorch_broadcast_to_contract.py +++ b/tests/backend_support/test_pytorch_broadcast_to_contract.py @@ -61,7 +61,9 @@ def test_pytorch_broadcast_to_shape_inputs_match_numpy_contract(): else: raise AssertionError(f"broadcast_to accepted boolean shape {invalid_shape!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) @pytest.mark.backend_portable @@ -101,4 +103,6 @@ def test_raw_pytorch_broadcast_to_shape_inputs_with_numpy_public_backend(): else: raise AssertionError(f"raw broadcast_to accepted boolean shape {invalid_shape!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("numpy")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("numpy") + ) diff --git a/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py b/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py index a1e117efd..35cfea6e1 100644 --- a/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py +++ b/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend import _common as common diff --git a/tests/backend_support/test_pytorch_diff_contract.py b/tests/backend_support/test_pytorch_diff_contract.py index 3413f6a82..4f37605cb 100644 --- a/tests/backend_support/test_pytorch_diff_contract.py +++ b/tests/backend_support/test_pytorch_diff_contract.py @@ -5,7 +5,6 @@ import pytest - SCRIPT = """ import pyrecest # noqa: F401 # triggers raw-backend compatibility patches import torch diff --git a/tests/backend_support/test_pytorch_dot_outer_device_contract.py b/tests/backend_support/test_pytorch_dot_outer_device_contract.py index 580ea72aa..4d50885e3 100644 --- a/tests/backend_support/test_pytorch_dot_outer_device_contract.py +++ b/tests/backend_support/test_pytorch_dot_outer_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_equality_device_contract.py b/tests/backend_support/test_pytorch_equality_device_contract.py index 225f6a4b0..c6cab6fdc 100644 --- a/tests/backend_support/test_pytorch_equality_device_contract.py +++ b/tests/backend_support/test_pytorch_equality_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py b/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py index b5a48a456..6efae09c0 100644 --- a/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py +++ b/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py @@ -57,4 +57,6 @@ def test_pytorch_fft_aliases_accept_matching_numpy_array_axes(): else: raise AssertionError("conflicting FFT axis aliases were accepted") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) diff --git a/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py b/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py index 96a4a374a..fe42f1bbb 100644 --- a/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py +++ b/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py @@ -17,7 +17,9 @@ def test_raw_pytorch_real_fft_accepts_numpy_scalar_array_axis_alias(): npt.assert_allclose(spectrum.numpy(), np.fft.rfft(vector, axis=axis)) reconstructed = pytorch_fft.irfft(spectrum, n=vector.size, axis=axis) - expected = np.fft.irfft(np.fft.rfft(vector, axis=axis), n=vector.size, axis=axis) + expected = np.fft.irfft( + np.fft.rfft(vector, axis=axis), n=vector.size, axis=axis + ) npt.assert_allclose(reconstructed.numpy(), expected) diff --git a/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py b/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py index e22a3e487..0e19a83f2 100644 --- a/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py +++ b/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py @@ -43,7 +43,9 @@ def test_public_pytorch_isclose_accepts_equal_nan_keyword_when_selected(): right = [np.nan, 1.0 + 1e-9, 2.0] result = backend.isclose(left, right, equal_nan=True) - expected = np.isclose(left, right, rtol=backend.rtol, atol=backend.atol, equal_nan=True) + expected = np.isclose( + left, right, rtol=backend.rtol, atol=backend.atol, equal_nan=True + ) npt.assert_array_equal(backend.to_numpy(result), expected) result_without_nan_match = backend.isclose(left, right, equal_nan=False) @@ -54,4 +56,6 @@ def test_public_pytorch_isclose_accepts_equal_nan_keyword_when_selected(): atol=backend.atol, equal_nan=False, ) - npt.assert_array_equal(backend.to_numpy(result_without_nan_match), expected_without_nan_match) + npt.assert_array_equal( + backend.to_numpy(result_without_nan_match), expected_without_nan_match + ) diff --git a/tests/backend_support/test_pytorch_matmul_device_contract.py b/tests/backend_support/test_pytorch_matmul_device_contract.py index 0232750c6..40a0160cb 100644 --- a/tests/backend_support/test_pytorch_matmul_device_contract.py +++ b/tests/backend_support/test_pytorch_matmul_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_minmax_device_contract.py b/tests/backend_support/test_pytorch_minmax_device_contract.py index b1dc3097a..491b278b4 100644 --- a/tests/backend_support/test_pytorch_minmax_device_contract.py +++ b/tests/backend_support/test_pytorch_minmax_device_contract.py @@ -5,7 +5,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_non_native_dtype_contract.py b/tests/backend_support/test_pytorch_non_native_dtype_contract.py index a53d1d04c..393c137c7 100644 --- a/tests/backend_support/test_pytorch_non_native_dtype_contract.py +++ b/tests/backend_support/test_pytorch_non_native_dtype_contract.py @@ -1,5 +1,4 @@ import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_pad_contract.py b/tests/backend_support/test_pytorch_pad_contract.py index b834d351b..bf133d236 100644 --- a/tests/backend_support/test_pytorch_pad_contract.py +++ b/tests/backend_support/test_pytorch_pad_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_randint_empty_size_contract.py b/tests/backend_support/test_pytorch_randint_empty_size_contract.py index 3662d09fb..f9727bb81 100644 --- a/tests/backend_support/test_pytorch_randint_empty_size_contract.py +++ b/tests/backend_support/test_pytorch_randint_empty_size_contract.py @@ -1,7 +1,6 @@ from __future__ import annotations import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_rotation_stub_contract.py b/tests/backend_support/test_pytorch_rotation_stub_contract.py index 4223cb30a..e101d5169 100644 --- a/tests/backend_support/test_pytorch_rotation_stub_contract.py +++ b/tests/backend_support/test_pytorch_rotation_stub_contract.py @@ -5,7 +5,6 @@ import importlib.util import pytest - from pyrecest.exceptions import BackendNotSupportedError diff --git a/tests/backend_support/test_pytorch_rotation_stub_methods.py b/tests/backend_support/test_pytorch_rotation_stub_methods.py index 3901c92aa..fa3fb8af4 100644 --- a/tests/backend_support/test_pytorch_rotation_stub_methods.py +++ b/tests/backend_support/test_pytorch_rotation_stub_methods.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest - from pyrecest.exceptions import BackendNotSupportedError diff --git a/tests/backend_support/test_pytorch_solve_sylvester_device.py b/tests/backend_support/test_pytorch_solve_sylvester_device.py index 5361e190e..504749749 100644 --- a/tests/backend_support/test_pytorch_solve_sylvester_device.py +++ b/tests/backend_support/test_pytorch_solve_sylvester_device.py @@ -5,7 +5,6 @@ torch: Any try: import torch - from pyrecest._backend import pytorch as pytorch_backend except ModuleNotFoundError: torch = None @@ -31,11 +30,11 @@ def test_solve_sylvester_keeps_common_dtype_for_mixed_array_like_inputs(self): self.assertEqual(result.dtype, pytorch_backend.float64) self.assertTrue(pytorch_backend.allclose(result, expected)) - @unittest.skipIf(torch is None or not torch.cuda.is_available(), "CUDA is not available") + @unittest.skipIf( + torch is None or not torch.cuda.is_available(), "CUDA is not available" + ) def test_solve_sylvester_aligns_mixed_tensor_devices(self): - a = torch.tensor( - [[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32, device="cuda" - ) + a = torch.tensor([[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32, device="cuda") b = torch.tensor([[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32) q = torch.tensor([[8.0, 10.0], [10.0, 12.0]], dtype=torch.float64) diff --git a/tests/backend_support/test_pytorch_special_contract.py b/tests/backend_support/test_pytorch_special_contract.py index 9e422f729..420a563e4 100644 --- a/tests/backend_support/test_pytorch_special_contract.py +++ b/tests/backend_support/test_pytorch_special_contract.py @@ -29,8 +29,7 @@ def test_raw_pytorch_special_helpers_are_patched_under_default_backend(): if importlib.util.find_spec("torch") is None: pytest.skip("torch is not installed") - _run_python( - r''' + _run_python(r""" import math import pyrecest # noqa: F401 @@ -70,8 +69,7 @@ def assert_close(values, expected): assert math.isnan(pole_values[0]) assert math.isinf(pole_values[1]) and pole_values[1] > 0 assert math.isinf(pole_values[2]) and pole_values[2] < 0 -''' - ) +""") @pytest.mark.backend_portable @@ -80,7 +78,7 @@ def test_public_pytorch_special_helpers_accept_arraylike_and_out(): pytest.skip("torch is not installed") _run_python( - r''' + r""" import math import pyrecest.backend as backend @@ -116,6 +114,6 @@ def assert_close(module, values, expected): special_backend.polygamma(1, [1.0, 2.0]), [math.pi**2 / 6.0, math.pi**2 / 6.0 - 1.0], ) -''', +""", backend_name="pytorch", ) diff --git a/tests/backend_support/test_pytorch_tile_contract.py b/tests/backend_support/test_pytorch_tile_contract.py index 23fb45e66..ba3bab2c1 100644 --- a/tests/backend_support/test_pytorch_tile_contract.py +++ b/tests/backend_support/test_pytorch_tile_contract.py @@ -54,7 +54,9 @@ def test_pytorch_tile_scalar_and_array_repetitions_match_numpy_contract(): else: raise AssertionError(f"tile accepted non-integer repetitions {bad_reps!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) @pytest.mark.backend_portable @@ -90,4 +92,6 @@ def test_raw_pytorch_tile_matches_numpy_contract_with_numpy_public_backend(): else: raise AssertionError(f"raw tile accepted non-integer repetitions {bad_reps!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("numpy")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("numpy") + ) diff --git a/tests/calibration/test_time_offset_real_numeric_validation.py b/tests/calibration/test_time_offset_real_numeric_validation.py index 5eebd8208..5b8a73b05 100644 --- a/tests/calibration/test_time_offset_real_numeric_validation.py +++ b/tests/calibration/test_time_offset_real_numeric_validation.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.calibration import apply_time_offset, make_offset_grid diff --git a/tests/distributions/test_abstract_mixture_sample_validation.py b/tests/distributions/test_abstract_mixture_sample_validation.py index 889ec04e2..fac1ab9f6 100644 --- a/tests/distributions/test_abstract_mixture_sample_validation.py +++ b/tests/distributions/test_abstract_mixture_sample_validation.py @@ -1,15 +1,18 @@ import numpy as np import pytest - from pyrecest.backend import array, eye from pyrecest.distributions.hypertorus.hypertoroidal_mixture import HypertoroidalMixture -from pyrecest.distributions.hypertorus.toroidal_wrapped_normal_distribution import ToroidalWrappedNormalDistribution +from pyrecest.distributions.hypertorus.toroidal_wrapped_normal_distribution import ( + ToroidalWrappedNormalDistribution, +) @pytest.mark.parametrize("count", ["3", np.str_("3")]) def test_mixture_sample_rejects_text_count(count): vmf = ToroidalWrappedNormalDistribution(array([1.0, 0.0]), eye(2)) - mixture = HypertoroidalMixture([vmf, vmf.shift(array([1.0, 1.0]))], array([0.5, 0.5])) + mixture = HypertoroidalMixture( + [vmf, vmf.shift(array([1.0, 1.0]))], array([0.5, 0.5]) + ) with pytest.raises(ValueError, match="n must be a positive integer"): mixture.sample(count) diff --git a/tests/distributions/test_dirac_tensor_copy_contract.py b/tests/distributions/test_dirac_tensor_copy_contract.py index ae746ae58..4b8668c62 100644 --- a/tests/distributions/test_dirac_tensor_copy_contract.py +++ b/tests/distributions/test_dirac_tensor_copy_contract.py @@ -39,4 +39,6 @@ def test_pytorch_dirac_distribution_clones_input_tensor_storage(): assert dist.d.data_ptr() != samples.data_ptr() assert dist.w.data_ptr() != weights.data_ptr() """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) diff --git a/tests/distributions/test_gaussian_distribution_sample_validation.py b/tests/distributions/test_gaussian_distribution_sample_validation.py index 8a8f8864d..0d91f54fc 100644 --- a/tests/distributions/test_gaussian_distribution_sample_validation.py +++ b/tests/distributions/test_gaussian_distribution_sample_validation.py @@ -1,8 +1,9 @@ import numpy as np import pytest - from pyrecest.backend import array, eye -from pyrecest.distributions.nonperiodic.gaussian_distribution import GaussianDistribution +from pyrecest.distributions.nonperiodic.gaussian_distribution import ( + GaussianDistribution, +) @pytest.mark.parametrize("count", ["3", b"3", np.str_("3"), np.bytes_(b"3")]) diff --git a/tests/distributions/test_uniform_order_validation.py b/tests/distributions/test_uniform_order_validation.py index 8c3d08623..2964121d8 100644 --- a/tests/distributions/test_uniform_order_validation.py +++ b/tests/distributions/test_uniform_order_validation.py @@ -1,7 +1,8 @@ import pytest - from pyrecest.backend import array -from pyrecest.distributions.hypertorus.hypertoroidal_uniform_distribution import HypertoroidalUniformDistribution +from pyrecest.distributions.hypertorus.hypertoroidal_uniform_distribution import ( + HypertoroidalUniformDistribution, +) def test_uniform_moment_rejects_invalid_order(): diff --git a/tests/evaluation/test_flat_empty_measurements.py b/tests/evaluation/test_flat_empty_measurements.py index 530b5cc48..c6372be4e 100644 --- a/tests/evaluation/test_flat_empty_measurements.py +++ b/tests/evaluation/test_flat_empty_measurements.py @@ -1,6 +1,5 @@ import numpy as np import numpy.testing as npt - from pyrecest.backend import array from pyrecest.evaluation import perform_predict_update_cycles from pyrecest.evaluation.configure_for_filter import register_filter_factory diff --git a/tests/evaluation/test_get_extract_mean_mtt_flag.py b/tests/evaluation/test_get_extract_mean_mtt_flag.py index e796e019d..b0b4e389f 100644 --- a/tests/evaluation/test_get_extract_mean_mtt_flag.py +++ b/tests/evaluation/test_get_extract_mean_mtt_flag.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evaluation.get_extract_mean import get_extract_mean diff --git a/tests/evaluation/test_iterate_configs_vector_parameter_shape.py b/tests/evaluation/test_iterate_configs_vector_parameter_shape.py index 9814dfff9..e8b0b0768 100644 --- a/tests/evaluation/test_iterate_configs_vector_parameter_shape.py +++ b/tests/evaluation/test_iterate_configs_vector_parameter_shape.py @@ -9,7 +9,9 @@ def test_iterate_configs_and_runs_uses_config_count_for_vector_parameters(monkey if backend.__backend_name__ not in ("numpy", "autograd"): pytest.skip("iterate_configs_and_runs stores object-valued filter states") - iterate_module = importlib.import_module("pyrecest.evaluation.iterate_configs_and_runs") + iterate_module = importlib.import_module( + "pyrecest.evaluation.iterate_configs_and_runs" + ) vector_parameter = np.array([1.0, 2.0]) calls = [] @@ -46,12 +48,14 @@ def fake_predict_update_cycles( "auto_warning_on_off": False, } - last_filter_states, runtimes, run_failed, *_ = iterate_module.iterate_configs_and_runs( - groundtruths, - measurements, - {"name": "dummy"}, - [{"name": "dummy_filter", "parameter": vector_parameter}], - evaluation_config, + last_filter_states, runtimes, run_failed, *_ = ( + iterate_module.iterate_configs_and_runs( + groundtruths, + measurements, + {"name": "dummy"}, + [{"name": "dummy_filter", "parameter": vector_parameter}], + evaluation_config, + ) ) assert np.shape(last_filter_states) == (1, 2) diff --git a/tests/evaluation/test_zero_measurements.py b/tests/evaluation/test_zero_measurements.py index 147ca1a39..8e2072953 100644 --- a/tests/evaluation/test_zero_measurements.py +++ b/tests/evaluation/test_zero_measurements.py @@ -1,5 +1,4 @@ import numpy as np - from pyrecest.distributions import GaussianDistribution from pyrecest.evaluation.generate_measurements import generate_measurements diff --git a/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py b/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py index 4a85c07ac..746ed25db 100644 --- a/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py +++ b/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest.filters import HypertoroidalParticleFilter diff --git a/tests/filters/test_wrapped_normal_filter_constant_likelihood.py b/tests/filters/test_wrapped_normal_filter_constant_likelihood.py index dac5dbc0e..080a2ebc0 100644 --- a/tests/filters/test_wrapped_normal_filter_constant_likelihood.py +++ b/tests/filters/test_wrapped_normal_filter_constant_likelihood.py @@ -1,5 +1,4 @@ import numpy.testing as npt - from pyrecest.backend import array from pyrecest.distributions import WrappedNormalDistribution from pyrecest.filters.wrapped_normal_filter import WrappedNormalFilter diff --git a/tests/test_common_reduction_axis_scalar.py b/tests/test_common_reduction_axis_scalar.py index 5be0adc2a..2ad9d4fd1 100644 --- a/tests/test_common_reduction_axis_scalar.py +++ b/tests/test_common_reduction_axis_scalar.py @@ -1,5 +1,4 @@ import numpy as np - from pyrecest._backend import _common diff --git a/tests/test_common_reduction_axis_validation.py b/tests/test_common_reduction_axis_validation.py index a4b3b9d78..b76512df7 100644 --- a/tests/test_common_reduction_axis_validation.py +++ b/tests/test_common_reduction_axis_validation.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from pyrecest._backend import _common diff --git a/tests/test_common_size_axis_validation.py b/tests/test_common_size_axis_validation.py index 30b3dc624..861e9e5e9 100644 --- a/tests/test_common_size_axis_validation.py +++ b/tests/test_common_size_axis_validation.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend import _common diff --git a/tests/test_cost_matrix_adjustments.py b/tests/test_cost_matrix_adjustments.py index 7cdf757aa..2e795c7d6 100644 --- a/tests/test_cost_matrix_adjustments.py +++ b/tests/test_cost_matrix_adjustments.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt - from pyrecest.utils.cost_matrix_adjustments import ( CallableCostMatrixAdjustment, CostMatrixAdjustmentResult, @@ -151,7 +150,9 @@ def test_numeric_validation(self): def test_named_adjustment_validation(self): with self.assertRaises(ValueError): - CallableCostMatrixAdjustment(name="", function=lambda matrix, metadata: matrix) + CallableCostMatrixAdjustment( + name="", function=lambda matrix, metadata: matrix + ) with self.assertRaises(ValueError): CallableCostMatrixAdjustment(name="bad", function=None) # type: ignore[arg-type] diff --git a/tests/test_evidence_metadata_validation.py b/tests/test_evidence_metadata_validation.py index eeea66646..8629e0a73 100644 --- a/tests/test_evidence_metadata_validation.py +++ b/tests/test_evidence_metadata_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evidence import EvidenceComputationMode diff --git a/tests/test_evidence_mode_validation.py b/tests/test_evidence_mode_validation.py index a9697335b..17527b1a2 100644 --- a/tests/test_evidence_mode_validation.py +++ b/tests/test_evidence_mode_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evidence import resolve_evidence_computation_mode diff --git a/tests/test_jax_linalg_matrix_power.py b/tests/test_jax_linalg_matrix_power.py index 50a6ce1b9..8d17f3773 100644 --- a/tests/test_jax_linalg_matrix_power.py +++ b/tests/test_jax_linalg_matrix_power.py @@ -1,7 +1,6 @@ """Regression tests for JAX linalg static-argument normalization.""" import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/test_multisession_scalar_cost_validation.py b/tests/test_multisession_scalar_cost_validation.py index 5e6f02558..0b8893570 100644 --- a/tests/test_multisession_scalar_cost_validation.py +++ b/tests/test_multisession_scalar_cost_validation.py @@ -23,13 +23,17 @@ def test_boolean_scalar_costs_are_rejected(self): for name, value in invalid_costs: with self.subTest(name=name): - with self.assertRaisesRegex(ValueError, f"{name} must be a finite scalar"): + with self.assertRaisesRegex( + ValueError, f"{name} must be a finite scalar" + ): solve_multisession_assignment( {}, session_sizes=[1], **{name: value}, ) - with self.assertRaisesRegex(ValueError, f"{name} must be a finite scalar"): + with self.assertRaisesRegex( + ValueError, f"{name} must be a finite scalar" + ): multisession_assignment_module.solve_multisession_assignment( {}, session_sizes=[1], diff --git a/tests/test_pytorch_close_missing_values.py b/tests/test_pytorch_close_missing_values.py index 07b9652f7..ff9a7af07 100644 --- a/tests/test_pytorch_close_missing_values.py +++ b/tests/test_pytorch_close_missing_values.py @@ -21,7 +21,9 @@ def test_isclose_accepts_equal_nan_for_raw_backend(self): left = pytorch_backend.array([1.0, float("nan"), 3.0]) right = pytorch_backend.array([1.0, float("nan"), 4.0]) - self.assertEqual(pytorch_backend.isclose(left, right).tolist(), [True, False, False]) + self.assertEqual( + pytorch_backend.isclose(left, right).tolist(), [True, False, False] + ) self.assertEqual( pytorch_backend.isclose(left, right, equal_nan=True).tolist(), [True, True, False], diff --git a/tests/test_pytorch_conj_array_like_contract.py b/tests/test_pytorch_conj_array_like_contract.py index c376f962b..87be13555 100644 --- a/tests/test_pytorch_conj_array_like_contract.py +++ b/tests/test_pytorch_conj_array_like_contract.py @@ -1,5 +1,4 @@ import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/test_pytorch_cross_contract.py b/tests/test_pytorch_cross_contract.py index 2fe9aebf7..89611fe52 100644 --- a/tests/test_pytorch_cross_contract.py +++ b/tests/test_pytorch_cross_contract.py @@ -1,8 +1,7 @@ import numpy as np import numpy.testing as npt -import pytest - import pyrecest.backend as backend +import pytest @pytest.mark.skipif( diff --git a/tests/test_pytorch_prod_aliases.py b/tests/test_pytorch_prod_aliases.py index c8dec5a75..92a710d5b 100644 --- a/tests/test_pytorch_prod_aliases.py +++ b/tests/test_pytorch_prod_aliases.py @@ -1,6 +1,5 @@ -import pytest - import pyrecest.backend as backend +import pytest from pyrecest.backend import array diff --git a/tests/test_track_completion_text_candidate_validation.py b/tests/test_track_completion_text_candidate_validation.py index e774bb92c..d0c646fcb 100644 --- a/tests/test_track_completion_text_candidate_validation.py +++ b/tests/test_track_completion_text_candidate_validation.py @@ -18,7 +18,9 @@ def provider(*args): def _assert_candidate_rejected(candidate): try: enumerate_fragment_completion_paths( - [[0, None]], direction="suffix", candidate_provider=_provider_with(candidate) + [[0, None]], + direction="suffix", + candidate_provider=_provider_with(candidate), ) except ValueError as exc: assert "candidate observations must be non-negative integers" in str(exc) diff --git a/tests/tracking/test_hypothesis_replay.py b/tests/tracking/test_hypothesis_replay.py index db20bc8c2..500130454 100644 --- a/tests/tracking/test_hypothesis_replay.py +++ b/tests/tracking/test_hypothesis_replay.py @@ -108,7 +108,9 @@ def test_hypothesis_replay_rejects_text_count_fields() -> None: for field_name, value, message in invalid_cases: with pytest.raises(ValueError, match=message): - HypothesisReplay(hypothesis_id="bad-count", records=[], **{field_name: value}) + HypothesisReplay( + hypothesis_id="bad-count", records=[], **{field_name: value} + ) def test_rank_replayed_hypotheses_calls_replay_function() -> None: diff --git a/tests/utils/test_logistic_association_scalar_prediction.py b/tests/utils/test_logistic_association_scalar_prediction.py index 0c1aaed38..6499cc41a 100644 --- a/tests/utils/test_logistic_association_scalar_prediction.py +++ b/tests/utils/test_logistic_association_scalar_prediction.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest.utils import LogisticPairwiseAssociationModel