diff --git a/src/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_uniform_distribution.py b/src/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_uniform_distribution.py index ec495a716..07467a9e2 100644 --- a/src/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_uniform_distribution.py +++ b/src/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_uniform_distribution.py @@ -25,7 +25,7 @@ def pdf(self, xs): : Probability density at the given data points. """ xs = array(xs) - if xs.shape[-1] != self.input_dim: + if xs.ndim == 0 or xs.shape[-1] != self.input_dim: raise ValueError("Invalid shape of input data points.") manifold_size = self.get_manifold_size() if manifold_size == 0: diff --git a/tests/distributions/test_hypersphere_uniform_scalar_pdf.py b/tests/distributions/test_hypersphere_uniform_scalar_pdf.py new file mode 100644 index 000000000..4c7ae1801 --- /dev/null +++ b/tests/distributions/test_hypersphere_uniform_scalar_pdf.py @@ -0,0 +1,17 @@ +"""Regression tests for hyperspherical uniform input validation.""" + +import unittest + +from pyrecest.distributions import HypersphericalUniformDistribution + + +class HypersphericalUniformScalarPdfTest(unittest.TestCase): + def test_pdf_rejects_scalar_input_with_value_error(self): + dist = HypersphericalUniformDistribution(2) + + with self.assertRaisesRegex(ValueError, "Invalid shape"): + dist.pdf(1.0) + + +if __name__ == "__main__": + unittest.main()