Skip to content

Support MTP loss mask_type v1 and multi mtp_config#1919

Open
x54-729 wants to merge 4 commits into
InternLM:mainfrom
x54-729:mtp_260616
Open

Support MTP loss mask_type v1 and multi mtp_config#1919
x54-729 wants to merge 4 commits into
InternLM:mainfrom
x54-729:mtp_260616

Conversation

@x54-729

@x54-729 x54-729 commented Jun 16, 2026

Copy link
Copy Markdown

mtp config example:

model_cfg.text_config.mtp_config = [
    MTPConfig(name="normal", num_layers=TEXT_MTP_LAYERS, share_weights=TEXT_MTP_LAYERS>1, loss_scaling_factor=TEXT_MTP_FACTOR),
    MTPConfig(name="sci", num_layers=NUM_MTP_LAYERS, share_weights=NUM_MTP_LAYERS>1, loss_scaling_factor=NUM_MTP_FACTOR, loss_cfg=SciMTPLossConfig(mask_type="v1"))
]

Refer to HAOCHENYE#2

loss = torch.tensor(0.0, device=DEVICE)
for key in model_outputs.model_fields:
value = getattr(model_outputs, key)
if "loss" in key and isinstance(value, torch.Tensor):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss part should be autonomous to the model, rather than being hardcoded here.

Comment thread xtuner/v1/loss/mtp_loss.py Outdated
Comment on lines +172 to +176
mask_type = self.loss_cfg.mask_type
if mask_type == "v1":
self.process_loss_weight_v1()
elif mask_type is not None:
raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation logic of loss should not be hard-coded here; please implement a new loss_context.

Comment thread xtuner/v1/loss/mtp_loss.py
Comment thread xtuner/v1/model/moe/moe.py Outdated
import types
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast
from typing import TYPE_CHECKING, Annotated, List, Literal, Self, Sequence, TypedDict, cast

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use python builtin list

Comment on lines +572 to +575
if key == "mtp_loss" and isinstance(value, dict):
for mtp_loss_name, mtp_loss in value.items():
loss += mtp_loss
elif "loss" in key and isinstance(value, torch.Tensor):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if key == "mtp_loss" and isinstance(value, dict):
for mtp_loss_name, mtp_loss in value.items():
loss += mtp_loss
elif "loss" in key and isinstance(value, torch.Tensor):
elif "loss" in key:
loss_values = list(value.values()) if isinstance(value, dict) else [value]
loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)]
for value in loss_values:
loss += value

Comment on lines +693 to +700
output,
layer_hidden_states,
position_embeddings,
balancing_ctx,
z_ctx,
mtp_seq_ctx,
mtp_loss_ctx_dict,
keep_router: bool,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing typehint here

Comment thread xtuner/v1/loss/mtp_loss.py Outdated
Args:
idx (int): 1-indexed MTP layer depth to bind.
"""
self.mtp_depth = idx

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depth and layer count are two different things. The user configures the number of layers in the config, whereas the layer index — which layer this actually is — is known at construction time. I'd suggest moving the bind_mtp_depth logic into LossContext, so that a single MTPConfig can produce LossContext instances for the different layers.

Comment on lines +868 to +873
for mtp_config in self.config.mtp_config:
self._mtp_forward(
mtp_config=mtp_config,
output=output,
layer_hidden_states=layer_hidden_states,
position_embeddings=position_embeddings,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this part of the code unchanged for now — we shouldn't be doing code cleanup/refactoring in this PR.

global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs
for mtp_name in self.mtp_block.keys():
mtp_block = self.mtp_block[mtp_name]
mtp_config = next((cfg for cfg in self.config.mtp_config if cfg.name == mtp_name), None) # type: ignore

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pass mtp_config directly into MTPBlock — no need for all this indirection.

mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT)
mtp_block.layers[local_mtp_idx] = mtp_layer

reshard_after_forward = local_mtp_idx != len(mtp_block.layers) - 1

@HAOCHENYE HAOCHENYE Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshard_after_forward should only be set to True for the last layer of the last mtp block

Comment on lines +1192 to +1193
mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT)
mtp_block.layers[local_mtp_idx] = mtp_layer

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT)
mtp_block.layers[local_mtp_idx] = mtp_layer
mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT)
mtp_block.layers[local_mtp_idx] = mtp_layer

self.rotary_emb = self.build_rotary_embedding(config)
self.embed_tokens = self.build_embeddings(config)
self.mtp_block = self.build_mtp_block(config) if config.mtp_config is not None else None
self.mtp_block = self.build_mtp_block_dict(config) if config.mtp_config is not None else None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we switch this to a ModuleList too? That way it'd be symmetric with the config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants