diff --git a/tools/extract_triton_kernels.sh b/tools/extract_triton_kernels.sh new file mode 100755 index 0000000000..6ff9e8fa63 --- /dev/null +++ b/tools/extract_triton_kernels.sh @@ -0,0 +1,88 @@ +#!/bin/bash +set -euo pipefail + +# Thin launcher for the triton kernel extraction pipeline. +# +# This script sets machine-specific paths and invokes the Python pipeline +# once per dataset category, using pre-computed allow-list files. +# +# Usage: +# bash tools/extract_triton_kernels.sh [gpu_ids] +# +# Args: +# gpu_ids (optional): comma-separated GPU IDs, e.g. "0,2,5,7" +# +# Examples: +# bash tools/extract_triton_kernels.sh # auto-detect GPUs +# bash tools/extract_triton_kernels.sh 0,2,5,7 # specified GPUs + +# ============================================================ +# Arguments +# ============================================================ + +GPU_ARG="${1:-}" + +# ============================================================ +# Machine-specific path configuration +# +# Edit the variables below to match your local environment. +# ============================================================ + +# Root directory containing graph data (base for resolving paths in allow-lists). +GRAPH_DIR="/path/to/input_graph_data" + +# Root directory for pipeline output (cache + extracted kernels). +OUTPUT_BASE="/path/to/dataset_output" + +# Dataset categories to process. +CATEGORIES=(sole_op_subgraphs fusible_subgraphs typical_subgraphs) + +# Allow-list files (one per category, same order as CATEGORIES). +ALLOW_LISTS=( + "${OUTPUT_BASE}/hf_sole_op_samples_v2_all_expanded.txt" + "${OUTPUT_BASE}/hf_fusible_samples_v2_all_expanded.txt" + "${OUTPUT_BASE}/hf_typical_samples_v2_all_expanded.txt" +) + +# ============================================================ +# Environment setup +# ============================================================ + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +# Validate that placeholder paths have been configured. +if [ "$GRAPH_DIR" = "/path/to/input_graph_data" ] || [ "$OUTPUT_BASE" = "/path/to/dataset_output" ]; then + echo "ERROR: Edit GRAPH_DIR and OUTPUT_BASE in this script before running." >&2 + exit 1 +fi + +export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" + +# ============================================================ +# Build GPU args +# ============================================================ + +GPU_ARGS=() +if [ -n "$GPU_ARG" ]; then + IFS=',' read -ra GPU_IDS <<< "$GPU_ARG" + GPU_ARGS=(--gpu-ids "${GPU_IDS[@]}") +fi + +# ============================================================ +# Run pipeline for each category +# ============================================================ + +for i in "${!CATEGORIES[@]}"; do + CATEGORY="${CATEGORIES[$i]}" + + python3 -m tools.triton_kernel_extractor extract \ + --allow-list "${ALLOW_LISTS[$i]}" \ + --graph-dir "$GRAPH_DIR" \ + --output-dir "${OUTPUT_BASE}/${CATEGORY}_inductor_dump" \ + --max-autotune-no-cudagraphs \ + --enable-cache-analysis \ + "${GPU_ARGS[@]}" +done + +echo "All categories processed." diff --git a/tools/triton_kernel_extractor/README.md b/tools/triton_kernel_extractor/README.md new file mode 100644 index 0000000000..2276bc35f9 --- /dev/null +++ b/tools/triton_kernel_extractor/README.md @@ -0,0 +1,178 @@ +# Triton Kernel Extractor + +A pipeline that compiles computational subgraphs through TorchInductor, filters +the results by kernel-level speedup, and extracts the autotuning-selected Triton +kernel source together with the corresponding PTX assembly from the inductor +compilation cache. + +## Background + +When `torch.compile` processes a model via the TorchInductor backend with +`TORCH_COMPILE_DEBUG=1`, the compiler produces a per-graph cache directory +containing: + +- **`output_code.py`** — the generated Python wrapper that calls into Triton + kernels via `async_compile.triton('kernel_name', '''...''')`. The kernels + appearing here are the final, autotuning-selected implementations adopted by + the inductor scheduler. +- **`triton/0/{HASH}/`** — one directory per autotuning candidate + configuration (varying `XBLOCK`, `YBLOCK`, `num_warps`, etc.), each holding + the compiled artifacts (`.ptx`, `.cubin`, `.ttir`, `.llir`, `.source`, + `.json`). When autotuning explores N configurations for a kernel, N + directories are created. +- **`*.best_config`** — a JSON file written by the Triton autotuner recording + the winning configuration. Its `triton_cache_hash` field maps back to one of + the `triton/0/{HASH}/` directories. + +This pipeline automates the full workflow: compile → filter → clean → extract → +pair, producing clean `(subgraph, triton_kernel, ptx)` triples ready for +downstream analysis. + +## Pipeline Steps + +The pipeline executes five steps on the samples enumerated from `--graph-dir` +(recursive scan) or `--allow-list` (explicit paths): + +### Step 1: Multi-GPU Parallel Compilation + +Compiles each subgraph sample using `graph_net_bench.torch.test_compiler +--kernel-time` in an isolated subprocess. Samples are distributed across +available GPUs in round-robin fashion, with one `ProcessPoolExecutor` worker per +GPU. Each subprocess receives a dedicated `CUDA_VISIBLE_DEVICES` and an +isolated `TORCHINDUCTOR_CACHE_DIR`. Pass `--max-autotune-no-cudagraphs` to enable +Inductor's `max-autotune-no-cudagraphs` mode (via +`torch.compile(mode="max-autotune-no-cudagraphs")`), which activates comprehensive +autotuning including `max_autotune_gemm`, +`coordinate_descent_tuning`, and `epilogue_fusion`. + +### Step 2: Speedup Filtering + +Parses the `[Speedup][kernel]:` metric from each sample's compilation log (the +last occurrence is used). Samples achieving a speedup >= 1.0 are moved to +`kept/`; the rest are moved to `discarded/`. + +### Step 3: Temporary File Cleanup + +Recursively removes `__pycache__/` directories, `*.pyc`, and `*.pyo` files from +the output tree to reduce storage footprint before extraction. + +### Step 4: Kernel and PTX Extraction + +For each kept sample that contains `original_graph/graph_hash.txt`: + +1. Copies `original_graph/model.py` (the source subgraph) into the output. +2. Parses `output_code.py` to extract all Triton kernel definitions using a + regex equivalent of the original Perl one-liner. +3. Writes each kernel source to `triton_kernel/{kernel_name}.py`. +4. Locates the corresponding PTX for each kernel by scanning `triton/0/` and + disambiguating via `.best_config` when multiple autotuning candidates exist, + then writes it to `ptx/{kernel_name}.ptx`. + +Output is written atomically (`.tmp` directory + `rename`) so that an +interrupted run never leaves half-written data. + +### Step 5: Empty Sample Cleanup + +Removes output samples that contain `original_graph/` but no `triton_kernel/` +directory (i.e., samples where no Triton kernels were extracted). + +## PTX Resolution Algorithm + +Each Triton kernel may have been compiled under multiple autotuning +configurations. The algorithm to locate the winning PTX is: + +1. Scan `triton/0/*/` for directories containing `{kernel_name}.ptx`. +2. If exactly one candidate exists, use it directly (no autotuning was needed). +3. If multiple candidates exist, collect `triton_cache_hash` values from all + `*.best_config` files in the sample, and select the candidate whose directory + name matches one of these hashes. + +## Output Structure + +``` +{output-dir}/ + {sample_name}/ # compilation cache (kept/discarded) + kept/ + discarded/ + extracted/ + {sample_name}/ + original_graph/ + model.py # source subgraph + triton_kernel/ + triton_poi_fused_xxx_0.py # Triton kernel source + triton_poi_fused_yyy_1.py + ptx/ + triton_poi_fused_xxx_0.ptx # corresponding PTX assembly + triton_poi_fused_yyy_1.ptx + analysis/ # if --enable-cache-analysis +``` + +## Usage + +```bash +# With allow-list: read sample paths from file, resolve against --graph-dir +python3 -m tools.triton_kernel_extractor extract \ + --allow-list /data/typical_samples_expanded.txt \ + --graph-dir /data/graphs/typical_subgraphs \ + --output-dir /data/output/typical_inductor_dump \ + --gpu-ids 0 2 5 7 + +# Without allow-list: recursively find all model.py in --graph-dir +python3 -m tools.triton_kernel_extractor extract \ + --graph-dir /data/graphs/typical_subgraphs \ + --output-dir /data/output/typical_inductor_dump \ + --gpu-ids 0 2 5 7 \ + --max-autotune-no-cudagraphs \ + --enable-cache-analysis + +# Cache analysis standalone: +python3 -m tools.triton_kernel_extractor analyze [--output-dir DIR] +``` + +### CLI Arguments + +#### `extract` subcommand + +| Argument | Required | Default | Description | +|---------------------------|----------|----------------------|-----------------------------------------------------------------------------| +| `--allow-list` | No | `None` | Text file with sample paths (one per line), relative to `--graph-dir`. When omitted, `--graph-dir` is scanned recursively for `model.py` | +| `--graph-dir` | Yes | — | Input graph data root. Scanned for `model.py` by default; path resolution base when `--allow-list` is given | +| `--output-dir` | Yes | — | Pipeline output directory (compilation cache, extracted kernels, analysis) | +| `--gpu-ids` | No | Auto-detected | GPU IDs for parallel compilation. Auto-detected via `CUDA_VISIBLE_DEVICES` or `nvidia-smi` when omitted | +| `--max-autotune-no-cudagraphs` | No | `False` | Enable Inductor `max-autotune-no-cudagraphs` mode for compilation | +| `--enable-cache-analysis` | No | `False` | Run cache analysis (statistics, plots) after extraction | + +#### `analyze` subcommand + +| Argument | Required | Default | Description | +|----------------|----------|----------------------|-------------------------------------------------| +| `cache_dir` | Yes | — | Inductor cache directory to analyze | +| `--output-dir` | No | `/analysis` | Directory for analysis output | + +## Module Structure + +``` +triton_kernel_extractor/ + __init__.py # package marker + __main__.py # CLI entry point (subcommands: extract, analyze) + config.py # PipelineConfig, constants + sample_enumerator.py # enumerate samples from graph-dir or allow-list + compiler.py # Step 1: multi-GPU parallel compilation + speedup_filter.py # Step 2: filter by kernel speedup + temp_cleaner.py # Step 3: remove __pycache__ / *.pyc / *.pyo + kernel_extractor.py # Step 4: extract Triton kernels and PTX + empty_sample_cleaner.py # Step 5: remove samples without Triton kernels + pipeline.py # orchestrate Steps 1-5 + cache_analyzer.py # analyze cache: logs, statistics, plots +``` + +## Idempotency and Resume + +Every step implements skip logic to support safe re-execution: + +- **Compilation** skips samples whose log already contains `[Speedup][kernel]:` + or that already exist under `kept/` or `discarded/`. +- **Filtering** skips samples already classified into `kept/` or `discarded/`. +- **Extraction** skips output samples that already exist in the output directory. + Stale `.tmp` directories from prior interrupted runs are cleaned up + automatically on startup. diff --git a/tools/triton_kernel_extractor/__init__.py b/tools/triton_kernel_extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/triton_kernel_extractor/__main__.py b/tools/triton_kernel_extractor/__main__.py new file mode 100644 index 0000000000..c55aa002ab --- /dev/null +++ b/tools/triton_kernel_extractor/__main__.py @@ -0,0 +1,282 @@ +"""CLI entry point for the triton kernel extraction pipeline. + +Subcommands +----------- + +**extract** (default when no subcommand is given):: + + python3 -m tools.triton_kernel_extractor [extract] \\ + --graph-dir /data/graphs/typical_subgraphs \\ + --output-dir /data/output/typical_inductor_dump \\ + [--allow-list paths.txt] \\ + [--gpu-ids 0 2 5 7] + +**analyze**:: + + python3 -m tools.triton_kernel_extractor analyze \\ + [--output-dir DIR] + +When ``--allow-list`` is provided, sample paths are read from the file and +resolved relative to ``--graph-dir``. Otherwise ``--graph-dir`` is scanned +recursively for ``model.py`` files. + +When ``--gpu-ids`` is omitted the script auto-detects all available GPUs +by parsing the output of ``nvidia-smi -L``. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import re +import subprocess +import sys +from pathlib import Path + +from .config import PipelineConfig + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# GPU detection (shared by the extract subcommand) +# --------------------------------------------------------------------------- + + +def _detect_gpu_ids() -> list[int]: + """Auto-detect available GPU IDs. + + Priority order (matching the original bash script): + 1. ``CUDA_VISIBLE_DEVICES`` environment variable + 2. ``nvidia-smi -L`` output + + Returns a list of integer GPU indices. Raises ``RuntimeError`` + when no GPUs are found. + """ + # Priority 1: honour CUDA_VISIBLE_DEVICES if set. + cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() + if cuda_env: + try: + return [int(x) for x in cuda_env.split(",") if x.strip()] + except ValueError: + pass # Fall through to nvidia-smi. + + # Priority 2: auto-detect from nvidia-smi. + try: + result = subprocess.run( + ["nvidia-smi", "-L"], + capture_output=True, + text=True, + timeout=10, + ) + ids = [int(m) for m in re.findall(r"GPU (\d+):", result.stdout)] + except (FileNotFoundError, subprocess.TimeoutExpired): + ids = [] + + if not ids: + raise RuntimeError( + "No GPUs detected. Pass --gpu-ids explicitly or check nvidia-smi." + ) + return ids + + +# --------------------------------------------------------------------------- +# Subcommand: extract +# --------------------------------------------------------------------------- + + +def _add_extract_parser(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "extract", + help="Run the compilation and extraction pipeline.", + description=( + "Compile graph samples and extract " + "(subgraph, triton_kernel, ptx) triples." + ), + ) + parser.add_argument( + "--allow-list", + type=Path, + default=None, + help=( + "Text file with sample paths (one per line), relative to --graph-dir. " + "When omitted, --graph-dir is scanned recursively for model.py." + ), + ) + parser.add_argument( + "--graph-dir", + type=Path, + required=True, + help=( + "Root directory of input graph data. " + "Scanned recursively for model.py by default. " + "When --allow-list is given, used as base for resolving relative paths." + ), + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Directory for pipeline output (compilation cache, extracted kernels, analysis).", + ) + parser.add_argument( + "--gpu-ids", + type=int, + nargs="*", + default=None, + help=( + "GPU IDs to use for parallel compilation. " + "Auto-detected via nvidia-smi when omitted." + ), + ) + parser.add_argument( + "--enable-cache-analysis", + action="store_true", + default=False, + help="Run cache analysis (statistics, plots) after extraction.", + ) + parser.add_argument( + "--max-autotune-no-cudagraphs", + action="store_true", + default=False, + help=( + "Enable Inductor max-autotune-no-cudagraphs mode during compilation " + "(passes mode='max-autotune-no-cudagraphs' to torch.compile " + "via graph_net_bench.torch.test_compiler)." + ), + ) + parser.set_defaults(func=_run_extract) + + +def _run_extract(args: argparse.Namespace) -> None: + from .pipeline import run_pipeline + + gpu_ids = args.gpu_ids if args.gpu_ids else _detect_gpu_ids() + + config = PipelineConfig( + gpu_ids=gpu_ids, + graph_dir=args.graph_dir, + output_dir=args.output_dir, + allow_list=args.allow_list, + max_autotune_no_cudagraphs=args.max_autotune_no_cudagraphs, + ) + + logger.info( + "Using %d GPU(s): %s", + len(config.gpu_ids), + " ".join(str(g) for g in config.gpu_ids), + ) + + # Unset CUDA_VISIBLE_DEVICES in the parent process so that worker + # subprocesses start with a clean slate and receive only the per-GPU + # value assigned by compiler.py. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + run_pipeline(config, enable_cache_analysis=args.enable_cache_analysis) + + +# --------------------------------------------------------------------------- +# Subcommand: analyze +# --------------------------------------------------------------------------- + + +def _add_analyze_parser(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "analyze", + help="Analyze an inductor cache directory (logs, statistics, plots).", + description=( + "Concatenate compiler logs, compute speedup statistics, and " + "generate distribution plots for an inductor cache directory." + ), + ) + parser.add_argument( + "cache_dir", + type=Path, + help="Inductor cache directory to analyze.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help=("Directory for analysis output. " "Defaults to /analysis."), + ) + parser.set_defaults(func=_run_analyze) + + +def _run_analyze(args: argparse.Namespace) -> None: + from .cache_analyzer import analyze_cache + + cache_dir: Path = args.cache_dir + output_dir: Path = args.output_dir or (cache_dir / "analysis") + analyze_cache(cache_dir, output_dir) + + +# --------------------------------------------------------------------------- +# Backward-compatible argument detection +# --------------------------------------------------------------------------- + + +def _needs_implicit_extract(argv: list[str]) -> bool: + """Return True if *argv* does not start with a known subcommand. + + When the first argument is a flag like ``--graph-dir`` rather than a + subcommand name, we prepend ``extract`` for convenience. + """ + if not argv: + return False + known_subcommands = {"extract", "analyze"} + first = argv[0] + if first in known_subcommands: + return False + if first in ("-h", "--help"): + return False + return True + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> None: + logging.basicConfig( + format="%(message)s", + level=logging.INFO, + stream=sys.stderr, + ) + + if argv is None: + argv = sys.argv[1:] + + # Backward compatibility: insert "extract" when no subcommand is given. + if _needs_implicit_extract(argv): + argv = ["extract"] + argv + + parser = argparse.ArgumentParser( + prog="python3 -m tools.triton_kernel_extractor", + description=( + "Triton kernel extraction toolkit: compile, filter, extract, " + "and analyze TorchInductor compilation caches." + ), + ) + subparsers = parser.add_subparsers(dest="command") + _add_extract_parser(subparsers) + _add_analyze_parser(subparsers) + + args = parser.parse_args(argv) + + if not hasattr(args, "func"): + parser.print_help() + sys.exit(1) + + try: + args.func(args) + except KeyboardInterrupt: + logger.info("") + logger.info("Interrupted.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tools/triton_kernel_extractor/cache_analyzer.py b/tools/triton_kernel_extractor/cache_analyzer.py new file mode 100644 index 0000000000..ab0e51fc49 --- /dev/null +++ b/tools/triton_kernel_extractor/cache_analyzer.py @@ -0,0 +1,480 @@ +"""Analyze an inductor cache directory and produce summary statistics and plots. + +This module replaces the standalone ``analyze_inductor_cache.sh`` script. It +concatenates compiler logs from all sample states (root, kept, discarded), +computes speedup statistics, and generates distribution plots. + +The analysis can be invoked via the CLI:: + + python3 -m tools.triton_kernel_extractor analyze [--output-dir DIR] +""" + +from __future__ import annotations + +import logging +import re +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +from .config import SPEEDUP_E2E_PATTERN, SPEEDUP_KERNEL_PATTERN, is_sample_dir + +logger = logging.getLogger(__name__) + +_SPEEDUP_E2E_RE = re.compile(SPEEDUP_E2E_PATTERN) +_SPEEDUP_KERNEL_RE = re.compile(SPEEDUP_KERNEL_PATTERN) + + +# --------------------------------------------------------------------------- +# Step 1: Log concatenation +# --------------------------------------------------------------------------- + + +def _concat_logs(search_dir: Path) -> tuple[str, int]: + """Concatenate ``test_compiler_log.log`` from all samples under *search_dir*. + + Returns the combined text and the number of log files found. + """ + if not search_dir.is_dir(): + return "", 0 + + parts: list[str] = [] + count = 0 + for sample_dir in sorted(search_dir.iterdir()): + if not sample_dir.is_dir(): + continue + if not is_sample_dir(sample_dir.name): + continue + log_file = sample_dir / "test_compiler_log.log" + if log_file.is_file(): + parts.append(log_file.read_text(encoding="utf-8", errors="replace")) + count += 1 + return "\n".join(parts), count + + +def concatenate_logs(cache_dir: Path, output_dir: Path) -> tuple[Path, Path, Path]: + """Concatenate logs from root, kept, and discarded sample directories. + + Writes three files to *output_dir* and returns their paths: + ``(all_log, kept_log, discarded_log)``. + """ + root_text, root_count = _concat_logs(cache_dir) + kept_text, kept_count = _concat_logs(cache_dir / "kept") + discarded_text, discarded_count = _concat_logs(cache_dir / "discarded") + + all_text = "\n".join(filter(None, [root_text, kept_text, discarded_text])) + + all_log = output_dir / "all_samples.log" + kept_log = output_dir / "kept_samples.log" + discarded_log = output_dir / "discarded_samples.log" + + all_log.write_text(all_text, encoding="utf-8") + kept_log.write_text(kept_text, encoding="utf-8") + discarded_log.write_text(discarded_text, encoding="utf-8") + + total = root_count + kept_count + discarded_count + logger.info( + " Logs concatenated: %d total (%d root, %d kept, %d discarded)", + total, + root_count, + kept_count, + discarded_count, + ) + logger.info(" All: %s", all_log) + logger.info(" Kept: %s", kept_log) + logger.info(" Discarded: %s", discarded_log) + + return all_log, kept_log, discarded_log + + +# --------------------------------------------------------------------------- +# Step 2: Summary statistics +# --------------------------------------------------------------------------- + + +def _parse_speedups(text: str, pattern: re.Pattern[str]) -> list[float]: + """Extract all speedup values matching *pattern* from log text.""" + return [float(m) for m in pattern.findall(text)] + + +def _percentile(values: list[float], p: float) -> float: + """Compute the *p*-th percentile (0–100) of a sorted list.""" + if not values: + return 0.0 + k = (len(values) - 1) * p / 100.0 + f = int(k) + c = f + 1 if f + 1 < len(values) else f + return values[f] + (k - f) * (values[c] - values[f]) + + +def _format_speedup_stats(values: list[float], label: str) -> str: + """Format a block of descriptive statistics for a speedup distribution.""" + lines: list[str] = [] + n = len(values) + lines.append(f" Samples with {label} speedup: {n}") + + if n == 0: + return "\n".join(lines) + + values_sorted = sorted(values) + mean = sum(values_sorted) / n + median = _percentile(values_sorted, 50) + + lines.append("") + lines.append(f" Mean: {mean:.4f}") + lines.append(f" Median: {median:.4f}") + lines.append(f" Min: {values_sorted[0]:.4f}") + lines.append(f" Max: {values_sorted[-1]:.4f}") + lines.append(f" P5: {_percentile(values_sorted, 5):.4f}") + lines.append(f" P25: {_percentile(values_sorted, 25):.4f}") + lines.append(f" P75: {_percentile(values_sorted, 75):.4f}") + lines.append(f" P95: {_percentile(values_sorted, 95):.4f}") + + ge2 = sum(1 for v in values_sorted if v >= 2.0) + ge1_5 = sum(1 for v in values_sorted if v >= 1.5) + ge1 = sum(1 for v in values_sorted if v >= 1.0) + lt1 = sum(1 for v in values_sorted if v < 1.0) + lt0_5 = sum(1 for v in values_sorted if v < 0.5) + + lines.append("") + lines.append(f" Speedup >= 2.0: {ge2} ({ge2/n*100:.1f}%)") + lines.append(f" Speedup >= 1.5: {ge1_5} ({ge1_5/n*100:.1f}%)") + lines.append(f" Speedup >= 1.0: {ge1} ({ge1/n*100:.1f}%)") + lines.append(f" Speedup < 1.0: {lt1} ({lt1/n*100:.1f}%) [negative optimization]") + lines.append(f" Speedup < 0.5: {lt0_5} ({lt0_5/n*100:.1f}%) [severe regression]") + + return "\n".join(lines) + + +def _count_subdirs(directory: Path) -> int: + """Count immediate subdirectories of *directory*.""" + if not directory.is_dir(): + return 0 + return sum(1 for d in directory.iterdir() if d.is_dir()) + + +def generate_summary( + cache_dir: Path, + all_log_text: str, + discarded_log_text: str, + output_dir: Path, +) -> Path: + """Generate a text summary report and return its path.""" + kernel_speedups = _parse_speedups(all_log_text, _SPEEDUP_KERNEL_RE) + e2e_speedups = _parse_speedups(all_log_text, _SPEEDUP_E2E_RE) + + # Count samples in each state. + root_samples = ( + sum(1 for d in cache_dir.iterdir() if d.is_dir() and is_sample_dir(d.name)) + if cache_dir.is_dir() + else 0 + ) + kept_samples = _count_subdirs(cache_dir / "kept") + discarded_samples = _count_subdirs(cache_dir / "discarded") + total_samples = root_samples + kept_samples + discarded_samples + + def pct(n: int) -> str: + return f"{n/total_samples*100:.1f}" if total_samples > 0 else "0.0" + + lines: list[str] = [ + "Inductor Cache Analysis Report", + "==============================", + f"Cache dir: {cache_dir}", + f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + "", + "Sample Counts", + "-------------", + f" Total: {total_samples}", + f" Kept: {kept_samples} ({pct(kept_samples)}%)", + f" Discarded: {discarded_samples} ({pct(discarded_samples)}%)", + f" Unclassified (root): {root_samples}", + "", + "Kernel Speedup Distribution (primary — used for kept/discarded filtering)", + "-------------------------------------------------------------------------", + _format_speedup_stats(kernel_speedups, "Kernel"), + "", + "E2E Speedup Distribution (secondary — includes framework overhead)", + "-------------------------------------------------------------------", + _format_speedup_stats(e2e_speedups, "E2E"), + ] + + # Failure analysis from discarded logs. + if discarded_log_text: + neg_opt = sum(1 for v in kernel_speedups if 0 < v < 1.0) + error_lines = sum( + 1 + for line in discarded_log_text.splitlines() + if re.search(r"ERROR|Exception|Traceback", line) + ) + lines.extend( + [ + "", + "Failure/Discard Analysis", + "------------------------", + f" Negative optimization (0 < kernel speedup < 1): {neg_opt}", + f" Logs with errors/exceptions: {error_lines} lines", + ] + ) + + report = "\n".join(lines) + "\n" + summary_file = output_dir / "summary.txt" + summary_file.write_text(report, encoding="utf-8") + + # Also print to console. + logger.info("%s", report) + logger.info("Summary saved to: %s", summary_file) + + return summary_file + + +# --------------------------------------------------------------------------- +# Step 3: Plots +# --------------------------------------------------------------------------- + + +def _check_plotting_deps() -> bool: + """Return True if matplotlib and numpy are importable.""" + try: + import matplotlib # noqa: F401 + import numpy # noqa: F401 + + return True + except ImportError: + return False + + +def generate_plots( + all_log_text: str, + all_log_path: Path, + output_dir: Path, +) -> None: + """Generate speedup distribution plots. + + Produces: + - ``speedup_histogram.png`` — raw and log2 histograms of kernel speedup. + - ``speedup_cdf.png`` — cumulative distribution function. + - ``violin.png`` — via ``graph_net_visual.plot_violin`` (best effort). + - ``ESt.png`` — via ``graph_net_visual.plot_ESt`` (best effort). + """ + if not _check_plotting_deps(): + logger.warning( + " Skipping plots: matplotlib or numpy not installed. " + "Run: pip install matplotlib numpy" + ) + return + + kernel_speedups = _parse_speedups(all_log_text, _SPEEDUP_KERNEL_RE) + if not kernel_speedups: + logger.info(" No kernel speedup data for plotting.") + return + + _generate_builtin_plots(kernel_speedups, output_dir) + _run_visual_plot("graph_net_visual.plot_violin", all_log_path, output_dir) + _run_visual_plot( + "graph_net_visual.plot_ESt", + all_log_path, + output_dir, + extra_args=["--disable-aggregation-mode"], + ) + + +def _generate_builtin_plots( + kernel_speedups: list[float], + output_dir: Path, +) -> None: + """Generate histogram and CDF plots using matplotlib.""" + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + speedups = np.array(kernel_speedups) + + # --- Histogram: raw + log2 side by side --- + fig, axes = plt.subplots(1, 2, figsize=(18, 7)) + + ax1 = axes[0] + bins = np.concatenate( + [ + np.arange(0, 1.0, 0.1), + np.arange(1.0, 2.0, 0.1), + np.arange(2.0, max(5.0, float(np.percentile(speedups, 99))) + 0.5, 0.5), + ] + ) + ax1.hist(speedups, bins=bins, color="steelblue", edgecolor="white", alpha=0.85) + ax1.axvline( + x=1.0, color="red", linestyle="--", linewidth=1.5, label="speedup = 1.0" + ) + ax1.axvline( + x=float(np.median(speedups)), + color="orange", + linestyle="-", + linewidth=1.5, + label=f"median = {np.median(speedups):.3f}", + ) + ax1.set_xlabel("Kernel Speedup", fontsize=14) + ax1.set_ylabel("Count", fontsize=14) + ax1.set_title("Kernel Speedup Distribution", fontsize=16) + ax1.legend(fontsize=12) + + n_total = len(speedups) + n_pos = int(np.sum(speedups >= 1.0)) + n_neg = int(np.sum(speedups < 1.0)) + stats_text = ( + f"Total: {n_total}\n" + f"Speedup >= 1: {n_pos} ({n_pos/n_total*100:.1f}%)\n" + f"Speedup < 1: {n_neg} ({n_neg/n_total*100:.1f}%)\n" + f"Mean: {np.mean(speedups):.3f}\n" + f"Median: {np.median(speedups):.3f}" + ) + ax1.text( + 0.97, + 0.97, + stats_text, + transform=ax1.transAxes, + fontsize=10, + verticalalignment="top", + horizontalalignment="right", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8), + ) + + # Log2 histogram. + positive = speedups[speedups > 0] + log2_sp = np.log2(positive) + ax2 = axes[1] + ax2.hist(log2_sp, bins=60, color="darkorange", edgecolor="white", alpha=0.85) + ax2.axvline( + x=0, color="red", linestyle="--", linewidth=1.5, label="log2(speedup) = 0" + ) + ax2.axvline( + x=float(np.median(log2_sp)), + color="blue", + linestyle="-", + linewidth=1.5, + label=f"median = {np.median(log2_sp):.3f}", + ) + ax2.set_xlabel("log2(Kernel Speedup)", fontsize=14) + ax2.set_ylabel("Count", fontsize=14) + ax2.set_title("log2(Kernel Speedup) Distribution", fontsize=16) + ax2.legend(fontsize=12) + + plt.tight_layout() + hist_path = output_dir / "speedup_histogram.png" + plt.savefig(str(hist_path), dpi=200, bbox_inches="tight") + plt.close(fig) + logger.info(" Saved: %s", hist_path) + + # --- CDF --- + fig2, ax3 = plt.subplots(figsize=(10, 6)) + sorted_sp = np.sort(speedups) + cdf = np.arange(1, len(sorted_sp) + 1) / len(sorted_sp) + ax3.plot(sorted_sp, cdf, color="steelblue", linewidth=2) + ax3.axvline( + x=1.0, color="red", linestyle="--", linewidth=1.2, label="speedup = 1.0" + ) + + cdf_at_1 = float(np.searchsorted(sorted_sp, 1.0) / len(sorted_sp)) + ax3.axhline(y=cdf_at_1, color="gray", linestyle=":", linewidth=1, alpha=0.7) + ax3.plot(1.0, cdf_at_1, "ro", markersize=8) + ax3.text(1.05, cdf_at_1, f"{cdf_at_1*100:.1f}% below 1.0", fontsize=12, color="red") + ax3.set_xlabel("Kernel Speedup", fontsize=14) + ax3.set_ylabel("Cumulative Fraction", fontsize=14) + ax3.set_title("Kernel Speedup CDF", fontsize=16) + ax3.set_xlim(0, min(5.0, float(np.percentile(speedups, 99.5)))) + ax3.legend(fontsize=12) + ax3.grid(True, alpha=0.3) + + cdf_path = output_dir / "speedup_cdf.png" + plt.savefig(str(cdf_path), dpi=200, bbox_inches="tight") + plt.close(fig2) + logger.info(" Saved: %s", cdf_path) + + +def _run_visual_plot( + module: str, + log_path: Path, + output_dir: Path, + *, + extra_args: list[str] | None = None, +) -> None: + """Run a graph_net_visual plotting module (best effort, never fatal).""" + cmd = [ + sys.executable, + "-m", + module, + "--benchmark-path", + str(log_path), + "--output-dir", + str(output_dir), + ] + if extra_args: + cmd.extend(extra_args) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=120, + ) + # Log any saved/error lines from stdout+stderr. + for line in (result.stdout + result.stderr).splitlines(): + if re.search(r"saved|Saved|Error|Warning", line, re.IGNORECASE): + logger.info(" %s: %s", module.rsplit(".", 1)[-1], line.strip()) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError) as exc: + logger.debug(" %s unavailable: %s", module, exc) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def analyze_cache(cache_dir: Path, output_dir: Path) -> None: + """Run the full analysis pipeline on an inductor cache directory.""" + if not cache_dir.is_dir(): + logger.error("Cache directory does not exist: %s", cache_dir) + return + + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("======================================================") + logger.info(" Inductor Cache Analysis") + logger.info(" Input: %s", cache_dir) + logger.info(" Output: %s", output_dir) + logger.info("======================================================") + + # Step 1: Concatenate logs. + logger.info("") + logger.info("=== Step 1: Concatenating log files ===") + all_log, _kept_log, discarded_log = concatenate_logs(cache_dir, output_dir) + + all_log_text = all_log.read_text(encoding="utf-8", errors="replace") + discarded_log_text = discarded_log.read_text(encoding="utf-8", errors="replace") + + # Step 2: Summary statistics. + logger.info("") + logger.info("=== Step 2: Summary statistics ===") + generate_summary(cache_dir, all_log_text, discarded_log_text, output_dir) + + # Step 3: Plots. + logger.info("") + logger.info("=== Step 3: Generating plots ===") + generate_plots(all_log_text, all_log, output_dir) + + # Report output files. + logger.info("") + logger.info("======================================================") + logger.info(" Analysis complete!") + logger.info(" Output directory: %s", output_dir) + logger.info("======================================================") + output_files = sorted( + f + for f in output_dir.iterdir() + if f.is_file() and f.suffix in {".txt", ".log", ".png"} + ) + for f in output_files: + size_kb = f.stat().st_size / 1024 + logger.info(" %s (%.1f KB)", f.name, size_kb) diff --git a/tools/triton_kernel_extractor/compiler.py b/tools/triton_kernel_extractor/compiler.py new file mode 100644 index 0000000000..a97f44654a --- /dev/null +++ b/tools/triton_kernel_extractor/compiler.py @@ -0,0 +1,245 @@ +"""Step 1: Multi-GPU parallel compilation of graph samples.""" + +from __future__ import annotations + +import base64 +import json +import logging +import os +import shutil +import subprocess +import sys +from concurrent.futures import Future, ProcessPoolExecutor, as_completed +from pathlib import Path + +from .config import PipelineConfig +from .sample_enumerator import compute_unique_dir + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Single-sample compilation +# --------------------------------------------------------------------------- + + +def _is_already_compiled( + log_file: Path, + cache_dir: Path, + unique_dir: str, +) -> bool: + """Check whether a sample has already been compiled in a prior run. + + Mirrors the bash resume logic:: + + { [ -f "$log_file" ] && grep -q '[Speedup][kernel]:' "$log_file"; } \ + || [ -d "$cache_dir/kept/$unique_dir" ] \ + || [ -d "$cache_dir/discarded/$unique_dir" ] + """ + if log_file.is_file(): + try: + content = log_file.read_text(encoding="utf-8", errors="replace") + if "[Speedup][kernel]:" in content: + return True + except OSError: + pass + + if (cache_dir / "kept" / unique_dir).is_dir(): + return True + if (cache_dir / "discarded" / unique_dir).is_dir(): + return True + + return False + + +def _compile_one_sample( + sample_path: str, + graph_dir: str, + cache_dir: Path, + gpu_id: int, + progress_label: str, + compiler_config: str | None = None, +) -> str: + """Compile a single graph sample on a specific GPU. + + Returns one of ``"compiled"``, ``"skipped"``, or ``"failed"``. + """ + unique_dir = compute_unique_dir(sample_path, graph_dir) + + sample_cache_dir = cache_dir / unique_dir + log_file = sample_cache_dir / "test_compiler_log.log" + + if _is_already_compiled(log_file, cache_dir, unique_dir): + logger.info("%s SKIP: %s", progress_label, sample_path) + return "skipped" + + # Remove incomplete cache from a prior interrupted attempt. + if sample_cache_dir.exists(): + shutil.rmtree(sample_cache_dir) + sample_cache_dir.mkdir(parents=True) + + logger.info("%s Compiling: %s", progress_label, sample_path) + + # Build a clean environment for the compiler subprocess. + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + env["TORCH_COMPILE_DEBUG"] = "1" + env["TORCHINDUCTOR_CACHE_DIR"] = str(sample_cache_dir) + # Ensure graph_net_bench is importable by the subprocess (tools/ -> repo root). + repo_root = str(Path(__file__).resolve().parents[2]) + env["PYTHONPATH"] = f"{repo_root}:{env.get('PYTHONPATH', '')}" + + result = subprocess.run( + [ + sys.executable, + "-m", + "graph_net_bench.torch.test_compiler", + "--model-path", + sample_path, + "--kernel-time", + "--warmup", + "5", + "--trials", + "100", + *(["--config", compiler_config] if compiler_config else []), + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # Write combined stdout+stderr to the log file (matches bash "> log 2>&1"). + log_file.write_text(result.stdout or "", encoding="utf-8") + + # Copy the original graph source into the cache. + og_subdir = sample_cache_dir / "original_graph" + og_subdir.mkdir(exist_ok=True) + model_src = Path(sample_path) + if model_src.is_dir(): + for item in model_src.iterdir(): + dest = og_subdir / item.name + if item.is_dir(): + shutil.copytree(str(item), str(dest), dirs_exist_ok=True) + else: + shutil.copy2(str(item), str(dest)) + + if result.returncode != 0: + return "failed" + return "compiled" + + +# --------------------------------------------------------------------------- +# Per-GPU sequential chunk worker +# --------------------------------------------------------------------------- + + +def _compile_chunk( + samples: list[str], + graph_dir: str, + cache_dir: Path, + gpu_id: int, + compiler_config: str | None = None, +) -> dict[str, int]: + """Process a chunk of samples sequentially on one GPU. + + This function is the top-level callable submitted to the process pool. + Each invocation runs in its own process, isolating CUDA state. + """ + total = len(samples) + stats = {"compiled": 0, "skipped": 0, "failed": 0} + + for idx, sample_path in enumerate(samples, 1): + label = f"[GPU{gpu_id} {idx}/{total}]" + status = _compile_one_sample( + sample_path=sample_path, + graph_dir=graph_dir, + cache_dir=cache_dir, + gpu_id=gpu_id, + progress_label=label, + compiler_config=compiler_config, + ) + stats[status] += 1 + + logger.info( + "[GPU%d] Done: %d compiled, %d skipped, %d failed (total: %d)", + gpu_id, + stats["compiled"], + stats["skipped"], + stats["failed"], + total, + ) + return stats + + +# --------------------------------------------------------------------------- +# Multi-GPU orchestrator +# --------------------------------------------------------------------------- + + +def compile_all_samples( + samples: list[str], + config: PipelineConfig, +) -> dict[str, int]: + """Split samples across GPUs round-robin and compile in parallel. + + Each GPU gets its own worker process that processes its chunk + sequentially, matching the original bash behaviour of one + ``compile_worker`` per GPU. + + Returns aggregated ``{"compiled": N, "skipped": N, "failed": N}``. + """ + gpu_ids = config.gpu_ids + num_gpus = len(gpu_ids) + + # Build base64-encoded config for test_compiler --config, if needed. + compiler_config: str | None = None + if config.max_autotune_no_cudagraphs: + config_dict = {"mode": "max-autotune-no-cudagraphs"} + compiler_config = base64.b64encode(json.dumps(config_dict).encode()).decode() + + # Round-robin assignment (mirrors bash: gpu_id = GPU_IDS[local_idx % NUM_GPUS]). + chunks: dict[int, list[str]] = {gid: [] for gid in gpu_ids} + for idx, sample in enumerate(samples): + gid = gpu_ids[idx % num_gpus] + chunks[gid].append(sample) + + aggregated: dict[str, int] = {"compiled": 0, "skipped": 0, "failed": 0} + + # Use one process per GPU. max_workers == num_gpus ensures no GPU + # contention. + with ProcessPoolExecutor(max_workers=num_gpus) as executor: + future_to_gpu: dict[Future[dict[str, int]], int] = {} + + for gid in gpu_ids: + chunk = chunks[gid] + if not chunk: + continue + future = executor.submit( + _compile_chunk, + samples=chunk, + graph_dir=str(config.graph_dir), + cache_dir=config.output_dir, + gpu_id=gid, + compiler_config=compiler_config, + ) + future_to_gpu[future] = gid + logger.info(" Launched worker GPU %d (%d samples)", gid, len(chunk)) + + logger.info(" Waiting for %d workers...", len(future_to_gpu)) + + has_errors = False + for future in as_completed(future_to_gpu): + gid = future_to_gpu[future] + try: + stats = future.result() + for key in aggregated: + aggregated[key] += stats[key] + except Exception: + has_errors = True + logger.exception("Worker GPU %d raised an exception", gid) + + if has_errors: + logger.warning("WARNING: Some workers had errors. Check logs for details.") + + return aggregated diff --git a/tools/triton_kernel_extractor/config.py b/tools/triton_kernel_extractor/config.py new file mode 100644 index 0000000000..76a5057e5d --- /dev/null +++ b/tools/triton_kernel_extractor/config.py @@ -0,0 +1,46 @@ +"""Pipeline configuration types and constants.""" + +from __future__ import annotations + +import dataclasses +from pathlib import Path + +# Log patterns emitted by graph_net_bench.torch.test_compiler with --kernel-time. +SPEEDUP_KERNEL_PATTERN = r"\[Speedup\]\[kernel\]:\s*([\d.]+)" +SPEEDUP_E2E_PATTERN = r"\[Speedup\]\[e2e\]:\s*([\d.]+)" + +# Subdirectory names reserved for internal bookkeeping inside the cache directory. +# These are skipped when iterating over sample directories. +RESERVED_DIR_NAMES = frozenset({"kept", "discarded"}) + +# Prefix used by temporary pipeline artifacts (chunk files, worker logs, sample +# lists). Directories whose name starts with this prefix are skipped during +# sample iteration. +RESERVED_DIR_PREFIX = "_" + +# Minimum kernel speedup required to keep a compiled sample. +SPEEDUP_THRESHOLD = 1.0 + + +def is_sample_dir(name: str) -> bool: + """Return True if *name* is a real sample directory, not a reserved one. + + Filters out ``kept``, ``discarded``, and directories starting with ``_`` + (temporary pipeline artifacts such as chunk files and worker logs). + """ + if name in RESERVED_DIR_NAMES: + return False + if name.startswith(RESERVED_DIR_PREFIX): + return False + return True + + +@dataclasses.dataclass(frozen=True) +class PipelineConfig: + """Immutable top-level configuration for a single pipeline run.""" + + gpu_ids: list[int] + graph_dir: Path + output_dir: Path + allow_list: Path | None = None + max_autotune_no_cudagraphs: bool = False diff --git a/tools/triton_kernel_extractor/empty_sample_cleaner.py b/tools/triton_kernel_extractor/empty_sample_cleaner.py new file mode 100644 index 0000000000..f16745f1c9 --- /dev/null +++ b/tools/triton_kernel_extractor/empty_sample_cleaner.py @@ -0,0 +1,45 @@ +"""Step 5: Remove output samples that have no extracted triton kernels.""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def clean_empty_kernel_samples(output_dir: Path) -> tuple[int, int]: + """Delete samples that contain ``original_graph/`` but no ``triton_kernel/``. + + Returns: + A tuple of ``(removed_count, kept_count)``. + """ + if not output_dir.is_dir(): + logger.warning("Output directory does not exist: %s", output_dir) + return 0, 0 + + total = 0 + removed = 0 + + for sample_dir in sorted(output_dir.iterdir()): + if not sample_dir.is_dir(): + continue + total += 1 + + has_graph = (sample_dir / "original_graph").is_dir() + has_kernel = (sample_dir / "triton_kernel").is_dir() + + if has_graph and not has_kernel: + logger.info(" Removing (no triton_kernel): %s", sample_dir.name) + shutil.rmtree(sample_dir) + removed += 1 + + kept = total - removed + logger.info( + "Cleanup: %d removed (no triton_kernel), %d kept (total: %d)", + removed, + kept, + total, + ) + return removed, kept diff --git a/tools/triton_kernel_extractor/kernel_extractor.py b/tools/triton_kernel_extractor/kernel_extractor.py new file mode 100644 index 0000000000..5e7a34691f --- /dev/null +++ b/tools/triton_kernel_extractor/kernel_extractor.py @@ -0,0 +1,254 @@ +"""Step 4: Extract autotuning-selected triton kernels and corresponding PTX.""" + +from __future__ import annotations + +import json +import logging +import re +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Compiled regex that replaces the original perl one-liner: +# +# perl -0777 -ne ' +# while (/async_compile\.triton\(\x27([^\x27]+)\x27,\s*\x27\x27\x27(.*?)\x27\x27\x27/gs) { +# print "===KERNEL_NAME===$1\n$2\n===KERNEL_END===\n"; +# }' +# +# Captures: group(1) = kernel name, group(2) = kernel source code. +_TRITON_KERNEL_PATTERN = re.compile( + r"async_compile\.triton\('([^']+)',\s*'''(.*?)'''", + re.DOTALL, +) + + +def _collect_best_config_hashes(sample_cache_dir: Path) -> set[str]: + """Gather all autotuning-selected cache hashes from a sample directory. + + TorchInductor writes ``.best_config`` JSON files (one per autotuned kernel) + in 2-char prefix subdirectories of the sample cache. Each file contains a + ``triton_cache_hash`` field identifying the winning configuration among + multiple compiled candidates in ``triton/0/``. + + This function is called once per sample and the result is reused for every + kernel in that sample. + """ + hashes: set[str] = set() + for bc_path in sample_cache_dir.rglob("*.best_config"): + try: + data = json.loads(bc_path.read_text(encoding="utf-8")) + cache_hash = data.get("triton_cache_hash") + if cache_hash: + hashes.add(cache_hash) + except (OSError, json.JSONDecodeError): + logger.debug("Skipping malformed .best_config: %s", bc_path) + return hashes + + +def _find_best_ptx( + sample_cache_dir: Path, + kernel_name: str, + best_hashes: set[str], +) -> str | None: + """Locate the corresponding PTX for a given kernel via autotuning results. + + The inductor cache compiles each triton kernel into one or more candidate + configurations under ``triton/0/{HASH}/``. When autotuning runs, multiple + candidate directories exist for the same kernel and a ``.best_config`` file + records the winning ``triton_cache_hash``. + + Resolution strategy: + - 0 candidates → return ``None`` (no PTX compiled for this kernel). + - 1 candidate → return its PTX (no disambiguation needed). + - N candidates → intersect directory names with *best_hashes*; the match + identifies the autotuning winner. + """ + triton_base = sample_cache_dir / "triton" / "0" + if not triton_base.is_dir(): + return None + + # Collect all triton/0/{hash}/ dirs that contain this kernel's PTX. + ptx_filename = f"{kernel_name}.ptx" + candidates: list[Path] = [ + ptx_file + for hash_dir in triton_base.iterdir() + if hash_dir.is_dir() + for ptx_file in [hash_dir / ptx_filename] + if ptx_file.is_file() + ] + + if not candidates: + logger.debug( + "No PTX found for kernel %s in %s", kernel_name, sample_cache_dir.name + ) + return None + + if len(candidates) == 1: + try: + return candidates[0].read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read PTX file: %s", candidates[0]) + return None + + # Multiple candidates: pick the one whose parent dir matches a best_config hash. + for ptx_path in candidates: + if ptx_path.parent.name in best_hashes: + try: + return ptx_path.read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read PTX file: %s", ptx_path) + return None + + # Fallback: no .best_config match (should not happen based on validation). + logger.warning( + "Multiple PTX candidates for %s but no .best_config match in %s", + kernel_name, + sample_cache_dir.name, + ) + return None + + +def extract_kernels_from_file( + output_code_path: Path, +) -> list[tuple[str, str]]: + """Parse an ``output_code.py`` and return ``(name, source)`` pairs. + + The file is read entirely into memory (``output_code.py`` files produced by + TorchInductor are typically well under 1 MB). Returns an empty list if the + file cannot be read. + """ + try: + content = output_code_path.read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read output_code.py: %s", output_code_path) + return [] + return _TRITON_KERNEL_PATTERN.findall(content) + + +def extract_triton_kernels( + cache_dir: Path, + output_dir: Path, +) -> tuple[int, int, int, int, int]: + """Walk kept samples, extract autotuning-selected triton kernels and corresponding PTX. + + For every kept sample that contains ``original_graph/graph_hash.txt``: + + 1. Copy ``original_graph/model.py`` into the output. + 2. Parse every ``output_code.py`` found in the sample tree. + 3. Write each extracted kernel to ``triton_kernel/{name}.py``. + 4. Locate the corresponding PTX for each kernel and write it to + ``ptx/{name}.ptx``. + + The output uses an atomic ``.tmp`` + ``rename`` pattern so that an + interrupted run never leaves a half-written sample directory. + + Returns: + ``(processed_files, total_kernels, total_ptx, extracted_samples, skip_count)`` + """ + kept_dir = cache_dir / "kept" + if not kept_dir.is_dir(): + logger.error("Kept directory does not exist: %s", kept_dir) + return 0, 0, 0, 0, 0 + + output_dir.mkdir(parents=True, exist_ok=True) + + # Clean up stale .tmp directories from a previous interrupted run. + for stale in output_dir.iterdir(): + if stale.is_dir() and stale.name.endswith(".tmp"): + shutil.rmtree(stale, ignore_errors=True) + + # Collect eligible samples (must contain original_graph/graph_hash.txt). + eligible: list[Path] = [ + d + for d in sorted(kept_dir.iterdir()) + if d.is_dir() and (d / "original_graph" / "graph_hash.txt").is_file() + ] + total = len(eligible) + + processed_files = 0 + total_kernels = 0 + total_ptx = 0 + extracted_samples = 0 + skip_count = 0 + + for idx, sample_cache_dir in enumerate(eligible, 1): + sample_name = sample_cache_dir.name + dest_sample_dir = output_dir / sample_name + + # Resume: skip if the final output already exists. + if dest_sample_dir.exists(): + skip_count += 1 + continue + + logger.info("[%d/%d] Extracting: %s", idx, total, sample_name) + + # Write to a temporary directory; rename atomically on success. + tmp_dir = dest_sample_dir.with_name(f"{sample_name}.tmp") + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + # Copy original model source when available. + model_src = sample_cache_dir / "original_graph" / "model.py" + if model_src.is_file(): + og_dir = tmp_dir / "original_graph" + og_dir.mkdir() + shutil.copy2(str(model_src), str(og_dir / "model.py")) + + # Pre-collect autotuning best-config hashes once per sample. + best_hashes = _collect_best_config_hashes(sample_cache_dir) + + # Track kernel names already written for this sample to detect + # duplicates across multiple output_code.py files. + seen_kernels: set[str] = set() + + # Find and process all output_code.py files within the sample. + for output_code_path in sorted(sample_cache_dir.rglob("output_code.py")): + processed_files += 1 + kernels = extract_kernels_from_file(output_code_path) + if not kernels: + continue + + triton_dir = tmp_dir / "triton_kernel" + triton_dir.mkdir(exist_ok=True) + + for name, source in kernels: + if name in seen_kernels: + logger.debug( + "Duplicate kernel %s in %s, skipping", name, sample_name + ) + continue + seen_kernels.add(name) + # Strip trailing whitespace then add exactly one newline, + # matching the bash `printf '%s\n'` semantics. + (triton_dir / f"{name}.py").write_text( + source.rstrip() + "\n", encoding="utf-8" + ) + total_kernels += 1 + + # Locate and write the corresponding PTX for this kernel. + ptx_content = _find_best_ptx(sample_cache_dir, name, best_hashes) + if ptx_content is not None: + ptx_dir = tmp_dir / "ptx" + ptx_dir.mkdir(exist_ok=True) + (ptx_dir / f"{name}.ptx").write_text(ptx_content, encoding="utf-8") + total_ptx += 1 + + # Atomic completion: rename .tmp → final (same filesystem guarantees + # a single rename(2) syscall). + tmp_dir.rename(dest_sample_dir) + extracted_samples += 1 + + logger.info( + "Extraction: %d files, %d kernels, %d ptx, %d samples, %d skipped (total: %d)", + processed_files, + total_kernels, + total_ptx, + extracted_samples, + skip_count, + total, + ) + logger.info("Output: %s", output_dir) + return processed_files, total_kernels, total_ptx, extracted_samples, skip_count diff --git a/tools/triton_kernel_extractor/pipeline.py b/tools/triton_kernel_extractor/pipeline.py new file mode 100644 index 0000000000..07711e0253 --- /dev/null +++ b/tools/triton_kernel_extractor/pipeline.py @@ -0,0 +1,91 @@ +"""Orchestrate the five-step extraction pipeline for a single run.""" + +from __future__ import annotations + +import logging + +from .compiler import compile_all_samples +from .config import PipelineConfig +from .empty_sample_cleaner import clean_empty_kernel_samples +from .kernel_extractor import extract_triton_kernels +from .sample_enumerator import enumerate_graph_dir, enumerate_list_samples +from .speedup_filter import filter_samples_by_speedup +from .temp_cleaner import clean_temp_files + +logger = logging.getLogger(__name__) + + +def _load_samples(config: PipelineConfig) -> list[str]: + """Load sample paths from allow-list or by scanning graph_dir.""" + if config.allow_list is not None: + if not config.allow_list.is_file(): + raise FileNotFoundError(f"Allow-list file not found: {config.allow_list}") + return enumerate_list_samples(config.allow_list, config.graph_dir) + return enumerate_graph_dir(config.graph_dir) + + +def run_pipeline( + config: PipelineConfig, + *, + enable_cache_analysis: bool = False, +) -> None: + """Run the full five-step pipeline.""" + logger.info("") + logger.info("======================================================") + logger.info(" Graph dir: %s", config.graph_dir) + logger.info(" Output dir: %s", config.output_dir) + logger.info(" GPUs: %s", " ".join(str(g) for g in config.gpu_ids)) + if config.allow_list: + logger.info(" Allow list: %s", config.allow_list) + if config.max_autotune_no_cudagraphs: + logger.info(" Autotune: mode='max-autotune-no-cudagraphs'") + logger.info("======================================================") + + samples = _load_samples(config) + logger.info(" Samples: %d", len(samples)) + + if not samples: + logger.error("No samples found.") + return + + config.output_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Parallel compilation. + num_gpus = len(config.gpu_ids) + logger.info("") + logger.info("=== Step 1: Parallel compilation (%d GPUs) ===", num_gpus) + compile_all_samples(samples, config) + + # Step 2: Filter by speedup. + logger.info("") + logger.info("=== Step 2: Filter by speedup ===") + filter_samples_by_speedup(config.output_dir) + + # Step 3: Clean temp files. + logger.info("") + logger.info("=== Step 3: Clean temp files ===") + clean_temp_files(config.output_dir) + + # Step 4: Extract triton kernels. + extraction_dir = config.output_dir / "extracted" + logger.info("") + logger.info( + "=== Step 4: Extract autotuning-selected triton kernels and corresponding PTX ===" + ) + extract_triton_kernels(config.output_dir, extraction_dir) + + # Step 5: Clean samples without triton kernels. + logger.info("") + logger.info("=== Step 5: Clean samples without triton kernels ===") + clean_empty_kernel_samples(extraction_dir) + + if enable_cache_analysis: + from .cache_analyzer import analyze_cache + + logger.info("") + logger.info("=== Cache analysis ===") + analysis_dir = config.output_dir / "analysis" + analyze_cache(config.output_dir, analysis_dir) + + logger.info("") + logger.info("Done.") diff --git a/tools/triton_kernel_extractor/sample_enumerator.py b/tools/triton_kernel_extractor/sample_enumerator.py new file mode 100644 index 0000000000..90e6e35faa --- /dev/null +++ b/tools/triton_kernel_extractor/sample_enumerator.py @@ -0,0 +1,56 @@ +"""Enumerate graph samples for the extraction pipeline.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def enumerate_list_samples(allow_list: Path, graph_dir: Path) -> list[str]: + """Read relative sample paths from *allow_list* and resolve against *graph_dir*. + + Returns absolute paths to sample directories. Blank lines are skipped. + + Raises: + FileNotFoundError: If *allow_list* does not exist. + """ + paths: list[str] = [] + with open(allow_list, encoding="utf-8") as fh: + for raw in fh: + stripped = raw.strip() + if stripped: + paths.append(str(graph_dir / stripped)) + return paths + + +def enumerate_graph_dir(graph_dir: Path) -> list[str]: + """Discover samples by recursively scanning *graph_dir* for ``model.py``. + + Returns the sorted list of parent directories that contain a ``model.py``. + + Raises: + FileNotFoundError: If *graph_dir* does not exist. + """ + if not graph_dir.is_dir(): + raise FileNotFoundError(f"Graph directory not found: {graph_dir}") + parents = sorted( + {str(p.parent) for p in graph_dir.rglob("model.py") if p.is_file()} + ) + return parents + + +def compute_unique_dir(sample_path: str, graph_dir: str) -> str: + """Derive a flat directory name from a sample path. + + Uses the relative portion below *graph_dir*, with ``/`` replaced by ``_``. + If *sample_path* is not under *graph_dir*, the full path is flattened + with leading slashes stripped to avoid reserved-prefix collisions. + """ + prefix = graph_dir.rstrip("/") + "/" + if sample_path.startswith(prefix): + rel = sample_path[len(prefix) :] + else: + rel = sample_path.lstrip("/") + return rel.replace("/", "_") diff --git a/tools/triton_kernel_extractor/speedup_filter.py b/tools/triton_kernel_extractor/speedup_filter.py new file mode 100644 index 0000000000..ecec553aa3 --- /dev/null +++ b/tools/triton_kernel_extractor/speedup_filter.py @@ -0,0 +1,101 @@ +"""Step 2: Partition compiled samples into *kept* and *discarded* by speedup.""" + +from __future__ import annotations + +import logging +import re +import shutil +from pathlib import Path + +from .config import ( + SPEEDUP_KERNEL_PATTERN, + SPEEDUP_THRESHOLD, + is_sample_dir, +) + +logger = logging.getLogger(__name__) + +_SPEEDUP_RE = re.compile(SPEEDUP_KERNEL_PATTERN) + + +def _parse_kernel_speedup(log_file: Path) -> float | None: + """Extract the last ``[Speedup][kernel]:`` value from a compiler log. + + Returns ``None`` when the log does not exist or contains no speedup line. + """ + if not log_file.is_file(): + return None + try: + content = log_file.read_text(encoding="utf-8", errors="replace") + except OSError: + return None + + last_match: re.Match[str] | None = None + for m in _SPEEDUP_RE.finditer(content): + last_match = m + + if last_match is None: + return None + + try: + return float(last_match.group(1)) + except ValueError: + return None + + +def filter_samples_by_speedup(cache_dir: Path) -> tuple[int, int, int]: + """Move compiled samples into ``kept/`` or ``discarded/`` sub-directories. + + A sample is *kept* when the last ``[Speedup][kernel]:`` value in its + ``test_compiler_log.log`` is >= ``SPEEDUP_THRESHOLD``. Samples that + have already been classified are silently skipped. + + Returns: + A tuple of ``(kept_count, discarded_count, skip_count)``. + """ + kept_dir = cache_dir / "kept" + discarded_dir = cache_dir / "discarded" + kept_dir.mkdir(parents=True, exist_ok=True) + discarded_dir.mkdir(parents=True, exist_ok=True) + + # Collect candidate directories (snapshot the listing to avoid mutating + # the iterator while moving entries). + candidates: list[Path] = [ + d for d in sorted(cache_dir.iterdir()) if d.is_dir() and is_sample_dir(d.name) + ] + total = len(candidates) + + kept_count = 0 + discarded_count = 0 + skip_count = 0 + + for idx, sample_cache_dir in enumerate(candidates, 1): + sample_name = sample_cache_dir.name + + # Skip if already classified in a previous (possibly interrupted) run. + if (kept_dir / sample_name).exists() or (discarded_dir / sample_name).exists(): + skip_count += 1 + continue + + log_file = sample_cache_dir / "test_compiler_log.log" + speedup = _parse_kernel_speedup(log_file) + should_keep = speedup is not None and speedup >= SPEEDUP_THRESHOLD + + if should_keep: + shutil.move(str(sample_cache_dir), str(kept_dir / sample_name)) + kept_count += 1 + else: + shutil.move(str(sample_cache_dir), str(discarded_dir / sample_name)) + discarded_count += 1 + + label = "KEPT" if should_keep else "DISCARDED" + logger.info("[%d/%d] %s: %s", idx, total, label, sample_name) + + logger.info( + "Filter: %d kept, %d discarded, %d skipped (total: %d)", + kept_count, + discarded_count, + skip_count, + total, + ) + return kept_count, discarded_count, skip_count diff --git a/tools/triton_kernel_extractor/temp_cleaner.py b/tools/triton_kernel_extractor/temp_cleaner.py new file mode 100644 index 0000000000..6144300b49 --- /dev/null +++ b/tools/triton_kernel_extractor/temp_cleaner.py @@ -0,0 +1,41 @@ +"""Step 3: Remove Python bytecode caches from a directory tree.""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def clean_temp_files(directory: Path) -> None: + """Recursively delete ``__pycache__`` dirs, ``*.pyc``, and ``*.pyo`` files. + + Mirrors the original bash implementation:: + + find "$dir" -type d -name "__pycache__" -exec rm -rf {} + + find "$dir" -type f -name "*.pyc" -delete + find "$dir" -type f -name "*.pyo" -delete + + Silently skips entries that vanish between discovery and deletion (e.g. + a ``.pyc`` inside a ``__pycache__`` that was already removed). + """ + if not directory.is_dir(): + logger.warning("Directory does not exist, skipping: %s", directory) + return + + logger.info("Cleaning temp files from %s ...", directory) + + # Collect first, then delete — avoids mutating the tree during iteration. + pycache_dirs = sorted(directory.rglob("__pycache__"), reverse=True) + for d in pycache_dirs: + if d.is_dir(): + shutil.rmtree(d, ignore_errors=True) + + for pattern in ("*.pyc", "*.pyo"): + for f in directory.rglob(pattern): + try: + f.unlink() + except FileNotFoundError: + pass