Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def _patch_pytorch_tile_facade() -> None:

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

selected_backend_is_pytorch = getattr(backend, "__backend_name__", None) == "pytorch"
selected_backend_is_pytorch = (
getattr(backend, "__backend_name__", None) == "pytorch"
)

try:
import numpy as _np # pylint: disable=import-outside-toplevel
Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/_backend/_shared_numpy/_rng_pkg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Compatibility package for shared NumPy RNG helpers."""

import sys as _sys
from importlib import util as _importlib_util
from pathlib import Path as _Path
import sys as _sys

import numpy as _np

Expand Down
8 changes: 6 additions & 2 deletions src/pyrecest/_backend/jax/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def _normalize_real_fft_axis(axis):


def rfft(a, n=None, axis=-1, norm=None):
return _fft.rfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm)
return _fft.rfft(
_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm
)


def irfft(a, n=None, axis=-1, norm=None):
return _fft.irfft(_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm)
return _fft.irfft(
_jnp.asarray(a), n=n, axis=_normalize_real_fft_axis(axis), norm=norm
)


def fftn(a, s=None, axes=None, norm=None):
Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/_backend/pytorch/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pyrecest._backend._dtype_utils import (
get_default_dtype as _shared_get_default_dtype,
)
from torch import bool as torch_bool
from torch import (
bool as torch_bool,
complex64,
complex128,
float32,
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/_backend_submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
"(dimension must be 2 or 3)"
)

leading_shape = numpy_module.broadcast_shapes(tuple(a.shape[:-1]), tuple(b.shape[:-1]))
leading_shape = numpy_module.broadcast_shapes(
tuple(a.shape[:-1]), tuple(b.shape[:-1])
)
if tuple(a.shape[:-1]) != leading_shape:
a = torch_module.broadcast_to(a, leading_shape + (a_dim,))
if tuple(b.shape[:-1]) != leading_shape:
Expand Down
14 changes: 10 additions & 4 deletions src/pyrecest/backend_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _patch_raw_pytorch_assignment_scalar_tensor_indices() -> None:
"""Make raw PyTorch assignment helpers accept scalar integer tensor indices."""

try:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch as _torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return
Expand Down Expand Up @@ -226,7 +226,9 @@ def _patch_pytorch_copy_numpy_contract() -> None:

try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
except (
ModuleNotFoundError
): # pragma: no cover - PyTorch backend import failed earlier
return

original_copy = raw_pytorch.copy
Expand Down Expand Up @@ -258,7 +260,9 @@ def _patch_pytorch_clip_numpy_contract() -> None:
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
except (
ModuleNotFoundError
): # pragma: no cover - PyTorch backend import failed earlier
return

original_clip = raw_pytorch.clip
Expand Down Expand Up @@ -351,7 +355,9 @@ def _patch_pytorch_broadcast_to_numpy_contract() -> None:
import numpy as np # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
except (
ModuleNotFoundError
): # pragma: no cover - PyTorch backend import failed earlier
return

original_broadcast_to = raw_pytorch.broadcast_to
Expand Down
32 changes: 24 additions & 8 deletions src/pyrecest/backend_support/_jax_random_empty_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ def _parse_randint_arguments(low, high, size, kwargs):
return None
return legacy_minval, legacy_maxval, size, kwargs
if (
raw_jax_random._looks_like_shape_sequence(low) # pylint: disable=protected-access
raw_jax_random._looks_like_shape_sequence(
low
) # pylint: disable=protected-access
and high is not None
and size is not None
and raw_jax_random._looks_like_scalar_randint_bound(high) # pylint: disable=protected-access
and raw_jax_random._looks_like_scalar_randint_bound(size) # pylint: disable=protected-access
and raw_jax_random._looks_like_scalar_randint_bound(
high
) # pylint: disable=protected-access
and raw_jax_random._looks_like_scalar_randint_bound(
size
) # pylint: disable=protected-access
):
return high, size, low, kwargs
if high is None:
Expand All @@ -55,17 +61,27 @@ def _empty_invalid_bound_result(low, high, size, args, kwargs):
return None
low, high, size, kwargs = parsed

low = raw_jax_random._validate_randint_bound(low, "low") # pylint: disable=protected-access
high = raw_jax_random._validate_randint_bound(high, "high") # pylint: disable=protected-access
low = raw_jax_random._validate_randint_bound(
low, "low"
) # pylint: disable=protected-access
high = raw_jax_random._validate_randint_bound(
high, "high"
) # pylint: disable=protected-access
try:
low, high = jnp.broadcast_arrays(low, high)
except ValueError as exc:
raise ValueError("low and high could not be broadcast together") from exc

shape = raw_jax_random._bounded_sampler_shape(size, low, high) # pylint: disable=protected-access
if not raw_jax_random._shape_has_no_samples(shape): # pylint: disable=protected-access
shape = raw_jax_random._bounded_sampler_shape(
size, low, high
) # pylint: disable=protected-access
if not raw_jax_random._shape_has_no_samples(
shape
): # pylint: disable=protected-access
return None
state, has_state, remaining_kwargs = raw_jax_random._get_state(**kwargs) # pylint: disable=protected-access
state, has_state, remaining_kwargs = raw_jax_random._get_state(
**kwargs
) # pylint: disable=protected-access
state, result = raw_jax_random._randint( # pylint: disable=protected-access
state,
shape,
Expand Down
28 changes: 20 additions & 8 deletions src/pyrecest/backend_support/_torch_dtype_promotion_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def patch_pytorch_dtype_promotion_contract() -> None:
from pyrecest._backend.pytorch._common import ( # pylint: disable=import-outside-toplevel
_normalize_dtype,
)
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
except (
ModuleNotFoundError
): # pragma: no cover - PyTorch backend import failed earlier
return

_patch_pytorch_repeat_numpy_contract(raw_pytorch, torch)
Expand Down Expand Up @@ -160,7 +162,9 @@ def _normalize_axis(axis, ndim):
if axis < 0:
axis += ndim
if axis < 0 or axis >= ndim:
raise IndexError(f"axis {axis} is out of bounds for array of dimension {ndim}")
raise IndexError(
f"axis {axis} is out of bounds for array of dimension {ndim}"
)
return axis

def _boundary(value, reference, axis):
Expand Down Expand Up @@ -220,7 +224,10 @@ def _patch_pytorch_transpose_numpy_axes_contract(raw_pytorch, np) -> None:

original_transpose = raw_pytorch.transpose
if getattr(original_transpose, "_pyrecest_numpy_axes_contract", False):
if backend is not None and getattr(backend, "__backend_name__", None) == "pytorch":
if (
backend is not None
and getattr(backend, "__backend_name__", None) == "pytorch"
):
backend.transpose = original_transpose
return

Expand All @@ -247,8 +254,7 @@ def _normalize_creation_shape(shape, torch, np):
normalized_shape = (_operator_index(shape_array.item()),)
else:
normalized_shape = tuple(
_operator_index(one_dimension)
for one_dimension in shape_array.tolist()
_operator_index(one_dimension) for one_dimension in shape_array.tolist()
)
if any(one_dimension < 0 for one_dimension in normalized_shape):
raise ValueError("negative dimensions are not allowed")
Expand Down Expand Up @@ -425,7 +431,9 @@ def _patch_pytorch_randint_empty_size_contract(raw_pytorch_random, torch) -> Non
active_pytorch_backend = (
backend is not None and getattr(backend, "__backend_name__", None) == "pytorch"
)
backend_random = getattr(backend, "random", None) if active_pytorch_backend else None
backend_random = (
getattr(backend, "random", None) if active_pytorch_backend else None
)
original_randint = raw_pytorch_random.randint
if getattr(original_randint, "_pyrecest_empty_size_contract", False):
if active_pytorch_backend and backend_random is not None:
Expand Down Expand Up @@ -503,7 +511,9 @@ def _normalize_pad_pairs(pad_width, ndim, np):
try:
pad_pairs = np.broadcast_to(pad_width_array, (ndim, 2))
except ValueError as exc:
raise ValueError(f"pad_width must be broadcastable to shape ({ndim}, 2)") from exc
raise ValueError(
f"pad_width must be broadcastable to shape ({ndim}, 2)"
) from exc
if np.any(pad_pairs < 0):
raise ValueError("index can't contain negative values")
return tuple(
Expand All @@ -525,7 +535,9 @@ def _normalize_constant_value_pairs(constant_values, ndim, np):

def _filled_pad_block(shape, value, reference, torch):
"""Return a constant-filled block compatible with ``reference``."""
scalar_value = torch.as_tensor(value, dtype=reference.dtype, device=reference.device)
scalar_value = torch.as_tensor(
value, dtype=reference.dtype, device=reference.device
)
if scalar_value.ndim != 0:
raise ValueError("constant_values entries must be scalar")
block = torch.empty(tuple(shape), dtype=reference.dtype, device=reference.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@


def _load_base_contract_module():
module_path = Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py"
module_path = (
Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py"
)
spec = importlib.util.spec_from_file_location(
"_pyrecest_torch_dtype_promotion_contract_base",
module_path,
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}")
raise ImportError(
f"Cannot load PyTorch dtype contract module from {module_path}"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
Expand All @@ -30,7 +34,9 @@ def patch_pytorch_dtype_promotion_contract() -> None:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
except (
ModuleNotFoundError
): # pragma: no cover - PyTorch backend import failed earlier
return

_patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np)
Expand Down Expand Up @@ -66,11 +72,17 @@ def assignment(x, values, indices, axis=0):
return assignment


def _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) -> None:
def _patch_pytorch_assignment_numpy_index_contract(
raw_pytorch, backend, torch, np
) -> None:
"""Make PyTorch assignment helpers accept NumPy integer and boolean indices."""
helper_names = ("assignment", "assignment_by_sum")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_numpy_index_contract", False)
getattr(
getattr(raw_pytorch, helper_name, None),
"_pyrecest_numpy_index_contract",
False,
)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
Expand Down Expand Up @@ -224,7 +236,9 @@ def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None:
"power": False,
}
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False)
getattr(
getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False
)
for helper_name in helpers
):
if getattr(backend, "__backend_name__", None) == "pytorch":
Expand All @@ -247,7 +261,9 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None
"""Keep equality-style helpers on an existing non-CPU tensor device."""
helper_names = ("equal", "less_equal", "array" + "_equal")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False)
getattr(
getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False
)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/distributions/abstract_dirac_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
arange,
argmax,
asarray,
copy as backend_copy,
)
from pyrecest.backend import copy as backend_copy
from pyrecest.backend import (
int32,
int64,
isclose,
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/filters/adaptive_process_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def observe(
accepted = _normalize_bool_flag(accepted, "accepted")
if not accepted:
return self.ratios_by_source.get(source, 1.0)
measurement_dim = _normalize_positive_integer(measurement_dim, "measurement_dim")
measurement_dim = _normalize_positive_integer(
measurement_dim, "measurement_dim"
)
nis = _normalize_nonnegative_finite_scalar(nis, "nis")
ratio = nis / float(measurement_dim)
previous = self.ratios_by_source.get(source, 1.0)
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/filters/wrapped_normal_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def update_nonlinear_progressive(
)

if current_lambda <= 0:
raise ValueError("Progressive update with given threshold impossible")
raise ValueError(
"Progressive update with given threshold impossible"
)

current_lambda = maximum(current_lambda, MINIMUM_LAMBDA)
current_lambda = minimum(current_lambda, lambda_)
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/models/additive_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def has_jacobian(self):

def measurement_residual(self, measurement, state, **kwargs):
"""Return ``measurement - h(state)``."""
return _array_difference(measurement, self.measurement_function(state, **kwargs))
return _array_difference(
measurement, self.measurement_function(state, **kwargs)
)

def sample_measurement(self, state, n: int = 1, **kwargs):
"""Draw ``n`` samples from ``p(measurement | state)``."""
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/models/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def infer_state_dim_from_distribution(

if hasattr(distribution, "d"):
try:
support = _maybe_call(getattr(distribution, "d"), allow_methods=allow_methods)
support = _maybe_call(
getattr(distribution, "d"), allow_methods=allow_methods
)
support = validate_matrix(support, name="distribution.d")
return _shape_tuple(support)[1]
except (TypeError, ValueError):
Expand Down
8 changes: 6 additions & 2 deletions src/pyrecest/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def _fit_standardization(self, features):
"__name__",
"_fit_standardization",
)
_fit_standardization.__doc__ = getattr(original_fit_standardization, "__doc__", None)
_fit_standardization.__doc__ = getattr(
original_fit_standardization, "__doc__", None
)
_fit_standardization._pyrecest_backend_std_contract = True
LogisticPairwiseAssociationModel._fit_standardization = _fit_standardization

Expand Down Expand Up @@ -279,7 +281,9 @@ def _fit_standardization(self, features):
)

_multisession_assignment_module.tracks_to_session_labels = tracks_to_session_labels
_multisession_assignment_module._validate_scalar_cost = _validate_multisession_scalar_cost
_multisession_assignment_module._validate_scalar_cost = (
_validate_multisession_scalar_cost
)

__all__ = [
"MultiSessionAssignmentResult",
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/utils/candidate_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ def _as_probability_matrix(
if probabilities.shape != shape:
raise ValueError("probability_matrix must match cost_matrix shape")
if np.any(np.isinf(probabilities)):
raise ValueError("probability_matrix may only contain finite probabilities or NaN")
raise ValueError(
"probability_matrix may only contain finite probabilities or NaN"
)
finite = np.isfinite(probabilities)
if np.any(finite & ((probabilities < 0.0) | (probabilities > 1.0))):
raise ValueError("finite probability_matrix entries must lie in [0, 1]")
Expand Down
Loading