Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ supported on Twinkle✨ framework.
| | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) |
| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) |
| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)
| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)
| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)|
| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)|
| Gemma4全系列 | [google/gemma-4-E2B](https://www.modelscope.cn/models/google/gemma-4-E2B) | 2.3B effective (5.1B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E2B · Hugging Face](https://huggingface.co/google/gemma-4-E2B) |
| | [google/gemma-4-E4B](https://www.modelscope.cn/models/google/gemma-4-E4B) | 4.5B effective (8B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E4B · Hugging Face](https://huggingface.co/google/gemma-4-E4B) |
| | [google/gemma-4-31B](https://www.modelscope.cn/models/google/gemma-4-31B) | 30.7B | transformers>=5.8.0 | ✘ | [google/gemma-4-31B · Hugging Face](https://huggingface.co/google/gemma-4-31B) |
| | [google/gemma-4-26B-A4B](https://www.modelscope.cn/models/google/gemma-4-26B-A4B) | 25.2B (Active 3.8B) | transformers>=5.8.0 | ✘ | [google/gemma-4-26B-A4B · Hugging Face](https://huggingface.co/google/gemma-4-26B-A4B) |

## Sample Code

Expand Down
8 changes: 6 additions & 2 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,12 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl
| | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) |
| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) |
| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)
| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)
| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)|
| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)|
| Gemma4全系列 | [google/gemma-4-E2B](https://www.modelscope.cn/models/google/gemma-4-E2B) | 2.3B effective (5.1B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E2B · Hugging Face](https://huggingface.co/google/gemma-4-E2B) |
| | [google/gemma-4-E4B](https://www.modelscope.cn/models/google/gemma-4-E4B) | 4.5B effective (8B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E4B · Hugging Face](https://huggingface.co/google/gemma-4-E4B) |
| | [google/gemma-4-31B](https://www.modelscope.cn/models/google/gemma-4-31B) | 30.7B | transformers>=5.8.0 | ✘ | [google/gemma-4-31B · Hugging Face](https://huggingface.co/google/gemma-4-31B) |
| | [google/gemma-4-26B-A4B](https://www.modelscope.cn/models/google/gemma-4-26B-A4B) | 25.2B (Active 3.8B) | transformers>=5.8.0 | ✘ | [google/gemma-4-26B-A4B · Hugging Face](https://huggingface.co/google/gemma-4-26B-A4B) |

## 示例代码

Expand Down
171 changes: 171 additions & 0 deletions cookbook/mm/fsdp2_gemma4_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig
from transformers import (
Gemma4Config,
)

import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor

logger = get_logger()

########## Construct a device_mesh ##########
device_mesh = DeviceMesh.from_sizes(
# fsdp_size=2,
# dp_size=1,
# ep_size=2,
device_type=Platform.get_platform().device_prefix(),
)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

########## hyperparameters ##########
IGNORE_MISMATCHED_SIZES = True
MODEL_PATH = 'ms://google/gemma-4-26b-a4b'
DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR'
TRAIN_LEN = 2000
BATCH_SIZE = 4
METRIC_STEP = 10
SAVE_STEP = 10

### reduce model layers for debug
TEXT_NUM_LAYERS = 3
VISION_NUM_LAYERS = 3


from twinkle.preprocessor import Preprocessor
from twinkle.data_format import Message, Trajectory
class LatexOCRProcessor(Preprocessor):

def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
col = self.map_row_to_col(rows)
return col

def preprocess(self, row) -> Trajectory:
return Trajectory(
messages=[
Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]),
Message(role='assistant', content=row['text']),
]
)

def eval(model, eval_dataloader):
for step, batch in tqdm(enumerate(eval_dataloader)):
model.forward_only(inputs=batch)
model.calculate_loss()
metrics = model.calculate_metric(is_training=False)
return metrics

def train():

### prepare dataset and dataloader
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN)))
# Set template to prepare encoding
dataset.set_template('Template', model_id=MODEL_PATH)
# Preprocess the dataset to standard format
# dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.map(preprocess_func=LatexOCRProcessor)
# Encode dataset
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)

config, kwargs = AutoConfig.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
return_unused_kwargs=True,
# code_revision=code_revision,
# _commit_hash=commit_hash,
# **hub_kwargs,
# **kwargs,
)

if isinstance(config, Gemma4Config): # 减层
text_config = config.text_config
vision_config = config.vision_config
if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
text_config.num_hidden_layers = TEXT_NUM_LAYERS
logger.info(f' modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}')
if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'):
vision_config.num_hidden_layers = VISION_NUM_LAYERS
logger.info(f' modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}')
if hasattr(config, 'use_cache'):
config.use_cache = False

# Use a TransformersModel
model = TransformersModel(
model_id=MODEL_PATH,
config=config,
device_mesh=device_mesh,
strategy='accelerate', # native_fsdp、 accelerate
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': True,
'expert_parallel': {
'enabled': True,
'router_dtype': 'fp32',
'keep_router_logits': False,
}
},
)

lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')

