diff --git a/README.md b/README.md index f98056bb..8a6c053a 100644 --- a/README.md +++ b/README.md @@ -166,8 +166,12 @@ supported on Twinkle✨ framework. | | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) | | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) -| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro) +| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)| +| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)| +| Gemma4全系列 | [google/gemma-4-E2B](https://www.modelscope.cn/models/google/gemma-4-E2B) | 2.3B effective (5.1B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E2B · Hugging Face](https://huggingface.co/google/gemma-4-E2B) | +| | [google/gemma-4-E4B](https://www.modelscope.cn/models/google/gemma-4-E4B) | 4.5B effective (8B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E4B · Hugging Face](https://huggingface.co/google/gemma-4-E4B) | +| | [google/gemma-4-31B](https://www.modelscope.cn/models/google/gemma-4-31B) | 30.7B | transformers>=5.8.0 | ✘ | [google/gemma-4-31B · Hugging Face](https://huggingface.co/google/gemma-4-31B) | +| | [google/gemma-4-26B-A4B](https://www.modelscope.cn/models/google/gemma-4-26B-A4B) | 25.2B (Active 3.8B) | transformers>=5.8.0 | ✘ | [google/gemma-4-26B-A4B · Hugging Face](https://huggingface.co/google/gemma-4-26B-A4B) | ## Sample Code diff --git a/README_ZH.md b/README_ZH.md index 0a8da692..23a6466f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -156,8 +156,12 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl | | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) | | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) -| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro) +| DeepSeek V4全系列 | [deepseek-ai/DeepSeek-V4-Flash](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) | 284B| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Flash](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)| +| | [deepseek-ai/DeepSeek-V4-Pro](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) | 1.6T| transformers>=5.8.0 | ✔ | [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro)| +| Gemma4全系列 | [google/gemma-4-E2B](https://www.modelscope.cn/models/google/gemma-4-E2B) | 2.3B effective (5.1B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E2B · Hugging Face](https://huggingface.co/google/gemma-4-E2B) | +| | [google/gemma-4-E4B](https://www.modelscope.cn/models/google/gemma-4-E4B) | 4.5B effective (8B with embeddings) | transformers>=5.8.0 | ✘ | [google/gemma-4-E4B · Hugging Face](https://huggingface.co/google/gemma-4-E4B) | +| | [google/gemma-4-31B](https://www.modelscope.cn/models/google/gemma-4-31B) | 30.7B | transformers>=5.8.0 | ✘ | [google/gemma-4-31B · Hugging Face](https://huggingface.co/google/gemma-4-31B) | +| | [google/gemma-4-26B-A4B](https://www.modelscope.cn/models/google/gemma-4-26B-A4B) | 25.2B (Active 3.8B) | transformers>=5.8.0 | ✘ | [google/gemma-4-26B-A4B · Hugging Face](https://huggingface.co/google/gemma-4-26B-A4B) | ## 示例代码 diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py new file mode 100644 index 00000000..77805187 --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -0,0 +1,171 @@ +import os +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoConfig +from transformers import ( + Gemma4Config, +) + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor + +logger = get_logger() + +########## Construct a device_mesh ########## +device_mesh = DeviceMesh.from_sizes( + # fsdp_size=2, + # dp_size=1, + # ep_size=2, + device_type=Platform.get_platform().device_prefix(), +) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +########## hyperparameters ########## +IGNORE_MISMATCHED_SIZES = True +MODEL_PATH = 'ms://google/gemma-4-26b-a4b' +DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' +TRAIN_LEN = 2000 +BATCH_SIZE = 4 +METRIC_STEP = 10 +SAVE_STEP = 10 + +### reduce model layers for debug +TEXT_NUM_LAYERS = 3 +VISION_NUM_LAYERS = 3 + + +from twinkle.preprocessor import Preprocessor +from twinkle.data_format import Message, Trajectory +class LatexOCRProcessor(Preprocessor): + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + col = self.map_row_to_col(rows) + return col + + def preprocess(self, row) -> Trajectory: + return Trajectory( + messages=[ + Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), + Message(role='assistant', content=row['text']), + ] + ) + +def eval(model, eval_dataloader): + for step, batch in tqdm(enumerate(eval_dataloader)): + model.forward_only(inputs=batch) + model.calculate_loss() + metrics = model.calculate_metric(is_training=False) + return metrics + +def train(): + + ### prepare dataset and dataloader + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) + # Set template to prepare encoding + dataset.set_template('Template', model_id=MODEL_PATH) + # Preprocess the dataset to standard format + # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(preprocess_func=LatexOCRProcessor) + # Encode dataset + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + config, kwargs = AutoConfig.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + return_unused_kwargs=True, + # code_revision=code_revision, + # _commit_hash=commit_hash, + # **hub_kwargs, + # **kwargs, + ) + + if isinstance(config, Gemma4Config): # 减层 + text_config = config.text_config + vision_config = config.vision_config + if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'): + text_config.num_hidden_layers = TEXT_NUM_LAYERS + logger.info(f' modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}') + if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'): + vision_config.num_hidden_layers = VISION_NUM_LAYERS + logger.info(f' modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}') + if hasattr(config, 'use_cache'): + config.use_cache = False + + # Use a TransformersModel + model = TransformersModel( + model_id=MODEL_PATH, + config=config, + device_mesh=device_mesh, + strategy='accelerate', # native_fsdp、 accelerate + ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, + fsdp_config={ + 'reshard_after_forward': True, + 'expert_parallel': { + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # Add Optimizer for lora `default` + model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + logger.info(get_device_placement()) + # Print the training config + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + best_eval_loss = float('inf') + # lora: 8G * 8 + # full: 18G * 8 + + ### eval dataset and dataloader + EVAL_LENGTH = 100 + eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) + eval_dataset.set_template('Template', model_id=MODEL_PATH) + # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + eval_dataset.map(preprocess_func=LatexOCRProcessor) + eval_dataset.encode() + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + # Do forward and backward + model.forward_backward(inputs=batch) + # Step + model.clip_grad_and_step() + + if step % METRIC_STEP == 0: + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') + + if step % SAVE_STEP == 0: + metrics = eval(model, eval_dataloader) + metrics['step'] = step + if float(metrics['loss']) < best_eval_loss: + # model.save(f'checkpoint-{step}') + best_eval_loss = float(metrics['loss']) + metrics['best_eval_loss'] = best_eval_loss + logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') + + # model.save(f'last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/cookbook/mm/fsdp2_gemma4_mm.sh b/cookbook/mm/fsdp2_gemma4_mm.sh new file mode 100644 index 00000000..c67113d8 --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.sh @@ -0,0 +1,3 @@ +export CUDA_VISIBLE_DEVICES=0,1 + +torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index d6e1eed9..f63833a3 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -49,6 +49,7 @@ class InputProcessor: 'pixel_values_videos', 'video_grid_thw', 'input_features', + 'input_features_mask', 'feature_attention_mask', 'grid_thws', } @@ -592,7 +593,15 @@ def is_mm_position_ids(position_ids): for field, values in vlm_fields.items(): if values: _values = [] - for value in values: + for i, value in enumerate(values): + if field == 'input_features': # [freq_bins, time_steps] -> [freq_bins, time_steps, num_features] + assert len(value.shape) == 2 + value = value.unsqueeze(-1) + if field == 'input_features_mask': # [freq_bins,] -> [freq_bins, time_steps] + assert len(value.shape) == 1 + input_features_shape = vlm_fields['input_features'][i].shape + assert value.shape[0] == input_features_shape[0] + value = value.unsqueeze(1).expand(input_features_shape[:2]) if value.dim() == 1: # image_thw may be squeezed value = value.unsqueeze(0) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 50ba3e5e..039e2244 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -40,7 +40,8 @@ def __init__(self, **kwargs): self.model_id = model_id model_id = HubOperation.download_model(model_id, ignore_model=True) - if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')): + if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')) or os.path.exists( + os.path.join(model_id, 'processor_config.json')): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(model_id, **kwargs) else: @@ -55,15 +56,17 @@ def __init__(self, self.truncation_strategy = truncation_strategy self.default_system = default_system self._test_support_assistant_tokens_mask() - self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ - self._add_default_system, # Add a default system field - self._to_standard_reasoning_content, # Convert thinking to standard field - self._build_standard_messages, # turn to standard mm messages + + self.pre_pipeline_names: List[str] = [ + '_add_default_system', + '_to_standard_reasoning_content', + '_build_standard_messages', ] - self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ - self._check_max_length, # Check and split input_features - self._add_attention_fields, # Add useful fields - self._roll_labels, # roll labels + + self.post_pipeline_names: List[str] = [ + '_check_max_length', + '_add_attention_fields', + '_roll_labels', ] def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]: @@ -166,7 +169,8 @@ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]: def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]: current = trajectories - for pipeline in self.pre_pipeline: + for pipeline_name in self.pre_pipeline_names: + pipeline: Callable[[Trajectory], List[Trajectory]] = getattr(self, pipeline_name) next_batch = [] for trajectory in current: next_batch.extend(pipeline(trajectory)) @@ -175,7 +179,8 @@ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajector def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]: current = input_features - for pipeline in self.post_pipeline: + for pipeline_name in self.post_pipeline_names: + pipeline: Callable[[InputFeature], List[InputFeature]] = getattr(self, pipeline_name) next_batch = [] for input_feature in current: next_batch.extend(pipeline(input_feature)) @@ -467,9 +472,15 @@ def _process_mm_string_format(self, messages: List, images: List, videos: List, def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: # Extract trajectory-level media - images = self.preprocess_images(trajectory.pop('images', None) or []) - videos = self.preprocess_videos(trajectory.pop('videos', None) or []) - audios = self.preprocess_audios(trajectory.pop('audios', None) or []) + extracted_images = trajectory.pop( + 'images', None) or [img for msg in trajectory['messages'] for img in msg.get('images', []) or []] + extracted_videos = trajectory.pop( + 'videos', None) or [video for msg in trajectory['messages'] for video in msg.get('videos', []) or []] + extracted_audios = trajectory.pop( + 'audios', None) or [audio for msg in trajectory['messages'] for audio in msg.get('audios', []) or []] + images = self.preprocess_images(extracted_images) + videos = self.preprocess_videos(extracted_videos) + audios = self.preprocess_audios(extracted_audios) trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios) if not self.is_mm: