From 67a81d6bcd96f834cc410af4ad65f60b6d076ea6 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 15 May 2026 17:39:39 +0800 Subject: [PATCH 1/2] feat(tools): add backward graph generation and validation tools This commit introduces backward graph generation pipeline integrated with GraphNet's test_compiler framework. Changes: - graph_net/torch/extractor.py: add try/except for capture_sparse_compute to support PyTorch versions where the config does not exist. - graph_net/torch/sample_pass/backward_graph_extractor.py: - switch module from train() to eval() to avoid dropout/BN side effects - clone forward inputs with detach().clone() to avoid inplace modification - add _is_pure_shape_graph() to skip subgraphs with only shape ops - tools/backward_graph_test.py: - batch backward FX Graph generation via aot_autograd - integrated test_compiler validation with auto-generated weight_meta.py - default GRAPH_NET_FLUCTUATION_DETECT_THRESHOLD=0.5 and trials=10 - tools/backward_kernel_dedup.py: - Triton kernel dedup analysis for backward graphs --- graph_net/torch/extractor.py | 5 +- .../sample_pass/backward_graph_extractor.py | 42 +- tools/backward_graph_test.py | 538 ++++++++++++++++++ tools/backward_kernel_dedup.py | 188 ++++++ 4 files changed, 771 insertions(+), 2 deletions(-) create mode 100755 tools/backward_graph_test.py create mode 100755 tools/backward_kernel_dedup.py diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 568ad995ad..949b713213 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -8,7 +8,10 @@ torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True -torch._dynamo.config.capture_sparse_compute = True +try: + torch._dynamo.config.capture_sparse_compute = True +except AttributeError: + pass torch._dynamo.config.raise_on_ctx_manager_usage = False torch._dynamo.config.allow_rnn = True diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index e50d3b7e9e..10bc7d43a4 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -27,7 +27,11 @@ def __call__(self): module, forward_inputs = get_torch_module_and_inputs( self.model_path, use_dummy_inputs=False, device=self.device ) - module.train() + module.eval() + + if self._is_pure_shape_graph(module): + print(f"[Skip] Pure shape graph: {self.model_path}") + return eval_forward_dir = os.path.join( self.output_dir, "eval_forward", self.rel_model_path @@ -35,6 +39,10 @@ def __call__(self): if not os.path.exists(eval_forward_dir): shutil.copytree(self.model_path, eval_forward_dir) + forward_inputs = [ + inp.detach().clone() if isinstance(inp, torch.Tensor) else inp + for inp in forward_inputs + ] forward_inputs = self.set_requires_grad_for_forward_inputs( self.model_path, module, forward_inputs ) @@ -117,6 +125,38 @@ def _remove_none_from_output(self, gm): gm.recompile() return gm + def _is_pure_shape_graph(self, module): + """Check if the graph only contains shape manipulation ops.""" + shape_only_ops = { + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.aten.squeeze, + torch.ops.aten.unsqueeze, + torch.ops.aten.permute, + torch.ops.aten.transpose, + torch.ops.aten.expand, + torch.ops.aten.flatten, + torch.ops.aten.t, + "view", + "reshape", + "squeeze", + "unsqueeze", + "permute", + "transpose", + "expand", + "flatten", + "t", + } + for node in module.graph.nodes: + if node.op in {"placeholder", "output", "get_attr"}: + continue + if node.op == "call_function" and node.target in shape_only_ops: + continue + if node.op == "call_method" and node.target in shape_only_ops: + continue + return False + return True + def _requires_grad(self, name, tensor): if not tensor.is_floating_point(): return False diff --git a/tools/backward_graph_test.py b/tools/backward_graph_test.py new file mode 100755 index 0000000000..3e36885bd9 --- /dev/null +++ b/tools/backward_graph_test.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python3 +"""Batch backward graph generation and test_compiler validation tool. + +Usage: + # Only generate backward graphs + python tools/backward_graph_test.py \ + --sample-root /path/to/samples \ + --limit 100 \ + --output-dir /tmp/bw_results + + # Generate + test_compiler + kernel collection + python tools/backward_graph_test.py \ + --sample-root /path/to/samples \ + --limit 20 \ + --output-dir /tmp/bw_results \ + --test-compiler \ + --collect-kernels \ + --device cuda +""" + +import argparse +import inspect +import json +import os +import shutil +import subprocess +import sys +import traceback +from pathlib import Path + +import torch +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_func + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from graph_net.torch.fx_graph_module_util import ( + get_torch_module_and_inputs, +) +from graph_net.torch.fx_graph_module_util import _get_tensor_metas as get_tensor_metas + + +def find_samples(root_dir, limit=-1): + """Recursively find sample directories containing model.py.""" + samples = [] + for dirpath, _, filenames in os.walk(root_dir): + if "model.py" in filenames: + samples.append(dirpath) + if limit > 0 and len(samples) >= limit: + break + return samples + + +def load_model_from_path(model_path, device="cpu"): + """Load GraphModule and inputs from a sample directory.""" + module, inputs = get_torch_module_and_inputs( + model_path, use_dummy_inputs=False, device=device + ) + return module, inputs + + +def set_requires_grad_by_meta(model_path, module, inputs): + """Set requires_grad based on weight_meta original_name.""" + try: + tensor_metas = get_tensor_metas(model_path) + except Exception: + return inputs + + name2tensor_meta = {tm.name: tm for tm in tensor_metas} + param_names = list(inspect.signature(module.forward).parameters.keys()) + for input_idx, name in enumerate(param_names): + if input_idx >= len(inputs): + break + tensor = inputs[input_idx] + if not isinstance(tensor, torch.Tensor): + continue + if not tensor.is_floating_point(): + continue + tm = name2tensor_meta.get(name) + if tm is None: + continue + check_name = tm.original_name or tm.name + nograd_keywords = [ + "running_mean", + "running_var", + "num_batches_tracked", + "mask", + "indices", + "position_ids", + "anchor", + ] + if any(kw in check_name for kw in nograd_keywords): + tensor.requires_grad = False + else: + tensor.requires_grad = True + return inputs + + +def is_pure_shape_graph(module): + """Check if the graph only contains shape manipulation ops.""" + if not hasattr(module, "graph"): + return False + shape_only_ops = { + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.aten.squeeze, + torch.ops.aten.unsqueeze, + torch.ops.aten.permute, + torch.ops.aten.transpose, + torch.ops.aten.expand, + torch.ops.aten.flatten, + torch.ops.aten.t, + "view", + "reshape", + "squeeze", + "unsqueeze", + "permute", + "transpose", + "expand", + "flatten", + "t", + } + for node in module.graph.nodes: + if node.op in {"placeholder", "output", "get_attr"}: + continue + if node.op == "call_function" and node.target in shape_only_ops: + continue + if node.op == "call_method" and node.target in shape_only_ops: + continue + return False + return True + + +def _tensor_meta_py_str(name, shape, dtype, device, mean=0.0, std=1.0): + dtype_str = str(dtype).replace("torch.", "torch.") + return ( + f"class Program_input_tensor_meta_{name}:\n" + f' name = "{name}"\n' + f" shape = {list(shape)}\n" + f' dtype = "{dtype_str}"\n' + f' device = "{device}"\n' + f" mean = {mean:.3f}\n" + f" std = {std:.3f}\n" + f" data = None\n" + ) + + +def _save_backward_model(gm, backward_inputs, output_path): + """Save backward GraphModule with model.py, input_meta.py, weight_meta.py.""" + os.makedirs(output_path, exist_ok=True) + + # model.py + model_py_path = os.path.join(output_path, "model.py") + code = gm.code + if "class GraphModule" not in code: + code = "import torch\n\n" "class GraphModule(torch.nn.Module):\n" + "\n".join( + " " + line if line.strip() else "" for line in code.split("\n") + ) + with open(model_py_path, "w", encoding="utf-8") as f: + f.write(code) + + # GraphNet test_compiler reads weight_meta.py as model inputs. + # Write backward graph inputs into weight_meta.py. + param_names = list(inspect.signature(gm.forward).parameters.keys()) + weight_meta_lines = [] + for idx, name in enumerate(param_names): + if idx < len(backward_inputs): + t = backward_inputs[idx] + if isinstance(t, torch.Tensor): + weight_meta_lines.append( + _tensor_meta_py_str( + name, + t.shape, + t.dtype, + str(t.device), + mean=0.0, + std=1.0, + ) + ) + weight_meta_path = os.path.join(output_path, "weight_meta.py") + with open(weight_meta_path, "w", encoding="utf-8") as f: + f.write("\n".join(weight_meta_lines)) + + # input_meta.py: empty (test_compiler does not use it) + input_meta_path = os.path.join(output_path, "input_meta.py") + with open(input_meta_path, "w", encoding="utf-8") as f: + f.write("") + + +def capture_backward_graph(module, inputs, device="cpu"): + """Capture forward and backward FX Graph via aot_autograd. + + Returns: + (backward_gm, backward_inputs) or (None, None) if no valid grad pairs. + """ + gm_holder = {} + backward_inputs = [] + + def forward_compiler(fx_gm, fwd_inputs): + gm_holder["forward_gm"] = fx_gm + return fx_gm + + def backward_compiler(fx_gm, bwd_inputs): + gm_holder["backward_gm"] = fx_gm + placeholders = [n for n in fx_gm.graph.nodes if n.op == "placeholder"] + origin_forward = fx_gm.forward + fx_gm._original_forward = origin_forward + + def wrapped_forward(*args): + for node, arg in zip(placeholders, args): + backward_inputs.append(arg.detach().clone()) + return origin_forward(*args) + + fx_gm.forward = wrapped_forward + return make_boxed_func(fx_gm) + + compiled = aot_module_simplified( + module, + inputs, + fw_compiler=forward_compiler, + bw_compiler=backward_compiler, + ) + outs = compiled(*inputs) + outs = [outs] if isinstance(outs, torch.Tensor) else outs + valid_pairs = [ + (out, torch.ones_like(out)) + for out in outs + if isinstance(out, torch.Tensor) and out.requires_grad + ] + + if not valid_pairs: + return None, None + + tensors, grads = zip(*valid_pairs) + torch.autograd.backward(tensors, grads) + + backward_gm = gm_holder.get("backward_gm") + if backward_gm is not None: + # Restore original forward for correct signature when saving + if hasattr(backward_gm, "_original_forward"): + backward_gm.forward = backward_gm._original_forward + backward_gm = _remove_none_from_output(backward_gm) + return backward_gm, backward_inputs + + +def _remove_none_from_output(gm): + output_node = next( + (n for n in gm.graph.nodes if n.op == "output"), + None, + ) + if output_node is None: + return gm + outs = ( + output_node.args[0] + if output_node and isinstance(output_node.args, (tuple, list)) + else output_node.args + ) + if isinstance(outs, (tuple, list)): + new_outs = tuple(out for out in outs if out is not None) + if new_outs != outs: + output_node.args = (new_outs,) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +def run_test_compiler(backward_model_path, device="cuda", compiler="nope", trials=10): + """Run graph_net_bench torch test_compiler on a backward model.""" + env = os.environ.copy() + env["GRAPH_NET_FLUCTUATION_DETECT_THRESHOLD"] = "0.5" + cmd = [ + sys.executable, + "-m", + "graph_net_bench.torch.test_compiler", + f"--model-path={backward_model_path}", + f"--compiler={compiler}", + f"--device={device}", + "--warmup=3", + f"--trials={trials}", + "--log-prompt=graph-net-backward-test-compiler-log", + ] + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=300, env=env + ) + stdout = result.stdout + stderr = result.stderr + success = result.returncode == 0 and "[Result] status: success" in stderr + return success, stdout, stderr, result.returncode + except subprocess.TimeoutExpired: + return False, "", "Timeout", -1 + except Exception as e: + return False, "", str(e), -1 + + +def collect_triton_kernels(backward_model_path, device="cuda"): + """Collect Triton kernels by running torch.compile with inductor backend.""" + kernels = [] + try: + import logging + + triton_logger = logging.getLogger("torch._inductor.codecache") + triton_handler = None + kernel_codes = [] + + class KernelCaptureHandler(logging.Handler): + def emit(self, record): + msg = record.getMessage() + if "triton" in msg.lower() and ".py" in msg.lower(): + kernel_codes.append(msg) + + triton_handler = KernelCaptureHandler() + triton_logger.addHandler(triton_handler) + triton_logger.setLevel(logging.DEBUG) + + module, inputs = get_torch_module_and_inputs( + backward_model_path, use_dummy_inputs=False, device=device + ) + compiled = torch.compile(module, backend="inductor") + _ = compiled(*inputs) + torch.cuda.synchronize() if "cuda" in device else None + + triton_logger.removeHandler(triton_handler) + kernels = kernel_codes + except Exception as e: + kernels = [f"Error collecting kernels: {e}"] + return kernels + + +def process_single_sample( + sample_path, + output_dir, + device="cpu", + test_compiler=False, + collect_kernels=False, + replace_inplace=False, + skip_pure_shape=True, +): + """Process a single sample: generate backward graph, optionally test and collect kernels. + + Returns: + dict with status and paths. + """ + rel_path = os.path.relpath( + sample_path, os.path.dirname(os.path.dirname(sample_path)) + ) + result = { + "sample": sample_path, + "rel_path": rel_path, + "status": "unknown", + "reason": "", + "backward_path": None, + "test_compiler_success": None, + "kernels": [], + } + + try: + module, inputs = load_model_from_path(sample_path, device=device) + module.eval() + + if skip_pure_shape and is_pure_shape_graph(module): + result["status"] = "skipped" + result["reason"] = "pure_shape_graph" + return result + + inputs = [ + inp.detach().clone() if isinstance(inp, torch.Tensor) else inp + for inp in inputs + ] + inputs = set_requires_grad_by_meta(sample_path, module, inputs) + + backward_gm, backward_inputs = capture_backward_graph( + module, inputs, device=device + ) + + if backward_gm is None: + result["status"] = "failed" + result["reason"] = "no_valid_grad_pairs" + return result + + backward_dir = os.path.join(output_dir, "backward_graphs", rel_path) + os.makedirs(backward_dir, exist_ok=True) + _save_backward_model(backward_gm, backward_inputs, backward_dir) + + # Copy graph_net.json if it exists + src_json = os.path.join(sample_path, "graph_net.json") + if os.path.exists(src_json): + shutil.copy2(src_json, os.path.join(backward_dir, "graph_net.json")) + + result["backward_path"] = backward_dir + result["status"] = "success" + + if test_compiler: + success, stdout, stderr, rc = run_test_compiler( + backward_dir, device=device, compiler="nope", trials=10 + ) + result["test_compiler_success"] = success + result["test_compiler_rc"] = rc + result["test_compiler_stderr"] = ( + stderr[-2000:] if len(stderr) > 2000 else stderr + ) + if not success: + result["status"] = "test_compiler_failed" + + if collect_kernels and result["status"] == "success": + kernels = collect_triton_kernels(backward_dir, device=device) + result["kernels"] = kernels + + except Exception as e: + result["status"] = "exception" + result["reason"] = f"{type(e).__name__}: {e}" + result["traceback"] = traceback.format_exc() + + return result + + +def main(): + parser = argparse.ArgumentParser( + description="Batch backward graph generation and validation." + ) + parser.add_argument( + "--sample-root", + type=str, + required=True, + help="Root directory containing subdirectories with model.py", + ) + parser.add_argument( + "--limit", + type=int, + default=-1, + help="Maximum number of samples to process (-1 for all)", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for backward graphs and results", + ) + parser.add_argument( + "--test-compiler", + action="store_true", + help="Run test_compiler on generated backward graphs", + ) + parser.add_argument( + "--collect-kernels", + action="store_true", + help="Collect Triton kernels from backward graphs", + ) + parser.add_argument( + "--replace-inplace", + action="store_true", + help="Auto replace inplace=True with inplace=False in model code", + ) + parser.add_argument( + "--skip-pure-shape", + action="store_true", + default=True, + help="Skip pure shape operation subgraphs (default: True)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device for model execution", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Skip already processed samples in output dir", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + samples = find_samples(args.sample_root, limit=args.limit) + print(f"Found {len(samples)} samples under {args.sample_root}") + + results = [] + stats = {"total": 0, "success": 0, "skipped": 0, "failed": 0, "exception": 0} + + for idx, sample_path in enumerate(samples): + print(f"[{idx + 1}/{len(samples)}] Processing {sample_path} ...") + stats["total"] += 1 + + if args.resume: + rel_path = os.path.relpath( + sample_path, os.path.dirname(os.path.dirname(sample_path)) + ) + backward_dir = os.path.join(args.output_dir, "backward_graphs", rel_path) + if os.path.exists(os.path.join(backward_dir, "model.py")): + print(" [Resume] Skip already processed.") + continue + + result = process_single_sample( + sample_path, + args.output_dir, + device=args.device, + test_compiler=args.test_compiler, + collect_kernels=args.collect_kernels, + replace_inplace=args.replace_inplace, + skip_pure_shape=args.skip_pure_shape, + ) + results.append(result) + + status = result["status"] + if status in stats: + stats[status] += 1 + else: + stats["exception"] += 1 + + print(f" Status: {status}") + if result.get("reason"): + print(f" Reason: {result['reason']}") + if result.get("test_compiler_success") is not None: + print( + f" test_compiler: {'success' if result['test_compiler_success'] else 'failed'}" + ) + + # Save results + summary = { + "args": vars(args), + "stats": stats, + "results": results, + } + result_path = os.path.join(args.output_dir, "backward_results.json") + with open(result_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + print("\n=== Summary ===") + print(f"Total: {stats['total']}") + print(f"Success: {stats['success']}") + print(f"Skipped: {stats['skipped']}") + print(f"Failed: {stats['failed']}") + print(f"Exception: {stats['exception']}") + print(f"Results saved to: {result_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/backward_kernel_dedup.py b/tools/backward_kernel_dedup.py new file mode 100755 index 0000000000..d32ad66de9 --- /dev/null +++ b/tools/backward_kernel_dedup.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""Backward kernel dedup analysis tool. + +Usage: + python tools/backward_kernel_dedup.py \ + --backward-dir /tmp/bw_results/backward_graphs \ + --tag typical_backward \ + --output /tmp/bw_dedup.json +""" + +import argparse +import hashlib +import json +import os +import sys +from pathlib import Path + +import torch + + +def compile_and_extract_kernels(model_path, device="cuda"): + """Compile a backward model with inductor and extract kernel sources.""" + kernels = [] + try: + module_name = Path(model_path).name + # Import the model dynamically + import importlib.util + + spec = importlib.util.spec_from_file_location( + module_name, os.path.join(model_path, "model.py") + ) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + model = mod.GraphModule().to(device) + + # We need dummy inputs; try to load from weight_meta / input_meta if available + inputs = [] + for meta_file in ["input_meta.py", "weight_meta.py"]: + meta_path = os.path.join(model_path, meta_file) + if not os.path.exists(meta_path): + continue + spec = importlib.util.spec_from_file_location( + f"{module_name}_{meta_file}", meta_path + ) + meta_mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(meta_mod) + for attr_name in dir(meta_mod): + attr = getattr(meta_mod, attr_name) + if ( + isinstance(attr, type) + and hasattr(attr, "shape") + and hasattr(attr, "dtype") + ): + shape = attr.shape + dtype_str = attr.dtype.replace("torch.", "") + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int64": torch.int64, + "int32": torch.int32, + "bool": torch.bool, + } + dtype = dtype_map.get(dtype_str, torch.float32) + if "int" in dtype_str or "bool" in dtype_str: + t = torch.zeros(shape, dtype=dtype, device=device) + else: + t = torch.randn(shape, dtype=dtype, device=device) + inputs.append(t) + + if not inputs: + return [] + + compiled = torch.compile(model, backend="inductor") + _ = compiled(*inputs) + if "cuda" in device: + torch.cuda.synchronize() + + # Try to read inductor generated code from cache + cache_dir = os.path.expanduser("~/.torchinductor") + if os.path.exists(cache_dir): + for root, _, files in os.walk(cache_dir): + for f in files: + if f.endswith(".py"): + with open(os.path.join(root, f), "r", encoding="utf-8") as fp: + content = fp.read() + if "triton" in content.lower(): + kernels.append(content) + except Exception as e: + print(f"Error extracting kernels from {model_path}: {e}") + return kernels + + +def hash_kernel(kernel_code): + """Compute a simple hash for a kernel source.""" + # Normalize by removing comments and extra whitespace + lines = [ + line.strip() + for line in kernel_code.split("\n") + if line.strip() and not line.strip().startswith("#") + ] + normalized = "\n".join(lines) + return hashlib.md5(normalized.encode("utf-8")).hexdigest() + + +def analyze_kernels(backward_dir, tag): + """Analyze kernel dedup for all backward graphs under backward_dir.""" + samples = [] + for dirpath, _, filenames in os.walk(backward_dir): + if "model.py" in filenames: + samples.append(dirpath) + + all_hashes = [] + sample_kernels = [] + + for sample_path in samples: + print(f"Processing {sample_path} ...") + kernels = compile_and_extract_kernels(sample_path) + hashes = [hash_kernel(k) for k in kernels] + all_hashes.extend(hashes) + sample_kernels.append( + { + "path": sample_path, + "kernel_count": len(kernels), + "hashes": hashes, + } + ) + + total = len(all_hashes) + unique = len(set(all_hashes)) + dedup_rate = (1 - unique / total) * 100 if total > 0 else 0 + + summary = { + "tag": tag, + "total_samples": len(samples), + "total_kernel_instances": total, + "unique_kernels": unique, + "dedup_rate_percent": round(dedup_rate, 2), + "avg_kernels_per_graph": round(total / len(samples), 2) if samples else 0, + "per_sample": sample_kernels, + } + return summary + + +def main(): + parser = argparse.ArgumentParser(description="Backward kernel dedup analysis.") + parser.add_argument( + "--backward-dir", + type=str, + required=True, + help="Directory containing backward graph subdirectories", + ) + parser.add_argument( + "--tag", + type=str, + default="backward", + help="Tag for the analysis (e.g., typical_backward, fusible_backward)", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output JSON path for dedup results", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device for compilation", + ) + args = parser.parse_args() + + summary = analyze_kernels(args.backward_dir, args.tag) + with open(args.output, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + print(f"\n=== Dedup Summary ({args.tag}) ===") + print(f"Total samples: {summary['total_samples']}") + print(f"Total kernel instances: {summary['total_kernel_instances']}") + print(f"Unique kernels: {summary['unique_kernels']}") + print(f"Dedup rate: {summary['dedup_rate_percent']}%") + print(f"Avg kernels/graph: {summary['avg_kernels_per_graph']}") + print(f"Result saved to: {args.output}") + + +if __name__ == "__main__": + main() From ea832a1c351c0433965a0edb1c7658034f64e0f6 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Mon, 18 May 2026 11:40:53 +0800 Subject: [PATCH 2/2] fix: address review feedback - remove is_pure_shape_graph and standalone tools - Remove _is_pure_shape_graph() from backward_graph_extractor.py per reviewer feedback (incomplete op whitelist, not maintainable) - Remove tools/backward_graph_test.py (use existing shell script graph_net/test/backward_graph_extractor.sh for batch processing) - Remove tools/backward_kernel_dedup.py (use existing graph_hash.txt based dedup in graph_net/tools/deduplicated.py) --- .../sample_pass/backward_graph_extractor.py | 36 -- tools/backward_graph_test.py | 538 ------------------ tools/backward_kernel_dedup.py | 188 ------ 3 files changed, 762 deletions(-) delete mode 100755 tools/backward_graph_test.py delete mode 100755 tools/backward_kernel_dedup.py diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index 10bc7d43a4..25ca3b81f7 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -29,10 +29,6 @@ def __call__(self): ) module.eval() - if self._is_pure_shape_graph(module): - print(f"[Skip] Pure shape graph: {self.model_path}") - return - eval_forward_dir = os.path.join( self.output_dir, "eval_forward", self.rel_model_path ) @@ -125,38 +121,6 @@ def _remove_none_from_output(self, gm): gm.recompile() return gm - def _is_pure_shape_graph(self, module): - """Check if the graph only contains shape manipulation ops.""" - shape_only_ops = { - torch.ops.aten.view, - torch.ops.aten.reshape, - torch.ops.aten.squeeze, - torch.ops.aten.unsqueeze, - torch.ops.aten.permute, - torch.ops.aten.transpose, - torch.ops.aten.expand, - torch.ops.aten.flatten, - torch.ops.aten.t, - "view", - "reshape", - "squeeze", - "unsqueeze", - "permute", - "transpose", - "expand", - "flatten", - "t", - } - for node in module.graph.nodes: - if node.op in {"placeholder", "output", "get_attr"}: - continue - if node.op == "call_function" and node.target in shape_only_ops: - continue - if node.op == "call_method" and node.target in shape_only_ops: - continue - return False - return True - def _requires_grad(self, name, tensor): if not tensor.is_floating_point(): return False diff --git a/tools/backward_graph_test.py b/tools/backward_graph_test.py deleted file mode 100755 index 3e36885bd9..0000000000 --- a/tools/backward_graph_test.py +++ /dev/null @@ -1,538 +0,0 @@ -#!/usr/bin/env python3 -"""Batch backward graph generation and test_compiler validation tool. - -Usage: - # Only generate backward graphs - python tools/backward_graph_test.py \ - --sample-root /path/to/samples \ - --limit 100 \ - --output-dir /tmp/bw_results - - # Generate + test_compiler + kernel collection - python tools/backward_graph_test.py \ - --sample-root /path/to/samples \ - --limit 20 \ - --output-dir /tmp/bw_results \ - --test-compiler \ - --collect-kernels \ - --device cuda -""" - -import argparse -import inspect -import json -import os -import shutil -import subprocess -import sys -import traceback -from pathlib import Path - -import torch -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_func - -sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -from graph_net.torch.fx_graph_module_util import ( - get_torch_module_and_inputs, -) -from graph_net.torch.fx_graph_module_util import _get_tensor_metas as get_tensor_metas - - -def find_samples(root_dir, limit=-1): - """Recursively find sample directories containing model.py.""" - samples = [] - for dirpath, _, filenames in os.walk(root_dir): - if "model.py" in filenames: - samples.append(dirpath) - if limit > 0 and len(samples) >= limit: - break - return samples - - -def load_model_from_path(model_path, device="cpu"): - """Load GraphModule and inputs from a sample directory.""" - module, inputs = get_torch_module_and_inputs( - model_path, use_dummy_inputs=False, device=device - ) - return module, inputs - - -def set_requires_grad_by_meta(model_path, module, inputs): - """Set requires_grad based on weight_meta original_name.""" - try: - tensor_metas = get_tensor_metas(model_path) - except Exception: - return inputs - - name2tensor_meta = {tm.name: tm for tm in tensor_metas} - param_names = list(inspect.signature(module.forward).parameters.keys()) - for input_idx, name in enumerate(param_names): - if input_idx >= len(inputs): - break - tensor = inputs[input_idx] - if not isinstance(tensor, torch.Tensor): - continue - if not tensor.is_floating_point(): - continue - tm = name2tensor_meta.get(name) - if tm is None: - continue - check_name = tm.original_name or tm.name - nograd_keywords = [ - "running_mean", - "running_var", - "num_batches_tracked", - "mask", - "indices", - "position_ids", - "anchor", - ] - if any(kw in check_name for kw in nograd_keywords): - tensor.requires_grad = False - else: - tensor.requires_grad = True - return inputs - - -def is_pure_shape_graph(module): - """Check if the graph only contains shape manipulation ops.""" - if not hasattr(module, "graph"): - return False - shape_only_ops = { - torch.ops.aten.view, - torch.ops.aten.reshape, - torch.ops.aten.squeeze, - torch.ops.aten.unsqueeze, - torch.ops.aten.permute, - torch.ops.aten.transpose, - torch.ops.aten.expand, - torch.ops.aten.flatten, - torch.ops.aten.t, - "view", - "reshape", - "squeeze", - "unsqueeze", - "permute", - "transpose", - "expand", - "flatten", - "t", - } - for node in module.graph.nodes: - if node.op in {"placeholder", "output", "get_attr"}: - continue - if node.op == "call_function" and node.target in shape_only_ops: - continue - if node.op == "call_method" and node.target in shape_only_ops: - continue - return False - return True - - -def _tensor_meta_py_str(name, shape, dtype, device, mean=0.0, std=1.0): - dtype_str = str(dtype).replace("torch.", "torch.") - return ( - f"class Program_input_tensor_meta_{name}:\n" - f' name = "{name}"\n' - f" shape = {list(shape)}\n" - f' dtype = "{dtype_str}"\n' - f' device = "{device}"\n' - f" mean = {mean:.3f}\n" - f" std = {std:.3f}\n" - f" data = None\n" - ) - - -def _save_backward_model(gm, backward_inputs, output_path): - """Save backward GraphModule with model.py, input_meta.py, weight_meta.py.""" - os.makedirs(output_path, exist_ok=True) - - # model.py - model_py_path = os.path.join(output_path, "model.py") - code = gm.code - if "class GraphModule" not in code: - code = "import torch\n\n" "class GraphModule(torch.nn.Module):\n" + "\n".join( - " " + line if line.strip() else "" for line in code.split("\n") - ) - with open(model_py_path, "w", encoding="utf-8") as f: - f.write(code) - - # GraphNet test_compiler reads weight_meta.py as model inputs. - # Write backward graph inputs into weight_meta.py. - param_names = list(inspect.signature(gm.forward).parameters.keys()) - weight_meta_lines = [] - for idx, name in enumerate(param_names): - if idx < len(backward_inputs): - t = backward_inputs[idx] - if isinstance(t, torch.Tensor): - weight_meta_lines.append( - _tensor_meta_py_str( - name, - t.shape, - t.dtype, - str(t.device), - mean=0.0, - std=1.0, - ) - ) - weight_meta_path = os.path.join(output_path, "weight_meta.py") - with open(weight_meta_path, "w", encoding="utf-8") as f: - f.write("\n".join(weight_meta_lines)) - - # input_meta.py: empty (test_compiler does not use it) - input_meta_path = os.path.join(output_path, "input_meta.py") - with open(input_meta_path, "w", encoding="utf-8") as f: - f.write("") - - -def capture_backward_graph(module, inputs, device="cpu"): - """Capture forward and backward FX Graph via aot_autograd. - - Returns: - (backward_gm, backward_inputs) or (None, None) if no valid grad pairs. - """ - gm_holder = {} - backward_inputs = [] - - def forward_compiler(fx_gm, fwd_inputs): - gm_holder["forward_gm"] = fx_gm - return fx_gm - - def backward_compiler(fx_gm, bwd_inputs): - gm_holder["backward_gm"] = fx_gm - placeholders = [n for n in fx_gm.graph.nodes if n.op == "placeholder"] - origin_forward = fx_gm.forward - fx_gm._original_forward = origin_forward - - def wrapped_forward(*args): - for node, arg in zip(placeholders, args): - backward_inputs.append(arg.detach().clone()) - return origin_forward(*args) - - fx_gm.forward = wrapped_forward - return make_boxed_func(fx_gm) - - compiled = aot_module_simplified( - module, - inputs, - fw_compiler=forward_compiler, - bw_compiler=backward_compiler, - ) - outs = compiled(*inputs) - outs = [outs] if isinstance(outs, torch.Tensor) else outs - valid_pairs = [ - (out, torch.ones_like(out)) - for out in outs - if isinstance(out, torch.Tensor) and out.requires_grad - ] - - if not valid_pairs: - return None, None - - tensors, grads = zip(*valid_pairs) - torch.autograd.backward(tensors, grads) - - backward_gm = gm_holder.get("backward_gm") - if backward_gm is not None: - # Restore original forward for correct signature when saving - if hasattr(backward_gm, "_original_forward"): - backward_gm.forward = backward_gm._original_forward - backward_gm = _remove_none_from_output(backward_gm) - return backward_gm, backward_inputs - - -def _remove_none_from_output(gm): - output_node = next( - (n for n in gm.graph.nodes if n.op == "output"), - None, - ) - if output_node is None: - return gm - outs = ( - output_node.args[0] - if output_node and isinstance(output_node.args, (tuple, list)) - else output_node.args - ) - if isinstance(outs, (tuple, list)): - new_outs = tuple(out for out in outs if out is not None) - if new_outs != outs: - output_node.args = (new_outs,) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm - - -def run_test_compiler(backward_model_path, device="cuda", compiler="nope", trials=10): - """Run graph_net_bench torch test_compiler on a backward model.""" - env = os.environ.copy() - env["GRAPH_NET_FLUCTUATION_DETECT_THRESHOLD"] = "0.5" - cmd = [ - sys.executable, - "-m", - "graph_net_bench.torch.test_compiler", - f"--model-path={backward_model_path}", - f"--compiler={compiler}", - f"--device={device}", - "--warmup=3", - f"--trials={trials}", - "--log-prompt=graph-net-backward-test-compiler-log", - ] - try: - result = subprocess.run( - cmd, capture_output=True, text=True, timeout=300, env=env - ) - stdout = result.stdout - stderr = result.stderr - success = result.returncode == 0 and "[Result] status: success" in stderr - return success, stdout, stderr, result.returncode - except subprocess.TimeoutExpired: - return False, "", "Timeout", -1 - except Exception as e: - return False, "", str(e), -1 - - -def collect_triton_kernels(backward_model_path, device="cuda"): - """Collect Triton kernels by running torch.compile with inductor backend.""" - kernels = [] - try: - import logging - - triton_logger = logging.getLogger("torch._inductor.codecache") - triton_handler = None - kernel_codes = [] - - class KernelCaptureHandler(logging.Handler): - def emit(self, record): - msg = record.getMessage() - if "triton" in msg.lower() and ".py" in msg.lower(): - kernel_codes.append(msg) - - triton_handler = KernelCaptureHandler() - triton_logger.addHandler(triton_handler) - triton_logger.setLevel(logging.DEBUG) - - module, inputs = get_torch_module_and_inputs( - backward_model_path, use_dummy_inputs=False, device=device - ) - compiled = torch.compile(module, backend="inductor") - _ = compiled(*inputs) - torch.cuda.synchronize() if "cuda" in device else None - - triton_logger.removeHandler(triton_handler) - kernels = kernel_codes - except Exception as e: - kernels = [f"Error collecting kernels: {e}"] - return kernels - - -def process_single_sample( - sample_path, - output_dir, - device="cpu", - test_compiler=False, - collect_kernels=False, - replace_inplace=False, - skip_pure_shape=True, -): - """Process a single sample: generate backward graph, optionally test and collect kernels. - - Returns: - dict with status and paths. - """ - rel_path = os.path.relpath( - sample_path, os.path.dirname(os.path.dirname(sample_path)) - ) - result = { - "sample": sample_path, - "rel_path": rel_path, - "status": "unknown", - "reason": "", - "backward_path": None, - "test_compiler_success": None, - "kernels": [], - } - - try: - module, inputs = load_model_from_path(sample_path, device=device) - module.eval() - - if skip_pure_shape and is_pure_shape_graph(module): - result["status"] = "skipped" - result["reason"] = "pure_shape_graph" - return result - - inputs = [ - inp.detach().clone() if isinstance(inp, torch.Tensor) else inp - for inp in inputs - ] - inputs = set_requires_grad_by_meta(sample_path, module, inputs) - - backward_gm, backward_inputs = capture_backward_graph( - module, inputs, device=device - ) - - if backward_gm is None: - result["status"] = "failed" - result["reason"] = "no_valid_grad_pairs" - return result - - backward_dir = os.path.join(output_dir, "backward_graphs", rel_path) - os.makedirs(backward_dir, exist_ok=True) - _save_backward_model(backward_gm, backward_inputs, backward_dir) - - # Copy graph_net.json if it exists - src_json = os.path.join(sample_path, "graph_net.json") - if os.path.exists(src_json): - shutil.copy2(src_json, os.path.join(backward_dir, "graph_net.json")) - - result["backward_path"] = backward_dir - result["status"] = "success" - - if test_compiler: - success, stdout, stderr, rc = run_test_compiler( - backward_dir, device=device, compiler="nope", trials=10 - ) - result["test_compiler_success"] = success - result["test_compiler_rc"] = rc - result["test_compiler_stderr"] = ( - stderr[-2000:] if len(stderr) > 2000 else stderr - ) - if not success: - result["status"] = "test_compiler_failed" - - if collect_kernels and result["status"] == "success": - kernels = collect_triton_kernels(backward_dir, device=device) - result["kernels"] = kernels - - except Exception as e: - result["status"] = "exception" - result["reason"] = f"{type(e).__name__}: {e}" - result["traceback"] = traceback.format_exc() - - return result - - -def main(): - parser = argparse.ArgumentParser( - description="Batch backward graph generation and validation." - ) - parser.add_argument( - "--sample-root", - type=str, - required=True, - help="Root directory containing subdirectories with model.py", - ) - parser.add_argument( - "--limit", - type=int, - default=-1, - help="Maximum number of samples to process (-1 for all)", - ) - parser.add_argument( - "--output-dir", - type=str, - required=True, - help="Output directory for backward graphs and results", - ) - parser.add_argument( - "--test-compiler", - action="store_true", - help="Run test_compiler on generated backward graphs", - ) - parser.add_argument( - "--collect-kernels", - action="store_true", - help="Collect Triton kernels from backward graphs", - ) - parser.add_argument( - "--replace-inplace", - action="store_true", - help="Auto replace inplace=True with inplace=False in model code", - ) - parser.add_argument( - "--skip-pure-shape", - action="store_true", - default=True, - help="Skip pure shape operation subgraphs (default: True)", - ) - parser.add_argument( - "--device", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device for model execution", - ) - parser.add_argument( - "--resume", - action="store_true", - help="Skip already processed samples in output dir", - ) - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - samples = find_samples(args.sample_root, limit=args.limit) - print(f"Found {len(samples)} samples under {args.sample_root}") - - results = [] - stats = {"total": 0, "success": 0, "skipped": 0, "failed": 0, "exception": 0} - - for idx, sample_path in enumerate(samples): - print(f"[{idx + 1}/{len(samples)}] Processing {sample_path} ...") - stats["total"] += 1 - - if args.resume: - rel_path = os.path.relpath( - sample_path, os.path.dirname(os.path.dirname(sample_path)) - ) - backward_dir = os.path.join(args.output_dir, "backward_graphs", rel_path) - if os.path.exists(os.path.join(backward_dir, "model.py")): - print(" [Resume] Skip already processed.") - continue - - result = process_single_sample( - sample_path, - args.output_dir, - device=args.device, - test_compiler=args.test_compiler, - collect_kernels=args.collect_kernels, - replace_inplace=args.replace_inplace, - skip_pure_shape=args.skip_pure_shape, - ) - results.append(result) - - status = result["status"] - if status in stats: - stats[status] += 1 - else: - stats["exception"] += 1 - - print(f" Status: {status}") - if result.get("reason"): - print(f" Reason: {result['reason']}") - if result.get("test_compiler_success") is not None: - print( - f" test_compiler: {'success' if result['test_compiler_success'] else 'failed'}" - ) - - # Save results - summary = { - "args": vars(args), - "stats": stats, - "results": results, - } - result_path = os.path.join(args.output_dir, "backward_results.json") - with open(result_path, "w", encoding="utf-8") as f: - json.dump(summary, f, indent=2, ensure_ascii=False) - - print("\n=== Summary ===") - print(f"Total: {stats['total']}") - print(f"Success: {stats['success']}") - print(f"Skipped: {stats['skipped']}") - print(f"Failed: {stats['failed']}") - print(f"Exception: {stats['exception']}") - print(f"Results saved to: {result_path}") - - -if __name__ == "__main__": - main() diff --git a/tools/backward_kernel_dedup.py b/tools/backward_kernel_dedup.py deleted file mode 100755 index d32ad66de9..0000000000 --- a/tools/backward_kernel_dedup.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -"""Backward kernel dedup analysis tool. - -Usage: - python tools/backward_kernel_dedup.py \ - --backward-dir /tmp/bw_results/backward_graphs \ - --tag typical_backward \ - --output /tmp/bw_dedup.json -""" - -import argparse -import hashlib -import json -import os -import sys -from pathlib import Path - -import torch - - -def compile_and_extract_kernels(model_path, device="cuda"): - """Compile a backward model with inductor and extract kernel sources.""" - kernels = [] - try: - module_name = Path(model_path).name - # Import the model dynamically - import importlib.util - - spec = importlib.util.spec_from_file_location( - module_name, os.path.join(model_path, "model.py") - ) - mod = importlib.util.module_from_spec(spec) - sys.modules[module_name] = mod - spec.loader.exec_module(mod) - model = mod.GraphModule().to(device) - - # We need dummy inputs; try to load from weight_meta / input_meta if available - inputs = [] - for meta_file in ["input_meta.py", "weight_meta.py"]: - meta_path = os.path.join(model_path, meta_file) - if not os.path.exists(meta_path): - continue - spec = importlib.util.spec_from_file_location( - f"{module_name}_{meta_file}", meta_path - ) - meta_mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(meta_mod) - for attr_name in dir(meta_mod): - attr = getattr(meta_mod, attr_name) - if ( - isinstance(attr, type) - and hasattr(attr, "shape") - and hasattr(attr, "dtype") - ): - shape = attr.shape - dtype_str = attr.dtype.replace("torch.", "") - dtype_map = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "int64": torch.int64, - "int32": torch.int32, - "bool": torch.bool, - } - dtype = dtype_map.get(dtype_str, torch.float32) - if "int" in dtype_str or "bool" in dtype_str: - t = torch.zeros(shape, dtype=dtype, device=device) - else: - t = torch.randn(shape, dtype=dtype, device=device) - inputs.append(t) - - if not inputs: - return [] - - compiled = torch.compile(model, backend="inductor") - _ = compiled(*inputs) - if "cuda" in device: - torch.cuda.synchronize() - - # Try to read inductor generated code from cache - cache_dir = os.path.expanduser("~/.torchinductor") - if os.path.exists(cache_dir): - for root, _, files in os.walk(cache_dir): - for f in files: - if f.endswith(".py"): - with open(os.path.join(root, f), "r", encoding="utf-8") as fp: - content = fp.read() - if "triton" in content.lower(): - kernels.append(content) - except Exception as e: - print(f"Error extracting kernels from {model_path}: {e}") - return kernels - - -def hash_kernel(kernel_code): - """Compute a simple hash for a kernel source.""" - # Normalize by removing comments and extra whitespace - lines = [ - line.strip() - for line in kernel_code.split("\n") - if line.strip() and not line.strip().startswith("#") - ] - normalized = "\n".join(lines) - return hashlib.md5(normalized.encode("utf-8")).hexdigest() - - -def analyze_kernels(backward_dir, tag): - """Analyze kernel dedup for all backward graphs under backward_dir.""" - samples = [] - for dirpath, _, filenames in os.walk(backward_dir): - if "model.py" in filenames: - samples.append(dirpath) - - all_hashes = [] - sample_kernels = [] - - for sample_path in samples: - print(f"Processing {sample_path} ...") - kernels = compile_and_extract_kernels(sample_path) - hashes = [hash_kernel(k) for k in kernels] - all_hashes.extend(hashes) - sample_kernels.append( - { - "path": sample_path, - "kernel_count": len(kernels), - "hashes": hashes, - } - ) - - total = len(all_hashes) - unique = len(set(all_hashes)) - dedup_rate = (1 - unique / total) * 100 if total > 0 else 0 - - summary = { - "tag": tag, - "total_samples": len(samples), - "total_kernel_instances": total, - "unique_kernels": unique, - "dedup_rate_percent": round(dedup_rate, 2), - "avg_kernels_per_graph": round(total / len(samples), 2) if samples else 0, - "per_sample": sample_kernels, - } - return summary - - -def main(): - parser = argparse.ArgumentParser(description="Backward kernel dedup analysis.") - parser.add_argument( - "--backward-dir", - type=str, - required=True, - help="Directory containing backward graph subdirectories", - ) - parser.add_argument( - "--tag", - type=str, - default="backward", - help="Tag for the analysis (e.g., typical_backward, fusible_backward)", - ) - parser.add_argument( - "--output", - type=str, - required=True, - help="Output JSON path for dedup results", - ) - parser.add_argument( - "--device", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device for compilation", - ) - args = parser.parse_args() - - summary = analyze_kernels(args.backward_dir, args.tag) - with open(args.output, "w", encoding="utf-8") as f: - json.dump(summary, f, indent=2, ensure_ascii=False) - - print(f"\n=== Dedup Summary ({args.tag}) ===") - print(f"Total samples: {summary['total_samples']}") - print(f"Total kernel instances: {summary['total_kernel_instances']}") - print(f"Unique kernels: {summary['unique_kernels']}") - print(f"Dedup rate: {summary['dedup_rate_percent']}%") - print(f"Avg kernels/graph: {summary['avg_kernels_per_graph']}") - print(f"Result saved to: {args.output}") - - -if __name__ == "__main__": - main()