Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pyrecest/evaluation/summarize_filter_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def summarize_filter_results(
warnings.warn("Using less than 1000 runs. This may lead to unreliable results.")

extract_mean = get_extract_mean(
scenario_config["manifold"], mtt_scenario=scenario_config["mtt"]
scenario_config["manifold"], mtt_scenario=scenario_config.get("mtt", False)
)
distance_function = get_distance_function(scenario_config["manifold"])
errors_all = determine_all_deviations(
Expand Down
34 changes: 34 additions & 0 deletions tests/test_evaluation_summarize_missing_mtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import pytest

import pyrecest.backend
from pyrecest.evaluation import summarize_filter_results


@pytest.mark.skipif(
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_summarize_filter_results_defaults_missing_mtt_flag():
n_runs = 2
groundtruths = np.empty((n_runs, 1), dtype=object)
for run in range(n_runs):
groundtruths[run, 0] = np.array([float(run), 0.0])

last_estimates = np.array([[[0.0, 0.0], [1.0, 0.0]]])
runtimes = np.ones((1, n_runs))
run_failed = np.zeros((1, n_runs), dtype=bool)
filter_configs = [{"name": "kf", "parameter": None}]

summarized = summarize_filter_results(
scenario_config={"manifold": "Euclidean"},
filter_configs=filter_configs,
runtimes=runtimes,
groundtruths=groundtruths,
run_failed=run_failed,
last_estimates=last_estimates,
)

assert summarized is filter_configs
assert summarized[0]["error_mean"] == 0
assert summarized[0]["failure_rate"] == 0
Loading