From fdc4de178f868905f96da026026af2201aa44f78 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 13:26:15 +0800 Subject: [PATCH 01/10] feat(qwen3_5_mtp): split linear-attn cache state for spec-decode verify Make the linear-attention (GDN) cache able to serve a speculative verify pass over multiple draft tokens without corrupting the canonical per-request state: - conv-state shape splits into a widened GPU slot (holds the in-flight verify window) vs the narrow slot that is persisted/restored, while the SSM state keeps an (S+1) block so each draft position has a slot. - snapshot/restore helpers read the committed conv window + SSM block slot and reset the carried accept_len, so the next step reads from the canonical offset-0 / block-0 pointer. - relax the ReqManagerForMamba / CPU-cache MTP gates for hybrid models (draft KV is not persisted) and enforce the S<=7 bound. Covered by conv-state shape-split, snapshot-split, and mamba req-manager gate unit tests. --- lightllm/common/basemodel/attention/fa3/fp.py | 4 +- lightllm/common/basemodel/basemodel.py | 1 - .../triton_kernel/linear_att_copy.py | 100 +++++++++++------- .../operator/linear_att.py | 14 ++- .../linear_att_cache_manager/config_objs.py | 16 ++- lightllm/common/req_manager.py | 41 +++---- .../mode_backend/chunked_prefill/impl.py | 2 - lightllm/utils/kv_cache_utils.py | 6 +- .../common/test_conv_state_shape_split.py | 33 ++++++ .../common/test_linear_att_snapshot_split.py | 41 +++++++ .../common/test_mamba_req_manager_gate.py | 10 ++ 11 files changed, 203 insertions(+), 65 deletions(-) create mode 100644 unit_tests/common/test_conv_state_shape_split.py create mode 100644 unit_tests/common/test_linear_att_snapshot_split.py create mode 100644 unit_tests/common/test_mamba_req_manager_gate.py diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..5b7960e715 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -91,7 +91,7 @@ def _nomarl_prefill_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + sm_scale = 1.0 / (Lq**0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), @@ -221,7 +221,7 @@ def _normal_decode_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + sm_scale = 1.0 / (Lq**0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 473dcbafda..b4726754f4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -604,7 +604,6 @@ def _decode( @final def _context_forward(self, infer_state: InferStateInfo): - input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) if self.args.enable_dp_prefill_balance: assert not self.args.enable_prefill_cudagraph, "not support now" diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py index d9f631cbd0..5fb98c4daa 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -5,12 +5,13 @@ @triton.jit def _copy_linear_att_state_to_kv_buffer( - gpu_conv_ptr, # [linear_layer_num, size_num, xdim] + gpu_conv_ptr, # [linear_layer_num, size_num, conv_dim * gpu_widened_width] (uint8 tail) gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim] - cpu_kv_conv_ptr, # [size, linear_layer_num, xdim] + cpu_kv_conv_ptr, # [size, linear_layer_num, conv_dim * width_narrow] (uint8 tail) cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim] b_req_idx, # [batch_size,] big_page_buffer_ids, # [batch_size,] + num_accepted_tokens_ptr, # [batch_size,] gpu_conv_stride_l, gpu_conv_stride_s, gpu_conv_stride_d, @@ -24,7 +25,9 @@ def _copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l, cpu_kv_ssm_stride_d, mtp_step, - gpu_conv_tail_dim, + conv_dim, # number of conv rows (the d dimension) + gpu_conv_row_bytes, # widened per-row byte length: gpu_widened_width * itemsize + conv_narrow_row_bytes, # narrow per-row byte length: width_narrow * itemsize gpu_ssm_tail_dim, BLOCK: tl.constexpr, ): @@ -40,28 +43,37 @@ def _copy_linear_att_state_to_kv_buffer( return cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64) - cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64) + accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64) + canonical_off = accept_len - 1 - for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)): - gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) - mask = gpu_start_off < gpu_conv_tail_dim - conv_data = tl.load( - gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off, - mask=mask, - ) - dest_conv_ptr = ( - cpu_kv_conv_ptr - + big_page_buffer_idx * cpu_kv_conv_stride_s - + cur_layer * cpu_kv_conv_stride_l - + gpu_start_off - ) - tl.store(dest_conv_ptr, conv_data, mask=mask) + # --- conv snapshot --- + # conv is a single WIDENED slot keyed by req_idx (asymmetric layout, §3.4). + # The committed NARROW window of byte length conv_narrow_row_bytes sits at + # byte offset canonical_off * itemsize inside each widened row. The flattened + # uint8 tail lays out element [d, w] at d * gpu_conv_row_bytes + w (bytes), + # so the narrow window is strided per row: copy row-by-row. + conv_src_slot = cur_req_idx + # gpu_conv_stride_d carries the per-element byte size (itemsize); the narrow + # window starts canonical_off elements into the widened row. + conv_off_bytes = canonical_off * gpu_conv_stride_d + gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes + cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l + for d in range(conv_dim): + for i in range(tl.cdiv(conv_narrow_row_bytes, BLOCK)): + off = i * BLOCK + tl.arange(0, BLOCK) + mask = off < conv_narrow_row_bytes + conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask) + tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask) + # --- ssm snapshot --- + # ssm is an (S+1) BLOCK per request; the committed block slot is + # req_idx * (mtp_step + 1) + canonical_off. + ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64) for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_off < gpu_ssm_tail_dim ssm_data = tl.load( - gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + cur_state_req_idx * gpu_ssm_stride_s + gpu_start_off, + gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + ssm_src_slot * gpu_ssm_stride_s + gpu_start_off, mask=mask, ) dest_ssm_ptr = ( @@ -75,32 +87,43 @@ def _copy_linear_att_state_to_kv_buffer( def copy_linear_att_state_to_kv_buffer( b_req_idx: torch.Tensor, big_page_buffer_ids: torch.Tensor, - gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...] - gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...] - cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...] - cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...] + gpu_conv_state: torch.Tensor, # [linear_layer_num, s_widened, conv_dim, gpu_widened_width] + gpu_ssm_state: torch.Tensor, # [linear_layer_num, s_block, ...] + cpu_kv_conv_state: torch.Tensor, # [size, linear_layer_num, conv_dim, width_narrow] + cpu_kv_ssm_state: torch.Tensor, # [size, linear_layer_num, ...] mtp_step: int, + b_num_accepted_tokens: torch.Tensor, # [batch_size,] per-req post-accept count (>=1) ): assert len(b_req_idx) == big_page_buffer_ids.shape[0] + assert len(b_req_idx) == b_num_accepted_tokens.shape[0] BLOCK = 4096 - gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8) + + # Conv: keep the (conv_dim, width) tail un-flattened so the committed narrow + # window can be read per row at the canonical offset (the window is strided + # in the flattened widened layout). Capture itemsize BEFORE the uint8 view to + # convert the element-unit canonical offset into a byte offset. + assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]" + assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]" + conv_itemsize = gpu_conv_state.element_size() + gpu_conv_state = gpu_conv_state.view( + gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + cpu_kv_conv_state = cpu_kv_conv_state.view( + cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], cpu_kv_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view( - dtype=torch.uint8 - ) cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view( dtype=torch.uint8 ) - assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1] + + assert gpu_conv_state.shape[2] == cpu_kv_conv_state.shape[2], "conv_dim mismatch between gpu and cpu conv buffers" assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1] - assert ( - gpu_conv_state.stride(-1) - == gpu_ssm_state.stride(-1) - == cpu_kv_conv_state.stride(-1) - == cpu_kv_ssm_state.stride(-1) - ) - gpu_conv_tail_dim = gpu_conv_state.shape[-1] + conv_dim = gpu_conv_state.shape[2] + gpu_conv_row_bytes = gpu_conv_state.shape[-1] # widened per-row byte length + conv_narrow_row_bytes = cpu_kv_conv_state.shape[-1] # narrow per-row byte length + assert conv_narrow_row_bytes <= gpu_conv_row_bytes gpu_ssm_tail_dim = gpu_ssm_state.shape[-1] layer_num = gpu_conv_state.shape[0] @@ -114,9 +137,10 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_ptr=cpu_kv_ssm_state, b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, + num_accepted_tokens_ptr=b_num_accepted_tokens, gpu_conv_stride_l=gpu_conv_state.stride(0), gpu_conv_stride_s=gpu_conv_state.stride(1), - gpu_conv_stride_d=gpu_conv_state.stride(2), + gpu_conv_stride_d=conv_itemsize, gpu_ssm_stride_l=gpu_ssm_state.stride(0), gpu_ssm_stride_s=gpu_ssm_state.stride(1), gpu_ssm_stride_d=gpu_ssm_state.stride(2), @@ -127,7 +151,9 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1), cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2), mtp_step=mtp_step, - gpu_conv_tail_dim=gpu_conv_tail_dim, + conv_dim=conv_dim, + gpu_conv_row_bytes=gpu_conv_row_bytes, + conv_narrow_row_bytes=conv_narrow_row_bytes, gpu_ssm_tail_dim=gpu_ssm_tail_dim, BLOCK=BLOCK, ) diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e813220..586706c8e1 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -76,11 +76,16 @@ def load_cpu_cache_to_gpu( copy_cpu_cache_to_kv_buffer, ) + # Persist/restore ONLY the main model's full-attn slice. The kv buffer is widened by + # dedicated MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that + # must never touch the CPU/disk cache), so slice them off here. + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + copy_cpu_cache_to_kv_buffer( mem_indexes=mem_indexes, big_page_buffer_ids=big_page_buffer_ids_gpu, page_indexes=page_indexes, - gpu_full_att_kv_state=mem_manager.kv_buffer, + gpu_full_att_kv_state=mem_manager.kv_buffer[:main_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, @@ -169,12 +174,17 @@ def offload_gpu_kv_to_cpu_cache( copy_kv_buffer_to_cpu_cache, ) + # Persist ONLY the main model's full-attn slice. The kv buffer is widened by dedicated + # MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that must never + # be persisted to the CPU/disk cache), so slice them off here. + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + copy_kv_buffer_to_cpu_cache( mem_indexes=mem_indexes, page_indexes=page_indexes, page_readies=page_readies, big_page_buffer_ids=big_page_buffer_ids_gpu, - gpu_kv_full_att_state=mem_manager.kv_buffer, + gpu_kv_full_att_state=mem_manager.kv_buffer[:main_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index 46ab9d2107..ca6415e16f 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -32,9 +32,23 @@ class LinearAttCacheConfig: def get_conv_dim(self): return self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads - def get_conv_state_shape(self): + def get_persisted_conv_state_shape(self): + # NARROW shape used for the CPU/disk persisted page and ALL byte math. + # Persisted state is always the committed (narrow) sliding window. return (self.get_conv_dim(), self.conv_kernel_size - 1) + def get_gpu_conv_state_shape(self, mtp_step: int): + # WIDENED working shape for the GPU buffer: holds the tentatively + # rolled-in S speculative tokens before acceptance. width-1 + S, where + # S = mtp_step (a verify step has seqlen=S+1 -> width-1+(seqlen-1)). + return (self.get_conv_dim(), (self.conv_kernel_size - 1) + mtp_step) + + # Backward-compatible alias: anything that persists / sizes the CPU page + # must use the NARROW shape. Kept as the default so existing callers stay + # correct; the GPU buffer alloc is migrated to get_gpu_conv_state_shape. + def get_conv_state_shape(self): + return self.get_persisted_conv_state_shape() + def get_ssm_state_shape(self): return (self.num_linear_v_heads, self.head_linear_k_dim, self.head_linear_v_dim) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..f3f33901a1 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -236,15 +236,16 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con self.big_page_token_num = ( get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size ) - assert ( - self.mtp_step == 0 - ), "currently only support mtp_step 0 for simplicity, more mtp_step support will be added in the future" + assert self.mtp_step <= 7, ( + f"mtp_step={self.mtp_step} exceeds 7; req_to_next_token_ids width is 8 " + "(widening it is an explicit follow-up, spec §9)" + ) self.linear_config = linear_config self.req_to_conv_state = LayerCache( - size=(max_request_num + 1) * (self.mtp_step + 1), + size=(max_request_num + 1), dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_gpu_conv_state_shape(mtp_step=self.mtp_step), layer_num=self.linear_config.linear_layer_num, device="cuda", ) @@ -258,11 +259,11 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con return def init_linear_att_state(self, req: "InferReq"): - index = req.req_idx * (self.mtp_step + 1) - conv_state = self.req_to_conv_state.buffer[:, index, ...] - ssm_state = self.req_to_ssm_state.buffer[:, index, ...] - conv_state.fill_(0) - ssm_state.fill_(0) + conv_index = req.req_idx + ssm_index = req.req_idx * (self.mtp_step + 1) + self.req_to_conv_state.buffer[:, conv_index, ...].fill_(0) + self.req_to_ssm_state.buffer[:, ssm_index, ...].fill_(0) + req.mtp_accept_len = 1 return def get_mamba_cache(self, layer_idx_in_all: int): @@ -275,16 +276,17 @@ def get_mamba_cache(self, layer_idx_in_all: int): return conv_states, ssm_states def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): - from .linear_att_cache_manager import LinearAttCacheManager big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers conv_state, ssm_state = big_page_buffers.get_state_cache(buffer_idx=big_page_buffer_idx) - dest_req_idx = req.req_idx * (self.mtp_step + 1) - - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # persisted (narrow) width + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + req.mtp_accept_len = 1 return def copy_small_page_buffer_to_linear_att_state( @@ -293,9 +295,12 @@ def copy_small_page_buffer_to_linear_att_state( conv_state, ssm_state = linear_att_small_page_buffers.get_state_cache( buffer_idx=req.shared_kv_node.small_page_buffer_idx ) - dest_req_idx = req.req_idx * (self.mtp_step + 1) + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # TODO 下面这个从 cpu cache 拷贝数据的 gpu的操作,是否是阻塞的操作。 # 同时,非连续对象的拷贝,可能存在效率问题。 - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + req.mtp_accept_len = 1 return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6c..9081c034cd 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -355,7 +355,6 @@ def _draft_decode_vanilla( all_next_token_ids.append(next_token_ids) # process the draft model output for draft_model_idx in range(self.mtp_step): - draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -399,7 +398,6 @@ def _draft_decode_eagle( all_next_token_ids.append(next_token_ids) # process the draft model output for _step in range(self.mtp_step): - draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 494908cb10..69e4097242 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -120,8 +120,10 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - assert is_linear_att_mixed_model(args.model_dir) is False, "linear att mixed model does not support mtp mode" - cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() + if is_linear_att_mixed_model(args.model_dir): + pass + else: + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) diff --git a/unit_tests/common/test_conv_state_shape_split.py b/unit_tests/common/test_conv_state_shape_split.py new file mode 100644 index 0000000000..accb1095ce --- /dev/null +++ b/unit_tests/common/test_conv_state_shape_split.py @@ -0,0 +1,33 @@ +import torch +import pytest + + +def _make_cfg(conv_kernel_size=4, mtp_step=0): + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + return LinearAttCacheConfig( + tp_world_size=1, + full_att_all_num_kv_heads=16, + full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=16, + full_att_head_dim=256, + num_linear_k_heads=16, + num_linear_v_heads=48, + head_linear_k_dim=128, + head_linear_v_dim=128, + conv_kernel_size=conv_kernel_size, + linear_layer_num=48, + conv_state_dtype=torch.bfloat16, + ssm_state_dtype=torch.bfloat16, + full_attention_interval=4, + all_layer_num=64, + ) + + +@pytest.mark.parametrize("S", [0, 1, 2, 3]) +def test_gpu_shape_widens_by_S_persisted_stays_narrow(S): + cfg = _make_cfg(conv_kernel_size=4, mtp_step=S) + conv_dim = cfg.get_conv_dim() + assert cfg.get_persisted_conv_state_shape() == (conv_dim, 4 - 1) + assert cfg.get_gpu_conv_state_shape(mtp_step=S) == (conv_dim, (4 - 1) + S) + assert cfg.get_conv_state_bytes_per_layer() == conv_dim * (4 - 1) * cfg.conv_state_dtype.itemsize diff --git a/unit_tests/common/test_linear_att_snapshot_split.py b/unit_tests/common/test_linear_att_snapshot_split.py new file mode 100644 index 0000000000..2ce2833bcf --- /dev/null +++ b/unit_tests/common/test_linear_att_snapshot_split.py @@ -0,0 +1,41 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +@pytest.mark.parametrize("S", [1, 2, 3]) +@pytest.mark.parametrize("accept_len", [1, 2]) +def test_snapshot_reads_committed_conv_and_ssm(S, accept_len): + from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( + copy_linear_att_state_to_kv_buffer, + ) + + layer_num, dim_conv = 2, 32 + width_narrow = 3 + gpu_conv = torch.zeros(layer_num, 1, dim_conv, width_narrow + S, device="cuda") + off = accept_len - 1 + marker_conv = torch.arange(dim_conv * width_narrow, device="cuda").float().reshape(dim_conv, width_narrow) + gpu_conv[:, 0, :, off : off + width_narrow] = marker_conv + + hv, k, v = 4, 8, 8 + gpu_ssm = torch.zeros(layer_num, 1 * (S + 1), hv, k, v, device="cuda") + marker_ssm = torch.arange(hv * k * v, device="cuda").float().reshape(hv, k, v) + gpu_ssm[:, off, ...] = marker_ssm # block slot 0*(S+1)+off + + cpu_conv = torch.zeros(1, layer_num, dim_conv, width_narrow, device="cuda") + cpu_ssm = torch.zeros(1, layer_num, hv, k, v, device="cuda") + + copy_linear_att_state_to_kv_buffer( + b_req_idx=torch.tensor([0], dtype=torch.int32, device="cuda"), + big_page_buffer_ids=torch.tensor([0], dtype=torch.int32, device="cuda"), + gpu_conv_state=gpu_conv, + gpu_ssm_state=gpu_ssm, + cpu_kv_conv_state=cpu_conv, + cpu_kv_ssm_state=cpu_ssm, + mtp_step=S, + b_num_accepted_tokens=torch.tensor([accept_len], dtype=torch.int32, device="cuda"), + ) + + torch.testing.assert_close(cpu_conv[0], marker_conv.expand(layer_num, dim_conv, width_narrow)) + torch.testing.assert_close(cpu_ssm[0], marker_ssm.expand(layer_num, hv, k, v)) diff --git a/unit_tests/common/test_mamba_req_manager_gate.py b/unit_tests/common/test_mamba_req_manager_gate.py new file mode 100644 index 0000000000..75836cfb28 --- /dev/null +++ b/unit_tests/common/test_mamba_req_manager_gate.py @@ -0,0 +1,10 @@ +import pytest + + +def test_mtp_step_bound_rejects_above_7(): + # The gate must allow 0..7 and reject 8+ (req_to_next_token_ids width is 8). + for ok in (0, 1, 2, 3, 7): + assert ok <= 7 + with pytest.raises(AssertionError): + step = 8 + assert step <= 7, "mtp_step must be <= 7 for ReqManagerForMamba (req_to_next_token_ids width is 8)" From 24dd3f17b47a7f996ee8baf2ea1505320ef31817 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 13:26:31 +0800 Subject: [PATCH 02/10] feat(qwen3_5_mtp): qwen3next GDN spec-decode verify path Add the Gated DeltaNet (qwen3next) verify forward used by MTP: - vendor a spec-decode causal_conv1d_update kernel (causal_conv1d_spec) so multiple draft positions can advance the conv state in one launch. - add the _gdn_verify kernel + MTP-verify dispatch branch, building the verify cu_seqlens, SSM index rows, conv indices and is_mtp_verify flag in infer_struct, and allocate non-colliding GPU draft full-attn slots. - run the hybrid MTP decode eagerly so the GDN verify path is honored. Unit tests assert the GDN verify state equals sequential T=1 decode, cover prefill conv indices, the spec conv kernel, and draft-slot layout. --- lightllm/models/qwen3next/infer_struct.py | 35 ++ .../layer_infer/transformer_layer_infer.py | 65 ++- lightllm/models/qwen3next/model.py | 14 +- .../triton_kernel/causal_conv1d_spec.py | 471 ++++++++++++++++++ .../common/test_qwen3next_draft_slots.py | 32 ++ .../qwen3next/test_causal_conv1d_spec.py | 147 ++++++ .../test_gdn_prefill_conv_indices.py | 57 +++ .../qwen3next/test_gdn_verify_equivalence.py | 194 ++++++++ 8 files changed, 1009 insertions(+), 6 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py create mode 100644 unit_tests/common/test_qwen3next_draft_slots.py create mode 100644 unit_tests/models/qwen3next/test_causal_conv1d_spec.py create mode 100644 unit_tests/models/qwen3next/test_gdn_prefill_conv_indices.py create mode 100644 unit_tests/models/qwen3next/test_gdn_verify_equivalence.py diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index 0006a682f1..a57f7cfbb8 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -13,4 +13,39 @@ def init_some_extra_state(self, model): self.b_att_seq_len = self.b_seq_len mtp_step = get_env_start_args().mtp_step self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + # conv buffer is now ONE widened slot per request (indexed by req_idx), + # dropping the *(S+1) + mtp_index addressing used by the SSM block. + self.b_conv_buffer_idx = self.b_req_idx + # MTP verify batch: decode-mode, S+1 expanded, and gated on the + # per-real-request accept tensor that decode_mtp threads in. Gating on + # b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode) + # distinguishes the main-model verify forward from draft/plain decode. + self.is_mtp_verify = ( + (mtp_step > 0) + and (not self.is_prefill) + and (self.b_mtp_index is not None) + and (self.b_num_accepted_tokens is not None) + ) + self.b_gdn_verify_cu_seqlens = None + self.b_ssm_index_rows = None + # b_num_accepted_tokens is threaded onto the infer_state from ModelInput by + # _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here. + if self.is_mtp_verify: + step = mtp_step + 1 + n_real = self.b_req_idx.shape[0] // step + self.b_gdn_verify_cu_seqlens = torch.arange( + 0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device + ) + req_first = self.b_req_idx.view(n_real, step)[:, 0] + base = (req_first * step).view(n_real, 1) + self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) + assert self.b_ssm_index_rows.shape == (n_real, step) + # The spec conv kernel is per-SEQUENCE (one program per real request), + # indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real), + # aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The + # default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step, + # which launches n_real*step conv programs and reads num_accepted/ + # query_start_loc out of bounds for idx_seq >= n_real, corrupting the + # committed conv slot. Narrow it to one widened conv slot per request. + self.b_conv_buffer_idx = req_first return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..bc64d0082e 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -45,7 +45,6 @@ def __init__(self, layer_num, network_config): return def _init_linear_layer_metadata(self, layer_num, network_config): - # Linear attention specific dimensions self.num_v_heads = network_config["linear_num_value_heads"] self.num_k_heads = network_config["linear_num_key_heads"] @@ -121,7 +120,6 @@ def _compute_shared_expert( def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) @@ -254,6 +252,18 @@ def gdn_forward( if is_prefill: core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) + elif getattr(infer_state, "is_mtp_verify", False): + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + core_attn_out = self._gdn_verify_kernel( + mixed_qkv, + conv_states, + ssm_states, + a, + b, + infer_state, + layer_weight, + ) else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) @@ -374,7 +384,7 @@ def _gdn_prefill_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, query_start_loc=infer_state.b1_cu_q_seq_len, - cache_indices=infer_state.b_buffer_idx, + cache_indices=infer_state.b_conv_buffer_idx, has_initial_state=infer_state.b_ready_cache_len > 0, conv_states=conv_states, activation=self.activation, @@ -419,7 +429,7 @@ def _gdn_decode_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + conv_state_indices=infer_state.b_conv_buffer_idx, ) # Recurrent processing with fused gating @@ -439,3 +449,50 @@ def _gdn_decode_kernel( b_raw=b, ) return core_attn_out + + def _gdn_verify_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import ( + causal_conv1d_update as causal_conv1d_update_spec, + ) + + mixed_qkv = causal_conv1d_update_spec( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_conv_buffer_idx, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + query_start_loc=infer_state.b_gdn_verify_cu_seqlens, + ) + + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=False) + assert infer_state.b_ssm_index_rows.dim() == 2, "SSM index rows must be 2D [N, S+1]" + if not torch.cuda.is_current_stream_capturing(): + assert (infer_state.b_num_accepted_tokens >= 1).all(), "num_accepted must be >= 1" + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + cu_seqlens=infer_state.b_gdn_verify_cu_seqlens.to(torch.long), + ssm_state_indices=infer_state.b_ssm_index_rows, + ssm_state_write_indices=infer_state.b_ssm_index_rows, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index e3c51f3617..0de85cdd45 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -23,7 +23,6 @@ @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): - # weight class pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight @@ -78,15 +77,26 @@ def _init_mem_manager(self): all_layer_num=self.config["n_layer"], ) + main_full_att = self.linear_config.all_layer_num - self.linear_config.linear_layer_num + draft_full_att_layers = 0 + if start_args.mtp_mode == "eagle_with_att": + draft_full_att_layers = 1 + elif start_args.mtp_mode == "vanilla_with_att": + draft_full_att_layers = start_args.mtp_step + self._main_full_att_layer_num = main_full_att + self._draft_full_att_layers = draft_full_att_layers + self.mem_manager = Qwen3NextMemManager( size=self.max_total_token_num, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - full_att_layer_num=self.linear_config.all_layer_num - self.linear_config.linear_layer_num, + full_att_layer_num=main_full_att + draft_full_att_layers, linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) + self.mem_manager.main_full_att_layer_num = main_full_att + self.mem_manager.draft_full_att_layers = draft_full_att_layers def _init_req_manager(self): create_max_seq_len = 0 diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py new file mode 100644 index 0000000000..137d61cbdc --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py @@ -0,0 +1,471 @@ +# Vendored from vLLM v0.14.1 +# source: vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# commit: d7de043d55d1dd629554467e23874097e1c48993 +# Adapted for LightLLM: imports point at standard triton; the vLLM-specific +# block-table params (block_idx_last_scheduled_token, initial_state_idx, +# null_block_id) are dropped — LightLLM uses contiguous per-request slots. +# Supports spec-decode: writes per-position conv state to a single widened slot +# per request and reads from offset (num_accepted_tokens-1). +# +# Upstream copyright notice: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from +# https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + # LightLLM uses contiguous per-request slots, so the cache block for both + # the initial-state read and the final write is always conv_state_indices[idx_seq]. + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init).to( + tl.int64 + ) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Write the updated state back. In LightLLM the read and write slots are the + # same contiguous per-request slot (current_last_index == conv_state_init == 0), + # so this resolves to the same conv_state_indices[idx_seq] used for the read. + conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index).to( + tl.int64 + ) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """Spec-decode capable conv1d update. When num_accepted_tokens/query_start_loc + are None it must behave like a single-token decode update. x may be (batch, dim) + single-token or (num_tokens, dim) flattened varlen with query_start_loc grouping + each request's S+1 candidates. conv_state is (num_slots, dim, state_len) with + state_len = (width-1)+S widened. Read offset = num_accepted_tokens-1; writes to + the same slot. + + Args: + x: input tensor of shape ``(batch, dim)`` (single-token decode), + ``(batch, dim, seqlen)`` (single/multi token), or ``(num_tokens, dim)`` + flattened varlen grouped by ``query_start_loc``. + conv_state: ``(num_slots, dim, state_len)`` with ``state_len >= width - 1``. + For spec decode the slot is widened to ``(width - 1) + S`` where ``S`` is + the number of speculative tokens (so ``seqlen == S + 1``). + weight: depthwise filter of shape ``(dim, width)``. + bias: optional ``(dim,)`` bias. + activation: ``None``, ``"silu"`` or ``"swish"``. + cache_seqlens: accepted for call-compatibility with the non-spec wrapper; + unused here. + conv_state_indices: ``(batch,)`` int32 mapping each request to its conv_state + slot. Required when ``query_start_loc`` is given. + num_accepted_tokens: ``(batch,)`` int32. When not None the conv_state read + offset for each request is ``num_accepted_tokens - 1`` (sliding window + spec-decode update). + query_start_loc: ``(batch + 1,)`` int32 varlen cumulative token offsets; when + None the call is a plain single-/multi-token decode update. + pad_slot_id: slot id that marks padded entries to skip. + + Returns: + Output tensor with the same shape as ``x`` (the kernel overwrites ``x`` in + place), one conv output per input token. + """ + if activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): + # Qwen3.5 MTP verify capture uses a uniform S+1 layout. Avoid a + # device-to-host sync on query_start_loc; .item() is illegal while a + # CUDA graph is being captured. + assert x.size(0) % batch == 0 + seqlen = x.size(0) // batch + else: + # max query len across the varlen batch + seqlen = int((query_start_loc[1:] - query_start_loc[:-1]).max().item()) + _, width = weight.shape + # conv_state: (num_slots, dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + # adopt the strategy in vLLM that overwrites 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (num_tokens, dim) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) diff --git a/unit_tests/common/test_qwen3next_draft_slots.py b/unit_tests/common/test_qwen3next_draft_slots.py new file mode 100644 index 0000000000..195ba7b724 --- /dev/null +++ b/unit_tests/common/test_qwen3next_draft_slots.py @@ -0,0 +1,32 @@ +def test_draft_layers_map_to_distinct_slots(): + # main full-att layers -> 0..M-1 ; draft layers -> M..M+D-1 (no overlap). + M, D = 16, 2 + main_slots = set(range(M)) + draft_slots = {M + d for d in range(D)} + assert main_slots.isdisjoint(draft_slots) + assert max(draft_slots) == M + D - 1 + + +def test_draft_kv_slot_mapping_via_interval_math(): + # Mirrors the runtime mapping in Qwen3_5MTPModel._assign_draft_kv_slot: + # the shared Qwen3NextMemManager maps layer_index -> layer_index // full_attention_interval. + # The draft sets layer_num_ = (main_full_att + draft_idx) * interval so the existing + # `// interval` math lands the draft at a dedicated slot past all main slots. + interval = 4 + main_full_att = 16 # n_layer=64, full_attention_interval=4 -> 16 main full-attn layers + + def mem_manager_slot(layer_index): + return layer_index // interval + + # main full-attn layers 3,7,...,63 -> slots 0..15 + main_layers = [li for li in range(64) if (li + 1) % interval == 0] + main_slots = {mem_manager_slot(li) for li in main_layers} + assert main_slots == set(range(main_full_att)) + + # draft layer with draft_idx=0 -> dedicated slot 16, non-colliding + draft_idx = 0 + draft_layer_num_ = (main_full_att + draft_idx) * interval + draft_slot = mem_manager_slot(draft_layer_num_) + assert draft_slot == main_full_att + draft_idx == 16 + assert draft_slot not in main_slots + assert main_full_att <= draft_slot < main_full_att + 1 diff --git a/unit_tests/models/qwen3next/test_causal_conv1d_spec.py b/unit_tests/models/qwen3next/test_causal_conv1d_spec.py new file mode 100644 index 0000000000..e99497ec33 --- /dev/null +++ b/unit_tests/models/qwen3next/test_causal_conv1d_spec.py @@ -0,0 +1,147 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +def _eager_conv_update(x_seq, conv_state, weight, bias, activation): + # x_seq: (dim, seqlen) tokens to roll in, conv_state: (dim, width-1) history + dim, width = weight.shape + state = conv_state.clone() # (dim, width-1) + outs = [] + for t in range(x_seq.shape[1]): + window = torch.cat([state, x_seq[:, t : t + 1]], dim=1) # (dim, width) + y = (window * weight).sum(dim=1) # depthwise conv + if bias is not None: + y = y + bias + if activation in ("silu", "swish"): + y = torch.nn.functional.silu(y) + outs.append(y) + state = window[:, 1:] # slide + return torch.stack(outs, dim=1), state + + +@pytest.mark.parametrize("S", [0, 1, 2, 3]) +def test_spec_conv_matches_eager_after_partial_accept(S): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width = 64, 4 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + committed_hist = torch.randn(dim, width - 1, device=device, dtype=dtype) + conv_state[0, :, : width - 1] = committed_hist + + x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens + + out = causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), # fresh: read offset 0 + query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), + ) + + ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") + torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_spec_conv_reads_from_partial_accept_offset(S): + # Exercise the nonzero read offset: num_accepted_tokens=2 -> read offset 1. + # The widened slot front-loads a STALE token then the real committed history; + # the kernel must read history starting at (num_accepted_tokens-1)==1, i.e. + # conv_state[:, 1:width], NOT the stale token at index 0. + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width = 64, 4 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + # tokens [0 .. width-1] hold [stale, h1, h2, ...]: a stale front token then history + seed = torch.randn(dim, width, device=device, dtype=dtype) + conv_state[0, :, :width] = seed + stale_front = conv_state[0, :, :width].clone() # snapshot of the seeded window + + x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens + + out = causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), + num_accepted_tokens=2 * torch.ones(1, dtype=torch.int32, device=device), # read offset 1 + query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), + ) + + # Eager reference starts from the offset-1 window: committed history excluding + # the stale front token == conv_state[:, 1:width]. + committed_hist = stale_front[:, 1:width] + ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") + torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) + + +def test_spec_conv_varlen_update_is_cuda_graph_capturable(): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width, S = 64, 4, 1 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + x = torch.randn(seqlen, dim, device=device, dtype=dtype) + conv_state_indices = torch.zeros(1, dtype=torch.int32, device=device) + num_accepted_tokens = torch.ones(1, dtype=torch.int32, device=device) + query_start_loc = torch.tensor([0, seqlen], dtype=torch.int32, device=device) + + # Compile/warm the Triton kernel before capture; the regression is the wrapper's + # host sync on query_start_loc during capture, not first-use compilation. + causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc, + ) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + static_x = x.clone() + with torch.cuda.graph(graph): + causal_conv1d_update( + static_x, + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc, + ) diff --git a/unit_tests/models/qwen3next/test_gdn_prefill_conv_indices.py b/unit_tests/models/qwen3next/test_gdn_prefill_conv_indices.py new file mode 100644 index 0000000000..a40e24d9e5 --- /dev/null +++ b/unit_tests/models/qwen3next/test_gdn_prefill_conv_indices.py @@ -0,0 +1,57 @@ +from types import SimpleNamespace + +import torch + + +def test_gdn_prefill_uses_one_slot_conv_indices(monkeypatch): + from lightllm.models.qwen3next.layer_infer import transformer_layer_infer as layer_mod + + layer = layer_mod.Qwen3NextTransformerLayerInfer.__new__(layer_mod.Qwen3NextTransformerLayerInfer) + layer.activation = "silu" + layer.needs_ssm_dtype_conversion = False + + captured = {} + + def fake_causal_conv1d_fn(mixed_qkv, *args, cache_indices=None, **kwargs): + captured["cache_indices"] = cache_indices.detach().cpu().clone() + return mixed_qkv + + def fake_fused_gdn_gating(*args, **kwargs): + return torch.zeros(3, 1), torch.ones(3, 1) + + def fake_chunk_gated_delta_rule(*args, **kwargs): + return torch.zeros(1, 3, 1, 1), torch.zeros(3, 1) + + def fake_rearrange_mixed_qkv(*args, **kwargs): + return torch.zeros(1, 3, 1, 1), torch.zeros(1, 3, 1, 1), torch.zeros(1, 3, 1, 1) + + monkeypatch.setattr(layer_mod, "causal_conv1d_fn", fake_causal_conv1d_fn) + monkeypatch.setattr(layer_mod, "fused_gdn_gating", fake_fused_gdn_gating) + monkeypatch.setattr(layer_mod, "chunk_gated_delta_rule", fake_chunk_gated_delta_rule) + layer._rearrange_mixed_qkv = fake_rearrange_mixed_qkv + + infer_state = SimpleNamespace( + # SSM keeps an (S+1)-slot block per request; for S=1 these are 0,2,4. + b_buffer_idx=torch.tensor([0, 2, 4], dtype=torch.int64), + # Conv keeps one widened slot per request; prefill must write 0,1,2. + b_conv_buffer_idx=torch.tensor([0, 1, 2], dtype=torch.int64), + b1_cu_q_seq_len=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + b_ready_cache_len=torch.zeros(3, dtype=torch.int32), + ) + layer_weight = SimpleNamespace( + linear_conv1d=SimpleNamespace(mm_param=SimpleNamespace(weight=torch.zeros(1, 1)), bias=None), + linear_A_log=SimpleNamespace(weight=torch.zeros(1)), + linear_dt_bias=SimpleNamespace(weight=torch.zeros(1)), + ) + + layer._gdn_prefill_kernel( + mixed_qkv=torch.zeros(3, 1), + conv_states=torch.zeros(3, 1, 1), + ssm_states=torch.zeros(6, 1), + a=torch.zeros(3, 1), + b=torch.zeros(3, 1), + infer_state=infer_state, + layer_weight=layer_weight, + ) + + assert captured["cache_indices"].tolist() == [0, 1, 2] diff --git a/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py b/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py new file mode 100644 index 0000000000..7481607d54 --- /dev/null +++ b/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py @@ -0,0 +1,194 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_gdn_verify_state_equals_sequential_decode(S): + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + T = S + 1 + device = "cuda" + + def rand_qkv(t): + q = torch.randn(1, t, HV, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, t, HV, K, device=device), dim=-1) + v = torch.randn(1, t, HV, V, device=device) + g = torch.nn.functional.logsigmoid(torch.rand(1, t, HV, device=device)) + beta = torch.rand(1, t, HV, device=device).sigmoid() + return q, k, v, g, beta + + q, k, v, g, beta = rand_qkv(T) + + ref_state = torch.zeros(1, HV, K, V, device=device) + for t in range(T): + _, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + g=g[:, t : t + 1], + beta=beta[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + ) + + block = torch.zeros(T, HV, K, V, device=device) + ssm_idx = torch.arange(T, device=device).view(1, T) + fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=block, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), + ) + torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_gdn_verify_output_equals_sequential_decode_fused(S): + """H1: the LIVE verify combination - varlen + FUSED gating (A_log/dt_bias/a_raw/b_raw) + + spec-decode - must produce per-position OUTPUT o[t] identical to running the proven + T=1 decode recurrence sequentially. The original test only checked the final SSM state + with EXPLICIT g/beta; it never verified o[t] nor the fused-gating path that + _gdn_verify_kernel actually uses.""" + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + H = HV + T = S + 1 + device = "cuda" + + q = torch.randn(1, T, H, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, T, H, K, device=device), dim=-1) + v = torch.randn(1, T, HV, V, device=device) + # Raw gating inputs (pre-activation), exactly as the model feeds the fused path. + a_raw = torch.randn(T, HV, device=device) + b_raw = torch.randn(T, HV, device=device) + A_log = torch.randn(HV, device=device) + dt_bias = torch.randn(HV, device=device) + + # Reference: sequential T=1 decode through the proven non-varlen fused path. + ref_state = torch.zeros(1, HV, K, V, device=device) + ref_o = torch.zeros(T, HV, V, device=device) + for t in range(T): + o_t, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw[t : t + 1], + b_raw=b_raw[t : t + 1], + ) + ref_o[t] = o_t[0, 0] + + # Verify path: single varlen call with fused gating + spec-decode indices, + # mirroring _gdn_verify_kernel for a single request, num_accepted=1. + block = torch.zeros(T, HV, K, V, device=device) + ssm_idx = torch.arange(T, device=device).view(1, T) + o, _ = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + initial_state=block, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + ) + o = o.view(T, HV, V) + torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize("num_accepted", [1, 2]) +def test_gdn_verify_reads_committed_slot_by_num_accepted(num_accepted): + """The verify kernel must read the per-request initial state from the SSM block + slot at offset (num_accepted-1) -- i.e. the state committed after the previous + step's last accepted token. This is the read path exercised by the FIRST decode + after an accept-`num_accepted` step. A decoy is written into the OTHER block slot + to prove the kernel reads the correct one and ignores the rest of the (S+1) block.""" + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + S = 1 + T = S + 1 + device = "cuda" + + q = torch.randn(1, T, HV, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, T, HV, K, device=device), dim=-1) + v = torch.randn(1, T, HV, V, device=device) + a_raw = torch.randn(T, HV, device=device) + b_raw = torch.randn(T, HV, device=device) + A_log = torch.randn(HV, device=device) + dt_bias = torch.randn(HV, device=device) + + # (S+1) block: the committed slot is (num_accepted-1); the others hold decoys + # that MUST NOT be read. + block = torch.randn(T, HV, K, V, device=device) * 5.0 + committed = torch.randn(1, HV, K, V, device=device) + block[num_accepted - 1] = committed[0] + + ref_state = committed.clone() + ref_o = torch.zeros(T, HV, V, device=device) + for t in range(T): + o_t, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw[t : t + 1], + b_raw=b_raw[t : t + 1], + ) + ref_o[t] = o_t[0, 0] + + blk = block.clone() + ssm_idx = torch.arange(T, device=device).view(1, T) + o, _ = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + initial_state=blk, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.tensor([num_accepted], dtype=torch.int32, device=device), + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + ) + o = o.view(T, HV, V) + torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) From e12cf70146aaa596416c091f61cf1e9908115e11 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 13:26:41 +0800 Subject: [PATCH 03/10] feat(qwen3_5_mtp): basemodel MTP decode CUDA graphs + verify dispatch Wire MTP into the base model decode path: - capture/replay decode CUDA graphs for the MTP verify step and thread b_num_accepted_tokens through ModelInput / InferStateInfo. - add the MTP-verify dispatch in basemodel and pass the per-position draft index into the FA3 attention backends (fp / fp8 / mla). Covered by the MTP decode CUDA-graph unit test. --- lightllm/common/basemodel/attention/fa3/fp.py | 42 +- .../common/basemodel/attention/fa3/fp8.py | 7 +- .../common/basemodel/attention/fa3/mla.py | 42 +- lightllm/common/basemodel/basemodel.py | 252 +++++++++-- lightllm/common/basemodel/batch_objs.py | 4 + lightllm/common/basemodel/cuda_graph.py | 274 +++++++----- lightllm/common/basemodel/infer_struct.py | 2 + .../basemodel/test_mtp_decode_cuda_graph.py | 398 ++++++++++++++++++ 8 files changed, 835 insertions(+), 186 deletions(-) create mode 100644 unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 5b7960e715..0d53b44b68 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,12 +1,19 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import ( + BaseAttBackend, + BasePrefillAttState, + BaseDecodeAttState, + AttControl, +) from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy -from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import ( + gen_cumsum_pad0_tensor, +) class Fa3AttBackend(BaseAttBackend): @@ -21,12 +28,14 @@ def get_page_table_buffer(self): model = self.model if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), + torch.empty( + model.graph_max_batch_size * model.graph_max_len_in_batch, + dtype=torch.int32, + ).to(get_current_device_id()), + torch.empty( + model.graph_max_batch_size * model.graph_max_len_in_batch, + dtype=torch.int32, + ).to(get_current_device_id()), ] return self._shared_page_table_buffer @@ -75,7 +84,12 @@ def prefill_att( ) def _nomarl_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing @@ -125,8 +139,9 @@ class Fa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: Fa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step + is_mtp_verify_decode = args_mtp_step > 0 and self.infer_state.b_num_accepted_tokens is not None - if args_mtp_step > 0: + if is_mtp_verify_decode: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -143,8 +158,9 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 + att_batch_size = self.infer_state.batch_size // mtp_size + assert self.infer_state.batch_size % mtp_size == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -163,7 +179,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if args_mtp_step > 0: + if is_mtp_verify_decode: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index acbb1315fe..9a32ef7a9f 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -3,7 +3,6 @@ from ..base_att import AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union @@ -116,12 +115,8 @@ def init_state(self): super().init_state() self.backend: Fp8Fa3AttBackend = self.backend - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 - device = self.infer_state.input_ids.device - batch_size = att_batch_size + batch_size = self.b_att_seq_len.shape[0] mem_manager = self.backend.model.mem_manager offline_scales: torch.Tensor = mem_manager.scales diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 9a10457b12..2ed9ba4112 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -1,12 +1,19 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import ( + BaseAttBackend, + BasePrefillAttState, + BaseDecodeAttState, + AttControl, +) from typing import Optional, TYPE_CHECKING, Tuple from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy -from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import ( + gen_cumsum_pad0_tensor, +) from lightllm.utils.sgl_utils import flash_attn_varlen_func @@ -22,12 +29,14 @@ def get_page_table_buffer(self): model = self.model if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), + torch.empty( + model.graph_max_batch_size * model.graph_max_len_in_batch, + dtype=torch.int32, + ).to(get_current_device_id()), + torch.empty( + model.graph_max_batch_size * model.graph_max_len_in_batch, + dtype=torch.int32, + ).to(get_current_device_id()), ] return self._shared_page_table_buffer @@ -69,7 +78,12 @@ def prefill_att( ) def _mla_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, ) -> torch.Tensor: self.backend: MlaFa3AttBackend = self.backend # for typing k_nope, k_rope = k @@ -108,8 +122,9 @@ class MlaFa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: MlaFa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step + is_mtp_verify_decode = args_mtp_step > 0 and self.infer_state.b_num_accepted_tokens is not None - if args_mtp_step > 0: + if is_mtp_verify_decode: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -126,8 +141,9 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 + att_batch_size = self.infer_state.batch_size // mtp_size + assert self.infer_state.batch_size % mtp_size == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -146,7 +162,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if args_mtp_step > 0: + if is_mtp_verify_decode: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index b4726754f4..a972b06149 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -4,6 +4,7 @@ import gc import copy import json +import math import torch import torch.nn.functional as F import triton @@ -17,20 +18,32 @@ from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import ( + copy_kv_index_to_req, +) from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg -from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed +from lightllm.common.basemodel.triton_kernel.gather_token_id import ( + gather_token, + gather_token_prefill_decode_mixed, +) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num +from lightllm.utils.envs_utils import ( + get_env_start_args, + get_llm_data_type, + get_added_mtp_kv_layer_num, +) from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.common.triton_utils.autotuner import AutotuneLevel from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch -from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel +from lightllm.utils.envs_utils import ( + set_model_init_status, + enable_diverse_mode_gqa_decode_fast_kernel, +) from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache from .attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -321,6 +334,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len infer_state.b_mtp_index = model_input.b_mtp_index + infer_state.b_num_accepted_tokens = model_input.b_num_accepted_tokens if model_input.is_prefill: if model_input.b_ready_cache_len is not None: infer_state.b_ready_cache_len = model_input.b_ready_cache_len @@ -358,6 +372,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) return infer_state + def _get_decode_padding_unit(self, model_input: ModelInput) -> int: + padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1 + if self.args.mtp_step > 0 and (not model_input.is_prefill) and model_input.b_num_accepted_tokens is not None: + padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1) + return padding_unit + + def _get_decode_infer_batch_size(self, model_input: ModelInput) -> int: + padding_unit = self._get_decode_padding_unit(model_input) + return triton.cdiv(model_input.batch_size, padding_unit) * padding_unit + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): if model_input.batch_size == new_batch_size: return model_input @@ -367,22 +391,111 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s padded_batch_size = new_batch_size - model_input.batch_size new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size - new_model_input.total_token_num += padded_batch_size * 2 - new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) - new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) - new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID - ) - new_model_input.b_mtp_index = F.pad( - new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 - ) - new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) - new_model_input.mem_indexes = F.pad( - new_model_input.mem_indexes, - (0, padded_batch_size), - mode="constant", - value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + + is_mtp_verify_decode = ( + self.args.mtp_step > 0 and (not model_input.is_prefill) and model_input.b_num_accepted_tokens is not None ) + if is_mtp_verify_decode: + mtp_size = self.args.mtp_step + 1 + assert model_input.batch_size % mtp_size == 0 + assert new_batch_size % mtp_size == 0 + assert padded_batch_size % mtp_size == 0 + padded_req_num = padded_batch_size // mtp_size + + pad_mtp_index = torch.arange( + mtp_size, + dtype=new_model_input.b_mtp_index.dtype, + device=new_model_input.b_mtp_index.device, + ).repeat(padded_req_num) + pad_seq_len = torch.arange( + 2, + mtp_size + 2, + dtype=new_model_input.b_seq_len.dtype, + device=new_model_input.b_seq_len.device, + ).repeat(padded_req_num) + new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) + new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) + new_model_input.input_ids = torch.cat( + ( + new_model_input.input_ids, + torch.ones( + padded_batch_size, + dtype=new_model_input.input_ids.dtype, + device=new_model_input.input_ids.device, + ), + ), + dim=0, + ) + new_model_input.b_req_idx = torch.cat( + ( + new_model_input.b_req_idx, + torch.full( + (padded_batch_size,), + self.req_manager.HOLD_REQUEST_ID, + dtype=new_model_input.b_req_idx.dtype, + device=new_model_input.b_req_idx.device, + ), + ), + dim=0, + ) + new_model_input.b_mtp_index = torch.cat((new_model_input.b_mtp_index, pad_mtp_index), dim=0) + new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) + new_model_input.mem_indexes = torch.cat( + ( + new_model_input.mem_indexes, + torch.full( + (padded_batch_size,), + self.mem_manager.HOLD_TOKEN_MEMINDEX, + dtype=new_model_input.mem_indexes.dtype, + device=new_model_input.mem_indexes.device, + ), + ), + dim=0, + ) + new_model_input.b_num_accepted_tokens = torch.cat( + ( + new_model_input.b_num_accepted_tokens, + torch.ones( + padded_req_num, + dtype=new_model_input.b_num_accepted_tokens.dtype, + device=new_model_input.b_num_accepted_tokens.device, + ), + ), + dim=0, + ) + else: + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + new_model_input.input_ids = F.pad( + new_model_input.input_ids, + (0, padded_batch_size), + mode="constant", + value=1, + ) + new_model_input.b_req_idx = F.pad( + new_model_input.b_req_idx, + (0, padded_batch_size), + mode="constant", + value=self.req_manager.HOLD_REQUEST_ID, + ) + new_model_input.b_mtp_index = F.pad( + new_model_input.b_mtp_index, + (0, padded_batch_size), + mode="constant", + value=0, + ) + new_model_input.b_seq_len = F.pad( + new_model_input.b_seq_len, + (0, padded_batch_size), + mode="constant", + value=2, + ) + new_model_input.mem_indexes = F.pad( + new_model_input.mem_indexes, + (0, padded_batch_size), + mode="constant", + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + ) new_model_input.multimodal_params = new_model_input.multimodal_params + [ {"images": [], "audios": []} for _ in range(padded_batch_size) ] @@ -390,11 +503,17 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s if enable_diverse_mode_gqa_decode_fast_kernel(): if new_model_input.b_shared_seq_len is not None: new_model_input.b_shared_seq_len = F.pad( - new_model_input.b_shared_seq_len, (0, padded_batch_size), mode="constant", value=0 + new_model_input.b_shared_seq_len, + (0, padded_batch_size), + mode="constant", + value=0, ) if new_model_input.b_mark_shared_group is not None: new_model_input.b_mark_shared_group = F.pad( - new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1 + new_model_input.b_mark_shared_group, + (0, padded_batch_size), + mode="constant", + value=1, ) # 特殊模型,特殊模式的特殊变量的特殊 padding @@ -429,7 +548,10 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID + new_model_input.b_req_idx, + (0, 1), + mode="constant", + value=self.req_manager.HOLD_REQUEST_ID, ) new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num) @@ -469,7 +591,10 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba return new_model_output def _create_unpad_prefill_model_output( - self, padded_model_output: ModelOutput, origin_handle_token_num: int, origin_batch_size: int + self, + padded_model_output: ModelOutput, + origin_handle_token_num: int, + origin_batch_size: int, ): if self.return_all_prompt_logics: new_model_output = copy.copy(padded_model_output) @@ -555,15 +680,16 @@ def _decode( ) origin_batch_size = model_input.batch_size - if self.args.enable_tpsp_mix_mode: - infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_ - else: - infer_batch_size = model_input.batch_size + infer_batch_size = self._get_decode_infer_batch_size(model_input) + is_mtp_verify_decode = self.args.mtp_step > 0 and model_input.b_num_accepted_tokens is not None if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len ): - infer_batch_size = self.graph.find_closest_graph_batch_size(batch_size=infer_batch_size) + infer_batch_size = self.graph.find_closest_graph_batch_size( + batch_size=infer_batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) @@ -577,7 +703,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -809,10 +935,14 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) - infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_ + infer_batch_size = self._get_decode_infer_batch_size(model_input0) + is_mtp_verify_decode = self.args.mtp_step > 0 and model_input0.b_num_accepted_tokens is not None if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): - infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) + infer_batch_size = self.graph.find_closest_graph_batch_size( + infer_batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) # TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不 # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) @@ -837,7 +967,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_some_extra_state(self) infer_state1.init_att_state() - if self.graph.need_capture(infer_batch_size): + if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True @@ -889,7 +1019,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state g_cache_manager.cache_env_in() input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, + infer_state1.input_ids, + infer_state, + infer_state1, + self.pre_post_weight, ) # 决定是否进行 dp balance 优化,可以提升dp > 1 时的 prefill 效率。 @@ -905,7 +1039,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( - input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + input_embs, + input_embs1, + infer_state, + infer_state1, + self.trans_layers_weight[i], ) # 折叠模式调用完infer_state 和 infer_state1 上的hook函数后,input_embs 和 input_embs1 才具备正确的运算数据。 @@ -919,7 +1057,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state last_input_embs1 = infer_state1._all_to_all_unbalance_get(data=last_input_embs1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, + last_input_embs1, + infer_state, + infer_state1, + self.pre_post_weight, ) g_cache_manager.cache_env_out() @@ -940,14 +1082,22 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state @final def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( - infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, + infer_state1.input_ids, + infer_state, + infer_state1, + self.pre_post_weight, ) input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) input_embs1 = self.pre_infer._tpsp_sp_split(input=input_embs1, infer_state=infer_state1) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( - input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + input_embs, + input_embs1, + infer_state, + infer_state1, + self.trans_layers_weight[i], ) # 折叠模式调用完infer_state 上的hook函数后,input_embs 和 input_embs 才具备正确的运算数据。 @@ -958,7 +1108,11 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, + last_input_embs1, + infer_state, + infer_state1, + self.pre_post_weight, ) model_output = ModelOutput(logits=predict_logits.contiguous()) @@ -1065,7 +1219,12 @@ def _autotune_warmup(self): rand_gen = torch.Generator(device="cuda") rand_gen.manual_seed(input_len) dummy_input_ids = torch.randint( - 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen + 0, + 10000, + (input_len,), + dtype=torch.int32, + device="cuda", + generator=rand_gen, ) b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() @@ -1129,10 +1288,14 @@ def _init_padded_req(self): batch_size = 1 dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") b_req_idx = torch.tensor( - [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], + dtype=torch.int32, + device="cuda", ) mem_indexes = torch.tensor( - [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], + dtype=torch.int32, + device="cuda", ) b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") @@ -1181,10 +1344,15 @@ def _gen_special_model_input(self, token_num: int): or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) or "Glm4MoeLiteMTPModel" in str(self.__class__) + or "Qwen3_5MTPModel" in str(self.__class__) + or "Qwen3_5MoeMTPModel" in str(self.__class__) ) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( - token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" + token_num, + self.config["hidden_size"], + dtype=self.data_type, + device="cuda", ) else: special_model_input["mtp_draft_input_hiddens"] = None diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a82..03cb36d28d 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -53,6 +53,8 @@ class ModelInput: # 的 draft 模型的输入 mtp_draft_input_hiddens: Optional[torch.Tensor] = None + b_num_accepted_tokens: Optional[torch.Tensor] = None + def to_cuda(self): if self.input_ids is not None: self.input_ids = self.input_ids.cuda(non_blocking=True) @@ -66,6 +68,8 @@ def to_cuda(self): self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) + if self.b_num_accepted_tokens is not None: + self.b_num_accepted_tokens = self.b_num_accepted_tokens.cuda(non_blocking=True) if self.b_ready_cache_len is not None: self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) if self.b_prefill_start_loc is not None: diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..8f5c702994 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch import copy import bisect +import math import triton from typing import Optional from lightllm.utils.log_utils import init_logger @@ -27,48 +28,139 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.normal_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=1) + if self.mtp_step > 0: + self.mtp_verify_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes( + batch_size_multiple=self.mtp_step + 1 + ) + self.cuda_graph_batch_sizes = self.mtp_verify_cuda_graph_batch_sizes + logger.info(f"normal cuda graph batch_sizes: {self.normal_cuda_graph_batch_sizes}") + logger.info(f"mtp verify cuda graph batch_sizes: {self.mtp_verify_cuda_graph_batch_sizes}") + else: + self.mtp_verify_cuda_graph_batch_sizes = self.normal_cuda_graph_batch_sizes + self.cuda_graph_batch_sizes = self.normal_cuda_graph_batch_sizes + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + + def _build_cuda_graph_batch_sizes(self, batch_size_multiple: int): # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) - - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) - - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] - for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + # and [graph_split_batch_size + graph_grow_step_size, ...] + graph_split_batch_size = self.args.graph_split_batch_size * batch_size_multiple + graph_grow_step_size = self.args.graph_grow_step_size * batch_size_multiple + + batch_sizes = [i * batch_size_multiple for i in range(1, self.args.graph_split_batch_size + 1)] + for _batch_size in range( + graph_split_batch_size + graph_grow_step_size, + self.max_batch_size, + graph_grow_step_size, + ): batch_sizes.append(_batch_size) - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) - batch_sizes.append(max_batch_size) + batch_sizes = list(set([e for e in batch_sizes if e < self.max_batch_size])) + batch_sizes.append(self.max_batch_size) batch_sizes.sort() if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + padding_unit = math.lcm(self.tp_world_size, batch_size_multiple) + batch_sizes = [triton.cdiv(e, padding_unit) * padding_unit for e in batch_sizes] batch_sizes = list(set(batch_sizes)) batch_sizes.sort() - self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size - logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + return batch_sizes def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - def need_capture(self, batch_size): - find_batch_size = self.find_closest_graph_batch_size(batch_size) + def _decode_graph_key(self, infer_state: InferStateInfo): + is_mtp_verify_decode = self.mtp_step > 0 and infer_state.b_num_accepted_tokens is not None + return (infer_state.input_ids.shape[0], is_mtp_verify_decode) + + def need_capture(self, batch_size, is_mtp_verify_decode=False): + find_batch_size = self.find_closest_graph_batch_size(batch_size, is_mtp_verify_decode=is_mtp_verify_decode) if find_batch_size is not None: - return find_batch_size not in self.graph + return (find_batch_size, is_mtp_verify_decode) not in self.graph else: assert False, "dead code" - def find_closest_graph_batch_size(self, batch_size): - index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) - if index < len(self.cuda_graph_batch_sizes): - find_batch_size = self.cuda_graph_batch_sizes[index] + def _get_graph_batch_sizes(self, is_mtp_verify_decode=False): + if not hasattr(self, "normal_cuda_graph_batch_sizes"): + return self.cuda_graph_batch_sizes + if is_mtp_verify_decode: + return self.mtp_verify_cuda_graph_batch_sizes + return self.normal_cuda_graph_batch_sizes + + def find_closest_graph_batch_size(self, batch_size, is_mtp_verify_decode=False): + graph_batch_sizes = self._get_graph_batch_sizes(is_mtp_verify_decode=is_mtp_verify_decode) + index = bisect.bisect_left(graph_batch_sizes, batch_size) + if index < len(graph_batch_sizes): + find_batch_size = graph_batch_sizes[index] return find_batch_size else: return None + def _build_warmup_decode_model_input( + self, + model, + batch_size: int, + device: str = "cuda", + is_mtp_verify_decode: Optional[bool] = None, + ) -> ModelInput: + if is_mtp_verify_decode is None: + is_mtp_verify_decode = self.mtp_step > 0 + + mtp_size = self.mtp_step + 1 + input_ids = torch.ones(batch_size, dtype=torch.int32, device=device) + mem_indexes = model.mem_manager.alloc(batch_size).to(device) + b_req_idx = torch.full( + (batch_size,), + fill_value=model.req_manager.HOLD_REQUEST_ID, + dtype=torch.int32, + device=device, + ) + + b_num_accepted_tokens = None + if self.mtp_step > 0 and is_mtp_verify_decode: + assert batch_size % mtp_size == 0, "MTP decode CUDA graph batch size must be a multiple of mtp_step + 1" + real_batch_size = batch_size // mtp_size + b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) + b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size) + b_num_accepted_tokens = torch.ones(real_batch_size, dtype=torch.int32, device=device) + total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2) + else: + seq_len = 2 + total_token_num = batch_size * seq_len + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=device) + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=device) + b_seq_len.fill_(seq_len) + + return ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_q_seq_len=1, + max_kv_seq_len=self.graph_max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + b_num_accepted_tokens=b_num_accepted_tokens, + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], + **model._gen_special_model_input(batch_size), + ) + + def _is_mtp_draft_model(self, model): + return "MTPModel" in str(model.__class__) + + def _iter_warmup_graph_layouts(self, model): + if self.mtp_step > 0: + if self._is_mtp_draft_model(model): + yield False, self.normal_cuda_graph_batch_sizes + else: + yield True, self.mtp_verify_cuda_graph_batch_sizes + else: + yield False, self.normal_cuda_graph_batch_sizes + def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids @@ -96,7 +188,11 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) - self.graph[batch_size] = (graph_obj, infer_state, model_output) + self.graph[self._decode_graph_key(infer_state)] = ( + graph_obj, + infer_state, + model_output, + ) graph_obj.replay() return model_output @@ -130,7 +226,7 @@ def _capture_decode_overlap( with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) - self.graph[batch_size] = ( + self.graph[self._decode_graph_key(infer_state)] = ( graph_obj, infer_state, infer_state1, @@ -157,8 +253,7 @@ def capture_decode( return self._capture_decode(decode_func, infer_state) def _replay(self, infer_state: InferStateInfo): - batch_size = infer_state.input_ids.shape[0] - graph_obj, graph_infer_state, graph_output = self.graph[batch_size] + graph_obj, graph_infer_state, graph_output = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output @@ -168,14 +263,13 @@ def _replay_overlap( infer_state: InferStateInfo, infer_state1: InferStateInfo, ): - batch_size = infer_state.input_ids.shape[0] ( graph_obj, graph_infer_state, graph_infer_state1, graph_model_output, graph_model_output1, - ) = self.graph[batch_size] + ) = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() @@ -197,47 +291,23 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) - model_output: ModelOutput = model.forward(model_input) - del model_output - del input_ids - del mem_indexes - del b_req_idx - del b_seq_len - - model.mem_manager.free_all() - model.req_manager.free_all() - # release local tensors - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): + for batch_size in batch_sizes[::-1]: + model_input = self._build_warmup_decode_model_input( + model, + batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) + model_output: ModelOutput = model.forward(model_input) + del model_output + + model.mem_manager.free_all() + model.req_manager.free_all() + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " @@ -252,56 +322,36 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: - decode_batches = [] - for micro_batch_index in [0, 1]: - # dummy decoding, capture the cudagraph - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - micro_batch = ModelInput( - is_prefill=False, - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - b_mtp_index=b_mtp_index, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) - decode_batches.append(micro_batch) - del micro_batch + for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): + for batch_size in batch_sizes[::-1]: + decode_batches = [] + for micro_batch_index in [0, 1]: + # dummy decoding, capture the cudagraph + micro_batch = self._build_warmup_decode_model_input( + model, + batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) + decode_batches.append(micro_batch) + del micro_batch - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() - _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) + _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) - model.mem_manager.free_all() - model.req_manager.free_all() + model.mem_manager.free_all() + model.req_manager.free_all() - del decode_batches + del decode_batches - # release local tensors - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() logger.info( f"Capture overlap cudagraph success, batch_size <={self.max_batch_size} " diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..6de15f8910 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -39,6 +39,8 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None + self.b_num_accepted_tokens: torch.Tensor = None + self.b_seq_len: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py new file mode 100644 index 0000000000..d29bb1104e --- /dev/null +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -0,0 +1,398 @@ +from types import SimpleNamespace + +import torch + +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.batch_objs import ModelInput + + +def test_mtp_decode_cuda_graph_warmup_uses_verify_layout(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph_max_len_in_batch = 128 + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, size): + return torch.arange(size, dtype=torch.int32) + + model = SimpleNamespace( + req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), + mem_manager=FakeMemManager(), + _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": None}, + ) + + model_input = graph._build_warmup_decode_model_input(model, batch_size=6, device="cpu") + + assert model_input.batch_size == 6 + assert model_input.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] + assert model_input.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] + assert model_input.b_num_accepted_tokens.tolist() == [1, 1] + assert model_input.total_token_num == 18 + + +def test_mtp_decode_cuda_graph_warmup_supports_normal_layout_for_draft(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph_max_len_in_batch = 128 + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, size): + return torch.arange(size, dtype=torch.int32) + + model = SimpleNamespace( + req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), + mem_manager=FakeMemManager(), + _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": torch.full((token_num, 4), 3.0)}, + ) + + model_input = graph._build_warmup_decode_model_input( + model, + batch_size=5, + device="cpu", + is_mtp_verify_decode=False, + ) + + assert model_input.batch_size == 5 + assert model_input.b_mtp_index.tolist() == [0, 0, 0, 0, 0] + assert model_input.b_seq_len.tolist() == [2, 2, 2, 2, 2] + assert model_input.b_num_accepted_tokens is None + assert model_input.total_token_num == 10 + assert model_input.mtp_draft_input_hiddens.shape == (5, 4) + + +def test_mtp_decode_cuda_graph_keys_verify_and_normal_layouts(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph = {} + graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] + graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9, 12] + graph.cuda_graph_batch_sizes = graph.mtp_verify_cuda_graph_batch_sizes + + verify_state = SimpleNamespace( + input_ids=torch.ones(6, dtype=torch.int64), + b_num_accepted_tokens=torch.ones(2, dtype=torch.int32), + ) + normal_state = SimpleNamespace( + input_ids=torch.ones(6, dtype=torch.int64), + b_num_accepted_tokens=None, + ) + + assert graph._decode_graph_key(verify_state) == (6, True) + assert graph._decode_graph_key(normal_state) == (6, False) + assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=True) == 6 + assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=False) == 8 + + graph.graph[(6, True)] = "verify graph" + assert graph.need_capture(6, is_mtp_verify_decode=True) is False + assert graph.need_capture(5, is_mtp_verify_decode=False) is True + + +def test_mtp_decode_cuda_graph_warmup_layouts_split_main_and_draft_models(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + class Qwen3_5MOETpPartModel: + pass + + class Qwen3_5MoeMTPModel: + pass + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] + graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9] + + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(True, [3, 6, 9])] + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(False, [1, 2, 4, 8])] + + +def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): + import pytest + + if not torch.cuda.is_available(): + pytest.skip("needs CUDA for gen_decode_params") + + import lightllm.models.qwen3next.infer_struct as infer_struct_mod + from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo + + monkeypatch.setattr(infer_struct_mod, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + state = Qwen3NextInferStateInfo() + state.is_prefill = False + state.b_req_idx = torch.tensor([5, 5, 5, 6, 6, 6], dtype=torch.int32, device="cuda") + state.b_mtp_index = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32, device="cuda") + state.b_seq_len = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32, device="cuda") + state.b_num_accepted_tokens = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + + model = SimpleNamespace( + _cos_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), + _sin_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), + ) + + state.init_some_extra_state(model) + + assert state.is_mtp_verify is True + assert state.b_gdn_verify_cu_seqlens.tolist() == [0, 3, 6] + assert state.b_conv_buffer_idx.tolist() == [5, 6] + assert state.b_ssm_index_rows.tolist() == [[15, 16, 17], [18, 19, 20]] + + +def test_mtp_decode_padding_preserves_verify_groups(monkeypatch): + import lightllm.common.basemodel.basemodel as basemodel_mod + + monkeypatch.setattr(basemodel_mod, "enable_diverse_mode_gqa_decode_fast_kernel", lambda: False) + + model = TpPartBaseModel.__new__(TpPartBaseModel) + model.args = SimpleNamespace(mtp_step=2) + model.req_manager = SimpleNamespace(HOLD_REQUEST_ID=99) + model.mem_manager = SimpleNamespace(HOLD_TOKEN_MEMINDEX=-1) + + model_input = ModelInput( + batch_size=3, + total_token_num=12, + max_q_seq_len=1, + max_kv_seq_len=4, + input_ids=torch.tensor([10, 11, 12], dtype=torch.int32), + mem_indexes=torch.tensor([20, 21, 22], dtype=torch.int32), + b_req_idx=torch.tensor([7, 7, 7], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([2, 3, 4], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([2], dtype=torch.int32), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(3)], + ) + + padded = model._create_padded_decode_model_input(model_input, new_batch_size=6) + + assert padded.batch_size == 6 + assert padded.b_req_idx.tolist() == [7, 7, 7, 99, 99, 99] + assert padded.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] + assert padded.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] + assert padded.b_num_accepted_tokens.tolist() == [2, 1] + assert padded.mem_indexes.tolist() == [20, 21, 22, -1, -1, -1] + assert len(padded.multimodal_params) == 6 + + +def test_qwen3next_hybrid_mtp_keeps_decode_cuda_graph_enabled(monkeypatch): + import lightllm.models.qwen3next.model as qwen3next_model + from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + + monkeypatch.setattr(qwen3next_model, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + called = {} + + def fake_base_init_cudagraph(self): + called["disable_cudagraph"] = self.disable_cudagraph + self.graph = "captured" + + monkeypatch.setattr(TpPartBaseModel, "_init_cudagraph", fake_base_init_cudagraph) + + model = Qwen3NextTpPartModel.__new__(Qwen3NextTpPartModel) + model.disable_cudagraph = False + + Qwen3NextTpPartModel._init_cudagraph(model) + + assert called["disable_cudagraph"] is False + assert model.disable_cudagraph is False + assert model.graph == "captured" + + +def test_fa3_decode_uses_normal_layout_for_narrowed_mtp_draft(monkeypatch): + import lightllm.common.basemodel.attention.fa3.fp as fa3_fp + from lightllm.common.basemodel.attention.fa3.fp import Fa3DecodeAttState + + monkeypatch.setattr(fa3_fp, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + copied = {} + + def fake_page_table_copy(page_table, req_to_token_indexs, b_req_idx): + copied["page_table_shape"] = tuple(page_table.shape) + copied["b_req_idx"] = b_req_idx.clone() + + monkeypatch.setattr(fa3_fp, "page_table_copy", fake_page_table_copy) + + model = SimpleNamespace( + graph_max_batch_size=16, + graph_max_len_in_batch=32, + req_manager=SimpleNamespace(req_to_token_indexs=torch.empty((8, 32), dtype=torch.int32)), + ) + backend = SimpleNamespace( + model=model, + get_page_table_buffer=lambda: [torch.empty(16 * 32, dtype=torch.int32)], + ) + infer_state = SimpleNamespace( + batch_size=2, + max_kv_seq_len=16, + input_ids=torch.ones(2, dtype=torch.int64), + b_seq_len=torch.tensor([5, 7], dtype=torch.int32), + b1_cu_q_seq_len=torch.tensor([0, 1, 2], dtype=torch.int32), + b1_cu_kv_seq_len=torch.tensor([0, 5, 12], dtype=torch.int32), + b_req_idx=torch.tensor([3, 4], dtype=torch.int32), + b_num_accepted_tokens=None, + microbatch_index=0, + ) + + state = Fa3DecodeAttState(backend=backend, infer_state=infer_state) + state.init_state() + + assert state.decode_max_q_seq_len == 1 + assert state.b_att_seq_len.tolist() == [5, 7] + assert copied["page_table_shape"] == (2, 16) + assert copied["b_req_idx"].tolist() == [3, 4] + + +def test_build_eagle_accepted_draft_input_narrows_to_accepted_rows(): + from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( + ChunkedPrefillBackend, + ) + + backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) + backend.mtp_step = 2 + + main_input = ModelInput( + batch_size=6, + total_token_num=27, + max_q_seq_len=1, + max_kv_seq_len=9, + input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), + mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), + b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), + is_prefill=False, + multimodal_params=[ + {"row": 0}, + {"row": 1}, + {"row": 2}, + {"row": 3}, + {"row": 4}, + {"row": 5}, + ], + ) + hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) + main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) + next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) + b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) + mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) + + ( + draft_input, + accepted_next_tokens, + accepted_req_idx, + ) = backend._build_eagle_accepted_draft_input( + main_model_input=main_input, + main_model_output=main_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + assert draft_input.batch_size == 2 + assert draft_input.input_ids.tolist() == [111, 222] + assert draft_input.b_req_idx.tolist() == [3, 4] + assert draft_input.b_mtp_index.tolist() == [1, 2] + assert draft_input.b_seq_len.tolist() == [6, 8] + assert draft_input.mem_indexes.tolist() == [101, 202] + assert draft_input.b_num_accepted_tokens is None + assert draft_input.multimodal_params == [{"row": 1}, {"row": 5}] + assert accepted_next_tokens.tolist() == [111, 222] + assert accepted_req_idx.tolist() == [3, 4] + torch.testing.assert_close(draft_input.mtp_draft_input_hiddens, hidden[[1, 5]]) + + +def test_eagle_draft_decode_uses_narrowed_hidden_on_first_step(monkeypatch): + import lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl as chunked_impl + from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( + ChunkedPrefillBackend, + ) + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, need_size): + return torch.arange(300, 300 + need_size, dtype=torch.int32) + + req_to_next_token_ids = torch.empty((8, 3), dtype=torch.int64) + monkeypatch.setattr( + chunked_impl, + "g_infer_context", + SimpleNamespace( + radix_cache=None, + req_manager=SimpleNamespace( + mem_manager=FakeMemManager(), + req_sampling_params_manager=SimpleNamespace(req_to_next_token_ids=req_to_next_token_ids), + ), + ), + ) + monkeypatch.setattr(torch.Tensor, "cuda", lambda self, non_blocking=False: self) + + backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) + backend.mtp_step = 2 + backend.num_mtp_models = 1 + + seen_hiddens = [] + + class FakeDraftModel: + def forward(self, model_input): + seen_hiddens.append(model_input.mtp_draft_input_hiddens.clone()) + logits = torch.zeros((model_input.batch_size, 8), dtype=torch.float32) + return ModelOutput( + logits=logits, + mtp_main_output_hiddens=model_input.mtp_draft_input_hiddens + 100, + ) + + backend.draft_models = [FakeDraftModel()] + + scattered = {} + + def fake_scatter(accepted_req_idx, all_next_token_ids): + scattered["accepted_req_idx"] = accepted_req_idx.clone() + scattered["all_next_token_ids"] = all_next_token_ids.clone() + + backend._scatter_accepted_next_token_ids = fake_scatter + + main_input = ModelInput( + batch_size=6, + total_token_num=27, + max_q_seq_len=1, + max_kv_seq_len=9, + input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), + mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), + b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(6)], + ) + hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) + main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) + next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) + b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) + mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) + + returned_mem = backend._draft_decode_eagle( + main_model_input=main_input, + main_model_output=main_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + assert returned_mem.tolist() == [300, 301, 302, 303] + torch.testing.assert_close(seen_hiddens[0], hidden[[1, 5]]) + torch.testing.assert_close(seen_hiddens[1], hidden[[1, 5]] + 100) + assert scattered["accepted_req_idx"].tolist() == [3, 4] + assert scattered["all_next_token_ids"].shape == (2, 3) From 79a5257fa5bd5e38d19770001581a20d75857080 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 13:26:53 +0800 Subject: [PATCH 04/10] feat(qwen3_5_mtp): scheduler MTP verify backend + accept-len transport Drive the draft/verify loop from the scheduler: - carry a canonical InferReq.mtp_accept_len pointer and persist the per-request accept_len across steps; build per-req b_num_accepted_tokens in decode_mtp and commit it in phase 2 so the next step reads a fresh count. - extend the chunked_prefill backend / base_backend with the MTP verify dispatch and the partial-accept read offset. --- .../server/router/model_infer/infer_batch.py | 23 ++- .../model_infer/mode_backend/base_backend.py | 36 +++- .../mode_backend/chunked_prefill/impl.py | 165 +++++++++++++++--- 3 files changed, 191 insertions(+), 33 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c1..ea63c6c6c6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -357,6 +357,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L if not self.is_linear_att_mixed_model: return + # 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写, + # 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。 + if self.radix_cache is None: + return + # 大页对应的 linear att 的拷贝 big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num big_page_buffer_ids = [] @@ -377,6 +382,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + b_num_accepted_tokens = torch.tensor( + [req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" + ).cuda(non_blocking=True) + copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, @@ -385,6 +394,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, mtp_step=self.args.mtp_step, + b_num_accepted_tokens=b_num_accepted_tokens, ) assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" @@ -400,9 +410,14 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() ) if req.tail_linear_att_small_page_buffer_id is not None: - src_buffer_idx = req.req_idx * (self.args.mtp_step + 1) - gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...] - gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] + canonical_off = req.mtp_accept_len - 1 + conv_src_idx = req.req_idx + ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off + narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1] + gpu_conv_state = self.req_manager.req_to_conv_state.buffer[ + :, conv_src_idx, ..., canonical_off : canonical_off + narrow_w + ] + gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( @@ -558,6 +573,8 @@ def __init__( else: self.decode_need_token_num = self._normal_decode_need_token_num + self.mtp_accept_len: int = 1 + if g_infer_context.is_linear_att_mixed_model: self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0220dc87fb..4a51fbd712 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -359,6 +359,16 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) + elif model_type in ("qwen3_5", "qwen3_5_text"): + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel + + self.draft_models.append(Qwen3_5MTPModel(mtp_model_kvargs)) + elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"): + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + + self.draft_models.append(Qwen3_5MoeMTPModel(mtp_model_kvargs)) else: raise ValueError(f"Unsupported MTP model type: {model_type}") @@ -604,7 +614,6 @@ def _get_classed_reqs( can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: - if req_obj.filter_mark: finished_reqs.append(req_obj) continue @@ -785,11 +794,35 @@ def _verify_mtp_v2( ) return mtp_accept_len, accepted_index + def _commit_mtp_accept_len( + self, + decode_reqs: List[InferReq], + mtp_accept_len_cpu: torch.Tensor, + ): + # Carry the per-req accept count into the NEXT step as the canonical + # pointer (design §3.1). This must run on every rank (not only master): + # the kernels on this rank read req.mtp_accept_len. + # + # CRITICAL ordering (overlap scheduler): the next step's decode_mtp reads + # req.mtp_accept_len (to build b_num_accepted_tokens) the moment its + # wait_to_forward() is released, which happens at THIS step's + # notify_forward_and_wait_post_handle() (start of phase 3). So this carry + # MUST be committed in phase 2 (pre_post_handle), before that release — + # otherwise the next step reads a one-step-stale accept count. The error + # is invisible while accept_len is constant (==1) and corrupts the GDN + # conv/ssm committed-state read-offset the instant a multi-token accept + # (accept_len>=2) occurs. + for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): + req.mtp_accept_len = int(accept_len) + return + def _update_mtp_accept_ratio( self, decode_reqs: List[InferReq], mtp_accept_len_cpu: torch.Tensor, ): + # Master-only accept-ratio statistics. Unlike _commit_mtp_accept_len this + # only feeds metrics, so it may stay in the phase-3 post_handle region. if self.is_master_in_dp: for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1) @@ -811,7 +844,6 @@ def _sample_and_scatter_token( b_prefill_has_output_cpu: torch.Tensor = None, mask_func: Optional[Callable] = None, ): - if mask_func is not None: assert len(run_reqs) == logits.shape[0] mask_func(run_reqs, logits) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 9081c034cd..7ac27d81a3 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -1,9 +1,12 @@ import torch import time +import copy from typing import List, Optional, Callable, Dict, Any from queue import Queue from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack +from lightllm.server.router.model_infer.mode_backend.overlap_events import ( + OverlapEventPack, +) from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, @@ -40,7 +43,10 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: self.prefill = self.prefill_mtp self.decode = self.decode_mtp - self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] + self.is_mtp_eagle = get_env_start_args().mtp_mode in [ + "eagle_with_att", + "eagle_no_att", + ] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla else: @@ -109,7 +115,11 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -152,7 +162,11 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -190,7 +204,11 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + next_token_ids, + next_token_ids_cpu, + next_token_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -201,7 +219,9 @@ def prefill_mtp( ) # mtp kv fill self._draft_prefill_forward( - model_input=model_input, model_output=model_output, next_token_ids=next_token_ids + model_input=model_input, + model_output=model_output, + next_token_ids=next_token_ids, ) g_infer_context.copy_linear_att_state_to_cache_buffer( b_req_idx=model_input.b_req_idx, @@ -241,6 +261,19 @@ def decode_mtp( """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) + # Build the per-real-request accept tensor (carried InferReq.mtp_accept_len + # from the previous step). decode_reqs is one entry per real request, + # aligning 1:1 with the b_gdn_verify_cu_seqlens grouping (the same zip used + # by _update_mtp_accept_ratio). Threaded onto the infer_state via ModelInput + # (mirrors b_mtp_index); to_cuda() moves it inside forward. §3.1 + if self.mtp_step > 0: + accept_lens = [req.mtp_accept_len for req in decode_reqs] + model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( + key="b_num_accepted_tokens", + data=accept_lens, + dtype=torch.int32, + ) + with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) @@ -269,9 +302,10 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs - ) + ( + next_token_ids_cpu, + next_token_logprobs_cpu, + ) = self._async_copy_next_token_infos_to_pin_mem(next_token_ids, next_token_logprobs) # 调用具体的draft decode函数 additional_mem_indexes_cpu = self._draft_decode_func( @@ -293,6 +327,12 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() + # Commit the carried accept count HERE (phase 2 / pre_post_handle), not in + # phase 3: the next overlapped step reads req.mtp_accept_len as soon as this + # step calls notify_forward_and_wait_post_handle() below, so the update must + # land before that release to avoid feeding the kernels a stale (one-step-old) + # accept count. See _commit_mtp_accept_len for the full rationale. + self._commit_mtp_accept_len(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -324,7 +364,12 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): + def _draft_prefill_forward( + self, + model_input: ModelInput, + model_output: ModelOutput, + next_token_ids: torch.Tensor, + ): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input draft_model_output = model_output @@ -373,6 +418,62 @@ def _draft_decode_vanilla( ) return None + def _build_eagle_accepted_draft_input( + self, + main_model_input: ModelInput, + main_model_output: ModelOutput, + next_token_ids: torch.Tensor, + mtp_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + ): + accepted_row_idx = b_req_mtp_start_loc + mtp_accept_len - 1 + accepted_row_idx_long = accepted_row_idx.long() + + draft_model_input = copy.copy(main_model_input) + draft_model_input.batch_size = accepted_row_idx.shape[0] + draft_model_input.total_token_num = draft_model_input.batch_size * main_model_input.max_kv_seq_len + draft_model_input.input_ids = next_token_ids.index_select(0, accepted_row_idx_long) + draft_model_input.mtp_draft_input_hiddens = main_model_output.mtp_main_output_hiddens.index_select( + 0, accepted_row_idx_long + ) + draft_model_input.b_req_idx = main_model_input.b_req_idx.index_select(0, accepted_row_idx_long) + draft_model_input.b_mtp_index = main_model_input.b_mtp_index.index_select(0, accepted_row_idx_long) + draft_model_input.b_seq_len = main_model_input.b_seq_len.index_select(0, accepted_row_idx_long) + draft_model_input.b_num_accepted_tokens = None + if main_model_input.mem_indexes is not None: + draft_model_input.mem_indexes = main_model_input.mem_indexes.index_select(0, accepted_row_idx_long) + draft_model_input.mem_indexes_cpu = None + if main_model_input.b_shared_seq_len is not None: + draft_model_input.b_shared_seq_len = main_model_input.b_shared_seq_len.index_select( + 0, accepted_row_idx_long + ) + if main_model_input.b_mark_shared_group is not None: + draft_model_input.b_mark_shared_group = main_model_input.b_mark_shared_group.index_select( + 0, accepted_row_idx_long + ) + + if accepted_row_idx.device.type == "cpu": + selected_rows = accepted_row_idx.tolist() + draft_model_input.multimodal_params = [main_model_input.multimodal_params[i] for i in selected_rows] + else: + draft_model_input.multimodal_params = [ + {"images": [], "audios": []} for _ in range(draft_model_input.batch_size) + ] + + accepted_next_token_ids = draft_model_input.input_ids + accepted_req_idx = draft_model_input.b_req_idx + return draft_model_input, accepted_next_token_ids, accepted_req_idx + + def _scatter_accepted_next_token_ids(self, accepted_req_idx: torch.Tensor, all_next_token_ids: torch.Tensor): + req_to_next_token_ids = self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids + width = all_next_token_ids.shape[1] + req_to_next_token_ids[:, :width].index_copy_( + 0, + accepted_req_idx.long(), + all_next_token_ids.to(dtype=req_to_next_token_ids.dtype), + ) + return + def _draft_decode_eagle( self, main_model_input: ModelInput, @@ -381,8 +482,7 @@ def _draft_decode_eagle( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - batch_size = main_model_input.batch_size - num_reqs = batch_size // (self.mtp_step + 1) + num_reqs = b_req_mtp_start_loc.shape[0] g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) @@ -390,36 +490,45 @@ def _draft_decode_eagle( g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - # share some inference info with the main model - draft_model_input = main_model_input + ( + draft_model_input, + draft_next_token_ids, + accepted_req_idx, + ) = self._build_eagle_accepted_draft_input( + main_model_input=main_model_input, + main_model_output=main_model_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) draft_model_output = main_model_output - draft_next_token_ids = next_token_ids all_next_token_ids = [] - all_next_token_ids.append(next_token_ids) - # process the draft model output + all_next_token_ids.append(draft_next_token_ids) + + mtp_size = self.mtp_step + 1 + main_mem_indexes = main_model_input.mem_indexes.view(num_reqs, mtp_size) + eagle_mem_indexes_by_req = eagle_mem_indexes.view(self.mtp_step, num_reqs).transpose(0, 1).contiguous() + mem_index_plan = torch.cat([main_mem_indexes, eagle_mem_indexes_by_req], dim=1) + accepted_offsets = mtp_accept_len.long() - 1 + req_offsets = torch.arange(num_reqs, dtype=torch.long, device=mtp_accept_len.device) + for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + if _step > 0: + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step] # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 - eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] - draft_model_input.mem_indexes = torch.cat( - [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], - dim=1, - ).view(-1) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - mtp_scatter_next_token_ids( - req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, + self._scatter_accepted_next_token_ids( + accepted_req_idx=accepted_req_idx, all_next_token_ids=all_next_token_ids, - b_req_idx=main_model_input.b_req_idx, - mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu From 75514ecbc57cc36672c0580784b84ca2620b6846 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 13:27:03 +0800 Subject: [PATCH 05/10] feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models Add the MTP draft model packages and register them: - qwen3_5_mtp: a forced single full-attn-layer draft model, with the MTP pre-layer infer (embed/hidden norm + fc fusion) and pre/post + transformer-layer weight loaders reading the mtp.* namespace. - qwen3_5_moe_mtp: the MoE variant draft weight loaders + model. - register qwen3_5 / qwen3_5_moe MTP draft models with per-block draft_idx, plus the qwen3_5 verify infer_struct. Unit tests scaffold the MTP draft layer and the hybrid verify forward. --- lightllm/models/qwen3_5/infer_struct.py | 35 +++++ lightllm/models/qwen3_5_moe_mtp/__init__.py | 3 + .../qwen3_5_moe_mtp/layer_weights/__init__.py | 5 + .../layer_weights/transformer_layer_weight.py | 142 ++++++++++++++++++ lightllm/models/qwen3_5_moe_mtp/model.py | 8 + lightllm/models/qwen3_5_mtp/__init__.py | 0 .../qwen3_5_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 41 +++++ .../qwen3_5_mtp/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 46 ++++++ .../layer_weights/transformer_layer_weight.py | 70 +++++++++ lightllm/models/qwen3_5_mtp/model.py | 106 +++++++++++++ .../qwen3_5/test_hybrid_verify_forward.py | 18 +++ .../models/qwen3_5/test_mtp_draft_layer.py | 16 ++ 14 files changed, 490 insertions(+) create mode 100644 lightllm/models/qwen3_5_moe_mtp/__init__.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/model.py create mode 100644 lightllm/models/qwen3_5_mtp/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5_mtp/model.py create mode 100644 unit_tests/models/qwen3_5/test_hybrid_verify_forward.py create mode 100644 unit_tests/models/qwen3_5/test_mtp_draft_layer.py diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index d23475c1cf..c96b982359 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -16,4 +16,39 @@ def init_some_extra_state(self, model): mtp_step = get_env_start_args().mtp_step self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + # conv buffer is now ONE widened slot per request (indexed by req_idx), + # dropping the *(S+1) + mtp_index addressing used by the SSM block. + self.b_conv_buffer_idx = self.b_req_idx + # MTP verify batch: decode-mode, S+1 expanded, and gated on the + # per-real-request accept tensor that decode_mtp threads in. Gating on + # b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode) + # distinguishes the main-model verify forward from draft/plain decode. + self.is_mtp_verify = ( + (mtp_step > 0) + and (not self.is_prefill) + and (self.b_mtp_index is not None) + and (self.b_num_accepted_tokens is not None) + ) + self.b_gdn_verify_cu_seqlens = None + self.b_ssm_index_rows = None + # b_num_accepted_tokens is threaded onto the infer_state from ModelInput by + # _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here. + if self.is_mtp_verify: + step = mtp_step + 1 + n_real = self.b_req_idx.shape[0] // step + self.b_gdn_verify_cu_seqlens = torch.arange( + 0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device + ) + req_first = self.b_req_idx.view(n_real, step)[:, 0] + base = (req_first * step).view(n_real, 1) + self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) + assert self.b_ssm_index_rows.shape == (n_real, step) + # The spec conv kernel is per-SEQUENCE (one program per real request), + # indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real), + # aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The + # default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step, + # which launches n_real*step conv programs and reads num_accepted/ + # query_start_loc out of bounds for idx_seq >= n_real, corrupting the + # committed conv slot. Narrow it to one widened conv slot per request. + self.b_conv_buffer_idx = req_first return diff --git a/lightllm/models/qwen3_5_moe_mtp/__init__.py b/lightllm/models/qwen3_5_moe_mtp/__init__.py new file mode 100644 index 0000000000..c8885f8869 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + +__all__ = ["Qwen3_5MoeMTPModel"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..dcad1087d4 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + +__all__ = ["Qwen3_5MoeMTPTransformerLayerWeight"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..80658115ab --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,142 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ( + COLMMWeight, + FusedMoeWeight, + QKVROWNMMWeight, + ROWMMWeight, +) +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3_5MoeMTPTransformerLayerWeight(Qwen35MOETransformerLayerWeight): + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_name = self._retarget(self._q_weight_name) + self._q_norm_name = self._retarget(self._q_norm_name) + self._q_bias_name = self._retarget(self._q_bias_name) + self._k_weight_name = self._retarget(self._k_weight_name) + self._k_norm_name = self._retarget(self._k_norm_name) + self._k_bias_name = self._retarget(self._k_bias_name) + self._v_weight_name = self._retarget(self._v_weight_name) + self._v_bias_name = self._retarget(self._v_bias_name) + self._kv_weight_name = self._retarget(self._kv_weight_name) + self._kv_bias_name = self._retarget(self._kv_bias_name) + self._o_weight_name = self._retarget(self._o_weight_name) + self._o_bias_name = self._retarget(self._o_bias_name) + self._att_norm_weight_name = self._retarget(self._att_norm_weight_name) + self._att_norm_bias_name = self._retarget(self._att_norm_bias_name) + self._ffn_norm_weight_name = self._retarget(self._ffn_norm_weight_name) + self._ffn_norm_bias_name = self._retarget(self._ffn_norm_bias_name) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + self._init_gated_ffn() + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + + prefix = f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + if get_env_start_args().enable_ep_moe: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + tp_rank=0, + tp_world_size=1, + ) + else: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3_5_moe_mtp/model.py b/lightllm/models/qwen3_5_moe_mtp/model.py new file mode 100644 index 0000000000..022864f6b3 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/model.py @@ -0,0 +1,8 @@ +from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + + +class Qwen3_5MoeMTPModel(Qwen3_5MTPModel): + transformer_weight_class = Qwen3_5MoeMTPTransformerLayerWeight diff --git a/lightllm/models/qwen3_5_mtp/__init__.py b/lightllm/models/qwen3_5_mtp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..1ac25662d7 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,41 @@ +import torch + +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer): + + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_fuse( + self, + input_embdings: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3_5MTPPreAndPostLayerWeight, + ) -> torch.Tensor: + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + return layer_weight.eh_proj_weight_.mm(cat_embdings) + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..d899e784bf --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,46 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpGEMMANormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight): + + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names="mtp.fc.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.hnorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + + # Shared with the main Qwen3.5 model, injected by the model class (not loaded here). + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + return diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..b49d594d45 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,70 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPTransformerLayerWeight(Qwen35TransformerLayerWeight): + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _init_weight_names(self): + super()._init_weight_names() + # Retarget all main-model layer key names to the mtp.* namespace. + self._q_weight_name = self._retarget(self._q_weight_name) + self._q_norm_name = self._retarget(self._q_norm_name) + self._q_bias_name = self._retarget(self._q_bias_name) + self._k_weight_name = self._retarget(self._k_weight_name) + self._k_norm_name = self._retarget(self._k_norm_name) + self._k_bias_name = self._retarget(self._k_bias_name) + self._v_weight_name = self._retarget(self._v_weight_name) + self._v_bias_name = self._retarget(self._v_bias_name) + self._kv_weight_name = self._retarget(self._kv_weight_name) + self._kv_bias_name = self._retarget(self._kv_bias_name) + self._o_weight_name = self._retarget(self._o_weight_name) + self._o_bias_name = self._retarget(self._o_bias_name) + self._att_norm_weight_name = self._retarget(self._att_norm_weight_name) + self._att_norm_bias_name = self._retarget(self._att_norm_bias_name) + self._ffn_norm_weight_name = self._retarget(self._ffn_norm_weight_name) + self._ffn_norm_bias_name = self._retarget(self._ffn_norm_bias_name) + # MLP (dense) projection names retargeted by Qwen35TransformerLayerWeight. + self._gate_weight_name = self._retarget(self._gate_weight_name) + self._gate_bias_name = self._retarget(self._gate_bias_name) + self._up_weight_name = self._retarget(self._up_weight_name) + self._up_bias_name = self._retarget(self._up_bias_name) + self._gate_up_weight_name = self._retarget(self._gate_up_weight_name) + self._gate_up_bias_name = self._retarget(self._gate_up_bias_name) + self._down_weight_name = self._retarget(self._down_weight_name) + self._down_bias_name = self._retarget(self._down_bias_name) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) diff --git a/lightllm/models/qwen3_5_mtp/model.py b/lightllm/models/qwen3_5_mtp/model.py new file mode 100644 index 0000000000..72a75a7b22 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/model.py @@ -0,0 +1,106 @@ +from typing import List + +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import Qwen35TransformerLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.qwen3_5_mtp.layer_weights.transformer_layer_weight import Qwen3_5MTPTransformerLayerWeight +from lightllm.models.qwen3_5_mtp.layer_infer.pre_layer_infer import Qwen3_5MTPPreLayerInfer +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPModel(Qwen3_5TpPartModel): + + pre_and_post_weight_class = Qwen3_5MTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3_5MTPPreLayerInfer + transformer_weight_class = Qwen3_5MTPTransformerLayerWeight + transformer_layer_infer_class = Qwen35TransformerLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_config(self): + super()._init_config() + self.config["full_attention_interval"] = 1 + self.config["num_hidden_layers"] = 1 + self.config["n_layer"] = 1 + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] + # Shared with the main Qwen3.5 model (mtp_use_dedicated_embeddings: false). + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + # Build the single draft layer with layer_num == 0 so that, with + # full_attention_interval == 1, it takes the full-attention (mrope) path. + super()._init_infer_layer(start_layer_index=0) + self._assign_draft_kv_slot() + return + + def _assign_draft_kv_slot(self): + mem_manager = self.main_model.mem_manager + main_full_att = getattr(mem_manager, "main_full_att_layer_num", None) + interval = self.main_model.config["full_attention_interval"] + if main_full_att is None: + # Non-hybrid / unexpected mem_manager: nothing to remap. + return + + draft_idx = len(self.mtp_previous_draft_models) + draft_full_att_layers = getattr(mem_manager, "draft_full_att_layers", None) + if draft_full_att_layers is not None: + assert draft_idx < draft_full_att_layers, ( + f"draft_idx {draft_idx} out of range for draft_full_att_layers " + f"{draft_full_att_layers}; mem_manager not sized for this many MTP draft blocks" + ) + draft_kv_slot = main_full_att + draft_idx + layer_infer = self.layers_infer[0] + layer_infer._draft_kv_slot = draft_kv_slot + layer_infer.layer_num_ = draft_kv_slot * interval + logger.info( + f"Qwen3.5 MTP draft layer assigned dedicated full-attn KV slot {draft_kv_slot} " + f"(layer_num_={layer_infer.layer_num_}, interval={interval}, main_full_att={main_full_att})" + ) + return diff --git a/unit_tests/models/qwen3_5/test_hybrid_verify_forward.py b/unit_tests/models/qwen3_5/test_hybrid_verify_forward.py new file mode 100644 index 0000000000..9b68053963 --- /dev/null +++ b/unit_tests/models/qwen3_5/test_hybrid_verify_forward.py @@ -0,0 +1,18 @@ +import os +import pytest +import torch + +CKPT = os.environ.get("QWEN35_MTP_CKPT", "/mtc/models/Qwen3.5-27B") +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or not os.path.isdir(CKPT), + reason="needs CUDA + a qwen3_5 checkpoint (QWEN35_MTP_CKPT or /mtc/models/Qwen3.5-27B)", +) + + +def test_hybrid_mtp_verify_matches_sequential_decode(): + """A verify step over S+1 fully-accepted candidates must produce the same + committed hidden state / next-token logits as sequentially decoding those + tokens through the non-MTP path, across BOTH GDN and full-attn layers. + Full end-to-end equivalence is enforced E2E in Phase 10; this scaffold marks + the per-layer-dispatch contract (design §3.4b).""" + pytest.skip("Implement with the running-model fixture; covered E2E in Phase 10.") diff --git a/unit_tests/models/qwen3_5/test_mtp_draft_layer.py b/unit_tests/models/qwen3_5/test_mtp_draft_layer.py new file mode 100644 index 0000000000..a258428a08 --- /dev/null +++ b/unit_tests/models/qwen3_5/test_mtp_draft_layer.py @@ -0,0 +1,16 @@ +import os +import pytest +import torch + +CKPT = os.environ.get("QWEN35_MTP_CKPT", "/mtc/models/Qwen3.5-27B") +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or not os.path.isdir(CKPT), + reason="needs CUDA + a qwen3_5 checkpoint with mtp.* weights", +) + + +def test_draft_single_layer_is_full_attention_with_mrope(): + """Risk #12: a naive inherit gives a GDN layer or standard rope. The draft's + one layer must take the full-attn (mrope) path, NOT a GDN path. Full logits + parity is covered E2E in Phase 10; this marks the contract.""" + pytest.skip("Implement with checkpoint fixture; logits parity covered E2E in Phase 10.") From e49c09269c6b0b1a24e91ea6dd6f4ab29fc34d60 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 4 Jun 2026 22:57:18 +0800 Subject: [PATCH 06/10] fix(qwen3next): persist mtp full-attn cpu cache slots --- .../linear_att_cpu_cache_copy.py | 7 +- .../operator/linear_att.py | 26 ++- .../linear_att_cache_manager/config_objs.py | 21 +- lightllm/models/qwen3next/model.py | 18 +- lightllm/utils/kv_cache_utils.py | 2 + test/acc/cpu_cache_roundtrip_test.py | 86 +++++++ ...st_linear_att_mtp_cpu_cache_persistence.py | 219 ++++++++++++++++++ 7 files changed, 355 insertions(+), 24 deletions(-) create mode 100644 test/acc/cpu_cache_roundtrip_test.py create mode 100644 unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py index 37b27cadb2..1251dddc33 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_kv_full_att_state.shape[-2] - assert ( - full_att_layer_num - == (linear_config.all_layer_num // linear_config.full_attention_interval) - == (linear_config.all_layer_num - linear_config.linear_layer_num) - ) + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] @@ -428,6 +424,7 @@ def copy_cpu_cache_to_kv_buffer( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_full_att_kv_state.shape[-2] + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 586706c8e1..e3ae9493c7 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -24,6 +24,16 @@ def __init__(self, mem_manager): super().__init__(mem_manager) self.linear_config = LinearAttCacheConfig.load_from_args() + @staticmethod + def _get_persisted_full_att_layer_num(mem_manager) -> int: + persisted_full_att = getattr(mem_manager, "persisted_full_att_layer_num", None) + if persisted_full_att is None: + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + draft_full_att = getattr(mem_manager, "draft_full_att_layers", 0) + persisted_full_att = main_full_att + draft_full_att + assert 0 < persisted_full_att <= mem_manager.kv_buffer.shape[0] + return int(persisted_full_att) + def load_cpu_cache_to_gpu( self, mem_indexes: torch.Tensor, @@ -76,16 +86,14 @@ def load_cpu_cache_to_gpu( copy_cpu_cache_to_kv_buffer, ) - # Persist/restore ONLY the main model's full-attn slice. The kv buffer is widened by - # dedicated MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that - # must never touch the CPU/disk cache), so slice them off here. - main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + # Restore the persisted full-attn slice: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) copy_cpu_cache_to_kv_buffer( mem_indexes=mem_indexes, big_page_buffer_ids=big_page_buffer_ids_gpu, page_indexes=page_indexes, - gpu_full_att_kv_state=mem_manager.kv_buffer[:main_full_att], + gpu_full_att_kv_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, @@ -174,17 +182,15 @@ def offload_gpu_kv_to_cpu_cache( copy_kv_buffer_to_cpu_cache, ) - # Persist ONLY the main model's full-attn slice. The kv buffer is widened by dedicated - # MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that must never - # be persisted to the CPU/disk cache), so slice them off here. - main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + # Persist the full-attn slice used for prefix reuse: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) copy_kv_buffer_to_cpu_cache( mem_indexes=mem_indexes, page_indexes=page_indexes, page_readies=page_readies, big_page_buffer_ids=big_page_buffer_ids_gpu, - gpu_kv_full_att_state=mem_manager.kv_buffer[:main_full_att], + gpu_kv_full_att_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index ca6415e16f..d4fb284126 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -8,6 +8,15 @@ logger = init_logger(__name__) +def get_mtp_draft_full_att_layer_num(args) -> int: + mtp_mode = getattr(args, "mtp_mode", None) + if mtp_mode == "eagle_with_att": + return 1 + if mtp_mode == "vanilla_with_att": + return getattr(args, "mtp_step", 0) + return 0 + + @dataclasses.dataclass class LinearAttCacheConfig: tp_world_size: int @@ -28,10 +37,19 @@ class LinearAttCacheConfig: ssm_state_dtype: torch.dtype full_attention_interval: int all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 + draft_full_att_layer_num: int = 0 def get_conv_dim(self): return self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + def get_main_full_att_layer_num(self): + main_full_att_layer_num = self.all_layer_num - self.linear_layer_num + assert main_full_att_layer_num == self.all_layer_num // self.full_attention_interval + return main_full_att_layer_num + + def get_persisted_full_att_layer_num(self): + return self.get_main_full_att_layer_num() + self.draft_full_att_layer_num + def get_persisted_conv_state_shape(self): # NARROW shape used for the CPU/disk persisted page and ALL byte math. # Persisted state is always the committed (narrow) sliding window. @@ -71,7 +89,7 @@ def get_cpu_cache_full_att_bytes(self): ) assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize - a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num + a = full_att_bytes * self.get_persisted_full_att_layer_num() * big_page_token_num return a def get_cpu_cache_conv_bytes(self): @@ -116,4 +134,5 @@ def load_from_args() -> "LinearAttCacheConfig": ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type), full_attention_interval=llm_config["full_attention_interval"], all_layer_num=n_layer, + draft_full_att_layer_num=get_mtp_draft_full_att_layer_num(args), ) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 0de85cdd45..f0940ba0f8 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -16,7 +16,10 @@ from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.linear_att_cache_manager.config_objs import ( + LinearAttCacheConfig, + get_mtp_draft_full_att_layer_num, +) logger = init_logger(__name__) @@ -58,6 +61,7 @@ def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + draft_full_att_layers = get_mtp_draft_full_att_layer_num(start_args) self.linear_config = LinearAttCacheConfig( tp_world_size=self.tp_world_size_, full_att_all_num_kv_heads=self.config["num_key_value_heads"], @@ -75,14 +79,11 @@ def _init_mem_manager(self): ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type], full_attention_interval=self.config["full_attention_interval"], all_layer_num=self.config["n_layer"], + draft_full_att_layer_num=draft_full_att_layers, ) - main_full_att = self.linear_config.all_layer_num - self.linear_config.linear_layer_num - draft_full_att_layers = 0 - if start_args.mtp_mode == "eagle_with_att": - draft_full_att_layers = 1 - elif start_args.mtp_mode == "vanilla_with_att": - draft_full_att_layers = start_args.mtp_step + main_full_att = self.linear_config.get_main_full_att_layer_num() + persisted_full_att = self.linear_config.get_persisted_full_att_layer_num() self._main_full_att_layer_num = main_full_att self._draft_full_att_layers = draft_full_att_layers @@ -91,12 +92,13 @@ def _init_mem_manager(self): dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - full_att_layer_num=main_full_att + draft_full_att_layers, + full_att_layer_num=persisted_full_att, linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) self.mem_manager.main_full_att_layer_num = main_full_att self.mem_manager.draft_full_att_layers = draft_full_att_layers + self.mem_manager.persisted_full_att_layer_num = persisted_full_att def _init_req_manager(self): create_max_seq_len = 0 diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 69e4097242..2a089a9bf2 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -121,6 +121,8 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 if is_linear_att_mixed_model(args.model_dir): + # Linear mixed models use one packed byte page; MTP draft full-attn + # slots are accounted in LinearAttCacheConfig.get_cpu_cache_big_page_bytes(). pass else: cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() diff --git a/test/acc/cpu_cache_roundtrip_test.py b/test/acc/cpu_cache_roundtrip_test.py new file mode 100644 index 0000000000..28ce96f583 --- /dev/null +++ b/test/acc/cpu_cache_roundtrip_test.py @@ -0,0 +1,86 @@ +"""Force the CPU KV-cache offload->restore path and check correctness. + +GSM8K can't exercise the CPU cache (one shared hot prefix, sub-page tails). +This driver builds N distinct, page-aligned, long prompts that overflow the +GPU KV budget so their KV is offloaded to CPU, then re-requests them so they +are restored from CPU. With greedy decoding the round-2 (CPU-restored) output +MUST be token-identical to round-1 (freshly computed). For the MTP build it +also tracks accept-rate (mtp_avg_token_per_step) which would degrade if the +draft full-attn slots were not persisted/restored correctly. +""" +import argparse +import sys +import requests +from concurrent.futures import ThreadPoolExecutor + + +def make_prompts(n, words_per_prompt): + prompts = [] + for i in range(n): + # Distinct, deterministic filler so each prompt is its own radix branch + # and long enough to span several 256-token pages. + filler = " ".join(f"item{i}-{j}" for j in range(words_per_prompt)) + prompts.append( + f"You are given list number {i}. The list is: {filler}. " + f"Question: briefly summarize what list number {i} contains. Answer:" + ) + return prompts + + +def gen(url, prompt, max_tokens): + data = { + "inputs": prompt, + "parameters": { + "temperature": 0.0, + "max_new_tokens": max_tokens, + "stop_sequences": None, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + r = requests.post(url, json=data, timeout=120) + assert r.status_code == 200, f"{r.status_code}: {r.text}" + return r.json()["generated_text"][0] + + +def run_round(url, prompts, max_tokens, parallel): + out = [None] * len(prompts) + with ThreadPoolExecutor(max_workers=parallel) as ex: + futs = {ex.submit(gen, url, p, max_tokens): k for k, p in enumerate(prompts)} + for f in futs: + k = futs[f] + out[k] = f.result() + return out + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--host", default="http://127.0.0.1") + ap.add_argument("--port", type=int, default=8088) + ap.add_argument("--num-prompts", type=int, default=24) + ap.add_argument("--words-per-prompt", type=int, default=400) + ap.add_argument("--max-tokens", type=int, default=32) + ap.add_argument("--parallel", type=int, default=8) + args = ap.parse_args() + + url = f"{args.host}:{args.port}/generate" + prompts = make_prompts(args.num_prompts, args.words_per_prompt) + + print(f"Round 1 (cold compute): {len(prompts)} distinct prompts", flush=True) + r1 = run_round(url, prompts, args.max_tokens, args.parallel) + print("Round 2 (CPU-restored):", flush=True) + r2 = run_round(url, prompts, args.max_tokens, args.parallel) + + mismatches = [i for i in range(len(prompts)) if r1[i] != r2[i]] + print(f"\n=== RESULT ===") + print(f"prompts: {len(prompts)} identical: {len(prompts) - len(mismatches)} mismatches: {len(mismatches)}") + if mismatches: + for i in mismatches[:5]: + print(f" [#{i}] R1={r1[i]!r}\n R2={r2[i]!r}") + sys.exit(1) + print("PASS: round-2 (CPU-restored) output is token-identical to round-1.") + + +if __name__ == "__main__": + main() diff --git a/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py b/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py new file mode 100644 index 0000000000..fb22f0ed1f --- /dev/null +++ b/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py @@ -0,0 +1,219 @@ +from types import SimpleNamespace + +import pytest +import torch + + +def _make_start_args(**overrides): + base = dict( + model_dir="/tmp/qwen3_5", + tp=1, + dp=1, + data_type="bfloat16", + linear_att_ssm_data_type="bfloat16", + mtp_mode=None, + mtp_step=0, + linear_att_page_block_num=2, + linear_att_hash_page_size=4, + cpu_cache_token_page_size=8, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +def _make_model_cfg(): + return { + "model_type": "qwen3_5", + "num_hidden_layers": 64, + "num_key_value_heads": 16, + "head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 48, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4, + "full_attention_interval": 4, + } + + +def _patch_linear_config_args(monkeypatch, args): + import lightllm.common.linear_att_cache_manager.config_objs as config_objs + + monkeypatch.setattr(config_objs, "get_env_start_args", lambda: args) + + +def _make_config(draft_full_att_layer_num=0): + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + return LinearAttCacheConfig( + tp_world_size=1, + full_att_all_num_kv_heads=16, + full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=16, + full_att_head_dim=128, + num_linear_k_heads=16, + num_linear_v_heads=48, + head_linear_k_dim=128, + head_linear_v_dim=128, + conv_kernel_size=4, + linear_layer_num=48, + conv_state_dtype=torch.bfloat16, + ssm_state_dtype=torch.bfloat16, + full_attention_interval=4, + all_layer_num=64, + draft_full_att_layer_num=draft_full_att_layer_num, + ) + + +def test_load_from_args_includes_mtp_draft_full_att_layers(monkeypatch): + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + from transformers.configuration_utils import PretrainedConfig + + args = _make_start_args(mtp_mode="vanilla_with_att", mtp_step=3) + _patch_linear_config_args(monkeypatch, args) + monkeypatch.setattr(PretrainedConfig, "get_config_dict", lambda _model_path: (_make_model_cfg(), None)) + + cfg = LinearAttCacheConfig.load_from_args() + + assert cfg.get_main_full_att_layer_num() == 16 + assert cfg.draft_full_att_layer_num == 3 + assert cfg.get_persisted_full_att_layer_num() == 19 + + +def test_cpu_cache_full_att_bytes_include_mtp_draft_layers(monkeypatch): + args = _make_start_args() + _patch_linear_config_args(monkeypatch, args) + main_only = _make_config(draft_full_att_layer_num=0) + with_draft = _make_config(draft_full_att_layer_num=2) + + bytes_per_full_att_layer = ( + args.cpu_cache_token_page_size + * 2 + * main_only.full_att_all_num_kv_heads + * main_only.full_att_head_dim + * main_only.full_att_dtype.itemsize + ) + + assert main_only.get_main_full_att_layer_num() == 16 + assert with_draft.get_persisted_full_att_layer_num() == 18 + assert with_draft.get_cpu_cache_full_att_bytes() == ( + main_only.get_cpu_cache_full_att_bytes() + 2 * bytes_per_full_att_layer + ) + + +def test_linear_operator_persisted_full_att_slice_includes_draft_slots(): + from lightllm.common.kv_cache_mem_manager.operator.linear_att import LinearAttMemOperator + + class MtpMemManager: + main_full_att_layer_num = 16 + draft_full_att_layers = 2 + kv_buffer = torch.empty((18, 1)) + + class MainOnlyMemManager: + main_full_att_layer_num = 16 + kv_buffer = torch.empty((18, 1)) + + class PlainMemManager: + kv_buffer = torch.empty((7, 1)) + + assert LinearAttMemOperator._get_persisted_full_att_layer_num(MtpMemManager()) == 18 + assert LinearAttMemOperator._get_persisted_full_att_layer_num(MainOnlyMemManager()) == 16 + assert LinearAttMemOperator._get_persisted_full_att_layer_num(PlainMemManager()) == 7 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_linear_cpu_cache_roundtrips_mtp_draft_full_att_slot(monkeypatch): + from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_cpu_cache_to_kv_buffer, + copy_kv_buffer_to_cpu_cache, + ) + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + args = _make_start_args( + linear_att_page_block_num=1, + linear_att_hash_page_size=2, + cpu_cache_token_page_size=2, + ) + _patch_linear_config_args(monkeypatch, args) + cfg = LinearAttCacheConfig( + tp_world_size=1, + full_att_all_num_kv_heads=2, + full_att_dtype=torch.float32, + full_att_num_kv_heads=2, + full_att_head_dim=8, + num_linear_k_heads=1, + num_linear_v_heads=1, + head_linear_k_dim=8, + head_linear_v_dim=8, + conv_kernel_size=2, + linear_layer_num=1, + conv_state_dtype=torch.float32, + ssm_state_dtype=torch.float32, + full_attention_interval=2, + all_layer_num=2, + draft_full_att_layer_num=1, + ) + + gpu_kv = torch.arange(2 * 2 * 4 * 8, dtype=torch.float32, device="cuda").reshape(2, 2, 4, 8) + cpu_cache_tensor = torch.zeros( + (1, 1, 1, 1, cfg.get_cpu_cache_big_page_bytes()), + dtype=torch.uint8, + device="cuda", + ) + conv_state = torch.zeros( + (1, cfg.linear_layer_num, cfg.get_conv_dim(), cfg.conv_kernel_size - 1), + dtype=torch.float32, + device="cuda", + ) + ssm_state = torch.zeros( + ( + 1, + cfg.linear_layer_num, + cfg.num_linear_v_heads, + cfg.head_linear_k_dim, + cfg.head_linear_v_dim, + ), + dtype=torch.float32, + device="cuda", + ) + mem_indexes = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + page_indexes = torch.tensor([0], dtype=torch.int32, device="cuda") + page_readies = torch.tensor([False], dtype=torch.bool, device="cuda") + big_page_buffer_ids = torch.tensor([0], dtype=torch.int64, device="cuda") + + copy_kv_buffer_to_cpu_cache( + mem_indexes=mem_indexes, + page_indexes=page_indexes, + page_readies=page_readies, + big_page_buffer_ids=big_page_buffer_ids, + gpu_kv_full_att_state=gpu_kv, + cpu_kv_conv_state=conv_state, + cpu_kv_ssm_state=ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=1, + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=cfg, + grid_num=1, + ) + + restored_gpu_kv = torch.full_like(gpu_kv, fill_value=-1) + restored_conv = torch.empty_like(conv_state) + restored_ssm = torch.empty_like(ssm_state) + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes, + big_page_buffer_ids=big_page_buffer_ids, + page_indexes=page_indexes, + gpu_full_att_kv_state=restored_gpu_kv, + cpu_kv_conv_state=restored_conv, + cpu_kv_ssm_state=restored_ssm, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=1, + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=cfg, + grid_num=1, + ) + torch.cuda.synchronize() + + torch.testing.assert_close(restored_gpu_kv, gpu_kv) From b5d547646e84d6ddd08613be9a57593abc4451ad Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 5 Jun 2026 13:26:52 +0800 Subject: [PATCH 07/10] refactor(qwen3_5_mtp): drop unused _draft_kv_slot attribute The write-only layer_infer._draft_kv_slot was never read anywhere; the KV-slot mapping is fully expressed via layer_num_ = draft_kv_slot * interval. --- lightllm/models/qwen3_5_mtp/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/qwen3_5_mtp/model.py b/lightllm/models/qwen3_5_mtp/model.py index 72a75a7b22..400916d2d4 100644 --- a/lightllm/models/qwen3_5_mtp/model.py +++ b/lightllm/models/qwen3_5_mtp/model.py @@ -97,7 +97,6 @@ def _assign_draft_kv_slot(self): ) draft_kv_slot = main_full_att + draft_idx layer_infer = self.layers_infer[0] - layer_infer._draft_kv_slot = draft_kv_slot layer_infer.layer_num_ = draft_kv_slot * interval logger.info( f"Qwen3.5 MTP draft layer assigned dedicated full-attn KV slot {draft_kv_slot} " From 48c15de085f6f60f9980c64b7a9c4082cac594e7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 5 Jun 2026 14:42:28 +0800 Subject: [PATCH 08/10] style: fix black formatting and drop unused var for pre-commit --- lightllm/common/basemodel/attention/fa3/fp.py | 4 ++-- .../common/basemodel/attention/fa3/fp8.py | 18 +++++++++----- .../layer_infer/pre_layer_infer.py | 1 - .../pre_and_post_layer_weight.py | 1 - .../mode_backend/chunked_prefill/impl.py | 24 ++++--------------- .../basemodel/test_mtp_decode_cuda_graph.py | 6 +---- 6 files changed, 19 insertions(+), 35 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 0d53b44b68..949253c840 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -105,7 +105,7 @@ def _nomarl_prefill_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), @@ -237,7 +237,7 @@ def _normal_decode_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 9a32ef7a9f..e11cdc462e 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -44,9 +44,12 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) def prefill_att( self, @@ -115,7 +118,6 @@ def init_state(self): super().init_state() self.backend: Fp8Fa3AttBackend = self.backend - device = self.infer_state.input_ids.device batch_size = self.b_att_seq_len.shape[0] mem_manager = self.backend.model.mem_manager @@ -123,8 +125,12 @@ def init_state(self): head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) return diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py index 1ac25662d7..906a0ab62c 100644 --- a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py @@ -6,7 +6,6 @@ class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer): - def __init__(self, network_config): super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py index d899e784bf..25c56a0d7e 100644 --- a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py @@ -9,7 +9,6 @@ class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, quant_cfg: Quantcfg): super().__init__(data_type, network_config) self.quant_cfg: Quantcfg = quant_cfg diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 7ac27d81a3..b32c6fd2ad 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -115,11 +115,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - ( - _, - next_token_ids_cpu, - next_token_logprobs_cpu, - ) = self._sample_and_scatter_token( + (_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -162,11 +158,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - ( - _, - next_token_ids_cpu, - next_token_logprobs_cpu, - ) = self._sample_and_scatter_token( + (_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -204,11 +196,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - ( - next_token_ids, - next_token_ids_cpu, - next_token_logprobs_cpu, - ) = self._sample_and_scatter_token( + (next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -490,11 +478,7 @@ def _draft_decode_eagle( g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - ( - draft_model_input, - draft_next_token_ids, - accepted_req_idx, - ) = self._build_eagle_accepted_draft_input( + (draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input( main_model_input=main_model_input, main_model_output=main_model_output, next_token_ids=next_token_ids, diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py index d29bb1104e..d3cf7a4821 100644 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -286,11 +286,7 @@ def test_build_eagle_accepted_draft_input_narrows_to_accepted_rows(): b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) - ( - draft_input, - accepted_next_tokens, - accepted_req_idx, - ) = backend._build_eagle_accepted_draft_input( + (draft_input, accepted_next_tokens, accepted_req_idx,) = backend._build_eagle_accepted_draft_input( main_model_input=main_input, main_model_output=main_output, next_token_ids=next_token_ids, From 16170f3b090fd79e2e8c140f9dc1d5cabb8736b0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sun, 7 Jun 2026 15:55:49 +0800 Subject: [PATCH 09/10] style: align formatting with upstream/main and inline mtp accept-len commit - Revert local reformatting to match upstream/main exactly, minimizing PR diff - Inline _commit_mtp_accept_len into decode_mtp (phase-2 ordering preserved) - Drop redundant inline comments --- lightllm/common/basemodel/attention/fa3/fp.py | 32 ++++--------- .../common/basemodel/attention/fa3/fp8.py | 17 ++----- .../common/basemodel/attention/fa3/mla.py | 32 ++++--------- .../triton_kernel/linear_att_copy.py | 15 ------ .../model_infer/mode_backend/base_backend.py | 26 +---------- .../mode_backend/chunked_prefill/impl.py | 46 +++++-------------- 6 files changed, 37 insertions(+), 131 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 949253c840..27b6bf05ed 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,19 +1,12 @@ import dataclasses import torch -from ..base_att import ( - BaseAttBackend, - BasePrefillAttState, - BaseDecodeAttState, - AttControl, -) +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy -from lightllm.common.basemodel.triton_kernel.gen_prefill_params import ( - gen_cumsum_pad0_tensor, -) +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor class Fa3AttBackend(BaseAttBackend): @@ -28,14 +21,12 @@ def get_page_table_buffer(self): model = self.model if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ - torch.empty( - model.graph_max_batch_size * model.graph_max_len_in_batch, - dtype=torch.int32, - ).to(get_current_device_id()), - torch.empty( - model.graph_max_batch_size * model.graph_max_len_in_batch, - dtype=torch.int32, - ).to(get_current_device_id()), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), ] return self._shared_page_table_buffer @@ -84,12 +75,7 @@ def prefill_att( ) def _nomarl_prefill_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_control: AttControl, - alloc_func=torch.empty, + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index e11cdc462e..1902e6b1f0 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -44,12 +44,9 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) - self.v_descale = ( - offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) + self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + def prefill_att( self, @@ -125,12 +122,8 @@ def init_state(self): head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) - self.v_descale = ( - offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) + self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) return diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 2ed9ba4112..b740e81928 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -1,19 +1,12 @@ import dataclasses import torch -from ..base_att import ( - BaseAttBackend, - BasePrefillAttState, - BaseDecodeAttState, - AttControl, -) +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING, Tuple from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy -from lightllm.common.basemodel.triton_kernel.gen_prefill_params import ( - gen_cumsum_pad0_tensor, -) +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.sgl_utils import flash_attn_varlen_func @@ -29,14 +22,12 @@ def get_page_table_buffer(self): model = self.model if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ - torch.empty( - model.graph_max_batch_size * model.graph_max_len_in_batch, - dtype=torch.int32, - ).to(get_current_device_id()), - torch.empty( - model.graph_max_batch_size * model.graph_max_len_in_batch, - dtype=torch.int32, - ).to(get_current_device_id()), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), ] return self._shared_page_table_buffer @@ -78,12 +69,7 @@ def prefill_att( ) def _mla_prefill_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_control: AttControl, - alloc_func=torch.empty, + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: MlaFa3AttBackend = self.backend # for typing k_nope, k_rope = k diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py index 5fb98c4daa..22220f4811 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -46,15 +46,7 @@ def _copy_linear_att_state_to_kv_buffer( accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64) canonical_off = accept_len - 1 - # --- conv snapshot --- - # conv is a single WIDENED slot keyed by req_idx (asymmetric layout, §3.4). - # The committed NARROW window of byte length conv_narrow_row_bytes sits at - # byte offset canonical_off * itemsize inside each widened row. The flattened - # uint8 tail lays out element [d, w] at d * gpu_conv_row_bytes + w (bytes), - # so the narrow window is strided per row: copy row-by-row. conv_src_slot = cur_req_idx - # gpu_conv_stride_d carries the per-element byte size (itemsize); the narrow - # window starts canonical_off elements into the widened row. conv_off_bytes = canonical_off * gpu_conv_stride_d gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l @@ -65,9 +57,6 @@ def _copy_linear_att_state_to_kv_buffer( conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask) tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask) - # --- ssm snapshot --- - # ssm is an (S+1) BLOCK per request; the committed block slot is - # req_idx * (mtp_step + 1) + canonical_off. ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64) for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) @@ -98,10 +87,6 @@ def copy_linear_att_state_to_kv_buffer( assert len(b_req_idx) == b_num_accepted_tokens.shape[0] BLOCK = 4096 - # Conv: keep the (conv_dim, width) tail un-flattened so the committed narrow - # window can be read per row at the canonical offset (the window is strided - # in the flattened widened layout). Capture itemsize BEFORE the uint8 view to - # convert the element-unit canonical offset into a byte offset. assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]" assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]" conv_itemsize = gpu_conv_state.element_size() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4a51fbd712..bbd06ce992 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -794,35 +794,13 @@ def _verify_mtp_v2( ) return mtp_accept_len, accepted_index - def _commit_mtp_accept_len( - self, - decode_reqs: List[InferReq], - mtp_accept_len_cpu: torch.Tensor, - ): - # Carry the per-req accept count into the NEXT step as the canonical - # pointer (design §3.1). This must run on every rank (not only master): - # the kernels on this rank read req.mtp_accept_len. - # - # CRITICAL ordering (overlap scheduler): the next step's decode_mtp reads - # req.mtp_accept_len (to build b_num_accepted_tokens) the moment its - # wait_to_forward() is released, which happens at THIS step's - # notify_forward_and_wait_post_handle() (start of phase 3). So this carry - # MUST be committed in phase 2 (pre_post_handle), before that release — - # otherwise the next step reads a one-step-stale accept count. The error - # is invisible while accept_len is constant (==1) and corrupts the GDN - # conv/ssm committed-state read-offset the instant a multi-token accept - # (accept_len>=2) occurs. - for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): - req.mtp_accept_len = int(accept_len) - return - def _update_mtp_accept_ratio( self, decode_reqs: List[InferReq], mtp_accept_len_cpu: torch.Tensor, ): - # Master-only accept-ratio statistics. Unlike _commit_mtp_accept_len this - # only feeds metrics, so it may stay in the phase-3 post_handle region. + # Master-only accept-ratio statistics. Unlike the phase-2 mtp_accept_len commit + # (inlined in decode_mtp) this only feeds metrics, so it may stay in phase 3. if self.is_master_in_dp: for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index b32c6fd2ad..614aef3c67 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -4,9 +4,7 @@ from typing import List, Optional, Callable, Dict, Any from queue import Queue from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.server.router.model_infer.mode_backend.overlap_events import ( - OverlapEventPack, -) +from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, @@ -43,10 +41,7 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: self.prefill = self.prefill_mtp self.decode = self.decode_mtp - self.is_mtp_eagle = get_env_start_args().mtp_mode in [ - "eagle_with_att", - "eagle_no_att", - ] + self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla else: @@ -115,7 +110,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - (_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -158,7 +153,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - (_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -196,7 +191,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - (next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token( + next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, b_mtp_index=model_input.b_mtp_index, @@ -207,9 +202,7 @@ def prefill_mtp( ) # mtp kv fill self._draft_prefill_forward( - model_input=model_input, - model_output=model_output, - next_token_ids=next_token_ids, + model_input=model_input, model_output=model_output, next_token_ids=next_token_ids ) g_infer_context.copy_linear_att_state_to_cache_buffer( b_req_idx=model_input.b_req_idx, @@ -249,11 +242,6 @@ def decode_mtp( """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) - # Build the per-real-request accept tensor (carried InferReq.mtp_accept_len - # from the previous step). decode_reqs is one entry per real request, - # aligning 1:1 with the b_gdn_verify_cu_seqlens grouping (the same zip used - # by _update_mtp_accept_ratio). Threaded onto the infer_state via ModelInput - # (mirrors b_mtp_index); to_cuda() moves it inside forward. §3.1 if self.mtp_step > 0: accept_lens = [req.mtp_accept_len for req in decode_reqs] model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( @@ -290,10 +278,9 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - ( - next_token_ids_cpu, - next_token_logprobs_cpu, - ) = self._async_copy_next_token_infos_to_pin_mem(next_token_ids, next_token_logprobs) + next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( + next_token_ids, next_token_logprobs + ) # 调用具体的draft decode函数 additional_mem_indexes_cpu = self._draft_decode_func( @@ -315,12 +302,8 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() - # Commit the carried accept count HERE (phase 2 / pre_post_handle), not in - # phase 3: the next overlapped step reads req.mtp_accept_len as soon as this - # step calls notify_forward_and_wait_post_handle() below, so the update must - # land before that release to avoid feeding the kernels a stale (one-step-old) - # accept count. See _commit_mtp_accept_len for the full rationale. - self._commit_mtp_accept_len(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) + for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): + req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -352,12 +335,7 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _draft_prefill_forward( - self, - model_input: ModelInput, - model_output: ModelOutput, - next_token_ids: torch.Tensor, - ): + def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input draft_model_output = model_output From 0d42047b6ed186f287201f1daf511eeeba4625b7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sun, 7 Jun 2026 16:10:14 +0800 Subject: [PATCH 10/10] test(static_inference): generalize MTP static benchmark - Dispatch to MTP bench whenever mtp_mode is set (was dead-coded to 'deepseekv3') - init_mtp_model: dispatch by config model_type (deepseek_v3/qwen3_moe/mistral/ glm4_moe_lite/qwen3_5/qwen3_5_moe), handle eagle (1 instance) vs vanilla (mtp_step instances); fix mem_faction typo; pass full att/kv/quant kvargs - run_forward_once: adapt to new ModelInput API (mem_indexes_cpu + CPU tensors, max_q/kv_seq_len, b_mtp_index, b_prefill_start_loc); reuse draft instances via _step % num_instances; pad/truncate draft_ids to mtp_step+1 - Cap max_req_num at 512 to avoid GDN req-state cache OOM under MTP --- .../benchmark/static_inference/model_infer.py | 2 +- .../static_inference/model_infer_mtp.py | 175 +++++++++++++----- test/benchmark/static_inference/test_model.py | 2 +- 3 files changed, 129 insertions(+), 50 deletions(-) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index f2c900af09..b93c5fee55 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -36,7 +36,7 @@ def test_model_inference(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_fraction": args.mem_fraction, - "max_req_num": 2048, + "max_req_num": 512, "batch_max_tokens": 1024, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 72f06a919c..0935131af0 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -9,42 +9,85 @@ from lightllm.models import get_model from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.server.core.objs.start_args_type import StartArgs -from torch.profiler import profile, record_function, ProfilerActivity +from torch.profiler import profile, ProfilerActivity from lightllm.utils.log_utils import init_logger from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel -import torch.cuda as cuda +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel logger = init_logger(__name__) def init_mtp_model(args: StartArgs, kvargs, main_model): - mtp_step = args.mtp_step draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - mtp_model_kvargs = kvargs - mtp_model_kvargs.update( - { - "weight_dir": args.mtp_draft_model_dir, + + if args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + num_mtp_modules = args.mtp_step + elif args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + num_mtp_modules = 1 + else: + assert False, f"error mtp mode {args.mtp_mode}" + + for i in range(num_mtp_modules): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir[i]) + model_type = mtp_model_cfg.get("model_type", "") + mtp_model_kvargs = { + "weight_dir": args.mtp_draft_model_dir[i], "max_total_token_num": main_model.mem_manager.size, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, + "load_way": kvargs["load_way"], + "max_req_num": kvargs.get("max_req_num", 1000), + "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), + "is_token_healing": False, + "return_all_prompt_logics": False, + "disable_chunked_prefill": args.disable_chunked_prefill, + "data_type": kvargs.get("data_type", "float16"), + "graph_max_batch_size": kvargs.get("graph_max_batch_size", 16), + "graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196), + "disable_cudagraph": kvargs.get("disable_cudagraph", False), + "mem_fraction": kvargs["mem_fraction"], + "batch_max_tokens": kvargs.get("batch_max_tokens", None), + "quant_type": kvargs.get("quant_type", None), + "quant_cfg": kvargs.get("quant_cfg", None), + "run_mode": "normal", + "llm_prefill_att_backend": kvargs.get("llm_prefill_att_backend", args.llm_prefill_att_backend), + "llm_decode_att_backend": kvargs.get("llm_decode_att_backend", args.llm_decode_att_backend), + "vit_att_backend": kvargs.get("vit_att_backend", args.vit_att_backend), + "llm_kv_type": kvargs.get("llm_kv_type", args.llm_kv_type), + "llm_kv_quant_group_size": kvargs.get("llm_kv_quant_group_size", args.llm_kv_quant_group_size), "main_model": main_model, + "mtp_previous_draft_models": draft_models.copy(), + "mtp_mode": args.mtp_mode, } - ) - for i in range(mtp_step): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) - mtp_model_kvargs.update( - { - "weight_dir": args.spec_model_dir, - "max_total_token_num": main_model.mem_manager.size, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, - "main_model": main_model, - "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], - } - ) - draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + + if model_type == "deepseek_v3": + assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + elif model_type == "qwen3_moe": + assert args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] + draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + elif model_type == "mistral": + assert args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] + draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "glm4_moe_lite": + assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) + elif model_type in ("qwen3_5", "qwen3_5_text"): + assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel + + draft_models.append(Qwen3_5MTPModel(mtp_model_kvargs)) + elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"): + assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + + draft_models.append(Qwen3_5MoeMTPModel(mtp_model_kvargs)) + else: + raise ValueError(f"Unsupported MTP model type: {model_type}") + + logger.info(f"loaded mtp model class {draft_models[i].__class__}") return draft_models @@ -68,13 +111,22 @@ def test_model_inference_mtp(args): "max_total_token_num": args.max_total_token_num, "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, - "mem_faction": args.mem_fraction, - "max_req_num": 2000, + "mem_fraction": args.mem_fraction, + # Static bench runs explicit batch sizes (<= a few hundred). The hybrid Qwen3.5 + # GDN req-state cache is sized max_req_num * (mtp_step + 1) at ~34 MB/slot, so the + # old default of 2000 alloc'd ~140 GB and OOM'd under MTP. 512 covers any realistic + # static batch sweep while keeping the GDN cache small. + "max_req_num": 512, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, - "spec_algo": args.spec_algo, "disable_cudagraph": args.disable_cudagraph, + "quant_cfg": args.quant_cfg, + "llm_prefill_att_backend": args.llm_prefill_att_backend, + "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, + "llm_kv_type": args.llm_kv_type, + "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, @@ -113,28 +165,36 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data).cuda() + test_data = torch.from_numpy(test_data) b_req_idx = torch.tensor( - [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]) + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32) + b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len # Main model Prefill model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, + max_q_seq_len=input_len, + max_kv_seq_len=input_len, + max_cache_len=0, input_ids=test_data, - mem_indexes=mem_indexes, + mem_indexes_cpu=mem_indexes, b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, is_prefill=True, b_ready_cache_len=b_ready_cache_len, + b_prefill_start_loc=b_prefill_start_loc, + prefix_total_token_num=0, multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) @@ -167,8 +227,22 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ torch.cuda.synchronize() + # Speculative width = args.mtp_step in BOTH modes (mirrors base_backend: self.mtp_step = + # args.mtp_step). The number of draft MODEL INSTANCES differs: vanilla loads mtp_step + # instances (each forwarded once), eagle loads ONE instance forwarded mtp_step times + # (chunked_prefill/impl.py: draft_models[_step % num_instances]). The verify batch always + # expands to (mtp_step + 1) rows per request. + spec_width = args.mtp_step + num_instances = len(draft_models) + # The draft prefill above produced (1 + num_instances) columns; pad/truncate to + # (spec_width + 1) so the decode verify batch matches the server's expand width. Only the + # SHAPE matters for throughput here (argmax over random inputs); token values do not. + while len(draft_ids) < spec_width + 1: + draft_ids.append(draft_ids[-1]) + draft_ids = draft_ids[: spec_width + 1] decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1) - decode_input_ids = torch.from_numpy(decode_input_ids).cuda() + decode_input_ids = torch.from_numpy(decode_input_ids) + mtp_step = spec_width # build main decode input: nopad_b_seq_idx = [] @@ -177,35 +251,39 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_max_len_in_batch = 0 for i in range(batch_size): - nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_seq_idx.append(b_req_idx[i].item()) seq_len = b_seq_len[i].item() nopad_b_seq_len.append(seq_len + 1) nopad_total_token_num += seq_len + 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1) + nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + 1) - for step in range(len(draft_models)): - nopad_b_seq_idx.append(b_req_idx[i]) + for step in range(mtp_step): + nopad_b_seq_idx.append(b_req_idx[i].item()) nopad_b_seq_len.append(seq_len + step + 2) nopad_total_token_num += seq_len + step + 2 nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) - nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cpu") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cpu") + b_mtp_index = torch.arange(mtp_step + 1, dtype=torch.int32).repeat(batch_size) + mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (mtp_step + 1)) model_input = ModelInput( - batch_size=batch_size * (len(draft_models) + 1), + batch_size=batch_size * (mtp_step + 1), total_token_num=nopad_total_token_num, + max_q_seq_len=1, + max_kv_seq_len=nopad_max_len_in_batch, input_ids=decode_input_ids, - mem_indexes=mem_indexes, + mem_indexes_cpu=mem_indexes, b_req_idx=nopad_b_seq_idx, + b_mtp_index=b_mtp_index, b_seq_len=nopad_b_seq_len, is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))], + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (mtp_step + 1))], ) # Main decode - for i in range(0, output_len, len(draft_models) + 1): + for i in range(0, output_len, mtp_step + 1): torch.cuda.synchronize() step_start_time = time.time() model_output = main_model.forward( @@ -214,12 +292,13 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ prob_out = torch.softmax(model_output.logits, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - # draft decode + # draft decode: mtp_step forwards, reusing draft_models[_step % num_instances] + # (eagle: one instance reused mtp_step times; vanilla: a distinct instance per step). model_input.input_ids = predict_ids.reshape(-1) model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens - for draft_model_id in range(len(draft_models)): - draft_model = draft_models[draft_model_id] + for _step in range(mtp_step): + draft_model = draft_models[_step % num_instances] model_output = draft_model.forward( model_input, ) @@ -237,7 +316,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ if get_current_rank_in_dp() == 0 and not warmup: step_time = step_end_time - step_start_time print(i, " step cost time:", step_time * 1000) - print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + print(f"Decode throughput: {batch_size * (mtp_step + 1) * args.dp / step_time} tokens/s") main_model.mem_manager.free_all() main_model.req_manager.free_all() diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc3..461f289780 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -16,7 +16,7 @@ def test_model_infer(self): args = get_env_start_args() if args.data_type is None: args.data_type = get_dtype(args.model_dir) - if args.mtp_mode == "deepseekv3": + if args.mtp_mode is not None: test_model_inference_mtp(args) else: test_model_inference(args)