diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py deleted file mode 100644 index 869f4cc8..00000000 --- a/cookbook/transformers/deepseek_v4_flash.py +++ /dev/null @@ -1,155 +0,0 @@ -import os - -import twinkle -from peft import LoraConfig -from transformers import AutoConfig -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() -# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. -# Convert the checkpoint before training by following: -# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 -# Install `transformers==5.8.0` before running this cookbook. -MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') - -NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4')) - -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2')) -LR = float(os.environ.get('LR', '1e-4')) -MAX_STEPS = int(os.environ.get('MAX_STEPS', '0')) -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50')) -RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' -GRADIENT_CHECKPOINTING = True -IGNORE_MISMATCHED_SIZES = False -LORA_TARGET_MODULES = [ - 'q_a_proj', - 'q_b_proj', - 'kv_proj', - 'o_b_proj', - 'gate_proj', - 'up_proj', - 'down_proj', -] -ADAPTER_NAME = 'default' - -device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, - dp_size=1, - ep_size=8, - device_type=Platform.get_platform().device_prefix(), -) - -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def create_dataset(data_slice=None): - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000))) - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - dataset.encode(batched=True) - return dataset - - -def eval(model): - dataset = create_dataset(data_slice=range(100)) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - for _, batch in enumerate(dataloader): - if callable(batch): - batch = batch() - model.forward_only(inputs=batch, adapter_name=ADAPTER_NAME) - model.calculate_loss(adapter_name=ADAPTER_NAME) - return model.calculate_metric(is_training=False, adapter_name=ADAPTER_NAME) - - -def train(): - dataset = create_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'): - config.num_hidden_layers = NUM_LAYERS - if hasattr(config, 'use_cache'): - config.use_cache = False - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - strategy='native_fsdp', - memory_efficient_init=True, - ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, - fsdp_config={ - 'reshard_after_forward': RESHARD_AFTER_FORWARD, - 'expert_parallel': { - 'enabled': True, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - }, - }, - ) - - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=LORA_TARGET_MODULES) - model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - - if not GRADIENT_CHECKPOINTING: - model.model.gradient_checkpointing_disable() - - model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME) - model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - adapter_name=ADAPTER_NAME, - ) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs(adapter_name=ADAPTER_NAME)) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, ' - f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, ' - f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, ' - f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, ' - f'reshard_after_forward={RESHARD_AFTER_FORWARD}, ' - f'lora_target_modules={LORA_TARGET_MODULES}') - - best_loss = float('inf') - for step, batch in enumerate(dataloader): - if MAX_STEPS and step >= MAX_STEPS: - break - if callable(batch): - batch = batch() - model.forward_backward( - inputs=batch, - adapter_name=ADAPTER_NAME, - ) - model.clip_grad_and_step( - adapter_name=ADAPTER_NAME, - gradient_accumulation_steps=GRAD_ACCUM_STEPS, - ) - - if step % 20 == 0: - metric = model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - - if step > 0 and step % SAVE_STEPS == 0: - metrics = eval(model) - logger.info(f'Eval metric: {metrics}') - loss = float(metrics['loss']) - if loss < best_loss: - model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) - best_loss = loss - - model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) - - -if __name__ == '__main__': - train() diff --git a/cookbook/transformers/deepseek_v4_flash.sh b/cookbook/transformers/deepseek_v4_flash.sh deleted file mode 100644 index bbdb58ff..00000000 --- a/cookbook/transformers/deepseek_v4_flash.sh +++ /dev/null @@ -1,6 +0,0 @@ -# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. -# Convert the checkpoint before training by following: -# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 -# Install `transformers==5.8.0` before running this cookbook. - -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 cookbook/transformers/deepseek_v4_flash.py diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py index 170010df..af72efa1 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py @@ -35,11 +35,12 @@ RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') +NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, + fsdp_size=NUM_GPUS, dp_size=1, - ep_size=8, + ep_size=NUM_GPUS, device_type=Platform.get_platform().device_prefix(), ) twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh new file mode 100644 index 00000000..7344474e --- /dev/null +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh @@ -0,0 +1,26 @@ + +# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. +# Convert the checkpoint before training by following: +# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 +# Install `transformers==5.8.0` before running this cookbook. + +export DSV4_MODEL_ID="ms://deepseek-ai/DeepSeek-V4-Flash-bf16" +export DATASET_ID="ms://swift/self-cognition" +# The following environment variables are required for multi-node training. Adjust the values according to your cluster setup. +export GLOO_SOCKET_IFNAME="eth0" # Use ifconfig to check the network interface name +export HCCL_SOCKET_IFNAME="eth0" +export HCCL_EXEC_TIMEOUT=1200 +export HCCL_CONNECT_TIMEOUT=1200 +export NNODES=4 +export NUM_GPUS=64 +export MASTER_ADDR="node0" # Replace with the IP address or hostname of the master node +export MASTER_PORT=29500 # Replace with an open port on the master node +export HCCL_IF_BASE_PORT=20000 + +torchrun --nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=16 \ + --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT ep_fsdp2_lora_deepseek_v4.py + +# NODE_RANK=0 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=1 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=2 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh +# NODE_RANK=3 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 303eaf74..a4d4ea06 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -165,7 +165,6 @@ def _try_init_process_group(self): 'init_method': 'env://', 'rank': Platform.get_rank(), 'world_size': Platform.get_world_size(), - 'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))), } if self._should_bind_device_id_for_process_group(backend): init_kwargs['device_id'] = torch.device(Platform.get_local_device()) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 7b3d414b..d0434991 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -7,150 +7,6 @@ from .load_context import fsdp_pretrained_load_context -def _patch_accelerate_fsdp2_load_full_state_dict(): - """Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers. - - Some Transformers models keep persistent buffers in `state_dict`. FSDP2 - shards parameters as DTensors, but those buffers can remain ordinary - tensors; older Accelerate versions assume every state-dict entry has - `device_mesh` and fail on such buffers. - """ - import accelerate.utils.fsdp_utils as fsdp_utils - import torch - import torch.distributed as dist - from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor - - if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False): - return - - original = fsdp_utils.fsdp2_load_full_state_dict - - def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False): - meta_sharded_sd = model.state_dict() - sharded_sd = {} - - def _infer_parameter_dtype(model, param_name, empty_param): - old_param = _get_state_dict_param_for_dtype_inference(model, param_name) - is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn') - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - casting_dtype = None - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - casting_dtype = old_param.dtype - return old_param is not None and old_param.is_contiguous(), casting_dtype - - def _cast_and_contiguous(tensor, to_contiguous, dtype): - if dtype is not None: - tensor = tensor.to(dtype=dtype) - if to_contiguous: - tensor = tensor.contiguous() - return tensor - - def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): - if device_mesh.device_type == 'cuda': - return distribute_tensor(full_tensor, device_mesh, placements) - - local_tensor = full_tensor - for mesh_dim, placement in enumerate(placements): - if isinstance(placement, Shard): - # All ranks already received the full tensor via broadcast. - # Split locally to avoid distribute_tensor's scatter path, - # which is fragile on some torch_npu/HCCL versions. - local_tensor = placement._shard_tensor( - local_tensor, - device_mesh, - mesh_dim, - src_data_rank=None, - ) - elif isinstance(placement, Replicate): - continue - elif isinstance(placement, Partial): - raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.') - else: - raise NotImplementedError(f'Unsupported DTensor placement: {placement}') - return DTensor.from_local( - local_tensor, - device_mesh=device_mesh, - placements=placements, - run_check=False, - shape=full_tensor.shape, - stride=full_tensor.stride(), - ) - - def _load_full_value(param_name, sharded_param): - if param_name not in full_sd: - raise KeyError( - f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. " - f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.') - full_value = full_sd[param_name].detach() - if isinstance(full_value, DTensor): - full_value = full_value.to_local() - device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device - return full_value.to(device).contiguous() - - def _tensor_debug(tensor): - if isinstance(tensor, DTensor): - return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} ' - f'placements={tensor.placements} mesh={tensor.device_mesh}') - if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'): - return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}' - return f'type={type(tensor).__name__}' - - for param_name, sharded_param in meta_sharded_sd.items(): - if isinstance(sharded_param, DTensor): - device_mesh = sharded_param.device_mesh - placements = sharded_param.placements - if accelerator.is_main_process: - full_param = _load_full_value(param_name, sharded_param) - else: - full_param = torch.empty( - sharded_param.size(), - device=device_mesh.device_type, - dtype=sharded_param.dtype, - ) - - dist.broadcast(full_param, src=0, group=dist.group.WORLD) - sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements) - to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param) - sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) - if cpu_offload: - sharded_tensor = sharded_tensor.to('cpu') - sharded_sd[param_name] = sharded_tensor - continue - - if accelerator.is_main_process: - full_value = _load_full_value(param_name, sharded_param) - else: - full_value = torch.empty( - sharded_param.size(), - device=accelerator.device, - dtype=sharded_param.dtype, - ) - - dist.broadcast(full_value, src=0, group=dist.group.WORLD) - to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value) - full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype) - if cpu_offload: - full_value = full_value.to('cpu') - sharded_sd[param_name] = full_value - - model.load_state_dict(sharded_sd, assign=True) - return model - - patched_fsdp2_load_full_state_dict._twinkle_patched = True - patched_fsdp2_load_full_state_dict._twinkle_original = original - fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict - - -def _get_state_dict_param_for_dtype_inference(model, param_name: str): - try: - return model.get_parameter_or_buffer(param_name) - except AttributeError: - if '.' in param_name: - base_param_name, param_name = param_name.rsplit('.', 1) - model = model.get_submodule(base_param_name) - return getattr(model, param_name) - - class AccelerateStrategy: """A training strategy that uses `accelerate` to wrap models. @@ -172,8 +28,6 @@ def __init__( from accelerate import Accelerator from accelerate.utils import InitProcessGroupKwargs - _patch_accelerate_fsdp2_load_full_state_dict() - self.device_mesh = device_mesh self.mixed_precision = mixed_precision self._memory_efficient_init = memory_efficient_init @@ -349,3 +203,21 @@ def get_full_state_dict(self, model) -> dict: state_dict[name] = local.cpu() del local return state_dict + + def get_adapter_state_dict(self, model, adapter_name: str) -> dict: + """Collect only LoRA adapter parameters.""" + from twinkle.utils import torch_util + unwrapped = self.unwrap_model(model) + state_dict = {} + adapter_suffix = f'.{adapter_name}.' + for name, param in unwrapped.named_parameters(): + if not _is_lora_state_key(name) or adapter_suffix not in name: + continue + local = torch_util.to_local_tensor(param) + state_dict[name] = local.cpu() + del local + return state_dict + + +def _is_lora_state_key(name: str) -> bool: + return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index b695f50b..e1e21ff4 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os import torch import torch.distributed as dist from torch import nn @@ -53,8 +54,11 @@ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None: return if not (enable_ep and self.use_rank0_pretrained_broadcast()): return - is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 - self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {}) + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') + is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0 + self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_source_rank else {}) self._pre_ep_state_captured = True def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool): @@ -131,15 +135,19 @@ def wrap_model(self, model, optimizer=None): adapter_source_sd = {} adapter_full_sd = {} if use_meta: - is_rank0 = (dist.get_rank() == 0) + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP node-local state loading requires LOCAL_RANK.') + is_source_rank = local_rank == 0 if ep_enabled and self._rank0_pre_ep_full_state_dict is not None: - original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {} + original_sd = self._rank0_pre_ep_full_state_dict if is_source_rank else {} else: - original_sd = model.state_dict() if is_rank0 else {} + original_sd = model.state_dict() if is_source_rank else {} adapter_source_sd = _collect_adapter_source_state(model.state_dict()) - adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {} - saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} - if is_rank0: + adapter_full_sd = ( + self._adapter_full_state_dict if is_source_rank and self._adapter_full_state_dict else {}) + saved_buffers = _get_non_persistent_buffers(model) if is_source_rank else {} + if is_source_rank: model = model.to(torch.device('meta')) if hasattr(model, 'tie_weights'): model.tie_weights() @@ -342,6 +350,39 @@ def get_full_state_dict(self, model) -> dict: return state_dict + def get_adapter_state_dict(self, model, adapter_name: str) -> dict: + """Collect only LoRA adapter parameters, with EP-aware all-gather.""" + unwrapped = self.unwrap_model(model) + state_dict = {} + + ep_fsdp_mesh = self.ep_fsdp_device_mesh + ep_group = None + ep_world_size = 1 + if ep_fsdp_mesh is not None: + ep_group = ep_fsdp_mesh['ep'].get_group() + ep_world_size = ep_fsdp_mesh['ep'].size() + + ep_expert_names = _detect_ep_expert_names(unwrapped) if ep_world_size > 1 else set() + adapter_suffix = f'.{adapter_name}.' + + for name, param in unwrapped.named_parameters(): + if not _is_lora_state_key(name) or adapter_suffix not in name: + continue + + local_full = torch_util.to_local_tensor(param) + if name in ep_expert_names and ep_world_size > 1 and ep_group is not None: + local_full = local_full.contiguous().to(Platform.get_local_device()) + gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)] + dist.all_gather(gathered, local_full, group=ep_group) + local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name)) + state_dict[name] = local_full.cpu() + del gathered, local_full + else: + state_dict[name] = local_full.cpu() + del local_full + + return state_dict + def _detect_ep_expert_names(model: nn.Module) -> Set[str]: candidate_names = set() @@ -534,6 +575,25 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di return rank_to_ep_rank +def _get_local_rank_info() -> tuple[int, int, int, List[int]]: + """Return local-rank topology for node-local state-dict fanout.""" + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = Platform.get_local_rank() + if 'LOCAL_WORLD_SIZE' not in os.environ and 'LOCAL_SIZE' not in os.environ: + raise RuntimeError('Native FSDP node-local state loading requires LOCAL_WORLD_SIZE or LOCAL_SIZE.') + local_world_size = Platform.get_local_world_size() + if local_rank < 0 or local_world_size <= 0 or world_size % local_world_size != 0: + raise RuntimeError(f'Invalid local rank topology: rank={rank}, world_size={world_size}, ' + f'local_rank={local_rank}, local_world_size={local_world_size}.') + node_start = rank - local_rank + node_ranks = list(range(node_start, min(node_start + local_world_size, world_size))) + if rank not in node_ranks or len(node_ranks) != local_world_size: + raise RuntimeError(f'Invalid local rank group: rank={rank}, local_rank={local_rank}, ' + f'local_world_size={local_world_size}, node_ranks={node_ranks}.') + return rank, world_size, node_start, node_ranks + + def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: """Find the experts module inside a decoder layer, if any.""" for module in layer_mod.modules(): @@ -793,7 +853,11 @@ def _broadcast_sharded_state_dict( meta_sharded_sd = model.state_dict() sharded_sd = {} - is_rank0 = (dist.get_rank() == 0) + rank, _, local_source_rank, local_ranks = _get_local_rank_info() + is_rank0 = (rank == 0) + is_source_rank = rank == local_source_rank + use_local_broadcast = Platform.device_backend() != 'hccl' + local_group = dist.new_group(ranks=local_ranks) if use_local_broadcast else None expert_shard_specs = expert_shard_specs or {} rank_to_ep_rank = rank_to_ep_rank or {} adapter_source_sd = adapter_source_sd or {} @@ -819,6 +883,23 @@ def _broadcast_sharded_state_dict( source_keys = metadata_holder[1] or {} adapter_metadata = metadata_holder[2] or {} + def _broadcast_from_local_source(full_tensor): + if is_source_rank: + if full_tensor is None: + raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') + if use_local_broadcast: + dist.broadcast(full_tensor, src=local_source_rank, group=local_group) + return full_tensor + + if is_source_rank: + for target_rank in local_ranks: + if target_rank == rank: + continue + dist.send(full_tensor, dst=target_rank) + else: + dist.recv(full_tensor, src=local_source_rank) + return full_tensor + def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): local_tensor = full_tensor for mesh_dim, placement in enumerate(placements): @@ -849,36 +930,20 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): ) def _broadcast_adapter_source_tensor(full_tensor, sharded_param): - if not isinstance(sharded_param, DTensor): - dist.broadcast(full_tensor, src=0) - return full_tensor - mesh = sharded_param.device_mesh.mesh - source_rank = int(mesh.flatten()[0].item()) - dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group()) - return full_tensor + return _broadcast_from_local_source(full_tensor) def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param): local_shape = tuple(sharded_param.size()) _, source_dtype = adapter_metadata[param_name] local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) - - if is_rank0: - shard_dim = _ep_expert_state_dict_gather_dim(param_name) - local_dim = local_shape[shard_dim] - world_size = dist.get_world_size() - for rank in range(world_size): - if rank not in rank_to_ep_rank: - raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') - ep_rank = rank_to_ep_rank[rank] - start = ep_rank * local_dim - chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type) - if rank == 0: - local_tensor.copy_(chunk) - else: - dist.send(chunk, dst=rank) - else: - dist.recv(local_tensor, src=0) - + shard_dim = _ep_expert_state_dict_gather_dim(param_name) + local_dim = local_shape[shard_dim] + local_tensor = _scatter_ep_tensor_from_source( + full_tensor, + local_tensor, + shard_dim=shard_dim, + shard_size=local_dim, + ) return local_tensor def _get_adapter_source(param_name): @@ -903,27 +968,35 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): _, source_dtype = source_metadata[param_name] local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) - if is_rank0: + if is_source_rank: if full_tensor.size(0) != num_experts: raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " f'but source state has shape {tuple(full_tensor.shape)}. ' 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') - world_size = dist.get_world_size() - for rank in range(world_size): - if rank not in rank_to_ep_rank: - raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') - ep_rank = rank_to_ep_rank[rank] - start = ep_rank * experts_per_rank - end = start + experts_per_rank - chunk = full_tensor[start:end].contiguous() - chunk_gpu = chunk.to(device_type) - if rank == 0: - local_tensor.copy_(chunk_gpu) + local_tensor = _scatter_ep_tensor_from_source( + full_tensor, + local_tensor, + shard_dim=0, + shard_size=experts_per_rank, + ) + return local_tensor + + def _scatter_ep_tensor_from_source(full_tensor, local_tensor, *, shard_dim: int, shard_size: int): + if is_source_rank: + if full_tensor is None: + raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') + for target_rank in local_ranks: + if target_rank not in rank_to_ep_rank: + raise RuntimeError(f'Missing EP rank mapping for global rank {target_rank}.') + ep_rank = rank_to_ep_rank[target_rank] + start = ep_rank * shard_size + chunk = full_tensor.narrow(shard_dim, start, shard_size).contiguous().to(device_type) + if target_rank == rank: + local_tensor.copy_(chunk) else: - dist.send(chunk_gpu, dst=rank) + dist.send(chunk, dst=target_rank) else: - dist.recv(local_tensor, src=0) - + dist.recv(local_tensor, src=local_source_rank) return local_tensor for param_name, sharded_param in meta_sharded_sd.items(): @@ -950,7 +1023,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) if not is_ep_adapter_param: full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param) - elif is_rank0: + elif is_source_rank: source_key = source_keys[param_name] if source_key not in full_sd: raise KeyError( @@ -979,7 +1052,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') if not is_adapter_param: - dist.broadcast(full_tensor, src=0) + full_tensor = _broadcast_from_local_source(full_tensor) torch_util.synchronize() if isinstance(sharded_param, DTensor): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a9a80b8f..e8b8c725 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -209,7 +209,12 @@ def __init__( def _should_init_empty_pretrained_model_on_this_rank(self) -> bool: use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False) - return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0) + if not (use_rank0_broadcast() and dist.is_available() and dist.is_initialized()): + return False + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP memory_efficient_init requires LOCAL_RANK.') + return local_rank != 0 def _init_empty_model_from_config(self, model_cls, **kwargs): from accelerate import init_empty_weights @@ -907,15 +912,10 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int # Full model save processed_state_dict = self.strategy.get_full_state_dict(self.model) else: - # LoRA adapter save (EP-aware via strategy.get_full_state_dict) - full_state = self.strategy.get_full_state_dict(self.model) - adapter_marker = '.lora_' + # LoRA adapter save. Avoid collecting the full base model for large FSDP/EP jobs. + adapter_state = self.strategy.get_adapter_state_dict(self.model, adapter_name) adapter_suffix = f'.{adapter_name}.' - for key, value in full_state.items(): - if adapter_marker not in key: - continue - if adapter_suffix not in key: - continue + for key, value in adapter_state.items(): normalized = key.replace(adapter_suffix, '.') processed_state_dict[normalized] = value