Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lightllm/utils/envs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,8 @@ def get_added_mtp_kv_layer_num() -> int:
@lru_cache(maxsize=None)
def get_pd_split_max_new_tokens() -> int:
return int(os.getenv("LIGHTLLM_PD_SPLIT_MAX_NEW_TOKENS", 2048))


@lru_cache(maxsize=None)
def get_lightllm_url_pool_maxsize() -> int:
return int(os.getenv("LIGHTLLM_URL_POOL_MAXSIZE", 256))
96 changes: 78 additions & 18 deletions lightllm/utils/multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,68 @@
from io import BytesIO
from fastapi import Request
from functools import lru_cache
from collections import OrderedDict
from typing import Awaitable, Callable, Dict, Optional, Tuple
import asyncio
from lightllm.utils.error_utils import ClientDisconnected
from lightllm.utils.envs_utils import get_lightllm_url_pool_maxsize
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class UrlResourcePool:
def __init__(self, maxsize: int = 256):
self._maxsize = maxsize
self._cache: "OrderedDict[Tuple[str, Optional[str]], bytes]" = OrderedDict()
self._inflight: Dict[Tuple[str, Optional[str]], asyncio.Task] = {}
self._lock = asyncio.Lock()

@staticmethod
def _normalize_url(url: str) -> str:
return url.strip()

async def get_or_create(self, url: str, proxy: Optional[str], loader: Callable[[], Awaitable[bytes]]) -> bytes:
key = (self._normalize_url(url), proxy)

async with self._lock:
cached = self._cache.get(key)
if cached is not None:
self._cache.move_to_end(key)
logger.info(f"url_pool hit")
return cached

task = self._inflight.get(key)
if task is None:

async def _run_and_cache() -> bytes:
try:
content = await loader()
async with self._lock:
self._cache[key] = content
self._cache.move_to_end(key)
while len(self._cache) > self._maxsize:
self._cache.popitem(last=False)
return content
finally:
async with self._lock:
self._inflight.pop(key, None)

task = asyncio.create_task(_run_and_cache())

def _consume_task_exception(completed_task: asyncio.Task) -> None:
if not completed_task.cancelled():
completed_task.exception()

task.add_done_callback(_consume_task_exception)
self._inflight[key] = task

return await asyncio.shield(task)


URL_RESOURCE_POOL = UrlResourcePool(maxsize=get_lightllm_url_pool_maxsize())


def _httpx_async_client_proxy_kwargs(proxy) -> dict:
"""
httpx 0.28+ 使用 AsyncClient(proxy=...);更早版本使用 proxies=...
Expand Down Expand Up @@ -46,22 +102,26 @@ def _get_xhttp_client(proxy=None):

async def fetch_resource(url, request: Request, timeout, proxy=None):
logger.info(f"Begin to download resource from url: {url}")
if request is not None and await request.is_disconnected():
raise ClientDisconnected(reason=f"client disconnected during url download")

start_time = time.time()
client = _get_xhttp_client(proxy)
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
if request is not None and await request.is_disconnected():
await response.aclose()
raise ClientDisconnected(reason=f"client disconnected during download of {url}")
ans_bytes.append(chunk)
# 接收的数据不能大于128M
if len(ans_bytes) > 128:
raise Exception(f"url {url} recv data is too big")

content = b"".join(ans_bytes)
end_time = time.time()
cost_time = end_time - start_time
logger.info(f"Download url {url} resource cost time: {cost_time} seconds")
return content

async def _load() -> bytes:
client = _get_xhttp_client(proxy)
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
ans_bytes.append(chunk)
# 接收的数据不能大于128M
if len(ans_bytes) > 128:
raise Exception("url data is too big")

content = b"".join(ans_bytes)
end_time = time.time()
cost_time = end_time - start_time
logger.info(f"url download done, cost={cost_time:.3f}s")
return content
Comment thread
Owleye4 marked this conversation as resolved.

return await URL_RESOURCE_POOL.get_or_create(url, proxy, _load)
Loading