diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index e4347b639..9ae4ef64b 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -304,7 +304,9 @@ def _patch_pytorch_tile_facade() -> None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel - selected_backend_is_pytorch = getattr(backend, "__backend_name__", None) == "pytorch" + selected_backend_is_pytorch = ( + getattr(backend, "__backend_name__", None) == "pytorch" + ) try: import numpy as _np # pylint: disable=import-outside-toplevel diff --git a/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py b/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py index bdffc0a86..31095fd8a 100644 --- a/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py +++ b/src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py @@ -1,8 +1,8 @@ """Compatibility package for shared NumPy RNG helpers.""" +import sys as _sys from importlib import util as _importlib_util from pathlib import Path as _Path -import sys as _sys import numpy as _np diff --git a/src/pyrecest/_backend/capabilities/__init__.py b/src/pyrecest/_backend/capabilities/__init__.py index c0ab40e5e..ca5279a64 100644 --- a/src/pyrecest/_backend/capabilities/__init__.py +++ b/src/pyrecest/_backend/capabilities/__init__.py @@ -74,7 +74,9 @@ def _patch_pytorch_dot_outer_device_contract() -> None: helper_names = ("dot", "outer") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": diff --git a/src/pyrecest/_backend/jax/fft.py b/src/pyrecest/_backend/jax/fft.py index fb71a6dfd..fcd26b65e 100644 --- a/src/pyrecest/_backend/jax/fft.py +++ b/src/pyrecest/_backend/jax/fft.py @@ -30,11 +30,15 @@ def _normalize_real_fft_axis(axis): def rfft(a, n=None, axis=-1, norm=None): - return _fft.rfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm) + return _fft.rfft( + _jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm + ) def irfft(a, n=None, axis=-1, norm=None): - return _fft.irfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm) + return _fft.irfft( + _jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm + ) def fftn(a, s=None, axes=None, norm=None): diff --git a/src/pyrecest/_backend/pytorch/_dtype.py b/src/pyrecest/_backend/pytorch/_dtype.py index fdd0fdea1..c3b0b54ec 100644 --- a/src/pyrecest/_backend/pytorch/_dtype.py +++ b/src/pyrecest/_backend/pytorch/_dtype.py @@ -15,8 +15,8 @@ from pyrecest._backend._dtype_utils import ( get_default_dtype as _shared_get_default_dtype, ) +from torch import bool as torch_bool from torch import ( - bool as torch_bool, complex64, complex128, float32, diff --git a/src/pyrecest/_backend_submodules.py b/src/pyrecest/_backend_submodules.py index 55409b6da..ea7509735 100644 --- a/src/pyrecest/_backend_submodules.py +++ b/src/pyrecest/_backend_submodules.py @@ -573,7 +573,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): "(dimension must be 2 or 3)" ) - leading_shape = numpy_module.broadcast_shapes(tuple(a.shape[:-1]), tuple(b.shape[:-1])) + leading_shape = numpy_module.broadcast_shapes( + tuple(a.shape[:-1]), tuple(b.shape[:-1]) + ) if tuple(a.shape[:-1]) != leading_shape: a = torch_module.broadcast_to(a, leading_shape + (a_dim,)) if tuple(b.shape[:-1]) != leading_shape: diff --git a/src/pyrecest/_pytorch_dot_contract.py b/src/pyrecest/_pytorch_dot_contract.py index bdebd130c..3ea6337de 100644 --- a/src/pyrecest/_pytorch_dot_contract.py +++ b/src/pyrecest/_pytorch_dot_contract.py @@ -7,8 +7,8 @@ def patch_pytorch_dot_numpy_contract() -> None: """Make raw and public PyTorch ``dot`` follow NumPy's axis contract.""" try: - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - PyTorch may be unavailable return diff --git a/src/pyrecest/backend_support/__init__.py b/src/pyrecest/backend_support/__init__.py index e11f3ec12..8318498fc 100644 --- a/src/pyrecest/backend_support/__init__.py +++ b/src/pyrecest/backend_support/__init__.py @@ -48,8 +48,8 @@ def _patch_raw_pytorch_assignment_scalar_tensor_indices() -> None: """Make raw PyTorch assignment helpers accept scalar integer tensor indices.""" try: - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch as _torch # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable return @@ -226,7 +226,9 @@ def _patch_pytorch_copy_numpy_contract() -> None: try: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_copy = raw_pytorch.copy @@ -258,7 +260,9 @@ def _patch_pytorch_clip_numpy_contract() -> None: try: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_clip = raw_pytorch.clip @@ -351,7 +355,9 @@ def _patch_pytorch_broadcast_to_numpy_contract() -> None: import numpy as np # pylint: disable=import-outside-toplevel import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return original_broadcast_to = raw_pytorch.broadcast_to diff --git a/src/pyrecest/backend_support/_jax_random_empty_contract.py b/src/pyrecest/backend_support/_jax_random_empty_contract.py index b584a490f..00f6d88de 100644 --- a/src/pyrecest/backend_support/_jax_random_empty_contract.py +++ b/src/pyrecest/backend_support/_jax_random_empty_contract.py @@ -34,11 +34,17 @@ def _parse_randint_arguments(low, high, size, kwargs): return None return legacy_minval, legacy_maxval, size, kwargs if ( - raw_jax_random._looks_like_shape_sequence(low) # pylint: disable=protected-access + raw_jax_random._looks_like_shape_sequence( + low + ) # pylint: disable=protected-access and high is not None and size is not None - and raw_jax_random._looks_like_scalar_randint_bound(high) # pylint: disable=protected-access - and raw_jax_random._looks_like_scalar_randint_bound(size) # pylint: disable=protected-access + and raw_jax_random._looks_like_scalar_randint_bound( + high + ) # pylint: disable=protected-access + and raw_jax_random._looks_like_scalar_randint_bound( + size + ) # pylint: disable=protected-access ): return high, size, low, kwargs if high is None: @@ -55,17 +61,27 @@ def _empty_invalid_bound_result(low, high, size, args, kwargs): return None low, high, size, kwargs = parsed - low = raw_jax_random._validate_randint_bound(low, "low") # pylint: disable=protected-access - high = raw_jax_random._validate_randint_bound(high, "high") # pylint: disable=protected-access + low = raw_jax_random._validate_randint_bound( + low, "low" + ) # pylint: disable=protected-access + high = raw_jax_random._validate_randint_bound( + high, "high" + ) # pylint: disable=protected-access try: low, high = jnp.broadcast_arrays(low, high) except ValueError as exc: raise ValueError("low and high could not be broadcast together") from exc - shape = raw_jax_random._bounded_sampler_shape(size, low, high) # pylint: disable=protected-access - if not raw_jax_random._shape_has_no_samples(shape): # pylint: disable=protected-access + shape = raw_jax_random._bounded_sampler_shape( + size, low, high + ) # pylint: disable=protected-access + if not raw_jax_random._shape_has_no_samples( + shape + ): # pylint: disable=protected-access return None - state, has_state, remaining_kwargs = raw_jax_random._get_state(**kwargs) # pylint: disable=protected-access + state, has_state, remaining_kwargs = raw_jax_random._get_state( + **kwargs + ) # pylint: disable=protected-access state, result = raw_jax_random._randint( # pylint: disable=protected-access state, shape, diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py index 1be1347e3..29dee2236 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py @@ -16,7 +16,9 @@ def patch_pytorch_dtype_promotion_contract() -> None: from pyrecest._backend.pytorch._common import ( # pylint: disable=import-outside-toplevel _normalize_dtype, ) - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return _patch_pytorch_repeat_numpy_contract(raw_pytorch, torch) @@ -160,7 +162,9 @@ def _normalize_axis(axis, ndim): if axis < 0: axis += ndim if axis < 0 or axis >= ndim: - raise IndexError(f"axis {axis} is out of bounds for array of dimension {ndim}") + raise IndexError( + f"axis {axis} is out of bounds for array of dimension {ndim}" + ) return axis def _boundary(value, reference, axis): @@ -220,7 +224,10 @@ def _patch_pytorch_transpose_numpy_axes_contract(raw_pytorch, np) -> None: original_transpose = raw_pytorch.transpose if getattr(original_transpose, "_pyrecest_numpy_axes_contract", False): - if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch": + if ( + backend is not None + and getattr(backend, "__backend_name__", None) == "pytorch" + ): backend.transpose = original_transpose return @@ -247,8 +254,7 @@ def _normalize_creation_shape(shape, torch, np): normalized_shape = (_operator_index(shape_array.item()),) else: normalized_shape = tuple( - _operator_index(one_dimension) - for one_dimension in shape_array.tolist() + _operator_index(one_dimension) for one_dimension in shape_array.tolist() ) if any(one_dimension < 0 for one_dimension in normalized_shape): raise ValueError("negative dimensions are not allowed") @@ -425,7 +431,9 @@ def _patch_pytorch_randint_empty_size_contract(raw_pytorch_random, torch) -> Non active_pytorch_backend = ( backend is not None and getattr(backend, "__backend_name__", None) == "pytorch" ) - backend_random = getattr(backend, "random", None) if active_pytorch_backend else None + backend_random = ( + getattr(backend, "random", None) if active_pytorch_backend else None + ) original_randint = raw_pytorch_random.randint if getattr(original_randint, "_pyrecest_empty_size_contract", False): if active_pytorch_backend and backend_random is not None: @@ -503,7 +511,9 @@ def _normalize_pad_pairs(pad_width, ndim, np): try: pad_pairs = np.broadcast_to(pad_width_array, (ndim, 2)) except ValueError as exc: - raise ValueError(f"pad_width must be broadcastable to shape ({ndim}, 2)") from exc + raise ValueError( + f"pad_width must be broadcastable to shape ({ndim}, 2)" + ) from exc if np.any(pad_pairs < 0): raise ValueError("index can't contain negative values") return tuple( @@ -525,7 +535,9 @@ def _normalize_constant_value_pairs(constant_values, ndim, np): def _filled_pad_block(shape, value, reference, torch): """Return a constant-filled block compatible with ``reference``.""" - scalar_value = torch.as_tensor(value, dtype=reference.dtype, device=reference.device) + scalar_value = torch.as_tensor( + value, dtype=reference.dtype, device=reference.device + ) if scalar_value.ndim != 0: raise ValueError("constant_values entries must be scalar") block = torch.empty(tuple(shape), dtype=reference.dtype, device=reference.device) diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index 73843254a..92f490a9c 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -7,13 +7,17 @@ def _load_base_contract_module(): - module_path = Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py" + module_path = ( + Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py" + ) spec = importlib.util.spec_from_file_location( "_pyrecest_torch_dtype_promotion_contract_base", module_path, ) if spec is None or spec.loader is None: - raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}") + raise ImportError( + f"Cannot load PyTorch dtype contract module from {module_path}" + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module @@ -30,7 +34,9 @@ def patch_pytorch_dtype_promotion_contract() -> None: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import pyrecest.backend as backend # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel - except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + except ( + ModuleNotFoundError + ): # pragma: no cover - PyTorch backend import failed earlier return _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) @@ -67,11 +73,17 @@ def assignment(x, values, indices, axis=0): return assignment -def _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) -> None: +def _patch_pytorch_assignment_numpy_index_contract( + raw_pytorch, backend, torch, np +) -> None: """Make PyTorch assignment helpers accept NumPy integer and boolean indices.""" helper_names = ("assignment", "assignment_by_sum") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_numpy_index_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_numpy_index_contract", + False, + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -225,7 +237,9 @@ def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None: "power": False, } if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helpers ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -248,7 +262,9 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None """Keep equality-style helpers on an existing non-CPU tensor device.""" helper_names = ("equal", "less_equal", "array" + "_equal") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -343,7 +359,9 @@ def _wrap_argsort_arraylike_helper(original_argsort, raw_pytorch, torch): if getattr(original_argsort, "_pyrecest_arraylike_contract", False): return original_argsort - def argsort(input, axis=-1, descending=False, stable=False, *, dim=None): # pylint: disable=redefined-builtin + def argsort( + input, axis=-1, descending=False, stable=False, *, dim=None + ): # pylint: disable=redefined-builtin if dim is not None: if axis != -1 and axis != dim: raise TypeError("argsort() got both 'axis' and 'dim'") @@ -375,7 +393,11 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non ) all_helper_names = (*helper_names, "argsort") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_arraylike_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_arraylike_contract", + False, + ) for helper_name in all_helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -393,7 +415,9 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non if getattr(backend, "__backend_name__", None) == "pytorch": setattr(backend, helper_name, wrapped_helper) - wrapped_argsort = _wrap_argsort_arraylike_helper(raw_pytorch.argsort, raw_pytorch, torch) + wrapped_argsort = _wrap_argsort_arraylike_helper( + raw_pytorch.argsort, raw_pytorch, torch + ) raw_pytorch.argsort = wrapped_argsort if getattr(backend, "__backend_name__", None) == "pytorch": backend.argsort = wrapped_argsort diff --git a/src/pyrecest/distributions/abstract_dirac_distribution.py b/src/pyrecest/distributions/abstract_dirac_distribution.py index ac838347a..dd0572d95 100644 --- a/src/pyrecest/distributions/abstract_dirac_distribution.py +++ b/src/pyrecest/distributions/abstract_dirac_distribution.py @@ -13,7 +13,9 @@ arange, argmax, asarray, - copy as backend_copy, +) +from pyrecest.backend import copy as backend_copy +from pyrecest.backend import ( int32, int64, isclose, diff --git a/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py b/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py index 13eebb508..e7aa7dc6a 100644 --- a/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py +++ b/src/pyrecest/distributions/hypersphere_subset/hyperspherical_uniform_distribution.py @@ -11,7 +11,9 @@ linalg, ones, pi, - random as backend_random, +) +from pyrecest.backend import random as backend_random +from pyrecest.backend import ( sin, sqrt, stack, diff --git a/src/pyrecest/evaluation/get_extract_mean.py b/src/pyrecest/evaluation/get_extract_mean.py index f387b2d93..78414f45b 100644 --- a/src/pyrecest/evaluation/get_extract_mean.py +++ b/src/pyrecest/evaluation/get_extract_mean.py @@ -88,7 +88,9 @@ def _extract_mtt_mean(filter_state): def get_extract_mean(manifold_name, mtt_scenario=False): normalized_name = _normalize_registry_name(manifold_name) - is_mtt_scenario = _coerce_mtt_scenario_flag(mtt_scenario) or "mtt" in normalized_name + is_mtt_scenario = ( + _coerce_mtt_scenario_flag(mtt_scenario) or "mtt" in normalized_name + ) registered_factory = _EXTRACT_MEAN_FACTORIES.get(normalized_name) if registered_factory is not None: return registered_factory(manifold_name, is_mtt_scenario) diff --git a/src/pyrecest/filters/adaptive_process_noise.py b/src/pyrecest/filters/adaptive_process_noise.py index 25bed3bce..6383ff5d5 100644 --- a/src/pyrecest/filters/adaptive_process_noise.py +++ b/src/pyrecest/filters/adaptive_process_noise.py @@ -151,7 +151,9 @@ def observe( accepted = _normalize_bool_flag(accepted, "accepted") if not accepted: return self.ratios_by_source.get(source, 1.0) - measurement_dim = _normalize_positive_integer(measurement_dim, "measurement_dim") + measurement_dim = _normalize_positive_integer( + measurement_dim, "measurement_dim" + ) nis = _normalize_nonnegative_finite_scalar(nis, "nis") ratio = nis / float(measurement_dim) previous = self.ratios_by_source.get(source, 1.0) diff --git a/src/pyrecest/filters/wrapped_normal_filter.py b/src/pyrecest/filters/wrapped_normal_filter.py index e2de0e247..df25d8d77 100644 --- a/src/pyrecest/filters/wrapped_normal_filter.py +++ b/src/pyrecest/filters/wrapped_normal_filter.py @@ -140,7 +140,9 @@ def update_nonlinear_progressive( ) if current_lambda <= 0: - raise ValueError("Progressive update with given threshold impossible") + raise ValueError( + "Progressive update with given threshold impossible" + ) current_lambda = maximum(current_lambda, MINIMUM_LAMBDA) current_lambda = minimum(current_lambda, lambda_) diff --git a/src/pyrecest/models/additive_noise.py b/src/pyrecest/models/additive_noise.py index f5effad75..1bb028967 100644 --- a/src/pyrecest/models/additive_noise.py +++ b/src/pyrecest/models/additive_noise.py @@ -411,7 +411,9 @@ def has_jacobian(self): def measurement_residual(self, measurement, state, **kwargs): """Return ``measurement - h(state)``.""" - return _array_difference(measurement, self.measurement_function(state, **kwargs)) + return _array_difference( + measurement, self.measurement_function(state, **kwargs) + ) def sample_measurement(self, state, n: int = 1, **kwargs): """Draw ``n`` samples from ``p(measurement | state)``.""" diff --git a/src/pyrecest/models/validation.py b/src/pyrecest/models/validation.py index cd977d367..2502ba0a8 100644 --- a/src/pyrecest/models/validation.py +++ b/src/pyrecest/models/validation.py @@ -384,7 +384,9 @@ def infer_state_dim_from_distribution( if hasattr(distribution, "d"): try: - support = _maybe_call(getattr(distribution, "d"), allow_methods=allow_methods) + support = _maybe_call( + getattr(distribution, "d"), allow_methods=allow_methods + ) support = validate_matrix(support, name="distribution.d") return _shape_tuple(support)[1] except (TypeError, ValueError): diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 2f3ed50c1..833b19e53 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -33,7 +33,11 @@ def _patch_pytorch_raw_comparison_arraylike_contract() -> None: active_pytorch_backend = getattr(backend, "__backend_name__", None) == "pytorch" helper_names = ("greater", "less", "logical_or") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_arraylike_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_arraylike_contract", + False, + ) for helper_name in helper_names ): if active_pytorch_backend: diff --git a/src/pyrecest/utils/__init__.py b/src/pyrecest/utils/__init__.py index 5e7e6c432..507f17f3c 100644 --- a/src/pyrecest/utils/__init__.py +++ b/src/pyrecest/utils/__init__.py @@ -163,7 +163,9 @@ def _fit_standardization(self, features): "__name__", "_fit_standardization", ) - _fit_standardization.__doc__ = getattr(original_fit_standardization, "__doc__", None) + _fit_standardization.__doc__ = getattr( + original_fit_standardization, "__doc__", None + ) _fit_standardization._pyrecest_backend_std_contract = True LogisticPairwiseAssociationModel._fit_standardization = _fit_standardization @@ -193,7 +195,9 @@ def _prepare_prediction_features(features, expected_feature_dimension): if not _backend.all(_backend.isfinite(flattened)): raise ValueError("features must be finite") return flattened, () - return original_prepare_prediction_features(features, expected_feature_dimension) + return original_prepare_prediction_features( + features, expected_feature_dimension + ) _prepare_prediction_features.__name__ = getattr( original_prepare_prediction_features, @@ -323,7 +327,9 @@ def _prepare_prediction_features(features, expected_feature_dimension): ) _multisession_assignment_module.tracks_to_session_labels = tracks_to_session_labels -_multisession_assignment_module._validate_scalar_cost = _validate_multisession_scalar_cost +_multisession_assignment_module._validate_scalar_cost = ( + _validate_multisession_scalar_cost +) __all__ = [ "MultiSessionAssignmentResult", diff --git a/src/pyrecest/utils/candidate_pruning.py b/src/pyrecest/utils/candidate_pruning.py index f602698c0..e0ab810b8 100644 --- a/src/pyrecest/utils/candidate_pruning.py +++ b/src/pyrecest/utils/candidate_pruning.py @@ -377,7 +377,9 @@ def _as_probability_matrix( if probabilities.shape != shape: raise ValueError("probability_matrix must match cost_matrix shape") if np.any(np.isinf(probabilities)): - raise ValueError("probability_matrix may only contain finite probabilities or NaN") + raise ValueError( + "probability_matrix may only contain finite probabilities or NaN" + ) finite = np.isfinite(probabilities) if np.any(finite & ((probabilities < 0.0) | (probabilities > 1.0))): raise ValueError("finite probability_matrix entries must lie in [0, 1]") diff --git a/tests/backend/test_numpy_random_size_validation.py b/tests/backend/test_numpy_random_size_validation.py index 1c47b496c..35ad26129 100644 --- a/tests/backend/test_numpy_random_size_validation.py +++ b/tests/backend/test_numpy_random_size_validation.py @@ -1,9 +1,7 @@ import numpy as np import pytest - from pyrecest._backend.numpy import random - _INVALID_SIZE_ARGUMENTS = ( True, np.bool_(True), diff --git a/tests/backend/test_numpy_random_uniform_validation.py b/tests/backend/test_numpy_random_uniform_validation.py index d9a1f88a8..a83df27ea 100644 --- a/tests/backend/test_numpy_random_uniform_validation.py +++ b/tests/backend/test_numpy_random_uniform_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest._backend.numpy import random diff --git a/tests/backend/test_numpy_uniform_ragged_bounds.py b/tests/backend/test_numpy_uniform_ragged_bounds.py index 2fe03a215..ed70556cd 100644 --- a/tests/backend/test_numpy_uniform_ragged_bounds.py +++ b/tests/backend/test_numpy_uniform_ragged_bounds.py @@ -1,5 +1,4 @@ import pytest - from pyrecest._backend import numpy as backend diff --git a/tests/backend/test_pytorch_comparison_device_contract.py b/tests/backend/test_pytorch_comparison_device_contract.py index 906371969..205912146 100644 --- a/tests/backend/test_pytorch_comparison_device_contract.py +++ b/tests/backend/test_pytorch_comparison_device_contract.py @@ -1,10 +1,9 @@ import pytest - torch = pytest.importorskip("torch") -import pyrecest.backend_tools # noqa: E402,F401 import pyrecest._backend.pytorch as pytorch_backend # noqa: E402 +import pyrecest.backend_tools # noqa: E402,F401 def _non_cpu_device(): @@ -72,7 +71,9 @@ def test_raw_pytorch_isclose_prefers_existing_non_cpu_device_for_right_operand() def test_raw_pytorch_allclose_accepts_arraylike_against_cuda_operand(): if not torch.cuda.is_available(): - pytest.skip("allclose returns a host bool and cannot be exercised on meta tensors") + pytest.skip( + "allclose returns a host bool and cannot be exercised on meta tensors" + ) right = torch.tensor([1.0, float("nan")], device="cuda") diff --git a/tests/backend_support/conftest.py b/tests/backend_support/conftest.py index 76b59c91e..a3c77d2e7 100644 --- a/tests/backend_support/conftest.py +++ b/tests/backend_support/conftest.py @@ -1,6 +1,5 @@ from pathlib import Path - _placeholder_regression = Path(__file__).with_name( "test_pytorch_dot_outer_device_contract.py" ) diff --git a/tests/backend_support/test_common_dot_contract.py b/tests/backend_support/test_common_dot_contract.py index 72d6d5c4d..89e52db47 100644 --- a/tests/backend_support/test_common_dot_contract.py +++ b/tests/backend_support/test_common_dot_contract.py @@ -1,7 +1,6 @@ import numpy as np import numpy.testing as npt import pytest - from pyrecest._backend import _common as common diff --git a/tests/backend_support/test_jax_one_hot_contract.py b/tests/backend_support/test_jax_one_hot_contract.py index 44b47bec1..b96d4cba1 100644 --- a/tests/backend_support/test_jax_one_hot_contract.py +++ b/tests/backend_support/test_jax_one_hot_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_jax_take_out_contract.py b/tests/backend_support/test_jax_take_out_contract.py index 36fc8ef8a..9022629b4 100644 --- a/tests/backend_support/test_jax_take_out_contract.py +++ b/tests/backend_support/test_jax_take_out_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_broadcast_to_contract.py b/tests/backend_support/test_pytorch_broadcast_to_contract.py index 924564f0b..43aed0f45 100644 --- a/tests/backend_support/test_pytorch_broadcast_to_contract.py +++ b/tests/backend_support/test_pytorch_broadcast_to_contract.py @@ -61,7 +61,9 @@ def test_pytorch_broadcast_to_shape_inputs_match_numpy_contract(): else: raise AssertionError(f"broadcast_to accepted boolean shape {invalid_shape!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) @pytest.mark.backend_portable @@ -101,4 +103,6 @@ def test_raw_pytorch_broadcast_to_shape_inputs_with_numpy_public_backend(): else: raise AssertionError(f"raw broadcast_to accepted boolean shape {invalid_shape!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("numpy")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("numpy") + ) diff --git a/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py b/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py index a1e117efd..35cfea6e1 100644 --- a/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py +++ b/tests/backend_support/test_pytorch_diagonal_numpy_scalar_args.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend import _common as common diff --git a/tests/backend_support/test_pytorch_diff_contract.py b/tests/backend_support/test_pytorch_diff_contract.py index 3413f6a82..4f37605cb 100644 --- a/tests/backend_support/test_pytorch_diff_contract.py +++ b/tests/backend_support/test_pytorch_diff_contract.py @@ -5,7 +5,6 @@ import pytest - SCRIPT = """ import pyrecest # noqa: F401 # triggers raw-backend compatibility patches import torch diff --git a/tests/backend_support/test_pytorch_dot_outer_device_contract.py b/tests/backend_support/test_pytorch_dot_outer_device_contract.py index 580ea72aa..4d50885e3 100644 --- a/tests/backend_support/test_pytorch_dot_outer_device_contract.py +++ b/tests/backend_support/test_pytorch_dot_outer_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_equality_device_contract.py b/tests/backend_support/test_pytorch_equality_device_contract.py index 225f6a4b0..c6cab6fdc 100644 --- a/tests/backend_support/test_pytorch_equality_device_contract.py +++ b/tests/backend_support/test_pytorch_equality_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py b/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py index b5a48a456..6efae09c0 100644 --- a/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py +++ b/tests/backend_support/test_pytorch_fft_alias_numpy_arrays_contract.py @@ -57,4 +57,6 @@ def test_pytorch_fft_aliases_accept_matching_numpy_array_axes(): else: raise AssertionError("conflicting FFT axis aliases were accepted") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) diff --git a/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py b/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py index 96a4a374a..fe42f1bbb 100644 --- a/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py +++ b/tests/backend_support/test_pytorch_fft_scalar_axis_contract.py @@ -17,7 +17,9 @@ def test_raw_pytorch_real_fft_accepts_numpy_scalar_array_axis_alias(): npt.assert_allclose(spectrum.numpy(), np.fft.rfft(vector, axis=axis)) reconstructed = pytorch_fft.irfft(spectrum, n=vector.size, axis=axis) - expected = np.fft.irfft(np.fft.rfft(vector, axis=axis), n=vector.size, axis=axis) + expected = np.fft.irfft( + np.fft.rfft(vector, axis=axis), n=vector.size, axis=axis + ) npt.assert_allclose(reconstructed.numpy(), expected) diff --git a/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py b/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py index e22a3e487..0e19a83f2 100644 --- a/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py +++ b/tests/backend_support/test_pytorch_isclose_equal_nan_contract.py @@ -43,7 +43,9 @@ def test_public_pytorch_isclose_accepts_equal_nan_keyword_when_selected(): right = [np.nan, 1.0 + 1e-9, 2.0] result = backend.isclose(left, right, equal_nan=True) - expected = np.isclose(left, right, rtol=backend.rtol, atol=backend.atol, equal_nan=True) + expected = np.isclose( + left, right, rtol=backend.rtol, atol=backend.atol, equal_nan=True + ) npt.assert_array_equal(backend.to_numpy(result), expected) result_without_nan_match = backend.isclose(left, right, equal_nan=False) @@ -54,4 +56,6 @@ def test_public_pytorch_isclose_accepts_equal_nan_keyword_when_selected(): atol=backend.atol, equal_nan=False, ) - npt.assert_array_equal(backend.to_numpy(result_without_nan_match), expected_without_nan_match) + npt.assert_array_equal( + backend.to_numpy(result_without_nan_match), expected_without_nan_match + ) diff --git a/tests/backend_support/test_pytorch_matmul_device_contract.py b/tests/backend_support/test_pytorch_matmul_device_contract.py index 0232750c6..40a0160cb 100644 --- a/tests/backend_support/test_pytorch_matmul_device_contract.py +++ b/tests/backend_support/test_pytorch_matmul_device_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_minmax_device_contract.py b/tests/backend_support/test_pytorch_minmax_device_contract.py index b1dc3097a..491b278b4 100644 --- a/tests/backend_support/test_pytorch_minmax_device_contract.py +++ b/tests/backend_support/test_pytorch_minmax_device_contract.py @@ -5,7 +5,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_non_native_dtype_contract.py b/tests/backend_support/test_pytorch_non_native_dtype_contract.py index a53d1d04c..393c137c7 100644 --- a/tests/backend_support/test_pytorch_non_native_dtype_contract.py +++ b/tests/backend_support/test_pytorch_non_native_dtype_contract.py @@ -1,5 +1,4 @@ import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_pad_contract.py b/tests/backend_support/test_pytorch_pad_contract.py index 7241e5be1..af600b81e 100644 --- a/tests/backend_support/test_pytorch_pad_contract.py +++ b/tests/backend_support/test_pytorch_pad_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code pytestmark = pytest.mark.backend_portable diff --git a/tests/backend_support/test_pytorch_randint_empty_size_contract.py b/tests/backend_support/test_pytorch_randint_empty_size_contract.py index 3662d09fb..f9727bb81 100644 --- a/tests/backend_support/test_pytorch_randint_empty_size_contract.py +++ b/tests/backend_support/test_pytorch_randint_empty_size_contract.py @@ -1,7 +1,6 @@ from __future__ import annotations import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_rotation_stub_contract.py b/tests/backend_support/test_pytorch_rotation_stub_contract.py index 4223cb30a..e101d5169 100644 --- a/tests/backend_support/test_pytorch_rotation_stub_contract.py +++ b/tests/backend_support/test_pytorch_rotation_stub_contract.py @@ -5,7 +5,6 @@ import importlib.util import pytest - from pyrecest.exceptions import BackendNotSupportedError diff --git a/tests/backend_support/test_pytorch_rotation_stub_methods.py b/tests/backend_support/test_pytorch_rotation_stub_methods.py index 3901c92aa..fa3fb8af4 100644 --- a/tests/backend_support/test_pytorch_rotation_stub_methods.py +++ b/tests/backend_support/test_pytorch_rotation_stub_methods.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest - from pyrecest.exceptions import BackendNotSupportedError diff --git a/tests/backend_support/test_pytorch_solve_sylvester_device.py b/tests/backend_support/test_pytorch_solve_sylvester_device.py index 5361e190e..504749749 100644 --- a/tests/backend_support/test_pytorch_solve_sylvester_device.py +++ b/tests/backend_support/test_pytorch_solve_sylvester_device.py @@ -5,7 +5,6 @@ torch: Any try: import torch - from pyrecest._backend import pytorch as pytorch_backend except ModuleNotFoundError: torch = None @@ -31,11 +30,11 @@ def test_solve_sylvester_keeps_common_dtype_for_mixed_array_like_inputs(self): self.assertEqual(result.dtype, pytorch_backend.float64) self.assertTrue(pytorch_backend.allclose(result, expected)) - @unittest.skipIf(torch is None or not torch.cuda.is_available(), "CUDA is not available") + @unittest.skipIf( + torch is None or not torch.cuda.is_available(), "CUDA is not available" + ) def test_solve_sylvester_aligns_mixed_tensor_devices(self): - a = torch.tensor( - [[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32, device="cuda" - ) + a = torch.tensor([[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32, device="cuda") b = torch.tensor([[2.0, 0.0], [0.0, 3.0]], dtype=torch.float32) q = torch.tensor([[8.0, 10.0], [10.0, 12.0]], dtype=torch.float64) diff --git a/tests/backend_support/test_pytorch_special_contract.py b/tests/backend_support/test_pytorch_special_contract.py index 9e422f729..420a563e4 100644 --- a/tests/backend_support/test_pytorch_special_contract.py +++ b/tests/backend_support/test_pytorch_special_contract.py @@ -29,8 +29,7 @@ def test_raw_pytorch_special_helpers_are_patched_under_default_backend(): if importlib.util.find_spec("torch") is None: pytest.skip("torch is not installed") - _run_python( - r''' + _run_python(r""" import math import pyrecest # noqa: F401 @@ -70,8 +69,7 @@ def assert_close(values, expected): assert math.isnan(pole_values[0]) assert math.isinf(pole_values[1]) and pole_values[1] > 0 assert math.isinf(pole_values[2]) and pole_values[2] < 0 -''' - ) +""") @pytest.mark.backend_portable @@ -80,7 +78,7 @@ def test_public_pytorch_special_helpers_accept_arraylike_and_out(): pytest.skip("torch is not installed") _run_python( - r''' + r""" import math import pyrecest.backend as backend @@ -116,6 +114,6 @@ def assert_close(module, values, expected): special_backend.polygamma(1, [1.0, 2.0]), [math.pi**2 / 6.0, math.pi**2 / 6.0 - 1.0], ) -''', +""", backend_name="pytorch", ) diff --git a/tests/backend_support/test_pytorch_tile_contract.py b/tests/backend_support/test_pytorch_tile_contract.py index 23fb45e66..ba3bab2c1 100644 --- a/tests/backend_support/test_pytorch_tile_contract.py +++ b/tests/backend_support/test_pytorch_tile_contract.py @@ -54,7 +54,9 @@ def test_pytorch_tile_scalar_and_array_repetitions_match_numpy_contract(): else: raise AssertionError(f"tile accepted non-integer repetitions {bad_reps!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) @pytest.mark.backend_portable @@ -90,4 +92,6 @@ def test_raw_pytorch_tile_matches_numpy_contract_with_numpy_public_backend(): else: raise AssertionError(f"raw tile accepted non-integer repetitions {bad_reps!r}") """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("numpy")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("numpy") + ) diff --git a/tests/calibration/test_time_offset_real_numeric_validation.py b/tests/calibration/test_time_offset_real_numeric_validation.py index 5eebd8208..5b8a73b05 100644 --- a/tests/calibration/test_time_offset_real_numeric_validation.py +++ b/tests/calibration/test_time_offset_real_numeric_validation.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.calibration import apply_time_offset, make_offset_grid diff --git a/tests/distributions/test_abstract_mixture_sample_validation.py b/tests/distributions/test_abstract_mixture_sample_validation.py index 889ec04e2..fac1ab9f6 100644 --- a/tests/distributions/test_abstract_mixture_sample_validation.py +++ b/tests/distributions/test_abstract_mixture_sample_validation.py @@ -1,15 +1,18 @@ import numpy as np import pytest - from pyrecest.backend import array, eye from pyrecest.distributions.hypertorus.hypertoroidal_mixture import HypertoroidalMixture -from pyrecest.distributions.hypertorus.toroidal_wrapped_normal_distribution import ToroidalWrappedNormalDistribution +from pyrecest.distributions.hypertorus.toroidal_wrapped_normal_distribution import ( + ToroidalWrappedNormalDistribution, +) @pytest.mark.parametrize("count", ["3", np.str_("3")]) def test_mixture_sample_rejects_text_count(count): vmf = ToroidalWrappedNormalDistribution(array([1.0, 0.0]), eye(2)) - mixture = HypertoroidalMixture([vmf, vmf.shift(array([1.0, 1.0]))], array([0.5, 0.5])) + mixture = HypertoroidalMixture( + [vmf, vmf.shift(array([1.0, 1.0]))], array([0.5, 0.5]) + ) with pytest.raises(ValueError, match="n must be a positive integer"): mixture.sample(count) diff --git a/tests/distributions/test_dirac_tensor_copy_contract.py b/tests/distributions/test_dirac_tensor_copy_contract.py index ae746ae58..4b8668c62 100644 --- a/tests/distributions/test_dirac_tensor_copy_contract.py +++ b/tests/distributions/test_dirac_tensor_copy_contract.py @@ -39,4 +39,6 @@ def test_pytorch_dirac_distribution_clones_input_tensor_storage(): assert dist.d.data_ptr() != samples.data_ptr() assert dist.w.data_ptr() != weights.data_ptr() """ - subprocess.run([sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch")) + subprocess.run( + [sys.executable, "-c", code], check=True, env=_backend_test_env("pytorch") + ) diff --git a/tests/distributions/test_gaussian_distribution_sample_validation.py b/tests/distributions/test_gaussian_distribution_sample_validation.py index 8a8f8864d..0d91f54fc 100644 --- a/tests/distributions/test_gaussian_distribution_sample_validation.py +++ b/tests/distributions/test_gaussian_distribution_sample_validation.py @@ -1,8 +1,9 @@ import numpy as np import pytest - from pyrecest.backend import array, eye -from pyrecest.distributions.nonperiodic.gaussian_distribution import GaussianDistribution +from pyrecest.distributions.nonperiodic.gaussian_distribution import ( + GaussianDistribution, +) @pytest.mark.parametrize("count", ["3", b"3", np.str_("3"), np.bytes_(b"3")]) diff --git a/tests/distributions/test_uniform_order_validation.py b/tests/distributions/test_uniform_order_validation.py index 8c3d08623..2964121d8 100644 --- a/tests/distributions/test_uniform_order_validation.py +++ b/tests/distributions/test_uniform_order_validation.py @@ -1,7 +1,8 @@ import pytest - from pyrecest.backend import array -from pyrecest.distributions.hypertorus.hypertoroidal_uniform_distribution import HypertoroidalUniformDistribution +from pyrecest.distributions.hypertorus.hypertoroidal_uniform_distribution import ( + HypertoroidalUniformDistribution, +) def test_uniform_moment_rejects_invalid_order(): diff --git a/tests/evaluation/test_flat_empty_measurements.py b/tests/evaluation/test_flat_empty_measurements.py index 530b5cc48..c6372be4e 100644 --- a/tests/evaluation/test_flat_empty_measurements.py +++ b/tests/evaluation/test_flat_empty_measurements.py @@ -1,6 +1,5 @@ import numpy as np import numpy.testing as npt - from pyrecest.backend import array from pyrecest.evaluation import perform_predict_update_cycles from pyrecest.evaluation.configure_for_filter import register_filter_factory diff --git a/tests/evaluation/test_get_extract_mean_mtt_flag.py b/tests/evaluation/test_get_extract_mean_mtt_flag.py index e796e019d..b0b4e389f 100644 --- a/tests/evaluation/test_get_extract_mean_mtt_flag.py +++ b/tests/evaluation/test_get_extract_mean_mtt_flag.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evaluation.get_extract_mean import get_extract_mean diff --git a/tests/evaluation/test_iterate_configs_vector_parameter_shape.py b/tests/evaluation/test_iterate_configs_vector_parameter_shape.py index 9814dfff9..e8b0b0768 100644 --- a/tests/evaluation/test_iterate_configs_vector_parameter_shape.py +++ b/tests/evaluation/test_iterate_configs_vector_parameter_shape.py @@ -9,7 +9,9 @@ def test_iterate_configs_and_runs_uses_config_count_for_vector_parameters(monkey if backend.__backend_name__ not in ("numpy", "autograd"): pytest.skip("iterate_configs_and_runs stores object-valued filter states") - iterate_module = importlib.import_module("pyrecest.evaluation.iterate_configs_and_runs") + iterate_module = importlib.import_module( + "pyrecest.evaluation.iterate_configs_and_runs" + ) vector_parameter = np.array([1.0, 2.0]) calls = [] @@ -46,12 +48,14 @@ def fake_predict_update_cycles( "auto_warning_on_off": False, } - last_filter_states, runtimes, run_failed, *_ = iterate_module.iterate_configs_and_runs( - groundtruths, - measurements, - {"name": "dummy"}, - [{"name": "dummy_filter", "parameter": vector_parameter}], - evaluation_config, + last_filter_states, runtimes, run_failed, *_ = ( + iterate_module.iterate_configs_and_runs( + groundtruths, + measurements, + {"name": "dummy"}, + [{"name": "dummy_filter", "parameter": vector_parameter}], + evaluation_config, + ) ) assert np.shape(last_filter_states) == (1, 2) diff --git a/tests/evaluation/test_zero_measurements.py b/tests/evaluation/test_zero_measurements.py index 147ca1a39..8e2072953 100644 --- a/tests/evaluation/test_zero_measurements.py +++ b/tests/evaluation/test_zero_measurements.py @@ -1,5 +1,4 @@ import numpy as np - from pyrecest.distributions import GaussianDistribution from pyrecest.evaluation.generate_measurements import generate_measurements diff --git a/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py b/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py index 4a85c07ac..746ed25db 100644 --- a/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py +++ b/tests/filters/test_hypertoroidal_particle_filter_numpy_scalars.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest.filters import HypertoroidalParticleFilter diff --git a/tests/filters/test_wrapped_normal_filter_constant_likelihood.py b/tests/filters/test_wrapped_normal_filter_constant_likelihood.py index dac5dbc0e..080a2ebc0 100644 --- a/tests/filters/test_wrapped_normal_filter_constant_likelihood.py +++ b/tests/filters/test_wrapped_normal_filter_constant_likelihood.py @@ -1,5 +1,4 @@ import numpy.testing as npt - from pyrecest.backend import array from pyrecest.distributions import WrappedNormalDistribution from pyrecest.filters.wrapped_normal_filter import WrappedNormalFilter diff --git a/tests/test_common_reduction_axis_scalar.py b/tests/test_common_reduction_axis_scalar.py index 5be0adc2a..2ad9d4fd1 100644 --- a/tests/test_common_reduction_axis_scalar.py +++ b/tests/test_common_reduction_axis_scalar.py @@ -1,5 +1,4 @@ import numpy as np - from pyrecest._backend import _common diff --git a/tests/test_common_reduction_axis_validation.py b/tests/test_common_reduction_axis_validation.py index a4b3b9d78..b76512df7 100644 --- a/tests/test_common_reduction_axis_validation.py +++ b/tests/test_common_reduction_axis_validation.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from pyrecest._backend import _common diff --git a/tests/test_common_size_axis_validation.py b/tests/test_common_size_axis_validation.py index 30b3dc624..861e9e5e9 100644 --- a/tests/test_common_size_axis_validation.py +++ b/tests/test_common_size_axis_validation.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend import _common diff --git a/tests/test_evidence_mode_validation.py b/tests/test_evidence_mode_validation.py index a9697335b..17527b1a2 100644 --- a/tests/test_evidence_mode_validation.py +++ b/tests/test_evidence_mode_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evidence import resolve_evidence_computation_mode diff --git a/tests/test_jax_linalg_matrix_power.py b/tests/test_jax_linalg_matrix_power.py index 50a6ce1b9..8d17f3773 100644 --- a/tests/test_jax_linalg_matrix_power.py +++ b/tests/test_jax_linalg_matrix_power.py @@ -1,7 +1,6 @@ """Regression tests for JAX linalg static-argument normalization.""" import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/test_multisession_scalar_cost_validation.py b/tests/test_multisession_scalar_cost_validation.py index 5e6f02558..0b8893570 100644 --- a/tests/test_multisession_scalar_cost_validation.py +++ b/tests/test_multisession_scalar_cost_validation.py @@ -23,13 +23,17 @@ def test_boolean_scalar_costs_are_rejected(self): for name, value in invalid_costs: with self.subTest(name=name): - with self.assertRaisesRegex(ValueError, f"{name} must be a finite scalar"): + with self.assertRaisesRegex( + ValueError, f"{name} must be a finite scalar" + ): solve_multisession_assignment( {}, session_sizes=[1], **{name: value}, ) - with self.assertRaisesRegex(ValueError, f"{name} must be a finite scalar"): + with self.assertRaisesRegex( + ValueError, f"{name} must be a finite scalar" + ): multisession_assignment_module.solve_multisession_assignment( {}, session_sizes=[1], diff --git a/tests/test_pytorch_close_missing_values.py b/tests/test_pytorch_close_missing_values.py index 07b9652f7..ff9a7af07 100644 --- a/tests/test_pytorch_close_missing_values.py +++ b/tests/test_pytorch_close_missing_values.py @@ -21,7 +21,9 @@ def test_isclose_accepts_equal_nan_for_raw_backend(self): left = pytorch_backend.array([1.0, float("nan"), 3.0]) right = pytorch_backend.array([1.0, float("nan"), 4.0]) - self.assertEqual(pytorch_backend.isclose(left, right).tolist(), [True, False, False]) + self.assertEqual( + pytorch_backend.isclose(left, right).tolist(), [True, False, False] + ) self.assertEqual( pytorch_backend.isclose(left, right, equal_nan=True).tolist(), [True, True, False], diff --git a/tests/test_pytorch_conj_array_like_contract.py b/tests/test_pytorch_conj_array_like_contract.py index c376f962b..87be13555 100644 --- a/tests/test_pytorch_conj_array_like_contract.py +++ b/tests/test_pytorch_conj_array_like_contract.py @@ -1,5 +1,4 @@ import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/test_pytorch_cross_contract.py b/tests/test_pytorch_cross_contract.py index 2fe9aebf7..89611fe52 100644 --- a/tests/test_pytorch_cross_contract.py +++ b/tests/test_pytorch_cross_contract.py @@ -1,8 +1,7 @@ import numpy as np import numpy.testing as npt -import pytest - import pyrecest.backend as backend +import pytest @pytest.mark.skipif( diff --git a/tests/test_track_completion_text_candidate_validation.py b/tests/test_track_completion_text_candidate_validation.py index e774bb92c..d0c646fcb 100644 --- a/tests/test_track_completion_text_candidate_validation.py +++ b/tests/test_track_completion_text_candidate_validation.py @@ -18,7 +18,9 @@ def provider(*args): def _assert_candidate_rejected(candidate): try: enumerate_fragment_completion_paths( - [[0, None]], direction="suffix", candidate_provider=_provider_with(candidate) + [[0, None]], + direction="suffix", + candidate_provider=_provider_with(candidate), ) except ValueError as exc: assert "candidate observations must be non-negative integers" in str(exc) diff --git a/tests/utils/test_logistic_association_scalar_prediction.py b/tests/utils/test_logistic_association_scalar_prediction.py index 0c1aaed38..6499cc41a 100644 --- a/tests/utils/test_logistic_association_scalar_prediction.py +++ b/tests/utils/test_logistic_association_scalar_prediction.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest.utils import LogisticPairwiseAssociationModel