Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"opencv-python-headless<=4.12.0.88",
"pydantic",
"tensorboard",
"tenacity",
"xxhash",
"imageio",
"timm",
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ transformers_stream_generator
opencv-python-headless
pydantic
tensorboard
tenacity
xxhash
imageio
timm
97 changes: 85 additions & 12 deletions tests/rl/test_rollout_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from unittest.mock import AsyncMock, MagicMock, patch

import httpx

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.rollout.controller import RolloutController
from xtuner.v1.rl.rollout.sglang import SGLangWorker
Expand Down Expand Up @@ -180,6 +179,9 @@ def _build_mock_error_rollout_worker(
safe_post_result: HttpRequestResult,
safe_handle_response=None,
):
async def passthrough_response(rollout_state, http_response):
return rollout_state

worker = RolloutWorker.__new__(RolloutWorker)
worker.receive_abort_request = threading.Event()
worker.enable_partial_rollout = False
Expand All @@ -190,7 +192,7 @@ def _build_mock_error_rollout_worker(
worker.logger = MagicMock()
worker._get_request_payload = MagicMock(return_value={"input_ids": [1], "max_tokens": 128})
worker._safe_post_request = AsyncMock(return_value=safe_post_result)
worker._safe_handle_response = AsyncMock(side_effect=safe_handle_response)
worker._safe_handle_response = AsyncMock(side_effect=safe_handle_response or passthrough_response)
return worker

async def test_generate_returns_aborted_when_abort_flag_is_set(self):
Expand Down Expand Up @@ -348,21 +350,11 @@ def server_error_result():
payload={"input_ids": [1]},
)

async def invalid_response(rollout_state, http_response):
rollout_state.status = Status.FAILED
return rollout_state

cases = [
("timeout", timeout_result(), None, ("Request failed", "3")),
("request_error", request_error_result(), None, ("Request failed", "3")),
("client_error", client_error_result(), None, ("Client error",)),
("server_error", server_error_result(), None, ("Server error",)),
(
"invalid_response",
HttpRequestResult(response=object()),
invalid_response,
("Invalid rollout response", "3"),
),
]

async def run_case(case_name, safe_post_result, safe_handle_response, expected_messages):
Expand Down Expand Up @@ -390,6 +382,87 @@ async def run_case(case_name, safe_post_result, safe_handle_response, expected_m
with patch("xtuner.v1.rl.rollout.worker.asyncio.sleep", new=AsyncMock()):
await asyncio.gather(*(run_case(*case) for case in cases))

async def test_generate_attempt_counts_for_retryable_and_non_retryable_errors(self):
# max_retry_per_sample=3 表示首轮请求 + 3 次 retry;非 retryable HTTP 错误不应重试。
def timeout_result():
error = httpx.TimeoutException("Mocked timeout error")
return HttpRequestResult(
error_type=HttpRequestErrorType.from_exception(error),
exception=error,
url="http://test/generate",
payload={"input_ids": [1]},
)

def client_error_result():
req = httpx.Request("POST", "http://test/generate")
response = httpx.Response(400, request=req)
error = httpx.HTTPStatusError("Mocked client error", request=req, response=response)
return HttpRequestResult(
error_type=HttpRequestErrorType.from_exception(error),
exception=error,
url="http://test/generate",
payload={"input_ids": [1]},
)

def server_error_result():
req = httpx.Request("POST", "http://test/generate")
response = httpx.Response(500, request=req)
error = httpx.HTTPStatusError("Mocked server error", request=req, response=response)
return HttpRequestResult(
error_type=HttpRequestErrorType.from_exception(error),
exception=error,
url="http://test/generate",
payload={"input_ids": [1]},
)

for safe_post_result, expected_calls in (
(timeout_result(), 4),
(client_error_result(), 1),
(server_error_result(), 1),
):
worker = self._build_mock_error_rollout_worker(safe_post_result=safe_post_result)
await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}]))
self.assertEqual(worker._safe_post_request.await_count, expected_calls)

async def test_generate_retries_invalid_response_for_configured_attempts(self):
# invalid response 回到 response handler 内判定;handler 抛 retryable 异常后由 Tenacity 重试。
from xtuner.v1.rl.rollout import worker as rollout_worker

async def invalid_response(rollout_state, http_response):
raise rollout_worker._RetryableInvalidRolloutResponseError("missing finish_reason")

worker = self._build_mock_error_rollout_worker(
safe_post_result=HttpRequestResult(response=object()),
safe_handle_response=invalid_response,
)

result_state = await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}]))

self.assertEqual(worker._safe_post_request.await_count, 4)
self.assertEqual(worker._safe_handle_response.await_count, 4)
self.assertEqual(result_state.status, Status.FAILED)
self.assertIn("missing finish_reason", result_state.error_msg)

