diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e89..1de1e7999 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -228,3 +228,8 @@ def get_added_mtp_kv_layer_num() -> int: @lru_cache(maxsize=None) def get_pd_split_max_new_tokens() -> int: return int(os.getenv("LIGHTLLM_PD_SPLIT_MAX_NEW_TOKENS", 2048)) + + +@lru_cache(maxsize=None) +def get_lightllm_url_pool_maxsize() -> int: + return int(os.getenv("LIGHTLLM_URL_POOL_MAXSIZE", 256)) diff --git a/lightllm/utils/multimodal_utils.py b/lightllm/utils/multimodal_utils.py index 876283b93..eaf598aef 100644 --- a/lightllm/utils/multimodal_utils.py +++ b/lightllm/utils/multimodal_utils.py @@ -6,12 +6,68 @@ from io import BytesIO from fastapi import Request from functools import lru_cache +from collections import OrderedDict +from typing import Awaitable, Callable, Dict, Optional, Tuple +import asyncio from lightllm.utils.error_utils import ClientDisconnected +from lightllm.utils.envs_utils import get_lightllm_url_pool_maxsize from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +class UrlResourcePool: + def __init__(self, maxsize: int = 256): + self._maxsize = maxsize + self._cache: "OrderedDict[Tuple[str, Optional[str]], bytes]" = OrderedDict() + self._inflight: Dict[Tuple[str, Optional[str]], asyncio.Task] = {} + self._lock = asyncio.Lock() + + @staticmethod + def _normalize_url(url: str) -> str: + return url.strip() + + async def get_or_create(self, url: str, proxy: Optional[str], loader: Callable[[], Awaitable[bytes]]) -> bytes: + key = (self._normalize_url(url), proxy) + + async with self._lock: + cached = self._cache.get(key) + if cached is not None: + self._cache.move_to_end(key) + logger.info(f"url_pool hit") + return cached + + task = self._inflight.get(key) + if task is None: + + async def _run_and_cache() -> bytes: + try: + content = await loader() + async with self._lock: + self._cache[key] = content + self._cache.move_to_end(key) + while len(self._cache) > self._maxsize: + self._cache.popitem(last=False) + return content + finally: + async with self._lock: + self._inflight.pop(key, None) + + task = asyncio.create_task(_run_and_cache()) + + def _consume_task_exception(completed_task: asyncio.Task) -> None: + if not completed_task.cancelled(): + completed_task.exception() + + task.add_done_callback(_consume_task_exception) + self._inflight[key] = task + + return await asyncio.shield(task) + + +URL_RESOURCE_POOL = UrlResourcePool(maxsize=get_lightllm_url_pool_maxsize()) + + def _httpx_async_client_proxy_kwargs(proxy) -> dict: """ httpx 0.28+ 使用 AsyncClient(proxy=...);更早版本使用 proxies=... @@ -46,22 +102,26 @@ def _get_xhttp_client(proxy=None): async def fetch_resource(url, request: Request, timeout, proxy=None): logger.info(f"Begin to download resource from url: {url}") + if request is not None and await request.is_disconnected(): + raise ClientDisconnected(reason=f"client disconnected during url download") + start_time = time.time() - client = _get_xhttp_client(proxy) - async with client.stream("GET", url, timeout=timeout) as response: - response.raise_for_status() - ans_bytes = [] - async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): - if request is not None and await request.is_disconnected(): - await response.aclose() - raise ClientDisconnected(reason=f"client disconnected during download of {url}") - ans_bytes.append(chunk) - # 接收的数据不能大于128M - if len(ans_bytes) > 128: - raise Exception(f"url {url} recv data is too big") - - content = b"".join(ans_bytes) - end_time = time.time() - cost_time = end_time - start_time - logger.info(f"Download url {url} resource cost time: {cost_time} seconds") - return content + + async def _load() -> bytes: + client = _get_xhttp_client(proxy) + async with client.stream("GET", url, timeout=timeout) as response: + response.raise_for_status() + ans_bytes = [] + async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): + ans_bytes.append(chunk) + # 接收的数据不能大于128M + if len(ans_bytes) > 128: + raise Exception("url data is too big") + + content = b"".join(ans_bytes) + end_time = time.time() + cost_time = end_time - start_time + logger.info(f"url download done, cost={cost_time:.3f}s") + return content + + return await URL_RESOURCE_POOL.get_or_create(url, proxy, _load)