diff --git a/src/pyrecest/utils/track_evaluation/__init__.py b/src/pyrecest/utils/track_evaluation/__init__.py index 13220f760..65d2fdcaf 100644 --- a/src/pyrecest/utils/track_evaluation/__init__.py +++ b/src/pyrecest/utils/track_evaluation/__init__.py @@ -3,6 +3,8 @@ import numpy as np +from ._identity_f1 import patch_identity_f1 + d = runpy.run_path( str(Path(__file__).resolve().parents[1] / "track_evaluation.py"), run_name=__name__ ) @@ -20,6 +22,7 @@ def f(v): d["_optional_int_candidate"] = f +patch_identity_f1(d) for n in d["__all__"]: globals()[n] = d[n] __all__ = d["__all__"] diff --git a/src/pyrecest/utils/track_evaluation/_identity_f1.py b/src/pyrecest/utils/track_evaluation/_identity_f1.py new file mode 100644 index 000000000..0019fc261 --- /dev/null +++ b/src/pyrecest/utils/track_evaluation/_identity_f1.py @@ -0,0 +1,39 @@ +"""Local identity-set F1 fixes for track-evaluation exports.""" + +from __future__ import annotations + +from typing import Any + + +def patch_identity_f1(namespace: dict[str, Any]) -> None: + """Install an identity-set scorer with a zero-valued disjoint F1.""" + + safe_ratio = namespace["_safe_ratio"] + zero_ratio = namespace["_zero_ratio"] + + def score_identity_sets( + predicted: set[Any], + reference: set[Any], + *, + prefix: str, + predicted_total_name: str, + reference_total_name: str, + ) -> dict[str, float | int]: + true_positives = len(predicted & reference) + false_positives = len(predicted - reference) + false_negatives = len(reference - predicted) + precision = safe_ratio(true_positives, true_positives + false_positives) + recall = safe_ratio(true_positives, true_positives + false_negatives) + f1 = zero_ratio(2.0 * precision * recall, precision + recall) + return { + f"{prefix}_true_positives": true_positives, + f"{prefix}_false_positives": false_positives, + f"{prefix}_false_negatives": false_negatives, + f"{prefix}_precision": precision, + f"{prefix}_recall": recall, + f"{prefix}_f1": f1, + predicted_total_name: len(predicted), + reference_total_name: len(reference), + } + + namespace["_score_identity_sets"] = score_identity_sets diff --git a/tests/test_track_evaluation_identity_f1.py b/tests/test_track_evaluation_identity_f1.py new file mode 100644 index 000000000..b41dd4e21 --- /dev/null +++ b/tests/test_track_evaluation_identity_f1.py @@ -0,0 +1,31 @@ +from pyrecest.utils.track_evaluation import score_complete_tracks, score_track_links + + +def test_disjoint_track_identity_sets_have_zero_f1(): + complete_scores = score_complete_tracks([[0, 1]], [[2, 3]]) + + assert complete_scores["complete_track_true_positives"] == 0 + assert complete_scores["complete_track_false_positives"] == 1 + assert complete_scores["complete_track_false_negatives"] == 1 + assert complete_scores["complete_track_precision"] == 0.0 + assert complete_scores["complete_track_recall"] == 0.0 + assert complete_scores["complete_track_f1"] == 0.0 + + link_scores = score_track_links([[0, 1]], [[2, 3]]) + + assert link_scores["track_link_true_positives"] == 0 + assert link_scores["track_link_false_positives"] == 1 + assert link_scores["track_link_false_negatives"] == 1 + assert link_scores["track_link_precision"] == 0.0 + assert link_scores["track_link_recall"] == 0.0 + assert link_scores["track_link_f1"] == 0.0 + + +def test_empty_track_identity_sets_keep_perfect_precision_recall_and_f1(): + scores = score_complete_tracks([[None, None]], [[None, None]]) + + assert scores["complete_tracks"] == 0 + assert scores["reference_complete_tracks"] == 0 + assert scores["complete_track_precision"] == 1.0 + assert scores["complete_track_recall"] == 1.0 + assert scores["complete_track_f1"] == 1.0