From fd99194cb0ccd2fae08f6989b5ccd208df6575a1 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 13 May 2026 11:20:46 +0800 Subject: [PATCH 1/2] feat: add flash attention 4 --- docker/Dockerfile | 2 + .../common/basemodel/attention/__init__.py | 1 + .../basemodel/attention/create_utils.py | 2 + .../basemodel/attention/fa4/__init__.py | 0 lightllm/common/basemodel/attention/fa4/fp.py | 149 ++++++++++++++++++ .../basemodel/attention_vit/create_utils.py | 2 + .../basemodel/attention_vit/fa4/__init__.py | 0 .../common/basemodel/attention_vit/fa4/fp.py | 38 +++++ lightllm/server/api_cli.py | 6 +- lightllm/server/api_start.py | 11 +- lightllm/server/core/objs/start_args_type.py | 6 +- lightllm/utils/backend_validator.py | 43 +++++ lightllm/utils/dist_check_utils.py | 2 +- lightllm/utils/fa4_utils.py | 46 ++++++ 14 files changed, 300 insertions(+), 8 deletions(-) create mode 100644 lightllm/common/basemodel/attention/fa4/__init__.py create mode 100644 lightllm/common/basemodel/attention/fa4/fp.py create mode 100644 lightllm/common/basemodel/attention_vit/fa4/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/fa4/fp.py create mode 100644 lightllm/utils/fa4_utils.py diff --git a/docker/Dockerfile b/docker/Dockerfile index bba404c965..f3f1885d5a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -62,6 +62,8 @@ RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cu RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* +RUN pip install --no-cache-dir "flash-attn-4==4.0.0b13" + ENV CUDA_HOME=/usr/local/cuda \ GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 10cd3b0864..896bc63bdd 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -4,6 +4,7 @@ from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend from .fa3.fp import Fa3AttBackend +from .fa4.fp import Fa4AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..1191f8b088 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -9,6 +9,7 @@ from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend from .fa3.fp import Fa3AttBackend +from .fa4.fp import Fa4AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend @@ -24,6 +25,7 @@ "None": { "triton": TritonAttBackend, "fa3": Fa3AttBackend, + "fa4": Fa4AttBackend, "flashinfer": FlashInferAttBackend, }, "int4kv": { diff --git a/lightllm/common/basemodel/attention/fa4/__init__.py b/lightllm/common/basemodel/attention/fa4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/fa4/fp.py b/lightllm/common/basemodel/attention/fa4/fp.py new file mode 100644 index 0000000000..af2fb32723 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa4/fp.py @@ -0,0 +1,149 @@ +import dataclasses +import torch + +from ..base_att import AttControl +from ..fa3.fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState +from lightllm.utils.fa4_utils import ( + ensure_fa4_available, + ensure_fa4_supported_gpu, + flash_attn_varlen_func, + unwrap_fa4_output, +) + + +class Fa4AttBackend(Fa3AttBackend): + def __init__(self, model): + ensure_fa4_available() + ensure_fa4_supported_gpu() + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state) -> "Fa4PrefillAttState": + return Fa4PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa4DecodeAttState": + return Fa4DecodeAttState(backend=self, infer_state=infer_state) + + +def _sm90_fa4_paged_kv_tile_n( + head_dim: int, + head_dim_v: int, + window_size: tuple[int, int], +) -> int | None: + major, _minor = torch.cuda.get_device_capability() + if major != 9: + return None + + is_local = window_size != (-1, -1) + if head_dim <= 64: + return 128 + if head_dim <= 96: + return 128 if is_local else 144 + if head_dim <= 128: + return 128 + if head_dim <= 192: + return 96 if is_local else (128 if head_dim_v <= 128 else 112) + return 64 if is_local else 80 + + +def _ensure_fa4_paged_kv_supported( + head_dim: int, + head_dim_v: int, + window_size: tuple[int, int], + page_size: int, +) -> None: + tile_n = _sm90_fa4_paged_kv_tile_n(head_dim, head_dim_v, window_size) + if tile_n is None or page_size == tile_n or tile_n >= 128: + return + + raise RuntimeError( + "FA4 SM90 paged KV requires page_size == tile_n for this shape; " + f"current page_size={page_size}, required_page_size={tile_n}, " + f"head_dim={head_dim}, head_dim_v={head_dim_v}, window_size={window_size}. " + "LightLLM's current FA4 wrapper uses token-granular KV pages, so this shape would need " + "the removed repack fallback to run. Please set the FA4 KV cache page size to " + f"{tile_n} tokens for this model/shape, or switch --llm_prefill_att_backend/" + "--llm_decode_att_backend to another backend." + ) + + +@dataclasses.dataclass +class Fa4PrefillAttState(Fa3PrefillAttState): + def _normal_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + head_dim = q.shape[-1] + head_dim_v = v.shape[-1] + softmax_scale = 1.0 / (head_dim ** 0.5) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1) + + out = flash_attn_varlen_func( + q=q, + k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + cu_seqlens_q=self.cu_seqlens_q, + seqused_k=self.infer_state.b_seq_len.int(), + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + page_table=self.page_table, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + learnable_sink=sink_weight, + softcap=0.0, + return_lse=False, + ) + return unwrap_fa4_output(out) + + +@dataclasses.dataclass +class Fa4DecodeAttState(Fa3DecodeAttState): + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + head_dim = q.shape[-1] + head_dim_v = v.shape[-1] + softmax_scale = 1.0 / (head_dim ** 0.5) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1) + + out = flash_attn_varlen_func( + q=q, + k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + cu_seqlens_q=self.cu_seqlens_q, + seqused_k=self.b_att_seq_len.int(), + max_seqlen_q=self.decode_max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + page_table=self.page_table, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + learnable_sink=sink_weight, + softcap=0.0, + return_lse=False, + ) + return unwrap_fa4_output(out) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 67f830ba0d..c4a56dd4db 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -4,6 +4,7 @@ from lightllm.utils.backend_validator import _validate from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.fa4.fp import Fa4VitAttBackend from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend @@ -15,6 +16,7 @@ "triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend, + "fa4": Fa4VitAttBackend, "xformers": XformersVitAttBackend, } diff --git a/lightllm/common/basemodel/attention_vit/fa4/__init__.py b/lightllm/common/basemodel/attention_vit/fa4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention_vit/fa4/fp.py b/lightllm/common/basemodel/attention_vit/fa4/fp.py new file mode 100644 index 0000000000..a685fc4972 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa4/fp.py @@ -0,0 +1,38 @@ +import torch + +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.utils.fa4_utils import ( + ensure_fa4_available, + ensure_fa4_supported_gpu, + _flash_attn_fwd, +) + + +class Fa4VitAttBackend(BaseVitAttBackend): + def __init__(self): + ensure_fa4_available() + ensure_fa4_supported_gpu() + + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + head_dim = q.shape[-1] + return _flash_attn_fwd( + out=o, + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=head_dim ** -0.5, + causal=False, + return_lse=False, + ) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 2db6c67e77..2037a5cb63 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -395,7 +395,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--llm_prefill_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "flashinfer"], + choices=["auto", "triton", "fa3", "fa4", "flashinfer"], default=["auto"], help="""prefill attention kernel used in llm. auto: automatically select best backend based on GPU and available packages @@ -405,7 +405,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--llm_decode_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "flashinfer"], + choices=["auto", "triton", "fa3", "fa4", "flashinfer"], default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages @@ -415,7 +415,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--vit_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "sdpa", "xformers"], + choices=["auto", "triton", "fa3", "fa4", "sdpa", "xformers"], default=["auto"], help="""vit attention kernel used in vlm. auto: automatically select best backend based on GPU and available packages diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 654ba0f3e5..3a29566283 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -196,7 +196,7 @@ def normal_or_p_d_start(args): assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" if args.enable_ep_moe: - allowed_ep_att_backends = {"auto", "fa3", "triton"} + allowed_ep_att_backends = {"auto", "fa3", "fa4", "triton"} for backend in args.llm_prefill_att_backend: assert backend in allowed_ep_att_backends, ( "When --enable_ep_moe is enabled, --llm_prefill_att_backend must be one of " @@ -208,6 +208,15 @@ def normal_or_p_d_start(args): f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported." ) + requested_backends = ( + list(args.llm_prefill_att_backend) + list(args.llm_decode_att_backend) + list(args.vit_att_backend) + ) + if "fa4" in requested_backends: + from lightllm.utils.fa4_utils import ensure_fa4_available, ensure_fa4_supported_gpu + + ensure_fa4_available() + ensure_fa4_supported_gpu() + # mtp params check if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6d0ee07465..ca325b2f13 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -137,13 +137,13 @@ class StartArgs: vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "flashinfer"]} ) vit_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "sdpa", "xformers"]} ) llm_kv_type: str = field( default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index ab5c0a88a1..039b2ed64b 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -59,6 +59,47 @@ def _validate_fa3(): return True, None +def _validate_fa4(): + """Validate FA4 with ground truth.""" + from lightllm.utils.fa4_utils import flash_attn_varlen_func, is_fa4_supported_gpu, unwrap_fa4_output + + if not is_fa4_supported_gpu(): + return False, "FA4 requires Hopper/Blackwell-class GPU" + if flash_attn_varlen_func is None: + return False, "flash_attn_varlen_func is None" + + batch, heads, seq, dim = 1, 4, 8, 64 + q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + + expected = _compute_ground_truth(q, k, v) + + q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) + k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) + v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) + cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") + + out = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq, + max_seqlen_k=seq, + softmax_scale=1.0 / (dim ** 0.5), + causal=True, + return_lse=False, + ) + out = unwrap_fa4_output(out).reshape(batch, seq, heads, dim).transpose(1, 2) + torch.cuda.synchronize() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + return True, None + + def _validate_flashinfer(): """Validate FlashInfer with ground truth.""" capability = torch.cuda.get_device_capability() @@ -242,6 +283,8 @@ def _run_in_subprocess(backend_name, pipe): try: if backend_name == "fa3": success, err = _validate_fa3() + elif backend_name == "fa4": + success, err = _validate_fa4() elif backend_name == "xformers": success, err = _validate_xformers() elif backend_name == "sdpa": diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index 12b0b81993..0f2548be8d 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 # 给flashinfer jit编译预留足够时间 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: diff --git a/lightllm/utils/fa4_utils.py b/lightllm/utils/fa4_utils.py new file mode 100644 index 0000000000..d7dfae722f --- /dev/null +++ b/lightllm/utils/fa4_utils.py @@ -0,0 +1,46 @@ +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + from flash_attn.cute import flash_attn_varlen_func + from flash_attn.cute.interface import _flash_attn_fwd + + HAS_FA4 = True +except Exception: + flash_attn_varlen_func = None + _flash_attn_fwd = None + HAS_FA4 = False + logger.warning("flash-attn-4 is not installed") + + +def is_fa4_supported_gpu() -> bool: + if not torch.cuda.is_available(): + return False + major, _minor = torch.cuda.get_device_capability() + return major in (9, 10, 11, 12) + + +def ensure_fa4_available() -> None: + if not HAS_FA4: + raise ImportError( + "flash-attn-4 is unavailable. Install it first, e.g. `pip install flash-attn-4`, " + "or install from the local flash-attention repo." + ) + + +def ensure_fa4_supported_gpu() -> None: + if not torch.cuda.is_available(): + raise RuntimeError("FA4 backend requires CUDA, but CUDA is not available.") + major, minor = torch.cuda.get_device_capability() + if major not in (9, 10, 11, 12): + raise RuntimeError( + f"FA4 backend requires Hopper/Blackwell-class GPUs (SM90/SM100/SM110/SM120). " + f"Current device capability is {major}.{minor}." + ) + + +def unwrap_fa4_output(output): + return output[0] if isinstance(output, tuple) else output From db91dd036b030cda5067b40f19e03fee57380188 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 6 Mar 2026 19:23:31 +0800 Subject: [PATCH 2/2] feat: page size > 1 support --- docker/Dockerfile | 2 - .../basemodel/attention/create_utils.py | 18 +- lightllm/common/basemodel/attention/fa3/fp.py | 4 +- lightllm/common/basemodel/attention/fa4/fp.py | 36 +- .../basemodel/attention/flashinfer/fp.py | 4 +- .../basemodel/attention/paged_fa3/__init__.py | 0 .../basemodel/attention/paged_fa3/fp.py | 188 +++++++ .../basemodel/attention/paged_fa3/mla.py | 174 ++++++ .../attention/paged_flashinfer/__init__.py | 0 .../attention/paged_flashinfer/fp.py | 193 +++++++ .../attention/paged_flashinfer/mla.py | 184 ++++++ .../common/basemodel/attention/triton/fp.py | 4 +- .../basemodel/triton_kernel/fa3_utils.py | 8 + .../triton_kernel/repack_kv_index.py | 63 +++ .../deepseek2_mem_manager.py | 6 +- .../kv_cache_mem_manager/mem_manager.py | 82 ++- lightllm/common/req_manager.py | 17 +- lightllm/server/api_start.py | 42 +- .../dynamic_prompt/paged_radix_cache.py | 532 ++++++++++++++++++ .../server/router/model_infer/infer_batch.py | 13 + .../model_infer/mode_backend/base_backend.py | 5 +- .../generic_padded_pre_process.py | 30 +- .../mode_backend/generic_pre_process.py | 24 +- .../router/req_queue/chunked_prefill/impl.py | 5 +- lightllm/utils/dist_check_utils.py | 2 +- lightllm/utils/envs_utils.py | 10 + lightllm/utils/fa4_utils.py | 40 ++ requirements.txt | 1 + .../triton_kernel/test_repack_kv_index.py | 49 +- 29 files changed, 1674 insertions(+), 62 deletions(-) create mode 100644 lightllm/common/basemodel/attention/paged_fa3/__init__.py create mode 100644 lightllm/common/basemodel/attention/paged_fa3/fp.py create mode 100644 lightllm/common/basemodel/attention/paged_fa3/mla.py create mode 100644 lightllm/common/basemodel/attention/paged_flashinfer/__init__.py create mode 100644 lightllm/common/basemodel/attention/paged_flashinfer/fp.py create mode 100644 lightllm/common/basemodel/attention/paged_flashinfer/mla.py create mode 100644 lightllm/server/router/dynamic_prompt/paged_radix_cache.py diff --git a/docker/Dockerfile b/docker/Dockerfile index f3f1885d5a..bba404c965 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -62,8 +62,6 @@ RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cu RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir "flash-attn-4==4.0.0b13" - ENV CUDA_HOME=/usr/local/cuda \ GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 1191f8b088..dda6b4cb94 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,5 +1,5 @@ """Attention backend selection utilities.""" -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate from typing import Dict @@ -12,21 +12,27 @@ from .fa4.fp import Fa4AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend +from .paged_fa3.fp import PagedFa3AttBackend +from .paged_fa3.mla import PagedMlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend +from .paged_flashinfer.fp import PagedFlashInferAttBackend +from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend logger = init_logger(__name__) +_PAGE_ENABLED = get_page_size() > 1 + # Backend class mappings by data type data_type_to_backend = { "None": { - "triton": TritonAttBackend, - "fa3": Fa3AttBackend, + "triton": TritonAttBackend, # triton backend supports arbitrary page size + "fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend, "fa4": Fa4AttBackend, - "flashinfer": FlashInferAttBackend, + "flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend, }, "int4kv": { "triton": Int4kvTritonAttBackend, @@ -49,8 +55,8 @@ mla_data_type_to_backend = { "None": { "triton": MlaTritonAttBackend, - "fa3": MlaFa3AttBackend, - "flashinfer": MlaFlashInferAttBackend, + "fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend, + "flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend, }, } diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..9568e4a892 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -66,7 +66,7 @@ def prefill_att( alloc_func=torch.empty, ) -> torch.Tensor: assert att_control.use_alibi is False - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, @@ -74,7 +74,7 @@ def prefill_att( alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/fa4/fp.py b/lightllm/common/basemodel/attention/fa4/fp.py index af2fb32723..dbc0f4ab1e 100644 --- a/lightllm/common/basemodel/attention/fa4/fp.py +++ b/lightllm/common/basemodel/attention/fa4/fp.py @@ -2,16 +2,17 @@ import torch from ..base_att import AttControl -from ..fa3.fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState +from ..paged_fa3.fp import PagedFa3AttBackend, PagedFa3PrefillAttState, PagedFa3DecodeAttState from lightllm.utils.fa4_utils import ( ensure_fa4_available, ensure_fa4_supported_gpu, flash_attn_varlen_func, + sm90_fa4_paged_kv_tile_n, unwrap_fa4_output, ) -class Fa4AttBackend(Fa3AttBackend): +class Fa4AttBackend(PagedFa3AttBackend): def __init__(self, model): ensure_fa4_available() ensure_fa4_supported_gpu() @@ -29,20 +30,7 @@ def _sm90_fa4_paged_kv_tile_n( head_dim_v: int, window_size: tuple[int, int], ) -> int | None: - major, _minor = torch.cuda.get_device_capability() - if major != 9: - return None - - is_local = window_size != (-1, -1) - if head_dim <= 64: - return 128 - if head_dim <= 96: - return 128 if is_local else 144 - if head_dim <= 128: - return 128 - if head_dim <= 192: - return 96 if is_local else (128 if head_dim_v <= 128 else 112) - return 64 if is_local else 80 + return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size) def _ensure_fa4_paged_kv_supported( @@ -67,7 +55,7 @@ def _ensure_fa4_paged_kv_supported( @dataclasses.dataclass -class Fa4PrefillAttState(Fa3PrefillAttState): +class Fa4PrefillAttState(PagedFa3PrefillAttState): def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: @@ -84,12 +72,12 @@ def _normal_prefill_att( head_dim = q.shape[-1] head_dim_v = v.shape[-1] softmax_scale = 1.0 / (head_dim ** 0.5) - _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size) out = flash_attn_varlen_func( q=q, - k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), - v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), cu_seqlens_q=self.cu_seqlens_q, seqused_k=self.infer_state.b_seq_len.int(), max_seqlen_q=self.infer_state.max_q_seq_len, @@ -106,7 +94,7 @@ def _normal_prefill_att( @dataclasses.dataclass -class Fa4DecodeAttState(Fa3DecodeAttState): +class Fa4DecodeAttState(PagedFa3DecodeAttState): def _normal_decode_att( self, q: torch.Tensor, @@ -128,12 +116,12 @@ def _normal_decode_att( head_dim = q.shape[-1] head_dim_v = v.shape[-1] softmax_scale = 1.0 / (head_dim ** 0.5) - _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size) out = flash_attn_varlen_func( q=q, - k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), - v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), cu_seqlens_q=self.cu_seqlens_q, seqused_k=self.b_att_seq_len.int(), max_seqlen_q=self.decode_max_q_seq_len, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..37478be76f 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -99,14 +99,14 @@ def prefill_att( and att_control.use_sliding_window is False and att_control.use_att_sink is False ) - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/paged_fa3/__init__.py b/lightllm/common/basemodel/attention/paged_fa3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_fa3/fp.py b/lightllm/common/basemodel/attention/paged_fa3/fp.py new file mode 100644 index 0000000000..5c01538c42 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/fp.py @@ -0,0 +1,188 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + self.page_table = torch.empty( + (self.infer_state.batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + page_table_copy( + page_table=self.page_table, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) + + +@dataclasses.dataclass +class PagedFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) diff --git a/lightllm/common/basemodel/attention/paged_fa3/mla.py b/lightllm/common/basemodel/attention/paged_fa3/mla.py new file mode 100644 index 0000000000..2e33c05409 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/mla.py @@ -0,0 +1,174 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_varlen_func +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedMlaFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedMlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl, alloc_func=torch.empty + ): + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + assert att_control.mla_prefill + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=att_control.mla_prefill_dict["softmax_scale"], + causal=True, + return_softmax_lse=False, + ) + + +@dataclasses.dataclass +class PagedMlaFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + return self._mla_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl, alloc_func=torch.empty + ): + q_nope, q_rope = q + qk_rope_head_dim = 64 + kv_lora_rank = k.shape[-1] - qk_rope_head_dim + return flash_attn_with_kvcache( + q=q_rope, + k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank), + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=att_control.mla_decode_dict["softmax_scale"], + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + ) diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py b/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/fp.py b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py new file mode 100644 index 0000000000..b1807ca30b --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py @@ -0,0 +1,193 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state): + return PagedFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedFlashInferAttBackend = self.backend + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + kv_indices = torch.empty( + batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size), + dtype=torch.int32, + device=device, + ) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class PagedFlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedFlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_last_page_len_buffer = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + buffer_len = self.infer_state.batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + new_state.backend.page_size, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/mla.py b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py new file mode 100644 index 0000000000..c9ea38052f --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py @@ -0,0 +1,184 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedMlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.v_head_dim = model.v_head_dim + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def create_att_prefill_state(self, infer_state): + return PagedMlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedMlaFlashInferAttBackend = self.backend + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.v_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k_nope, k_rope = k + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class PagedMlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedMlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + buffer_len = batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + self.backend.page_size, + False, + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + new_state.backend.page_size, + False, + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + q_nope, q_rope = q + qk_rope_head_dim = 64 + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), + k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index a1370a7045..627b6a84c3 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -30,7 +30,7 @@ def prefill_att( assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,7 +59,7 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/fa3_utils.py b/lightllm/common/basemodel/triton_kernel/fa3_utils.py index 0a524b63b6..f9d1c9e9c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fa3_utils.py +++ b/lightllm/common/basemodel/triton_kernel/fa3_utils.py @@ -1,5 +1,6 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -37,6 +38,13 @@ def page_table_copy( assert page_table.dim() == 2, "page_table should be 2D" assert req_to_token_indexs.dim() == 2, "req_to_token_indexs should be 2D" + page_size = get_page_size() + if page_size > 1: + max_seq_len_k = page_table.shape[1] * page_size + sampled = req_to_token_indexs[b_req_idx, :max_seq_len_k:page_size] + page_table.copy_(sampled // page_size) + return + max_seq_len_k = page_table.shape[1] batch_size = page_table.size(0) BLOCK_SIZE = 128 diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..d50a0a230b 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -33,6 +34,40 @@ def _fwd_kernel_repack_kv_index( return +@triton.jit +def _fwd_kernel_repack_page_kv_index_from_tokens( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + token_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = (start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) * page_size + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) * page_size + token_data = tl.load( + req_to_token_indexs + token_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + page_data = token_data // page_size + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, page_data, mask=offs_seq < block_end_loc) + return + + @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] @@ -58,6 +93,34 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv return +@torch.no_grad() +def paged_repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): + page_size = get_page_size() + assert page_size > 1 + batch_size = req_index.shape[0] + # flashinfer requires out_kv_index to be zeroed before use + out_kv_index.zero_() + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + kv_index.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): output[start : start + sl] = req_to_token_indexs[b][:sl] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index d49c8d7e73..276bbf54bc 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -1,6 +1,7 @@ import torch import os import torch.distributed as dist +import triton from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager from typing import List, Union, Any @@ -10,6 +11,7 @@ from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io from .operator import Deepseek2MemOperator +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -30,7 +32,9 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0a1deba499..37df531831 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +import triton from typing import List, Tuple, Any, Union from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -11,7 +12,7 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads @@ -38,6 +39,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.dtype = dtype # profile the max total token num if the size is None self.profile_size(mem_fraction) + page_size = get_page_size() + if page_size > 1: + self.size = (self.size // page_size) * page_size self.allocator = KvCacheAllocator(self.size) @@ -86,7 +90,9 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, 2 * head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): """ @@ -322,7 +328,79 @@ def _free_buffers(self): def alloc(self, need_size) -> torch.Tensor: return self.allocator.alloc(need_size) + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def alloc_token_indexes( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + return self.alloc(need_size) + + def _expand_to_page_mem_indices(self, free_index: Union[torch.Tensor, List[int]]): + page_size = get_page_size() + if page_size > 1: + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int32) + base_indices = free_index[free_index % page_size == 0] + if len(base_indices) == 0: + return free_index + page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device) + return (base_indices[:, None] + page_offsets[None, :]).reshape(-1) + + return free_index + + def _expand_by_page_size(self, b_token_len, page_size): + b_page_len = triton.cdiv(b_token_len, page_size) + need_pages_num = int(b_page_len.sum().item()) + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, p_token_len + + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() + if b_ready_cache_len is not None: + b_ready_cache_len = b_ready_cache_len.cpu() + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] + + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = int(need_new_page_mask.sum().item()) + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages_tokens = self.alloc(new_pages_num * page_size) + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] + mask = ~need_new_page_mask + if mask.any(): + token_idxs[mask] = b_last_mem_index[mask] + 1 + return token_idxs + + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = triton.cdiv(need_tokens_array, page_size) + need_new_pages = need_pages_array.sum() + else: + need_new_pages = ((b_seq_len - 1) % page_size == 0).sum() + return need_new_pages * page_size + def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: + free_index = self._expand_to_page_mem_indices(free_index) self.allocator.free(free_index) def free_all(self): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..4017aa467f 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,7 +1,6 @@ import torch import collections from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig - from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional, TYPE_CHECKING @@ -78,6 +77,22 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana def alloc(self): return self.req_list.alloc() + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return self.mem_manager.calc_real_need_token_num(need_token_num, b_seq_len, b_ready_cache_len) + + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + def alloc_token_indexes( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + return self.mem_manager.alloc_token_indexes(need_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + def free(self, free_req_indexes: List[int], free_token_index): for req_index in free_req_indexes: self.req_list.free(req_index) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3a29566283..425476a4e1 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -11,7 +11,7 @@ from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name -from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive +from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive, get_page_size, set_page_size from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active @@ -29,6 +29,20 @@ logger = init_logger(__name__) +def _auto_set_fa4_page_size(args, requested_backends): + if "fa4" not in requested_backends or "PAGE_SIZE" in os.environ: + return + + from lightllm.utils.fa4_utils import infer_fa4_page_size + + page_size = infer_fa4_page_size(args.model_dir) + if page_size is None: + return + + set_page_size(page_size) + logger.info(f"auto set PAGE_SIZE={page_size} for FA4 backend") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -208,23 +222,35 @@ def normal_or_p_d_start(args): f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported." ) - requested_backends = ( - list(args.llm_prefill_att_backend) + list(args.llm_decode_att_backend) + list(args.vit_att_backend) - ) + llm_requested_backends = list(args.llm_prefill_att_backend) + list(args.llm_decode_att_backend) + requested_backends = llm_requested_backends + list(args.vit_att_backend) if "fa4" in requested_backends: - from lightllm.utils.fa4_utils import ensure_fa4_available, ensure_fa4_supported_gpu - - ensure_fa4_available() - ensure_fa4_supported_gpu() + _auto_set_fa4_page_size(args, llm_requested_backends) # mtp params check if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None assert args.mtp_step > 0 + assert get_page_size() == 1, "page_size > 1 is not supported with MTP, please set PAGE_SIZE=1" else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # page_size > 1 compatibility check + if get_page_size() > 1: + assert args.run_mode not in ( + "prefill", + "decode", + ), "page_size > 1 is not supported with RPyC PD split mode, please set PAGE_SIZE=1" + assert args.run_mode not in ( + "nixl_prefill", + "nixl_decode", + ), "page_size > 1 is not supported with NIXL PD split mode, please set PAGE_SIZE=1" + assert ( + not args.enable_dp_prefill_balance + ), "page_size > 1 is not supported with DP prefill balance, please set PAGE_SIZE=1" + assert not args.enable_cpu_cache, "page_size > 1 is not supported with CPU cache, please set PAGE_SIZE=1" + if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 0000000000..ff6c187877 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,532 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +import collections +from typing import Tuple, Dict, Set, List, Optional, Union +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) + + def _compute_key(self, tokens: torch.Tensor): + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else page_tokens.cpu().numpy().tobytes() + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + split_parent_node.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + split_parent_node.children[self._compute_key(self.token_id_key[prefix_len:])] = self + split_parent_node.ref_counter = self.ref_counter + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = self.token_id_key[prefix_len:] + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child_key = child._compute_key(child.token_id_key) + assert child_key not in self.children.keys() + self.children[child_key] = child + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[child_node._compute_key(child_node.token_id_key)] + child_node.parent = None + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.common.kv_cache_mem_manager import MemoryManager + + self.mem_manager: MemoryManager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _align_prefix_len(self, prefix_len: int) -> int: + if self.page_size <= 1: + return prefix_len + if prefix_len % self.page_size == 0: + return prefix_len + if self._page_size_is_power_of_2: + return prefix_len & ~self._page_size_mask + return (prefix_len // self.page_size) * self.page_size + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + if self.page_size > 1 and aligned_len % self.page_size != 0: + aligned_len = self._align_prefix_len(aligned_len) + if free_truncated and aligned_len < len(key) and self.mem_manager is not None and value is not None: + truncated_value = value[aligned_len:] + if len(truncated_value) > 0: + base = truncated_value[0] - truncated_value[0] % self.page_size + full_page = torch.arange( + base, base + self.page_size, dtype=truncated_value.dtype, device=truncated_value.device + ) + self.mem_manager.free(full_page) + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + if value is None: + value = key + + assert len(key) == len(value) + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0, None + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key, value)) + + ans_prefix_len = 0 + ans_node = None + + while len(handle_stack) != 0: + node, key, value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) + if len(ans_tuple) == 4: + (_prefix_len, new_node, new_key, new_value) = ans_tuple + ans_prefix_len += _prefix_len + handle_stack.append((new_node, new_key, new_value)) + else: + _prefix_len, ans_node = ans_tuple + ans_prefix_len += _prefix_len + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + assert ans_node is not None + + return ans_prefix_len, ans_node + + def _insert_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + child_key = node._compute_key(key) + if child_key in node.children.keys(): + child: TreeNode = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + if prefix_len == len(key): + if prefix_len == len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, child + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if child.is_leaf(): + self.evict_tree_set.add(child) + + return prefix_len, split_parent_node + else: + assert False, "can not run to here" + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, new_node + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + if update_refs: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key)) + + ans_node = None + + while len(handle_stack) != 0: + node, key = handle_stack.popleft() + ans_tuple = self._match_prefix_helper_no_recursion( + node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs + ) + if isinstance(ans_tuple, tuple): + new_node, new_key = ans_tuple + handle_stack.append((new_node, new_key)) + else: + ans_node = ans_tuple + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + return ans_node + + def _match_prefix_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + if len(key) == 0: + return node + + child_key = node._compute_key(key) + if child_key not in node.children.keys(): + return node + else: + child = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + return node + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return (child, key[prefix_len:]) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: + parent_node = child_node.parent + if ( + parent_node is None + or parent_node == self.root_node + or parent_node.ref_counter != 0 + or len(parent_node.children) != 1 + or child_node.ref_counter != 0 + ): + return None + + if child_node.is_leaf(): + self.evict_tree_set.discard(child_node) + + child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key]) + child_node.token_mem_index_value = torch.cat( + [parent_node.token_mem_index_value, child_node.token_mem_index_value] + ) + child_node.node_value_len = len(child_node.token_mem_index_value) + child_node.time_id = max(parent_node.time_id, child_node.time_id) + + grandparent_node = parent_node.parent + key_in_grandparent = grandparent_node._compute_key(parent_node.token_id_key) + grandparent_node.children[key_in_grandparent] = child_node + child_node.parent = grandparent_node + + parent_node.parent = None + + if child_node.is_leaf(): + self.evict_tree_set.add(child_node) + + return child_node + + def merge_unreferenced_nodes(self): + worklist = collections.deque( + [ + node + for node in self.evict_tree_set + if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node + ] + ) + + while worklist: + node = worklist.popleft() + if node.parent is None: + continue + merged_node = self._try_merge(node) + if merged_node: + worklist.append(merged_node) + + def clear_tree_nodes(self): + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.allocator.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c1..410e7a724d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -513,6 +513,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False @@ -626,6 +627,7 @@ def _match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 self.shm_req.shm_cur_kv_len = self.cur_kv_len @@ -662,6 +664,7 @@ def _linear_match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 assert self.tail_linear_att_small_page_buffer_id is None # 恢复linear att 状态 @@ -677,6 +680,7 @@ def _linear_match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 assert self.tail_linear_att_small_page_buffer_id is None # 恢复linear att 状态 @@ -717,6 +721,12 @@ def _linear_match_radix_cache(self): big_page_shared_node = radix_cache.deref_to_first_big_page_node(node=share_node) self.shared_kv_node = big_page_shared_node self.cur_kv_len = int(shared_kv_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + if need_tokens > 0: + self.last_kv_mem_index = tail_mems[-1].item() + else: + self.last_kv_mem_index = ( + value_tensor[cur_big_page_tokens - 1].item() if cur_big_page_tokens > 0 else -1 + ) self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 else: # 没有充足的token 容量时, 直接找到最接近的大页,进行大页恢复 @@ -731,6 +741,9 @@ def _linear_match_radix_cache(self): self.req_idx, 0:ready_cache_len ] = value_tensor[0:ready_cache_len] self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = ( + value_tensor[ready_cache_len - 1].item() if ready_cache_len > 0 else -1 + ) self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 assert self.tail_linear_att_small_page_buffer_id is None # 恢复linear att 状态 diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0220dc87fb..665e8dbb2d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,6 +9,7 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -31,6 +32,7 @@ from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node from lightllm.utils.envs_utils import ( get_env_start_args, + get_page_size, enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) @@ -194,7 +196,8 @@ def init_model(self, kvargs): linear_att_small_page_buffers=self.linear_att_cache_manager, ) else: - self.radix_cache = RadixCache( + radix_cache_class = PagedRadixCache if get_page_size() > 1 else RadixCache + self.radix_cache = radix_cache_class( unique_name=get_unique_server_name(), total_token_num=self.model.mem_manager.size, rank_in_node=self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index d3796b6392..840d5a3e62 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -97,8 +97,18 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_token_indexes( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() if padded_req_num > 0: @@ -152,6 +162,7 @@ def padded_prepare_decode_inputs( b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] args_mtp_step = get_env_start_args().mtp_step batch_multimodal_params = [] for req in req_objs: @@ -164,6 +175,7 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -199,13 +211,23 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + b_seq_len.shape[0] - padded_mem_indexes_num, b_seq_len[: len(b_last_mem_index)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_token_indexes( + b_seq_len.shape[0] - padded_mem_indexes_num, + b_seq_len[: len(b_last_mem_index)], + b_last_mem_index=b_last_mem_index, + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() if padded_mem_indexes_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index ae1af19565..e6aa0b6c61 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -71,8 +71,16 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0], b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_token_indexes(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -105,6 +113,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] multimodal_params = [] for req in req_objs: run_reqs.append(req) @@ -116,6 +125,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -133,6 +143,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") if enable_diverse_mode_gqa_decode_fast_kernel(): b_shared_seq_len, b_mark_shared_group = build_diverse_shared_group_infos(run_reqs=run_reqs) @@ -143,8 +154,13 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_token_indexes( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..6eec51367a 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -38,9 +39,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = len(self.cache_len_list) * (page_size - 1) if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index 0f2548be8d..fc17f14ee2 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 # 给flashinfer jit编译预留足够时间 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 60.0 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 2bdd4005fa..db8e8c19d0 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -165,6 +165,16 @@ def get_triton_autotune_level(): return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) +@lru_cache(maxsize=None) +def get_page_size(): + return int(os.getenv("PAGE_SIZE", 1)) + + +def set_page_size(page_size: int): + os.environ["PAGE_SIZE"] = str(page_size) + get_page_size.cache_clear() + + g_model_init_done = False diff --git a/lightllm/utils/fa4_utils.py b/lightllm/utils/fa4_utils.py index d7dfae722f..4780a61c50 100644 --- a/lightllm/utils/fa4_utils.py +++ b/lightllm/utils/fa4_utils.py @@ -42,5 +42,45 @@ def ensure_fa4_supported_gpu() -> None: ) +def sm90_fa4_paged_kv_tile_n(head_dim: int, head_dim_v: int, window_size: tuple[int, int] = (-1, -1)) -> int | None: + major, _minor = torch.cuda.get_device_capability() + if major != 9: + return None + + is_local = window_size != (-1, -1) + if head_dim <= 64: + return 128 + if head_dim <= 96: + return 128 if is_local else 144 + if head_dim <= 128: + return 128 + if head_dim <= 192: + return 96 if is_local else (128 if head_dim_v <= 128 else 112) + return 64 if is_local else 80 + + +def infer_fa4_page_size(model_dir: str) -> int | None: + from transformers.configuration_utils import PretrainedConfig + from lightllm.utils.device_utils import is_sm100_gpu + + if is_sm100_gpu(): + return 128 + + model_cfg, _ = PretrainedConfig.get_config_dict(model_dir) + llm_config = model_cfg.get("text_config", model_cfg) + + head_dim = llm_config.get("head_dim") + if head_dim is None: + head_dim = llm_config["hidden_size"] // llm_config["num_attention_heads"] + head_dim_v = llm_config.get("v_head_dim", head_dim) + + window_size = (-1, -1) + sliding_window = llm_config.get("sliding_window", None) + if sliding_window is not None and not llm_config.get("full_attention_interval", None): + window_size = (sliding_window - 1, sliding_window - 1) + + return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size) + + def unwrap_fa4_output(output): return output[0] if isinstance(output, tuple) else output diff --git a/requirements.txt b/requirements.txt index f124ce76f5..e80bc98196 100644 --- a/requirements.txt +++ b/requirements.txt @@ -98,3 +98,4 @@ nixl==1.1.0 xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 +flash-attn-4[13]==4.0.0b15 diff --git a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index b5184d3caa..0bab0ae540 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,8 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index, paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -41,3 +42,49 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) assert torch.allclose(output.float(), ref.float()) + + +@pytest.mark.parametrize( + "batch, max_seq_len, page_size", + [ + (1, 16, 4), + (8, 32, 4), + (16, 128, 8), + ], +) +def test_paged_repack_kv_index(batch, max_seq_len, page_size, monkeypatch): + def repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, output, page_size): + for b, sl, start in zip(b_req_idx, b_page_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][: sl * page_size : page_size] // page_size + + BATCH, MAX_SEQ_LEN = batch, max_seq_len + max_page_len = (MAX_SEQ_LEN + page_size - 1) // page_size + total_token_len = 2 * MAX_SEQ_LEN + total_page_len = (total_token_len + page_size - 1) // page_size + + req_to_token_indexs = torch.empty((2 * BATCH, total_token_len), dtype=torch.int32, device="cuda") + page_offsets = torch.arange(page_size, dtype=torch.int32, device="cuda") + for row in range(2 * BATCH): + page_ids = torch.arange(row * total_page_len, (row + 1) * total_page_len, dtype=torch.int32, device="cuda") + req_to_token_indexs[row] = (page_ids[:, None] * page_size + page_offsets[None, :]).reshape(-1)[:total_token_len] + + b_req_idx = torch.randperm(BATCH, device="cuda", dtype=torch.int32) + b_seq_len = torch.randint(1, MAX_SEQ_LEN + 1, (BATCH,), device="cuda", dtype=torch.int32) + b_page_len = (b_seq_len + page_size - 1) // page_size + b_start_loc = torch.cat( + [torch.zeros((1,), dtype=torch.int32, device="cuda"), b_page_len[:-1].cumsum(dim=0, dtype=torch.int32)] + ) + + output = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + ref = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + + monkeypatch.setenv("PAGE_SIZE", str(page_size)) + get_page_size.cache_clear() + try: + repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, ref, page_size) + paged_repack_kv_index(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_page_len, output) + finally: + monkeypatch.delenv("PAGE_SIZE", raising=False) + get_page_size.cache_clear() + + assert torch.equal(output, ref)