From 2b6c67662685b323088b09b7234bd04946e91c20 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Fri, 15 May 2026 08:48:27 +0000 Subject: [PATCH] feat(agent): support MoE, multimodal, audio, seq2seq, diffusion model extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend GraphNet Agent to correctly identify and extract computation graphs for a wider range of model architectures beyond basic text/vision models. Key changes: - ModelMetadata: add architecture_type field ("text"/"vision"/"seq2seq"/ "audio"/"multimodal"/"diffusion"/"moe") - ConfigMetadataAnalyzer: use AutoConfig.from_pretrained() for rich config introspection; classify architecture via transformers' own task mapping tables (MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, etc.) — no hardcoded lists; add per-arch input shape builders covering whisper decoder_input_ids, CLIP text_config seq_len, diffusion sample/timestep/encoder_hidden_states, MoE/seq2seq specific inputs; add field-based model_type inference fallback for configs missing model_type (e.g. prajjwal1/bert-tiny) - TemplateCodeGenerator: branch model loader by arch (AutoModelForSeq2SeqLM for seq2seq, UNet2DConditionModel for diffusion); add diffusion-specific script generation using positional args; inject inferred model_type into config when absent - LLMCodeFixer: extend _SYSTEM_PROMPT with MoE and diffusion input specs and error patterns; add MoE routing / UNet / seq2seq / GQA / audio fields to _extract_key_fields - GraphNetAgent: add _resolve_model_dir() to detect diffusers pipelines (model_index.json) and automatically redirect to unet/ subdir Tested on: bert-tiny (text), convnextv2 (vision), t5-small (seq2seq), whisper-tiny (audio), clip-vit-base-patch32 (multimodal), tiny-random-MixtralForCausalLM (moe), tiny-stable-diffusion-pipe (diffusion) Co-Authored-By: Claude Sonnet 4.6 --- .../agent/code_generator/llm_code_fixer.py | 50 +- .../code_generator/template_generator.py | 106 +++- graph_net/agent/graph_net_agent.py | 26 + .../config_metadata_analyzer.py | 490 +++++++++++++++--- .../agent/metadata_analyzer/model_metadata.py | 3 + 5 files changed, 576 insertions(+), 99 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 2a56a6f9ab..5c6242f25e 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -64,6 +64,27 @@ attention_mask = torch.ones((1, 64), dtype=torch.long).to(device) decoder_input_ids = torch.randint(0, min(vocab_size-1, 1000), (1, 32), dtype=torch.long).to(device) +**MoE 类**(mixtral/qwen2_moe/deepseek_v2/dbrx/olmoe 等): + - 架构上仍是文本模型,输入与文本类完全相同(input_ids + attention_mask) + - 加载同样用 AutoModel.from_config(config),无需任何特殊处理 + ⚠️ vocab_size 通常很大(32000+),严格用 min(vocab_size-1, 30000) 作为 randint 上界 + 关键 config 字段:num_local_experts(Mixtral)/ num_experts(Qwen2-MoE)/ n_routed_experts(DeepSeek) + +**扩散模型类**(UNet2DConditionModel / DiT / stable-diffusion / SDXL 等): + from diffusers import UNet2DConditionModel + _config = UNet2DConditionModel.load_config(model_dir) + model = UNet2DConditionModel.from_config(_config) + # 从 config 读取关键维度 + in_channels = _config.get("in_channels", 4) + sample_size = _config.get("sample_size", 64) + cross_attention_dim = _config.get("cross_attention_dim", 768) + sample = torch.randn(1, in_channels, sample_size, sample_size).to(device) + timestep = torch.tensor([1]).to(device) + encoder_hidden_states = torch.randn(1, 77, cross_attention_dim).to(device) + # 调用必须用位置参数,不能 **inputs + wrapped(sample, timestep, encoder_hidden_states) + ⚠️ dynamic 必须为 False;调用格式固定为位置参数,禁止用 **inputs dict 展开 + ## 【常见报错 → 修复方法】 | 报错关键词 | 修复方法 | |---|---| @@ -77,6 +98,9 @@ | "sentencepiece" / "tiktoken" ImportError | 不使用 tokenizer,用 torch.randint 直接构造 input_ids | | "PendingUnbackedSymbolNotFound" | 确认 dynamic=False(不要改为 True) | | decoder_input_ids missing | Seq2Seq 模型需要同时传 input_ids 和 decoder_input_ids | +| "encoder_hidden_states" required(UNet) | 扩散模型必须以位置参数传入 encoder_hidden_states,不能省略 | +| UNet sample/timestep 形状错误 | 检查 in_channels/sample_size/cross_attention_dim 是否从 config 正确读取 | +| MoE expert 路由 RuntimeError | 输入格式与普通文本模型相同,通常是 vocab 越界,检查 randint 上界是否 < vocab_size | """ @@ -312,9 +336,29 @@ def _extract_key_fields(model_dir: Path) -> str: "patch_size", "num_mel_bins", "chunk_length", + # MoE routing (field names vary across models) + "num_local_experts", + "num_experts_per_tok", + "num_experts", + "n_routed_experts", + "moe_intermediate_size", + "num_shared_experts", + # Diffusion / UNet + "in_channels", + "sample_size", + "cross_attention_dim", + "layers_per_block", + # Seq2Seq + "is_encoder_decoder", + "decoder_start_token_id", + # GQA (Llama/Mistral family) + "num_key_value_heads", + # Audio + "feature_size", + "sample_rate", ] result = {k: cfg[k] for k in keys if k in cfg} - # 对嵌套 config 只取 model_type + # 对嵌套 config 只取关键字段 for nested in ("audio_config", "vision_config", "text_config"): if isinstance(result.get(nested), dict): result[nested] = { @@ -326,6 +370,10 @@ def _extract_key_fields(model_dir: Path) -> str: "num_channels", "num_mel_bins", "hidden_size", + "num_local_experts", + "num_experts", + "n_routed_experts", + "sample_rate", ) if k in result[nested] } diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index b2ec415f3f..a3332f695b 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -65,16 +65,19 @@ def _model_short_name(model_id: str) -> str: return model_id.replace("/", "_") def _generate_code(self, model_dir: Path, model_metadata: ModelMetadata) -> str: - """Generate complete extraction script code string""" - # Generate model loading code - load_code = self._generate_model_loader(model_dir, model_metadata) + """Generate complete extraction script code string.""" + if model_metadata.architecture_type == "diffusion": + return self._generate_diffusion_code(model_dir, model_metadata) + return self._generate_standard_code(model_dir, model_metadata) - # Generate input construction code + def _generate_standard_code( + self, model_dir: Path, model_metadata: ModelMetadata + ) -> str: + """Generate standard (transformers-based) extraction script.""" + load_code = self._generate_model_loader(model_dir, model_metadata) input_code = self._generate_input_code(model_metadata) - short_name = self._model_short_name(model_metadata.model_id) - # Generate main code code = f"""import torch try: from transformers import AutoModel @@ -102,6 +105,48 @@ def main(): with torch.no_grad(): wrapped(**inputs) +if __name__ == "__main__": + main() +""" + return code + + def _generate_diffusion_code( + self, model_dir: Path, model_metadata: ModelMetadata + ) -> str: + """Generate extraction script for diffusion models (diffusers UNet).""" + load_code = self._generate_model_loader(model_dir, model_metadata) + input_code = self._generate_input_code(model_metadata) + short_name = self._model_short_name(model_metadata.model_id) + + # Diffusion model forward takes positional args, not **inputs dict + code = f"""import torch +try: + from diffusers import UNet2DConditionModel +except ImportError: + raise ImportError("diffusers is required. Install with: pip install diffusers") + +import graph_net + +def main(): + # Load model +{self._indent(load_code, 4)} + + # Prepare inputs +{self._indent(input_code, 4)} + + # Extract graph + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device).eval() + + sample = inputs["sample"].to(device) + timestep = inputs["timestep"].to(device) + encoder_hidden_states = inputs["encoder_hidden_states"].to(device) + + wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() + + with torch.no_grad(): + wrapped(sample, timestep, encoder_hidden_states) + if __name__ == "__main__": main() """ @@ -110,14 +155,47 @@ def main(): def _generate_model_loader( self, model_dir: Path, model_metadata: ModelMetadata ) -> str: - """Generate model loading code — config only, random weights""" + """Generate model loading code based on architecture type.""" model_path = str(model_dir).replace("\\", "/") - - return ( - f"from transformers import AutoConfig\n" - f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' - f"model = AutoModel.from_config(_config)" - ) + arch = model_metadata.architecture_type + + if arch == "seq2seq": + return ( + f"from transformers import AutoConfig, AutoModelForSeq2SeqLM\n" + f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = AutoModelForSeq2SeqLM.from_config(_config)" + ) + elif arch == "diffusion": + return ( + f"from diffusers import UNet2DConditionModel\n" + f'_config = UNet2DConditionModel.load_config("{model_path}")\n' + f"model = UNet2DConditionModel.from_config(_config)" + ) + else: + # text, moe, vision, multimodal, audio, None → AutoModel + # If model_type is not present in config.json (e.g. prajjwal1/bert-tiny), + # inject the inferred model_type so AutoConfig can resolve the class. + model_type = model_metadata.model_type + if model_type: + return ( + f"import json as _json, os as _os, tempfile as _tmp\n" + f"from transformers import AutoConfig, AutoModel\n" + f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n' + f'if "model_type" not in _raw:\n' + f' _raw["model_type"] = "{model_type}"\n' + f" _td = _tmp.mkdtemp()\n" + f' _json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n' + f" _config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n" + f"else:\n" + f' _config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = AutoModel.from_config(_config)" + ) + else: + return ( + f"from transformers import AutoConfig, AutoModel\n" + f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = AutoModel.from_config(_config)" + ) def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" @@ -129,7 +207,7 @@ def _generate_input_code(self, model_metadata: ModelMetadata) -> str: shape_tuple = f"({', '.join(map(str, shape))})" if dtype == "int64": - if "input_ids" in name.lower(): + if "input_ids" in name.lower() or "decoder_input_ids" in name.lower(): safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) lines.append( f'inputs["{name}"] = torch.randint(0, {safe_vocab_size}, {shape_tuple}, dtype={torch_dtype})' diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 4339bc65d6..2a2b234eb8 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -108,6 +108,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: self.logger.info(f"Starting extraction for model: {model_id}") model_dir = self._fetch_model(model_id) + model_dir = self._resolve_model_dir(model_dir) model_metadata = self._analyze_model(model_dir) script_path = self._generate_script(model_dir, model_metadata, model_id) @@ -199,6 +200,31 @@ def _fetch_model(self, model_id: str) -> Path: self.logger.info(f"Model downloaded to: {model_dir}") return model_dir + def _resolve_model_dir(self, model_dir: Path) -> Path: + """ + For diffusers pipeline repos (identified by model_index.json at root), + resolve to the UNet subdirectory which contains the actual UNet config. + Returns model_dir unchanged for non-pipeline repos. + """ + model_index = model_dir / "model_index.json" + if not model_index.exists(): + return model_dir + + # It's a diffusers pipeline — find the unet subdirectory + unet_dir = model_dir / "unet" + if unet_dir.is_dir() and (unet_dir / "config.json").exists(): + self.logger.info( + f"Detected diffusers pipeline; using UNet subdir: {unet_dir}" + ) + return unet_dir + + # Pipeline without unet/ (e.g., image-to-image or non-SD pipeline) + self.logger.warning( + f"Diffusers pipeline detected but no unet/ subdir found in {model_dir}; " + "proceeding with root dir." + ) + return model_dir + def _analyze_model(self, model_dir: Path): """Analyze model configuration to extract metadata""" self.logger.info("Analyzing model configuration") diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index 7b8501f5c0..b15df28b0a 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from graph_net.agent.metadata_analyzer.base import BaseMetadataAnalyzer from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata @@ -22,12 +22,23 @@ ] +def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any: + """Unified attribute/key access for both PretrainedConfig objects and dicts.""" + if isinstance(cfg, dict): + return cfg.get(key, default) + return getattr(cfg, key, default) + + class ConfigMetadataAnalyzer(BaseMetadataAnalyzer): - """Analyzer that extracts metadata from config.json""" + """Analyzer that extracts metadata from config.json, using transformers AutoConfig + when available to leverage rich config object properties for architecture detection. + """ def analyze(self, model_dir: Path) -> ModelMetadata: """ - Analyze model by parsing config.json + Analyze model by parsing config.json (with transformers AutoConfig if available). + Also handles diffusers-style configs that lack a 'model_type' key but have + '_class_name' (e.g., UNet2DConditionModel). Args: model_dir: Path to model directory @@ -39,28 +50,46 @@ def analyze(self, model_dir: Path) -> ModelMetadata: AnalysisError: If analysis fails """ config_path = model_dir / "config.json" - if not config_path.exists(): raise AnalysisError(f"config.json not found in {model_dir}") try: + # Primary path: load via AutoConfig to get a rich PretrainedConfig object + cfg_obj = None + try: + from transformers import AutoConfig + + cfg_obj = AutoConfig.from_pretrained( + str(model_dir), trust_remote_code=True + ) + except Exception: + pass # fall back to dict-only mode + + # Always parse raw dict as fallback / supplementary info with open(config_path, "r", encoding="utf-8") as f: - config = json.load(f) + cfg_dict = json.load(f) - # Extract model type - model_type = self._infer_model_type(config) - - # Extract input shapes and dtypes - input_shapes, input_dtypes = self._extract_input_info(config) - - # Extract vocab_size - vocab_size = config.get("vocab_size") + arch_type = self._classify_architecture(cfg_obj, cfg_dict) + input_shapes, input_dtypes = self._extract_input_info( + cfg_obj, cfg_dict, arch_type + ) - # Try to get actual embedding size from model weights + vocab_size = ( + _cfg_get(cfg_obj, "vocab_size") + if cfg_obj is not None + else cfg_dict.get("vocab_size") + ) embedding_size = self._get_embedding_size(model_dir) - - # Get model_id from directory name or config - model_id = self._get_model_id(model_dir, config) + model_id = self._get_model_id(model_dir, cfg_dict) + model_type = ( + _cfg_get(cfg_obj, "model_type") + if cfg_obj is not None + else cfg_dict.get("model_type") + ) + # If model_type is still missing, try field-based heuristic inference. + # This handles models with incomplete config.json (e.g., prajjwal1/bert-tiny). + if not model_type: + model_type = self._infer_model_type_from_fields(cfg_dict) return ModelMetadata( model_id=model_id, @@ -69,19 +98,350 @@ def analyze(self, model_dir: Path) -> ModelMetadata: model_type=model_type, vocab_size=vocab_size, embedding_size=embedding_size, + architecture_type=arch_type, ) except json.JSONDecodeError as e: raise AnalysisError(f"Failed to parse config.json: {e}") from e + except AnalysisError: + raise except Exception as e: raise AnalysisError(f"Failed to analyze model: {e}") from e + # ------------------------------------------------------------------ + # Architecture classification + # ------------------------------------------------------------------ + + @staticmethod + def _classify_architecture(cfg_obj: Any, cfg_dict: Dict) -> Optional[str]: + """ + Classify model architecture type using transformers' own task mapping tables + when available, falling back to config field inspection. + + Priority order (high → low): + diffusion > audio > multimodal > moe > seq2seq > vision > text + """ + model_type = ( + _cfg_get(cfg_obj, "model_type") or cfg_dict.get("model_type") or "" + ).lower() + + # 1. Diffusion models (diffusers ecosystem) + # UNet2DConditionModel config has both in_channels and sample_size. + has_in_channels = _cfg_get(cfg_obj, "in_channels") or cfg_dict.get( + "in_channels" + ) + has_sample_size = _cfg_get(cfg_obj, "sample_size") or cfg_dict.get( + "sample_size" + ) + if has_in_channels and has_sample_size: + return "diffusion" + + # 2. Audio models + # Use the union of transformers' audio task mapping tables — no hardcoded list. + try: + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_CTC_MAPPING_NAMES, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, + MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, + ) + + all_audio: set = ( + set(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) + | set(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES) + | set(MODEL_FOR_CTC_MAPPING_NAMES) + | set(MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) + | set(MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES) + ) + if model_type in all_audio: + return "audio" + except ImportError: + # Attribute-based fallback + if _cfg_get(cfg_obj, "num_mel_bins") or cfg_dict.get("num_mel_bins"): + return "audio" + if _cfg_get(cfg_obj, "feat_extract_norm") or cfg_dict.get( + "feat_extract_norm" + ): + return "audio" + + # 3. Multimodal VLMs + # Use transformers' multimodal task mapping tables — no hardcoded list. + try: + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES, + ) + + all_multimodal: set = set(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) | set( + MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES + ) + if model_type in all_multimodal: + return "multimodal" + except ImportError: + pass + # Fallback: check sub_configs / dict keys for vision+text pair + if cfg_obj is not None: + sub_configs = getattr(cfg_obj, "sub_configs", {}) + has_vision = "vision_config" in sub_configs or hasattr( + cfg_obj, "vision_config" + ) + has_text = "text_config" in sub_configs or hasattr(cfg_obj, "text_config") + if has_vision and has_text: + return "multimodal" + elif cfg_dict.get("vision_config") and cfg_dict.get("text_config"): + return "multimodal" + + # 4. MoE (Mixture of Experts) + # No common base class in transformers; detect via public config field names. + # We only look at non-private fields (skip leading-underscore) to avoid + # matching internal attributes like `_experts_implementation_internal` that + # exist on ALL PretrainedConfig objects. + _MOE_PUBLIC_PREFIXES = ( + "num_experts", # num_experts, num_experts_per_tok, num_local_experts + "n_routed_experts", # DeepSeek-V2/V3 + "num_local_experts", # Mixtral (also caught by num_experts prefix) + "num_shared_experts", + "expert_capacity", # switch_transformers + "moe_topk", + "moe_k", + "moe_intermediate_size", + ) + if cfg_obj is not None: + public_fields = {k for k in vars(cfg_obj).keys() if not k.startswith("_")} + if any( + k.lower().startswith(p) + for k in public_fields + for p in _MOE_PUBLIC_PREFIXES + ): + return "moe" + else: + if any( + k.lower().startswith(p) + for k in cfg_dict.keys() + for p in _MOE_PUBLIC_PREFIXES + ): + return "moe" + + # 5. Seq2Seq (encoder-decoder) + # Use the standard transformers PretrainedConfig attribute. + # Must be checked AFTER audio so that whisper (which also sets + # is_encoder_decoder=True) is classified as audio. + is_enc_dec = ( + getattr(cfg_obj, "is_encoder_decoder", False) + if cfg_obj is not None + else cfg_dict.get("is_encoder_decoder", False) + ) + if is_enc_dec: + return "seq2seq" + + # 6. Pure vision (no vocabulary) + has_image = ( + _cfg_get(cfg_obj, "image_size") + or cfg_dict.get("image_size") + or cfg_dict.get("num_channels") + ) + has_vocab = _cfg_get(cfg_obj, "vocab_size") or cfg_dict.get("vocab_size") + if has_image and not has_vocab: + return "vision" + + # 7. Default: text + return "text" + + # ------------------------------------------------------------------ + # Input shape / dtype extraction + # ------------------------------------------------------------------ + + def _extract_input_info( + self, + cfg_obj: Any, + cfg_dict: Dict, + arch_type: Optional[str], + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + """ + Build input_shapes and input_dtypes based on the detected architecture type. + """ + if arch_type == "diffusion": + return self._inputs_diffusion(cfg_obj, cfg_dict) + if arch_type == "audio": + return self._inputs_audio(cfg_obj, cfg_dict) + if arch_type == "multimodal": + return self._inputs_multimodal(cfg_obj, cfg_dict) + if arch_type == "seq2seq": + return self._inputs_seq2seq(cfg_obj, cfg_dict) + if arch_type == "vision": + return self._inputs_vision(cfg_obj, cfg_dict) + # text / moe / None → standard NLP inputs + return self._inputs_text(cfg_obj, cfg_dict) + + def _inputs_text( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + raw_len = _cfg_get(cfg_obj, "max_position_embeddings") or cfg_dict.get( + "max_position_embeddings", 512 + ) + seq_len = min(int(raw_len), _MAX_SEQ_LEN) + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + } + return shapes, dtypes + + def _inputs_seq2seq( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + enc_len = 64 + dec_len = 32 + shapes = { + "input_ids": [1, enc_len], + "attention_mask": [1, enc_len], + "decoder_input_ids": [1, dec_len], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "decoder_input_ids": "int64", + } + return shapes, dtypes + + def _inputs_vision( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + raw_size = _cfg_get(cfg_obj, "image_size") or cfg_dict.get("image_size", 224) + if isinstance(raw_size, (list, tuple)): + raw_size = raw_size[0] + image_size = min(int(raw_size), _MAX_IMAGE_SIZE) + num_channels = _cfg_get(cfg_obj, "num_channels") or cfg_dict.get( + "num_channels", 3 + ) + shapes = {"pixel_values": [1, int(num_channels), image_size, image_size]} + dtypes = {"pixel_values": "float32"} + return shapes, dtypes + + def _inputs_multimodal( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + # Text branch — prefer text_config.max_position_embeddings (e.g., CLIP has 77) + txt_cfg = _cfg_get(cfg_obj, "text_config") or cfg_dict.get("text_config", {}) + if hasattr(txt_cfg, "max_position_embeddings"): + raw_len = txt_cfg.max_position_embeddings + elif isinstance(txt_cfg, dict): + raw_len = txt_cfg.get("max_position_embeddings", None) + else: + raw_len = None + if raw_len is None: + raw_len = _cfg_get(cfg_obj, "max_position_embeddings") or cfg_dict.get( + "max_position_embeddings", 512 + ) + seq_len = min(int(raw_len), _MAX_SEQ_LEN) + + # Vision branch — prefer sub vision_config + vis_cfg = _cfg_get(cfg_obj, "vision_config") or cfg_dict.get( + "vision_config", {} + ) + if hasattr(vis_cfg, "image_size"): + raw_size = vis_cfg.image_size + num_channels = getattr(vis_cfg, "num_channels", 3) + elif isinstance(vis_cfg, dict): + raw_size = vis_cfg.get("image_size", 224) + num_channels = vis_cfg.get("num_channels", 3) + else: + raw_size = _cfg_get(cfg_obj, "image_size") or cfg_dict.get( + "image_size", 224 + ) + num_channels = _cfg_get(cfg_obj, "num_channels") or cfg_dict.get( + "num_channels", 3 + ) + if isinstance(raw_size, (list, tuple)): + raw_size = raw_size[0] + image_size = min(int(raw_size), _MAX_IMAGE_SIZE) + + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + "pixel_values": [1, int(num_channels), image_size, image_size], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "pixel_values": "float32", + } + return shapes, dtypes + + def _inputs_audio( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + model_type = ( + _cfg_get(cfg_obj, "model_type") or cfg_dict.get("model_type") or "" + ).lower() + + if model_type == "whisper": + num_mel_bins = int( + _cfg_get(cfg_obj, "num_mel_bins") or cfg_dict.get("num_mel_bins", 80) + ) + # chunk_length (seconds) × sample_rate / hop_length ≈ frames + # whisper default: 30s × 16000 / 160 = 3000 frames + frames = 3000 + shapes = { + "input_features": [1, num_mel_bins, frames], + # Whisper is encoder-decoder; explicitly pass decoder_input_ids + # to avoid internal conflict between decoder_input_ids and + # decoder_inputs_embeds auto-generation. + "decoder_input_ids": [1, 1], + } + dtypes = { + "input_features": "float32", + "decoder_input_ids": "int64", + } + elif model_type == "clap": + shapes = {"input_features": [1, 1, 1001, 64]} + dtypes = {"input_features": "float32"} + else: + # wav2vec2, hubert, unispeech, wavlm, sew, etc. — 1 second at 16 kHz + shapes = {"input_values": [1, 16000]} + dtypes = {"input_values": "float32"} + return shapes, dtypes + + def _inputs_diffusion( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + in_channels = int( + _cfg_get(cfg_obj, "in_channels") or cfg_dict.get("in_channels", 4) + ) + sample_size_raw = _cfg_get(cfg_obj, "sample_size") or cfg_dict.get( + "sample_size", 64 + ) + if isinstance(sample_size_raw, (list, tuple)): + sample_size = int(sample_size_raw[0]) + else: + sample_size = int(sample_size_raw) + cross_dim = int( + _cfg_get(cfg_obj, "cross_attention_dim") + or cfg_dict.get("cross_attention_dim", 768) + ) + shapes = { + "sample": [1, in_channels, sample_size, sample_size], + "timestep": [1], + "encoder_hidden_states": [1, 77, cross_dim], + } + dtypes = { + "sample": "float32", + "timestep": "int64", + "encoder_hidden_states": "float32", + } + return shapes, dtypes + + # ------------------------------------------------------------------ + # Legacy helpers (kept for backward compatibility) + # ------------------------------------------------------------------ + def _infer_model_type(self, config: Dict) -> Optional[str]: - """Infer model type from config""" - # Check common model type indicators + """Infer model type from raw config dict (legacy path).""" if "model_type" in config: return config["model_type"] - - # Check architecture field if "architectures" in config and config["architectures"]: arch = config["architectures"][0].lower() if "bert" in arch: @@ -92,92 +452,56 @@ def _infer_model_type(self, config: Dict) -> Optional[str]: return "resnet" elif "vit" in arch or "vision" in arch: return "vit" - return None - def _extract_input_info( - self, config: Dict - ) -> tuple[Dict[str, List[int]], Dict[str, str]]: + @staticmethod + def _infer_model_type_from_fields(cfg_dict: Dict) -> Optional[str]: """ - Extract input shapes and dtypes from config - - Returns: - Tuple of (input_shapes, input_dtypes) + Last-resort model_type inference based on field name signatures. + Used when config.json has neither 'model_type' nor 'architectures'. """ - input_shapes = {} - input_dtypes = {} - - # Common patterns for NLP models - if "max_position_embeddings" in config or "vocab_size" in config: - # NLP model (BERT, GPT, etc.) - # Cap to _MAX_SEQ_LEN: large models set max_position_embeddings to - # 131072+ which causes OOM via O(n²) attention during graph tracing. - raw_len = config.get("max_position_embeddings", 512) - max_length = min(raw_len, _MAX_SEQ_LEN) - batch_size = 1 - input_shapes["input_ids"] = [batch_size, max_length] - input_dtypes["input_ids"] = "int64" - - # Add attention_mask if present - if "attention_mask" not in input_shapes: - input_shapes["attention_mask"] = [batch_size, max_length] - input_dtypes["attention_mask"] = "int64" - - # Common patterns for vision models - elif "image_size" in config or "num_channels" in config: - # Vision model (ResNet, ViT, etc.) - # image_size may be an int or a [H, W] list - raw_size = config.get("image_size", 224) - if isinstance(raw_size, (list, tuple)): - raw_size = raw_size[0] - image_size = min(int(raw_size), _MAX_IMAGE_SIZE) - num_channels = config.get("num_channels", 3) - batch_size = 1 - input_shapes["pixel_values"] = [ - batch_size, - num_channels, - image_size, - image_size, - ] - input_dtypes["pixel_values"] = "float32" - - # Fallback: use default values - if not input_shapes: - # Default to common NLP input - input_shapes["input_ids"] = [1, 128] - input_dtypes["input_ids"] = "int64" - - return input_shapes, input_dtypes + keys = set(cfg_dict.keys()) + # BERT: type_vocab_size is unique to BERT/RoBERTa family + if "type_vocab_size" in keys: + return "bert" + # GPT-2: n_head / n_layer naming + if "n_head" in keys and "n_layer" in keys: + return "gpt2" + # Generic transformer-like text model + if { + "vocab_size", + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + } <= keys: + return "bert" + return None def _get_model_id(self, model_dir: Path, config: Dict) -> str: - """Get model ID from directory or config""" - # Try to get from config first + """Get model ID from directory or config.""" if "name_or_path" in config: return config["name_or_path"] - - # Fallback to directory name return model_dir.name def _get_embedding_size(self, model_dir: Path) -> Optional[int]: - """Get actual embedding layer size from model weights""" + """Get actual embedding layer size from model weights.""" model_file = self._find_model_file(model_dir) if not model_file: return None - if model_file.suffix == ".safetensors": return self._get_embedding_size_from_safetensors(model_file) else: return self._get_embedding_size_from_pytorch(model_file) def _find_model_file(self, model_dir: Path) -> Optional[Path]: - """Find model weight file (pytorch_model*.bin or model.safetensors)""" + """Find model weight file (pytorch_model*.bin or model.safetensors).""" model_files = list(model_dir.glob("pytorch_model*.bin")) if not model_files: model_files = list(model_dir.glob("model.safetensors")) return model_files[0] if model_files else None def _get_embedding_size_from_safetensors(self, model_file: Path) -> Optional[int]: - """Extract embedding size from safetensors file""" + """Extract embedding size from safetensors file.""" try: from safetensors import safe_open @@ -192,20 +516,18 @@ def _get_embedding_size_from_safetensors(self, model_file: Path) -> Optional[int return None def _get_embedding_size_from_pytorch(self, model_file: Path) -> Optional[int]: - """Extract embedding size from PyTorch .bin file""" + """Extract embedding size from PyTorch .bin file.""" try: import torch state_dict = torch.load(model_file, map_location="cpu") - # Check known embedding keys first for key in _EMBEDDING_WEIGHT_KEYS: if key in state_dict: tensor = state_dict[key] if tensor is not None and len(tensor.shape) >= 1: return int(tensor.shape[0]) - # Fallback: search by pattern for key, tensor in state_dict.items(): if "embedding" in key.lower() and "weight" in key.lower(): if tensor is not None and len(tensor.shape) >= 1: diff --git a/graph_net/agent/metadata_analyzer/model_metadata.py b/graph_net/agent/metadata_analyzer/model_metadata.py index 35d1eb7d91..cf57331528 100644 --- a/graph_net/agent/metadata_analyzer/model_metadata.py +++ b/graph_net/agent/metadata_analyzer/model_metadata.py @@ -16,6 +16,9 @@ class ModelMetadata: embedding_size: Optional[ int ] = None # Actual embedding layer size (from model weights) + architecture_type: Optional[ + str + ] = None # e.g., "text", "vision", "seq2seq", "audio", "multimodal", "diffusion", "moe" def __post_init__(self): """Validate metadata"""