Support MTP loss mask_type v1 and multi mtp_config#1919
Conversation
| loss = torch.tensor(0.0, device=DEVICE) | ||
| for key in model_outputs.model_fields: | ||
| value = getattr(model_outputs, key) | ||
| if "loss" in key and isinstance(value, torch.Tensor): |
There was a problem hiding this comment.
The loss part should be autonomous to the model, rather than being hardcoded here.
| mask_type = self.loss_cfg.mask_type | ||
| if mask_type == "v1": | ||
| self.process_loss_weight_v1() | ||
| elif mask_type is not None: | ||
| raise NotImplementedError(f"Unknown MTP Loss Mask Type: {mask_type}") |
There was a problem hiding this comment.
The calculation logic of loss should not be hard-coded here; please implement a new loss_context.
| import types | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast | ||
| from typing import TYPE_CHECKING, Annotated, List, Literal, Self, Sequence, TypedDict, cast |
There was a problem hiding this comment.
Use python builtin list
| if key == "mtp_loss" and isinstance(value, dict): | ||
| for mtp_loss_name, mtp_loss in value.items(): | ||
| loss += mtp_loss | ||
| elif "loss" in key and isinstance(value, torch.Tensor): |
There was a problem hiding this comment.
| if key == "mtp_loss" and isinstance(value, dict): | |
| for mtp_loss_name, mtp_loss in value.items(): | |
| loss += mtp_loss | |
| elif "loss" in key and isinstance(value, torch.Tensor): | |
| elif "loss" in key: | |
| loss_values = list(value.values()) if isinstance(value, dict) else [value] | |
| loss_values = [i for i in loss_values if isinstance(i, torch.Tensor)] | |
| for value in loss_values: | |
| loss += value |
| output, | ||
| layer_hidden_states, | ||
| position_embeddings, | ||
| balancing_ctx, | ||
| z_ctx, | ||
| mtp_seq_ctx, | ||
| mtp_loss_ctx_dict, | ||
| keep_router: bool, |
| Args: | ||
| idx (int): 1-indexed MTP layer depth to bind. | ||
| """ | ||
| self.mtp_depth = idx |
There was a problem hiding this comment.
Depth and layer count are two different things. The user configures the number of layers in the config, whereas the layer index — which layer this actually is — is known at construction time. I'd suggest moving the bind_mtp_depth logic into LossContext, so that a single MTPConfig can produce LossContext instances for the different layers.
| for mtp_config in self.config.mtp_config: | ||
| self._mtp_forward( | ||
| mtp_config=mtp_config, | ||
| output=output, | ||
| layer_hidden_states=layer_hidden_states, | ||
| position_embeddings=position_embeddings, |
There was a problem hiding this comment.
Let's keep this part of the code unchanged for now — we shouldn't be doing code cleanup/refactoring in this PR.
| global_mtp_idx = 0 # Track global MTP layer index across all mtp_configs | ||
| for mtp_name in self.mtp_block.keys(): | ||
| mtp_block = self.mtp_block[mtp_name] | ||
| mtp_config = next((cfg for cfg in self.config.mtp_config if cfg.name == mtp_name), None) # type: ignore |
There was a problem hiding this comment.
Just pass mtp_config directly into MTPBlock — no need for all this indirection.
| mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) | ||
| mtp_block.layers[local_mtp_idx] = mtp_layer | ||
|
|
||
| reshard_after_forward = local_mtp_idx != len(mtp_block.layers) - 1 |
There was a problem hiding this comment.
reshard_after_forward should only be set to True for the last layer of the last mtp block
| mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) | ||
| mtp_block.layers[local_mtp_idx] = mtp_layer |
There was a problem hiding this comment.
| mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) | |
| mtp_block.layers[local_mtp_idx] = mtp_layer | |
| mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) | |
| mtp_block.layers[local_mtp_idx] = mtp_layer |
| self.rotary_emb = self.build_rotary_embedding(config) | ||
| self.embed_tokens = self.build_embeddings(config) | ||
| self.mtp_block = self.build_mtp_block(config) if config.mtp_config is not None else None | ||
| self.mtp_block = self.build_mtp_block_dict(config) if config.mtp_config is not None else None |
There was a problem hiding this comment.
Could we switch this to a ModuleList too? That way it'd be symmetric with the config.
mtp config example:
Refer to HAOCHENYE#2