async def test_generate_retries_response_handler_failed_status_for_backwards_compatibility(self):
# Some rollout backends still signal invalid responses by returning
# Status.FAILED instead of raising the retryable invalid-response error.
async def failed_response(rollout_state, http_response):
rollout_state.status = Status.FAILED
rollout_state.error_msg = "handler returned incomplete token data"
return rollout_state

worker = self._build_mock_error_rollout_worker(
safe_post_result=HttpRequestResult(response=object()),
safe_handle_response=failed_response,
)

result_state = await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}]))

self.assertEqual(worker._safe_post_request.await_count, 4)
self.assertEqual(worker._safe_handle_response.await_count, 4)
self.assertEqual(result_state.status, Status.FAILED)
self.assertIn("handler returned incomplete token data", result_state.error_msg)


class TestRolloutHealthChecker(unittest.TestCase):
def _build_checker(self, workers_info):
Expand Down
128 changes: 128 additions & 0 deletions tests/utils/test_retry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import unittest


class TestSyncRetryPolicies(unittest.TestCase):
def test_retry_trace_store_bootstrap_uses_incremental_wait_and_reraises_last_value_error(self):
from xtuner.v1.utils.retry_utils import retry_trace_store_bootstrap

attempts = 0
sleep_seconds: list[float] = []

def attempt_once():
nonlocal attempts
attempts += 1
raise ValueError(f"attempt {attempts}")

retryer = retry_trace_store_bootstrap(
attempts=3,
wait_start_seconds=0.2,
wait_increment_seconds=0.2,
wait_max_seconds=2.0,
sleep=lambda seconds: sleep_seconds.append(seconds),
)

with self.assertRaisesRegex(ValueError, "attempt 3"):
retryer(attempt_once)

self.assertEqual(attempts, 3)
self.assertEqual(sleep_seconds, [0.2, 0.4])


class TestAsyncRetryPolicies(unittest.IsolatedAsyncioTestCase):
async def test_retry_rollout_request_uses_configured_attempts_and_fixed_wait(self):
from xtuner.v1.utils.retry_utils import retry_rollout_request

class RetryableRolloutError(Exception):
pass

attempts = 0
sleep_seconds: list[float] = []
before_retry_attempts: list[int] = []

async def attempt_once():
nonlocal attempts
attempts += 1
raise RetryableRolloutError(f"attempt {attempts}")

async def sleep(seconds: float):
sleep_seconds.append(seconds)

retryer = retry_rollout_request(
attempts=3,
retry_exceptions=(RetryableRolloutError,),
wait_seconds=0.1,
before_retry=lambda retry_state: before_retry_attempts.append(retry_state.attempt_number),
sleep=sleep,
)

with self.assertRaisesRegex(RetryableRolloutError, "attempt 3"):
await retryer(attempt_once)

self.assertEqual(attempts, 3)
self.assertEqual(sleep_seconds, [0.1, 0.1])
self.assertEqual(before_retry_attempts, [1, 2])

async def test_retry_sandbox_acquire_uses_policy_waits_for_unhealthy_and_create_failures(self):
from xtuner.v1.utils.retry_utils import retry_sandbox_acquire

class UnhealthySandboxError(Exception):
pass

class SandboxError(Exception):
pass

attempts = 0
sleep_seconds: list[float] = []

async def attempt_once():
nonlocal attempts
attempts += 1
if attempts == 1:
raise UnhealthySandboxError("unhealthy")
raise SandboxError(f"attempt {attempts}")

async def sleep(seconds: float):
sleep_seconds.append(seconds)

retryer = retry_sandbox_acquire(
attempts=3,
unhealthy_exceptions=(UnhealthySandboxError,),
create_wait_multiplier=2.0,
create_wait_min_seconds=2.0,
create_wait_max_seconds=8.0,
sleep=sleep,
)

with self.assertRaisesRegex(SandboxError, "attempt 3"):
await retryer(attempt_once)

self.assertEqual(attempts, 3)
self.assertEqual(sleep_seconds, [0.0, 4.0])

async def test_poll_sandbox_health_returns_last_result_after_timeout(self):
from tenacity import stop_after_attempt

from xtuner.v1.utils.retry_utils import poll_sandbox_health

attempts = 0
sleep_seconds: list[float] = []

async def poll_once():
nonlocal attempts
attempts += 1
return False

async def sleep(seconds: float):
sleep_seconds.append(seconds)

retryer = poll_sandbox_health(
stop_strategy=stop_after_attempt(3),
wait_seconds=0.1,
sleep=sleep,
)

result = await retryer(poll_once)

self.assertFalse(result)
self.assertEqual(attempts, 3)
self.assertEqual(sleep_seconds, [0.1, 0.1])
Loading
Loading