Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 0 additions & 155 deletions cookbook/transformers/deepseek_v4_flash.py

This file was deleted.

6 changes: 0 additions & 6 deletions cookbook/transformers/deepseek_v4_flash.sh

This file was deleted.

5 changes: 3 additions & 2 deletions cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1'
ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
NUM_GPUS = int(os.environ.get('NUM_GPUS', '8'))

device_mesh = DeviceMesh.from_sizes(
fsdp_size=8,
fsdp_size=NUM_GPUS,
dp_size=1,
ep_size=8,
ep_size=NUM_GPUS,
device_type=Platform.get_platform().device_prefix(),
)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
Expand Down
26 changes: 26 additions & 0 deletions cookbook/transformers/ep_fsdp2_lora_deepseek_v4_multinode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
# Convert the checkpoint before training by following:
# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
# Install `transformers==5.8.0` before running this cookbook.

export DSV4_MODEL_ID="ms://deepseek-ai/DeepSeek-V4-Flash-bf16"
export DATASET_ID="ms://swift/self-cognition"
# The following environment variables are required for multi-node training. Adjust the values according to your cluster setup.
export GLOO_SOCKET_IFNAME="eth0" # Use ifconfig to check the network interface name
export HCCL_SOCKET_IFNAME="eth0"
export HCCL_EXEC_TIMEOUT=1200
export HCCL_CONNECT_TIMEOUT=1200
export NNODES=4
export NUM_GPUS=64
export MASTER_ADDR="node0" # Replace with the IP address or hostname of the master node
export MASTER_PORT=29500 # Replace with an open port on the master node
export HCCL_IF_BASE_PORT=20000

torchrun --nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=16 \
--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT ep_fsdp2_lora_deepseek_v4.py

# NODE_RANK=0 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
# NODE_RANK=1 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
# NODE_RANK=2 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
# NODE_RANK=3 OUTPUT_DIR=./output sh ep_fsdp2_lora_deepseek_v4_multinode.sh
1 change: 0 additions & 1 deletion src/twinkle/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def _try_init_process_group(self):
'init_method': 'env://',
'rank': Platform.get_rank(),
'world_size': Platform.get_world_size(),
'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))),
}
if self._should_bind_device_id_for_process_group(backend):
init_kwargs['device_id'] = torch.device(Platform.get_local_device())
Expand Down
164 changes: 18 additions & 146 deletions src/twinkle/model/transformers/strategy/accelerate.py
Comment thread
tpx818 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -7,150 +7,6 @@
from .load_context import fsdp_pretrained_load_context


def _patch_accelerate_fsdp2_load_full_state_dict():
"""Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers.

Some Transformers models keep persistent buffers in `state_dict`. FSDP2
shards parameters as DTensors, but those buffers can remain ordinary
tensors; older Accelerate versions assume every state-dict entry has
`device_mesh` and fail on such buffers.
"""
import accelerate.utils.fsdp_utils as fsdp_utils
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor

if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
return

original = fsdp_utils.fsdp2_load_full_state_dict

def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
meta_sharded_sd = model.state_dict()
sharded_sd = {}

def _infer_parameter_dtype(model, param_name, empty_param):
old_param = _get_state_dict_param_for_dtype_inference(model, param_name)
is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn')
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
casting_dtype = None
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype

def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor

def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
if device_mesh.device_type == 'cuda':
return distribute_tensor(full_tensor, device_mesh, placements)

local_tensor = full_tensor
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
# All ranks already received the full tensor via broadcast.
# Split locally to avoid distribute_tensor's scatter path,
# which is fragile on some torch_npu/HCCL versions.
local_tensor = placement._shard_tensor(
local_tensor,
device_mesh,
mesh_dim,
src_data_rank=None,
)
elif isinstance(placement, Replicate):
continue
elif isinstance(placement, Partial):
raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.')
else:
raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
return DTensor.from_local(
local_tensor,
device_mesh=device_mesh,
placements=placements,
run_check=False,
shape=full_tensor.shape,
stride=full_tensor.stride(),
)

def _load_full_value(param_name, sharded_param):
if param_name not in full_sd:
raise KeyError(
f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. "
f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.')
full_value = full_sd[param_name].detach()
if isinstance(full_value, DTensor):
full_value = full_value.to_local()
device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device
return full_value.to(device).contiguous()

def _tensor_debug(tensor):
if isinstance(tensor, DTensor):
return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} '
f'placements={tensor.placements} mesh={tensor.device_mesh}')
if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'):
return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}'
return f'type={type(tensor).__name__}'

for param_name, sharded_param in meta_sharded_sd.items():
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
if accelerator.is_main_process:
full_param = _load_full_value(param_name, sharded_param)
else:
full_param = torch.empty(
sharded_param.size(),
device=device_mesh.device_type,
dtype=sharded_param.dtype,
)

dist.broadcast(full_param, src=0, group=dist.group.WORLD)
sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
if cpu_offload:
sharded_tensor = sharded_tensor.to('cpu')
sharded_sd[param_name] = sharded_tensor
continue

if accelerator.is_main_process:
full_value = _load_full_value(param_name, sharded_param)
else:
full_value = torch.empty(
sharded_param.size(),
device=accelerator.device,
dtype=sharded_param.dtype,
)

dist.broadcast(full_value, src=0, group=dist.group.WORLD)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value)
full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype)
if cpu_offload:
full_value = full_value.to('cpu')
sharded_sd[param_name] = full_value

model.load_state_dict(sharded_sd, assign=True)
return model

patched_fsdp2_load_full_state_dict._twinkle_patched = True
patched_fsdp2_load_full_state_dict._twinkle_original = original
fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict


def _get_state_dict_param_for_dtype_inference(model, param_name: str):
try:
return model.get_parameter_or_buffer(param_name)
except AttributeError:
if '.' in param_name:
base_param_name, param_name = param_name.rsplit('.', 1)
model = model.get_submodule(base_param_name)
return getattr(model, param_name)


class AccelerateStrategy:
"""A training strategy that uses `accelerate` to wrap models.

Expand All @@ -172,8 +28,6 @@ def __init__(
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs

_patch_accelerate_fsdp2_load_full_state_dict()

self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._memory_efficient_init = memory_efficient_init
Expand Down Expand Up @@ -349,3 +203,21 @@ def get_full_state_dict(self, model) -> dict:
state_dict[name] = local.cpu()
del local
return state_dict

def get_adapter_state_dict(self, model, adapter_name: str) -> dict:
"""Collect only LoRA adapter parameters."""
from twinkle.utils import torch_util
unwrapped = self.unwrap_model(model)
state_dict = {}
adapter_suffix = f'.{adapter_name}.'
for name, param in unwrapped.named_parameters():
if not _is_lora_state_key(name) or adapter_suffix not in name:
continue
local = torch_util.to_local_tensor(param)
state_dict[name] = local.cpu()
del local
return state_dict


def _is_lora_state_key(name: str) -> bool:
return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name
Loading
Loading