Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ class Fa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: Fa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = args_mtp_step > 0 and self.infer_state.b_num_accepted_tokens is not None

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -143,8 +144,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -163,7 +165,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
8 changes: 1 addition & 7 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..base_att import AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
from typing import Union
Expand Down Expand Up @@ -116,12 +115,7 @@ def init_state(self):
super().init_state()
self.backend: Fp8Fa3AttBackend = self.backend

args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
Expand Down
10 changes: 6 additions & 4 deletions lightllm/common/basemodel/attention/fa3/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ class MlaFa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: MlaFa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = args_mtp_step > 0 and self.infer_state.b_num_accepted_tokens is not None

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -126,8 +127,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -146,7 +148,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
Loading
Loading