diff --git a/pyproject.toml b/pyproject.toml index a3af546cf6..8d67642cb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "opencv-python-headless<=4.12.0.88", "pydantic", "tensorboard", + "tenacity", "xxhash", "imageio", "timm", diff --git a/requirements/runtime.txt b/requirements/runtime.txt index ce60f33cf4..d248bc014e 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,6 +17,7 @@ transformers_stream_generator opencv-python-headless pydantic tensorboard +tenacity xxhash imageio timm diff --git a/tests/rl/test_rollout_logic.py b/tests/rl/test_rollout_logic.py index 8aa7501de7..740057c316 100644 --- a/tests/rl/test_rollout_logic.py +++ b/tests/rl/test_rollout_logic.py @@ -17,7 +17,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx - from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status from xtuner.v1.rl.rollout.controller import RolloutController from xtuner.v1.rl.rollout.sglang import SGLangWorker @@ -180,6 +179,9 @@ def _build_mock_error_rollout_worker( safe_post_result: HttpRequestResult, safe_handle_response=None, ): + async def passthrough_response(rollout_state, http_response): + return rollout_state + worker = RolloutWorker.__new__(RolloutWorker) worker.receive_abort_request = threading.Event() worker.enable_partial_rollout = False @@ -190,7 +192,7 @@ def _build_mock_error_rollout_worker( worker.logger = MagicMock() worker._get_request_payload = MagicMock(return_value={"input_ids": [1], "max_tokens": 128}) worker._safe_post_request = AsyncMock(return_value=safe_post_result) - worker._safe_handle_response = AsyncMock(side_effect=safe_handle_response) + worker._safe_handle_response = AsyncMock(side_effect=safe_handle_response or passthrough_response) return worker async def test_generate_returns_aborted_when_abort_flag_is_set(self): @@ -348,21 +350,11 @@ def server_error_result(): payload={"input_ids": [1]}, ) - async def invalid_response(rollout_state, http_response): - rollout_state.status = Status.FAILED - return rollout_state - cases = [ ("timeout", timeout_result(), None, ("Request failed", "3")), ("request_error", request_error_result(), None, ("Request failed", "3")), ("client_error", client_error_result(), None, ("Client error",)), ("server_error", server_error_result(), None, ("Server error",)), - ( - "invalid_response", - HttpRequestResult(response=object()), - invalid_response, - ("Invalid rollout response", "3"), - ), ] async def run_case(case_name, safe_post_result, safe_handle_response, expected_messages): @@ -390,6 +382,87 @@ async def run_case(case_name, safe_post_result, safe_handle_response, expected_m with patch("xtuner.v1.rl.rollout.worker.asyncio.sleep", new=AsyncMock()): await asyncio.gather(*(run_case(*case) for case in cases)) + async def test_generate_attempt_counts_for_retryable_and_non_retryable_errors(self): + # max_retry_per_sample=3 表示首轮请求 + 3 次 retry;非 retryable HTTP 错误不应重试。 + def timeout_result(): + error = httpx.TimeoutException("Mocked timeout error") + return HttpRequestResult( + error_type=HttpRequestErrorType.from_exception(error), + exception=error, + url="http://test/generate", + payload={"input_ids": [1]}, + ) + + def client_error_result(): + req = httpx.Request("POST", "http://test/generate") + response = httpx.Response(400, request=req) + error = httpx.HTTPStatusError("Mocked client error", request=req, response=response) + return HttpRequestResult( + error_type=HttpRequestErrorType.from_exception(error), + exception=error, + url="http://test/generate", + payload={"input_ids": [1]}, + ) + + def server_error_result(): + req = httpx.Request("POST", "http://test/generate") + response = httpx.Response(500, request=req) + error = httpx.HTTPStatusError("Mocked server error", request=req, response=response) + return HttpRequestResult( + error_type=HttpRequestErrorType.from_exception(error), + exception=error, + url="http://test/generate", + payload={"input_ids": [1]}, + ) + + for safe_post_result, expected_calls in ( + (timeout_result(), 4), + (client_error_result(), 1), + (server_error_result(), 1), + ): + worker = self._build_mock_error_rollout_worker(safe_post_result=safe_post_result) + await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}])) + self.assertEqual(worker._safe_post_request.await_count, expected_calls) + + async def test_generate_retries_invalid_response_for_configured_attempts(self): + # invalid response 回到 response handler 内判定;handler 抛 retryable 异常后由 Tenacity 重试。 + from xtuner.v1.rl.rollout import worker as rollout_worker + + async def invalid_response(rollout_state, http_response): + raise rollout_worker._RetryableInvalidRolloutResponseError("missing finish_reason") + + worker = self._build_mock_error_rollout_worker( + safe_post_result=HttpRequestResult(response=object()), + safe_handle_response=invalid_response, + ) + + result_state = await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}])) + + self.assertEqual(worker._safe_post_request.await_count, 4) + self.assertEqual(worker._safe_handle_response.await_count, 4) + self.assertEqual(result_state.status, Status.FAILED) + self.assertIn("missing finish_reason", result_state.error_msg) + + async def test_generate_retries_response_handler_failed_status_for_backwards_compatibility(self): + # Some rollout backends still signal invalid responses by returning + # Status.FAILED instead of raising the retryable invalid-response error. + async def failed_response(rollout_state, http_response): + rollout_state.status = Status.FAILED + rollout_state.error_msg = "handler returned incomplete token data" + return rollout_state + + worker = self._build_mock_error_rollout_worker( + safe_post_result=HttpRequestResult(response=object()), + safe_handle_response=failed_response, + ) + + result_state = await worker.generate(RolloutState(message=[{"role": "user", "content": "Hello!"}])) + + self.assertEqual(worker._safe_post_request.await_count, 4) + self.assertEqual(worker._safe_handle_response.await_count, 4) + self.assertEqual(result_state.status, Status.FAILED) + self.assertIn("handler returned incomplete token data", result_state.error_msg) + class TestRolloutHealthChecker(unittest.TestCase): def _build_checker(self, workers_info): diff --git a/tests/utils/test_retry_utils.py b/tests/utils/test_retry_utils.py new file mode 100644 index 0000000000..167694a783 --- /dev/null +++ b/tests/utils/test_retry_utils.py @@ -0,0 +1,128 @@ +import unittest + + +class TestSyncRetryPolicies(unittest.TestCase): + def test_retry_trace_store_bootstrap_uses_incremental_wait_and_reraises_last_value_error(self): + from xtuner.v1.utils.retry_utils import retry_trace_store_bootstrap + + attempts = 0 + sleep_seconds: list[float] = [] + + def attempt_once(): + nonlocal attempts + attempts += 1 + raise ValueError(f"attempt {attempts}") + + retryer = retry_trace_store_bootstrap( + attempts=3, + wait_start_seconds=0.2, + wait_increment_seconds=0.2, + wait_max_seconds=2.0, + sleep=lambda seconds: sleep_seconds.append(seconds), + ) + + with self.assertRaisesRegex(ValueError, "attempt 3"): + retryer(attempt_once) + + self.assertEqual(attempts, 3) + self.assertEqual(sleep_seconds, [0.2, 0.4]) + + +class TestAsyncRetryPolicies(unittest.IsolatedAsyncioTestCase): + async def test_retry_rollout_request_uses_configured_attempts_and_fixed_wait(self): + from xtuner.v1.utils.retry_utils import retry_rollout_request + + class RetryableRolloutError(Exception): + pass + + attempts = 0 + sleep_seconds: list[float] = [] + before_retry_attempts: list[int] = [] + + async def attempt_once(): + nonlocal attempts + attempts += 1 + raise RetryableRolloutError(f"attempt {attempts}") + + async def sleep(seconds: float): + sleep_seconds.append(seconds) + + retryer = retry_rollout_request( + attempts=3, + retry_exceptions=(RetryableRolloutError,), + wait_seconds=0.1, + before_retry=lambda retry_state: before_retry_attempts.append(retry_state.attempt_number), + sleep=sleep, + ) + + with self.assertRaisesRegex(RetryableRolloutError, "attempt 3"): + await retryer(attempt_once) + + self.assertEqual(attempts, 3) + self.assertEqual(sleep_seconds, [0.1, 0.1]) + self.assertEqual(before_retry_attempts, [1, 2]) + + async def test_retry_sandbox_acquire_uses_policy_waits_for_unhealthy_and_create_failures(self): + from xtuner.v1.utils.retry_utils import retry_sandbox_acquire + + class UnhealthySandboxError(Exception): + pass + + class SandboxError(Exception): + pass + + attempts = 0 + sleep_seconds: list[float] = [] + + async def attempt_once(): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise UnhealthySandboxError("unhealthy") + raise SandboxError(f"attempt {attempts}") + + async def sleep(seconds: float): + sleep_seconds.append(seconds) + + retryer = retry_sandbox_acquire( + attempts=3, + unhealthy_exceptions=(UnhealthySandboxError,), + create_wait_multiplier=2.0, + create_wait_min_seconds=2.0, + create_wait_max_seconds=8.0, + sleep=sleep, + ) + + with self.assertRaisesRegex(SandboxError, "attempt 3"): + await retryer(attempt_once) + + self.assertEqual(attempts, 3) + self.assertEqual(sleep_seconds, [0.0, 4.0]) + + async def test_poll_sandbox_health_returns_last_result_after_timeout(self): + from tenacity import stop_after_attempt + + from xtuner.v1.utils.retry_utils import poll_sandbox_health + + attempts = 0 + sleep_seconds: list[float] = [] + + async def poll_once(): + nonlocal attempts + attempts += 1 + return False + + async def sleep(seconds: float): + sleep_seconds.append(seconds) + + retryer = poll_sandbox_health( + stop_strategy=stop_after_attempt(3), + wait_seconds=0.1, + sleep=sleep, + ) + + result = await retryer(poll_once) + + self.assertFalse(result) + self.assertEqual(attempts, 3) + self.assertEqual(sleep_seconds, [0.1, 0.1]) diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py index e664cff874..3346730ea6 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py @@ -49,6 +49,7 @@ async def hook(client, item, record) -> None ) from xtuner.v1.rl.agent_loop.sandbox_agent_loop.trace import span from xtuner.v1.utils import get_logger +from xtuner.v1.utils.retry_utils import poll_sandbox_health, retry_sandbox_acquire # ───────────────────────────────────────────────────────────────── @@ -826,6 +827,10 @@ async def _sandbox_alive(client: Any, timeout_sec: float = 5.0) -> bool: # ───────────────────────────────────────────────────────────────── +class _UnhealthySandboxError(RuntimeError): + pass + + class SandboxPool: """Per-run sandbox client pool: lazily acquires + caches clients by name. @@ -873,7 +878,7 @@ async def get(self, name: str, *, record: StageRecord | None = None) -> Any: self.validate_name(name) spec = self._specs[name] try: - client, env_id = await self._acquire_ready(spec) + client, env_id = await self._acquire_ready(spec, name=name) except Exception as exc: if record is not None: record.status = StageStatus.FAILED @@ -931,30 +936,24 @@ def _url_of(client: Any) -> str | None: return str(val) return None - async def _acquire_ready(self, spec: SandboxSpec) -> tuple[Any, str]: - last_err: Exception | None = None - for attempt in range(1, self._max_attempts + 1): - try: - create_kwargs: dict[str, Any] = {} - if spec.key: - create_kwargs["key"] = spec.key - if spec.env_vars: - create_kwargs["env_vars"] = spec.env_vars - if spec.resources: - create_kwargs["resources"] = spec.resources - if self._create_limiter is not None: - await self._create_limiter.acquire() - client, env_id = await self._provider.create( - image_tag=spec.image, - ttl_seconds=spec.ttl_seconds, - **create_kwargs, - ) - except Exception as exc: - last_err = exc - await asyncio.sleep(min(2**attempt, 8)) - continue + async def _acquire_ready(self, spec: SandboxSpec, *, name: str | None = None) -> tuple[Any, str]: + async def _attempt_once() -> tuple[Any, str]: + create_kwargs: dict[str, Any] = {} + if spec.key: + create_kwargs["key"] = spec.key + if spec.env_vars: + create_kwargs["env_vars"] = spec.env_vars + if spec.resources: + create_kwargs["resources"] = spec.resources + if self._create_limiter is not None: + await self._create_limiter.acquire() + client, env_id = await self._provider.create( + image_tag=spec.image, + ttl_seconds=spec.ttl_seconds, + **create_kwargs, + ) - if await self._wait_healthy(client): + if await self._wait_healthy(client, name=name): return client, env_id try: @@ -971,29 +970,46 @@ async def _acquire_ready(self, spec: SandboxSpec) -> tuple[Any, str]: f"aclose of unhealthy sandbox env_id={env_id} failed:\n" f"{''.join(traceback.format_exception(type(exc), exc, exc.__traceback__)).rstrip()}" ) - last_err = RuntimeError(f"sandbox {env_id} unhealthy") - - last_err_msg = ( - "".join(traceback.format_exception(type(last_err), last_err, last_err.__traceback__)).rstrip() - if last_err is not None - else "unknown" + raise _UnhealthySandboxError(f"sandbox {env_id} unhealthy") + + retryer = retry_sandbox_acquire( + attempts=self._max_attempts, + unhealthy_exceptions=(_UnhealthySandboxError,), + logger=get_logger(), + sandbox_name=name, + sleep=asyncio.sleep, ) - raise RuntimeError(f"could not acquire a healthy sandbox after {self._max_attempts} attempts: {last_err_msg}") + try: + return await retryer(_attempt_once) + except Exception as exc: + last_err_msg = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)).rstrip() + raise RuntimeError( + f"could not acquire a healthy sandbox after {self._max_attempts} attempts: {last_err_msg}" + ) from exc - async def _wait_healthy(self, client: Any) -> bool: - deadline = time.monotonic() + self._health_max_wait_sec - while time.monotonic() < deadline: + async def _wait_healthy(self, client: Any, *, name: str | None = None) -> bool: + async def _poll_once() -> bool: try: h = await client.health_check() - if h.get("ok"): - return True + return bool(h.get("ok")) except Exception as exc: get_logger().debug( "health poll error:\n" f"{''.join(traceback.format_exception(type(exc), exc, exc.__traceback__)).rstrip()}" ) - await asyncio.sleep(self._health_poll_interval_sec) - return False + return False + + if self._health_max_wait_sec <= 0: + return False + + retryer = poll_sandbox_health( + max_wait_seconds=self._health_max_wait_sec, + wait_seconds=self._health_poll_interval_sec, + logger=get_logger(), + sandbox_name=name, + sleep=asyncio.sleep, + ) + return await retryer(_poll_once) # ───────────────────────────────────────────────────────────────── diff --git a/xtuner/v1/rl/rollout/trace_store.py b/xtuner/v1/rl/rollout/trace_store.py index cc98ba6619..fea9984a6b 100644 --- a/xtuner/v1/rl/rollout/trace_store.py +++ b/xtuner/v1/rl/rollout/trace_store.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field from xtuner.v1.utils import get_logger +from xtuner.v1.utils.retry_utils import retry_trace_store_bootstrap _STORE_NAME = "rollout_trace_store" @@ -334,6 +335,19 @@ def get_objects(self, keys: list[str]) -> list[ray.ObjectRef]: return [self.objects[key] for key in keys if key in self.objects] +def _create_or_lookup_store_once(): + try: + return RolloutTraceStore.options( + name=_STORE_NAME, + namespace=_STORE_NAMESPACE, + ).remote() + except ValueError as exc: + try: + return ray.get_actor(_STORE_NAME, namespace=_STORE_NAMESPACE) + except ValueError: + raise exc from None + + def get_store(): """Process-local cached handle to the singleton store actor. @@ -355,28 +369,18 @@ def get_store(): except ValueError: pass - import time as _time - - for attempt in range(10): - try: - _handle_cache = RolloutTraceStore.options( - name=_STORE_NAME, - namespace=_STORE_NAMESPACE, - ).remote() - return _handle_cache - except ValueError as exc: - try: - _handle_cache = ray.get_actor(_STORE_NAME, namespace=_STORE_NAMESPACE) - return _handle_cache - except ValueError: - get_logger().debug(f"RolloutTraceStore bootstrap retry {attempt}: {exc}") - _time.sleep(0.2 * (attempt + 1)) - continue - - raise RuntimeError( - f"RolloutTraceStore: failed to acquire named actor " - f"{_STORE_NAME!r} in namespace {_STORE_NAMESPACE!r} after retries" - ) + try: + retryer = retry_trace_store_bootstrap( + logger=get_logger(), + sleep=time.sleep, + ) + _handle_cache = retryer(_create_or_lookup_store_once) + return _handle_cache + except ValueError as exc: + raise RuntimeError( + f"RolloutTraceStore: failed to acquire named actor " + f"{_STORE_NAME!r} in namespace {_STORE_NAMESPACE!r} after retries" + ) from exc # reasoning rl may be not initialized diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 8cd21fe2fd..4ba7446840 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -37,6 +37,7 @@ ) from xtuner.v1.utils import get_logger from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult +from xtuner.v1.utils.retry_utils import retry_rollout_request from .session_server import SessionServerActor from .utils import ROLLOUT_RAY_GET_TIMEOUT, PartialRolloutHandler @@ -50,6 +51,20 @@ ROLLOUT_CONCURRENCY_GROUP_GENERATE = "generate" +class _RetryableRolloutRequestError(Exception): + def __init__(self, http_result: HttpRequestResult) -> None: + super().__init__( + f"retryable rollout request error {http_result.error_type} with message: {http_result.error_msg}" + ) + self.http_result = http_result + + +class _RetryableInvalidRolloutResponseError(Exception): + def __init__(self, reason: str = "Invalid rollout response") -> None: + super().__init__(reason) + self.reason = reason + + class RolloutConfig(BaseModel): """Rollout worker configuration for XTuner. @@ -728,8 +743,14 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: rollout_state.status = Status.COMPLETED return rollout_state - for attempt in range(max_retries + 1): - is_last_attempt = attempt == max_retries + def _prepare_before_retry(retry_state): + nonlocal rollout_state, payload + rollout_state, payload = self._prepare_request_payload( + rollout_state, request_max_tokens, discard_response=True + ) + + async def _attempt_once() -> None: + nonlocal rollout_state http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) # Case 1: HTTP Request is Successful @@ -742,33 +763,19 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: rollout_state.sample_params = rollout_state.sample_params.model_copy( update={"max_tokens": request_max_tokens} ) - return rollout_state + return if rollout_state.status == Status.COMPLETED: - return rollout_state + return if rollout_state.status == Status.ABORTED: rollout_state.sample_params = rollout_state.sample_params.model_copy( update={"max_tokens": request_max_tokens} ) - return rollout_state + return - if is_last_attempt: - # Case 1.2: Invalid rollout response and no retries left, so we return FAILED - self.logger.warning( - f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED." - ) - rollout_state.status = Status.FAILED - rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts." - return rollout_state - - # Case 1.3: Invalid rollout response but we have retries left - self.logger.warning( - f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}." - ) - rollout_state, payload = self._prepare_request_payload( - rollout_state, request_max_tokens, discard_response=True + error_msg = rollout_state.error_msg or "response handler returned without completing or aborting" + raise _RetryableInvalidRolloutResponseError( + f"response handler returned status {rollout_state.status.value}: {error_msg}" ) - await asyncio.sleep(0.1) - continue # Case 2: Error occurred during HTTP Request if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: @@ -778,7 +785,7 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: rollout_state.sample_params = rollout_state.sample_params.model_copy( update={"max_tokens": request_max_tokens} ) - return rollout_state + return if http_result.is_client_error: # Case 2.2: A non-retryable client error occurred (such as 4xx HTTP status) @@ -789,7 +796,7 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: f"Client error {http_result.error_type} with message: {http_result.error_msg}" ) rollout_state.status = Status.FAILED - return rollout_state + return if http_result.is_server_error: # Case 2.3: A non-retryable server error occurred (such as 5xx HTTP status) @@ -800,32 +807,48 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: f"Server error {http_result.error_type} with message: {http_result.error_msg}" ) rollout_state.status = Status.FAILED - return rollout_state + return # Case 3: Retryable error occurred during HTTP Request if http_result.is_retryable: - if is_last_attempt: - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" - ) - rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" - rollout_state.status = Status.FAILED - return rollout_state - - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}." - ) - rollout_state, payload = self._prepare_request_payload( - rollout_state, request_max_tokens, discard_response=True - ) - await asyncio.sleep(0.1) - continue + raise _RetryableRolloutRequestError(http_result) # Case 4: Unknown error occurred during HTTP Request and stop the rollout if http_result.is_unknown_error: raise RuntimeError( f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" ) + + return + + retryer = retry_rollout_request( + attempts=max_retries + 1, + retry_exceptions=( + _RetryableRolloutRequestError, + _RetryableInvalidRolloutResponseError, + ), + logger=self.logger, + request_uid=uid, + before_retry=_prepare_before_retry, + sleep=asyncio.sleep, + ) + try: + await retryer(_attempt_once) + except _RetryableInvalidRolloutResponseError as e: + self.logger.warning( + f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED. Last error: {e.reason}" + ) + rollout_state.status = Status.FAILED + rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts: {e.reason}" + return rollout_state + except _RetryableRolloutRequestError as e: + http_result = e.http_result + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" + rollout_state.status = Status.FAILED + return rollout_state return rollout_state finally: if rollout_state.status == Status.FAILED: @@ -1018,14 +1041,15 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response self.logger.warning( f"finish_reason is missing in response meta_info when waiting for aborted message {uid}, defaulting to 'abort'. Response: {response}" ) + rollout_state.error_msg = "Missing finish_reason in response meta_info" + return rollout_state else: - rollout_state.finish_reason = "error" - rollout_state.status = Status.FAILED self.logger.warning( - f"finish_reason is missing in response meta_info for message {uid}, defaulting to 'error'. Response: {response}" + f"finish_reason is missing in response meta_info for message {uid}. Response: {response}" + ) + raise _RetryableInvalidRolloutResponseError( + f"Missing finish_reason in response meta_info for message {uid}" ) - rollout_state.error_msg = "Missing finish_reason in response meta_info" - return rollout_state returned_response = response.get("text", "") # 获取response_ids && respoonse_ids if ( @@ -1070,17 +1094,11 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response if validation_errors: error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}" self.logger.error(error_msg) - rollout_state.routed_experts = routed_experts - rollout_state.status = Status.FAILED - rollout_state.error_msg = error_msg - return rollout_state + raise _RetryableInvalidRolloutResponseError(error_msg) elif rollout_status == Status.FAILED: error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" self.logger.error(error_msg) - rollout_state.routed_experts = routed_experts - rollout_state.status = Status.FAILED - rollout_state.error_msg = error_msg - return rollout_state + raise _RetryableInvalidRolloutResponseError(error_msg) if self.enable_partial_rollout: prompt_tokens = response["meta_info"]["prompt_tokens"] @@ -1104,6 +1122,8 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.finish_reason = finish_reason rollout_state.status = rollout_status return rollout_state + except _RetryableInvalidRolloutResponseError: + raise except KeyError as e: response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" @@ -1134,15 +1154,20 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response finish_reason = response["choices"][0]["finish_reason"] rollout_status = update_status_from_finish_reason(finish_reason) if rollout_status == Status.COMPLETED and not returned_response: - self.logger.error(f"Empty response text for msg {uid} with finish_reason {finish_reason}") - rollout_state.status = Status.FAILED - rollout_state.error_msg = "Empty response text" - return rollout_state + error_msg = f"Empty response text for msg {uid} with finish_reason {finish_reason}" + self.logger.error(error_msg) + raise _RetryableInvalidRolloutResponseError(error_msg) + if rollout_status == Status.FAILED: + error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" + self.logger.error(error_msg) + raise _RetryableInvalidRolloutResponseError(error_msg) rollout_state.response = returned_response rollout_state.finish_reason = finish_reason rollout_state.status = rollout_status return rollout_state + except _RetryableInvalidRolloutResponseError: + raise except KeyError as e: response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" diff --git a/xtuner/v1/utils/retry_utils.py b/xtuner/v1/utils/retry_utils.py new file mode 100644 index 0000000000..664f050988 --- /dev/null +++ b/xtuner/v1/utils/retry_utils.py @@ -0,0 +1,220 @@ +import asyncio +import time +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, TypeVar + +from tenacity import ( + AsyncRetrying, + RetryCallState, + Retrying, + retry_if_exception_type, + retry_if_result, + stop_after_attempt, + stop_after_delay, + wait_exponential, + wait_fixed, + wait_incrementing, +) + + +_ResultT = TypeVar("_ResultT") + +XTUNER_ROLLOUT_RETRY_WAIT_SECONDS = 0.1 +XTUNER_SANDBOX_CREATE_RETRY_WAIT_MULTIPLIER = 2.0 +XTUNER_SANDBOX_CREATE_RETRY_WAIT_MIN_SECONDS = 2.0 +XTUNER_SANDBOX_CREATE_RETRY_WAIT_MAX_SECONDS = 8.0 +XTUNER_TRACE_STORE_RETRY_ATTEMPTS = 10 +XTUNER_TRACE_STORE_RETRY_WAIT_START_SECONDS = 0.2 +XTUNER_TRACE_STORE_RETRY_WAIT_INCREMENT_SECONDS = 0.2 +XTUNER_TRACE_STORE_RETRY_WAIT_MAX_SECONDS = 2.0 + + +def _retry_before_sleep_with_logging( + *, + operation: str, + logger: Any | None = None, + context: Mapping[str, Any] | None = None, + before_retry: Callable[[RetryCallState], None] | None = None, +) -> Callable[[RetryCallState], None]: + def _before_sleep(retry_state: RetryCallState) -> None: + if logger is not None: + assert retry_state.outcome is not None + exception = retry_state.outcome.exception() + next_action = getattr(retry_state, "next_action", None) + next_sleep = getattr(next_action, "sleep", None) + retry_in = f", retrying in {next_sleep:.2f}s" if next_sleep is not None else ", retrying" + if exception is None: + reason = f"result={retry_state.outcome.result()!r}" + else: + reason = f"{type(exception).__name__}: {exception}" + context_suffix = "" + if context: + context_suffix = " " + " ".join(f"{key}={value}" for key, value in context.items()) + logger.warning( + f"{operation} failed on attempt {retry_state.attempt_number}{context_suffix} with {reason}{retry_in}." + ) + if before_retry is not None: + before_retry(retry_state) + + return _before_sleep + + +def _build_async_exception_retryer( + *, + operation: str, + attempts: int, + retry_exceptions: type[BaseException] | tuple[type[BaseException], ...], + wait_strategy: Any, + logger: Any | None = None, + context: Mapping[str, Any] | None = None, + before_retry: Callable[[RetryCallState], None] | None = None, + sleep: Callable[[float], Awaitable[Any]] = asyncio.sleep, +) -> AsyncRetrying: + return AsyncRetrying( + stop=stop_after_attempt(attempts), + wait=wait_strategy, + retry=retry_if_exception_type(retry_exceptions), + before_sleep=_retry_before_sleep_with_logging( + operation=operation, + logger=logger, + context=context, + before_retry=before_retry, + ), + sleep=sleep, + reraise=True, + ) + + +def retry_rollout_request( + *, + attempts: int, + retry_exceptions: type[BaseException] | tuple[type[BaseException], ...], + logger: Any | None = None, + request_uid: Any | None = None, + wait_seconds: float | None = None, + before_retry: Callable[[RetryCallState], None] | None = None, + sleep: Callable[[float], Awaitable[Any]] = asyncio.sleep, +) -> AsyncRetrying: + wait_seconds = wait_seconds if wait_seconds is not None else XTUNER_ROLLOUT_RETRY_WAIT_SECONDS + context = {"uid": request_uid} if request_uid is not None else None + return _build_async_exception_retryer( + operation="rollout_request", + attempts=attempts, + retry_exceptions=retry_exceptions, + wait_strategy=wait_fixed(wait_seconds), + logger=logger, + context=context, + before_retry=before_retry, + sleep=sleep, + ) + + +def retry_sandbox_acquire( + *, + attempts: int, + unhealthy_exceptions: type[BaseException] | tuple[type[BaseException], ...], + logger: Any | None = None, + sandbox_name: str | None = None, + create_wait_multiplier: float | None = None, + create_wait_min_seconds: float | None = None, + create_wait_max_seconds: float | None = None, + before_retry: Callable[[RetryCallState], None] | None = None, + sleep: Callable[[float], Awaitable[Any]] = asyncio.sleep, +) -> AsyncRetrying: + create_wait_multiplier = ( + create_wait_multiplier if create_wait_multiplier is not None else XTUNER_SANDBOX_CREATE_RETRY_WAIT_MULTIPLIER + ) + create_wait_min_seconds = ( + create_wait_min_seconds + if create_wait_min_seconds is not None + else XTUNER_SANDBOX_CREATE_RETRY_WAIT_MIN_SECONDS + ) + create_wait_max_seconds = ( + create_wait_max_seconds + if create_wait_max_seconds is not None + else XTUNER_SANDBOX_CREATE_RETRY_WAIT_MAX_SECONDS + ) + context = {"sandbox": sandbox_name} if sandbox_name else None + create_wait = wait_exponential( + multiplier=create_wait_multiplier, + min=create_wait_min_seconds, + max=create_wait_max_seconds, + ) + + def _wait_strategy(retry_state: RetryCallState) -> float: + assert retry_state.outcome is not None + exception = retry_state.outcome.exception() + if exception is not None and isinstance(exception, unhealthy_exceptions): + return 0.0 + return create_wait(retry_state) + + return _build_async_exception_retryer( + operation="sandbox_acquire", + attempts=attempts, + retry_exceptions=(Exception,), + wait_strategy=_wait_strategy, + logger=logger, + context=context, + before_retry=before_retry, + sleep=sleep, + ) + + +def poll_sandbox_health( + *, + max_wait_seconds: float | None = None, + stop_strategy: Any | None = None, + wait_seconds: float, + logger: Any | None = None, + sandbox_name: str | None = None, + sleep: Callable[[float], Awaitable[Any]] = asyncio.sleep, +) -> AsyncRetrying: + if stop_strategy is None: + if max_wait_seconds is None: + raise ValueError("max_wait_seconds or stop_strategy must be provided") + stop_strategy = stop_after_delay(max_wait_seconds) + elif max_wait_seconds is not None: + raise ValueError("Only one of max_wait_seconds and stop_strategy may be provided") + context = {"sandbox": sandbox_name} if sandbox_name else None + return AsyncRetrying( + stop=stop_strategy, + wait=wait_fixed(wait_seconds), + retry=retry_if_result(lambda healthy: not healthy), + retry_error_callback=lambda retry_state: retry_state.outcome.result(), + before_sleep=_retry_before_sleep_with_logging( + operation="sandbox_health_poll", + logger=logger, + context=context, + ), + sleep=sleep, + ) + + +def retry_trace_store_bootstrap( + *, + attempts: int | None = None, + wait_start_seconds: float = XTUNER_TRACE_STORE_RETRY_WAIT_START_SECONDS, + wait_increment_seconds: float = XTUNER_TRACE_STORE_RETRY_WAIT_INCREMENT_SECONDS, + wait_max_seconds: float | None = None, + logger: Any | None = None, + before_retry: Callable[[RetryCallState], None] | None = None, + sleep: Callable[[float], Any] = time.sleep, +) -> Retrying: + attempts = attempts if attempts is not None else XTUNER_TRACE_STORE_RETRY_ATTEMPTS + wait_max_seconds = wait_max_seconds if wait_max_seconds is not None else XTUNER_TRACE_STORE_RETRY_WAIT_MAX_SECONDS + return Retrying( + stop=stop_after_attempt(attempts), + wait=wait_incrementing( + start=wait_start_seconds, + increment=wait_increment_seconds, + max=wait_max_seconds, + ), + retry=retry_if_exception_type((ValueError,)), + before_sleep=_retry_before_sleep_with_logging( + operation="trace_store_bootstrap", + logger=logger, + before_retry=before_retry, + ), + sleep=sleep, + reraise=True, + )