Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions examples/ablation/run_agent_loop_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
REPO_ROOT = Path(__file__).resolve().parents[2]
OUT_DIR = Path(__file__).parent / "diagnostics"
DEFAULT_MSMARCO_PATH = REPO_ROOT / "tests" / "benchmark" / "data" / "msmarco_passage.json"
DEFAULT_LLM_BASE_URL = "http://localhost:8012/v1"
DEFAULT_LLM_MODEL = "Qwen3.6-27B"
DEFAULT_API_KEY_ENV = "OPENAI_API_KEY"
DEEPSEEK_LLM_BASE_URL = "https://api.deepseek.com/v1"
DEEPSEEK_LLM_MODEL = "deepseek-v4-flash"
DEEPSEEK_API_KEY_ENV = "DEEPSEEK_API_KEY"

_AGENT_LOOP_EXTRA_CONTEXT = """Benchmark context:
- You are evaluating retrieval, not general knowledge.
Expand Down Expand Up @@ -66,6 +72,27 @@ def _load_local_env(paths: list[Path] | None = None) -> None:
os.environ[key] = value.strip().strip("\"'")


def _resolve_llm_settings(
*,
preset: str,
llm_base_url: str | None,
model: str | None,
api_key_env: str | None,
) -> tuple[str, str, str]:
"""Resolve provider defaults while preserving explicit CLI overrides."""
if preset == "deepseek":
return (
llm_base_url or DEEPSEEK_LLM_BASE_URL,
model or DEEPSEEK_LLM_MODEL,
api_key_env or DEEPSEEK_API_KEY_ENV,
)
return (
llm_base_url or DEFAULT_LLM_BASE_URL,
model or DEFAULT_LLM_MODEL,
api_key_env or DEFAULT_API_KEY_ENV,
)


@dataclass(slots=True)
class AgentLoopRow:
qid: str
Expand Down Expand Up @@ -493,9 +520,18 @@ async def amain(argv: list[str] | None = None) -> int:
parser.add_argument("--sqlite-db-path", type=Path, required=True)
parser.add_argument("--subset", type=int, default=20)
parser.add_argument("--corpus-limit", type=int, default=0)
parser.add_argument("--llm-base-url", default="http://localhost:8012/v1")
parser.add_argument("--model", default="Qwen3.6-27B")
parser.add_argument("--api-key-env", default="OPENAI_API_KEY")
parser.add_argument(
"--llm-preset",
choices=("local", "deepseek"),
default="local",
help=(
"Provider preset for omitted LLM settings. "
"deepseek => api.deepseek.com/v1, deepseek-v4-flash, DEEPSEEK_API_KEY."
),
)
parser.add_argument("--llm-base-url", default=None)
parser.add_argument("--model", default=None)
parser.add_argument("--api-key-env", default=None)
parser.add_argument("--max-turns", type=int, default=5)
parser.add_argument(
"--llm-timeout",
Expand Down Expand Up @@ -543,6 +579,12 @@ async def amain(argv: list[str] | None = None) -> int:
raise SystemExit("--preflight-timeout must be positive")
if args.resume and args.out_jsonl is None:
raise SystemExit("--resume requires --out-jsonl")
args.llm_base_url, args.model, args.api_key_env = _resolve_llm_settings(
preset=args.llm_preset,
llm_base_url=args.llm_base_url,
model=args.model,
api_key_env=args.api_key_env,
)
if not args.msmarco_path.exists():
raise SystemExit(f"{args.msmarco_path} does not exist")
if not args.sqlite_db_path.exists():
Expand Down
39 changes: 39 additions & 0 deletions tests/test_agent_search_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,45 @@ def test_agent_loop_load_local_env_without_overriding_shell_env(
assert os.environ["DEEPSEEK_API_KEY"] == "from_shell"


def test_agent_loop_deepseek_preset_resolves_provider_defaults() -> None:
assert loop_runner._resolve_llm_settings(
preset="deepseek",
llm_base_url=None,
model=None,
api_key_env=None,
) == (
"https://api.deepseek.com/v1",
"deepseek-v4-flash",
"DEEPSEEK_API_KEY",
)


def test_agent_loop_llm_preset_preserves_explicit_overrides() -> None:
assert loop_runner._resolve_llm_settings(
preset="deepseek",
llm_base_url="https://example.test/v1",
model="custom-model",
api_key_env="CUSTOM_KEY",
) == (
"https://example.test/v1",
"custom-model",
"CUSTOM_KEY",
)


def test_agent_loop_local_preset_keeps_existing_defaults() -> None:
assert loop_runner._resolve_llm_settings(
preset="local",
llm_base_url=None,
model=None,
api_key_env=None,
) == (
"http://localhost:8012/v1",
"Qwen3.6-27B",
"OPENAI_API_KEY",
)


def test_llm_preflight_error_message_names_endpoint_and_skip_hint() -> None:
msg = loop_runner._llm_preflight_error_message(
"http://127.0.0.1:18012/v1",
Expand Down
Loading