# Add a lora to model, with name `default`
# Comment this to use full-parameter training
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
# Add Optimizer for lora `default`
model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)

# Add LRScheduler for lora `default`
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))

logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
best_eval_loss = float('inf')
# lora: 8G * 8
# full: 18G * 8

### eval dataset and dataloader
EVAL_LENGTH = 100
eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH)))
eval_dataset.set_template('Template', model_id=MODEL_PATH)
# eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
eval_dataset.map(preprocess_func=LatexOCRProcessor)
eval_dataset.encode()
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8)
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()

if step % METRIC_STEP == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}')

if step % SAVE_STEP == 0:
metrics = eval(model, eval_dataloader)
metrics['step'] = step
if float(metrics['loss']) < best_eval_loss:
# model.save(f'checkpoint-{step}')
best_eval_loss = float(metrics['loss'])
metrics['best_eval_loss'] = best_eval_loss
logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}')

# model.save(f'last-checkpoint')


if __name__ == '__main__':
train()
3 changes: 3 additions & 0 deletions cookbook/mm/fsdp2_gemma4_mm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export CUDA_VISIBLE_DEVICES=0,1

torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py
11 changes: 10 additions & 1 deletion src/twinkle/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class InputProcessor:
'pixel_values_videos',
'video_grid_thw',
'input_features',
'input_features_mask',
'feature_attention_mask',
'grid_thws',
}
Expand Down Expand Up @@ -592,7 +593,15 @@ def is_mm_position_ids(position_ids):
for field, values in vlm_fields.items():
if values:
_values = []
for value in values:
for i, value in enumerate(values):
if field == 'input_features': # [freq_bins, time_steps] -> [freq_bins, time_steps, num_features]
assert len(value.shape) == 2
value = value.unsqueeze(-1)
if field == 'input_features_mask': # [freq_bins,] -> [freq_bins, time_steps]
assert len(value.shape) == 1
input_features_shape = vlm_fields['input_features'][i].shape
assert value.shape[0] == input_features_shape[0]
value = value.unsqueeze(1).expand(input_features_shape[:2])
if value.dim() == 1:
# image_thw may be squeezed
value = value.unsqueeze(0)
Expand Down
39 changes: 25 additions & 14 deletions src/twinkle/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self,
**kwargs):
self.model_id = model_id
model_id = HubOperation.download_model(model_id, ignore_model=True)
if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')):
if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')) or os.path.exists(
os.path.join(model_id, 'processor_config.json')):
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained(model_id, **kwargs)
else:
Expand All @@ -55,15 +56,17 @@ def __init__(self,
self.truncation_strategy = truncation_strategy
self.default_system = default_system
self._test_support_assistant_tokens_mask()
self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [
self._add_default_system, # Add a default system field
self._to_standard_reasoning_content, # Convert thinking to standard field
self._build_standard_messages, # turn to standard mm messages

self.pre_pipeline_names: List[str] = [
'_add_default_system',
'_to_standard_reasoning_content',
'_build_standard_messages',
]
self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [
self._check_max_length, # Check and split input_features
self._add_attention_fields, # Add useful fields
self._roll_labels, # roll labels

self.post_pipeline_names: List[str] = [
'_check_max_length',
'_add_attention_fields',
'_roll_labels',
]

def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -166,7 +169,8 @@ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]:

def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]:
current = trajectories
for pipeline in self.pre_pipeline:
for pipeline_name in self.pre_pipeline_names:
pipeline: Callable[[Trajectory], List[Trajectory]] = getattr(self, pipeline_name)
next_batch = []
for trajectory in current:
next_batch.extend(pipeline(trajectory))
Expand All @@ -175,7 +179,8 @@ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajector

def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]:
current = input_features
for pipeline in self.post_pipeline:
for pipeline_name in self.post_pipeline_names:
pipeline: Callable[[InputFeature], List[InputFeature]] = getattr(self, pipeline_name)
next_batch = []
for input_feature in current:
next_batch.extend(pipeline(input_feature))
Expand Down Expand Up @@ -467,9 +472,15 @@ def _process_mm_string_format(self, messages: List, images: List, videos: List,

def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]:
# Extract trajectory-level media
images = self.preprocess_images(trajectory.pop('images', None) or [])
videos = self.preprocess_videos(trajectory.pop('videos', None) or [])
audios = self.preprocess_audios(trajectory.pop('audios', None) or [])
extracted_images = trajectory.pop(
'images', None) or [img for msg in trajectory['messages'] for img in msg.get('images', []) or []]
extracted_videos = trajectory.pop(
'videos', None) or [video for msg in trajectory['messages'] for video in msg.get('videos', []) or []]
extracted_audios = trajectory.pop(
'audios', None) or [audio for msg in trajectory['messages'] for audio in msg.get('audios', []) or []]
images = self.preprocess_images(extracted_images)
videos = self.preprocess_videos(extracted_videos)
audios = self.preprocess_audios(extracted_audios)

trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios)
if not self.is_mm:
Expand Down
Loading