diff --git a/src/pyrecest/utils/pairwise_covariance_features.py b/src/pyrecest/utils/pairwise_covariance_features.py index 89193b808..dfbbc6335 100644 --- a/src/pyrecest/utils/pairwise_covariance_features.py +++ b/src/pyrecest/utils/pairwise_covariance_features.py @@ -158,8 +158,11 @@ def pairwise_covariance_shape_components( n_a = covariances_a.shape[2] n_b = covariances_b.shape[2] if n_a == 0 or n_b == 0: - empty = zeros((n_a, n_b), dtype=float64) - return empty, empty, empty + return ( + zeros((n_a, n_b), dtype=float64), + zeros((n_a, n_b), dtype=float64), + zeros((n_a, n_b), dtype=float64), + ) moved_covariances_a = _symmetrized_covariance_batch(covariances_a) moved_covariances_b = _symmetrized_covariance_batch(covariances_b) diff --git a/tests/test_covariance_empty_outputs.py b/tests/test_covariance_empty_outputs.py new file mode 100644 index 000000000..d687ccd99 --- /dev/null +++ b/tests/test_covariance_empty_outputs.py @@ -0,0 +1,13 @@ +from pyrecest.backend import zeros +from pyrecest.utils import pairwise_covariance_shape_components + + +def test_pairwise_covariance_empty_outputs_are_independent_objects(): + shape_cost, logdet_cost, shape_similarity = pairwise_covariance_shape_components( + zeros((2, 2, 0)), zeros((2, 2, 4)) + ) + + assert shape_cost.shape == (0, 4) + assert logdet_cost.shape == (0, 4) + assert shape_similarity.shape == (0, 4) + assert len({id(shape_cost), id(logdet_cost), id(shape_similarity)}) == 3