diff --git a/src/pyrecest/filters/gaussian_hypothesis_mixture.py b/src/pyrecest/filters/gaussian_hypothesis_mixture.py index 55d5c70f3..30004bde5 100644 --- a/src/pyrecest/filters/gaussian_hypothesis_mixture.py +++ b/src/pyrecest/filters/gaussian_hypothesis_mixture.py @@ -7,8 +7,18 @@ import numpy as np -_TEXT_OR_BOOL_KINDS = {"b", "S", "U"} -_TEXT_OR_BOOL_SCALAR_TYPES = (bool, np.bool_, str, bytes, bytearray, np.str_, np.bytes_) +_INVALID_FLOAT_ARRAY_KINDS = {"b", "S", "U", "c"} +_INVALID_FLOAT_ARRAY_SCALAR_TYPES = ( + bool, + np.bool_, + str, + bytes, + bytearray, + np.str_, + np.bytes_, + complex, + np.complexfloating, +) @dataclass(frozen=True) @@ -91,16 +101,16 @@ def _as_float_array(value: Any, name: str) -> np.ndarray: try: array = np.asarray(value) except (TypeError, ValueError, OverflowError) as exc: - raise ValueError(f"{name} must contain numeric values") from exc - if array.dtype.kind in _TEXT_OR_BOOL_KINDS or ( + raise ValueError(f"{name} must contain real numeric values") from exc + if array.dtype.kind in _INVALID_FLOAT_ARRAY_KINDS or ( array.dtype.kind == "O" - and any(isinstance(item, _TEXT_OR_BOOL_SCALAR_TYPES) for item in array.flat) + and any(isinstance(item, _INVALID_FLOAT_ARRAY_SCALAR_TYPES) for item in array.flat) ): - raise ValueError(f"{name} must contain numeric values") + raise ValueError(f"{name} must contain real numeric values") try: return array.astype(float, copy=False) except (TypeError, ValueError, OverflowError) as exc: - raise ValueError(f"{name} must contain numeric values") from exc + raise ValueError(f"{name} must contain real numeric values") from exc def _as_log_weight(value: Any) -> float: diff --git a/tests/filters/test_gaussian_hypothesis_mixture.py b/tests/filters/test_gaussian_hypothesis_mixture.py index 7fe960ce9..b7cf5f752 100644 --- a/tests/filters/test_gaussian_hypothesis_mixture.py +++ b/tests/filters/test_gaussian_hypothesis_mixture.py @@ -52,6 +52,23 @@ def test_log_weights_reject_bool_and_text_inputs(self): with self.assertRaisesRegex(ValueError, "log_weights"): normalize_log_weights(log_weights) + def test_complex_numeric_inputs_are_rejected_without_dropping_imaginary_parts(self): + with self.assertRaisesRegex(ValueError, "mean"): + WeightedGaussianHypothesis(np.array([1.0 + 2.0j]), np.array([[1.0]])) + + with self.assertRaisesRegex(ValueError, "covariance"): + WeightedGaussianHypothesis( + np.array([0.0]), np.array([[1.0 + 2.0j]]) + ) + + with self.assertRaisesRegex(ValueError, "log_weight"): + WeightedGaussianHypothesis( + np.array([0.0]), np.array([[1.0]]), log_weight=1.0 + 2.0j + ) + + with self.assertRaisesRegex(ValueError, "log_weights"): + normalize_log_weights([0.0, 1.0 + 2.0j]) + def test_hypotheses_reject_nonfinite_mean_and_covariance(self): with self.assertRaisesRegex(ValueError, "mean"): WeightedGaussianHypothesis(np.array([np.nan]), np.array([[1.0]]))