diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 7d1bba33774..cc8fc0711ed 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -446,7 +446,7 @@ def apply_tp( gate_out = gate_out.cast("float32") if fc1_latent_proj is not None: x = fc1_latent_proj(x) - gate_out, topk_weights, topk_idx = get_moe_scores( + gate_out, _, __ = get_moe_scores( gate_out, layer.n_group, layer.topk_group, @@ -458,11 +458,6 @@ def apply_tp( use_fused_cast=use_fused, ) - if layer.routed_scaling_factor_learnable: - safe_topk_indices = paddle.clip(topk_idx, min=0) - gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1) - topk_weights = topk_weights * gathered_scales - ( permute_input, token_nums_per_expert, @@ -484,6 +479,12 @@ def apply_tp( self.moe_quant_type, topk_only_mode=True, ) + + if layer.routed_scaling_factor_learnable: + safe_topk_indices = paddle.clip(topk_idx, min=0) + gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1) + topk_weights = topk_weights * gathered_scales + else: gate_out = gate_out.cast("float32") if fc1_latent_proj is not None: