diff --git a/src/pyrecest/evaluation/summarize_filter_results.py b/src/pyrecest/evaluation/summarize_filter_results.py index 89a58e7c0..ec28abac9 100644 --- a/src/pyrecest/evaluation/summarize_filter_results.py +++ b/src/pyrecest/evaluation/summarize_filter_results.py @@ -38,7 +38,7 @@ def summarize_filter_results( warnings.warn("Using less than 1000 runs. This may lead to unreliable results.") extract_mean = get_extract_mean( - scenario_config["manifold"], mtt_scenario=scenario_config["mtt"] + scenario_config["manifold"], mtt_scenario=scenario_config.get("mtt", False) ) distance_function = get_distance_function(scenario_config["manifold"]) errors_all = determine_all_deviations( diff --git a/tests/test_evaluation_summarize_missing_mtt.py b/tests/test_evaluation_summarize_missing_mtt.py new file mode 100644 index 000000000..04e85671a --- /dev/null +++ b/tests/test_evaluation_summarize_missing_mtt.py @@ -0,0 +1,34 @@ +import numpy as np +import pytest + +import pyrecest.backend +from pyrecest.evaluation import summarize_filter_results + + +@pytest.mark.skipif( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", +) +def test_summarize_filter_results_defaults_missing_mtt_flag(): + n_runs = 2 + groundtruths = np.empty((n_runs, 1), dtype=object) + for run in range(n_runs): + groundtruths[run, 0] = np.array([float(run), 0.0]) + + last_estimates = np.array([[[0.0, 0.0], [1.0, 0.0]]]) + runtimes = np.ones((1, n_runs)) + run_failed = np.zeros((1, n_runs), dtype=bool) + filter_configs = [{"name": "kf", "parameter": None}] + + summarized = summarize_filter_results( + scenario_config={"manifold": "Euclidean"}, + filter_configs=filter_configs, + runtimes=runtimes, + groundtruths=groundtruths, + run_failed=run_failed, + last_estimates=last_estimates, + ) + + assert summarized is filter_configs + assert summarized[0]["error_mean"] == 0 + assert summarized[0]["failure_rate"] == 0