Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions applications/ColossalChat/start_code_verifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from typing import List, Optional

from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel

app = FastAPI()

_API_KEY = os.environ.get("CODE_VERIFIER_API_KEY", "")
_MAX_TIMEOUT = 30


class CheckCorrectnessRequest(BaseModel):
in_outs: Optional[dict]
Expand All @@ -21,12 +25,14 @@ class CheckCorrectnessResponse(BaseModel):


@app.post("/check_correctness", response_model=CheckCorrectnessResponse)
def check_correctness_api(request: CheckCorrectnessRequest):
def check_correctness_api(request: CheckCorrectnessRequest, x_api_key: str = Header(...)):
if not _API_KEY or x_api_key != _API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
try:
result, metadata = check_correctness(
in_outs=request.in_outs,
generation=request.generation,
timeout=request.timeout,
timeout=min(request.timeout, _MAX_TIMEOUT),
debug=request.debug,
eval_mode=request.eval_mode,
)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_invariant_start_code_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import importlib.util
from pathlib import Path

import pytest
from fastapi.testclient import TestClient

# Load the actual production app from the file under test
spec = importlib.util.spec_from_file_location(
"start_code_verifier", Path(__file__).parent / "applications/ColossalChat/start_code_verifier.py"
)
module = importlib.util.load_from_spec(spec)
spec.loader.exec_module(module)
app = module.app

client = TestClient(app, raise_server_exceptions=False)

VALID_PAYLOAD = {"generation": "def f(x): return x", "in_outs": "{}", "timeout": 5, "debug": False, "eval_mode": "exec"}


@pytest.mark.parametrize(
"headers,payload",
[
# Exact exploit: no auth, arbitrary code execution attempt
(
{},
{
"generation": "__import__('os').system('id')",
"in_outs": "{}",
"timeout": 10,
"debug": False,
"eval_mode": "exec",
},
),
# Missing token (boundary: empty Authorization header)
({"Authorization": ""}, VALID_PAYLOAD),
# Malformed token
({"Authorization": "Bearer not.a.valid.jwt"}, VALID_PAYLOAD),
# Expired/fake token
({"Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid"}, VALID_PAYLOAD),
],
)
def test_check_correctness_requires_authentication(headers, payload):
"""Invariant: /check_correctness must reject unauthenticated or invalidly-authenticated requests with 401 or 403."""
response = client.post("/check_correctness", json=payload, headers=headers)
assert response.status_code in (401, 403), (
f"Expected 401 or 403 for unauthenticated request, got {response.status_code}. "
"The endpoint must not be publicly accessible without valid credentials."
)