diff --git a/src/pyrecest/distributions/nonperiodic/gaussian_distribution.py b/src/pyrecest/distributions/nonperiodic/gaussian_distribution.py index 9666e785f..9b257e0fd 100644 --- a/src/pyrecest/distributions/nonperiodic/gaussian_distribution.py +++ b/src/pyrecest/distributions/nonperiodic/gaussian_distribution.py @@ -181,6 +181,7 @@ def ln_pdf(self, xs): many likelihoods are accumulated or when densities may underflow. """ xs = self._validate_evaluation_points(xs) + scalar_input = self.dim == 1 and ndim(xs) == 0 if pyrecest.backend.__backend_name__ == "numpy": from scipy.stats import multivariate_normal as mvn @@ -207,6 +208,8 @@ def ln_pdf(self, xs): else: raise NotImplementedError("Backend not supported") + if scalar_input: + log_pdf_vals = reshape(log_pdf_vals, ()) return log_pdf_vals log_pdf = ln_pdf diff --git a/tests/backend_support/test_gaussian_scalar_logpdf_contract.py b/tests/backend_support/test_gaussian_scalar_logpdf_contract.py new file mode 100644 index 000000000..cc4d9fc24 --- /dev/null +++ b/tests/backend_support/test_gaussian_scalar_logpdf_contract.py @@ -0,0 +1,36 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +@pytest.mark.parametrize("backend_name", ["numpy", "pytorch", "jax"]) +def test_one_dimensional_gaussian_scalar_logpdf_preserves_scalar_shape(backend_name): + if backend_name == "pytorch" and importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + if backend_name == "jax" and importlib.util.find_spec("jax") is None: + pytest.skip("JAX is not installed") + + code = """ +import pyrecest.backend as backend +from pyrecest.distributions import GaussianDistribution + + +distribution = GaussianDistribution(backend.array(0.0), backend.array(1.0)) +log_density = distribution.ln_pdf(backend.array(0.0)) +density = distribution.pdf(backend.array(0.0)) +batch_log_density = distribution.ln_pdf(backend.array([0.0])) + +assert tuple(backend.shape(backend.asarray(log_density))) == () +assert tuple(backend.shape(backend.asarray(density))) == () +assert tuple(backend.shape(backend.asarray(batch_log_density))) == (1,) +assert float(backend.to_numpy(log_density)) < 0.0 +assert float(backend.to_numpy(density)) > 0.0 +print("ok") +""" + result = run_backend_code(backend_name, code) + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout