Skip to content

Commit ad3d548

Browse files
macekondclaude
andcommitted
fix(eval): thread-safe progress callbacks for --concurrency
Wrap Rich Console.print() calls in _make_progress_callbacks with a threading.Lock. Without this, concurrent ThreadPoolExecutor workers deadlock when stdout is piped (e.g. background process). The lock serializes only the print calls — actual evaluation work remains fully concurrent. Add test_progress_callbacks_thread_safe to verify callbacks can be called from 8 concurrent threads without errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 04bfbc5 commit ad3d548

2 files changed

Lines changed: 55 additions & 4 deletions

File tree

packages/gooddata-eval/src/gooddata_eval/cli/main.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import argparse
55
import sys
6+
import threading
67
from datetime import datetime, timezone
78
from pathlib import Path
89

@@ -139,14 +140,22 @@ def _parse_model_arg(val: str) -> tuple[str | None, str]:
139140

140141

141142
def _make_progress_callbacks(console: Console):
142-
"""Build (on_item_start, on_run_done, on_item_done) callbacks that stream progress."""
143+
"""Build (on_item_start, on_run_done, on_item_done) callbacks that stream progress.
144+
145+
A threading lock guards all console.print() calls so that concurrent
146+
``--concurrency 2+`` workers do not deadlock on Rich's internal buffer
147+
when stdout is piped (e.g. running in a background process).
148+
"""
149+
_print_lock = threading.Lock()
143150

144151
def on_item_start(index: int, total: int, item: DatasetItem) -> None:
145-
console.print(f"[dim]\\[{index}/{total}][/dim] [cyan]{item.id}[/cyan] {_truncate(item.question)}")
152+
with _print_lock:
153+
console.print(f"[dim]\\[{index}/{total}][/dim] [cyan]{item.id}[/cyan] {_truncate(item.question)}")
146154

147155
def on_run_done(index: int, total: int, run_index: int, runs: int, passed: bool, latency: float) -> None:
148156
tag = "[green]pass[/green]" if passed else "[red]fail[/red]"
149-
console.print(f"[dim]\\[{index}/{total}][/dim] run {run_index}/{runs} {tag} [dim]{latency:.2f}s[/dim]")
157+
with _print_lock:
158+
console.print(f"[dim]\\[{index}/{total}][/dim] run {run_index}/{runs} {tag} [dim]{latency:.2f}s[/dim]")
150159

151160
def on_item_done(index: int, total: int, report: ItemReport) -> None:
152161
if report.skipped:
@@ -165,7 +174,8 @@ def on_item_done(index: int, total: int, report: ItemReport) -> None:
165174
f" [dim]({report.latency_s:.2f}s total, {report.avg_latency_s:.2f}s avg, "
166175
f"quality={quality_str}, {report.runs} run(s))[/dim]"
167176
)
168-
console.print(f"[dim]\\[{index}/{total}][/dim] -> {tag} [cyan]{report.id}[/cyan]{suffix}")
177+
with _print_lock:
178+
console.print(f"[dim]\\[{index}/{total}][/dim] -> {tag} [cyan]{report.id}[/cyan]{suffix}")
169179

170180
return on_item_start, on_run_done, on_item_done
171181

packages/gooddata-eval/tests/test_cli.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,44 @@ def test_cli_rejects_negative_concurrency(monkeypatch, fixtures_dir):
530530
]
531531
)
532532
assert exit_code == 2
533+
534+
535+
def test_progress_callbacks_thread_safe():
536+
"""Verify progress callbacks can be called from multiple threads without error."""
537+
import io
538+
import threading
539+
from concurrent.futures import ThreadPoolExecutor, as_completed
540+
541+
console = Console(file=io.StringIO(), force_terminal=False)
542+
on_item_start, on_run_done, on_item_done = cli_main._make_progress_callbacks(console)
543+
544+
errors: list[Exception] = []
545+
546+
def _worker(index: int) -> None:
547+
try:
548+
item = DatasetItem(
549+
id=f"test-{index}",
550+
dataset_name="test",
551+
test_kind="general_question",
552+
question=f"Question {index}",
553+
expected_output="answer",
554+
)
555+
on_item_start(index, 100, item)
556+
on_run_done(index, 100, 1, 1, index % 2 == 0, 1.5)
557+
report = ItemReport(id=f"test-{index}", dataset_name="test", test_kind="general_question")
558+
report.runs = 1
559+
report.latency_s = 1.5
560+
report.pass_at_k = index % 2 == 0
561+
on_item_done(index, 100, report)
562+
except Exception as e:
563+
errors.append(e)
564+
565+
with ThreadPoolExecutor(max_workers=8) as pool:
566+
futures = [pool.submit(_worker, i) for i in range(50)]
567+
for f in as_completed(futures):
568+
f.result() # re-raise if any thread failed
569+
570+
assert not errors, f"Thread-safety violation: {errors}"
571+
output = console.file.getvalue()
572+
assert "test-1" in output
573+
assert "test-49" in output

0 commit comments

Comments
 (0)