From f903139863836e939b94033b3456824f1ed1216e Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Mon, 15 Jun 2026 20:24:32 +0000 Subject: [PATCH] video sft --- src/maxtext/configs/base.yml | 6 +- .../sft-vision-llava-video-178k.yml | 38 ++++ src/maxtext/configs/types.py | 5 + .../input_pipeline/hf_data_processing.py | 31 +++ .../input_pipeline/input_pipeline_utils.py | 36 +++- src/maxtext/multimodal/processor.py | 8 +- .../multimodal/processor_qwen3_omni.py | 164 +++++++++++++-- src/maxtext/multimodal/utils.py | 69 +++++-- src/maxtext/trainers/pre_train/train.py | 6 +- .../download_hf_multimodal_dataset.py | 194 ++++++++++++++++++ 10 files changed, 502 insertions(+), 55 deletions(-) create mode 100644 src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml create mode 100644 tools/data_generation/download_hf_multimodal_dataset.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..710a4dc995 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -52,7 +52,7 @@ load_full_state_path: "" # If enable_checkpointing is true, an asynchronous checkpointer will be used if # async_checkpointing is true, else a synchronous one is used. If you have # problems with the checkpointer we recommend trying the synchronous one. -enable_checkpointing: true +enable_checkpointing: false save_checkpoint_on_completion: true async_checkpointing: true checkpoint_period: 10_000 @@ -839,9 +839,7 @@ tpu_num_sparse_cores_to_trace: 2 # - upload xplane profiling, if it is enabled. # - upload training metrics, at the defined log_period interval. managed_mldiagnostics: false # Whether to enable the managed diagnostics -managed_mldiagnostics_on_demand_profiling: true # Enable on-demand profiling server by default managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. -managed_mldiagnostics_region: "" # Optional. GCP region for managed mldiagnostics. If empty, it will be auto-detected by the SDK. # Dump HLO and jaxpr options dump_hlo: false @@ -1124,12 +1122,14 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg" video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4" +video_directory: "" # Local video directory used for SFT training, e.g. "/mounted/LLaVA-Video-178K" audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav" image_placeholder: "<|image|>" video_placeholder: "<|video|>" audio_placeholder: "<|audio|>" use_audio_in_video: false posemb_type_for_vit: "learn" +filter_sft_sequences_by_length: false # max_num_images_per_example only applies for training when your image column is a list of images. # -1 means no limit, and will pad to the max possible number of images determined by sequence length. # Set it to avoid unnecessary padding if you know the maximum number of images per example. diff --git a/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml new file mode 100644 index 0000000000..bc7cb0f368 --- /dev/null +++ b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml @@ -0,0 +1,38 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_config: "base.yml" + +use_sft: true +use_tunix_gradient_accumulation: true +use_multimodal: true +sft_train_on_completion_only: true +packing: false # packing is not supported yet +freeze_vision_encoder_params: true +learning_rate: 2.e-5 + +# -------------- Model -------------- +model_name: "qwen3-omni-30b-a3b" +tokenizer_path: "Qwen/Qwen3-Omni-30B-A3B-Instruct" + +# -------------- HF pipeline -------------- +dataset_type: "hf" +hf_path: "parquet" +hf_train_files: "gs://hengtaoguo-maxtext-logs/datasets/LLaVA-Video-178K/0_30_s_academic_v0_1/*.parquet" +train_split: "train" +train_data_columns: ["query", "label"] +train_image_column: "video" + +# Local SSD path for videos on the TPU VM +video_directory: "/mounted/LLaVA-Video-178K/0_30_s_academic_v0_1" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..2747944c03 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1872,6 +1872,7 @@ class MultimodalGeneral(BaseModel): description="Maximum number of images per example for training with image lists. -1 means no limit.", ) video_path: PathStr = Field("", description="Path to a video for decoding.") + video_directory: PathStr = Field("", description="Local directory path containing video files for SFT.") audio_path: PathStr = Field("", description="Path to an audio file for decoding.") video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.") audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.") @@ -1879,6 +1880,10 @@ class MultimodalGeneral(BaseModel): use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.") mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.") position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).") + filter_sft_sequences_by_length: bool = Field( + False, + description="Filter out multimodal SFT sequences that exceed max_prefill_predict_length or max_target_length.", + ) class VisionTower(BaseModel): diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index 370f1895bd..347dcb321d 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -55,6 +55,25 @@ def vision_sft_preprocessing_pipeline( """pipeline for multimodal SFT with HF dataset""" assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}" + + # Format conversations if columns are missing + features_keys = list(dataset.features.keys()) if dataset.features else [] + if "conversations" in features_keys and not all(col in features_keys for col in text_columns): + def format_llava_video_dataset(example): + conversations = example["conversations"] + query = "" + label = "" + for turn in conversations: + if turn["from"] == "human" and not query: + query = turn["value"] + elif turn["from"] == "gpt" and not label: + label = turn["value"] + example[text_columns[0]] = query + example[text_columns[1]] = label + return example + + dataset = dataset.map(format_llava_video_dataset) + # Tunix GA requires per-micro-batch slicing at the data level, # whereas Native GA processes the full batch and splits it internally. if config.elastic_enabled: @@ -137,6 +156,18 @@ def vision_sft_preprocessing_pipeline( fn_kwargs={"column_name": text_columns[0], "config": config}, ) + # Filter out sequences exceeding max_prefill_predict_length or max_target_length + if getattr(config, "filter_sft_sequences_by_length", False): + max_prefill = getattr(config, "max_prefill_predict_length", 8192) + max_target = getattr(config, "max_target_length", 8192 + 512) + + def filter_by_length(example): + prefill_len = len(example[text_columns[0]]) + response_len = len(example[text_columns[1]]) + return (prefill_len <= max_prefill) and (prefill_len + response_len <= max_target) + + dataset = dataset.filter(filter_by_length) + dataset = input_pipeline_utils.HFDataSource( dataset=dataset, dataloading_host_index=dataloading_host_index, diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 621b79bb47..ccbe9c517f 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -92,17 +92,25 @@ def _process_string(string_tensor): def reformat_prompt(example, column, image_placeholder, model_name): """reformat prompt for multimodal SFT""" - if isinstance(example["images"], list): - num_images = len(example["images"]) + if isinstance(example["images"], str): + example[column] = mm_processor.reformat_prompt( + example[column], image_placeholder, model_name, num_images=0, video_placeholder=image_placeholder, num_videos=1 + ) else: - num_images = 1 - example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images) + if isinstance(example["images"], list): + num_images = len(example["images"]) + else: + num_images = 1 + example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images) return example def reformat_response(example, column, model_name): """reformat response for multimodal SFT""" - example[column] = mm_processor.reformat_response(example[column][0], model_name) + val = example[column] + if isinstance(val, (list, tuple)) and len(val) > 0: + val = val[0] + example[column] = mm_processor.reformat_response(val, model_name) return example @@ -120,9 +128,17 @@ def merge_image_columns(example, image_columns, max_num_images_per_example): def pre_process_image_sft(example, image_column, config): - """pre-process image for multimodal SFT""" + """pre-process image or video for multimodal SFT""" def _process_image_fn(image): + if isinstance(image, str): + import os + + video_directory = getattr(config, "video_directory", "") + if video_directory: + image = os.path.join(video_directory, image) + return mm_processor.preprocess_image_for_training(image, config) + if isinstance(image, list): image = [np.array(mm_utils.convert_to_RGB(img)) for img in image] else: @@ -131,7 +147,7 @@ def _process_image_fn(image): image = mm_processor.preprocess_image_for_training(image, config) return image - example[image_column] = _process_image_fn(example[image_column]) + example[image_column] = _process_image_fn(example[image_column]) if example.get(image_column) is not None else None return example @@ -702,12 +718,12 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) - if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput): raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}") - if preprocessed_image.pixel_values is None: - raise ValueError("Input preprocessed_image must have pixel_values to pad images.") - if self.config.model_name and self.config.model_name.startswith("qwen3-omni"): return preprocessed_image + if preprocessed_image.pixel_values is None: + raise ValueError("Input preprocessed_image must have pixel_values to pad images.") + # Determine the maximum number of images/masks allowed. image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image) single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0] diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 7c99800f2a..d2d3c39e64 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -69,9 +69,13 @@ def preprocess_image_for_training(image, config): return preprocess_mm_data_llama4(image) elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: - from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel + from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training, preprocess_mm_data_qwen3_omni_for_training_video # pylint: disable=import-outside-toplevel - return preprocess_mm_data_qwen3_omni_for_training(image, config) + if isinstance(image, str): + use_audio_in_video = getattr(config, "use_audio_in_video", False) + return preprocess_mm_data_qwen3_omni_for_training_video(image, config) + else: + return preprocess_mm_data_qwen3_omni_for_training(image, config) else: raise ValueError(f"Model {config.model_name} not supported for image preprocessing.") diff --git a/src/maxtext/multimodal/processor_qwen3_omni.py b/src/maxtext/multimodal/processor_qwen3_omni.py index b29b8acc84..e333ecf34b 100644 --- a/src/maxtext/multimodal/processor_qwen3_omni.py +++ b/src/maxtext/multimodal/processor_qwen3_omni.py @@ -344,6 +344,36 @@ def floor_by_factor(number: int, factor: int) -> int: return nframes +def _read_video_opencv(video_path, idx) -> np.ndarray: + """Robust fallback video reader using OpenCV.""" + import cv2 + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise RuntimeError(f"OpenCV failed to open video file: {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # OpenCV reads in BGR format, convert to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + + if len(frames) == 0: + raise RuntimeError(f"OpenCV decoded zero frames from video: {video_path}") + + selected_frames = [] + for i in idx: + clamped_i = min(i, len(frames) - 1) + selected_frames.append(frames[clamped_i]) + + video = np.stack(selected_frames, axis=0) + video = np.transpose(video, (0, 3, 1, 2)) + return video + + def _read_video_decord(video_path, video_start=0.0, video_end=None) -> tuple[np.ndarray, float]: """Read video using decord.VideoReader (torch-free version) @@ -370,24 +400,46 @@ def _read_video_decord(video_path, video_start=0.0, video_end=None) -> tuple[np. } try: vr = decord.VideoReader(video_path) - except Exception as e: - raise RuntimeError(f"Failed to read video from {video_path}: {e}") from e - total_frames, video_fps = len(vr), vr.get_avg_fps() - start_frame, end_frame, total_frames = calculate_video_frame_range( - video_config, - total_frames, - video_fps, - ) - nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) - - # Use numpy linspace instead of torch.linspace - idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() - - video = vr.get_batch(idx).asnumpy() - # Convert from THWC to TCHW format using numpy - video = np.transpose(video, (0, 3, 1, 2)) + total_frames, video_fps = len(vr), vr.get_avg_fps() + start_frame, end_frame, total_frames = calculate_video_frame_range( + video_config, + total_frames, + video_fps, + ) + nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) + idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() + video = vr.get_batch(idx).asnumpy() + video = np.transpose(video, (0, 3, 1, 2)) + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + except Exception as decord_error: + import logging + logging.warning( + f"Decord failed to load/decode video {video_path} due to: {decord_error}. " + "Falling back to OpenCV video reader." + ) + try: + import cv2 + cap = cv2.VideoCapture(video_path) + video_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 100 + cap.release() + + start_frame, end_frame, total_frames = calculate_video_frame_range( + video_config, + total_frames, + video_fps, + ) + nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) + idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() + + video = _read_video_opencv(video_path, idx) + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + except Exception as cv2_error: + raise RuntimeError( + f"Both Decord and OpenCV failed to decode video. " + f"Decord error: {decord_error}. OpenCV error: {cv2_error}" + ) from decord_error - sample_fps = nframes / max(total_frames, 1e-6) * video_fps return video, sample_fps @@ -554,6 +606,84 @@ def preprocess_mm_data_qwen3_omni_for_training(images, config): ) +def preprocess_mm_data_qwen3_omni_for_training_video(video_path, config): + """Preprocesses video (and audio) for Qwen3-Omni SFT training.""" + + import os + + if not os.path.exists(video_path): + raise FileNotFoundError( + f"Video file not found at local path: '{video_path}'. " + "Please make sure you have fully downloaded the dataset using the " + "download utility before running SFT training." + ) + + try: + video_array, _ = _read_video_decord(video_path) + video_processed, video_grid_thw = preprocess_video(video_array, config) + video_values = np.reshape( + video_processed, + ( + 1, + config.num_channels_for_vit, + config.temporal_patch_size_for_vit * video_grid_thw[0, 0], + config.patch_size_for_vit * video_grid_thw[0, 1], + config.patch_size_for_vit * video_grid_thw[0, 2], + ), + ) + except Exception as e: + import logging + logging.warning( + "\n" + "="*80 + "\n" + f"[DATASET CORRUPTION WARNING] BOTH DECORD AND OPENCV FAILED TO DECODE VIDEO:\n" + f"Path: {video_path}\n" + f"Error: {e}\n" + "Substituting dummy zero-video to prevent SFT training crash!\n" + "Please check if this video file is completely corrupted or empty.\n" + + "="*80 + "\n" + ) + grid_t = 1 + grid_h = 16 + grid_w = 16 + video_grid_thw = np.array([[grid_t, grid_h, grid_w]], dtype=np.int32) + fallback_t = config.temporal_patch_size_for_vit * grid_t + fallback_h = config.patch_size_for_vit * grid_h + fallback_w = config.patch_size_for_vit * grid_w + video_values = np.zeros( + (1, config.num_channels_for_vit, fallback_t, fallback_h, fallback_w), + dtype=np.float32 + ) + + processor_outputs = Qwen3OmniPreprocessorOutput( + num_videos=1, + video_values=video_values, + video_grid_thw=video_grid_thw, + video_second_per_grid=np.asarray([config.temporal_patch_size_for_vit], dtype=np.float32), + ) + + use_audio_in_video = getattr(config, "use_audio_in_video", False) + if use_audio_in_video: + try: + mt_audio = mm_utils.load_audio(video_path, sample_rate=SAMPLE_RATE) + mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio) + processor_outputs.audio_values = mt_audio + processor_outputs.audio_mask = mt_audio_mask + audio_mask_sum = np.sum(mt_audio_mask, axis=-1) + audio_lengths = _get_feat_extract_output_lengths(audio_mask_sum) + processor_outputs.audio_lengths = np.array(audio_lengths, dtype=np.int32) + except Exception as e: + import logging + + logging.warning(f"Audio extraction failed for {video_path}: {e}. Using dummy audio.") + dummy_audio = np.zeros((1, 128, 3000), dtype=np.float32) + dummy_mask = np.zeros((1, 3000), dtype=np.int32) + processor_outputs.audio_values = dummy_audio + processor_outputs.audio_mask = dummy_mask + processor_outputs.audio_lengths = np.array([0], dtype=np.int32) + + return processor_outputs + + def preprocess_mm_data_qwen3_omni(config): """Placeholder for multimodal data preprocessing.""" processor_outputs = Qwen3OmniPreprocessorOutput() diff --git a/src/maxtext/multimodal/utils.py b/src/maxtext/multimodal/utils.py index 65b5670fc1..bb47f0f819 100644 --- a/src/maxtext/multimodal/utils.py +++ b/src/maxtext/multimodal/utils.py @@ -767,25 +767,52 @@ def window_function( def load_audio(data_path: str, sample_rate: int = 16000) -> np.ndarray: - """Load audio from a file path. - - Args: - data_path (str): The path to the audio file or video file. - sample_rate (int): The target sample rate in Hz. Default is 16000. - - Returns: - np.ndarray: The loaded audio waveform. - - Raises: - FileNotFoundError: If the audio file does not exist. - RuntimeError: If the audio file cannot be loaded. - """ + """Load audio from a file path (supporting both audio and video files).""" if not os.path.isfile(data_path): - raise FileNotFoundError(f"Audio file not found at path {data_path}. Please specify a valid audio file path") - if librosa is None: - raise ImportError("librosa is required for audio processing but not installed.") - try: - audio = librosa.load(data_path, sr=sample_rate)[0] - return audio - except Exception as e: - raise RuntimeError(f"Failed to load audio from {data_path}: {e}") from e + raise FileNotFoundError(f"Audio file not found at path {data_path}.") + + import soundfile as sf + import subprocess + import tempfile + + is_video = data_path.lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".flv", ".webm")) + + if is_video: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: + temp_wav_path = temp_wav.name + + try: + cmd = [ + "ffmpeg", + "-y", + "-i", + data_path, + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + str(sample_rate), + "-ac", + "1", + temp_wav_path, + ] + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + audio, sr = sf.read(temp_wav_path) + assert sr == sample_rate, f"Sample rate mismatch: expected {sample_rate}, got {sr}" + return audio + except Exception as e: + raise RuntimeError(f"Failed to extract and load audio from video {data_path}: {e}") + finally: + if os.path.exists(temp_wav_path): + os.remove(temp_wav_path) + else: + try: + audio, sr = sf.read(data_path) + if sr != sample_rate: + if librosa is not None: + audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate) + else: + raise RuntimeError(f"Audio sample rate {sr} does not match target {sample_rate} and librosa is not installed.") + return audio + except Exception as e: + raise RuntimeError(f"Failed to load audio from {data_path}: {e}") diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..4e4d0cc560 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -104,10 +104,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_train_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_train_on, :] else: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_eval_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_eval_on, :] mutable_collections = ["intermediates"] if config.mtp_num_layers > 0 and is_train: # The single model.apply call now triggers the entire chain if MTP is enabled: diff --git a/tools/data_generation/download_hf_multimodal_dataset.py b/tools/data_generation/download_hf_multimodal_dataset.py new file mode 100644 index 0000000000..f87823fac4 --- /dev/null +++ b/tools/data_generation/download_hf_multimodal_dataset.py @@ -0,0 +1,194 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Download a HuggingFace multimodal dataset, unzip video archives, +and generate local Parquet metadata files. +""" + +import argparse +import os +import re +import shutil +import tarfile +import zipfile +from huggingface_hub import hf_hub_download, list_repo_files +from datasets import load_dataset +import pyarrow.parquet as pq + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Download and prepare local multimodal dataset from HuggingFace Hub." + ) + parser.add_argument( + "--repo_id", + required=True, + help="HuggingFace dataset repository ID (e.g. lmms-lab/LLaVA-Video-178K)", + ) + parser.add_argument( + "--subset", + required=True, + help="Subset directory inside repository (e.g. 0_30_s_academic_v0_1)", + ) + parser.add_argument( + "--dataset_dir", + required=True, + help="Target local directory to write both videos and parquets.", + ) + parser.add_argument( + "--split", + default="all", + choices=["all", "caption", "open_ended", "multi_choice"], + help="Specific split to prepare (default: all).", + ) + parser.add_argument( + "--token", + default=None, + help="HuggingFace access token for gated datasets.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Ensure the target local directory exists + os.makedirs(args.dataset_dir, exist_ok=True) + + print(f"Connecting to HuggingFace Hub to scan '{args.repo_id}' under subset '{args.subset}'...") + try: + all_files = list_repo_files(repo_id=args.repo_id, repo_type="dataset", token=args.token) + except Exception as e: + print(f"Error accessing HuggingFace repository: {e}") + return + + subset_prefix = args.subset.strip("/") + "/" + target_files = [f for f in all_files if f.startswith(subset_prefix)] + + if not target_files: + print(f"Error: No files found in HuggingFace repo matching subset prefix '{subset_prefix}'") + return + + # 1. Separate JSON metadata files based on split choice + split_patterns = { + "caption": r".*cap.*\.json", + "open_ended": r".*oe.*\.json", + "multi_choice": r".*mc.*\.json", + } + + json_files = [] + if args.split == "all": + json_files = [f for f in target_files if f.endswith(".json")] + else: + pattern = split_patterns[args.split] + json_files = [f for f in target_files if f.endswith(".json") and re.match(pattern, os.path.basename(f))] + + if not json_files: + print(f"Error: No metadata JSON files found matching split choice '{args.split}'") + return + + # 2. Identify video archive files + tar_files = [f for f in target_files if f.endswith(".tar.gz") or f.endswith(".tar") or f.endswith(".zip")] + + print("\n" + "="*80) + print(f"DATASET PREPARATION PLAN") + print(f"HuggingFace Repo: {args.repo_id}") + print(f"Subset / Split: {args.subset} / {args.split}") + print(f"Local Directory: {args.dataset_dir}") + print(f"Metadata JSONs: {len(json_files)}") + print(f"Video Archives: {len(tar_files)}") + print("="*80 + "\n") + + # 3. Download and extract video archives + staging_dir = os.path.join(args.dataset_dir, ".staging") + os.makedirs(staging_dir, exist_ok=True) + + downloaded_archives = [] + for i, f in enumerate(tar_files): + filename = os.path.basename(f) + print(f"[{i+1}/{len(tar_files)}] Downloading video archive: {filename} ...") + try: + local_path = hf_hub_download( + repo_id=args.repo_id, + filename=f, + repo_type="dataset", + local_dir=staging_dir, + token=args.token + ) + downloaded_archives.append(local_path) + except Exception as e: + print(f"Failed to download archive {filename}: {e}") + return + + for i, archive_path in enumerate(downloaded_archives): + print(f"[{i+1}/{len(downloaded_archives)}] Extracting video archive: {os.path.basename(archive_path)} ...") + try: + if archive_path.endswith(".tar.gz") or archive_path.endswith(".tar"): + with tarfile.open(archive_path, "r:gz" if archive_path.endswith(".tar.gz") else "r:") as tar: + tar.extractall(path=args.dataset_dir) + elif archive_path.endswith(".zip"): + with zipfile.ZipFile(archive_path, "r") as zip_ref: + zip_ref.extractall(path=args.dataset_dir) + except Exception as e: + print(f"Failed to extract archive {archive_path}: {e}") + return + + # Clean up staging directory + shutil.rmtree(staging_dir) + print("Video archives extracted successfully. Staging directory cleaned.") + + # 4. Download and convert JSON files to local Parquet files + local_json_paths = [] + for i, f in enumerate(json_files): + filename = os.path.basename(f) + print(f"[{i+1}/{len(json_files)}] Downloading metadata JSON: {filename} ...") + try: + local_path = hf_hub_download( + repo_id=args.repo_id, + filename=f, + repo_type="dataset", + local_dir=args.dataset_dir, + token=args.token + ) + local_json_paths.append(local_path) + except Exception as e: + print(f"Failed to download metadata JSON {filename}: {e}") + return + + print("\nConverting JSON files to local Parquet format...") + try: + ds = load_dataset("json", data_files=local_json_paths, split="train") + table = ds.data.table + + # Target filename indicates subset/split configurations + parquet_filename = f"llava-video-178k-{args.split}-00000-of-00001.parquet" + output_parquet_path = os.path.join(args.dataset_dir, parquet_filename) + pq.write_table(table, output_parquet_path, compression="zstd") + + print(f"Success! Local parquet file generated at: {output_parquet_path}") + except Exception as e: + print(f"Error during JSON-to-Parquet conversion: {e}") + return + finally: + # Always clean up temporary JSON files + for p in local_json_paths: + if os.path.exists(p): + os.remove(p) + + print(f"\nAll operations completed successfully! Dataset is ready locally at: {args.dataset_dir}\n") + + +if __name__ == "__main__": + main()