From 090b9ff65eae926583276654f9a0f20cfce96aee Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 8 Dec 2025 09:58:26 -0500 Subject: [PATCH 01/12] add layer-based red evolution --- README.md | 157 +++-- configs/README.md | 96 ++- configs/cluster_analysis/README.md | 115 ++++ .../mobilenetv2_cifar10_full.yaml | 50 ++ .../resnet18_cifar10_full.yaml | 110 ++++ .../resnet50_imagenet100.yaml | 113 ++++ .../cluster_analysis/vgg16_cifar10_full.yaml | 57 ++ configs/paper/README.md | 145 ++--- docs/METRIC_CONSISTENCY.md | 24 +- docs/README.md | 37 +- docs/api_reference.md | 273 ++++++--- scripts/README.md | 21 +- scripts/run_experiment.py | 211 +++++++ slurm-48692030.out | 4 - slurm_jobs/run_cluster_analysis_resnet18.sh | 69 +++ slurm_jobs/run_cluster_analysis_resnet50.sh | 76 +++ src/alignment/analysis/README.md | 63 +- src/alignment/analysis/__init__.py | 16 + src/alignment/analysis/cascade_analysis.py | 154 +++++ src/alignment/analysis/clustering/__init__.py | 16 + .../analysis/clustering/cross_layer_halo.py | 156 +++++ .../analysis/clustering/metric_clustering.py | 64 ++ .../analysis/visualization/__init__.py | 15 + .../analysis/visualization/cluster_plots.py | 553 ++++++++++++++++++ src/alignment/configs/config_loader.py | 2 +- src/alignment/experiments/__init__.py | 10 + src/alignment/experiments/base.py | 2 +- src/alignment/metrics/information/__init__.py | 1 + .../metrics/information/synergy_continuous.py | 267 +++++++++ src/alignment/pruning/strategies/__init__.py | 5 + .../pruning/strategies/cluster_aware.py | 552 +++++++++++++++++ src/alignment/training/README.md | 63 +- src/alignment/training/__init__.py | 13 + src/alignment/{ => training}/evaluation.py | 0 34 files changed, 3192 insertions(+), 318 deletions(-) create mode 100644 configs/cluster_analysis/README.md create mode 100644 configs/cluster_analysis/mobilenetv2_cifar10_full.yaml create mode 100644 configs/cluster_analysis/resnet18_cifar10_full.yaml create mode 100644 configs/cluster_analysis/resnet50_imagenet100.yaml create mode 100644 configs/cluster_analysis/vgg16_cifar10_full.yaml delete mode 100644 slurm-48692030.out create mode 100644 slurm_jobs/run_cluster_analysis_resnet18.sh create mode 100644 slurm_jobs/run_cluster_analysis_resnet50.sh create mode 100644 src/alignment/analysis/cascade_analysis.py create mode 100644 src/alignment/analysis/clustering/__init__.py create mode 100644 src/alignment/analysis/clustering/cross_layer_halo.py create mode 100644 src/alignment/analysis/clustering/metric_clustering.py create mode 100644 src/alignment/analysis/visualization/cluster_plots.py create mode 100644 src/alignment/metrics/information/synergy_continuous.py create mode 100644 src/alignment/pruning/strategies/cluster_aware.py rename src/alignment/{ => training}/evaluation.py (100%) diff --git a/README.md b/README.md index 874d5b30..fef64657 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ # Alignment Framework -Neural network alignment analysis and pruning framework. +Neural network analysis and structured pruning using alignment metrics and information theory. ## Overview -Tools for analyzing and pruning neural networks using alignment metrics, information theory, and structured pruning strategies. +This framework provides tools for analyzing and pruning neural networks through: -**Supported architectures**: MLPs, CNNs (ResNet, VGG), Transformers, LLMs (LLaMA, Mistral) +- **Alignment metrics**: Rayleigh quotient, activation-based importance +- **Information-theoretic analysis**: Mutual information, redundancy, synergy +- **Cluster-based analysis**: Functional type identification, cross-layer halo tracking +- **Structured pruning**: Channel/neuron removal with multiple scoring strategies + +**Supported architectures**: MLPs, CNNs (ResNet, VGG, MobileNet), Transformers, LLMs (LLaMA, Mistral, Qwen) ## Installation @@ -20,8 +25,6 @@ pip install -e . ## Quick Start -### Run Experiments - ```bash # Vision model analysis python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml @@ -29,48 +32,20 @@ python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml # CNN pruning python scripts/run_experiment.py --config configs/examples/resnet_pruning.yaml -# LLM importance scoring -python scripts/run_experiment.py --config configs/examples/llm_alignment.yaml -``` - -### Programmatic Usage - -```python -from alignment import ModelWrapper, get_metric - -wrapper = ModelWrapper(model) -rq = get_metric('rayleigh_quotient') +# LLM analysis +python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml -outputs, activations = wrapper.forward_with_activations(inputs) -weights = wrapper.get_layer_weights() -scores = rq.compute(activations['layer_input'], weights['layer']) +# Cluster-based analysis +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml ``` -## Configuration - -Experiments use YAML configuration files: - -```yaml -model: - name: "resnet18" - pretrained: true - -dataset: - name: "cifar10" - batch_size: 128 +## Experiment Types -alignment_methods: - - "rayleigh_quotient" - - "pairwise_redundancy_gaussian" - -pruning: - enabled: true - algorithms: ["alignment"] - sparsity_levels: [0.3, 0.5, 0.7] - structured: true -``` - -See `configs/template.yaml` for all parameters. +| Type | Description | Config Example | +|------|-------------|----------------| +| `alignment_analysis` | General alignment metrics | `mnist_basic.yaml` | +| `llm_alignment` | LLM supernode/SCAR analysis | `llama3_8b_full.yaml` | +| `cluster_analysis` | Metric-space clustering with halos | `resnet18_cifar10_full.yaml` | ## Metrics @@ -80,6 +55,20 @@ See `configs/template.yaml` for all parameters. | Alignment | `rayleigh_quotient`, `delta_alignment` | | Information | `mutual_information_gaussian`, `pairwise_redundancy_gaussian`, `gaussian_pid_synergy_mmi` | | SCAR (LLM) | `scar_activation_power`, `scar_taylor`, `scar_curvature`, `scar_loss_proxy` | +| Synergy | `synergy_continuous_target` (with logit margin) | + +## Cluster-Based Analysis + +The cluster analysis framework groups channels/neurons into functional types: + +| Type | Characteristics | Pruning Implication | +|------|-----------------|---------------------| +| Critical | High RQ, Low Redundancy, High Synergy | Protect | +| Redundant | Moderate RQ, High Redundancy | Target for pruning | +| Synergistic | Moderate RQ, High Synergy | Preserve pairs | +| Background | Low on all metrics | Safe to remove | + +Cross-layer halo analysis tracks downstream dependencies to predict cascade effects. ## Pruning Strategies @@ -87,35 +76,83 @@ See `configs/template.yaml` for all parameters. |----------|-------------| | `magnitude` | Prune by weight magnitude | | `alignment` | Prune by alignment score | -| `hybrid` | Combine magnitude and alignment | +| `composite` | Combine multiple metrics | +| `cluster_aware` | Use cluster membership and halo analysis | | `random` | Random baseline | -| `global` | Cross-layer pruning | ## Project Structure ``` alignment/ -├── configs/ # YAML configuration files -│ ├── examples/ # Example experiments -│ └── template.yaml # Parameter reference -├── scripts/ # Entry points -│ ├── run_experiment.py -│ └── run_analysis.py -├── src/alignment/ # Main package -│ ├── analysis/ # Visualization -│ ├── experiments/ # Experiment classes -│ ├── metrics/ # Alignment metrics -│ ├── models/ # Model wrappers -│ └── pruning/ # Pruning strategies -├── tests/ # Unit tests -└── docs/ # Documentation +├── configs/ +│ ├── cluster_analysis/ # Cluster-based analysis configs +│ ├── paper/ # Paper experiment configs +│ └── examples/ # Example configs +├── scripts/ +│ ├── run_experiment.py # Main entry point +│ └── run_analysis.py # Post-hoc analysis +├── src/alignment/ +│ ├── analysis/ # Visualization, clustering, cascade analysis +│ ├── experiments/ # Experiment classes +│ ├── metrics/ # Importance metrics +│ ├── models/ # Model wrappers +│ └── pruning/ # Pruning strategies +├── tests/ # Unit tests +└── docs/ # Documentation ``` +## Key Modules + +### Analysis +- `MetricSpaceClustering`: K-means clustering in (RQ, Redundancy, Synergy) space +- `CrossLayerHaloAnalysis`: Track downstream channel dependencies +- `CascadeAnalysis`: Validate importance via ablation +- `UnifiedVisualizer`: Generate analysis plots + +### Experiments +- `GeneralAlignmentExperiment`: Vision model analysis +- `LLMAlignmentExperiment`: LLM supernode and SCAR analysis +- `ClusterAnalysisExperiment`: Cluster-based analysis for any architecture + +### Metrics +- `RayleighQuotient`: Input-weight alignment +- `PairwiseRedundancyGaussian`: Gaussian MI-based redundancy +- `SynergyContinuousTarget`: PID synergy with continuous target +- SCAR metrics for LLMs + ## Documentation - [Usage Guide](docs/usage.md) - Running experiments and configuration - [API Reference](docs/api_reference.md) - Core classes and functions -- [LLM Guide](docs/llm_guide.md) - LLM-specific analysis and pruning +- [LLM Guide](docs/llm_guide.md) - LLM-specific analysis +- [Metric Consistency](docs/METRIC_CONSISTENCY.md) - Theory-code verification + +## Configuration + +```yaml +experiment_type: cluster_analysis # or llm_alignment, alignment_analysis + +model: + name: resnet18 + pretrained: true + +dataset: + name: cifar10 + batch_size: 128 + +clustering: + n_clusters: 4 + compute_stability: true + +halo_analysis: + percentile: 90.0 + +pruning: + ratios: [0.3, 0.5, 0.7] + methods: [magnitude, taylor, cluster_aware] +``` + +See `configs/template.yaml` for complete parameter reference. ## Testing diff --git a/configs/README.md b/configs/README.md index 59e82b58..93c15154 100644 --- a/configs/README.md +++ b/configs/README.md @@ -5,60 +5,106 @@ ``` configs/ ├── template.yaml # Complete template with all options -├── examples/ # Ready-to-use examples -│ ├── mnist_basic.yaml # MNIST RQ analysis -│ ├── resnet_pruning.yaml # ResNet pruning -│ ├── llama3_pruning.yaml # Llama-3 pruning -│ ├── llm_alignment.yaml # LLM supernode analysis -│ └── vision_comprehensive.yaml -└── projects/ # Project configs - ├── llm_supernode.yaml - └── vision_synergy.yaml +├── cluster_analysis/ # Cluster-based analysis configs +│ ├── resnet18_cifar10_full.yaml +│ ├── vgg16_cifar10_full.yaml +│ └── mobilenetv2_cifar10_full.yaml +├── paper/ # LLM paper experiment configs +│ ├── llama3_8b_full.yaml +│ ├── llama2_7b_full.yaml +│ ├── mistral_7b_full.yaml +│ └── qwen2_7b_full.yaml +└── examples/ # Example configs + ├── mnist_basic.yaml + ├── resnet_pruning.yaml + └── llm_alignment.yaml ``` ## Usage ```bash -python scripts/run_experiment.py --config configs/examples/llama3_comprehensive_pruning.yaml.yaml -python scripts/run_experiment.py --config configs/examples/vision_pruning_test.yaml +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml +python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml +python scripts/run_experiment.py --config configs/examples/resnet_pruning.yaml +``` + +## Experiment Types + +| Type | Description | +|------|-------------| +| `alignment_analysis` | General alignment metrics | +| `llm_alignment` | LLM supernode/SCAR analysis | +| `cluster_analysis` | Metric-space clustering with halos | ## Configuration Blocks | Block | Purpose | |-------|---------| -| `experiment` | Name, type (`alignment_analysis` or `llm_alignment`), seed, device | -| `model` | Architecture, pretrained, tracked_layers. For LLMs: model_id, dtype | +| `experiment` | Name, type, seed, device | +| `model` | Architecture, pretrained, tracked_layers | | `dataset` | Dataset name, batch_size, data_path | -| `metrics` | `enabled`: list of metrics. `num_samples`: calibration samples. `composite_weights`: for composite scoring | -| `training` | `enabled`, epochs, learning_rate, optimizer | -| `supernode` | Detection settings: score_metric, core_fraction, protect_core | -| `pruning` | strategy, sparsity_levels, scoring, direction, structured | +| `metrics` | Enabled metrics, num_samples, composite_weights | +| `clustering` | n_clusters, compute_stability, n_bootstrap | +| `halo_analysis` | percentile, use_activation_weight | +| `cascade_analysis` | n_remove_per_cluster | +| `supernode` | Detection settings for LLMs | +| `pruning` | Strategy, sparsity_levels, scoring | | `llm` | LLM-specific: scar_metrics, evaluate_perplexity | -| `cnn` | CNN-specific: mode (unfold, patchwise) | -| `analysis` | save_scores, generate_plots, plots to enable | -| `visualization` | format, dpi | ## Metrics -Specify metrics to compute in `metrics.enabled`: +Available metrics for `metrics.enabled`: - `rayleigh_quotient` - Input-weight alignment - `activation_l2_norm` - Activation magnitude -- `activation_outlier_index` - Outlier detection - `pairwise_redundancy_gaussian` - Pairwise redundancy - `synergy_gaussian_mmi` - Synergistic information - `mutual_information_gaussian` - MI estimate ## Composite Scoring -Define weights in `metrics.composite_weights` for combining metrics: +Define weights in `metrics.composite_weights`: ```yaml metrics: composite_weights: activation_l2_norm: 0.2 rayleigh_quotient: 0.3 - pairwise_redundancy_gaussian: -0.2 # Negative penalizes redundancy + pairwise_redundancy_gaussian: -0.2 +``` + +## Cluster Analysis Configuration + +```yaml +experiment_type: cluster_analysis + +clustering: + n_clusters: 4 + compute_stability: true + n_bootstrap: 50 + +halo_analysis: + percentile: 90.0 + use_activation_weight: true + +cascade_analysis: + n_remove_per_cluster: 5 ``` -Used when `pruning.scoring: "composite"` or `supernode.score_metric: "composite"`. +## LLM Configuration + +```yaml +experiment_type: llm_alignment + +model_config: + model_id: "meta-llama/Llama-3.1-8B" + torch_dtype: "bfloat16" + +do_scar_metrics: true +scar_num_samples: 100 + +supernode: + enabled: true + core_fraction: 0.01 + protect_core: true +``` diff --git a/configs/cluster_analysis/README.md b/configs/cluster_analysis/README.md new file mode 100644 index 00000000..f7159d80 --- /dev/null +++ b/configs/cluster_analysis/README.md @@ -0,0 +1,115 @@ +# Cluster Analysis Experiment Configurations + +This directory contains configurations for **cluster-based neural network analysis** - a general framework that works on any architecture. + +## Overview + +The cluster-based analysis pipeline identifies functional types of neurons/channels by clustering them in metric space: + +1. **Metric Computation**: RQ (alignment), Redundancy (Gaussian MI), Synergy (with continuous target) +2. **Clustering**: K-means in metric space → 4 functional types +3. **Cross-Layer Halo Analysis**: Track downstream dependencies +4. **Cascade Testing**: Validate cluster damage predictions +5. **Pruning Experiments**: Compare cluster-aware vs baselines + +## Supported Architectures + +- **Vision**: ResNet, VGG, MobileNet, EfficientNet, etc. +- **LLMs**: Can be adapted for FFN analysis (see LLM configs) +- **Any model** with Conv2d or Linear layers + +## Configuration Files + +| Config | Model | Dataset | Purpose | +|--------|-------|---------|---------| +| `resnet18_cifar10_full.yaml` | ResNet-18 | CIFAR-10 | Full analysis | +| `vgg16_cifar10_full.yaml` | VGG-16-BN | CIFAR-10 | Full analysis | +| `mobilenetv2_cifar10_full.yaml` | MobileNetV2 | CIFAR-10 | Full analysis | +| `resnet50_imagenet100.yaml` | ResNet-50 | ImageNet-100 | Large-scale analysis | + +## Running Experiments + +Use the unified `run_experiment.py` script (same as all other experiments): + +```bash +# Run full analysis (experiment_type is read from config) +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml + +# Override device +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml --device cuda:1 + +# Override seed for reproducibility study +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml --seed 123 + +# Specify output directory +python scripts/run_experiment.py --config configs/cluster_analysis/vgg16_cifar10_full.yaml \ + --output-dir results/cluster_analysis/vgg16_run1 +``` + +## Key Configuration Options + +### Metrics +```yaml +metrics: + n_calibration_samples: 5000 # Samples for metric computation + synergy_target: logit_margin # Continuous target for synergy + synergy_num_pairs: 10 # Partners per channel for synergy +``` + +### Clustering +```yaml +clustering: + n_clusters: 4 # 4 functional types + compute_stability: true # Bootstrap stability analysis + n_bootstrap: 50 # Number of bootstrap samples +``` + +### Halo Analysis +```yaml +halo_analysis: + percentile: 90.0 # Halo membership threshold + use_activation_weight: true # Weight influence by activation std +``` + +### Pruning +```yaml +pruning: + ratios: [0.1, 0.3, 0.5, 0.7] # Sparsity levels to test + methods: + - magnitude # L2 norm baseline + - taylor # Taylor importance + - network_slimming # BN gamma + - composite # Per-channel RQ+Red+Syn + - cluster_aware # Full cluster + halo aware +``` + +## Output Structure + +``` +results/cluster_analysis/resnet18_cifar10/ +├── results.json # Full results +├── figures/ +│ ├── cluster_scatter_*.png # Metric space plots +│ ├── cluster_evolution.png # Composition by depth +│ ├── influence_matrix_*.png # Cross-layer influence +│ ├── cascade_*.png # Damage by cluster type +│ └── halo_properties_*.png # Halo redundancy/synergy +└── metrics/ + └── layer_metrics.npz # Raw per-channel metrics +``` + +## Functional Types + +The 4-cluster structure identifies: + +| Type | Characteristics | Pruning Implication | +|------|-----------------|---------------------| +| **Critical** | High RQ, Low Red, High Syn | Protect (max 30% removal) | +| **Redundant** | Mod RQ, High Red, Low Syn | Target for pruning | +| **Synergistic** | Mod RQ, Low Red, High Syn | Preserve pairs | +| **Background** | Low on all metrics | Safe to remove | + +## Related Papers + +- Vision paper: `drafts/alignment_notes/vision_synergy_icml_v3.tex` +- LLM paper: `drafts/LLM_prune/scar_paper_icml_v4.tex` diff --git a/configs/cluster_analysis/mobilenetv2_cifar10_full.yaml b/configs/cluster_analysis/mobilenetv2_cifar10_full.yaml new file mode 100644 index 00000000..c0460143 --- /dev/null +++ b/configs/cluster_analysis/mobilenetv2_cifar10_full.yaml @@ -0,0 +1,50 @@ +# MobileNetV2 on CIFAR-10 - Full Cluster Analysis + +experiment_name: mobilenetv2_cifar10_cluster_analysis +experiment_type: cluster_analysis + +model: + name: mobilenet_v2 + pretrained: true + num_classes: 10 + +dataset: + name: cifar10 + root: ./data + train_batch_size: 128 + test_batch_size: 256 + +metrics: + n_calibration_samples: 5000 + compute_rq: true + compute_redundancy: true + compute_synergy: true + synergy_target: logit_margin + synergy_num_pairs: 10 + +clustering: + enabled: true + n_clusters: 4 + compute_stability: true + n_bootstrap: 50 + +halo_analysis: + enabled: true + percentile: 90.0 + +cascade_analysis: + enabled: true + n_remove_per_cluster: 5 + +pruning: + enabled: true + ratios: [0.1, 0.3, 0.5, 0.7] + +visualization: + enabled: true + +output: + dir: results/vision/mobilenetv2_cifar10 + +device: cuda +seed: 42 diff --git a/configs/cluster_analysis/resnet18_cifar10_full.yaml b/configs/cluster_analysis/resnet18_cifar10_full.yaml new file mode 100644 index 00000000..3a91c9c6 --- /dev/null +++ b/configs/cluster_analysis/resnet18_cifar10_full.yaml @@ -0,0 +1,110 @@ +# ResNet-18 on CIFAR-10 - Full Cluster Analysis +# This config runs the complete cluster-based analysis pipeline + +experiment_name: resnet18_cifar10_cluster_analysis +experiment_type: cluster_analysis + +# Model configuration +model: + name: resnet18 + pretrained: true + num_classes: 10 + +# Dataset configuration +dataset: + name: cifar10 + root: ./data + train_batch_size: 128 + test_batch_size: 256 + num_workers: 4 + augmentation: true + +# Metric computation +metrics: + n_calibration_samples: 5000 + + # Rayleigh Quotient + compute_rq: true + rq_shrinkage: true + + # Redundancy (Gaussian pairwise MI) + compute_redundancy: true + redundancy_sampling: all # all, random, top_k + + # Synergy with continuous target + compute_synergy: true + synergy_target: logit_margin # logit_margin, correct_logit, logit_pc1 + synergy_num_pairs: 10 + synergy_sampling: top_k # random, top_k, all + +# Clustering +clustering: + enabled: true + n_clusters: 4 + normalize_features: true + compute_stability: true + n_bootstrap: 50 + +# Cross-layer halo analysis +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + compute_influence_matrix: true + +# Cascade/damage analysis +cascade_analysis: + enabled: true + n_remove_per_cluster: 5 + damage_sample_fraction: 0.2 + +# Pruning experiments +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + methods: + - name: random + - name: magnitude + - name: taylor + - name: rq_only + - name: composite # RQ + Red + Syn + - name: cluster_aware # Full cluster + halo aware + + fine_tuning: + epochs: 10 + lr: 0.0001 + +# Baselines to compare against +baselines: + - magnitude + - taylor + - network_slimming + +# Visualization +visualization: + enabled: true + figures: + - cluster_scatter + - cluster_evolution + - influence_matrix + - cascade_test + - halo_properties + - pruning_curves + - metric_distributions + +# Output +output: + dir: results/vision/resnet18_cifar10 + save_metrics: true + save_clusters: true + save_figures: true + +# Pre-training (required for ImageNet pretrained models on CIFAR) +# This fine-tunes the model on CIFAR-10 before running experiments +pretrain_epochs: 20 +pretrain_lr: 0.001 + +# Hardware +device: cuda +seed: 42 diff --git a/configs/cluster_analysis/resnet50_imagenet100.yaml b/configs/cluster_analysis/resnet50_imagenet100.yaml new file mode 100644 index 00000000..cb630c39 --- /dev/null +++ b/configs/cluster_analysis/resnet50_imagenet100.yaml @@ -0,0 +1,113 @@ +# ResNet-50 on ImageNet-100 - Full Cluster Analysis +# This config runs the complete cluster-based analysis pipeline on ImageNet subset +# +# ImageNet-100 is a 100-class subset of ImageNet for tractable experiments. +# Classes are selected to maintain semantic diversity. + +experiment_name: resnet50_imagenet100_cluster_analysis +experiment_type: cluster_analysis + +# Model configuration +model: + name: resnet50 + pretrained: true + num_classes: 100 + weights: IMAGENET1K_V2 # Use torchvision pretrained weights + +# Dataset configuration +dataset: + name: imagenet100 + root: ./data/imagenet100 # Path to ImageNet-100 subset + train_batch_size: 64 + test_batch_size: 128 + num_workers: 8 + image_size: 224 + normalize: true + # If using full ImageNet, specify class indices: + # class_indices: [0, 10, 20, ...] # 100 selected classes + +# Metric computation +metrics: + n_calibration_samples: 5000 + + # Rayleigh Quotient + compute_rq: true + rq_shrinkage: true + + # Redundancy (Gaussian pairwise MI) + compute_redundancy: true + redundancy_sampling: all + + # Synergy with continuous target + compute_synergy: true + synergy_target: logit_margin + synergy_num_pairs: 10 + synergy_sampling: top_k + +# Clustering +clustering: + enabled: true + n_clusters: 4 + normalize_features: true + compute_stability: true + n_bootstrap: 30 # Fewer for larger model + +# Cross-layer halo analysis +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true # Account for BN scaling + compute_influence_matrix: true + +# Cascade/damage analysis +cascade_analysis: + enabled: true + n_remove_per_cluster: 5 + damage_sample_fraction: 0.1 # Smaller for faster computation + +# Pruning experiments +pruning: + enabled: true + ratios: [0.2, 0.3, 0.4, 0.5, 0.6] + + methods: + - name: random + - name: magnitude + - name: taylor + - name: composite # RQ + Red + Syn (per-channel) + - name: cluster_aware # Full cluster + halo aware + + fine_tuning: + epochs: 5 # Fewer epochs for ImageNet + lr: 0.00001 + weight_decay: 0.0001 + +# Baselines to compare against +baselines: + - magnitude + - taylor + - network_slimming + +# Visualization +visualization: + enabled: true + figures: + - cluster_scatter + - cluster_evolution + - influence_matrix + - cascade_test + - halo_properties + - pruning_curves + - metric_distributions + +# Output +output: + dir: results/vision/resnet50_imagenet100 + save_metrics: true + save_clusters: true + save_figures: true + +# Hardware +device: cuda +seed: 42 +num_workers: 8 diff --git a/configs/cluster_analysis/vgg16_cifar10_full.yaml b/configs/cluster_analysis/vgg16_cifar10_full.yaml new file mode 100644 index 00000000..4e774f6e --- /dev/null +++ b/configs/cluster_analysis/vgg16_cifar10_full.yaml @@ -0,0 +1,57 @@ +# VGG-16-BN on CIFAR-10 - Full Cluster Analysis + +experiment_name: vgg16_cifar10_cluster_analysis +experiment_type: cluster_analysis + +model: + name: vgg16_bn + pretrained: true + num_classes: 10 + +dataset: + name: cifar10 + root: ./data + train_batch_size: 128 + test_batch_size: 256 + +metrics: + n_calibration_samples: 5000 + compute_rq: true + compute_redundancy: true + compute_synergy: true + synergy_target: logit_margin + synergy_num_pairs: 10 + +clustering: + enabled: true + n_clusters: 4 + compute_stability: true + n_bootstrap: 50 + +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + +cascade_analysis: + enabled: true + n_remove_per_cluster: 5 + +pruning: + enabled: true + ratios: [0.1, 0.3, 0.5, 0.7] + methods: + - magnitude + - taylor + - network_slimming + - composite + - cluster_aware + +visualization: + enabled: true + +output: + dir: results/vision/vgg16_cifar10 + +device: cuda +seed: 42 diff --git a/configs/paper/README.md b/configs/paper/README.md index 56511ab8..9a27f922 100644 --- a/configs/paper/README.md +++ b/configs/paper/README.md @@ -1,92 +1,56 @@ # SCAR Paper Experiment Configurations -Comprehensive configurations for generating all results in the SCAR paper. +Configurations for generating results in the SCAR LLM pruning paper. ## Configurations | Config | Model | Layers | FFN Width | Runtime | |--------|-------|--------|-----------|---------| -| `llama3_8b_full.yaml` | LLaMA-3.1-8B | 32 | 14336 | ~6-8h | -| `mistral_7b_full.yaml` | Mistral-7B | 32 | 14336 | ~4-6h | -| `llama2_7b_full.yaml` | LLaMA-2-7B | 32 | 11008 | ~4-6h | -| `qwen2_7b_full.yaml` | Qwen2-7B | 28 | 18944 | ~4-6h | +| `llama3_8b_full.yaml` | LLaMA-3.1-8B | 32 | 14336 | 6-8h | +| `mistral_7b_full.yaml` | Mistral-7B | 32 | 14336 | 4-6h | +| `llama2_7b_full.yaml` | LLaMA-2-7B | 32 | 11008 | 4-6h | +| `qwen2_7b_full.yaml` | Qwen2-7B | 28 | 18944 | 4-6h | ## Quick Start -### Run all experiments: +Run all experiments: ```bash -sbatch ../slurm_jobs/run_paper_experiments.sh +sbatch slurm_jobs/paper/run_all_paper.sh ``` -### Run single model: +Run single model: ```bash -python -m alignment.experiments.llm_alignment \ - --config configs/paper/llama3_8b_full.yaml +python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml ``` -## What's Included - -### Pruning Methods (All Configs) +## Pruning Methods | Category | Methods | |----------|---------| -| **Alignment-based** | `rayleigh_quotient`, `gaussian_mi_analytic`, `average_redundancy` | -| **SCAR (gradient-based)** | `scar_loss_proxy`, `scar_taylor`, `scar_activation_power`, `scar_curvature` | -| **Supernode-aware** | `supernode_protection_score`, `supernode_connectivity_score` | -| **Generalized** | `generalized_importance` (no outlier assumption) | -| **Cross-layer** | `cross_layer_importance` (SCAR-aligned downstream dependency) | -| **Magnitude baseline** | `activation_l2_norm` | -| **SOTA baselines** | `wanda`, `sparsegpt` | - -### Analyses - -1. **Supernode Distribution** - - Loss proxy histograms by layer - - Concentration across depth - - Top 1%, 5%, 10% highlighting - -2. **Supernode Robustness** - - Bootstrap stability analysis (10 resamples) - - Jaccard similarity between metrics - - Spearman correlation heatmaps - - Cross-metric consistency - -3. **Supernode Summary** - - Halo vs non-halo metrics by layer - - Outlier z-score analysis - -4. **Halo Redundancy Analysis** - - Within-halo redundancy - - Within-non-halo redundancy - - Cross-group redundancy - - Depth comparison plots - - Comprehensive 4-panel figures - -5. **Cross-Layer Importance** - - Downstream importance (next layer dependency) - - Layer transition efficiency - - Importance vs redundancy scatter - -6. **Generalized Importance** - - Works without clear supernode structure - - Neighborhood-based redundancy - - Downstream propagation - -### Evaluation Benchmarks - -**Perplexity:** -- WikiText-2 -- C4 (validation subset) - -**Zero-shot:** -- HellaSwag, PIQA, BoolQ, WinoGrande -- ARC-Easy, ARC-Challenge, OpenBookQA - -**Few-shot:** -- HellaSwag (5-shot) -- PIQA (5-shot) -- ARC-Challenge (5-shot) -- MMLU (5-shot, full) +| Alignment-based | `rayleigh_quotient`, `gaussian_mi_analytic`, `average_redundancy` | +| SCAR (gradient-based) | `scar_loss_proxy`, `scar_taylor`, `scar_activation_power`, `scar_curvature` | +| Supernode-aware | `supernode_protection_score`, `supernode_connectivity_score` | +| Generalized | `generalized_importance` (no outlier assumption) | +| Cross-layer | `cross_layer_importance` (downstream dependency) | +| Magnitude baseline | `activation_l2_norm` | +| SOTA baselines | `wanda`, `sparsegpt` | + +## Analyses + +1. **Supernode Distribution**: Loss proxy histograms, concentration across depth +2. **Supernode Robustness**: Bootstrap stability, Jaccard similarity, cross-metric consistency +3. **Supernode Summary**: Halo vs non-halo metrics by layer +4. **Halo Redundancy**: Within-halo, within-non-halo, cross-group redundancy +5. **Cross-Layer Importance**: Downstream importance, layer transition efficiency +6. **Generalized Importance**: Neighborhood-based scoring without outlier assumption + +## Evaluation Benchmarks + +**Perplexity**: WikiText-2, C4 + +**Zero-shot**: HellaSwag, PIQA, BoolQ, WinoGrande, ARC-Easy, ARC-Challenge, OpenBookQA + +**Few-shot**: HellaSwag (5-shot), PIQA (5-shot), ARC-Challenge (5-shot), MMLU (5-shot) ## Output Structure @@ -95,44 +59,31 @@ results/paper// ├── metrics/ │ ├── layer_metrics.json │ ├── supernode_analysis.json -│ ├── supernode_robustness.json -│ ├── halo_redundancy.json -│ └── cross_layer_analysis.json +│ └── halo_redundancy.json ├── evaluation/ │ ├── perplexity_results.json │ └── benchmark_results.json ├── pruning/ -│ ├── sparsity_curves.json -│ └── per_method_results.json +│ └── sparsity_curves.json └── figures/ ├── fig1_supernode_distribution.pdf ├── fig2_halo_redundancy.pdf - ├── fig3_cross_layer_importance.pdf - ├── fig4_pruning_curves.pdf - ├── supernode_robustness/ - │ ├── jaccard_heatmap.pdf - │ ├── spearman_heatmap.pdf - │ └── bootstrap_stability.pdf - └── supplementary/ + └── fig3_pruning_curves.pdf ``` -## Key Differences from `examples/llama3_comprehensive_pruning.yaml` - -The paper configs include everything from the comprehensive pruning config plus: +## Features -1. ✅ Structured for paper figure generation (PDF output) -2. ✅ All SOTA baselines (Wanda, SparseGPT) -3. ✅ Supernode robustness analysis -4. ✅ Supernode summary/outlier analysis -5. ✅ Generalized importance (no outlier assumption) -6. ✅ Cross-layer importance (SCAR-aligned) -7. ✅ Selection modes (low/high) -8. ✅ Additional evaluation metrics (bits_per_byte) -9. ✅ Comprehensive scatter pair analysis +Compared to example configs, paper configs include: +- PDF figure output +- All SOTA baselines (Wanda, SparseGPT) +- Supernode robustness analysis +- Generalized importance (outlier-free) +- Cross-layer importance analysis +- Comprehensive evaluation metrics ## Resource Requirements -- **GPU**: 1x A100 80GB (recommended) or H100 +- **GPU**: 1x A100 80GB or H100 - **Memory**: ~60GB GPU memory for 8B models -- **Storage**: ~50GB per model for full results -- **Time**: ~20-30 hours total for all 4 models +- **Storage**: ~50GB per model +- **Time**: ~20-30 hours total for all models diff --git a/docs/METRIC_CONSISTENCY.md b/docs/METRIC_CONSISTENCY.md index 8d6f8cf5..53b0e266 100644 --- a/docs/METRIC_CONSISTENCY.md +++ b/docs/METRIC_CONSISTENCY.md @@ -7,12 +7,12 @@ definitions in `drafts/alignment_notes/main.tex` and `drafts/alignment_notes/new | Metric | LaTeX Reference | Code Implementation | Status | |--------|-----------------|---------------------|--------| -| Rayleigh Quotient | Eq. 3.1 in new.tex | `src/alignment/metrics/rayleigh/rayleigh_quotient.py` | ✅ Consistent | -| Pairwise Redundancy | Eq. 5.1-5.2 in new.tex | `src/alignment/metrics/information/redundancy.py` | ✅ Consistent | -| Composite Score | Eq. 6.1 in new.tex | `src/alignment/metrics/composite.py` | ✅ Consistent | -| Class-conditioned RQ | Eq. 4.1-4.3 in new.tex | `src/alignment/metrics/conditional_metrics.py` | ✅ Consistent | -| Gaussian MI | Section 3.2 in new.tex | `src/alignment/metrics/information/gaussian_mi.py` | ✅ Consistent | -| PID Synergy | Eq. 5.4 in new.tex | `src/alignment/metrics/information/gaussian_pid.py` | ✅ Consistent | +| Rayleigh Quotient | Eq. 3.1 in new.tex | `src/alignment/metrics/rayleigh/rayleigh_quotient.py` | [x] Consistent | +| Pairwise Redundancy | Eq. 5.1-5.2 in new.tex | `src/alignment/metrics/information/redundancy.py` | [x] Consistent | +| Composite Score | Eq. 6.1 in new.tex | `src/alignment/metrics/composite.py` | [x] Consistent | +| Class-conditioned RQ | Eq. 4.1-4.3 in new.tex | `src/alignment/metrics/conditional_metrics.py` | [x] Consistent | +| Gaussian MI | Section 3.2 in new.tex | `src/alignment/metrics/information/gaussian_mi.py` | [x] Consistent | +| PID Synergy | Eq. 5.4 in new.tex | `src/alignment/metrics/information/gaussian_pid.py` | [x] Consistent | --- @@ -36,7 +36,7 @@ rq_values = numerator / denominator ### Verification - **Formula**: Matches exactly. Computes w^T Σ w / w^T w - **Normalization**: Code supports both absolute and relative (divided by trace) modes -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- @@ -65,7 +65,7 @@ mi_with_refs = -0.5 * torch.log(1.0 - rho_sq) - **Formula**: Matches exactly. Uses -0.5 * log(1 - ρ²) - **Correlation**: Computed from normalized activations (equivalent to ρ in theory) - **Clamping**: Properly handles edge cases (ρ² < 1) -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- @@ -96,7 +96,7 @@ for metric_name, weight in self.metric_weights.items(): - **Formula**: Matches. Supports arbitrary metric weights - **Log RQ**: Correctly applies log transform when configured - **Signs**: Redundancy can be given negative weight (penalty) -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- @@ -134,7 +134,7 @@ delta_rq = rq_uncond - rq_cond - **Per-class RQ**: Correctly computes RQ with class-specific covariance - **Weighted average**: Uses class proportions p(y) as weights - **Delta RQ**: Matches definition exactly -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- @@ -164,7 +164,7 @@ mi_scores = 0.5 * torch.log(output_var / noise_variance + 1.0) ### Verification - **Formula**: Matches the Gaussian channel capacity formula - **RQ Connection**: log(MI) ∝ log(RQ) for fixed noise (documented in code) -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- @@ -193,7 +193,7 @@ S = I_z_y12 - I_z_y1 - I_z_y2 + R_mmi - **MMI Redundancy**: Uses min correctly - **Synergy formula**: Matches exactly - **Gaussian MI terms**: All I() computed using same Gaussian formulas -- **Status**: ✅ **CONSISTENT** +- **Status**: [x] **CONSISTENT** --- diff --git a/docs/README.md b/docs/README.md index fea94546..c6be3a64 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,18 +5,47 @@ - [Usage Guide](usage.md) - Running experiments and configuration - [API Reference](api_reference.md) - Core classes and functions - [LLM Guide](llm_guide.md) - LLM-specific analysis and pruning +- [Metric Consistency](METRIC_CONSISTENCY.md) - Theory-code verification ## Configuration - [Template](../configs/template.yaml) - Complete parameter reference +- [Cluster Analysis](../configs/cluster_analysis/) - Cluster-based analysis configs +- [Paper Configs](../configs/paper/) - LLM paper experiment configs - [Examples](../configs/examples/) - Example configurations -## Quick Start +## Quick Reference + +### Experiment Types + +| Type | Description | +|------|-------------| +| `alignment_analysis` | General alignment metrics for vision models | +| `llm_alignment` | LLM supernode and SCAR analysis | +| `cluster_analysis` | Metric-space clustering with halo analysis | + +### Key Classes + +| Class | Module | Purpose | +|-------|--------|---------| +| `MetricSpaceClustering` | `analysis.clustering` | Cluster channels by functional type | +| `CrossLayerHaloAnalysis` | `analysis.clustering` | Track downstream dependencies | +| `CascadeAnalysis` | `analysis` | Validate importance via ablation | +| `LLMAlignmentExperiment` | `experiments` | LLM analysis runner | +| `ClusterAnalysisExperiment` | `experiments` | Cluster analysis runner | + +### Running Experiments ```bash -# Run experiment +# Vision/general analysis python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml -# Generate analysis -python scripts/run_analysis.py --results-dir ./results --output-dir ./plots --quick +# LLM analysis +python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml + +# Cluster-based analysis +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml + +# Post-hoc analysis +python scripts/run_analysis.py --results-dir ./results --output-dir ./plots ``` diff --git a/docs/api_reference.md b/docs/api_reference.md index 2335e1b1..e035fffc 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -11,7 +11,7 @@ from alignment import ModelWrapper wrapper = ModelWrapper( model, # PyTorch model - tracked_layers=None, # List of layer names or None (auto-detect) + tracked_layers=None, # Layer names or None (auto-detect) track_inputs=True, track_outputs=True ) @@ -31,6 +31,8 @@ metric.requires_outputs # bool metric.compute(inputs, weights, outputs, **kwargs) # Returns scores ``` +--- + ## Metrics ### Rayleigh Quotient @@ -42,10 +44,7 @@ rq = get_metric('rayleigh_quotient', relative=True, regularization=1e-6 ) -scores = rq.compute(inputs, weights) # [num_neurons] - -# Class-conditioned -results = rq.compute_class_conditioned(inputs, weights, targets, return_delta_rq=True) +scores = rq.compute(inputs, weights) ``` ### Redundancy @@ -59,92 +58,126 @@ redundancy = get_metric('pairwise_redundancy_gaussian', scores = redundancy.compute(outputs=layer_outputs) ``` -### Synergy +### Synergy (Continuous Target) ```python -synergy = get_metric('synergy_gaussian_mmi', num_pairs=10) -scores = synergy.compute(inputs, weights, targets=labels) +from alignment.metrics.information import SynergyContinuousTarget + +synergy = SynergyContinuousTarget( + target_type='logit_margin', # or 'correct_logit', 'logit_pc1' + num_pairs=10, + sampling_strategy='top_k' +) +scores = synergy.compute(outputs=activations, logits=logits, labels=labels) ``` -## Services +--- -### ActivationCaptureService +## Clustering Analysis + +### MetricSpaceClustering + +Clusters channels in (RQ, Redundancy, Synergy) space. ```python -from alignment.services import ActivationCaptureService +from alignment.analysis.clustering import MetricSpaceClustering, ClusterResult -capture = ActivationCaptureService(model_wrapper) -data = capture.capture(input_batch, layers=['conv1'], include_weights=True) +clusterer = MetricSpaceClustering(n_clusters=4, seed=42) +result = clusterer.fit(rq_scores, redundancy_scores, synergy_scores, layer_name="conv1") + +# Result attributes +result.labels # Cluster assignments [n_channels] +result.centroids # Cluster centers [n_clusters, 3] +result.silhouette # Silhouette score +result.type_mapping # {cluster_id: 'critical'|'redundant'|'synergistic'|'background'} +result.type_counts # {'critical': N, ...} ``` -### NodeScoringService +### CrossLayerHaloAnalysis + +Analyzes downstream dependencies via halos. ```python -from alignment.services import NodeScoringService +from alignment.analysis.clustering import CrossLayerHaloAnalysis, HaloResult -scorer = NodeScoringService( - metrics={'rq': rq_metric, 'redundancy': redundancy_metric}, - gamma_redundancy=0.4, - delta_rq=0.3 -) -scores = scorer.compute_composite_scores(inputs, weights, targets) -``` +halo_analyzer = CrossLayerHaloAnalysis(percentile=90.0, use_activation_weight=True) -### MaskOperations +# Compute influence matrix +influence = halo_analyzer.compute_influence(weights, activations) -```python -from alignment.services import MaskOperations +# Find halo for a cluster +halo_indices, rel_influence = halo_analyzer.find_halo(influence, cluster_indices) -mask = MaskOperations.create_structured_mask(scores, amount=0.5, mode='low') -stats = MaskOperations.get_mask_statistics(mask) +# Analyze halo properties +halo_result = halo_analyzer.analyze_halo( + halo_indices, next_layer_redundancy, next_layer_synergy, + layer_name="layer2", cluster_name="critical" +) ``` -## Pruning +### CascadeAnalysis -### Quick Pruning +Validates importance via channel ablation. ```python -from alignment.pruning.orchestrator import prune_with_all_options +from alignment.analysis import CascadeAnalysis, DamagePrediction -result = prune_with_all_options( - model, - target_sparsity=0.7, - distribution='adaptive_sensitivity', - scoring='composite', - direction='low', - val_loader=val_loader, - eval_fn=evaluate -) +cascade = CascadeAnalysis(model, test_loader, device="cuda") +baseline = cascade.baseline() + +# Ablate specific channels +result = cascade.ablate(layer_name="conv1", indices=[0, 5, 10]) +# result.accuracy_drop, result.loss_increase + +# Test by cluster type +results = cascade.by_cluster(layer_name, labels, type_mapping, n_rm=5) ``` -### Dependency-Aware Pruning +--- + +## Experiments + +### ClusterAnalysisExperiment + +General cluster-based analysis for any architecture. ```python -from alignment.pruning.dependency_aware import DependencyAwarePruning +from alignment.experiments import ClusterAnalysisExperiment, ClusterAnalysisConfig + +config = ClusterAnalysisConfig( + model_name="resnet18", + dataset_name="cifar10", + n_clusters=4, + synergy_target="logit_margin", + halo_percentile=90.0, + device="cuda" +) -pruner = DependencyAwarePruning(model) -result = pruner.prune(layer_scores={'conv1': scores1}, amount=0.5, mode='low') +experiment = ClusterAnalysisExperiment(config, model, train_loader, test_loader) +results = experiment.run() +experiment.generate_figures() ``` -## Model Wrappers +### LLMAlignmentExperiment -### TransformerWrapperEnhanced +LLM-specific analysis with SCAR metrics. ```python -from alignment.models.transformer_enhanced import TransformerWrapperEnhanced +from alignment.experiments import LLMAlignmentExperiment -wrapper = TransformerWrapperEnhanced( - transformer_model, - track_qkv=True, - track_per_head=True -) -head_repr = wrapper.extract_attention_heads(attn_output) -``` +experiment = LLMAlignmentExperiment(config) +experiment.setup() -## Experiments +scores = experiment.compute_importance_scores(num_samples=100) +scar_scores = experiment.compute_scar_supernode_metrics() +masks = experiment.apply_pruning(sparsity=0.3, metric="scar_loss_proxy", mode="low") +perplexity = experiment.evaluate_perplexity("wikitext", "test", num_samples=100) +``` ### GeneralAlignmentExperiment +Vision model alignment analysis. + ```python from alignment.experiments import GeneralAlignmentExperiment @@ -152,63 +185,104 @@ experiment = GeneralAlignmentExperiment.from_yaml("config.yaml") results = experiment.run() ``` -### LLMAlignmentExperiment +--- + +## Visualization + +### Cluster Plots ```python -from alignment.experiments import LLMAlignmentExperiment +from alignment.analysis.visualization import ( + plot_metric_scatter, + plot_cluster_evolution, + plot_influence_matrix, + plot_cascade_test, + plot_halo_properties +) -experiment = LLMAlignmentExperiment(config) -experiment.setup() +# Metric space scatter (RQ vs Red, RQ vs Syn, Red vs Syn) +plot_metric_scatter(rq, redundancy, synergy, labels, type_mapping, + layer_name, save_path) -# Compute importance scores -scores = experiment.compute_importance_scores(num_samples=100) +# Cluster composition across depth +plot_cluster_evolution(layer_results, save_path) -# Compute SCAR metrics -scar_scores = experiment.compute_scar_supernode_metrics() +# Cross-cluster influence heatmap +plot_influence_matrix(flow_dict, layer_name, save_path) -# Analyze supernode connections -supernode_analysis = experiment.analyze_supernode_connections( - scar_scores=scar_scores, - supernode_fraction=0.01, # Top 1% as supernodes - follower_fraction=0.10, # Top 10% by weight from supernodes - supernode_metric="scar_activation_power", # Metric for supernode identification - cross_layer_analysis=True, # Enable next-layer analysis - compute_metrics=["activation", "rayleigh_quotient", "mutual_information", "redundancy"], - compare_by_connection=True, # Compare high vs low connected neurons - target_layers=["model.layers.10.mlp.down_proj"], # Specific layers (None = use tracked_layers, [] = all) - plots_dir="./plots" -) +# Cascade damage by cluster type +plot_cascade_test(cascade_results, save_path) ``` -## Visualization - ### UnifiedVisualizer ```python from alignment.analysis.visualization import UnifiedVisualizer viz = UnifiedVisualizer() - -# Basic plots viz.plot_layer_scores(scores, metric_name, plot_type='violin', save_path='plot.png') viz.plot_importance_histogram(scores, layer_name, metric_name, plots_dir) viz.plot_scatter_2d(x, y, xlabel, ylabel, title, save_path) viz.plot_heatmap(data, title, cmap, save_path) -viz.plot_pruning_performance(results, metrics, save_path) +``` + +--- + +## Pruning -# Supernode analysis plots -viz.plot_supernode_activation_distribution( - activation_values, threshold_value, threshold_percentile, - layer_name, metric_name="scar_activation_power", save_path=path +### Quick Pruning + +```python +from alignment.pruning.orchestrator import prune_with_all_options + +result = prune_with_all_options( + model, + target_sparsity=0.7, + distribution='adaptive_sensitivity', + scoring='composite', + direction='low', + val_loader=val_loader, + eval_fn=evaluate ) -viz.plot_outgoing_weights_distribution(weights, layer_name, save_path=path) -viz.plot_supernode_influence(influence_values, threshold_value, threshold_percentile, layer_name, save_path=path) -viz.plot_correlation_matrix(corr_matrix, title, xlabel, ylabel, save_path=path) -viz.plot_1d_histogram(values, xlabel, ylabel, title, vline, vline_label, color, save_path=path) -viz.plot_rq_vs_mi(rq_scores, mi_scores, redundancy_scores, layer_name, save_path=path) -viz.plot_redundancy_comparison(high_redundancy, low_redundancy, high_mean, low_mean, layer_name, save_dir) ``` +### Dependency-Aware Pruning + +```python +from alignment.pruning.dependency_aware import DependencyAwarePruning + +pruner = DependencyAwarePruning(model) +result = pruner.prune(layer_scores={'conv1': scores1}, amount=0.5, mode='low') +``` + +--- + +## Services + +### ActivationCaptureService + +```python +from alignment.services import ActivationCaptureService + +capture = ActivationCaptureService(model_wrapper) +data = capture.capture(input_batch, layers=['conv1'], include_weights=True) +``` + +### NodeScoringService + +```python +from alignment.services import NodeScoringService + +scorer = NodeScoringService( + metrics={'rq': rq_metric, 'redundancy': redundancy_metric}, + gamma_redundancy=0.4, + delta_rq=0.3 +) +scores = scorer.compute_composite_scores(inputs, weights, targets) +``` + +--- + ## Configuration Parameters ### Metric Parameters @@ -222,10 +296,25 @@ viz.plot_redundancy_comparison(high_redundancy, low_redundancy, high_mean, low_m - `num_pairs` (int): Partners to sample - `aggregation` (str): 'mean', 'median', 'max', 'sum' +**SynergyContinuousTarget** +- `target_type` (str): 'logit_margin', 'correct_logit', 'logit_pc1' +- `num_pairs` (int): Partner neurons per channel +- `sampling_strategy` (str): 'random', 'top_k', 'all' + +### Clustering Parameters + +**MetricSpaceClustering** +- `n_clusters` (int): Number of clusters (default: 4) +- `seed` (int): Random seed + +**CrossLayerHaloAnalysis** +- `percentile` (float): Halo membership threshold (default: 90.0) +- `use_activation_weight` (bool): Weight influence by activation std + ### Pruning Parameters -**Strategy**: 'magnitude', 'alignment', 'composite', 'movement', 'adaptive' +**Strategy**: 'magnitude', 'alignment', 'composite', 'cluster_aware', 'random' -**Distribution**: 'uniform', 'global_threshold', 'adaptive_sensitivity', 'cascading' +**Distribution**: 'uniform', 'global_threshold', 'adaptive_sensitivity' -**Direction**: 'low' (prune unimportant), 'high' (ablation), 'random' (baseline) +**Direction**: 'low' (prune unimportant), 'high' (ablation) diff --git a/scripts/README.md b/scripts/README.md index 9e5de1bb..37fba2ad 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -7,18 +7,23 @@ Entry points for experiments and analysis. Run experiments from YAML configuration: ```bash +# Vision analysis python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml -python scripts/run_experiment.py --config configs/examples/resnet_pruning.yaml -python scripts/run_experiment.py --config configs/examples/llm_alignment.yaml + +# LLM analysis +python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml + +# Cluster-based analysis +python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml ``` Options: - `--config PATH` - Configuration file (required) -- `--device STRING` - Override device +- `--device STRING` - Override device (e.g., cuda:0) - `--seed INT` - Override random seed - `--output-dir PATH` - Override output directory - `--analysis-only` - Regenerate plots from existing results -- `--experiment-dir PATH` - Existing experiment directory +- `--experiment-dir PATH` - Existing experiment directory (with --analysis-only) ## run_analysis.py @@ -35,3 +40,11 @@ Options: - `--output-dir PATH` - Output directory - `--analyses LIST` - Specific analyses to run - `--quick` - Run all analyses with defaults + +## analyze_halo_redundancy.py + +Specialized script for halo redundancy analysis: + +```bash +python scripts/analyze_halo_redundancy.py --results-dir ./results --output-dir ./plots +``` diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 3beb8506..17efd3b8 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -68,10 +68,218 @@ def patched_tqdm(*args, **kwargs): from alignment.pruning.experiments.cascading_layer import CascadingLayerPruningExperiment from alignment.pruning.experiments.layer_wise import LayerIsolatedPruningExperiment from alignment.experiments.llm_experiments import LLMAlignmentExperiment +from alignment.experiments.cluster_experiments import ( + ClusterAnalysisExperiment, + ClusterAnalysisConfig, + VisionExperiment, # backward compat + VisionExperimentConfig, # backward compat +) logger = logging.getLogger(__name__) +def _create_cluster_experiment(config): + """Create ClusterAnalysisExperiment from unified config.""" + import torch + import torchvision + import torchvision.transforms as transforms + + # Helper to safely get nested config values + def _get_nested(obj, key, default): + """Get nested config value, handling both dict and object attributes.""" + if hasattr(obj, key): + val = getattr(obj, key) + if isinstance(val, dict): + return val + return default + return default + + # Extract nested configs with proper defaults + model_cfg = _get_nested(config, "model", {}) + dataset_cfg = _get_nested(config, "dataset", {}) + metrics_cfg = _get_nested(config, "metrics", {}) + clustering_cfg = _get_nested(config, "clustering", {}) + halo_cfg = _get_nested(config, "halo_analysis", {}) + + # Build ClusterAnalysisConfig from the loaded config + cluster_config = ClusterAnalysisConfig( + model_name=getattr(config, "model_name", model_cfg.get("name", "resnet18") if isinstance(model_cfg, dict) else "resnet18"), + dataset_name=getattr(config, "dataset_name", dataset_cfg.get("name", "cifar10") if isinstance(dataset_cfg, dict) else "cifar10"), + n_calibration=getattr(config, "n_calibration", metrics_cfg.get("n_calibration_samples", 5000) if isinstance(metrics_cfg, dict) else 5000), + n_clusters=getattr(config, "n_clusters", clustering_cfg.get("n_clusters", 4) if isinstance(clustering_cfg, dict) else 4), + synergy_target=getattr(config, "synergy_target", metrics_cfg.get("synergy_target", "logit_margin") if isinstance(metrics_cfg, dict) else "logit_margin"), + synergy_pairs=getattr(config, "synergy_pairs", metrics_cfg.get("synergy_num_pairs", 10) if isinstance(metrics_cfg, dict) else 10), + halo_percentile=getattr(config, "halo_percentile", halo_cfg.get("percentile", 90.0) if isinstance(halo_cfg, dict) else 90.0), + output_dir=getattr(config, "experiment_dir", "results/cluster_analysis"), + device=getattr(config, "device", "cuda"), + seed=getattr(config, "seed", 42), + ) + + # Load model + model_name = cluster_config.model_name.lower() + num_classes = 10 if "cifar" in cluster_config.dataset_name.lower() else 1000 + + if "resnet18" in model_name: + model = torchvision.models.resnet18(pretrained=True) + model.fc = torch.nn.Linear(model.fc.in_features, num_classes) + elif "resnet50" in model_name: + model = torchvision.models.resnet50(pretrained=True) + model.fc = torch.nn.Linear(model.fc.in_features, num_classes) + elif "vgg16" in model_name: + model = torchvision.models.vgg16_bn(pretrained=True) + model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) + elif "mobilenet" in model_name: + model = torchvision.models.mobilenet_v2(pretrained=True) + model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) + else: + raise ValueError(f"Unknown model: {model_name}") + + # Load dataset + dataset_name = cluster_config.dataset_name.lower() + if "cifar10" in dataset_name: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + elif "cifar100" in dataset_name: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) + test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") + + batch_size = getattr(config, "batch_size", 128) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4) + + # Fine-tune the model on target dataset before experiments + # This is necessary because we replaced the classifier head with random weights + model = _finetune_model_for_dataset( + model, train_loader, test_loader, + device=cluster_config.device, + epochs=getattr(config, "pretrain_epochs", 20), + lr=getattr(config, "pretrain_lr", 0.001), + ) + + return ClusterAnalysisExperiment(cluster_config, model, train_loader, test_loader) + + +def _finetune_model_for_dataset( + model: torch.nn.Module, + train_loader: torch.utils.data.DataLoader, + test_loader: torch.utils.data.DataLoader, + device: str = "cuda", + epochs: int = 20, + lr: float = 0.001, +) -> torch.nn.Module: + """ + Fine-tune a pretrained model on the target dataset. + + This is necessary when using ImageNet pretrained models on CIFAR-10/100 + because the classifier head is replaced with random weights. + + Args: + model: Model with replaced classifier head + train_loader: Training data loader + test_loader: Test data loader + device: Device to train on + epochs: Number of fine-tuning epochs + lr: Learning rate + + Returns: + Fine-tuned model + """ + import torch.optim as optim + + model = model.to(device) + + # Check initial accuracy + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x, y = x.to(device), y.to(device) + out = model(x) + correct += (out.argmax(1) == y).sum().item() + total += y.size(0) + initial_acc = correct / total + + # If already trained (>50% accuracy), skip fine-tuning + if initial_acc > 0.5: + logger.info(f"Model already trained (accuracy: {initial_acc:.2%}), skipping fine-tuning") + return model + + logger.info(f"Fine-tuning model on target dataset (initial accuracy: {initial_acc:.2%})...") + + # Use different learning rates for pretrained vs new layers + # Freeze early layers, fine-tune later layers + new classifier + pretrained_params = [] + new_params = [] + + for name, param in model.named_parameters(): + if 'fc' in name or 'classifier' in name: + new_params.append(param) + else: + pretrained_params.append(param) + + optimizer = optim.Adam([ + {'params': pretrained_params, 'lr': lr * 0.1}, # Lower LR for pretrained + {'params': new_params, 'lr': lr}, # Higher LR for new classifier + ], weight_decay=1e-4) + + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + criterion = torch.nn.CrossEntropyLoss() + + best_acc = 0 + best_state = None + + for epoch in range(epochs): + # Train + model.train() + train_loss = 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + out = model(x) + loss = criterion(out, y) + loss.backward() + optimizer.step() + train_loss += loss.item() + + scheduler.step() + + # Evaluate + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x, y = x.to(device), y.to(device) + out = model(x) + correct += (out.argmax(1) == y).sum().item() + total += y.size(0) + + acc = correct / total + if acc > best_acc: + best_acc = acc + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + if (epoch + 1) % 5 == 0 or epoch == 0: + logger.info(f" Epoch {epoch+1}/{epochs}: loss={train_loss/len(train_loader):.4f}, acc={acc:.2%}") + + # Load best model + if best_state is not None: + model.load_state_dict({k: v.to(device) for k, v in best_state.items()}) + + logger.info(f"Fine-tuning complete. Best accuracy: {best_acc:.2%}") + + return model + + def run_post_analysis(config, results_file: Path, output_dir: Path): """Run post-experiment analysis using AnalysisRunner.""" post_analysis_config = getattr(config, "post_analysis", {}) @@ -214,6 +422,9 @@ def main(): experiment = LLMAlignmentExperiment(config) elif experiment_type in {"alignment_analysis", "vision_synergy", "general_alignment"}: experiment = GeneralAlignmentExperiment(config) + elif experiment_type in {"cluster_analysis", "vision_cluster_analysis", "metric_cluster_analysis"}: + # Cluster-based analysis experiment (works for any architecture) + experiment = _create_cluster_experiment(config) elif experiment_type == "layer_isolated_pruning": experiment = LayerIsolatedPruningExperiment(config) elif experiment_type == "cascading_layer_pruning": diff --git a/slurm-48692030.out b/slurm-48692030.out deleted file mode 100644 index 0fec100d..00000000 --- a/slurm-48692030.out +++ /dev/null @@ -1,4 +0,0 @@ -Running vision synergy experiment with config: configs/projects/vision_synergy.yaml -Working directory: /var/slurmd/spool/slurmd - -python: can't open file '/var/slurmd/spool/slurmd/scripts/run_experiment.py': [Errno 2] No such file or directory diff --git a/slurm_jobs/run_cluster_analysis_resnet18.sh b/slurm_jobs/run_cluster_analysis_resnet18.sh new file mode 100644 index 00000000..360719df --- /dev/null +++ b/slurm_jobs/run_cluster_analysis_resnet18.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#SBATCH --job-name=cluster_analysis_resnet18 +#SBATCH --output=logs/cluster_analysis_resnet18_%j.out +#SBATCH --error=logs/cluster_analysis_resnet18_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=4:00:00 +#SBATCH --mem=64GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ============================================================================ +# CLUSTER-BASED ANALYSIS: ResNet-18 on CIFAR-10 +# ============================================================================ +# Full cluster-based analysis including: +# - Per-channel metrics (RQ, Redundancy, Synergy with continuous target) +# - K-means clustering into functional types +# - Cross-layer halo analysis +# - Cascade damage testing +# - Visualization generation +# +# Expected runtime: ~1-2 hours on single GPU +# ============================================================================ + +echo "============================================================================" +echo "Cluster-Based Analysis: ResNet-18 on CIFAR-10" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment + +# Create directories +mkdir -p logs +mkdir -p results/cluster_analysis/resnet18_cifar10 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK + +echo "" +echo "Running ResNet-18 cluster analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/cluster_analysis/resnet18_cifar10_full.yaml \ + --device cuda + +EXIT_CODE=$? + +echo "" +echo "============================================================================" +echo "ResNet-18 cluster analysis completed at $(date)" +echo "Exit code: $EXIT_CODE" +echo "============================================================================" +echo "" +echo "Results saved to: results/cluster_analysis/resnet18_cifar10/" + +exit $EXIT_CODE + diff --git a/slurm_jobs/run_cluster_analysis_resnet50.sh b/slurm_jobs/run_cluster_analysis_resnet50.sh new file mode 100644 index 00000000..f34b623a --- /dev/null +++ b/slurm_jobs/run_cluster_analysis_resnet50.sh @@ -0,0 +1,76 @@ +#!/bin/bash +#SBATCH --job-name=cluster_analysis_resnet50 +#SBATCH --output=logs/cluster_analysis_resnet50_%j.out +#SBATCH --error=logs/cluster_analysis_resnet50_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=12:00:00 +#SBATCH --mem=128GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ============================================================================ +# CLUSTER-BASED ANALYSIS: ResNet-50 on ImageNet-100 +# ============================================================================ +# Full cluster-based analysis including: +# - Per-channel metrics (RQ, Redundancy, Synergy with continuous target) +# - K-means clustering into functional types +# - Cross-layer halo analysis with activation weighting +# - Cascade damage testing +# - Pruning experiments with fine-tuning +# - Visualization generation +# +# Expected runtime: ~6-10 hours on single GPU (A100) +# ============================================================================ + +echo "============================================================================" +echo "Cluster-Based Analysis: ResNet-50 on ImageNet-100" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment + +# Create directories +mkdir -p logs +mkdir -p results/vision/resnet50_imagenet100 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK + +# Check for ImageNet-100 data +if [ ! -d "data/imagenet100" ]; then + echo "WARNING: ImageNet-100 data not found at data/imagenet100" + echo "Please download or symlink the ImageNet-100 subset before running." + echo "" +fi + +echo "" +echo "Running ResNet-50 cluster analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/cluster_analysis/resnet50_imagenet100.yaml \ + --device cuda + +EXIT_CODE=$? + +echo "" +echo "============================================================================" +echo "ResNet-50 cluster analysis completed at $(date)" +echo "Exit code: $EXIT_CODE" +echo "============================================================================" +echo "" +echo "Results saved to: results/vision/resnet50_imagenet100/" + +exit $EXIT_CODE diff --git a/src/alignment/analysis/README.md b/src/alignment/analysis/README.md index 81ef5314..47f48083 100644 --- a/src/alignment/analysis/README.md +++ b/src/alignment/analysis/README.md @@ -1,16 +1,66 @@ # Analysis Module -Result analysis, visualization, and reporting. +Result analysis, visualization, clustering, and reporting. ## Components +### Core Analysis - `AnalysisRunner` - Unified entry point for analysis tasks -- `UnifiedVisualizer` - Plot generation -- `UnifiedReporter` - Report generation (HTML, Markdown, JSON) - `ResultAggregator` - Result collection and summarization +### Clustering +- `MetricSpaceClustering` - K-means in (RQ, Redundancy, Synergy) space +- `CrossLayerHaloAnalysis` - Track downstream channel dependencies +- `CascadeAnalysis` - Validate importance via ablation +- `DamagePrediction` - Correlate scores with true damage + +### Visualization +- `UnifiedVisualizer` - General plot generation +- `plot_metric_scatter` - Cluster scatter plots +- `plot_cluster_evolution` - Composition across depth +- `plot_influence_matrix` - Cross-layer influence heatmaps +- `plot_cascade_test` - Damage by cluster type + +### Reporting +- `UnifiedReporter` - Report generation (HTML, Markdown, JSON) + ## Usage +### Cluster Analysis + +```python +from alignment.analysis.clustering import MetricSpaceClustering, CrossLayerHaloAnalysis + +# Cluster channels +clusterer = MetricSpaceClustering(n_clusters=4) +result = clusterer.fit(rq, redundancy, synergy, "layer1") + +# Analyze halos +halo_analyzer = CrossLayerHaloAnalysis(percentile=90.0) +halo_idx, influence = halo_analyzer.find_halo(weights, cluster_indices) +``` + +### Cascade Testing + +```python +from alignment.analysis import CascadeAnalysis + +cascade = CascadeAnalysis(model, test_loader, device="cuda") +cascade.baseline() +results = cascade.by_cluster(layer_name, labels, type_mapping, n_rm=5) +``` + +### Visualization + +```python +from alignment.analysis.visualization import plot_metric_scatter, plot_cluster_evolution + +plot_metric_scatter(rq, red, syn, labels, type_map, "layer1", "scatter.png") +plot_cluster_evolution(layer_results, "evolution.png") +``` + +### General Analysis + ```python from alignment.analysis import AnalysisRunner, AnalysisConfig @@ -23,10 +73,6 @@ runner = AnalysisRunner(config) runner.run() ``` -```bash -python scripts/run_analysis.py --results-dir ./results --output-dir ./plots --quick -``` - ## Available Analyses - `histograms` - Importance score distributions @@ -34,3 +80,6 @@ python scripts/run_analysis.py --results-dir ./results --output-dir ./plots --qu - `heatmaps` - Layer-metric heatmaps - `pruning_curves` - Sparsity vs performance - `scar_analysis` - SCAR metrics (LLM) +- `cluster_scatter` - Metric space cluster plots +- `cluster_evolution` - Cluster composition by depth +- `cascade_test` - Ablation damage analysis diff --git a/src/alignment/analysis/__init__.py b/src/alignment/analysis/__init__.py index a667545d..d361ea56 100644 --- a/src/alignment/analysis/__init__.py +++ b/src/alignment/analysis/__init__.py @@ -22,6 +22,12 @@ # Unified Analysis Runner from .analysis_runner import AnalysisRunner, AnalysisConfig, run_analysis_from_config +# Clustering Analysis +from .clustering import MetricSpaceClustering, ClusterResult, CrossLayerHaloAnalysis, HaloResult + +# Cascade Analysis +from .cascade_analysis import CascadeAnalysis, DamagePrediction, CascadeResult, DamageResult + __all__ = [ # Aggregation "ResultAggregator", @@ -37,4 +43,14 @@ "AnalysisRunner", "AnalysisConfig", "run_analysis_from_config", + # Clustering + "MetricSpaceClustering", + "ClusterResult", + "CrossLayerHaloAnalysis", + "HaloResult", + # Cascade Analysis + "CascadeAnalysis", + "DamagePrediction", + "CascadeResult", + "DamageResult", ] diff --git a/src/alignment/analysis/cascade_analysis.py b/src/alignment/analysis/cascade_analysis.py new file mode 100644 index 00000000..06d6a53b --- /dev/null +++ b/src/alignment/analysis/cascade_analysis.py @@ -0,0 +1,154 @@ +""" +Cascade and damage analysis for pruning validation. + +Implements: +1. Cascade test: measure downstream disruption when removing channels +2. Damage prediction: correlate scores with true accuracy drop +3. Cluster-specific ablation: compare damage by functional type +""" + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Any +import numpy as np + +logger = logging.getLogger(__name__) + +try: + import torch + import torch.nn as nn + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + + +@dataclass +class CascadeResult: + """Result of cascade test for a cluster type.""" + layer_name: str + cluster_type: str + n_removed: int + accuracy_drop: float + loss_increase: float + + +@dataclass +class DamageResult: + """Result of damage prediction analysis.""" + layer_name: str + method: str + spearman: float + top_k_recall: Dict[int, float] + + +class CascadeAnalysis: + """Analyze cascade effects of channel removal.""" + + def __init__(self, model, dataloader, device="cuda"): + self.model = model + self.loader = dataloader + self.device = device + self._baseline = None + + def baseline(self): + """Compute baseline accuracy/loss.""" + if not HAS_TORCH: + return {"acc": 0., "loss": 0.} + self.model.eval() + correct, total, loss_sum = 0, 0, 0. + crit = nn.CrossEntropyLoss() + with torch.no_grad(): + for x, y in self.loader: + x, y = x.to(self.device), y.to(self.device) + out = self.model(x) + loss_sum += crit(out, y).item() * x.size(0) + correct += (out.argmax(1) == y).sum().item() + total += y.size(0) + self._baseline = {"acc": correct/total, "loss": loss_sum/total} + return self._baseline + + def ablate(self, layer_name: str, indices: List[int]) -> CascadeResult: + """Remove channels and measure effect.""" + if self._baseline is None: + self.baseline() + layer = dict(self.model.named_modules()).get(layer_name) + if layer is None or not hasattr(layer, 'weight'): + return CascadeResult(layer_name, "", len(indices), 0., 0.) + orig_w = layer.weight.data.clone() + orig_b = layer.bias.data.clone() if layer.bias is not None else None + layer.weight.data[indices] = 0 + if orig_b is not None: + layer.bias.data[indices] = 0 + new = self._eval() + layer.weight.data = orig_w + if orig_b is not None: + layer.bias.data = orig_b + return CascadeResult(layer_name, "", len(indices), + self._baseline["acc"] - new["acc"], + new["loss"] - self._baseline["loss"]) + + def by_cluster(self, layer: str, labels: np.ndarray, + types: Dict[int, str], n_rm: int = 5) -> Dict[str, CascadeResult]: + """Run cascade test per cluster type.""" + results = {} + for cid, ctype in types.items(): + idx = np.where(labels == cid)[0] + if len(idx) == 0: + continue + rm = np.random.choice(idx, min(n_rm, len(idx)), replace=False).tolist() + r = self.ablate(layer, rm) + r.cluster_type = ctype + results[ctype] = r + return results + + def _eval(self): + self.model.eval() + correct, total, loss_sum = 0, 0, 0. + crit = nn.CrossEntropyLoss() + with torch.no_grad(): + for x, y in self.loader: + x, y = x.to(self.device), y.to(self.device) + out = self.model(x) + loss_sum += crit(out, y).item() * x.size(0) + correct += (out.argmax(1) == y).sum().item() + total += y.size(0) + return {"acc": correct/total, "loss": loss_sum/total} + + +class DamagePrediction: + """Predict damage from importance scores.""" + + def __init__(self, cascade: CascadeAnalysis, layer: str): + self.cascade = cascade + self.layer = layer + self._damages = None + + def compute_damages(self, n_ch: int, frac: float = 0.2) -> np.ndarray: + """Compute true per-channel damage.""" + damages = np.zeros(n_ch) + test_idx = np.random.choice(n_ch, max(1, int(n_ch * frac)), replace=False) + for i in test_idx: + r = self.cascade.ablate(self.layer, [int(i)]) + damages[i] = r.accuracy_drop + self._damages = damages + return damages + + def evaluate(self, scores: np.ndarray, method: str = "composite", + top_ks: List[int] = [10, 20, 50]) -> DamageResult: + """Evaluate score vs damage correlation.""" + from scipy import stats + if self._damages is None: + raise ValueError("Call compute_damages first") + mask = self._damages != 0 + if mask.sum() < 5: + return DamageResult(self.layer, method, 0., {}) + d, s = self._damages[mask], scores[mask] + rho, _ = stats.spearmanr(s, -d) + recall = {} + by_d = np.argsort(-d) + by_s = np.argsort(s) + for k in top_ks: + k = min(k, len(d)) + overlap = len(set(by_d[:k]) & set(by_s[:k])) + recall[k] = overlap / k if k > 0 else 0. + return DamageResult(self.layer, method, float(rho) if not np.isnan(rho) else 0., recall) diff --git a/src/alignment/analysis/clustering/__init__.py b/src/alignment/analysis/clustering/__init__.py new file mode 100644 index 00000000..0b996c3b --- /dev/null +++ b/src/alignment/analysis/clustering/__init__.py @@ -0,0 +1,16 @@ +""" +Clustering analysis module for neural network channels. + +Provides clustering in (RQ, Redundancy, Synergy) space to identify +functional types: Critical, Redundant, Synergistic, Background. +""" + +from .metric_clustering import MetricSpaceClustering, ClusterResult +from .cross_layer_halo import CrossLayerHaloAnalysis, HaloResult + +__all__ = [ + "MetricSpaceClustering", + "ClusterResult", + "CrossLayerHaloAnalysis", + "HaloResult", +] diff --git a/src/alignment/analysis/clustering/cross_layer_halo.py b/src/alignment/analysis/clustering/cross_layer_halo.py new file mode 100644 index 00000000..9e5c1337 --- /dev/null +++ b/src/alignment/analysis/clustering/cross_layer_halo.py @@ -0,0 +1,156 @@ +"""Cross-layer halo analysis.""" +import numpy as np +from dataclasses import dataclass +from typing import Dict, List, Optional + + +@dataclass +class HaloResult: + """Result of halo analysis for a cluster.""" + layer_name: str + source_cluster: str + halo_indices: np.ndarray + halo_size: int + halo_redundancy_mean: float + halo_synergy_mean: float + influence_scores: Optional[np.ndarray] = None + + +class CrossLayerHaloAnalysis: + """ + Analyze downstream halos of clusters. + + A halo is the set of channels in the next layer that receive + disproportionate input from a given cluster. + """ + + def __init__(self, percentile: float = 90.0, use_activation_weight: bool = True): + """ + Args: + percentile: Threshold percentile for halo membership + use_activation_weight: Whether to weight by activation std + """ + self.percentile = percentile + self.use_activation_weight = use_activation_weight + + def compute_influence(self, weights: np.ndarray, activations: Optional[np.ndarray] = None) -> np.ndarray: + """ + Compute influence scores. + + Args: + weights: Weight matrix [out_channels, in_channels] + activations: Optional activations [batch, in_channels] + + Returns: + Influence matrix [out_channels, in_channels] + """ + w = np.abs(weights) + if activations is not None and self.use_activation_weight: + std = np.std(activations, axis=0) + w = w * std[None, :] + return w + + def find_halo( + self, + influence: np.ndarray, + cluster_indices: np.ndarray, + ) -> tuple: + """ + Find receivers that get high relative influence from cluster. + + Args: + influence: Influence matrix [out, in] + cluster_indices: Indices of channels in source cluster + + Returns: + (halo_indices, relative_influence_scores) + """ + # Sum influence from cluster members + infl_from_cluster = influence[:, cluster_indices].sum(axis=1) + # Normalize by total incoming + total_infl = influence.sum(axis=1) + 1e-10 + rel_infl = infl_from_cluster / total_infl + # Threshold + thresh = np.percentile(rel_infl, self.percentile) + halo_mask = rel_infl >= thresh + return np.where(halo_mask)[0], rel_infl + + def analyze_halo( + self, + halo_indices: np.ndarray, + redundancy: np.ndarray, + synergy: np.ndarray, + layer_name: str = "", + cluster_name: str = "", + ) -> HaloResult: + """ + Compute properties of a halo. + + Args: + halo_indices: Indices of halo channels + redundancy: Per-channel redundancy in next layer + synergy: Per-channel synergy in next layer + layer_name: Layer identifier + cluster_name: Source cluster type + + Returns: + HaloResult with summary statistics + """ + if len(halo_indices) == 0: + return HaloResult( + layer_name=layer_name, + source_cluster=cluster_name, + halo_indices=halo_indices, + halo_size=0, + halo_redundancy_mean=0.0, + halo_synergy_mean=0.0, + ) + + red_mean = float(np.mean(redundancy[halo_indices])) + syn_mean = float(np.mean(synergy[halo_indices])) + + return HaloResult( + layer_name=layer_name, + source_cluster=cluster_name, + halo_indices=halo_indices, + halo_size=len(halo_indices), + halo_redundancy_mean=red_mean, + halo_synergy_mean=syn_mean, + ) + + def compute_cluster_to_cluster_flow( + self, + influence: np.ndarray, + source_labels: np.ndarray, + target_labels: np.ndarray, + source_types: Dict[int, str], + target_types: Dict[int, str], + ) -> Dict[str, Dict[str, float]]: + """ + Compute cluster-to-cluster influence matrix. + + Args: + influence: [out, in] influence matrix + source_labels: Cluster labels for source layer + target_labels: Cluster labels for target layer + source_types: Mapping from cluster ID to type name + target_types: Mapping from cluster ID to type name + + Returns: + Nested dict: flow[source_type][target_type] = mean influence + """ + flow = {} + for src_id, src_type in source_types.items(): + flow[src_type] = {} + src_mask = source_labels == src_id + src_infl = influence[:, src_mask].sum(axis=1) # [out] + + for tgt_id, tgt_type in target_types.items(): + tgt_mask = target_labels == tgt_id + if tgt_mask.sum() > 0: + mean_infl = float(np.mean(src_infl[tgt_mask])) + flow[src_type][tgt_type] = mean_infl + else: + flow[src_type][tgt_type] = 0.0 + + return flow diff --git a/src/alignment/analysis/clustering/metric_clustering.py b/src/alignment/analysis/clustering/metric_clustering.py new file mode 100644 index 00000000..06ce39db --- /dev/null +++ b/src/alignment/analysis/clustering/metric_clustering.py @@ -0,0 +1,64 @@ +"""Metric-space clustering for channels.""" +import numpy as np +from dataclasses import dataclass +from typing import Dict, List, Any + +try: + from sklearn.cluster import KMeans + from sklearn.metrics import silhouette_score + HAS_SK = True +except ImportError: + HAS_SK = False + + +@dataclass +class ClusterResult: + layer_name: str + n_channels: int + n_clusters: int + labels: np.ndarray + centroids: np.ndarray + silhouette: float + type_mapping: Dict[int, str] + type_counts: Dict[str, int] + + +class MetricSpaceClustering: + def __init__(self, n_clusters=4, seed=42): + self.n_clusters = n_clusters + self.seed = seed + + def fit(self, rq, red, syn, name="layer"): + rq = np.asarray(rq).flatten() + red = np.asarray(red).flatten() + syn = np.asarray(syn).flatten() + n = len(rq) + X = np.column_stack([np.log(np.clip(rq, 1e-10, None)), red, syn]) + X = (X - X.mean(0)) / (X.std(0) + 1e-8) + if HAS_SK and n >= self.n_clusters: + km = KMeans(self.n_clusters, random_state=self.seed, n_init=10) + lab = km.fit_predict(X) + cen = km.cluster_centers_ + sil = silhouette_score(X, lab) if n > self.n_clusters else 0. + else: + lab, cen, sil = np.zeros(n, int), np.zeros((1, 3)), 0. + tm = self._types(cen) + tc = {t: int((lab == k).sum()) for k, t in tm.items()} + return ClusterResult(name, n, len(cen), lab, cen, sil, tm, tc) + + def _types(self, c): + if len(c) < 4: + return {i: "unknown" for i in range(len(c))} + m, used = {}, set() + i = int(np.argmax(c[:, 0] - c[:, 1])) + m[i] = "critical"; used.add(i) + rem = [j for j in range(len(c)) if j not in used] + i = rem[int(np.argmax([c[j, 1] for j in rem]))] + m[i] = "redundant"; used.add(i) + rem = [j for j in range(len(c)) if j not in used] + i = rem[int(np.argmax([c[j, 2] for j in rem]))] + m[i] = "synergistic"; used.add(i) + for j in range(len(c)): + if j not in m: + m[j] = "background" + return m diff --git a/src/alignment/analysis/visualization/__init__.py b/src/alignment/analysis/visualization/__init__.py index 8c245c1e..57066681 100644 --- a/src/alignment/analysis/visualization/__init__.py +++ b/src/alignment/analysis/visualization/__init__.py @@ -46,6 +46,15 @@ plot_halo_redundancy_heatmap, ) +# Cluster visualization plots +from .cluster_plots import ( + plot_metric_scatter, + plot_cluster_evolution, + plot_influence_matrix, + plot_cascade_test, + plot_halo_properties, +) + __all__ = [ # Primary "UnifiedVisualizer", @@ -58,4 +67,10 @@ "plot_halo_redundancy_by_depth", "plot_halo_redundancy_comprehensive", "plot_halo_redundancy_heatmap", + # Cluster plots + "plot_metric_scatter", + "plot_cluster_evolution", + "plot_influence_matrix", + "plot_cascade_test", + "plot_halo_properties", ] diff --git a/src/alignment/analysis/visualization/cluster_plots.py b/src/alignment/analysis/visualization/cluster_plots.py new file mode 100644 index 00000000..86e5fa42 --- /dev/null +++ b/src/alignment/analysis/visualization/cluster_plots.py @@ -0,0 +1,553 @@ +""" +Cluster visualization module for vision network analysis. + +Provides visualizations for: +1. Metric space scatter plots (RQ vs Red, RQ vs Syn, Red vs Syn) +2. Cluster composition across depth (stacked bars) +3. Cross-layer influence matrices (heatmaps) +4. Cluster stability analysis +5. Cascade test results by cluster type +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np + +logger = logging.getLogger(__name__) + +try: + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + HAS_MPL = True +except ImportError: + HAS_MPL = False + + +CLUSTER_COLORS = { + "critical": "#e74c3c", + "redundant": "#3498db", + "synergistic": "#2ecc71", + "background": "#95a5a6", + "unknown": "#bdc3c7", +} + + +def plot_metric_scatter( + rq: np.ndarray, + redundancy: np.ndarray, + synergy: np.ndarray, + labels: np.ndarray, + type_mapping: Dict[int, str], + layer_name: str = "", + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (15, 5), +) -> Optional["plt.Figure"]: + """ + Plot 2D projections of metric space with cluster colors. + + Creates 3 subplots: RQ vs Red, RQ vs Syn, Red vs Syn + """ + if not HAS_MPL: + return None + + fig, axes = plt.subplots(1, 3, figsize=figsize) + log_rq = np.log(np.clip(rq, 1e-10, None)) + + pairs = [ + (log_rq, redundancy, "log(RQ)", "Redundancy"), + (log_rq, synergy, "log(RQ)", "Synergy"), + (redundancy, synergy, "Redundancy", "Synergy"), + ] + + for ax, (x, y, xl, yl) in zip(axes, pairs): + for cid, ctype in type_mapping.items(): + mask = labels == cid + color = CLUSTER_COLORS.get(ctype, "#999999") + ax.scatter(x[mask], y[mask], c=color, label=ctype, alpha=0.6, s=20) + ax.set_xlabel(xl) + ax.set_ylabel(yl) + ax.legend() + ax.grid(True, alpha=0.3) + + fig.suptitle(f"Metric Space Clusters: {layer_name}", fontsize=14) + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved cluster scatter to {save_path}") + + return fig + + +def plot_cluster_evolution( + layer_results: List[Dict[str, Any]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 6), +) -> Optional["plt.Figure"]: + """ + Plot cluster composition across network depth as stacked bars. + + Args: + layer_results: List of dicts with 'layer_name' and 'type_counts' + """ + if not HAS_MPL: + return None + + layers = [r["layer_name"] for r in layer_results] + types = ["critical", "redundant", "synergistic", "background"] + + # Build data matrix + data = {t: [] for t in types} + for r in layer_results: + tc = r.get("type_counts", {}) + total = sum(tc.values()) or 1 + for t in types: + data[t].append(tc.get(t, 0) / total * 100) + + fig, ax = plt.subplots(figsize=figsize) + x = np.arange(len(layers)) + bottom = np.zeros(len(layers)) + + for t in types: + color = CLUSTER_COLORS.get(t, "#999999") + ax.bar(x, data[t], bottom=bottom, label=t, color=color, alpha=0.8) + bottom += np.array(data[t]) + + ax.set_xlabel("Layer") + ax.set_ylabel("Percentage of Channels") + ax.set_title("Cluster Composition Across Depth") + ax.set_xticks(x) + ax.set_xticklabels(layers, rotation=45, ha='right') + ax.legend(loc='upper right') + ax.set_ylim(0, 100) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved cluster evolution to {save_path}") + + return fig + + +def plot_influence_matrix( + flow: Dict[str, Dict[str, float]], + layer_name: str = "", + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (8, 6), +) -> Optional["plt.Figure"]: + """ + Plot cluster-to-cluster influence matrix as heatmap. + + Args: + flow: Nested dict flow[source_type][target_type] = value + """ + if not HAS_MPL: + return None + + types = ["critical", "redundant", "synergistic", "background"] + matrix = np.zeros((len(types), len(types))) + + for i, src in enumerate(types): + for j, tgt in enumerate(types): + matrix[i, j] = flow.get(src, {}).get(tgt, 0) + + # Normalize rows + row_sums = matrix.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1 + matrix_norm = matrix / row_sums + + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(matrix_norm, cmap='YlOrRd', aspect='auto') + + ax.set_xticks(np.arange(len(types))) + ax.set_yticks(np.arange(len(types))) + ax.set_xticklabels([t.capitalize() for t in types]) + ax.set_yticklabels([t.capitalize() for t in types]) + ax.set_xlabel("Target Cluster (Layer ℓ+1)") + ax.set_ylabel("Source Cluster (Layer ℓ)") + ax.set_title(f"Cross-Cluster Influence: {layer_name}") + + # Add annotations + for i in range(len(types)): + for j in range(len(types)): + ax.text(j, i, f"{matrix_norm[i, j]:.2f}", + ha="center", va="center", color="black", fontsize=10) + + plt.colorbar(im, ax=ax, label="Normalized Influence") + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved influence matrix to {save_path}") + + return fig + + +def plot_cascade_test( + results: Dict[str, Any], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (10, 6), +) -> Optional["plt.Figure"]: + """ + Plot cascade test results by cluster type. + + Args: + results: Dict mapping cluster_type to CascadeResult + """ + if not HAS_MPL: + return None + + types = list(results.keys()) + acc_drops = [results[t].accuracy_drop * 100 for t in types] + colors = [CLUSTER_COLORS.get(t, "#999999") for t in types] + + fig, ax = plt.subplots(figsize=figsize) + x = np.arange(len(types)) + bars = ax.bar(x, acc_drops, color=colors, alpha=0.8) + + ax.set_xlabel("Cluster Type") + ax.set_ylabel("Accuracy Drop (%)") + ax.set_title("Cascade Damage by Cluster Type") + ax.set_xticks(x) + ax.set_xticklabels([t.capitalize() for t in types]) + + # Add value labels + for bar, val in zip(bars, acc_drops): + ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, + f'{val:.2f}%', ha='center', va='bottom', fontsize=10) + + ax.axhline(y=np.mean(acc_drops), color='gray', linestyle='--', + label=f'Mean: {np.mean(acc_drops):.2f}%') + ax.legend() + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved cascade test to {save_path}") + + return fig + + +def plot_halo_properties( + halo_results: List[Dict[str, Any]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 5), +) -> Optional["plt.Figure"]: + """ + Plot halo redundancy and synergy by source cluster type. + + Args: + halo_results: List of dicts with cluster_type, halo_red, halo_syn + """ + if not HAS_MPL: + return None + + types = [r["cluster_type"] for r in halo_results] + reds = [r.get("halo_red", 0) for r in halo_results] + syns = [r.get("halo_syn", 0) for r in halo_results] + colors = [CLUSTER_COLORS.get(t, "#999999") for t in types] + + fig, axes = plt.subplots(1, 2, figsize=figsize) + x = np.arange(len(types)) + + axes[0].bar(x, reds, color=colors, alpha=0.8) + axes[0].set_ylabel("Halo Redundancy") + axes[0].set_title("Halo Redundancy by Source Cluster") + axes[0].set_xticks(x) + axes[0].set_xticklabels([t.capitalize() for t in types]) + axes[0].axhline(y=np.mean(reds), color='gray', linestyle='--') + + axes[1].bar(x, syns, color=colors, alpha=0.8) + axes[1].set_ylabel("Halo Synergy") + axes[1].set_title("Halo Synergy by Source Cluster") + axes[1].set_xticks(x) + axes[1].set_xticklabels([t.capitalize() for t in types]) + axes[1].axhline(y=np.mean(syns), color='gray', linestyle='--') + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved halo properties to {save_path}") + + return fig + + +def plot_centroid_evolution( + layer_centroids: List[Dict[str, Any]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (15, 5), +) -> Optional["plt.Figure"]: + """ + Plot how cluster centroids evolve across network depth. + + Creates 2D trajectory plots showing centroid movement in: + - log(RQ) vs Redundancy + - log(RQ) vs Synergy + - Redundancy vs Synergy + + Args: + layer_centroids: List of dicts with 'layer_name', 'depth', 'centroids', 'type_mapping' + where centroids is [n_clusters, 3] array (log_rq, red, syn) + save_path: Optional path to save figure + """ + if not HAS_MPL: + return None + + if not layer_centroids: + return None + + fig, axes = plt.subplots(1, 3, figsize=figsize) + + types = ["critical", "redundant", "synergistic", "background"] + + # Collect centroid trajectories by type + trajectories = {t: {"log_rq": [], "red": [], "syn": [], "depth": []} for t in types} + + for layer_data in layer_centroids: + centroids = np.array(layer_data["centroids"]) # [n_clusters, 3] + type_mapping = layer_data.get("type_mapping", {}) + depth = layer_data.get("depth", 0) + + for cluster_id, cluster_type in type_mapping.items(): + if cluster_type in trajectories and int(cluster_id) < len(centroids): + c = centroids[int(cluster_id)] + trajectories[cluster_type]["log_rq"].append(c[0]) + trajectories[cluster_type]["red"].append(c[1]) + trajectories[cluster_type]["syn"].append(c[2]) + trajectories[cluster_type]["depth"].append(depth) + + # Plot pairs + pairs = [ + ("log_rq", "red", "log(RQ)", "Redundancy"), + ("log_rq", "syn", "log(RQ)", "Synergy"), + ("red", "syn", "Redundancy", "Synergy"), + ] + + for ax, (x_key, y_key, x_label, y_label) in zip(axes, pairs): + for ctype in types: + traj = trajectories[ctype] + if not traj["depth"]: + continue + + # Sort by depth + sorted_idx = np.argsort(traj["depth"]) + x_vals = np.array(traj[x_key])[sorted_idx] + y_vals = np.array(traj[y_key])[sorted_idx] + depths = np.array(traj["depth"])[sorted_idx] + + color = CLUSTER_COLORS.get(ctype, "#999999") + + # Plot trajectory with arrows + ax.plot(x_vals, y_vals, '-', color=color, alpha=0.7, linewidth=2, label=ctype) + + # Add markers with depth coloring + scatter = ax.scatter(x_vals, y_vals, c=depths, cmap='viridis', + s=80, edgecolors=color, linewidths=2, zorder=5) + + # Add start/end markers + if len(x_vals) > 0: + ax.scatter(x_vals[0], y_vals[0], marker='o', s=150, + facecolors='white', edgecolors=color, linewidths=3, zorder=6) + ax.scatter(x_vals[-1], y_vals[-1], marker='s', s=150, + facecolors=color, edgecolors='black', linewidths=2, zorder=6) + + ax.set_xlabel(x_label, fontsize=11) + ax.set_ylabel(y_label, fontsize=11) + ax.grid(True, alpha=0.3) + ax.legend(loc='best', fontsize=9) + + # Add colorbar for depth + cbar = fig.colorbar(scatter, ax=axes, orientation='horizontal', + fraction=0.05, pad=0.12, aspect=40) + cbar.set_label('Layer Depth', fontsize=11) + + fig.suptitle('Cluster Centroid Evolution Across Network Depth\n(○ = early layers, ■ = late layers)', + fontsize=13, fontweight='bold') + plt.tight_layout(rect=[0, 0.08, 1, 0.95]) + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved centroid evolution to {save_path}") + + return fig + + +def plot_centroid_depth_profiles( + layer_centroids: List[Dict[str, Any]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 8), +) -> Optional["plt.Figure"]: + """ + Plot each metric's centroid value vs depth for each cluster type. + + Shows how log(RQ), Redundancy, and Synergy change with depth + for each functional type. + """ + if not HAS_MPL: + return None + + if not layer_centroids: + return None + + fig, axes = plt.subplots(3, 1, figsize=figsize, sharex=True) + + types = ["critical", "redundant", "synergistic", "background"] + metrics = [("log_rq", "log(RQ) - Alignment"), + ("red", "Redundancy"), + ("syn", "Synergy")] + + # Collect data + data = {t: {"depth": [], "log_rq": [], "red": [], "syn": []} for t in types} + + for layer_data in layer_centroids: + centroids = np.array(layer_data["centroids"]) + type_mapping = layer_data.get("type_mapping", {}) + depth = layer_data.get("depth", 0) + + for cluster_id, cluster_type in type_mapping.items(): + if cluster_type in data and int(cluster_id) < len(centroids): + c = centroids[int(cluster_id)] + data[cluster_type]["depth"].append(depth) + data[cluster_type]["log_rq"].append(c[0]) + data[cluster_type]["red"].append(c[1]) + data[cluster_type]["syn"].append(c[2]) + + for ax, (metric_key, metric_label) in zip(axes, metrics): + for ctype in types: + d = data[ctype] + if not d["depth"]: + continue + + sorted_idx = np.argsort(d["depth"]) + depths = np.array(d["depth"])[sorted_idx] + values = np.array(d[metric_key])[sorted_idx] + + color = CLUSTER_COLORS.get(ctype, "#999999") + ax.plot(depths, values, 'o-', color=color, label=ctype, + linewidth=2, markersize=6) + + ax.set_ylabel(metric_label, fontsize=11) + ax.grid(True, alpha=0.3) + ax.legend(loc='best', fontsize=9, ncol=2) + + axes[-1].set_xlabel('Layer Depth', fontsize=11) + fig.suptitle('Cluster Centroid Metrics vs Network Depth', fontsize=13, fontweight='bold') + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved centroid depth profiles to {save_path}") + + return fig + + +def plot_pruning_comparison( + results: Dict[str, Dict[float, Dict[str, float]]], + baseline_acc: float, + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 6), +) -> Optional["plt.Figure"]: + """ + Plot pruning accuracy comparison across methods and sparsity levels. + + Args: + results: Dict mapping method -> {ratio -> {'accuracy_after_ft': float}} + baseline_acc: Baseline (unpruned) accuracy + save_path: Optional path to save figure + """ + if not HAS_MPL: + return None + + fig, ax = plt.subplots(figsize=figsize) + + # Colors for methods + method_colors = { + 'random': '#95a5a6', + 'magnitude': '#e74c3c', + 'taylor': '#3498db', + 'composite': '#9b59b6', + 'cluster_aware': '#2ecc71', + 'network_slimming': '#f39c12', + 'chip': '#1abc9c', + } + + method_markers = { + 'random': 'o', + 'magnitude': 's', + 'taylor': '^', + 'composite': 'd', + 'cluster_aware': '*', + 'network_slimming': 'v', + 'chip': 'p', + } + + for method, ratio_results in results.items(): + if not ratio_results: + continue + + ratios = sorted(ratio_results.keys()) + accs = [] + for r in ratios: + data = ratio_results[r] + if isinstance(data, dict) and 'accuracy_after_ft' in data: + accs.append(data['accuracy_after_ft'] * 100) + elif isinstance(data, dict) and 'error' not in data: + accs.append(0) + else: + accs.append(None) + + # Filter out None values + valid = [(r, a) for r, a in zip(ratios, accs) if a is not None] + if not valid: + continue + + ratios_plot, accs_plot = zip(*valid) + ratios_pct = [r * 100 for r in ratios_plot] + + color = method_colors.get(method, '#333333') + marker = method_markers.get(method, 'o') + label = method.replace('_', ' ').title() + + ax.plot(ratios_pct, accs_plot, marker=marker, color=color, + label=label, linewidth=2, markersize=8) + + # Add baseline + ax.axhline(y=baseline_acc * 100, color='gray', linestyle='--', + label=f'Unpruned ({baseline_acc*100:.1f}%)', linewidth=1.5) + + ax.set_xlabel('Channel Sparsity (%)', fontsize=12) + ax.set_ylabel('Test Accuracy (%)', fontsize=12) + ax.set_title('Pruning Method Comparison', fontsize=14) + ax.legend(loc='lower left', fontsize=10) + ax.grid(True, alpha=0.3) + + # Set reasonable y-axis limits + ax.set_ylim([60, 100]) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved pruning comparison to {save_path}") + + return fig diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index d7d64128..321b39ae 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -362,7 +362,7 @@ def _map_nested_to_flat_config(nested_config: Dict[str, Any]) -> Dict[str, Any]: flat_config["alignment_composite_weights"] = alignment_block["composite_weights"] flat_config["supernode_config"] = nested_config.get("supernode", {}) - + # Map supernode-related nested configs directly flat_config["supernode"] = nested_config.get("supernode", {}) flat_config["supernode_robustness"] = nested_config.get("supernode_robustness", {}) diff --git a/src/alignment/experiments/__init__.py b/src/alignment/experiments/__init__.py index 82e9ddb3..2eeddc4c 100644 --- a/src/alignment/experiments/__init__.py +++ b/src/alignment/experiments/__init__.py @@ -19,6 +19,12 @@ ) from .general_alignment import GeneralAlignmentConfig, GeneralAlignmentExperiment from .llm_experiments import LLMAlignmentExperiment +from .cluster_experiments import ( + ClusterAnalysisExperiment, + ClusterAnalysisConfig, + VisionExperiment, # backward compat + VisionExperimentConfig, # backward compat +) # Training utilities from .training_utils import convert_training_history, create_experiment_trainer, evaluate_with_metrics, train_with_metrics @@ -31,6 +37,10 @@ "GeneralAlignmentExperiment", "GeneralAlignmentConfig", "LLMAlignmentExperiment", + "ClusterAnalysisExperiment", + "ClusterAnalysisConfig", + "VisionExperiment", # backward compat alias + "VisionExperimentConfig", # backward compat alias # Configuration components "TrainingConfig", "PruningConfig", diff --git a/src/alignment/experiments/base.py b/src/alignment/experiments/base.py index a6202b38..00d6a2f5 100644 --- a/src/alignment/experiments/base.py +++ b/src/alignment/experiments/base.py @@ -168,7 +168,7 @@ class ExperimentConfig: do_scar_metrics: bool = False # Whether to compute SCAR-style supernode metrics (T_i, R_i, L_i) scar_num_samples: int = 0 # Number of calibration samples for SCAR (0 => align with alignment_data_num_samples) scar_max_length: int = 512 # Max sequence length for SCAR calibration passes - + # Supernode analysis configs (nested dicts from YAML) supernode: Dict[str, Any] = field(default_factory=dict) # Core supernode analysis config supernode_robustness: Dict[str, Any] = field(default_factory=dict) # Robustness analysis config diff --git a/src/alignment/metrics/information/__init__.py b/src/alignment/metrics/information/__init__.py index c48d56a8..35bf063e 100644 --- a/src/alignment/metrics/information/__init__.py +++ b/src/alignment/metrics/information/__init__.py @@ -12,6 +12,7 @@ from .pid import UniqueInformationX, UniqueInformationY from .redundancy import AverageRedundancy from .synergy_mmi import SynergyGaussianMMI +from .synergy_continuous import SynergyContinuousTarget # Import higher-order metrics if available try: diff --git a/src/alignment/metrics/information/synergy_continuous.py b/src/alignment/metrics/information/synergy_continuous.py new file mode 100644 index 00000000..ec93383a --- /dev/null +++ b/src/alignment/metrics/information/synergy_continuous.py @@ -0,0 +1,267 @@ +""" +Synergy metric using continuous target (logit margin). + +This addresses the mixed discrete-continuous MI issue by using a continuous +decision variable T (e.g., logit margin) instead of discrete labels Z. + +Under Gaussian approximation, all MI terms can be computed from covariances +of (T, Y_i, Y_j), making the synergy computation well-defined. + +References: +- Barrett (2015): Gaussian PID +- Williams & Beer (2010): PID foundations +""" + +import logging +from typing import Any, Optional, Literal + +import torch +import numpy as np + +from ...core.base import BaseMetric +from ...core.registry import register_metric + +logger = logging.getLogger(__name__) + + +@register_metric("synergy_continuous_target") +class SynergyContinuousTarget(BaseMetric): + """ + Compute per-neuron synergy using Gaussian MI with continuous target. + + Instead of discrete labels Z, we use a continuous decision variable T: + - 'logit_margin': T = f_z(x) - max_{c!=z} f_c(x) (correct - max incorrect) + - 'correct_logit': T = f_z(x) (correct class logit) + - 'logit_pc1': T = PC1 of logits (first principal component) + + This allows proper Gaussian MI estimation for synergy: + S(T; Y_i, Y_j) = I(T; Y_i, Y_j) - I(T; Y_i) - I(T; Y_j) + min(I(T; Y_i), I(T; Y_j)) + + Example: + >>> metric = SynergyContinuousTarget(target_type='logit_margin', num_pairs=10) + >>> synergy = metric.compute(outputs=activations, logits=logits, labels=labels) + """ + + def __init__( + self, + target_type: Literal['logit_margin', 'correct_logit', 'logit_pc1'] = 'logit_margin', + num_pairs: int = 10, + sampling_strategy: str = 'top_k', # 'random', 'top_k', 'all' + eps: float = 1e-8, + **config: Any, + ): + """ + Initialize synergy metric. + + Args: + target_type: Type of continuous target to use + num_pairs: Number of partner neurons per neuron + sampling_strategy: 'random', 'top_k' (highest synergy), 'all' + eps: Numerical stability epsilon + """ + super().__init__(**config) + self.target_type = target_type + self.num_pairs = num_pairs + self.sampling_strategy = sampling_strategy + self.eps = eps + + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return False + + @property + def requires_outputs(self) -> bool: + return True + + def compute( + self, + outputs: torch.Tensor, # [batch, n_neurons] + logits: torch.Tensor, # [batch, n_classes] + labels: torch.Tensor, # [batch] + inputs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Compute per-neuron synergy scores. + + Args: + outputs: Layer activations [batch, n_neurons] + logits: Model logits [batch, n_classes] + labels: True labels [batch] + + Returns: + Per-neuron synergy scores [n_neurons] + """ + device = outputs.device + dtype = outputs.dtype + + # Flatten if needed + if outputs.ndim > 2: + # Conv layer: [B, C, H, W] -> [B, C] via GAP + outputs = outputs.mean(dim=(2, 3)) if outputs.ndim == 4 else outputs.reshape(outputs.shape[0], -1) + + batch_size, n_neurons = outputs.shape + + # Compute continuous target T + T = self._compute_target(logits, labels) # [batch] + + if T is None: + logger.warning("Could not compute continuous target") + return torch.zeros(n_neurons, device=device, dtype=dtype) + + # Compute synergy for each neuron + synergy = torch.zeros(n_neurons, device=device, dtype=dtype) + + # Precompute individual MIs with target + mi_individual = self._compute_mi_batch(T, outputs) # [n_neurons] + + for i in range(n_neurons): + partners = self._sample_partners(i, n_neurons, mi_individual) + if len(partners) == 0: + continue + + syn_values = [] + for j in partners: + s = self._compute_pairwise_synergy( + T, outputs[:, i], outputs[:, j], + mi_individual[i], mi_individual[j] + ) + syn_values.append(s) + + if syn_values: + synergy[i] = torch.stack(syn_values).mean() + + return synergy + + def _compute_target(self, logits: torch.Tensor, labels: torch.Tensor) -> Optional[torch.Tensor]: + """Compute continuous target variable T.""" + batch_size, n_classes = logits.shape + + if self.target_type == 'logit_margin': + # T = correct_logit - max_incorrect_logit + correct_logits = logits[torch.arange(batch_size), labels] + # Mask out correct class + mask = torch.ones_like(logits, dtype=torch.bool) + mask[torch.arange(batch_size), labels] = False + max_incorrect = logits.masked_fill(~mask, float('-inf')).max(dim=1)[0] + T = correct_logits - max_incorrect + + elif self.target_type == 'correct_logit': + T = logits[torch.arange(batch_size), labels] + + elif self.target_type == 'logit_pc1': + # First principal component of logits + logits_centered = logits - logits.mean(dim=0) + _, _, V = torch.linalg.svd(logits_centered, full_matrices=False) + T = (logits_centered @ V[0]).squeeze() + + else: + logger.warning(f"Unknown target type: {self.target_type}") + return None + + return T + + def _compute_mi_batch(self, T: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + """ + Compute I(T; Y_i) for all neurons using Gaussian approximation. + + I(T; Y) = 0.5 * log(var(T) / var(T|Y)) + = 0.5 * log(1 / (1 - rho^2)) + = -0.5 * log(1 - rho^2) + """ + T = T.float() + Y = Y.float() + + # Compute correlations + T_centered = T - T.mean() + Y_centered = Y - Y.mean(dim=0) + + T_std = T_centered.std() + self.eps + Y_std = Y_centered.std(dim=0) + self.eps + + # Pearson correlation + rho = (T_centered[:, None] * Y_centered).mean(dim=0) / (T_std * Y_std) + rho = torch.clamp(rho, -1 + self.eps, 1 - self.eps) + + # Gaussian MI + mi = -0.5 * torch.log(1 - rho ** 2) + return torch.clamp(mi, min=0.0) + + def _compute_pairwise_synergy( + self, + T: torch.Tensor, + Y_i: torch.Tensor, + Y_j: torch.Tensor, + mi_i: torch.Tensor, + mi_j: torch.Tensor, + ) -> torch.Tensor: + """ + Compute synergy S(T; Y_i, Y_j) using MMI redundancy axiom. + + S = I(T; Y_i, Y_j) - I(T; Y_i) - I(T; Y_j) + min(I(T; Y_i), I(T; Y_j)) + """ + # Joint MI via 3x3 covariance + joint = torch.stack([T, Y_i, Y_j], dim=1) # [batch, 3] + joint_centered = joint - joint.mean(dim=0) + cov = (joint_centered.T @ joint_centered) / (joint.shape[0] - 1 + self.eps) + + # Add regularization + cov = cov + self.eps * torch.eye(3, device=cov.device, dtype=cov.dtype) + + # I(T; [Y_i, Y_j]) = 0.5 * log(det(cov_T) * det(cov_YiYj) / det(cov_all)) + var_T = cov[0, 0] + cov_Y = cov[1:, 1:] + det_all = torch.linalg.det(cov) + det_Y = torch.linalg.det(cov_Y) + + if det_all <= 0 or det_Y <= 0 or var_T <= 0: + return torch.tensor(0.0, device=T.device) + + mi_joint = 0.5 * torch.log(var_T * det_Y / det_all) + mi_joint = torch.clamp(mi_joint, min=0.0) + + # MMI redundancy + redundancy = torch.min(mi_i, mi_j) + + # Synergy + synergy = mi_joint - mi_i - mi_j + redundancy + + return synergy + + def _sample_partners( + self, + neuron_idx: int, + n_neurons: int, + mi_scores: Optional[torch.Tensor] = None, + ) -> list: + """Sample partner neurons for synergy computation.""" + available = [j for j in range(n_neurons) if j != neuron_idx] + + if len(available) == 0: + return [] + + k = min(self.num_pairs, len(available)) + + if self.sampling_strategy == 'all': + return available + + elif self.sampling_strategy == 'random': + idx = torch.randperm(len(available))[:k] + return [available[i] for i in idx.tolist()] + + elif self.sampling_strategy == 'top_k': + # Sample partners with highest MI (likely to have interesting synergy) + if mi_scores is not None: + partner_mi = mi_scores[available] + _, top_idx = torch.topk(partner_mi, k) + return [available[i] for i in top_idx.tolist()] + else: + idx = torch.randperm(len(available))[:k] + return [available[i] for i in idx.tolist()] + + return [] diff --git a/src/alignment/pruning/strategies/__init__.py b/src/alignment/pruning/strategies/__init__.py index f41f9be0..d15d3e53 100644 --- a/src/alignment/pruning/strategies/__init__.py +++ b/src/alignment/pruning/strategies/__init__.py @@ -4,6 +4,7 @@ from .alignment_based import AlignmentPruning, GlobalAlignmentPruning, HybridPruning from .cascading import CascadingAlignmentPruning +from .cluster_aware import ClusterAwarePruning, ClusterAwarePruningConfig, CompositePruning from .gradient import FisherPruning, GradientPruning, MomentumPruning from .llm_baselines import WandaPruning, SparseGPTPruning from .magnitude import GlobalMagnitudePruning, IterativeMagnitudePruning, MagnitudePruning @@ -34,6 +35,10 @@ "HybridPruning", "GlobalAlignmentPruning", "CascadingAlignmentPruning", + # Cluster-aware (vision paper) + "ClusterAwarePruning", + "ClusterAwarePruningConfig", + "CompositePruning", # LLM Baselines (Wanda, SparseGPT) "WandaPruning", "SparseGPTPruning", diff --git a/src/alignment/pruning/strategies/cluster_aware.py b/src/alignment/pruning/strategies/cluster_aware.py new file mode 100644 index 00000000..11f5cf19 --- /dev/null +++ b/src/alignment/pruning/strategies/cluster_aware.py @@ -0,0 +1,552 @@ +""" +Cluster-aware pruning strategy with halo scoring and cluster constraints. + +This implements the cluster-and-halo pruning approach from the vision paper: + +Score_i = α·log(RQ_i) + β·Syn_i - γ·Red_i + λ·HaloSyn_i + +With constraints: +1. Protect critical: prune at most p_C fraction from critical cluster per layer +2. Target redundant/background: prioritize pruning from these clusters +3. Synergy-pair constraint: don't prune both members of top synergistic pairs + +References: +- Channel Clusters and Halo Dependencies for Structured Pruning (ICML 2026) +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from ..base import BasePruningStrategy, PruningConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class ClusterAwarePruningConfig(PruningConfig): + """Configuration for cluster-aware pruning.""" + + # Score weights (Eq. 14 in paper) + alpha: float = 1.0 # Weight for log(RQ) + beta: float = 0.5 # Weight for Synergy + gamma: float = 0.3 # Weight for Redundancy (subtracted) + lambda_halo: float = 0.5 # Weight for HaloSyn term + + # Cluster constraints + protect_critical_frac: float = 0.3 # Max fraction of critical to prune per layer + target_redundant: bool = True # Prioritize pruning redundant/background + + # Synergy-pair constraint + synergy_pair_constraint: bool = True + top_synergy_pairs: int = 10 # Number of top synergy pairs to protect + + # Halo parameters + halo_percentile: float = 90.0 + use_activation_weight: bool = True + + # Clustering + n_clusters: int = 4 + + # Structured pruning (default True for channels) + structured: bool = True + + +class ClusterAwarePruning(BasePruningStrategy): + """ + Cluster-aware structured pruning with halo scoring. + + This strategy: + 1. Computes per-channel metrics (RQ, Redundancy, Synergy) + 2. Clusters channels into functional types (Critical, Redundant, Synergistic, Background) + 3. Computes downstream halo synergy for each channel + 4. Scores channels using composite formula with halo term + 5. Applies cluster constraints during selection + + Example: + >>> config = ClusterAwarePruningConfig(amount=0.5) + >>> strategy = ClusterAwarePruning(config) + >>> scores = strategy.compute_importance_scores( + ... module, inputs=activations, logits=logits, labels=labels + ... ) + """ + + def __init__( + self, + config: Optional[ClusterAwarePruningConfig] = None, + precomputed_metrics: Optional[Dict[str, np.ndarray]] = None, + precomputed_clusters: Optional[Dict[str, Any]] = None, + precomputed_halos: Optional[Dict[str, np.ndarray]] = None, + ): + """ + Initialize cluster-aware pruning. + + Args: + config: Pruning configuration + precomputed_metrics: Optional dict with 'rq', 'redundancy', 'synergy' arrays + precomputed_clusters: Optional dict with 'labels', 'type_mapping' + precomputed_halos: Optional dict with 'halo_syn' per-channel array + """ + super().__init__(config or ClusterAwarePruningConfig()) + self.config: ClusterAwarePruningConfig + + self.precomputed_metrics = precomputed_metrics + self.precomputed_clusters = precomputed_clusters + self.precomputed_halos = precomputed_halos + + # Cache for computed values + self._metrics_cache = {} + self._cluster_cache = {} + self._halo_cache = {} + + def compute_importance_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + outputs: Optional[torch.Tensor] = None, + logits: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_layer_weights: Optional[torch.Tensor] = None, + next_layer_metrics: Optional[Dict[str, np.ndarray]] = None, + layer_name: str = "", + **kwargs: Any, + ) -> torch.Tensor: + """ + Compute cluster-aware importance scores. + + Args: + module: Module to compute scores for + inputs: Input activations [batch, channels, ...] + outputs: Output activations [batch, channels, ...] + logits: Model logits [batch, n_classes] + labels: True labels [batch] + next_layer_weights: Weights of next layer for halo computation + next_layer_metrics: Metrics of next layer channels for halo synergy + layer_name: Layer identifier for caching + + Returns: + Per-channel importance scores [n_channels] + """ + # Get number of channels + if hasattr(module, 'weight'): + n_channels = module.weight.shape[0] + elif outputs is not None: + n_channels = outputs.shape[1] + else: + raise ValueError("Cannot determine number of channels") + + device = module.weight.device if hasattr(module, 'weight') else 'cpu' + + # 1. Get or compute per-channel metrics + metrics = self._get_metrics( + module, inputs, outputs, logits, labels, n_channels, layer_name + ) + + # 2. Get or compute clusters + clusters = self._get_clusters(metrics, n_channels, layer_name) + + # 3. Get or compute halo synergy + halo_syn = self._get_halo_syn( + module, outputs, next_layer_weights, next_layer_metrics, + clusters, n_channels, layer_name + ) + + # 4. Compute composite scores (Eq. 14) + log_rq = np.log(np.clip(metrics['rq'], 1e-10, None)) + + # Normalize each component to [0, 1] for stable weighting + log_rq_norm = self._normalize(log_rq) + syn_norm = self._normalize(metrics['synergy']) + red_norm = self._normalize(metrics['redundancy']) + halo_syn_norm = self._normalize(halo_syn) + + scores = ( + self.config.alpha * log_rq_norm + + self.config.beta * syn_norm - + self.config.gamma * red_norm + + self.config.lambda_halo * halo_syn_norm + ) + + return torch.from_numpy(scores).float().to(device) + + def select_channels_to_prune( + self, + scores: torch.Tensor, + n_prune: int, + layer_name: str = "", + ) -> List[int]: + """ + Select channels to prune with cluster constraints. + + Args: + scores: Per-channel importance scores [n_channels] + n_prune: Number of channels to prune + layer_name: Layer identifier for cached clusters + + Returns: + List of channel indices to prune + """ + n_channels = len(scores) + scores_np = scores.cpu().numpy() + + # Get cluster info + clusters = self._cluster_cache.get(layer_name, {}) + labels = clusters.get('labels', np.zeros(n_channels, dtype=int)) + type_mapping = clusters.get('type_mapping', {0: 'unknown'}) + + # Invert type_mapping for lookup + type_to_id = {v: k for k, v in type_mapping.items()} + + # Initialize selection + selected = set() + protected = set() + + # 1. Apply critical protection constraint + if self.config.protect_critical_frac < 1.0: + critical_id = type_to_id.get('critical', -1) + if critical_id >= 0: + critical_idx = np.where(labels == critical_id)[0] + max_prune_critical = int(len(critical_idx) * self.config.protect_critical_frac) + # Sort critical by score (ascending = low score first to prune) + critical_sorted = sorted(critical_idx, key=lambda i: scores_np[i]) + # Protect the top (1 - frac) of critical channels + protected.update(critical_sorted[max_prune_critical:]) + + # 2. Apply synergy-pair constraint + if self.config.synergy_pair_constraint: + metrics = self._metrics_cache.get(layer_name, {}) + synergy_pairs = self._get_top_synergy_pairs( + metrics, self.config.top_synergy_pairs + ) + # Don't prune both members of a pair + # We'll handle this during selection + else: + synergy_pairs = [] + + # 3. Sort channels by score (ascending - low scores get pruned) + # But prioritize redundant/background + if self.config.target_redundant: + redundant_id = type_to_id.get('redundant', -1) + background_id = type_to_id.get('background', -1) + + # Create priority: redundant/background first, then others + priority = np.zeros(n_channels) + if redundant_id >= 0: + priority[labels == redundant_id] = -1 # Higher priority to prune + if background_id >= 0: + priority[labels == background_id] = -1 + + # Combined score: priority first, then actual score + combined = list(zip(priority, scores_np, range(n_channels))) + combined.sort() # Sort by (priority, score) + sorted_idx = [i for _, _, i in combined] + else: + sorted_idx = np.argsort(scores_np).tolist() + + # 4. Select channels respecting constraints + pair_set = set() + for i, j in synergy_pairs: + pair_set.add((min(i, j), max(i, j))) + + for idx in sorted_idx: + if len(selected) >= n_prune: + break + + # Skip protected channels + if idx in protected: + continue + + # Check synergy-pair constraint + if self.config.synergy_pair_constraint: + pair_conflict = False + for i, j in pair_set: + if (idx == i and j in selected) or (idx == j and i in selected): + pair_conflict = True + break + if pair_conflict: + continue + + selected.add(idx) + + return list(selected) + + def _get_metrics( + self, + module: nn.Module, + inputs: Optional[torch.Tensor], + outputs: Optional[torch.Tensor], + logits: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + n_channels: int, + layer_name: str, + ) -> Dict[str, np.ndarray]: + """Get or compute per-channel metrics.""" + if layer_name in self._metrics_cache: + return self._metrics_cache[layer_name] + + if self.precomputed_metrics is not None: + self._metrics_cache[layer_name] = self.precomputed_metrics + return self.precomputed_metrics + + # Compute metrics from scratch + metrics = { + 'rq': np.ones(n_channels), + 'redundancy': np.zeros(n_channels), + 'synergy': np.zeros(n_channels), + } + + if outputs is None: + logger.warning("No outputs provided, using default metrics") + self._metrics_cache[layer_name] = metrics + return metrics + + # Flatten spatial dims if needed + if outputs.ndim == 4: + # Conv: [B, C, H, W] -> [B, C] + acts = outputs.mean(dim=(2, 3)).detach().cpu().numpy() + else: + acts = outputs.detach().cpu().numpy() + + # 1. RQ proxy (activation variance / weight norm^2) + var = np.var(acts, axis=0) + if hasattr(module, 'weight'): + w = module.weight.detach().cpu().numpy() + w_flat = w.reshape(w.shape[0], -1) + w_norm_sq = np.sum(w_flat ** 2, axis=1) + metrics['rq'] = var / (w_norm_sq[:len(var)] + 1e-10) + else: + metrics['rq'] = var + + # 2. Redundancy (Gaussian pairwise MI) + if n_channels > 1: + corr = np.corrcoef(acts.T) + corr = np.clip(corr, -0.999, 0.999) + mi_matrix = -0.5 * np.log(1 - corr ** 2) + np.fill_diagonal(mi_matrix, 0) + metrics['redundancy'] = np.mean(mi_matrix, axis=1) + + # 3. Synergy with continuous target + if logits is not None and labels is not None: + logits_np = logits.detach().cpu().numpy() + labels_np = labels.detach().cpu().numpy() + + # Logit margin as target + batch_size = logits_np.shape[0] + correct_logits = logits_np[np.arange(batch_size), labels_np] + mask = np.ones_like(logits_np, dtype=bool) + mask[np.arange(batch_size), labels_np] = False + max_incorrect = np.where(mask, logits_np, -np.inf).max(axis=1) + T = correct_logits - max_incorrect + + synergy = np.zeros(n_channels) + for i in range(n_channels): + mi_i = self._gaussian_mi(T, acts[:, i]) + # Top-k partners by redundancy + partners = np.argsort(-metrics['redundancy'])[:10] + syn_vals = [] + for j in partners: + if i == j: + continue + mi_j = self._gaussian_mi(T, acts[:, j]) + mi_joint = self._gaussian_mi_joint(T, acts[:, i], acts[:, j]) + s = mi_joint - mi_i - mi_j + min(mi_i, mi_j) + syn_vals.append(s) + synergy[i] = np.mean(syn_vals) if syn_vals else 0. + metrics['synergy'] = synergy + + self._metrics_cache[layer_name] = metrics + return metrics + + def _get_clusters( + self, + metrics: Dict[str, np.ndarray], + n_channels: int, + layer_name: str, + ) -> Dict[str, Any]: + """Get or compute clusters.""" + if layer_name in self._cluster_cache: + return self._cluster_cache[layer_name] + + if self.precomputed_clusters is not None: + self._cluster_cache[layer_name] = self.precomputed_clusters + return self.precomputed_clusters + + # Cluster using MetricSpaceClustering + from ...analysis.clustering import MetricSpaceClustering + + clusterer = MetricSpaceClustering( + n_clusters=self.config.n_clusters, + seed=42, + ) + result = clusterer.fit( + metrics['rq'], + metrics['redundancy'], + metrics['synergy'], + layer_name, + ) + + clusters = { + 'labels': result.labels, + 'centroids': result.centroids, + 'type_mapping': result.type_mapping, + 'type_counts': result.type_counts, + } + + self._cluster_cache[layer_name] = clusters + return clusters + + def _get_halo_syn( + self, + module: nn.Module, + outputs: Optional[torch.Tensor], + next_layer_weights: Optional[torch.Tensor], + next_layer_metrics: Optional[Dict[str, np.ndarray]], + clusters: Dict[str, Any], + n_channels: int, + layer_name: str, + ) -> np.ndarray: + """Compute per-channel halo synergy.""" + if layer_name in self._halo_cache: + return self._halo_cache[layer_name] + + if self.precomputed_halos is not None: + halo_syn = self.precomputed_halos.get('halo_syn', np.zeros(n_channels)) + self._halo_cache[layer_name] = halo_syn + return halo_syn + + # If no next layer info, return zeros + if next_layer_weights is None or next_layer_metrics is None: + halo_syn = np.zeros(n_channels) + self._halo_cache[layer_name] = halo_syn + return halo_syn + + # Compute halo for each channel + from ...analysis.clustering import CrossLayerHaloAnalysis + + halo_analyzer = CrossLayerHaloAnalysis( + percentile=self.config.halo_percentile, + use_activation_weight=self.config.use_activation_weight, + ) + + # Get weights and compute influence + w_np = next_layer_weights.detach().cpu().numpy() + if w_np.ndim == 4: + # Conv: [out, in, k, k] -> [out, in] + influence = np.abs(w_np).sum(axis=(2, 3)) + else: + influence = np.abs(w_np) + + # Weight by activation std if available + if outputs is not None and self.config.use_activation_weight: + if outputs.ndim == 4: + acts = outputs.mean(dim=(2, 3)).detach().cpu().numpy() + else: + acts = outputs.detach().cpu().numpy() + std = np.std(acts, axis=0) + n_in = min(influence.shape[1], len(std)) + influence[:, :n_in] = influence[:, :n_in] * std[:n_in] + + # Get next layer synergy + next_syn = next_layer_metrics.get('synergy', np.zeros(influence.shape[0])) + + # Per-channel halo synergy + halo_syn = np.zeros(n_channels) + for i in range(min(n_channels, influence.shape[1])): + # Find receivers that get high influence from this channel + infl_i = influence[:, i] + total_infl = influence.sum(axis=1) + 1e-10 + rel_infl = infl_i / total_infl + thresh = np.percentile(rel_infl, self.config.halo_percentile) + halo_mask = rel_infl >= thresh + if halo_mask.sum() > 0: + halo_syn[i] = np.mean(next_syn[halo_mask]) + + self._halo_cache[layer_name] = halo_syn + return halo_syn + + def _get_top_synergy_pairs( + self, + metrics: Dict[str, np.ndarray], + top_k: int, + ) -> List[Tuple[int, int]]: + """Get top synergy pairs for constraint.""" + synergy = metrics.get('synergy', np.array([])) + n = len(synergy) + if n < 2: + return [] + + # Simple heuristic: pair high-synergy channels + top_idx = np.argsort(-synergy)[:min(top_k * 2, n)] + pairs = [] + for i in range(0, len(top_idx) - 1, 2): + pairs.append((int(top_idx[i]), int(top_idx[i + 1]))) + return pairs[:top_k] + + def _normalize(self, x: np.ndarray) -> np.ndarray: + """Normalize array to [0, 1].""" + x_min, x_max = x.min(), x.max() + if x_max > x_min: + return (x - x_min) / (x_max - x_min) + return x + + def _gaussian_mi(self, x: np.ndarray, y: np.ndarray) -> float: + """Compute Gaussian MI between two variables.""" + rho = np.corrcoef(x, y)[0, 1] + rho = np.clip(rho, -0.999, 0.999) + return max(0, -0.5 * np.log(1 - rho ** 2)) + + def _gaussian_mi_joint(self, t: np.ndarray, y1: np.ndarray, y2: np.ndarray) -> float: + """Compute Gaussian MI I(T; [Y1, Y2]).""" + joint = np.column_stack([t, y1, y2]) + cov = np.cov(joint.T) + 1e-8 * np.eye(3) + var_t = cov[0, 0] + cov_y = cov[1:, 1:] + det_all = np.linalg.det(cov) + det_y = np.linalg.det(cov_y) + if det_all <= 0 or det_y <= 0 or var_t <= 0: + return 0. + return max(0, 0.5 * np.log(var_t * det_y / det_all)) + + +class CompositePruning(ClusterAwarePruning): + """ + Composite per-channel pruning (baseline without cluster constraints). + + Uses the same scoring as ClusterAwarePruning but without: + - Cluster constraints (protect critical, target redundant) + - Synergy-pair constraints + - Halo term (lambda = 0) + + This corresponds to the "Composite" baseline in the paper. + """ + + def __init__( + self, + config: Optional[ClusterAwarePruningConfig] = None, + **kwargs, + ): + if config is None: + config = ClusterAwarePruningConfig() + + # Disable cluster constraints and halo + config.protect_critical_frac = 1.0 # No protection + config.target_redundant = False + config.synergy_pair_constraint = False + config.lambda_halo = 0.0 # No halo term + + super().__init__(config, **kwargs) + + def select_channels_to_prune( + self, + scores: torch.Tensor, + n_prune: int, + layer_name: str = "", + ) -> List[int]: + """Simple selection by score (no constraints).""" + scores_np = scores.cpu().numpy() + sorted_idx = np.argsort(scores_np) + return sorted_idx[:n_prune].tolist() diff --git a/src/alignment/training/README.md b/src/alignment/training/README.md index 949680af..46fffa19 100644 --- a/src/alignment/training/README.md +++ b/src/alignment/training/README.md @@ -1,24 +1,65 @@ # Training Module -Training utilities and callbacks. +Training utilities, trainers, and evaluation functions. ## Components -- `BaseTrainer` - Training loop with metric tracking -- `AlignmentMetricsCallback` - Track alignment during training +### Trainers +- `BaseTrainer` - Base trainer class +- `ExperimentTrainer` - Trainer for alignment experiments +- `TensorizedNetworkWrapper` - Multi-network training + +### Evaluation +- `evaluate_classification()` - Classification accuracy and loss +- `evaluate_perplexity()` - Language model perplexity +- `evaluate_regression()` - Regression MSE and MAE +- `evaluate_model()` - General dispatcher +- `EvaluationManager` - Track evaluation metrics over time + +### Callbacks +- `AlignmentCallback` - Track alignment metrics during training ## Usage +### Training + ```python -from alignment.training.callbacks import AlignmentMetricsCallback +from alignment.training import ExperimentTrainer, ExperimentTrainingConfig -callback = AlignmentMetricsCallback( - metrics={'rq': get_metric('rayleigh_quotient')}, - layers=['conv1'], - frequency=100 +config = ExperimentTrainingConfig( + epochs=10, + learning_rate=0.001, + batch_size=128 ) +trainer = ExperimentTrainer(model, config) +trainer.train(train_loader, val_loader) +``` + +### Evaluation + +```python +from alignment.training import evaluate_classification, evaluate_perplexity + +# Classification +results = evaluate_classification(model, test_loader, device="cuda") +# Returns: {"loss": 0.32, "accuracy": 91.5} + +# Language modeling +results = evaluate_perplexity(model, text_loader, device="cuda") +# Returns: {"perplexity": 12.4, "loss": 2.52} +``` + +### Evaluation Manager + +```python +from alignment.training import EvaluationManager + +manager = EvaluationManager(task="classification") -# In training loop -callback.on_batch_end(wrapper, inputs, targets, step) -history = callback.get_history() +for epoch in range(epochs): + train(...) + results = manager.evaluate(model, val_loader, step=epoch) + +best = manager.get_best(metric="accuracy") +history = manager.get_history() ``` diff --git a/src/alignment/training/__init__.py b/src/alignment/training/__init__.py index bd175cc2..e5d8cf61 100644 --- a/src/alignment/training/__init__.py +++ b/src/alignment/training/__init__.py @@ -5,6 +5,13 @@ from .base import BaseTrainer, TrainingConfig from .experiment_trainer import ExperimentTrainer, ExperimentTrainingConfig from .multi_network import TensorizedNetworkWrapper, train_networks_fully_tensorized +from .evaluation import ( + evaluate_classification, + evaluate_perplexity, + evaluate_regression, + evaluate_model, + EvaluationManager, +) __all__ = [ "BaseTrainer", @@ -13,4 +20,10 @@ "TensorizedNetworkWrapper", "ExperimentTrainer", "ExperimentTrainingConfig", + # Evaluation + "evaluate_classification", + "evaluate_perplexity", + "evaluate_regression", + "evaluate_model", + "EvaluationManager", ] diff --git a/src/alignment/evaluation.py b/src/alignment/training/evaluation.py similarity index 100% rename from src/alignment/evaluation.py rename to src/alignment/training/evaluation.py From 7675e8bd19a0b53863baaf024609fbb8014a09be Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 8 Dec 2025 12:46:02 -0500 Subject: [PATCH 02/12] refactor visualization --- .../analysis/visualization/__init__.py | 70 +- .../analysis/visualization/cluster_plots.py | 275 ++++++-- .../analysis/visualization/metric_plots.py | 661 ++++++++++++++++++ .../analysis/visualization/pruning_plots.py | 360 ++++++++++ 4 files changed, 1298 insertions(+), 68 deletions(-) create mode 100644 src/alignment/analysis/visualization/metric_plots.py diff --git a/src/alignment/analysis/visualization/__init__.py b/src/alignment/analysis/visualization/__init__.py index 57066681..3136f996 100644 --- a/src/alignment/analysis/visualization/__init__.py +++ b/src/alignment/analysis/visualization/__init__.py @@ -6,6 +6,8 @@ - PruningVisualizer: Specialized pruning analysis plots (advanced use) - AlignmentVisualizer: Specialized alignment plots (advanced use) - HaloPlots: Halo redundancy visualization (by layer depth) +- MetricPlots: Histogram/distribution plots for metrics (RQ, Redundancy, Synergy) +- ClusterPlots: Cluster-specific visualizations (scatter, evolution, cascade) For most use cases, use UnifiedVisualizer: @@ -15,12 +17,21 @@ viz.plot_layer_scores(scores, "Rayleigh Quotient") viz.plot_pruning_before_after(sparsities, before, after) -For advanced pruning visualizations (multi-seed, ablations): +For metric histograms and distributions: - from alignment.analysis.visualization import PruningVisualizer + from alignment.analysis.visualization import ( + plot_metric_histogram, + plot_metric_violin, + plot_multi_metric_histogram, + ) + + plot_metric_histogram(rq_values, "rq", layer_name="conv1", highlight_percentile=95) + +For pruning comparisons (unified for both vision and LLM): + + from alignment.analysis.visualization import plot_unified_pruning_comparison - viz = PruningVisualizer() - viz.plot_sparsity_perplexity_curves(df, save_path="curves.png") + plot_unified_pruning_comparison(results, baseline_value=0.92, metric='accuracy') For halo redundancy analysis: @@ -36,9 +47,29 @@ from .unified_visualizer import UnifiedVisualizer, plot_quick_summary, generate_experiment_visualizations # Specialized visualizers (for advanced use cases) -from .pruning_plots import PruningVisualizer +from .pruning_plots import ( + PruningVisualizer, + # Unified pruning functions (work for both vision and LLM) + plot_unified_pruning_comparison, + plot_pruning_accuracy_loss_grid, + plot_pruning_recovery_chart, + PRUNING_METHOD_COLORS, +) from .alignment_plots import AlignmentVisualizer +# Metric distribution plots (histograms, violins, correlations) +from .metric_plots import ( + plot_metric_histogram, + plot_metric_violin, + plot_metric_boxplot, + plot_multi_metric_histogram, + plot_metric_scatter_matrix, + plot_metric_correlation_heatmap, + plot_layer_metric_heatmap, + plot_top_neurons_bar, + METRIC_COLORS, +) + # Halo redundancy plots from .halo_plots import ( plot_halo_redundancy_by_depth, @@ -53,6 +84,12 @@ plot_influence_matrix, plot_cascade_test, plot_halo_properties, + plot_pruning_comparison, + plot_metric_distributions_for_layer, + plot_layer_metric_summary, + plot_centroid_evolution, + plot_centroid_depth_profiles, + CLUSTER_COLORS, ) __all__ = [ @@ -60,9 +97,24 @@ "UnifiedVisualizer", "plot_quick_summary", "generate_experiment_visualizations", - # Specialized + # Specialized visualizers "PruningVisualizer", "AlignmentVisualizer", + # Unified pruning plots (vision + LLM) + "plot_unified_pruning_comparison", + "plot_pruning_accuracy_loss_grid", + "plot_pruning_recovery_chart", + "PRUNING_METHOD_COLORS", + # Metric distribution plots + "plot_metric_histogram", + "plot_metric_violin", + "plot_metric_boxplot", + "plot_multi_metric_histogram", + "plot_metric_scatter_matrix", + "plot_metric_correlation_heatmap", + "plot_layer_metric_heatmap", + "plot_top_neurons_bar", + "METRIC_COLORS", # Halo plots "plot_halo_redundancy_by_depth", "plot_halo_redundancy_comprehensive", @@ -73,4 +125,10 @@ "plot_influence_matrix", "plot_cascade_test", "plot_halo_properties", + "plot_pruning_comparison", + "plot_metric_distributions_for_layer", + "plot_layer_metric_summary", + "plot_centroid_evolution", + "plot_centroid_depth_profiles", + "CLUSTER_COLORS", ] diff --git a/src/alignment/analysis/visualization/cluster_plots.py b/src/alignment/analysis/visualization/cluster_plots.py index 86e5fa42..34dfafc6 100644 --- a/src/alignment/analysis/visualization/cluster_plots.py +++ b/src/alignment/analysis/visualization/cluster_plots.py @@ -1,5 +1,5 @@ """ -Cluster visualization module for vision network analysis. +Cluster visualization module for neural network analysis. Provides visualizations for: 1. Metric space scatter plots (RQ vs Red, RQ vs Syn, Red vs Syn) @@ -7,6 +7,9 @@ 3. Cross-layer influence matrices (heatmaps) 4. Cluster stability analysis 5. Cascade test results by cluster type +6. Pruning comparisons (unified with LLM experiments) + +This module works for both vision (ResNet, VGG) and LLM (Qwen, Llama) experiments. """ import logging @@ -23,6 +26,24 @@ except ImportError: HAS_MPL = False +# Import unified pruning functions +from .pruning_plots import ( + plot_unified_pruning_comparison, + plot_pruning_accuracy_loss_grid, + plot_pruning_recovery_chart, + PRUNING_METHOD_COLORS, +) + +# Import metric plotting functions +from .metric_plots import ( + plot_metric_histogram, + plot_metric_violin, + plot_multi_metric_histogram, + plot_metric_correlation_heatmap, + plot_top_neurons_bar, + METRIC_COLORS, +) + CLUSTER_COLORS = { "critical": "#e74c3c", @@ -468,86 +489,216 @@ def plot_pruning_comparison( """ Plot pruning accuracy comparison across methods and sparsity levels. + This is a convenience wrapper around plot_unified_pruning_comparison + for backward compatibility with vision experiments. + Args: results: Dict mapping method -> {ratio -> {'accuracy_after_ft': float}} baseline_acc: Baseline (unpruned) accuracy save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None """ if not HAS_MPL: return None - fig, ax = plt.subplots(figsize=figsize) + # Use the unified pruning comparison function + return plot_unified_pruning_comparison( + results=results, + baseline_value=baseline_acc, + metric='accuracy', + higher_is_better=True, + title='Pruning Method Comparison', + save_path=save_path, + figsize=figsize, + ) + + +def plot_metric_distributions_for_layer( + metrics: Dict[str, np.ndarray], + layer_name: str = "", + save_dir: Optional[Path] = None, + figsize: Tuple[int, int] = (15, 5), +) -> Optional["plt.Figure"]: + """ + Plot histograms of all metrics (RQ, Redundancy, Synergy) for a single layer. - # Colors for methods - method_colors = { - 'random': '#95a5a6', - 'magnitude': '#e74c3c', - 'taylor': '#3498db', - 'composite': '#9b59b6', - 'cluster_aware': '#2ecc71', - 'network_slimming': '#f39c12', - 'chip': '#1abc9c', - } - - method_markers = { - 'random': 'o', - 'magnitude': 's', - 'taylor': '^', - 'composite': 'd', - 'cluster_aware': '*', - 'network_slimming': 'v', - 'chip': 'p', - } - - for method, ratio_results in results.items(): - if not ratio_results: - continue - - ratios = sorted(ratio_results.keys()) - accs = [] - for r in ratios: - data = ratio_results[r] - if isinstance(data, dict) and 'accuracy_after_ft' in data: - accs.append(data['accuracy_after_ft'] * 100) - elif isinstance(data, dict) and 'error' not in data: - accs.append(0) - else: - accs.append(None) - - # Filter out None values - valid = [(r, a) for r, a in zip(ratios, accs) if a is not None] - if not valid: - continue - - ratios_plot, accs_plot = zip(*valid) - ratios_pct = [r * 100 for r in ratios_plot] + This is a convenience function that combines multiple metric histograms + into a single figure for a given layer. + + Args: + metrics: Dict with keys like 'rq', 'redundancy', 'synergy' mapping to arrays + layer_name: Name of the layer for title + save_dir: Directory to save figure (filename auto-generated) + figsize: Figure size - color = method_colors.get(method, '#333333') - marker = method_markers.get(method, 'o') - label = method.replace('_', ' ').title() + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL or not metrics: + return None + + # Use the multi-metric histogram from metric_plots + save_path = None + if save_dir: + safe_name = layer_name.replace('.', '_').replace('/', '_') + save_path = Path(save_dir) / f"metric_distributions_{safe_name}.png" + + return plot_multi_metric_histogram( + metrics=metrics, + layer_name=layer_name, + bins=30, + save_path=save_path, + figsize=figsize, + ) + + +def plot_layer_metric_summary( + layer_metrics: Dict[str, Dict[str, np.ndarray]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (14, 10), +) -> Optional["plt.Figure"]: + """ + Plot summary statistics of all metrics across all layers. + + Creates a 2x2 grid with: + - Mean values heatmap + - Standard deviation heatmap + - Metric correlations (averaged across layers) + - Layer-wise metric ranges + + Args: + layer_metrics: Dict mapping layer_name -> {metric_name -> values} + save_path: Optional path to save figure + figsize: Figure size - ax.plot(ratios_pct, accs_plot, marker=marker, color=color, - label=label, linewidth=2, markersize=8) + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL or not layer_metrics: + return None - # Add baseline - ax.axhline(y=baseline_acc * 100, color='gray', linestyle='--', - label=f'Unpruned ({baseline_acc*100:.1f}%)', linewidth=1.5) + layer_names = list(layer_metrics.keys()) + if not layer_names: + return None - ax.set_xlabel('Channel Sparsity (%)', fontsize=12) - ax.set_ylabel('Test Accuracy (%)', fontsize=12) - ax.set_title('Pruning Method Comparison', fontsize=14) - ax.legend(loc='lower left', fontsize=10) - ax.grid(True, alpha=0.3) + # Get all metric names from first layer + metric_names = list(layer_metrics[layer_names[0]].keys()) + if not metric_names: + return None - # Set reasonable y-axis limits - ax.set_ylim([60, 100]) + fig, axes = plt.subplots(2, 2, figsize=figsize) + + # 1. Mean values heatmap + ax = axes[0, 0] + mean_matrix = np.zeros((len(layer_names), len(metric_names))) + for i, layer in enumerate(layer_names): + for j, metric in enumerate(metric_names): + values = layer_metrics[layer].get(metric, []) + mean_matrix[i, j] = np.mean(values) if len(values) > 0 else 0 + + im = ax.imshow(mean_matrix, aspect='auto', cmap='YlOrRd') + ax.set_xticks(range(len(metric_names))) + ax.set_yticks(range(len(layer_names))) + ax.set_xticklabels(metric_names, rotation=45, ha='right', fontsize=9) + ax.set_yticklabels(layer_names, fontsize=8) + ax.set_title('Mean Values by Layer', fontsize=12) + plt.colorbar(im, ax=ax) + + # 2. Std values heatmap + ax = axes[0, 1] + std_matrix = np.zeros((len(layer_names), len(metric_names))) + for i, layer in enumerate(layer_names): + for j, metric in enumerate(metric_names): + values = layer_metrics[layer].get(metric, []) + std_matrix[i, j] = np.std(values) if len(values) > 0 else 0 + + im = ax.imshow(std_matrix, aspect='auto', cmap='Blues') + ax.set_xticks(range(len(metric_names))) + ax.set_yticks(range(len(layer_names))) + ax.set_xticklabels(metric_names, rotation=45, ha='right', fontsize=9) + ax.set_yticklabels(layer_names, fontsize=8) + ax.set_title('Std Dev by Layer', fontsize=12) + plt.colorbar(im, ax=ax) + + # 3. Average correlation across layers + ax = axes[1, 0] + corr_sum = np.zeros((len(metric_names), len(metric_names))) + n_valid = 0 + for layer in layer_names: + try: + data = np.column_stack([ + np.asarray(layer_metrics[layer].get(m, [])).flatten() + for m in metric_names + ]) + if data.shape[0] > 10: # Need enough samples + corr = np.corrcoef(data.T) + if not np.any(np.isnan(corr)): + corr_sum += corr + n_valid += 1 + except: + continue + if n_valid > 0: + avg_corr = corr_sum / n_valid + im = ax.imshow(avg_corr, cmap='RdBu_r', vmin=-1, vmax=1) + ax.set_xticks(range(len(metric_names))) + ax.set_yticks(range(len(metric_names))) + ax.set_xticklabels(metric_names, rotation=45, ha='right', fontsize=9) + ax.set_yticklabels(metric_names, fontsize=9) + for i in range(len(metric_names)): + for j in range(len(metric_names)): + ax.text(j, i, f'{avg_corr[i, j]:.2f}', ha='center', va='center', fontsize=8) + plt.colorbar(im, ax=ax) + ax.set_title('Average Metric Correlation', fontsize=12) + + # 4. Metric ranges across depth + ax = axes[1, 1] + for i, metric in enumerate(metric_names): + means = [] + stds = [] + for layer in layer_names: + values = layer_metrics[layer].get(metric, []) + if len(values) > 0: + means.append(np.mean(values)) + stds.append(np.std(values)) + else: + means.append(0) + stds.append(0) + + color = METRIC_COLORS.get(metric.lower(), f'C{i}') + x = np.arange(len(layer_names)) + ax.plot(x, means, 'o-', label=metric, color=color, linewidth=1.5, markersize=4) + ax.fill_between(x, np.array(means) - np.array(stds), + np.array(means) + np.array(stds), alpha=0.2, color=color) + + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha='right', fontsize=8) + ax.set_ylabel('Value (mean ± std)', fontsize=10) + ax.set_title('Metrics Across Depth', fontsize=12) + ax.legend(loc='best', fontsize=9) + ax.grid(True, alpha=0.3) + + fig.suptitle('Layer Metric Summary', fontsize=14, fontweight='bold') plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(save_path, dpi=150, bbox_inches='tight') - logger.info(f"Saved pruning comparison to {save_path}") + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved layer metric summary to {save_path}") return fig + + +# Legacy function name kept for backward compatibility +def plot_vision_pruning_comparison( + results: Dict[str, Dict[float, Dict[str, float]]], + baseline_acc: float, + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 6), +) -> Optional["plt.Figure"]: + """Alias for plot_pruning_comparison for backward compatibility.""" + return plot_pruning_comparison(results, baseline_acc, save_path, figsize) diff --git a/src/alignment/analysis/visualization/metric_plots.py b/src/alignment/analysis/visualization/metric_plots.py new file mode 100644 index 00000000..56fefbb2 --- /dev/null +++ b/src/alignment/analysis/visualization/metric_plots.py @@ -0,0 +1,661 @@ +""" +Metric visualization module for alignment analysis. + +Provides histogram, distribution, and comparison plots for metrics like: +- Rayleigh Quotient (RQ) +- Redundancy (Gaussian MI) +- Synergy (PID-based) +- Taylor Saliency +- Outlier Index + +These visualizations are used by both vision (ClusterAnalysisExperiment) and +LLM (LLMAlignmentExperiment) experiments. +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +try: + import matplotlib.pyplot as plt + from matplotlib.figure import Figure + HAS_MPL = True +except ImportError: + HAS_MPL = False + Figure = None + +try: + import seaborn as sns + HAS_SEABORN = True +except (ImportError, AttributeError): + HAS_SEABORN = False + + +# Standard color scheme for metrics +METRIC_COLORS = { + "rq": "#3498db", # Blue + "rayleigh_quotient": "#3498db", + "redundancy": "#e74c3c", # Red + "synergy": "#2ecc71", # Green + "taylor": "#9b59b6", # Purple + "magnitude": "#f39c12", # Orange + "outlier_index": "#1abc9c", # Teal + "activation": "#34495e", # Dark gray +} + + +def plot_metric_histogram( + values: np.ndarray, + metric_name: str = "metric", + layer_name: str = "", + bins: int = 50, + log_scale: bool = False, + highlight_percentile: Optional[float] = None, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (10, 6), + title: Optional[str] = None, +) -> Optional[Figure]: + """ + Plot histogram of metric values with optional percentile highlighting. + + Args: + values: 1D array of metric values + metric_name: Name of the metric (for labels and colors) + layer_name: Layer name for title + bins: Number of histogram bins + log_scale: Whether to use log scale on x-axis + highlight_percentile: If set, highlight values above this percentile + save_path: Optional path to save figure + figsize: Figure size + title: Custom title (auto-generated if None) + + Returns: + Matplotlib Figure or None if matplotlib not available + """ + if not HAS_MPL: + logger.warning("matplotlib not available for plotting") + return None + + values = np.asarray(values).flatten() + if len(values) == 0: + logger.warning(f"Empty values array for {metric_name}") + return None + + fig, ax = plt.subplots(figsize=figsize) + + # Get color for this metric + color = METRIC_COLORS.get(metric_name.lower(), "#7f8c8d") + + # Handle log scale + plot_values = np.log10(np.clip(values, 1e-10, None)) if log_scale else values + + # Plot histogram + n, bin_edges, patches = ax.hist( + plot_values, bins=bins, color=color, alpha=0.7, + edgecolor='white', linewidth=0.5 + ) + + # Highlight percentile if requested + if highlight_percentile is not None: + threshold = np.percentile(values, highlight_percentile) + threshold_plot = np.log10(max(threshold, 1e-10)) if log_scale else threshold + ax.axvline(threshold_plot, color='red', linestyle='--', linewidth=2, + label=f'{highlight_percentile}th percentile: {threshold:.4f}') + + # Color bars above threshold + for patch, left_edge in zip(patches, bin_edges[:-1]): + if left_edge >= threshold_plot: + patch.set_facecolor('#e74c3c') + patch.set_alpha(0.9) + ax.legend() + + # Add statistics + stats_text = f"n={len(values):,}\nμ={np.mean(values):.4f}\nσ={np.std(values):.4f}" + stats_text += f"\nmed={np.median(values):.4f}" + ax.text(0.98, 0.98, stats_text, transform=ax.transAxes, + verticalalignment='top', horizontalalignment='right', + fontsize=9, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + # Labels + xlabel = f"log10({metric_name})" if log_scale else metric_name.replace('_', ' ').title() + ax.set_xlabel(xlabel, fontsize=12) + ax.set_ylabel("Count", fontsize=12) + + if title is None: + title = f"{metric_name.replace('_', ' ').title()} Distribution" + if layer_name: + title += f" - {layer_name}" + ax.set_title(title, fontsize=14, fontweight='bold') + + ax.grid(True, alpha=0.3, axis='y') + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved histogram to {save_path}") + + return fig + + +def plot_metric_violin( + layer_metrics: Dict[str, np.ndarray], + metric_name: str = "metric", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (14, 6), + max_points_per_layer: int = 1000, +) -> Optional[Figure]: + """ + Plot violin plot of metric values across layers. + + Args: + layer_metrics: Dict mapping layer_name -> metric values array + metric_name: Name of the metric + save_path: Optional path to save figure + figsize: Figure size + max_points_per_layer: Subsample if more points than this + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not layer_metrics: + logger.warning(f"Empty layer_metrics for {metric_name}") + return None + + # Prepare data + plot_data = [] + layer_names = [] + + for layer_name, values in layer_metrics.items(): + values = np.asarray(values).flatten() + if len(values) == 0: + continue + + # Subsample if needed + if len(values) > max_points_per_layer: + values = np.random.choice(values, max_points_per_layer, replace=False) + + plot_data.append(values) + layer_names.append(layer_name) + + if not plot_data: + return None + + fig, ax = plt.subplots(figsize=figsize) + color = METRIC_COLORS.get(metric_name.lower(), "#7f8c8d") + + if HAS_SEABORN: + # Use seaborn for nicer violins + import pandas as pd + df_rows = [] + for layer_name, values in zip(layer_names, plot_data): + for v in values: + df_rows.append({"Layer": layer_name, "Value": v}) + df = pd.DataFrame(df_rows) + sns.violinplot(data=df, x="Layer", y="Value", ax=ax, color=color, alpha=0.7) + else: + # Fallback to matplotlib violin + parts = ax.violinplot(plot_data, positions=range(len(layer_names)), showmeans=True) + for pc in parts['bodies']: + pc.set_facecolor(color) + pc.set_alpha(0.7) + + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha='right', fontsize=9) + ax.set_ylabel(metric_name.replace('_', ' ').title(), fontsize=12) + ax.set_title(f"{metric_name.replace('_', ' ').title()} Distribution Across Layers", + fontsize=14, fontweight='bold') + ax.grid(True, alpha=0.3, axis='y') + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved violin plot to {save_path}") + + return fig + + +def plot_metric_boxplot( + layer_metrics: Dict[str, np.ndarray], + metric_name: str = "metric", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (14, 6), + show_outliers: bool = True, +) -> Optional[Figure]: + """ + Plot boxplot of metric values across layers. + + Args: + layer_metrics: Dict mapping layer_name -> metric values array + metric_name: Name of the metric + save_path: Optional path to save figure + figsize: Figure size + show_outliers: Whether to show outlier points + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not layer_metrics: + return None + + layer_names = list(layer_metrics.keys()) + data = [np.asarray(layer_metrics[l]).flatten() for l in layer_names] + + fig, ax = plt.subplots(figsize=figsize) + color = METRIC_COLORS.get(metric_name.lower(), "#7f8c8d") + + bp = ax.boxplot(data, patch_artist=True, showfliers=show_outliers) + + for patch in bp['boxes']: + patch.set_facecolor(color) + patch.set_alpha(0.7) + + ax.set_xticklabels(layer_names, rotation=45, ha='right', fontsize=9) + ax.set_ylabel(metric_name.replace('_', ' ').title(), fontsize=12) + ax.set_title(f"{metric_name.replace('_', ' ').title()} by Layer", + fontsize=14, fontweight='bold') + ax.grid(True, alpha=0.3, axis='y') + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved boxplot to {save_path}") + + return fig + + +def plot_multi_metric_histogram( + metrics: Dict[str, np.ndarray], + layer_name: str = "", + bins: int = 30, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (15, 5), +) -> Optional[Figure]: + """ + Plot histograms of multiple metrics side by side. + + Args: + metrics: Dict mapping metric_name -> values array + layer_name: Layer name for title + bins: Number of histogram bins + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not metrics: + return None + + metric_names = list(metrics.keys()) + n_metrics = len(metric_names) + + fig, axes = plt.subplots(1, n_metrics, figsize=figsize) + if n_metrics == 1: + axes = [axes] + + for ax, metric_name in zip(axes, metric_names): + values = np.asarray(metrics[metric_name]).flatten() + if len(values) == 0: + continue + + color = METRIC_COLORS.get(metric_name.lower(), "#7f8c8d") + + ax.hist(values, bins=bins, color=color, alpha=0.7, + edgecolor='white', linewidth=0.5) + ax.set_xlabel(metric_name.replace('_', ' ').title(), fontsize=11) + ax.set_ylabel("Count", fontsize=11) + + # Add statistics + stats = f"μ={np.mean(values):.3f}\nσ={np.std(values):.3f}" + ax.text(0.98, 0.98, stats, transform=ax.transAxes, + verticalalignment='top', horizontalalignment='right', + fontsize=9, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + ax.grid(True, alpha=0.3, axis='y') + + title = "Metric Distributions" + if layer_name: + title += f" - {layer_name}" + fig.suptitle(title, fontsize=14, fontweight='bold') + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved multi-metric histogram to {save_path}") + + return fig + + +def plot_metric_scatter_matrix( + metrics: Dict[str, np.ndarray], + labels: Optional[np.ndarray] = None, + label_names: Optional[Dict[int, str]] = None, + layer_name: str = "", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 12), + max_points: int = 2000, +) -> Optional[Figure]: + """ + Plot scatter matrix of all metric pairs with optional cluster coloring. + + Args: + metrics: Dict mapping metric_name -> values array + labels: Optional cluster labels for coloring + label_names: Optional mapping of label id -> name + layer_name: Layer name for title + save_path: Optional path to save figure + figsize: Figure size + max_points: Maximum points to plot (subsampled if more) + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not metrics or len(metrics) < 2: + return None + + metric_names = list(metrics.keys()) + n_metrics = len(metric_names) + + # Get sample size + n_samples = len(metrics[metric_names[0]]) + + # Subsample if needed + if n_samples > max_points: + idx = np.random.choice(n_samples, max_points, replace=False) + else: + idx = np.arange(n_samples) + + fig, axes = plt.subplots(n_metrics, n_metrics, figsize=figsize) + + # Cluster colors + cluster_colors = { + "critical": "#e74c3c", + "redundant": "#3498db", + "synergistic": "#2ecc71", + "background": "#95a5a6", + } + + for i, m1 in enumerate(metric_names): + for j, m2 in enumerate(metric_names): + ax = axes[i, j] + + if i == j: + # Diagonal: histogram + values = np.asarray(metrics[m1])[idx] + color = METRIC_COLORS.get(m1.lower(), "#7f8c8d") + ax.hist(values, bins=30, color=color, alpha=0.7) + ax.set_ylabel("Count" if j == 0 else "") + else: + # Off-diagonal: scatter + x = np.asarray(metrics[m2])[idx] + y = np.asarray(metrics[m1])[idx] + + if labels is not None: + labels_subset = np.asarray(labels)[idx] + unique_labels = np.unique(labels_subset) + for lbl in unique_labels: + mask = labels_subset == lbl + lbl_name = label_names.get(lbl, str(lbl)) if label_names else str(lbl) + color = cluster_colors.get(lbl_name.lower(), "#7f8c8d") + ax.scatter(x[mask], y[mask], c=color, alpha=0.5, s=10, label=lbl_name) + else: + ax.scatter(x, y, alpha=0.3, s=10, c='#34495e') + + # Labels only on edges + if i == n_metrics - 1: + ax.set_xlabel(m2.replace('_', ' ').title(), fontsize=10) + if j == 0: + ax.set_ylabel(m1.replace('_', ' ').title(), fontsize=10) + + ax.tick_params(labelsize=8) + + # Add legend to top-right plot if labels provided + if labels is not None and label_names: + axes[0, -1].legend(loc='upper right', fontsize=8) + + title = "Metric Scatter Matrix" + if layer_name: + title += f" - {layer_name}" + fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02) + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved scatter matrix to {save_path}") + + return fig + + +def plot_metric_correlation_heatmap( + metrics: Dict[str, np.ndarray], + layer_name: str = "", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (8, 6), +) -> Optional[Figure]: + """ + Plot correlation heatmap between metrics. + + Args: + metrics: Dict mapping metric_name -> values array + layer_name: Layer name for title + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not metrics or len(metrics) < 2: + return None + + metric_names = list(metrics.keys()) + n_metrics = len(metric_names) + + # Compute correlation matrix + data = np.column_stack([np.asarray(metrics[m]).flatten() for m in metric_names]) + corr = np.corrcoef(data.T) + + fig, ax = plt.subplots(figsize=figsize) + + if HAS_SEABORN: + sns.heatmap(corr, xticklabels=metric_names, yticklabels=metric_names, + annot=True, fmt='.2f', cmap='RdBu_r', center=0, + vmin=-1, vmax=1, ax=ax) + else: + im = ax.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1) + ax.set_xticks(range(n_metrics)) + ax.set_yticks(range(n_metrics)) + ax.set_xticklabels(metric_names, rotation=45, ha='right') + ax.set_yticklabels(metric_names) + + # Add annotations + for i in range(n_metrics): + for j in range(n_metrics): + ax.text(j, i, f'{corr[i, j]:.2f}', ha='center', va='center', fontsize=10) + + plt.colorbar(im, ax=ax, label='Correlation') + + title = "Metric Correlations" + if layer_name: + title += f" - {layer_name}" + ax.set_title(title, fontsize=14, fontweight='bold') + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved correlation heatmap to {save_path}") + + return fig + + +def plot_layer_metric_heatmap( + layer_metrics: Dict[str, Dict[str, float]], + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 8), + normalize_metrics: bool = True, +) -> Optional[Figure]: + """ + Plot heatmap of mean metric values across layers. + + Args: + layer_metrics: Dict mapping layer_name -> {metric_name -> mean_value} + save_path: Optional path to save figure + figsize: Figure size + normalize_metrics: Whether to normalize each metric to [0, 1] + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + if not layer_metrics: + return None + + layer_names = list(layer_metrics.keys()) + metric_names = list(layer_metrics[layer_names[0]].keys()) + + # Build matrix + matrix = np.zeros((len(layer_names), len(metric_names))) + for i, layer in enumerate(layer_names): + for j, metric in enumerate(metric_names): + matrix[i, j] = layer_metrics[layer].get(metric, 0) + + # Normalize if requested + if normalize_metrics: + for j in range(len(metric_names)): + col = matrix[:, j] + col_min, col_max = col.min(), col.max() + if col_max > col_min: + matrix[:, j] = (col - col_min) / (col_max - col_min) + + fig, ax = plt.subplots(figsize=figsize) + + if HAS_SEABORN: + sns.heatmap(matrix, xticklabels=metric_names, yticklabels=layer_names, + annot=True, fmt='.2f', cmap='YlOrRd', ax=ax) + else: + im = ax.imshow(matrix, cmap='YlOrRd', aspect='auto') + ax.set_xticks(range(len(metric_names))) + ax.set_yticks(range(len(layer_names))) + ax.set_xticklabels(metric_names, rotation=45, ha='right') + ax.set_yticklabels(layer_names) + plt.colorbar(im, ax=ax) + + ax.set_title("Mean Metrics Across Layers", fontsize=14, fontweight='bold') + ax.set_xlabel("Metric", fontsize=12) + ax.set_ylabel("Layer", fontsize=12) + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved layer metric heatmap to {save_path}") + + return fig + + +def plot_top_neurons_bar( + values: np.ndarray, + metric_name: str = "metric", + layer_name: str = "", + top_k: int = 20, + show_bottom: bool = True, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 8), +) -> Optional[Figure]: + """ + Plot bar chart of top-k (and optionally bottom-k) neurons by metric value. + + Args: + values: 1D array of metric values per neuron + metric_name: Name of the metric + layer_name: Layer name for title + top_k: Number of top/bottom neurons to show + show_bottom: Whether to also show bottom-k neurons + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL: + return None + + values = np.asarray(values).flatten() + if len(values) == 0: + return None + + sorted_idx = np.argsort(values) + top_idx = sorted_idx[-top_k:][::-1] + bottom_idx = sorted_idx[:top_k] + + n_plots = 2 if show_bottom else 1 + fig, axes = plt.subplots(n_plots, 1, figsize=figsize) + if n_plots == 1: + axes = [axes] + + # Top neurons + ax = axes[0] + ax.bar(range(top_k), values[top_idx], color='#2ecc71', alpha=0.8) + ax.set_xticks(range(top_k)) + ax.set_xticklabels([f"N{i}" for i in top_idx], rotation=45, ha='right', fontsize=9) + ax.set_ylabel(metric_name.replace('_', ' ').title(), fontsize=11) + ax.set_title(f"Top {top_k} Highest {metric_name.replace('_', ' ').title()}", fontsize=12) + ax.grid(True, alpha=0.3, axis='y') + + # Bottom neurons + if show_bottom: + ax = axes[1] + ax.bar(range(top_k), values[bottom_idx], color='#e74c3c', alpha=0.8) + ax.set_xticks(range(top_k)) + ax.set_xticklabels([f"N{i}" for i in bottom_idx], rotation=45, ha='right', fontsize=9) + ax.set_ylabel(metric_name.replace('_', ' ').title(), fontsize=11) + ax.set_title(f"Top {top_k} Lowest {metric_name.replace('_', ' ').title()}", fontsize=12) + ax.grid(True, alpha=0.3, axis='y') + + title = f"{metric_name.replace('_', ' ').title()} - Extreme Neurons" + if layer_name: + title += f" ({layer_name})" + fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02) + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved top neurons bar chart to {save_path}") + + return fig diff --git a/src/alignment/analysis/visualization/pruning_plots.py b/src/alignment/analysis/visualization/pruning_plots.py index a4949cc6..a1296854 100644 --- a/src/alignment/analysis/visualization/pruning_plots.py +++ b/src/alignment/analysis/visualization/pruning_plots.py @@ -711,3 +711,363 @@ def _plot_summary_stats(self, ax, results): table[(0, i)].set_text_props(weight="bold", color="white") ax.set_title("Accuracy at Key Sparsity Levels", pad=20) + + +# ============================================================================== +# Unified Pruning Comparison Functions (for both Vision and LLM experiments) +# ============================================================================== + +# Standard colors for pruning methods +PRUNING_METHOD_COLORS = { + # Standard methods + "random": "#95a5a6", + "magnitude": "#e74c3c", + "taylor": "#3498db", + "gradient": "#ff7f0e", + "fisher": "#2ca02c", + # Cluster/composite methods + "composite": "#9b59b6", + "cluster_aware": "#2ecc71", + "rq_only": "#f39c12", + # LLM-specific methods + "wanda": "#1abc9c", + "sparsegpt": "#e91e63", + "scar": "#00bcd4", + # Low/high variants + "magnitude_low": "#e74c3c", + "magnitude_high": "#c0392b", + "gradient_low": "#ff7f0e", + "gradient_high": "#d35400", + "scar_low": "#00bcd4", + "scar_high": "#0097a7", + # Network slimming + "network_slimming": "#8e44ad", + "chip": "#16a085", +} + +PRUNING_METHOD_MARKERS = { + "random": "o", + "magnitude": "s", + "taylor": "^", + "gradient": "d", + "composite": "p", + "cluster_aware": "*", + "wanda": "v", + "sparsegpt": "<", + "scar": ">", +} + + +def plot_unified_pruning_comparison( + results: Dict[str, Dict[float, Dict[str, Any]]], + baseline_value: Optional[float] = None, + metric: str = "accuracy", + higher_is_better: bool = True, + title: str = "Pruning Method Comparison", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 7), + show_before_ft: bool = False, + x_as_percentage: bool = True, +) -> "plt.Figure": + """ + Unified pruning comparison plot that works for both vision and LLM experiments. + + Args: + results: Dict mapping method_name -> {sparsity -> {metric: value, ...}} + Supports both 'accuracy' (vision) and 'perplexity' (LLM) metrics. + Can include 'accuracy_before_ft', 'accuracy_after_ft', 'perplexity', etc. + baseline_value: Baseline (unpruned) metric value + metric: Which metric to plot ('accuracy', 'perplexity', 'loss', etc.) + higher_is_better: Whether higher metric values are better (True for accuracy, False for perplexity) + title: Plot title + save_path: Optional path to save figure + figsize: Figure size + show_before_ft: Whether to show before-fine-tuning values (dashed lines) + x_as_percentage: Whether to display x-axis as percentage + + Returns: + Matplotlib Figure + + Example: + >>> results = { + ... 'magnitude': {0.3: {'accuracy_after_ft': 0.85}, 0.5: {'accuracy_after_ft': 0.78}}, + ... 'cluster_aware': {0.3: {'accuracy_after_ft': 0.88}, 0.5: {'accuracy_after_ft': 0.82}}, + ... } + >>> plot_unified_pruning_comparison(results, baseline_value=0.92, metric='accuracy') + """ + fig, ax = plt.subplots(figsize=figsize) + + # Determine the actual metric key to use + # Support multiple naming conventions + metric_keys = { + 'accuracy': ['accuracy_after_ft', 'accuracy', 'acc'], + 'perplexity': ['perplexity', 'ppl'], + 'loss': ['loss', 'test_loss'], + } + + for method, method_results in results.items(): + if not method_results: + continue + + sparsities = sorted([s for s in method_results.keys() if isinstance(s, (int, float))]) + values = [] + values_before = [] + + for sparsity in sparsities: + data = method_results[sparsity] + if isinstance(data, dict): + # Try different key names + value = None + for key in metric_keys.get(metric, [metric]): + if key in data: + value = data[key] + break + + if value is None and 'error' not in data: + # Try any key containing the metric name + for k, v in data.items(): + if metric in k.lower() and isinstance(v, (int, float)): + value = v + break + + values.append(value) + + # Check for before-fine-tuning value + if show_before_ft: + before_key = f'{metric}_before_ft' if metric != 'accuracy' else 'accuracy_before_ft' + values_before.append(data.get(before_key)) + else: + values.append(data if isinstance(data, (int, float)) else None) + values_before.append(None) + + # Filter out None values + valid_data = [(s, v) for s, v in zip(sparsities, values) if v is not None] + if not valid_data: + continue + + valid_sparsities, valid_values = zip(*valid_data) + + # Convert to percentage for display + if metric == 'accuracy' and max(valid_values) <= 1.0: + valid_values = [v * 100 for v in valid_values] + + x_values = [s * 100 for s in valid_sparsities] if x_as_percentage else list(valid_sparsities) + + # Get style + color = PRUNING_METHOD_COLORS.get(method.lower(), None) + marker = PRUNING_METHOD_MARKERS.get(method.lower().split('_')[0], 'o') + label = method.replace('_', ' ').title() + + # Plot main line (after fine-tuning) + ax.plot(x_values, valid_values, marker=marker, color=color, + label=label, linewidth=2.5, markersize=8) + + # Plot before fine-tuning (dashed) if requested + if show_before_ft: + valid_before = [(s * 100 if x_as_percentage else s, v) + for s, v in zip(sparsities, values_before) if v is not None] + if valid_before: + x_before, y_before = zip(*valid_before) + if metric == 'accuracy' and max(y_before) <= 1.0: + y_before = [v * 100 for v in y_before] + ax.plot(x_before, y_before, linestyle='--', color=color, + alpha=0.5, linewidth=1.5) + + # Add baseline + if baseline_value is not None: + baseline_display = baseline_value * 100 if metric == 'accuracy' and baseline_value <= 1.0 else baseline_value + ax.axhline(y=baseline_display, color='gray', linestyle='--', linewidth=1.5, + label=f'Unpruned ({baseline_display:.1f}{"%" if metric == "accuracy" else ""})') + + # Labels and formatting + xlabel = "Sparsity (%)" if x_as_percentage else "Sparsity" + ax.set_xlabel(xlabel, fontsize=12) + + ylabel = metric.replace('_', ' ').title() + if metric == 'accuracy': + ylabel += " (%)" + ax.set_ylabel(ylabel, fontsize=12) + + ax.set_title(title, fontsize=14, fontweight='bold') + ax.legend(loc='lower left' if higher_is_better else 'upper left', fontsize=10) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved unified pruning comparison to {save_path}") + + return fig + + +def plot_pruning_accuracy_loss_grid( + results: Dict[str, Dict[float, Dict[str, float]]], + baseline_acc: Optional[float] = None, + baseline_loss: Optional[float] = None, + title: str = "Pruning Analysis", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (14, 6), +) -> "plt.Figure": + """ + Plot pruning results showing both accuracy and loss in a grid. + + Args: + results: Dict mapping method -> {sparsity -> {'accuracy': v, 'loss': v}} + baseline_acc: Baseline accuracy + baseline_loss: Baseline loss + title: Plot title + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure + """ + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + for method, method_results in results.items(): + if not method_results: + continue + + sparsities = sorted([s for s in method_results.keys() if isinstance(s, (int, float))]) + accuracies = [] + losses = [] + + for s in sparsities: + data = method_results[s] + if isinstance(data, dict): + acc = data.get('accuracy_after_ft') or data.get('accuracy') + loss = data.get('loss') or data.get('test_loss') + accuracies.append(acc * 100 if acc and acc <= 1.0 else acc) + losses.append(loss) + else: + accuracies.append(None) + losses.append(None) + + color = PRUNING_METHOD_COLORS.get(method.lower(), None) + marker = PRUNING_METHOD_MARKERS.get(method.lower().split('_')[0], 'o') + label = method.replace('_', ' ').title() + + # Plot accuracy + valid_acc = [(s * 100, a) for s, a in zip(sparsities, accuracies) if a is not None] + if valid_acc: + x, y = zip(*valid_acc) + ax1.plot(x, y, marker=marker, color=color, label=label, linewidth=2, markersize=7) + + # Plot loss + valid_loss = [(s * 100, l) for s, l in zip(sparsities, losses) if l is not None] + if valid_loss: + x, y = zip(*valid_loss) + ax2.plot(x, y, marker=marker, color=color, label=label, linewidth=2, markersize=7) + + # Baselines + if baseline_acc is not None: + baseline_acc_display = baseline_acc * 100 if baseline_acc <= 1.0 else baseline_acc + ax1.axhline(y=baseline_acc_display, color='gray', linestyle='--', + linewidth=1.5, label='Unpruned') + if baseline_loss is not None: + ax2.axhline(y=baseline_loss, color='gray', linestyle='--', + linewidth=1.5, label='Unpruned') + + # Format axes + ax1.set_xlabel("Sparsity (%)", fontsize=12) + ax1.set_ylabel("Accuracy (%)", fontsize=12) + ax1.set_title("Accuracy vs Sparsity", fontsize=13) + ax1.legend(loc='lower left', fontsize=9) + ax1.grid(True, alpha=0.3) + + ax2.set_xlabel("Sparsity (%)", fontsize=12) + ax2.set_ylabel("Loss", fontsize=12) + ax2.set_title("Loss vs Sparsity", fontsize=13) + ax2.legend(loc='upper left', fontsize=9) + ax2.grid(True, alpha=0.3) + + fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02) + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved pruning accuracy/loss grid to {save_path}") + + return fig + + +def plot_pruning_recovery_chart( + results: Dict[str, Dict[float, Dict[str, float]]], + baseline_value: float, + metric: str = "accuracy", + title: str = "Accuracy Recovery After Pruning", + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 6), +) -> "plt.Figure": + """ + Plot the percentage of original performance retained after pruning. + + Args: + results: Dict mapping method -> {sparsity -> {metric: value}} + baseline_value: Baseline (unpruned) value + metric: Which metric to use + title: Plot title + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure + """ + fig, ax = plt.subplots(figsize=figsize) + + for method, method_results in results.items(): + if not method_results: + continue + + sparsities = sorted([s for s in method_results.keys() if isinstance(s, (int, float))]) + recoveries = [] + + for s in sparsities: + data = method_results[s] + if isinstance(data, dict): + value = data.get(f'{metric}_after_ft') or data.get(metric) + if value is not None and baseline_value > 0: + recovery = (value / baseline_value) * 100 + recoveries.append(recovery) + else: + recoveries.append(None) + else: + recoveries.append(None) + + valid_data = [(s * 100, r) for s, r in zip(sparsities, recoveries) if r is not None] + if not valid_data: + continue + + x, y = zip(*valid_data) + + color = PRUNING_METHOD_COLORS.get(method.lower(), None) + marker = PRUNING_METHOD_MARKERS.get(method.lower().split('_')[0], 'o') + label = method.replace('_', ' ').title() + + ax.plot(x, y, marker=marker, color=color, label=label, linewidth=2.5, markersize=8) + + # Reference lines + ax.axhline(y=100, color='gray', linestyle='--', linewidth=1.5, label='100% (Unpruned)') + ax.axhline(y=90, color='green', linestyle=':', linewidth=1, alpha=0.7, label='90% Threshold') + + ax.set_xlabel("Sparsity (%)", fontsize=12) + ax.set_ylabel(f"{metric.title()} Recovery (%)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.legend(loc='lower left', fontsize=10) + ax.grid(True, alpha=0.3) + ax.set_ylim([50, 105]) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved pruning recovery chart to {save_path}") + + return fig From 169704618f747763ac35ff835b185c3aa978c7f0 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 8 Dec 2025 12:53:26 -0500 Subject: [PATCH 03/12] correct import --- scripts/run_experiment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 17efd3b8..6f7abe2e 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -18,6 +18,7 @@ from datetime import datetime from pathlib import Path +import torch import yaml # Add the project root and src directory to Python path From 453166e1deab1ae16006314f45093be0a8b0b842 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 8 Dec 2025 15:49:24 -0500 Subject: [PATCH 04/12] refactor config/ update vision configs --- configs/README.md | 99 ++- configs/{paper => prune_llm}/README.md | 0 .../{paper => prune_llm}/llama2_7b_full.yaml | 0 configs/prune_llm/llama2_7b_unified.yaml | 387 ++++++++++ .../{paper => prune_llm}/llama3_8b_full.yaml | 0 configs/prune_llm/llama3_8b_unified.yaml | 451 ++++++++++++ .../{paper => prune_llm}/mistral_7b_full.yaml | 0 configs/prune_llm/mistral_7b_unified.yaml | 386 ++++++++++ .../{paper => prune_llm}/qwen2_7b_full.yaml | 4 +- configs/prune_llm/qwen2_7b_unified.yaml | 387 ++++++++++ configs/unified_template.yaml | 272 +++++++ .../README.md | 0 .../mobilenetv2_cifar10_full.yaml | 0 .../mobilenetv2_cifar10_unified.yaml | 392 ++++++++++ .../resnet18_cifar10_full.yaml | 0 .../resnet18_cifar10_unified.yaml | 401 ++++++++++ .../resnet50_imagenet100.yaml | 0 .../resnet50_imagenet100_unified.yaml | 409 +++++++++++ .../vgg16_cifar10_full.yaml | 0 .../vision_prune/vgg16_cifar10_unified.yaml | 373 ++++++++++ docs/source/developer_guide/extensibility.rst | 275 +++++++ docs/source/developer_guide/index.rst | 30 + .../{paper => prune_llm}/run_all_paper.sh | 0 .../{paper => prune_llm}/run_llama2_7b.sh | 0 .../{paper => prune_llm}/run_llama3_8b.sh | 7 +- .../{paper => prune_llm}/run_mistral_7b.sh | 0 .../{paper => prune_llm}/run_qwen2_7b.sh | 0 .../run_cluster_analysis_resnet18.sh | 4 +- .../run_cluster_analysis_resnet50.sh | 0 .../analysis/visualization/cluster_plots.py | 2 +- src/alignment/configs/__init__.py | 72 +- src/alignment/configs/config_loader.py | 385 ++++++++++ src/alignment/configs/unified_config.py | 683 ++++++++++++++++++ src/alignment/core/__init__.py | 170 ++++- src/alignment/core/protocols.py | 502 ++++++++++++- src/alignment/core/registry.py | 590 ++++++++++++++- 36 files changed, 6218 insertions(+), 63 deletions(-) rename configs/{paper => prune_llm}/README.md (100%) rename configs/{paper => prune_llm}/llama2_7b_full.yaml (100%) create mode 100644 configs/prune_llm/llama2_7b_unified.yaml rename configs/{paper => prune_llm}/llama3_8b_full.yaml (100%) create mode 100644 configs/prune_llm/llama3_8b_unified.yaml rename configs/{paper => prune_llm}/mistral_7b_full.yaml (100%) create mode 100644 configs/prune_llm/mistral_7b_unified.yaml rename configs/{paper => prune_llm}/qwen2_7b_full.yaml (100%) create mode 100644 configs/prune_llm/qwen2_7b_unified.yaml create mode 100644 configs/unified_template.yaml rename configs/{cluster_analysis => vision_prune}/README.md (100%) rename configs/{cluster_analysis => vision_prune}/mobilenetv2_cifar10_full.yaml (100%) create mode 100644 configs/vision_prune/mobilenetv2_cifar10_unified.yaml rename configs/{cluster_analysis => vision_prune}/resnet18_cifar10_full.yaml (100%) create mode 100644 configs/vision_prune/resnet18_cifar10_unified.yaml rename configs/{cluster_analysis => vision_prune}/resnet50_imagenet100.yaml (100%) create mode 100644 configs/vision_prune/resnet50_imagenet100_unified.yaml rename configs/{cluster_analysis => vision_prune}/vgg16_cifar10_full.yaml (100%) create mode 100644 configs/vision_prune/vgg16_cifar10_unified.yaml create mode 100644 docs/source/developer_guide/extensibility.rst create mode 100644 docs/source/developer_guide/index.rst rename slurm_jobs/{paper => prune_llm}/run_all_paper.sh (100%) rename slurm_jobs/{paper => prune_llm}/run_llama2_7b.sh (100%) rename slurm_jobs/{paper => prune_llm}/run_llama3_8b.sh (91%) rename slurm_jobs/{paper => prune_llm}/run_mistral_7b.sh (100%) rename slurm_jobs/{paper => prune_llm}/run_qwen2_7b.sh (100%) rename slurm_jobs/{ => prune_vision}/run_cluster_analysis_resnet18.sh (96%) rename slurm_jobs/{ => prune_vision}/run_cluster_analysis_resnet50.sh (100%) create mode 100644 src/alignment/configs/unified_config.py diff --git a/configs/README.md b/configs/README.md index 93c15154..b9908763 100644 --- a/configs/README.md +++ b/configs/README.md @@ -5,12 +5,16 @@ ``` configs/ ├── template.yaml # Complete template with all options -├── cluster_analysis/ # Cluster-based analysis configs +├── unified_template.yaml # Unified format template +├── vision_prune/ # Vision model pruning configs │ ├── resnet18_cifar10_full.yaml +│ ├── resnet18_cifar10_unified.yaml # Unified format version +│ ├── resnet50_imagenet100.yaml │ ├── vgg16_cifar10_full.yaml │ └── mobilenetv2_cifar10_full.yaml -├── paper/ # LLM paper experiment configs +├── prune_llm/ # LLM pruning configs │ ├── llama3_8b_full.yaml +│ ├── llama3_8b_unified.yaml # Unified format version │ ├── llama2_7b_full.yaml │ ├── mistral_7b_full.yaml │ └── qwen2_7b_full.yaml @@ -108,3 +112,94 @@ supernode: core_fraction: 0.01 protect_core: true ``` + +## Unified Configuration Format + +The framework supports a **unified configuration format** that works consistently +across both vision and LLM experiments. Files with `_unified.yaml` suffix use this format. + +### Unified Metric Names + +| Unified Name | Vision Aliases | LLM Aliases | +|--------------|---------------|-------------| +| `rayleigh_quotient` | `rq`, `compute_rq` | `rayleigh_quotient` | +| `redundancy` | `compute_redundancy` | `gaussian_mi_analytic`, `average_redundancy` | +| `synergy` | `compute_synergy` | `synergy_gaussian_mmi` | +| `magnitude` | `weight_magnitude` | `activation_l2_norm` | +| `scar` | - | `scar_*` (LLM-specific) | + +### Unified Structure + +```yaml +experiment: + name: "my_experiment" + type: "cluster_analysis" # or "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/..." + +model: + name: "resnet18" # or "hf_causal_lm" + # Vision: num_classes, pretrained + # LLM: model_id, dtype, device_map + +dataset: + name: "cifar10" # or "wikitext" + batch_size: 128 + +calibration: + num_samples: 5000 # Vision: ~5000, LLM: ~128 + +metrics: + rayleigh_quotient: + enabled: true + redundancy: + enabled: true + synergy: + enabled: true + magnitude: + enabled: true + composite_weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 + synergy: 0.33 + +clustering: # Vision + enabled: true + n_clusters: 4 + +supernode: # LLM (alternative to clustering) + enabled: true + score_metric: "scar_loss_proxy" + +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] + algorithms: [...] + selection_modes: ["low", "high"] + +visualization: + enabled: true + format: "png" + +output: + dir: "./results/..." + save_metrics: true +``` + +### Loading Unified Configs + +```python +from alignment.configs import load_unified_config + +# Works with both old and unified formats! +config = load_unified_config("configs/vision_prune/resnet18_cifar10_unified.yaml") + +# Access in a consistent way +print(config.experiment.name) +print(config.model.name) +print(config.pruning.ratios) + +# Validate +warnings = config.validate() +``` diff --git a/configs/paper/README.md b/configs/prune_llm/README.md similarity index 100% rename from configs/paper/README.md rename to configs/prune_llm/README.md diff --git a/configs/paper/llama2_7b_full.yaml b/configs/prune_llm/llama2_7b_full.yaml similarity index 100% rename from configs/paper/llama2_7b_full.yaml rename to configs/prune_llm/llama2_7b_full.yaml diff --git a/configs/prune_llm/llama2_7b_unified.yaml b/configs/prune_llm/llama2_7b_unified.yaml new file mode 100644 index 00000000..cee0950d --- /dev/null +++ b/configs/prune_llm/llama2_7b_unified.yaml @@ -0,0 +1,387 @@ +# ============================================================================= +# LLAMA-2-7B COMPREHENSIVE ANALYSIS - UNIFIED FORMAT +# ============================================================================= +# Same structure as LLaMA-3.1-8B for cross-model generalization results +# Llama-2 uses different FFN structure (no SwiGLU gate) +# +# Usage: python scripts/run_experiment.py --config configs/prune_llm/llama2_7b_unified.yaml +# Estimated runtime: ~4-6 hours on 1x A100 +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "llama2_7b_paper_results" + type: "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/paper/llama2_7b" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "hf_causal_lm" + model_id: "meta-llama/Llama-2-7b-hf" + dtype: "bfloat16" + device_map: "auto" + trust_remote_code: true + + tracked_layers: + - "model.model.layers.*.mlp.up_proj" + - "model.model.layers.*.mlp.gate_proj" + - "model.model.layers.*.mlp.down_proj" + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "train" + batch_size: 1 + num_workers: 0 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 128 + max_length: 2048 + batch_size: 4 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +# Note: supernode_protection_score, supernode_connectivity_score are computed +# by the supernode analysis pipeline, not as standalone metrics +metrics: + rayleigh_quotient: + enabled: true + relative: true + regularization: 1.0e-6 + + redundancy: + enabled: true + + magnitude: + enabled: true + + scar: + enabled: true + num_samples: 64 + max_length: 512 + + composite_weights: + rayleigh_quotient: 0.25 + redundancy: -0.25 + magnitude: 0.15 + scar_loss_proxy: 0.35 + +# ----------------------------------------------------------------------------- +# SUPERNODE +# ----------------------------------------------------------------------------- +supernode: + enabled: true + score_metric: "scar_loss_proxy" + core_fraction: 0.01 + halo_fraction: 0.10 + follower_fraction: 0.10 + protect_core: true + cross_layer_analysis: true + compare_by_connection: true + compute_metrics: + - "activation" + - "rayleigh_quotient" + - "mutual_information" + - "redundancy" + +# ----------------------------------------------------------------------------- +# CLUSTERING (disabled for LLM) +# ----------------------------------------------------------------------------- +clustering: + enabled: false + n_clusters: 4 + features: ["rayleigh_quotient", "redundancy", "magnitude"] + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + sample_pairs: 2000 + max_refs: 512 + max_pairs_per_group: 1000 + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + selection_modes: ["low", "high"] + distribution: "uniform" + structured: true + dependency_aware: true + + algorithms: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "wanda" + - "sparsegpt" + + scoring_methods: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "within_layer_importance" + - "wanda" + - "sparsegpt" + + fine_tune: + enabled: false + epochs: 1 + learning_rate: 1.0e-5 + +# ----------------------------------------------------------------------------- +# EVALUATION +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + accuracy: false + loss: true + + perplexity_enabled: true + perplexity_datasets: + - name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "test" + - name: "c4" + split: "validation" + max_samples: 1000 + + benchmarks_enabled: true + benchmark_tasks: + - name: "hellaswag" + num_fewshot: 0 + - name: "piqa" + num_fewshot: 0 + - name: "boolq" + num_fewshot: 0 + - name: "winogrande" + num_fewshot: 0 + - name: "arc_easy" + num_fewshot: 0 + - name: "arc_challenge" + num_fewshot: 0 + - name: "openbookqa" + num_fewshot: 0 + - name: "hellaswag" + num_fewshot: 5 + - name: "piqa" + num_fewshot: 5 + - name: "arc_challenge" + num_fewshot: 5 + - name: "mmlu" + num_fewshot: 5 + benchmark_fewshot: 0 + benchmark_batch_size: 8 + +# ----------------------------------------------------------------------------- +# VISUALIZATION +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + + histograms: true + violin_plots: true + correlation_heatmap: true + pruning_comparison: true + pruning_recovery: true + supernode_distribution: true + halo_structure: true + cross_layer_heatmap: true + + scatter_pairs: + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "magnitude"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/paper/llama2_7b" + save_metrics: true + save_figures: true + save_checkpoints: false + +# ----------------------------------------------------------------------------- +# EXTRA +# ----------------------------------------------------------------------------- +extra: + analysis: + # Llama-2 has 32 layers + layer_indices: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 31] + save_scores: true + generate_plots: true + metrics: + - "activation_l2_norm" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "average_redundancy" + - "scar_activation_power" + - "scar_curvature" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "halo_redundancy" + - "cross_layer_importance" + - "within_layer_importance" + plots: + histograms: true + scatter_plots: true + pruning_curves: true + redundancy_heatmaps: true + scatter_pairs: + - ["activation_l2_norm", "rayleigh_quotient"] + - ["scar_activation_power", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["average_redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + + supernode_robustness: + enabled: true + supernode_fraction: 0.01 + num_bootstrap_samples: 10 + batch_size: 32 + max_samples: 256 + metrics: + - "scar_activation_power" + - "scar_loss_proxy" + - "scar_taylor" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "activation_l2_norm" + target_layers: null + + supernode_summary: + enabled: true + outlier_analysis: true + + halo_analysis: + enabled: true + supernode_fraction: 0.01 + halo_fraction: 0.10 + num_samples: 8 + max_length: 256 + sample_pairs: 2000 + max_samples: 2000 + max_pairs_per_group: 1000 + plots: + depth_comparison: true + histograms: true + heatmaps: true + comprehensive: true + + cross_layer: + enabled: true + max_refs: 512 + rq_weight: 0.25 + downstream_weight: 0.35 + within_redundancy_weight: 0.25 + activation_weight: 0.15 + normalize: true + + generalized_importance: + enabled: true + neighborhood_fraction: 0.10 + propagation_weight: 0.3 + redundancy_penalty: 0.5 + num_samples: 8 + max_length: 256 + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + histograms: true + scatter_plots: true + heatmaps: true + supernode_distribution: + enabled: true + plot_loss_proxy_histogram: true + plot_concentration_by_layer: true + highlight_top_percent: [1, 5, 10] + halo_structure: + enabled: true + plot_redundancy_by_depth: true + plot_protection_vs_connection: true + plot_within_vs_cross_group: true + cross_layer: + enabled: true + plot_downstream_importance_by_layer: true + plot_importance_vs_redundancy: true + plot_efficiency_heatmap: true + pruning_curves: + enabled: true + plot_sparsity_vs_perplexity: true + plot_sparsity_vs_accuracy: true + metrics_to_compare: + - "supernode_connectivity_score" + - "cross_layer_importance" + - "generalized_importance" + - "scar_loss_proxy" + - "wanda" + - "sparsegpt" + - "activation_l2_norm" + - "random" + supernode_robustness: + enabled: true + jaccard_heatmap: true + spearman_heatmap: true + bootstrap_stability: true + consistency_bars: true + scatter_pairs: + - ["activation_l2_norm", "scar_loss_proxy"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + - ["generalized_importance", "scar_loss_proxy"] + + do_scar_metrics: true + do_directed_redundancy: true + do_connectivity_pruning: true + do_halo_analysis: true + do_generalized_importance: true diff --git a/configs/paper/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml similarity index 100% rename from configs/paper/llama3_8b_full.yaml rename to configs/prune_llm/llama3_8b_full.yaml diff --git a/configs/prune_llm/llama3_8b_unified.yaml b/configs/prune_llm/llama3_8b_unified.yaml new file mode 100644 index 00000000..b421ae2e --- /dev/null +++ b/configs/prune_llm/llama3_8b_unified.yaml @@ -0,0 +1,451 @@ +# ============================================================================= +# LLAMA-3.1-8B COMPREHENSIVE ANALYSIS - UNIFIED FORMAT +# ============================================================================= +# This is the same config as prune_llm/llama3_8b_full.yaml +# but converted to the unified format. +# +# Key features of unified format: +# - Consistent section names across vision/LLM experiments +# - Unified metric naming (redundancy instead of gaussian_mi_analytic) +# - All experiment-specific settings in `extra:` section +# - Same pruning/evaluation/visualization structure +# +# Usage: python scripts/run_experiment.py --config configs/unified/llama3_8b_unified.yaml +# Estimated runtime: ~6-8 hours on 1x A100 +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "llama3_8b_paper_results" + type: "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/paper/llama3_8b" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "hf_causal_lm" + model_id: "meta-llama/Llama-3.1-8B" + dtype: "bfloat16" + device_map: "auto" + trust_remote_code: true + + tracked_layers: + - "model.model.layers.*.mlp.up_proj" + - "model.model.layers.*.mlp.gate_proj" + - "model.model.layers.*.mlp.down_proj" + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "train" + batch_size: 1 + num_workers: 0 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 128 + max_length: 2048 + batch_size: 4 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +# Unified naming with LLM-specific extensions +# Core metrics (same names as vision): +# rayleigh_quotient, redundancy, synergy, magnitude +# LLM-specific: +# scar (activation_power, curvature, loss_proxy, taylor) +# Note: supernode_protection_score, supernode_connectivity_score are computed +# by the supernode analysis pipeline, not as standalone metrics +# ----------------------------------------------------------------------------- +metrics: + rayleigh_quotient: + enabled: true + relative: true + regularization: 1.0e-6 + + redundancy: + enabled: true + # LLM uses: gaussian_mi_analytic, average_redundancy + + magnitude: + enabled: true + # LLM uses: activation_l2_norm + + # LLM-specific SCAR metrics + scar: + enabled: true + num_samples: 64 + max_length: 512 + # Computes: scar_activation_power, scar_curvature, scar_loss_proxy, scar_taylor + + composite_weights: + rayleigh_quotient: 0.25 + redundancy: -0.25 + magnitude: 0.15 + scar_loss_proxy: 0.35 + +# ----------------------------------------------------------------------------- +# SUPERNODE (LLM outlier detection - alternative to clustering) +# ----------------------------------------------------------------------------- +supernode: + enabled: true + score_metric: "scar_loss_proxy" + core_fraction: 0.01 + halo_fraction: 0.10 + follower_fraction: 0.10 + protect_core: true + cross_layer_analysis: true + compare_by_connection: true + compute_metrics: + - "activation" + - "rayleigh_quotient" + - "mutual_information" + - "redundancy" + +# ----------------------------------------------------------------------------- +# CLUSTERING (disabled for LLM, uses supernode instead) +# ----------------------------------------------------------------------------- +clustering: + enabled: false + n_clusters: 4 + features: ["rayleigh_quotient", "redundancy", "magnitude"] + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + sample_pairs: 2000 + max_refs: 512 + max_pairs_per_group: 1000 + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + selection_modes: ["low", "high"] + distribution: "uniform" + structured: true + dependency_aware: true + + # Algorithms to test (unified + LLM-specific) + algorithms: + # Core (same as vision) + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + + # LLM-specific + - "scar_loss_proxy" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + + # SOTA baselines + - "wanda" + - "sparsegpt" + + scoring_methods: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "within_layer_importance" + - "wanda" + - "sparsegpt" + + fine_tune: + enabled: false + epochs: 1 + learning_rate: 1.0e-5 + +# ----------------------------------------------------------------------------- +# EVALUATION +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + accuracy: false # Not applicable to LLM + loss: true + + # LLM-specific: Perplexity + perplexity_enabled: true + perplexity_datasets: + - name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "test" + - name: "c4" + split: "validation" + max_samples: 1000 + + # LLM-specific: Benchmarks + benchmarks_enabled: true + benchmark_tasks: + - name: "hellaswag" + num_fewshot: 0 + - name: "piqa" + num_fewshot: 0 + - name: "boolq" + num_fewshot: 0 + - name: "winogrande" + num_fewshot: 0 + - name: "arc_easy" + num_fewshot: 0 + - name: "arc_challenge" + num_fewshot: 0 + - name: "openbookqa" + num_fewshot: 0 + - name: "hellaswag" + num_fewshot: 5 + - name: "piqa" + num_fewshot: 5 + - name: "arc_challenge" + num_fewshot: 5 + - name: "mmlu" + num_fewshot: 5 + benchmark_fewshot: 0 + benchmark_batch_size: 8 + +# ----------------------------------------------------------------------------- +# VISUALIZATION +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + + # Plot types (same as vision) + histograms: true + violin_plots: true + correlation_heatmap: true + pruning_comparison: true + pruning_recovery: true + + # LLM-specific plots + supernode_distribution: true + halo_structure: true + cross_layer_heatmap: true + + scatter_pairs: + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "magnitude"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/paper/llama3_8b" + save_metrics: true + save_figures: true + save_checkpoints: false + +# ----------------------------------------------------------------------------- +# EXTRA (LLM-specific settings not in unified schema) +# ----------------------------------------------------------------------------- +extra: + # Analysis configuration (same as original llama3_8b_full.yaml) + analysis: + layer_indices: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 31] + save_scores: true + generate_plots: true + metrics: + # Standard alignment metrics + - "activation_l2_norm" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "average_redundancy" + # SCAR metrics + - "scar_activation_power" + - "scar_curvature" + - "scar_loss_proxy" + - "scar_taylor" + # Supernode-aware metrics + - "supernode_protection_score" + - "supernode_connectivity_score" + # Extended metrics + - "halo_redundancy" + - "multi_supernode" + - "cross_layer_redundancy" + - "cross_layer_importance" + - "within_layer_importance" + plots: + histograms: true + scatter_plots: true + pruning_curves: true + redundancy_heatmaps: true + scatter_pairs: + # Alignment vs magnitude + - ["activation_l2_norm", "rayleigh_quotient"] + - ["activation_l2_norm", "gaussian_mi_analytic"] + # SCAR relationships + - ["scar_activation_power", "scar_loss_proxy"] + - ["scar_activation_power", "scar_curvature"] + - ["scar_taylor", "scar_loss_proxy"] + # Alignment vs SCAR + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["gaussian_mi_analytic", "scar_loss_proxy"] + # Redundancy relationships + - ["average_redundancy", "rayleigh_quotient"] + - ["average_redundancy", "activation_l2_norm"] + # Cross-layer + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + + # Supernode robustness analysis + supernode_robustness: + enabled: true + supernode_fraction: 0.01 + num_bootstrap_samples: 10 + batch_size: 32 + max_samples: 256 + metrics: + - "scar_activation_power" + - "scar_loss_proxy" + - "scar_taylor" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "activation_l2_norm" + target_layers: null + + # Supernode summary analysis + supernode_summary: + enabled: true + outlier_analysis: true + + # Halo analysis (detailed settings) + halo_analysis: + enabled: true + supernode_fraction: 0.01 + halo_fraction: 0.10 + num_samples: 8 + max_length: 256 + sample_pairs: 2000 + max_samples: 2000 + max_pairs_per_group: 1000 + plots: + depth_comparison: true + histograms: true + heatmaps: true + comprehensive: true + + # Multi-supernode clustering + multi_supernode: + enabled: true + supernode_fraction: 0.05 + n_clusters: 4 + halo_fraction: 0.10 + clustering_features: "weights" + + # Cross-layer importance weights + cross_layer: + enabled: true + max_refs: 512 + rq_weight: 0.25 + downstream_weight: 0.35 + within_redundancy_weight: 0.25 + activation_weight: 0.15 + normalize: true + + # Generalized importance + generalized_importance: + enabled: true + neighborhood_fraction: 0.10 + propagation_weight: 0.3 + redundancy_penalty: 0.5 + num_samples: 8 + max_length: 256 + + # Visualization for paper figures (detailed) + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + histograms: true + scatter_plots: true + heatmaps: true + # Figure 1: Supernode distribution + supernode_distribution: + enabled: true + plot_loss_proxy_histogram: true + plot_concentration_by_layer: true + highlight_top_percent: [1, 5, 10] + # Figure 2: Halo redundancy structure + halo_structure: + enabled: true + plot_redundancy_by_depth: true + plot_protection_vs_connection: true + plot_within_vs_cross_group: true + # Figure 3: Cross-layer importance + cross_layer: + enabled: true + plot_downstream_importance_by_layer: true + plot_importance_vs_redundancy: true + plot_efficiency_heatmap: true + # Figure 4: Pruning comparison curves + pruning_curves: + enabled: true + plot_sparsity_vs_perplexity: true + plot_sparsity_vs_accuracy: true + metrics_to_compare: + - "supernode_connectivity_score" + - "cross_layer_importance" + - "generalized_importance" + - "scar_loss_proxy" + - "wanda" + - "sparsegpt" + - "activation_l2_norm" + - "random" + # Supernode robustness plots + supernode_robustness: + enabled: true + jaccard_heatmap: true + spearman_heatmap: true + bootstrap_stability: true + consistency_bars: true + scatter_pairs: + - ["activation_l2_norm", "scar_loss_proxy"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["downstream_importance", "within_layer_redundancy"] + - ["cross_layer_importance", "activation_l2_norm"] + - ["generalized_importance", "scar_loss_proxy"] + + # Advanced analysis flags + do_scar_metrics: true + do_directed_redundancy: true + do_connectivity_pruning: true + do_halo_analysis: true + do_generalized_importance: true diff --git a/configs/paper/mistral_7b_full.yaml b/configs/prune_llm/mistral_7b_full.yaml similarity index 100% rename from configs/paper/mistral_7b_full.yaml rename to configs/prune_llm/mistral_7b_full.yaml diff --git a/configs/prune_llm/mistral_7b_unified.yaml b/configs/prune_llm/mistral_7b_unified.yaml new file mode 100644 index 00000000..735bbc6a --- /dev/null +++ b/configs/prune_llm/mistral_7b_unified.yaml @@ -0,0 +1,386 @@ +# ============================================================================= +# MISTRAL-7B COMPREHENSIVE ANALYSIS - UNIFIED FORMAT +# ============================================================================= +# Same structure as LLaMA-3.1-8B for cross-model generalization results +# +# Usage: python scripts/run_experiment.py --config configs/prune_llm/mistral_7b_unified.yaml +# Estimated runtime: ~4-6 hours on 1x A100 +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "mistral_7b_paper_results" + type: "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/paper/mistral_7b" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "hf_causal_lm" + model_id: "mistralai/Mistral-7B-v0.1" + dtype: "bfloat16" + device_map: "auto" + trust_remote_code: true + + tracked_layers: + - "model.model.layers.*.mlp.up_proj" + - "model.model.layers.*.mlp.gate_proj" + - "model.model.layers.*.mlp.down_proj" + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "train" + batch_size: 1 + num_workers: 0 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 128 + max_length: 2048 + batch_size: 4 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +# Note: supernode_protection_score, supernode_connectivity_score are computed +# by the supernode analysis pipeline, not as standalone metrics +metrics: + rayleigh_quotient: + enabled: true + relative: true + regularization: 1.0e-6 + + redundancy: + enabled: true + + magnitude: + enabled: true + + scar: + enabled: true + num_samples: 64 + max_length: 512 + + composite_weights: + rayleigh_quotient: 0.25 + redundancy: -0.25 + magnitude: 0.15 + scar_loss_proxy: 0.35 + +# ----------------------------------------------------------------------------- +# SUPERNODE +# ----------------------------------------------------------------------------- +supernode: + enabled: true + score_metric: "scar_loss_proxy" + core_fraction: 0.01 + halo_fraction: 0.10 + follower_fraction: 0.10 + protect_core: true + cross_layer_analysis: true + compare_by_connection: true + compute_metrics: + - "activation" + - "rayleigh_quotient" + - "mutual_information" + - "redundancy" + +# ----------------------------------------------------------------------------- +# CLUSTERING (disabled for LLM) +# ----------------------------------------------------------------------------- +clustering: + enabled: false + n_clusters: 4 + features: ["rayleigh_quotient", "redundancy", "magnitude"] + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + sample_pairs: 2000 + max_refs: 512 + max_pairs_per_group: 1000 + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + selection_modes: ["low", "high"] + distribution: "uniform" + structured: true + dependency_aware: true + + algorithms: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "wanda" + - "sparsegpt" + + scoring_methods: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "within_layer_importance" + - "wanda" + - "sparsegpt" + + fine_tune: + enabled: false + epochs: 1 + learning_rate: 1.0e-5 + +# ----------------------------------------------------------------------------- +# EVALUATION +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + accuracy: false + loss: true + + perplexity_enabled: true + perplexity_datasets: + - name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "test" + - name: "c4" + split: "validation" + max_samples: 1000 + + benchmarks_enabled: true + benchmark_tasks: + - name: "hellaswag" + num_fewshot: 0 + - name: "piqa" + num_fewshot: 0 + - name: "boolq" + num_fewshot: 0 + - name: "winogrande" + num_fewshot: 0 + - name: "arc_easy" + num_fewshot: 0 + - name: "arc_challenge" + num_fewshot: 0 + - name: "openbookqa" + num_fewshot: 0 + - name: "hellaswag" + num_fewshot: 5 + - name: "piqa" + num_fewshot: 5 + - name: "arc_challenge" + num_fewshot: 5 + - name: "mmlu" + num_fewshot: 5 + benchmark_fewshot: 0 + benchmark_batch_size: 8 + +# ----------------------------------------------------------------------------- +# VISUALIZATION +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + + histograms: true + violin_plots: true + correlation_heatmap: true + pruning_comparison: true + pruning_recovery: true + supernode_distribution: true + halo_structure: true + cross_layer_heatmap: true + + scatter_pairs: + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "magnitude"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/paper/mistral_7b" + save_metrics: true + save_figures: true + save_checkpoints: false + +# ----------------------------------------------------------------------------- +# EXTRA +# ----------------------------------------------------------------------------- +extra: + analysis: + # Mistral has 32 layers + layer_indices: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 31] + save_scores: true + generate_plots: true + metrics: + - "activation_l2_norm" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "average_redundancy" + - "scar_activation_power" + - "scar_curvature" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "halo_redundancy" + - "cross_layer_importance" + - "within_layer_importance" + plots: + histograms: true + scatter_plots: true + pruning_curves: true + redundancy_heatmaps: true + scatter_pairs: + - ["activation_l2_norm", "rayleigh_quotient"] + - ["scar_activation_power", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["average_redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + + supernode_robustness: + enabled: true + supernode_fraction: 0.01 + num_bootstrap_samples: 10 + batch_size: 32 + max_samples: 256 + metrics: + - "scar_activation_power" + - "scar_loss_proxy" + - "scar_taylor" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "activation_l2_norm" + target_layers: null + + supernode_summary: + enabled: true + outlier_analysis: true + + halo_analysis: + enabled: true + supernode_fraction: 0.01 + halo_fraction: 0.10 + num_samples: 8 + max_length: 256 + sample_pairs: 2000 + max_samples: 2000 + max_pairs_per_group: 1000 + plots: + depth_comparison: true + histograms: true + heatmaps: true + comprehensive: true + + cross_layer: + enabled: true + max_refs: 512 + rq_weight: 0.25 + downstream_weight: 0.35 + within_redundancy_weight: 0.25 + activation_weight: 0.15 + normalize: true + + generalized_importance: + enabled: true + neighborhood_fraction: 0.10 + propagation_weight: 0.3 + redundancy_penalty: 0.5 + num_samples: 8 + max_length: 256 + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + histograms: true + scatter_plots: true + heatmaps: true + supernode_distribution: + enabled: true + plot_loss_proxy_histogram: true + plot_concentration_by_layer: true + highlight_top_percent: [1, 5, 10] + halo_structure: + enabled: true + plot_redundancy_by_depth: true + plot_protection_vs_connection: true + plot_within_vs_cross_group: true + cross_layer: + enabled: true + plot_downstream_importance_by_layer: true + plot_importance_vs_redundancy: true + plot_efficiency_heatmap: true + pruning_curves: + enabled: true + plot_sparsity_vs_perplexity: true + plot_sparsity_vs_accuracy: true + metrics_to_compare: + - "supernode_connectivity_score" + - "cross_layer_importance" + - "generalized_importance" + - "scar_loss_proxy" + - "wanda" + - "sparsegpt" + - "activation_l2_norm" + - "random" + supernode_robustness: + enabled: true + jaccard_heatmap: true + spearman_heatmap: true + bootstrap_stability: true + consistency_bars: true + scatter_pairs: + - ["activation_l2_norm", "scar_loss_proxy"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + - ["generalized_importance", "scar_loss_proxy"] + + do_scar_metrics: true + do_directed_redundancy: true + do_connectivity_pruning: true + do_halo_analysis: true + do_generalized_importance: true diff --git a/configs/paper/qwen2_7b_full.yaml b/configs/prune_llm/qwen2_7b_full.yaml similarity index 100% rename from configs/paper/qwen2_7b_full.yaml rename to configs/prune_llm/qwen2_7b_full.yaml index aefd568e..899ddd87 100644 --- a/configs/paper/qwen2_7b_full.yaml +++ b/configs/prune_llm/qwen2_7b_full.yaml @@ -61,7 +61,7 @@ metrics: rayleigh_quotient: relative: true regularization: 1.0e-6 - + # ============================================================================ # LLM-SPECIFIC SETTINGS # ============================================================================ @@ -223,7 +223,7 @@ pruning: distribution: "uniform" structured: true dependency_aware: true - + sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] selection_modes: ["low", "high"] diff --git a/configs/prune_llm/qwen2_7b_unified.yaml b/configs/prune_llm/qwen2_7b_unified.yaml new file mode 100644 index 00000000..eaa616e5 --- /dev/null +++ b/configs/prune_llm/qwen2_7b_unified.yaml @@ -0,0 +1,387 @@ +# ============================================================================= +# QWEN2-7B COMPREHENSIVE ANALYSIS - UNIFIED FORMAT +# ============================================================================= +# Same structure as LLaMA-3.1-8B for cross-model generalization results +# Qwen2 uses different FFN structure (larger intermediate size) +# +# Usage: python scripts/run_experiment.py --config configs/prune_llm/qwen2_7b_unified.yaml +# Estimated runtime: ~4-6 hours on 1x A100 +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "qwen2_7b_paper_results" + type: "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/paper/qwen2_7b" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "hf_causal_lm" + model_id: "Qwen/Qwen2-7B" + dtype: "bfloat16" + device_map: "auto" + trust_remote_code: true + + tracked_layers: + - "model.model.layers.*.mlp.up_proj" + - "model.model.layers.*.mlp.gate_proj" + - "model.model.layers.*.mlp.down_proj" + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "train" + batch_size: 1 + num_workers: 0 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 128 + max_length: 2048 + batch_size: 4 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +# Note: supernode_protection_score, supernode_connectivity_score are computed +# by the supernode analysis pipeline, not as standalone metrics +metrics: + rayleigh_quotient: + enabled: true + relative: true + regularization: 1.0e-6 + + redundancy: + enabled: true + + magnitude: + enabled: true + + scar: + enabled: true + num_samples: 64 + max_length: 512 + + composite_weights: + rayleigh_quotient: 0.25 + redundancy: -0.25 + magnitude: 0.15 + scar_loss_proxy: 0.35 + +# ----------------------------------------------------------------------------- +# SUPERNODE +# ----------------------------------------------------------------------------- +supernode: + enabled: true + score_metric: "scar_loss_proxy" + core_fraction: 0.01 + halo_fraction: 0.10 + follower_fraction: 0.10 + protect_core: true + cross_layer_analysis: true + compare_by_connection: true + compute_metrics: + - "activation" + - "rayleigh_quotient" + - "mutual_information" + - "redundancy" + +# ----------------------------------------------------------------------------- +# CLUSTERING (disabled for LLM) +# ----------------------------------------------------------------------------- +clustering: + enabled: false + n_clusters: 4 + features: ["rayleigh_quotient", "redundancy", "magnitude"] + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + sample_pairs: 2000 + max_refs: 512 + max_pairs_per_group: 1000 + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + selection_modes: ["low", "high"] + distribution: "uniform" + structured: true + dependency_aware: true + + algorithms: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "wanda" + - "sparsegpt" + + scoring_methods: + - "random" + - "magnitude" + - "rayleigh_quotient" + - "redundancy" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "generalized_importance" + - "cross_layer_importance" + - "within_layer_importance" + - "wanda" + - "sparsegpt" + + fine_tune: + enabled: false + epochs: 1 + learning_rate: 1.0e-5 + +# ----------------------------------------------------------------------------- +# EVALUATION +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + accuracy: false + loss: true + + perplexity_enabled: true + perplexity_datasets: + - name: "wikitext" + subset: "wikitext-2-raw-v1" + split: "test" + - name: "c4" + split: "validation" + max_samples: 1000 + + benchmarks_enabled: true + benchmark_tasks: + - name: "hellaswag" + num_fewshot: 0 + - name: "piqa" + num_fewshot: 0 + - name: "boolq" + num_fewshot: 0 + - name: "winogrande" + num_fewshot: 0 + - name: "arc_easy" + num_fewshot: 0 + - name: "arc_challenge" + num_fewshot: 0 + - name: "openbookqa" + num_fewshot: 0 + - name: "hellaswag" + num_fewshot: 5 + - name: "piqa" + num_fewshot: 5 + - name: "arc_challenge" + num_fewshot: 5 + - name: "mmlu" + num_fewshot: 5 + benchmark_fewshot: 0 + benchmark_batch_size: 8 + +# ----------------------------------------------------------------------------- +# VISUALIZATION +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + + histograms: true + violin_plots: true + correlation_heatmap: true + pruning_comparison: true + pruning_recovery: true + supernode_distribution: true + halo_structure: true + cross_layer_heatmap: true + + scatter_pairs: + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "magnitude"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/paper/qwen2_7b" + save_metrics: true + save_figures: true + save_checkpoints: false + +# ----------------------------------------------------------------------------- +# EXTRA +# ----------------------------------------------------------------------------- +extra: + analysis: + # Qwen2-7B has 28 layers + layer_indices: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27] + save_scores: true + generate_plots: true + metrics: + - "activation_l2_norm" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "average_redundancy" + - "scar_activation_power" + - "scar_curvature" + - "scar_loss_proxy" + - "scar_taylor" + - "supernode_protection_score" + - "supernode_connectivity_score" + - "halo_redundancy" + - "cross_layer_importance" + - "within_layer_importance" + plots: + histograms: true + scatter_plots: true + pruning_curves: true + redundancy_heatmaps: true + scatter_pairs: + - ["activation_l2_norm", "rayleigh_quotient"] + - ["scar_activation_power", "scar_loss_proxy"] + - ["rayleigh_quotient", "scar_loss_proxy"] + - ["average_redundancy", "rayleigh_quotient"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + + supernode_robustness: + enabled: true + supernode_fraction: 0.01 + num_bootstrap_samples: 10 + batch_size: 32 + max_samples: 256 + metrics: + - "scar_activation_power" + - "scar_loss_proxy" + - "scar_taylor" + - "rayleigh_quotient" + - "gaussian_mi_analytic" + - "activation_l2_norm" + target_layers: null + + supernode_summary: + enabled: true + outlier_analysis: true + + halo_analysis: + enabled: true + supernode_fraction: 0.01 + halo_fraction: 0.10 + num_samples: 8 + max_length: 256 + sample_pairs: 2000 + max_samples: 2000 + max_pairs_per_group: 1000 + plots: + depth_comparison: true + histograms: true + heatmaps: true + comprehensive: true + + cross_layer: + enabled: true + max_refs: 512 + rq_weight: 0.25 + downstream_weight: 0.35 + within_redundancy_weight: 0.25 + activation_weight: 0.15 + normalize: true + + generalized_importance: + enabled: true + neighborhood_fraction: 0.10 + propagation_weight: 0.3 + redundancy_penalty: 0.5 + num_samples: 8 + max_length: 256 + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + histograms: true + scatter_plots: true + heatmaps: true + supernode_distribution: + enabled: true + plot_loss_proxy_histogram: true + plot_concentration_by_layer: true + highlight_top_percent: [1, 5, 10] + halo_structure: + enabled: true + plot_redundancy_by_depth: true + plot_protection_vs_connection: true + plot_within_vs_cross_group: true + cross_layer: + enabled: true + plot_downstream_importance_by_layer: true + plot_importance_vs_redundancy: true + plot_efficiency_heatmap: true + pruning_curves: + enabled: true + plot_sparsity_vs_perplexity: true + plot_sparsity_vs_accuracy: true + metrics_to_compare: + - "supernode_connectivity_score" + - "cross_layer_importance" + - "generalized_importance" + - "scar_loss_proxy" + - "wanda" + - "sparsegpt" + - "activation_l2_norm" + - "random" + supernode_robustness: + enabled: true + jaccard_heatmap: true + spearman_heatmap: true + bootstrap_stability: true + consistency_bars: true + scatter_pairs: + - ["activation_l2_norm", "scar_loss_proxy"] + - ["scar_loss_proxy", "supernode_connectivity_score"] + - ["cross_layer_importance", "activation_l2_norm"] + - ["generalized_importance", "scar_loss_proxy"] + + do_scar_metrics: true + do_directed_redundancy: true + do_connectivity_pruning: true + do_halo_analysis: true + do_generalized_importance: true diff --git a/configs/unified_template.yaml b/configs/unified_template.yaml new file mode 100644 index 00000000..b7116566 --- /dev/null +++ b/configs/unified_template.yaml @@ -0,0 +1,272 @@ +# ============================================================================= +# UNIFIED ALIGNMENT FRAMEWORK - CONFIG TEMPLATE +# ============================================================================= +# Works for both Vision (ResNet, VGG) and LLM (Llama, Qwen) experiments +# Usage: python scripts/run_experiment.py --config configs/unified_template.yaml +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT (unified structure) +# ----------------------------------------------------------------------------- +experiment: + name: "my_experiment" + # Types: "alignment_analysis" (general), "cluster_analysis" (vision clustering), + # "llm_alignment" (LLM pruning) + type: "alignment_analysis" + seed: 42 + device: "cuda" + output_dir: "./results" + +# ----------------------------------------------------------------------------- +# MODEL (unified for vision and LLM) +# ----------------------------------------------------------------------------- +model: + # Vision: "resnet18", "resnet50", "vgg16", "mobilenet_v2" + # LLM: "hf_causal_lm" (requires model_id) + name: "resnet18" + pretrained: true + num_classes: 10 # Vision only + + # HuggingFace (LLM only) + model_id: null # e.g., "meta-llama/Llama-3.1-8B" + dtype: "bfloat16" + device_map: "auto" + + # Layers to track (null = auto-detect) + tracked_layers: null + +# ----------------------------------------------------------------------------- +# DATASET (unified) +# ----------------------------------------------------------------------------- +dataset: + name: "cifar10" # Vision: cifar10, cifar100, imagenet + # LLM: wikitext, c4 + root: "./data" + batch_size: 128 + num_workers: 4 + +# ----------------------------------------------------------------------------- +# CALIBRATION (samples for metric computation) +# ----------------------------------------------------------------------------- +calibration: + num_samples: 5000 # Vision: typically 5000 + # LLM: typically 64-128 + max_length: 2048 # LLM only: sequence length + +# ----------------------------------------------------------------------------- +# METRICS (unified naming with aliases) +# ----------------------------------------------------------------------------- +metrics: + # Core metrics (available for both vision and LLM) + rayleigh_quotient: + enabled: true + relative: true + regularization: 1.0e-6 + + redundancy: # Gaussian pairwise MI + enabled: true + sampling: "all" # all, random, top_k + num_pairs: 10 # For pairwise computation + + synergy: # PID-based synergy with target + enabled: true + target: "logit_margin" # Vision: logit_margin, correct_logit + # LLM: loss_proxy, perplexity_delta + num_pairs: 10 + + magnitude: + enabled: true + type: "l2_norm" # l2_norm, l1_norm, max + + # LLM-specific metrics (ignored for vision) + scar: + enabled: false # Enable for LLM + activation_power: true + curvature: true + loss_proxy: true + taylor: true + +# ----------------------------------------------------------------------------- +# CLUSTERING (unified for both) +# ----------------------------------------------------------------------------- +clustering: + enabled: true + n_clusters: 4 + # Cluster type names (semantic mapping) + type_names: ["critical", "redundant", "synergistic", "background"] + normalize_features: true + + # Features to cluster on + features: + - "rayleigh_quotient" + - "redundancy" + - "synergy" + + # Stability analysis + stability: + enabled: true + n_bootstrap: 50 + +# ----------------------------------------------------------------------------- +# SUPERNODE DETECTION (alternative to clustering) +# ----------------------------------------------------------------------------- +supernode: + enabled: false # Enable for outlier-based analysis + score_metric: "synergy" # Or: rayleigh_quotient, scar_loss_proxy + core_fraction: 0.01 # Top 1% as supernodes + halo_fraction: 0.10 # 10% as halo neurons + protect_core: true + +# ----------------------------------------------------------------------------- +# HALO / CROSS-LAYER ANALYSIS (unified) +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 # Top X% by influence + use_activation_weight: true # Weight by activation magnitude + + # Compute influence/dependency matrices + compute_influence_matrix: true + + # For LLM: analyze downstream impact + cross_layer: + enabled: true + max_refs: 512 + +# ----------------------------------------------------------------------------- +# CASCADE / DAMAGE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 # Neurons to remove per cluster/supernode + damage_sample_fraction: 0.2 # Fraction of data to evaluate damage + +# ----------------------------------------------------------------------------- +# PRUNING (unified interface) +# ----------------------------------------------------------------------------- +pruning: + enabled: true + + # Sparsity levels + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + # Methods to compare (available for both vision and LLM) + methods: + # Baselines + - name: "random" + - name: "magnitude" + + # Alignment-based + - name: "rayleigh_quotient" + selection: "low" # Prune low-RQ neurons + - name: "redundancy" + selection: "high" # Prune high-redundancy neurons + - name: "synergy" + selection: "low" # Prune low-synergy neurons + + # Composite methods + - name: "composite" + weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 # Negative = penalize redundancy + synergy: 0.33 + + # Structure-aware + - name: "cluster_aware" # Respects cluster boundaries + - name: "supernode_aware" # Protects supernodes + + # LLM-specific (ignored for vision) + - name: "wanda" + - name: "sparsegpt" + - name: "scar_loss_proxy" + + # Selection modes + selection_modes: ["low", "high"] + + # Fine-tuning after pruning + fine_tune: + enabled: true + epochs: 10 + learning_rate: 0.0001 + +# ----------------------------------------------------------------------------- +# EVALUATION +# ----------------------------------------------------------------------------- +evaluation: + # Vision + accuracy: true + loss: true + + # LLM + perplexity: + enabled: false # Enable for LLM + datasets: ["wikitext"] + + benchmarks: # LLM benchmarks + enabled: false + tasks: ["hellaswag", "piqa", "arc_easy"] + num_fewshot: 0 + +# ----------------------------------------------------------------------------- +# VISUALIZATION (unified) +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "png" # png, pdf, svg + dpi: 300 + + # Plots to generate (works for both) + plots: + # Metric distributions + histograms: true + violin_plots: true + correlation_heatmap: true + + # Clustering/Supernode + cluster_scatter: true # 2D/3D metric space scatter + cluster_evolution: true # Type composition across depth + + # Cross-layer + influence_matrix: true + halo_properties: true + + # Pruning + pruning_comparison: true + pruning_recovery: true + + # Cascade + cascade_test: true + + # Scatter plot pairs + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["redundancy", "synergy"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results" + save_metrics: true + save_clusters: true + save_figures: true + save_checkpoints: false + +# ============================================================================= +# PRESETS (inherit and override) +# ============================================================================= +# You can create preset files and inherit from them: +# +# # In configs/presets/vision_base.yaml: +# _preset: vision_base +# experiment: +# type: cluster_analysis +# model: +# name: resnet18 +# +# # In your config: +# _inherit: presets/vision_base +# model: +# name: resnet50 # Override +# ============================================================================= diff --git a/configs/cluster_analysis/README.md b/configs/vision_prune/README.md similarity index 100% rename from configs/cluster_analysis/README.md rename to configs/vision_prune/README.md diff --git a/configs/cluster_analysis/mobilenetv2_cifar10_full.yaml b/configs/vision_prune/mobilenetv2_cifar10_full.yaml similarity index 100% rename from configs/cluster_analysis/mobilenetv2_cifar10_full.yaml rename to configs/vision_prune/mobilenetv2_cifar10_full.yaml diff --git a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml new file mode 100644 index 00000000..63015ff8 --- /dev/null +++ b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml @@ -0,0 +1,392 @@ +# ============================================================================= +# MobileNetV2 on CIFAR-10 - UNIFIED FORMAT (ENHANCED) +# ============================================================================= +# Full cluster analysis pipeline for MobileNetV2 on CIFAR-10. +# MobileNetV2 is particularly interesting for pruning research due to: +# - Inverted residual blocks with depthwise separable convolutions +# - Already efficient architecture - pruning must be careful +# - Linear bottleneck design affects which channels are safe to prune +# +# Key features: +# - Uses unified metric naming +# - Comprehensive evaluation metrics +# - Full visualization pipeline for paper figures +# - Layer-wise sensitivity analysis +# +# Usage: python scripts/run_experiment.py --config configs/vision_prune/mobilenetv2_cifar10_unified.yaml +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "mobilenetv2_cifar10_cluster_analysis" + type: "cluster_analysis" + seed: 42 + device: "cuda" + output_dir: "./results/vision/mobilenetv2_cifar10" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "mobilenet_v2" + pretrained: true + num_classes: 10 + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "cifar10" + root: "./data" + batch_size: 128 + num_workers: 4 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 5000 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +metrics: + rayleigh_quotient: + enabled: true + relative: true + shrinkage: true + + redundancy: + enabled: true + sampling: "all" + + synergy: + enabled: true + target: "logit_margin" + num_pairs: 10 + sampling: "top_k" + + magnitude: + enabled: true + + taylor: + enabled: true + criterion: "gradient_weight" + + activation_sparsity: + enabled: true + threshold: 0.01 + + composite_weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 + synergy: 0.33 + +# ----------------------------------------------------------------------------- +# CLUSTERING +# ----------------------------------------------------------------------------- +clustering: + enabled: true + n_clusters: 4 + type_names: ["critical", "redundant", "synergistic", "background"] + normalize_features: true + features: ["rayleigh_quotient", "redundancy", "synergy"] + + stability_enabled: true + n_bootstrap: 50 + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + compute_influence_matrix: true + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] # MobileNet is already efficient - conservative pruning + selection_modes: ["low", "high"] + + algorithms: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + scoring_methods: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + fine_tune: + enabled: true + epochs: 15 # MobileNet may need more fine-tuning + learning_rate: 0.0001 + weight_decay: 0.00001 + +# ----------------------------------------------------------------------------- +# EVALUATION (Enhanced for Vision) +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + + # Classification metrics + accuracy: true + top1_accuracy: true + top5_accuracy: true + loss: true + + # Per-class analysis + per_class_accuracy: true + confusion_matrix: true + + # Calibration metrics + calibration_enabled: true + expected_calibration_error: true + reliability_diagram: true + + # Efficiency metrics - especially important for MobileNet + compute_flops: true + compute_params: true + compute_memory: true + measure_latency: true + latency_batch_sizes: [1, 8, 32, 128] + + # Mobile-specific metrics + measure_cpu_latency: true + measure_peak_memory: true + + # Robustness (optional) + robustness_enabled: false + corruption_types: ["gaussian_noise", "shot_noise", "gaussian_blur", "contrast", "brightness"] + corruption_severities: [1, 3, 5] + + # Transfer evaluation (optional) + transfer_enabled: false + transfer_datasets: ["cifar100", "svhn"] + +# ----------------------------------------------------------------------------- +# BENCHMARKS (Vision-specific) +# ----------------------------------------------------------------------------- +benchmarks: + enabled: true + + tasks: + - name: "cifar10_test" + dataset: "cifar10" + split: "test" + enabled: true + + - name: "cifar100_transfer" + dataset: "cifar100" + split: "test" + enabled: false + + inference: + warmup_iterations: 10 + benchmark_iterations: 100 + batch_sizes: [1, 8, 32, 128] + devices: ["cuda", "cpu"] # CPU is relevant for MobileNet + + adversarial: + enabled: false + attacks: ["fgsm", "pgd"] + epsilons: [0.01, 0.03, 0.1] + +# ----------------------------------------------------------------------------- +# VISUALIZATION (Enhanced) +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + histograms: true + violin_plots: true + correlation_heatmap: true + cluster_scatter: true + cluster_evolution: true + influence_matrix: true + halo_properties: true + pruning_comparison: true + pruning_recovery: true + cascade_test: true + + # Additional analysis plots + metric_distributions: true + layer_importance_heatmap: true + sensitivity_curves: true + efficiency_tradeoffs: true + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["redundancy", "synergy"] + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "taylor"] + - ["taylor", "rayleigh_quotient"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/vision/mobilenetv2_cifar10" + save_metrics: true + save_clusters: true + save_figures: true + save_checkpoints: true + save_per_layer: true + +# ----------------------------------------------------------------------------- +# EXTRA (Vision-specific detailed settings) +# ----------------------------------------------------------------------------- +extra: + pretrain_epochs: 20 + pretrain_lr: 0.001 + + baselines: + - "magnitude" + - "taylor" + - "network_slimming" + - "geometric_median" + + analysis: + layer_indices: "all" + save_scores: true + generate_plots: true + + # Analyze different block types separately + analyze_by_block_type: true + block_types: + - "expansion" # 1x1 expansion convs + - "depthwise" # 3x3 depthwise convs + - "projection" # 1x1 projection/bottleneck convs + + metrics: + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "magnitude" + - "taylor" + - "activation_sparsity" + + plots: + histograms: true + scatter_plots: true + pruning_curves: true + layer_comparison: true + block_type_comparison: true # Compare across expansion/depthwise/projection + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["magnitude", "taylor"] + - ["redundancy", "synergy"] + + sensitivity_analysis: + enabled: true + per_layer: true + per_block_type: true # Sensitivity by block type + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] + metric: "accuracy" + output_dir: "sensitivity" + + structured_pruning: + enabled: true + granularity: "filter" + # MobileNet-specific: careful with depthwise layers + skip_depthwise: true + importance_criteria: + - "l1_norm" + - "l2_norm" + - "taylor" + - "alignment" + + feature_analysis: + enabled: true + compute_feature_rank: true + compute_channel_redundancy: true + visualize_filters: false + num_samples_to_visualize: 10 + + efficiency: + track_flops: true + track_params: true + track_memory: true + track_latency: true + track_cpu_latency: true # Important for mobile deployment + baseline_comparison: true + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + metric_distributions: + enabled: true + by_layer: true + by_cluster: true + by_block_type: true # MobileNet-specific + + cluster_analysis: + enabled: true + scatter_3d: true + cluster_evolution_by_layer: true + cluster_purity: true + + pruning_comparison: + enabled: true + accuracy_vs_sparsity: true + accuracy_vs_flops: true + accuracy_vs_params: true + accuracy_vs_latency: true # Important for MobileNet + methods_to_compare: + - "random" + - "magnitude" + - "taylor" + - "composite" + - "cluster_aware" + + layer_importance: + enabled: true + heatmap: true + bar_chart: true + by_block_type: true # MobileNet-specific + + fine_tuning_recovery: + enabled: true + by_method: true + by_sparsity: true + + efficiency_tradeoffs: + enabled: true + accuracy_vs_flops: true + accuracy_vs_latency: true + accuracy_vs_params: true + accuracy_vs_cpu_latency: true # Mobile deployment consideration diff --git a/configs/cluster_analysis/resnet18_cifar10_full.yaml b/configs/vision_prune/resnet18_cifar10_full.yaml similarity index 100% rename from configs/cluster_analysis/resnet18_cifar10_full.yaml rename to configs/vision_prune/resnet18_cifar10_full.yaml diff --git a/configs/vision_prune/resnet18_cifar10_unified.yaml b/configs/vision_prune/resnet18_cifar10_unified.yaml new file mode 100644 index 00000000..a8a0e0be --- /dev/null +++ b/configs/vision_prune/resnet18_cifar10_unified.yaml @@ -0,0 +1,401 @@ +# ============================================================================= +# ResNet-18 on CIFAR-10 - UNIFIED FORMAT (ENHANCED) +# ============================================================================= +# Full cluster analysis pipeline for ResNet-18 on CIFAR-10 with comprehensive +# evaluation, benchmarks, and analysis sections for vision pruning research. +# +# Key features: +# - Uses unified metric naming (rayleigh_quotient, redundancy, synergy, magnitude) +# - Comprehensive evaluation metrics (accuracy, efficiency, per-class) +# - Full visualization pipeline for paper figures +# - Layer-wise sensitivity analysis +# +# Usage: python scripts/run_experiment.py --config configs/vision_prune/resnet18_cifar10_unified.yaml +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "resnet18_cifar10_cluster_analysis" + type: "cluster_analysis" + seed: 42 + device: "cuda" + output_dir: "./results/vision/resnet18_cifar10" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "resnet18" + pretrained: true + num_classes: 10 + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "cifar10" + root: "./data" + batch_size: 128 + num_workers: 4 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 5000 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +# Unified naming convention: +# rayleigh_quotient (alias: rq, compute_rq) +# redundancy (alias: gaussian_mi_analytic, average_redundancy, pairwise_redundancy) +# synergy (alias: synergy_gaussian_mmi) +# magnitude (alias: activation_l2_norm) +# ----------------------------------------------------------------------------- +metrics: + rayleigh_quotient: + enabled: true + relative: true + shrinkage: true + + redundancy: + enabled: true + sampling: "all" # all, random, top_k + + synergy: + enabled: true + target: "logit_margin" # logit_margin, correct_logit, logit_pc1 + num_pairs: 10 + sampling: "top_k" + + magnitude: + enabled: true + + taylor: + enabled: true + criterion: "gradient_weight" # gradient_weight, gradient_activation + + activation_sparsity: + enabled: true + threshold: 0.01 + + # Composite weights for combined scoring + composite_weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 # Negative = penalize redundancy + synergy: 0.33 + +# ----------------------------------------------------------------------------- +# CLUSTERING +# ----------------------------------------------------------------------------- +clustering: + enabled: true + n_clusters: 4 + type_names: ["critical", "redundant", "synergistic", "background"] + normalize_features: true + features: ["rayleigh_quotient", "redundancy", "synergy"] + + stability_enabled: true + n_bootstrap: 50 + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS (Cross-layer dependencies) +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + compute_influence_matrix: true + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS (Damage testing) +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + selection_modes: ["low", "high"] + + # Algorithms to test (unified naming) + algorithms: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + # Scoring methods (for comparison) + scoring_methods: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + fine_tune: + enabled: true + epochs: 10 + learning_rate: 0.0001 + weight_decay: 0.0001 + +# ----------------------------------------------------------------------------- +# EVALUATION (Enhanced for Vision) +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + + # Classification metrics + accuracy: true + top1_accuracy: true + top5_accuracy: true + loss: true + + # Per-class analysis + per_class_accuracy: true + confusion_matrix: true + + # Calibration metrics + calibration_enabled: true + expected_calibration_error: true + reliability_diagram: true + + # Efficiency metrics + compute_flops: true + compute_params: true + compute_memory: true + measure_latency: true + latency_batch_sizes: [1, 8, 32, 128] + + # Robustness (optional - requires corruption data) + robustness_enabled: false + corruption_types: ["gaussian_noise", "shot_noise", "impulse_noise", "gaussian_blur", "contrast", "brightness"] + corruption_severities: [1, 3, 5] + + # Transfer evaluation (optional) + transfer_enabled: false + transfer_datasets: ["cifar100", "svhn"] + +# ----------------------------------------------------------------------------- +# BENCHMARKS (Vision-specific) +# ----------------------------------------------------------------------------- +benchmarks: + enabled: true + + # Standard test benchmarks + tasks: + - name: "cifar10_test" + dataset: "cifar10" + split: "test" + enabled: true + + - name: "cifar100_transfer" + dataset: "cifar100" + split: "test" + enabled: false + + # Inference benchmarks + inference: + warmup_iterations: 10 + benchmark_iterations: 100 + batch_sizes: [1, 8, 32, 128] + devices: ["cuda"] + + # Adversarial robustness (optional) + adversarial: + enabled: false + attacks: ["fgsm", "pgd"] + epsilons: [0.01, 0.03, 0.1] + +# ----------------------------------------------------------------------------- +# VISUALIZATION (Enhanced) +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" # pdf for paper quality + dpi: 300 + style: "seaborn-v0_8-paper" + + # Basic plots + histograms: true + violin_plots: true + correlation_heatmap: true + cluster_scatter: true + cluster_evolution: true + influence_matrix: true + halo_properties: true + pruning_comparison: true + pruning_recovery: true + cascade_test: true + + # Additional analysis plots + metric_distributions: true + layer_importance_heatmap: true + sensitivity_curves: true + efficiency_tradeoffs: true + + # Scatter plot pairs (unified naming) + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["redundancy", "synergy"] + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "taylor"] + - ["taylor", "rayleigh_quotient"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/vision/resnet18_cifar10" + save_metrics: true + save_clusters: true + save_figures: true + save_checkpoints: true + save_per_layer: true + +# ----------------------------------------------------------------------------- +# EXTRA (Vision-specific detailed settings) +# ----------------------------------------------------------------------------- +extra: + # Pre-training for ImageNet pretrained models on CIFAR + pretrain_epochs: 20 + pretrain_lr: 0.001 + + # Baselines to compare against + baselines: + - "magnitude" + - "taylor" + - "network_slimming" + - "geometric_median" + + # Layer-wise analysis + analysis: + layer_indices: "all" # or specific: [0, 2, 4, 6, 8] + save_scores: true + generate_plots: true + + # Metrics to compute per layer + metrics: + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "magnitude" + - "taylor" + - "activation_sparsity" + + # Plots to generate + plots: + histograms: true + scatter_plots: true + pruning_curves: true + layer_comparison: true + filter_correlation: true + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["magnitude", "taylor"] + - ["redundancy", "synergy"] + + # Pruning sensitivity analysis + sensitivity_analysis: + enabled: true + per_layer: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] + metric: "accuracy" + output_dir: "sensitivity" + + # Structured pruning options + structured_pruning: + enabled: true + granularity: "filter" # filter, channel, block + importance_criteria: + - "l1_norm" + - "l2_norm" + - "taylor" + - "alignment" + + # Feature analysis + feature_analysis: + enabled: true + compute_feature_rank: true + compute_channel_redundancy: true + visualize_filters: false # Set true for filter visualization (slow) + num_samples_to_visualize: 10 + + # Efficiency tracking + efficiency: + track_flops: true + track_params: true + track_memory: true + track_latency: true + baseline_comparison: true + + # Paper figure generation + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + # Figure 1: Metric distributions by layer + metric_distributions: + enabled: true + by_layer: true + by_cluster: true + + # Figure 2: Cluster analysis + cluster_analysis: + enabled: true + scatter_3d: true + cluster_evolution_by_layer: true + cluster_purity: true + + # Figure 3: Pruning comparison + pruning_comparison: + enabled: true + accuracy_vs_sparsity: true + accuracy_vs_flops: true + accuracy_vs_params: true + methods_to_compare: + - "random" + - "magnitude" + - "taylor" + - "composite" + - "cluster_aware" + - "network_slimming" + + # Figure 4: Layer-wise importance + layer_importance: + enabled: true + heatmap: true + bar_chart: true + + # Figure 5: Recovery after fine-tuning + fine_tuning_recovery: + enabled: true + by_method: true + by_sparsity: true + + # Figure 6: Efficiency vs Accuracy tradeoffs + efficiency_tradeoffs: + enabled: true + accuracy_vs_flops: true + accuracy_vs_latency: true + accuracy_vs_params: true diff --git a/configs/cluster_analysis/resnet50_imagenet100.yaml b/configs/vision_prune/resnet50_imagenet100.yaml similarity index 100% rename from configs/cluster_analysis/resnet50_imagenet100.yaml rename to configs/vision_prune/resnet50_imagenet100.yaml diff --git a/configs/vision_prune/resnet50_imagenet100_unified.yaml b/configs/vision_prune/resnet50_imagenet100_unified.yaml new file mode 100644 index 00000000..fe55bac8 --- /dev/null +++ b/configs/vision_prune/resnet50_imagenet100_unified.yaml @@ -0,0 +1,409 @@ +# ============================================================================= +# ResNet-50 on ImageNet-100 - UNIFIED FORMAT (ENHANCED) +# ============================================================================= +# Full cluster analysis pipeline for ResNet-50 on ImageNet-100 subset. +# ImageNet-100 is a 100-class subset of ImageNet for tractable experiments +# while maintaining realistic scale and complexity. +# +# Key features: +# - Uses unified metric naming +# - Comprehensive evaluation metrics +# - Full visualization pipeline for paper figures +# - Layer-wise sensitivity analysis +# - Higher resolution images (224x224) +# +# Usage: python scripts/run_experiment.py --config configs/vision_prune/resnet50_imagenet100_unified.yaml +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "resnet50_imagenet100_cluster_analysis" + type: "cluster_analysis" + seed: 42 + device: "cuda" + output_dir: "./results/vision/resnet50_imagenet100" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "resnet50" + pretrained: true + num_classes: 100 + weights: "IMAGENET1K_V2" # Use torchvision pretrained weights + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "imagenet100" + root: "./data/imagenet100" + batch_size: 64 + num_workers: 8 + image_size: 224 + normalize: true + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 5000 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +metrics: + rayleigh_quotient: + enabled: true + relative: true + shrinkage: true + + redundancy: + enabled: true + sampling: "all" + + synergy: + enabled: true + target: "logit_margin" + num_pairs: 10 + sampling: "top_k" + + magnitude: + enabled: true + + taylor: + enabled: true + criterion: "gradient_weight" + + activation_sparsity: + enabled: true + threshold: 0.01 + + composite_weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 + synergy: 0.33 + +# ----------------------------------------------------------------------------- +# CLUSTERING +# ----------------------------------------------------------------------------- +clustering: + enabled: true + n_clusters: 4 + type_names: ["critical", "redundant", "synergistic", "background"] + normalize_features: true + features: ["rayleigh_quotient", "redundancy", "synergy"] + + stability_enabled: true + n_bootstrap: 30 # Fewer for larger model + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + compute_influence_matrix: true + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.1 # Smaller for faster computation + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.2, 0.3, 0.4, 0.5, 0.6] + selection_modes: ["low", "high"] + + algorithms: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + - "network_slimming" + + scoring_methods: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + fine_tune: + enabled: true + epochs: 5 # Fewer epochs for ImageNet + learning_rate: 0.00001 + weight_decay: 0.0001 + +# ----------------------------------------------------------------------------- +# EVALUATION (Enhanced for Vision) +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + + # Classification metrics + accuracy: true + top1_accuracy: true + top5_accuracy: true + loss: true + + # Per-class analysis + per_class_accuracy: true + confusion_matrix: true + + # Calibration metrics + calibration_enabled: true + expected_calibration_error: true + reliability_diagram: true + + # Efficiency metrics + compute_flops: true + compute_params: true + compute_memory: true + measure_latency: true + latency_batch_sizes: [1, 8, 32, 64] + + # Robustness (optional - requires ImageNet-C) + robustness_enabled: false + corruption_types: ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression"] + corruption_severities: [1, 3, 5] + + # Transfer evaluation (optional) + transfer_enabled: false + transfer_datasets: ["imagenet_v2", "imagenet_a", "imagenet_r"] + +# ----------------------------------------------------------------------------- +# BENCHMARKS (Vision-specific) +# ----------------------------------------------------------------------------- +benchmarks: + enabled: true + + tasks: + - name: "imagenet100_val" + dataset: "imagenet100" + split: "val" + enabled: true + + - name: "imagenet_v2" + dataset: "imagenet_v2" + enabled: false + + - name: "imagenet_a" + dataset: "imagenet_a" + enabled: false + + inference: + warmup_iterations: 10 + benchmark_iterations: 50 # Fewer for larger model + batch_sizes: [1, 8, 32, 64] + devices: ["cuda"] + + adversarial: + enabled: false + attacks: ["fgsm", "pgd"] + epsilons: [0.01, 0.03] + +# ----------------------------------------------------------------------------- +# VISUALIZATION (Enhanced) +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + histograms: true + violin_plots: true + correlation_heatmap: true + cluster_scatter: true + cluster_evolution: true + influence_matrix: true + halo_properties: true + pruning_comparison: true + pruning_recovery: true + cascade_test: true + metric_distributions: true + + # Additional analysis plots + layer_importance_heatmap: true + sensitivity_curves: true + efficiency_tradeoffs: true + block_analysis: true # ResNet block structure + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["redundancy", "synergy"] + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "taylor"] + - ["taylor", "rayleigh_quotient"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/vision/resnet50_imagenet100" + save_metrics: true + save_clusters: true + save_figures: true + save_checkpoints: true + save_per_layer: true + +# ----------------------------------------------------------------------------- +# EXTRA (Vision-specific detailed settings) +# ----------------------------------------------------------------------------- +extra: + pretrain_epochs: 10 # Less for ImageNet pretrained + pretrain_lr: 0.0001 + + baselines: + - "magnitude" + - "taylor" + - "network_slimming" + - "geometric_median" + - "hrank" # HRank pruning for ResNet + + analysis: + layer_indices: "all" + save_scores: true + generate_plots: true + + # Analyze ResNet stages separately + analyze_by_stage: true + stages: + - "conv1" # Initial 7x7 conv + - "layer1" # Stage 1 (56x56) + - "layer2" # Stage 2 (28x28) + - "layer3" # Stage 3 (14x14) + - "layer4" # Stage 4 (7x7) + + metrics: + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "magnitude" + - "taylor" + - "activation_sparsity" + - "feature_rank" # For HRank-style analysis + + plots: + histograms: true + scatter_plots: true + pruning_curves: true + layer_comparison: true + stage_comparison: true # Compare across ResNet stages + residual_analysis: true # Analyze skip connections + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["magnitude", "taylor"] + - ["redundancy", "synergy"] + - ["magnitude", "feature_rank"] + + sensitivity_analysis: + enabled: true + per_layer: true + per_stage: true # Sensitivity by ResNet stage + ratios: [0.2, 0.3, 0.4, 0.5] + metric: "accuracy" + output_dir: "sensitivity" + + structured_pruning: + enabled: true + granularity: "filter" + # ResNet-specific: handle residual connections + preserve_residual_dimensions: true + importance_criteria: + - "l1_norm" + - "l2_norm" + - "taylor" + - "alignment" + - "feature_rank" + + feature_analysis: + enabled: true + compute_feature_rank: true + compute_channel_redundancy: true + visualize_filters: false + num_samples_to_visualize: 10 + # ResNet-specific + analyze_residual_contribution: true + + efficiency: + track_flops: true + track_params: true + track_memory: true + track_latency: true + baseline_comparison: true + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + metric_distributions: + enabled: true + by_layer: true + by_cluster: true + by_stage: true # ResNet-specific + + cluster_analysis: + enabled: true + scatter_3d: true + cluster_evolution_by_layer: true + cluster_purity: true + + pruning_comparison: + enabled: true + accuracy_vs_sparsity: true + accuracy_vs_flops: true + accuracy_vs_params: true + methods_to_compare: + - "random" + - "magnitude" + - "taylor" + - "composite" + - "cluster_aware" + - "network_slimming" + - "hrank" + + layer_importance: + enabled: true + heatmap: true + bar_chart: true + by_stage: true # ResNet-specific + + fine_tuning_recovery: + enabled: true + by_method: true + by_sparsity: true + + efficiency_tradeoffs: + enabled: true + accuracy_vs_flops: true + accuracy_vs_latency: true + accuracy_vs_params: true + + # ResNet-specific visualizations + residual_analysis: + enabled: true + skip_connection_importance: true + bottleneck_analysis: true diff --git a/configs/cluster_analysis/vgg16_cifar10_full.yaml b/configs/vision_prune/vgg16_cifar10_full.yaml similarity index 100% rename from configs/cluster_analysis/vgg16_cifar10_full.yaml rename to configs/vision_prune/vgg16_cifar10_full.yaml diff --git a/configs/vision_prune/vgg16_cifar10_unified.yaml b/configs/vision_prune/vgg16_cifar10_unified.yaml new file mode 100644 index 00000000..846377da --- /dev/null +++ b/configs/vision_prune/vgg16_cifar10_unified.yaml @@ -0,0 +1,373 @@ +# ============================================================================= +# VGG-16-BN on CIFAR-10 - UNIFIED FORMAT (ENHANCED) +# ============================================================================= +# Full cluster analysis pipeline for VGG-16 with batch normalization. +# VGG is particularly interesting for pruning due to its high redundancy +# in fully-connected layers and uniform filter sizes. +# +# Key features: +# - Uses unified metric naming +# - Comprehensive evaluation metrics +# - Full visualization pipeline for paper figures +# - Layer-wise sensitivity analysis +# +# Usage: python scripts/run_experiment.py --config configs/vision_prune/vgg16_cifar10_unified.yaml +# ============================================================================= + +# ----------------------------------------------------------------------------- +# EXPERIMENT +# ----------------------------------------------------------------------------- +experiment: + name: "vgg16_cifar10_cluster_analysis" + type: "cluster_analysis" + seed: 42 + device: "cuda" + output_dir: "./results/vision/vgg16_cifar10" + +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +model: + name: "vgg16_bn" + pretrained: true + num_classes: 10 + +# ----------------------------------------------------------------------------- +# DATASET +# ----------------------------------------------------------------------------- +dataset: + name: "cifar10" + root: "./data" + batch_size: 128 + num_workers: 4 + +# ----------------------------------------------------------------------------- +# CALIBRATION +# ----------------------------------------------------------------------------- +calibration: + num_samples: 5000 + +# ----------------------------------------------------------------------------- +# METRICS +# ----------------------------------------------------------------------------- +metrics: + rayleigh_quotient: + enabled: true + relative: true + shrinkage: true + + redundancy: + enabled: true + sampling: "all" + + synergy: + enabled: true + target: "logit_margin" + num_pairs: 10 + sampling: "top_k" + + magnitude: + enabled: true + + taylor: + enabled: true + criterion: "gradient_weight" + + activation_sparsity: + enabled: true + threshold: 0.01 + + composite_weights: + rayleigh_quotient: 0.33 + redundancy: -0.33 + synergy: 0.33 + +# ----------------------------------------------------------------------------- +# CLUSTERING +# ----------------------------------------------------------------------------- +clustering: + enabled: true + n_clusters: 4 + type_names: ["critical", "redundant", "synergistic", "background"] + normalize_features: true + features: ["rayleigh_quotient", "redundancy", "synergy"] + + stability_enabled: true + n_bootstrap: 50 + +# ----------------------------------------------------------------------------- +# HALO ANALYSIS +# ----------------------------------------------------------------------------- +halo_analysis: + enabled: true + percentile: 90.0 + use_activation_weight: true + compute_influence_matrix: true + +# ----------------------------------------------------------------------------- +# CASCADE ANALYSIS +# ----------------------------------------------------------------------------- +cascade_analysis: + enabled: true + n_remove_per_group: 5 + damage_sample_fraction: 0.2 + +# ----------------------------------------------------------------------------- +# PRUNING +# ----------------------------------------------------------------------------- +pruning: + enabled: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] # VGG can be pruned more aggressively + selection_modes: ["low", "high"] + + algorithms: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + - "network_slimming" + + scoring_methods: + - "random" + - "magnitude" + - "taylor" + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "composite" + - "cluster_aware" + + fine_tune: + enabled: true + epochs: 10 + learning_rate: 0.0001 + weight_decay: 0.0001 + +# ----------------------------------------------------------------------------- +# EVALUATION (Enhanced for Vision) +# ----------------------------------------------------------------------------- +evaluation: + enabled: true + + # Classification metrics + accuracy: true + top1_accuracy: true + top5_accuracy: true + loss: true + + # Per-class analysis + per_class_accuracy: true + confusion_matrix: true + + # Calibration metrics + calibration_enabled: true + expected_calibration_error: true + reliability_diagram: true + + # Efficiency metrics + compute_flops: true + compute_params: true + compute_memory: true + measure_latency: true + latency_batch_sizes: [1, 8, 32, 128] + + # Robustness (optional) + robustness_enabled: false + corruption_types: ["gaussian_noise", "shot_noise", "gaussian_blur", "contrast", "brightness"] + corruption_severities: [1, 3, 5] + + # Transfer evaluation (optional) + transfer_enabled: false + transfer_datasets: ["cifar100", "svhn"] + +# ----------------------------------------------------------------------------- +# BENCHMARKS (Vision-specific) +# ----------------------------------------------------------------------------- +benchmarks: + enabled: true + + tasks: + - name: "cifar10_test" + dataset: "cifar10" + split: "test" + enabled: true + + - name: "cifar100_transfer" + dataset: "cifar100" + split: "test" + enabled: false + + inference: + warmup_iterations: 10 + benchmark_iterations: 100 + batch_sizes: [1, 8, 32, 128] + devices: ["cuda"] + + adversarial: + enabled: false + attacks: ["fgsm", "pgd"] + epsilons: [0.01, 0.03, 0.1] + +# ----------------------------------------------------------------------------- +# VISUALIZATION (Enhanced) +# ----------------------------------------------------------------------------- +visualization: + enabled: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + histograms: true + violin_plots: true + correlation_heatmap: true + cluster_scatter: true + cluster_evolution: true + influence_matrix: true + halo_properties: true + pruning_comparison: true + pruning_recovery: true + cascade_test: true + + # Additional analysis plots + metric_distributions: true + layer_importance_heatmap: true + sensitivity_curves: true + efficiency_tradeoffs: true + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["redundancy", "synergy"] + - ["magnitude", "rayleigh_quotient"] + - ["magnitude", "taylor"] + - ["taylor", "rayleigh_quotient"] + +# ----------------------------------------------------------------------------- +# OUTPUT +# ----------------------------------------------------------------------------- +output: + dir: "./results/vision/vgg16_cifar10" + save_metrics: true + save_clusters: true + save_figures: true + save_checkpoints: true + save_per_layer: true + +# ----------------------------------------------------------------------------- +# EXTRA (Vision-specific detailed settings) +# ----------------------------------------------------------------------------- +extra: + pretrain_epochs: 20 + pretrain_lr: 0.001 + + baselines: + - "magnitude" + - "taylor" + - "network_slimming" + - "geometric_median" + + analysis: + layer_indices: "all" + save_scores: true + generate_plots: true + + metrics: + - "rayleigh_quotient" + - "redundancy" + - "synergy" + - "magnitude" + - "taylor" + - "activation_sparsity" + + plots: + histograms: true + scatter_plots: true + pruning_curves: true + layer_comparison: true + filter_correlation: true + + scatter_pairs: + - ["rayleigh_quotient", "redundancy"] + - ["rayleigh_quotient", "synergy"] + - ["magnitude", "taylor"] + - ["redundancy", "synergy"] + + sensitivity_analysis: + enabled: true + per_layer: true + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] # More ratios for VGG + metric: "accuracy" + output_dir: "sensitivity" + + structured_pruning: + enabled: true + granularity: "filter" + importance_criteria: + - "l1_norm" + - "l2_norm" + - "taylor" + - "alignment" + + feature_analysis: + enabled: true + compute_feature_rank: true + compute_channel_redundancy: true + visualize_filters: false + num_samples_to_visualize: 10 + + efficiency: + track_flops: true + track_params: true + track_memory: true + track_latency: true + baseline_comparison: true + + visualization: + save_plots: true + format: "pdf" + dpi: 300 + style: "seaborn-v0_8-paper" + + metric_distributions: + enabled: true + by_layer: true + by_cluster: true + + cluster_analysis: + enabled: true + scatter_3d: true + cluster_evolution_by_layer: true + cluster_purity: true + + pruning_comparison: + enabled: true + accuracy_vs_sparsity: true + accuracy_vs_flops: true + accuracy_vs_params: true + methods_to_compare: + - "random" + - "magnitude" + - "taylor" + - "composite" + - "cluster_aware" + - "network_slimming" + + layer_importance: + enabled: true + heatmap: true + bar_chart: true + + fine_tuning_recovery: + enabled: true + by_method: true + by_sparsity: true + + efficiency_tradeoffs: + enabled: true + accuracy_vs_flops: true + accuracy_vs_latency: true + accuracy_vs_params: true diff --git a/docs/source/developer_guide/extensibility.rst b/docs/source/developer_guide/extensibility.rst new file mode 100644 index 00000000..5008fcc5 --- /dev/null +++ b/docs/source/developer_guide/extensibility.rst @@ -0,0 +1,275 @@ +Extending the Framework +======================= + +This guide explains how to extend the alignment framework with custom components +using the registry system. + +Overview +-------- + +The framework uses a **registry system** where components register themselves +using decorators. When you import a module with ``@register_*`` decorators, +the components automatically become available. + +This enables: + +1. **Plugin-based architecture** - Add new components without modifying core code +2. **Configuration-driven instantiation** - Create components from config files +3. **Auto-discovery** - Automatically find and register components from packages +4. **Metadata tracking** - Store component capabilities, requirements, and documentation + +Available Registries +-------------------- + +.. list-table:: + :header-rows: 1 + + * - Registry + - Decorator + - Description + * - ``METRIC_REGISTRY`` + - ``@register_metric`` + - Per-neuron metrics (RQ, MI, etc.) + * - ``ANALYZER_REGISTRY`` + - ``@register_analyzer`` + - Analysis pipelines (clustering, halo) + * - ``PRUNER_REGISTRY`` + - ``@register_pruner`` + - Pruning strategies + * - ``VISUALIZER_REGISTRY`` + - ``@register_visualizer`` + - Visualization components + * - ``EVALUATOR_REGISTRY`` + - ``@register_evaluator`` + - Model evaluation + * - ``EXPERIMENT_REGISTRY`` + - ``@register_experiment`` + - Full experiment pipelines + +Creating a Custom Metric +------------------------ + +Here's how to create and register a custom alignment metric: + +.. code-block:: python + + from alignment.core.registry import register_metric + from alignment.core.protocols import BaseMetric + import torch + + @register_metric( + "activation_kurtosis", + category="statistical", + description="Measures kurtosis of activation distributions per neuron", + tags=["statistics", "distribution", "outlier"], + aliases=["kurtosis", "act_kurt"], + ) + class ActivationKurtosis(BaseMetric): + """ + Compute excess kurtosis of activations for each neuron. + + High kurtosis indicates heavy tails (potential outlier neurons). + """ + + name = "activation_kurtosis" + requires_inputs = False + requires_weights = False + requires_outputs = True + + def __init__(self, fisher: bool = True): + """ + Args: + fisher: If True, compute excess kurtosis (subtract 3). + """ + self.fisher = fisher + + def compute( + self, + inputs=None, + weights=None, + outputs=None, + **kwargs + ) -> torch.Tensor: + """ + Compute kurtosis for each neuron/channel. + + Args: + outputs: Activations [batch_size, num_neurons] + + Returns: + Kurtosis values [num_neurons] + """ + if outputs is None: + raise ValueError("ActivationKurtosis requires outputs") + + # Handle different tensor shapes + if outputs.dim() == 4: + # Conv layer: [batch, channels, h, w] -> [batch, channels] + outputs = outputs.mean(dim=(2, 3)) + + # Compute per-neuron statistics + mean = outputs.mean(dim=0) + std = outputs.std(dim=0) + 1e-8 + z = (outputs - mean) / std + m4 = (z ** 4).mean(dim=0) + + if self.fisher: + return m4 - 3.0 + return m4 + +Creating a Custom Analyzer +-------------------------- + +Analyzers perform higher-level analysis on metrics: + +.. code-block:: python + + from alignment.core.registry import register_analyzer + from alignment.core.protocols import BaseAnalyzer + import numpy as np + + @register_analyzer( + "layer_similarity", + category="comparison", + description="Analyze similarity between layers using CKA", + tags=["cka", "similarity", "cross-layer"], + ) + class LayerSimilarityAnalyzer(BaseAnalyzer): + """Analyze representational similarity between layers using CKA.""" + + name = "layer_similarity" + requires = ["activations"] + provides = ["similarity_matrix", "layer_clusters"] + + def __init__(self, method: str = "linear_cka"): + self.method = method + + def analyze(self, metrics, model=None, activations=None, **kwargs): + """Compute layer-to-layer similarity matrix.""" + if activations is None: + raise ValueError("LayerSimilarityAnalyzer requires activations") + + # Your analysis logic here + layer_names = list(activations.keys()) + n_layers = len(layer_names) + similarity_matrix = np.zeros((n_layers, n_layers)) + + # ... compute CKA similarity ... + + return { + "similarity_matrix": similarity_matrix.tolist(), + "layer_names": layer_names, + "method": self.method, + } + + def visualize(self, results, output_dir=None, **kwargs): + """Generate similarity heatmap.""" + # Your visualization logic here + return [] # List of saved figure paths + +Creating a Custom Pruner +------------------------ + +Pruning strategies define how to select neurons for removal: + +.. code-block:: python + + from alignment.core.registry import register_pruner + from alignment.core.protocols import BasePruner + import torch + + @register_pruner( + "entropy_based", + category="information", + description="Prune neurons with low activation entropy", + tags=["entropy", "information", "diversity"], + ) + class EntropyBasedPruner(BasePruner): + """Prune neurons based on activation entropy.""" + + name = "entropy_based" + structured = True + + def __init__(self, n_bins: int = 50): + self.n_bins = n_bins + + def compute_importance(self, model, layer_name, activations=None, **kwargs): + """Compute entropy-based importance scores.""" + if activations is None: + raise ValueError("Requires activations") + + # Compute entropy per neuron + n_neurons = activations.size(-1) + entropies = torch.zeros(n_neurons) + + for i in range(n_neurons): + hist = torch.histc(activations[..., i], bins=self.n_bins) + probs = hist / hist.sum() + probs = probs[probs > 0] + entropies[i] = -(probs * torch.log2(probs)).sum() + + return entropies + +Using Custom Components +----------------------- + +Once registered, custom components can be used by name: + +.. code-block:: python + + from alignment.core.registry import get_metric, initialize_registries + + # Initialize (discovers built-in + custom components) + initialize_registries() + + # Use by name + metric = get_metric("activation_kurtosis", fisher=True) + scores = metric.compute(outputs=activations) + + # Use alias + metric = get_metric("kurtosis") # Same as "activation_kurtosis" + + # Search for metrics + from alignment.core import METRIC_REGISTRY + statistical_metrics = METRIC_REGISTRY.search(tags=["statistics"]) + +Plugin Discovery +---------------- + +Place your custom components in these locations for auto-discovery: + +- ``./plugins/`` (project-local) +- ``~/.alignment/plugins/`` (user-global) + +They will be automatically loaded when the framework initializes. + +Or manually load from a custom location: + +.. code-block:: python + + from alignment.core.registry import discover_plugins + + discover_plugins(["./my_custom_plugins/"]) + +Using in Configuration Files +---------------------------- + +Custom components can be referenced in YAML configs by name: + +.. code-block:: yaml + + pruning: + algorithms: + - "activation_kurtosis" # Your custom metric! + - "entropy_based" # Your custom pruner! + - "magnitude" # Built-in + +Best Practices +-------------- + +1. **Use meaningful names**: Choose descriptive, unique names +2. **Add metadata**: Tags and descriptions help discoverability +3. **Follow protocols**: Implement the required interface methods +4. **Document**: Add docstrings explaining what your component does +5. **Test**: Include tests for your custom components +6. **Handle edge cases**: Check for None inputs, empty tensors, etc. diff --git a/docs/source/developer_guide/index.rst b/docs/source/developer_guide/index.rst new file mode 100644 index 00000000..0548bec0 --- /dev/null +++ b/docs/source/developer_guide/index.rst @@ -0,0 +1,30 @@ +Developer Guide +=============== + +This section contains documentation for developers who want to extend or contribute +to the alignment framework. + +.. toctree:: + :maxdepth: 2 + + extensibility + internal/index + +Overview +-------- + +The alignment framework is designed to be highly extensible. You can add: + +- **Custom Metrics**: Define new per-neuron alignment metrics +- **Custom Analyzers**: Create new analysis pipelines (clustering, halo, etc.) +- **Custom Pruners**: Implement new pruning strategies +- **Custom Visualizers**: Add new plot types +- **Custom Evaluators**: Define new evaluation methods + +See :doc:`extensibility` for detailed instructions and examples. + +Internal Documentation +---------------------- + +The :doc:`internal/index` section contains documentation for maintainers about +codebase organization and documentation structure. diff --git a/slurm_jobs/paper/run_all_paper.sh b/slurm_jobs/prune_llm/run_all_paper.sh similarity index 100% rename from slurm_jobs/paper/run_all_paper.sh rename to slurm_jobs/prune_llm/run_all_paper.sh diff --git a/slurm_jobs/paper/run_llama2_7b.sh b/slurm_jobs/prune_llm/run_llama2_7b.sh similarity index 100% rename from slurm_jobs/paper/run_llama2_7b.sh rename to slurm_jobs/prune_llm/run_llama2_7b.sh diff --git a/slurm_jobs/paper/run_llama3_8b.sh b/slurm_jobs/prune_llm/run_llama3_8b.sh similarity index 91% rename from slurm_jobs/paper/run_llama3_8b.sh rename to slurm_jobs/prune_llm/run_llama3_8b.sh index 15569ec5..35f9bd17 100755 --- a/slurm_jobs/paper/run_llama3_8b.sh +++ b/slurm_jobs/prune_llm/run_llama3_8b.sh @@ -8,7 +8,7 @@ #SBATCH --cpus-per-task=16 #SBATCH --time=12:00:00 #SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 +#SBATCH --partition=kempner_eng #SBATCH --account=kempner_dev # ============================================================================ @@ -55,8 +55,11 @@ echo "" echo "Running LLaMA-3.1-8B full paper analysis..." echo "" +# python scripts/run_experiment.py \ +# --config configs/paper/llama3_8b_full.yaml \ +# --device cuda python scripts/run_experiment.py \ - --config configs/paper/llama3_8b_full.yaml \ + --config configs/prune_llm/llama3_8b_unified.yaml \ --device cuda echo "" diff --git a/slurm_jobs/paper/run_mistral_7b.sh b/slurm_jobs/prune_llm/run_mistral_7b.sh similarity index 100% rename from slurm_jobs/paper/run_mistral_7b.sh rename to slurm_jobs/prune_llm/run_mistral_7b.sh diff --git a/slurm_jobs/paper/run_qwen2_7b.sh b/slurm_jobs/prune_llm/run_qwen2_7b.sh similarity index 100% rename from slurm_jobs/paper/run_qwen2_7b.sh rename to slurm_jobs/prune_llm/run_qwen2_7b.sh diff --git a/slurm_jobs/run_cluster_analysis_resnet18.sh b/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh similarity index 96% rename from slurm_jobs/run_cluster_analysis_resnet18.sh rename to slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh index 360719df..4afda69c 100644 --- a/slurm_jobs/run_cluster_analysis_resnet18.sh +++ b/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh @@ -50,9 +50,9 @@ export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK echo "" echo "Running ResNet-18 cluster analysis..." echo "" - + python scripts/run_experiment.py \ - --config configs/cluster_analysis/resnet18_cifar10_full.yaml \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ --device cuda EXIT_CODE=$? diff --git a/slurm_jobs/run_cluster_analysis_resnet50.sh b/slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh similarity index 100% rename from slurm_jobs/run_cluster_analysis_resnet50.sh rename to slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh diff --git a/src/alignment/analysis/visualization/cluster_plots.py b/src/alignment/analysis/visualization/cluster_plots.py index 34dfafc6..b76950f4 100644 --- a/src/alignment/analysis/visualization/cluster_plots.py +++ b/src/alignment/analysis/visualization/cluster_plots.py @@ -640,7 +640,7 @@ def plot_layer_metric_summary( n_valid += 1 except: continue - + if n_valid > 0: avg_corr = corr_sum / n_valid im = ax.imshow(avg_corr, cmap='RdBu_r', vmin=-1, vmax=1) diff --git a/src/alignment/configs/__init__.py b/src/alignment/configs/__init__.py index d1d18b9c..804ff44c 100644 --- a/src/alignment/configs/__init__.py +++ b/src/alignment/configs/__init__.py @@ -1,8 +1,78 @@ """ Configuration management for the Neural Network Alignment framework. + +This module provides: +- Legacy config loading (load_config) +- Unified config schema (UnifiedConfig) +- Config validation + +Usage: + # Legacy (still works) + from alignment.configs import load_config + config = load_config("path/to/config.yaml") + + # New unified config (recommended) + from alignment.configs import load_unified_config, UnifiedConfig + config = load_unified_config("path/to/config.yaml") + + # Programmatic config + from alignment.configs import UnifiedConfig, ExperimentConfig + config = UnifiedConfig( + experiment=ExperimentConfig(name="my_exp", type="cluster_analysis"), + ... + ) """ from .config_loader import load_config, save_config from .config_validator import validate_config -__all__ = ["load_config", "save_config", "validate_config"] +# Unified config system +from .unified_config import ( + # Main config class + UnifiedConfig, + # Sub-config classes + ExperimentConfig, + ModelConfig, + DatasetConfig, + CalibrationConfig, + MetricsConfig, + MetricItemConfig, + ClusteringConfig, + SupernodeConfig, + HaloConfig, + CascadeConfig, + PruningConfig, + PruningMethodConfig, + EvaluationConfig, + VisualizationConfig, + OutputConfig, + # Loading functions + load_unified_config, + create_config_template, +) + +__all__ = [ + # Legacy + "load_config", + "save_config", + "validate_config", + # Unified config + "UnifiedConfig", + "ExperimentConfig", + "ModelConfig", + "DatasetConfig", + "CalibrationConfig", + "MetricsConfig", + "MetricItemConfig", + "ClusteringConfig", + "SupernodeConfig", + "HaloConfig", + "CascadeConfig", + "PruningConfig", + "PruningMethodConfig", + "EvaluationConfig", + "VisualizationConfig", + "OutputConfig", + "load_unified_config", + "create_config_template", +] diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index 321b39ae..2d3e0350 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -1,5 +1,7 @@ """ Configuration loading and saving utilities. + +Supports both original format and unified format configs. """ import json @@ -15,9 +17,387 @@ logger = logging.getLogger(__name__) +# ============================================================================= +# UNIFIED FORMAT DETECTION AND CONVERSION +# ============================================================================= + +# Metric name mappings: unified -> original +METRIC_UNIFIED_TO_ORIGINAL = { + "rayleigh_quotient": "rayleigh_quotient", + "redundancy": "gaussian_mi_analytic", + "synergy": "synergy_gaussian_mmi", + "magnitude": "activation_l2_norm", +} + +# Reverse mapping: original -> unified +METRIC_ORIGINAL_TO_UNIFIED = {v: k for k, v in METRIC_UNIFIED_TO_ORIGINAL.items()} +METRIC_ORIGINAL_TO_UNIFIED.update({ + "average_redundancy": "redundancy", + "pairwise_redundancy_gaussian": "redundancy", + "gaussian_mi": "redundancy", +}) + + +def _is_unified_format(config_dict: Dict[str, Any]) -> bool: + """ + Detect if config is in unified format. + + Unified format characteristics: + - metrics block has nested dicts with 'enabled' keys (not a list) + - Has 'extra' section for experiment-specific settings + - Uses unified metric names (redundancy, magnitude, etc.) + """ + metrics = config_dict.get("metrics", {}) + if not isinstance(metrics, dict): + return False + + # Unified format: metrics.rayleigh_quotient.enabled exists + # Original format: metrics.enabled is a list + if "enabled" in metrics and isinstance(metrics["enabled"], list): + return False + + # Check for unified metric structure + unified_metrics = ["rayleigh_quotient", "redundancy", "synergy", "magnitude", "scar"] + for metric in unified_metrics: + if metric in metrics and isinstance(metrics[metric], dict): + if "enabled" in metrics[metric]: + return True + + # Check for 'extra' section (strong indicator of unified format) + if "extra" in config_dict: + return True + + return False + + +def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert unified format config to original format. + + This ensures that the unified config produces the exact same + ExperimentConfig as the original format would. + """ + original = {} + + # ------------------------------------------------------------------------- + # EXPERIMENT + # ------------------------------------------------------------------------- + if "experiment" in unified: + exp = unified["experiment"] + original["experiment"] = { + "name": exp.get("name", "experiment"), + "type": exp.get("type", "alignment_analysis"), + } + original["seed"] = exp.get("seed", 42) + original["device"] = exp.get("device", "cuda") + if "output_dir" in exp: + original["results_path"] = exp["output_dir"] + if "num_networks" in exp: + original["num_networks"] = exp["num_networks"] + + # ------------------------------------------------------------------------- + # MODEL + # ------------------------------------------------------------------------- + if "model" in unified: + model = unified["model"] + original["model"] = { + "name": model.get("name", "resnet18"), + "pretrained": model.get("pretrained", True), + } + # LLM fields + if "model_id" in model: + original["model"]["model_id"] = model["model_id"] + if "dtype" in model: + original["model"]["dtype"] = model["dtype"] + if "device_map" in model: + original["model"]["device_map"] = model["device_map"] + if "trust_remote_code" in model: + original["model"]["trust_remote_code"] = model["trust_remote_code"] + if "tracked_layers" in model: + original["model"]["tracked_layers"] = model["tracked_layers"] + if "num_classes" in model: + original["model"]["num_classes"] = model["num_classes"] + + # ------------------------------------------------------------------------- + # DATASET + # ------------------------------------------------------------------------- + if "dataset" in unified: + dataset = unified["dataset"] + original["dataset"] = { + "name": dataset.get("name", "cifar10"), + "batch_size": dataset.get("batch_size", 128), + "num_workers": dataset.get("num_workers", 4), + } + if "subset" in dataset: + original["dataset"]["subset"] = dataset["subset"] + if "split" in dataset: + original["dataset"]["split"] = dataset["split"] + if "root" in dataset: + original["dataset"]["data_path"] = dataset["root"] + + # ------------------------------------------------------------------------- + # CALIBRATION + # ------------------------------------------------------------------------- + if "calibration" in unified: + cal = unified["calibration"] + # Put calibration info in metrics block (where original format expects it) + if "metrics" not in original: + original["metrics"] = {} + original["metrics"]["num_samples"] = cal.get("num_samples", 5000) + # Also keep calibration block for LLM experiments + original["calibration"] = cal + + # ------------------------------------------------------------------------- + # METRICS - Convert unified names to original names + # ------------------------------------------------------------------------- + if "metrics" in unified: + metrics = unified["metrics"] + enabled_metrics = [] + metric_configs = {} + + # Check each unified metric + for unified_name, original_name in METRIC_UNIFIED_TO_ORIGINAL.items(): + if unified_name in metrics: + metric_cfg = metrics[unified_name] + if isinstance(metric_cfg, dict): + if metric_cfg.get("enabled", True): + enabled_metrics.append(original_name) + # Copy metric-specific params + params = {k: v for k, v in metric_cfg.items() if k != "enabled"} + if params: + metric_configs[original_name] = params + elif metric_cfg is True: + enabled_metrics.append(original_name) + + # Handle SCAR metrics (LLM-specific) + if "scar" in metrics: + scar = metrics["scar"] + if isinstance(scar, dict) and scar.get("enabled", True): + original["do_scar_metrics"] = True + original["scar_num_samples"] = scar.get("num_samples", 64) + original["scar_max_length"] = scar.get("max_length", 512) + + # Handle additional metrics + # Note: Skip analysis-derived metrics that are computed by analysis pipelines + # (not standalone metrics that can be computed independently) + ANALYSIS_DERIVED_METRICS = { + "supernode_protection_score", + "supernode_connectivity_score", + "scar_activation_power", + "scar_curvature", + "scar_loss_proxy", + "scar_taylor", + } + if "additional" in metrics: + for name, cfg in metrics["additional"].items(): + if name in ANALYSIS_DERIVED_METRICS: + continue # Skip analysis-derived metrics + if isinstance(cfg, dict) and cfg.get("enabled", True): + enabled_metrics.append(name) + + original["metrics"] = { + "enabled": enabled_metrics, + **metric_configs, + } + + # Composite weights - convert unified names to original + if "composite_weights" in metrics: + comp_weights = {} + for name, weight in metrics["composite_weights"].items(): + original_name = METRIC_UNIFIED_TO_ORIGINAL.get(name, name) + comp_weights[original_name] = weight + original["metrics"]["composite_weights"] = comp_weights + + # ------------------------------------------------------------------------- + # SUPERNODE (LLM outlier detection) + # ------------------------------------------------------------------------- + if "supernode" in unified: + original["supernode"] = unified["supernode"] + + # ------------------------------------------------------------------------- + # CLUSTERING (Vision) + # ------------------------------------------------------------------------- + if "clustering" in unified: + original["clustering"] = unified["clustering"] + + # ------------------------------------------------------------------------- + # HALO ANALYSIS + # ------------------------------------------------------------------------- + if "halo_analysis" in unified: + original["halo_analysis"] = unified["halo_analysis"] + if unified["halo_analysis"].get("enabled"): + original["do_halo_analysis"] = True + + # ------------------------------------------------------------------------- + # CASCADE ANALYSIS + # ------------------------------------------------------------------------- + if "cascade_analysis" in unified: + original["cascade_analysis"] = unified["cascade_analysis"] + + # ------------------------------------------------------------------------- + # PRUNING - Convert unified metric names in algorithms/scoring_methods + # ------------------------------------------------------------------------- + if "pruning" in unified: + pruning = unified["pruning"] + original_pruning = { + "enabled": pruning.get("enabled", True), + } + + # Ratios/sparsity levels + if "ratios" in pruning: + original_pruning["sparsity_levels"] = pruning["ratios"] + elif "sparsity_levels" in pruning: + original_pruning["sparsity_levels"] = pruning["sparsity_levels"] + + # Selection modes + if "selection_modes" in pruning: + original_pruning["selection_modes"] = pruning["selection_modes"] + + # Convert algorithm names + if "algorithms" in pruning: + converted_algorithms = [] + for alg in pruning["algorithms"]: + converted_algorithms.append(METRIC_UNIFIED_TO_ORIGINAL.get(alg, alg)) + original_pruning["algorithms"] = converted_algorithms + + # Convert scoring methods + if "scoring_methods" in pruning: + converted_scoring = [] + for method in pruning["scoring_methods"]: + converted_scoring.append(METRIC_UNIFIED_TO_ORIGINAL.get(method, method)) + original_pruning["scoring_methods"] = converted_scoring + + # Other pruning fields + for key in ["distribution", "structured", "dependency_aware", "target", "single_strategy"]: + if key in pruning: + original_pruning[key] = pruning[key] + + # Fine-tune settings + if "fine_tune" in pruning: + original_pruning["fine_tune"] = pruning["fine_tune"] + + original["pruning"] = original_pruning + + # ------------------------------------------------------------------------- + # EVALUATION + # ------------------------------------------------------------------------- + if "evaluation" in unified: + ev = unified["evaluation"] + original["evaluation"] = {"enabled": ev.get("enabled", True)} + + # Perplexity (LLM) + if ev.get("perplexity_enabled"): + original["do_perplexity_computation"] = True + if "perplexity_datasets" in ev: + # Convert to original format + original["evaluation"]["perplexity"] = { + "enabled": True, + "datasets": ev["perplexity_datasets"], + } + + # Benchmarks (LLM) + if ev.get("benchmarks_enabled"): + if "benchmark_tasks" in ev: + original["evaluation"]["benchmarks"] = ev["benchmark_tasks"] + original["evaluation"]["batch_size"] = ev.get("benchmark_batch_size", 8) + + # ------------------------------------------------------------------------- + # VISUALIZATION + # ------------------------------------------------------------------------- + if "visualization" in unified: + viz = unified["visualization"] + original["visualization"] = viz + if viz.get("enabled", True): + original["generate_plots"] = True + if "format" in viz: + original["plot_format"] = viz["format"] + if "dpi" in viz: + original["plot_dpi"] = viz["dpi"] + + # ------------------------------------------------------------------------- + # OUTPUT + # ------------------------------------------------------------------------- + if "output" in unified: + out = unified["output"] + if "dir" in out: + original["results_path"] = out["dir"] + # Also set experiment.output_dir for compatibility + if "experiment" in original: + original["experiment"]["output_dir"] = out["dir"] + + # ------------------------------------------------------------------------- + # EXTRA - Expand LLM-specific settings from extra block + # ------------------------------------------------------------------------- + if "extra" in unified: + extra = unified["extra"] + + # Analysis options (with all detailed settings) + if "analysis" in extra: + original["analysis"] = extra["analysis"] + + # Supernode robustness + if "supernode_robustness" in extra: + original["supernode_robustness"] = extra["supernode_robustness"] + + # Supernode summary + if "supernode_summary" in extra: + original["supernode_summary"] = extra["supernode_summary"] + + # Multi-supernode + if "multi_supernode" in extra: + original["multi_supernode"] = extra["multi_supernode"] + + # Cross-layer + if "cross_layer" in extra: + original["cross_layer"] = extra["cross_layer"] + + # Generalized importance + if "generalized_importance" in extra: + original["generalized_importance"] = extra["generalized_importance"] + if extra["generalized_importance"].get("enabled"): + original["do_generalized_importance"] = True + + # Halo analysis (detailed settings from extra override top-level) + if "halo_analysis" in extra: + if "halo_analysis" not in original: + original["halo_analysis"] = {} + original["halo_analysis"].update(extra["halo_analysis"]) + + # Visualization (detailed paper figure settings) + if "visualization" in extra: + if "visualization" not in original: + original["visualization"] = {} + original["visualization"].update(extra["visualization"]) + + # Top-level flags + for flag in ["do_scar_metrics", "do_directed_redundancy", "do_connectivity_pruning", + "do_halo_analysis", "do_generalized_importance"]: + if flag in extra: + original[flag] = extra[flag] + + # LLM block reconstruction + llm_block = {} + if original.get("do_scar_metrics"): + llm_block["scar_metrics"] = True + if "scar_num_samples" in original: + llm_block["scar_num_samples"] = original["scar_num_samples"] + if "scar_max_length" in original: + llm_block["scar_max_length"] = original["scar_max_length"] + if original.get("do_perplexity_computation"): + llm_block["evaluate_perplexity"] = True + if llm_block: + original["llm"] = llm_block + + logger.info("Converted unified config to original format") + return original + + def load_config(config_path: Union[str, Path]) -> ExperimentConfig: """ Load configuration from a YAML or JSON file. + + Supports both original format and unified format configs. + Unified format configs are automatically detected and converted. Args: config_path: Path to configuration file @@ -47,6 +427,11 @@ def load_config(config_path: Union[str, Path]) -> ExperimentConfig: # Handle environment variable substitution config_dict = _substitute_env_vars(config_dict) + # Detect and convert unified format to original format + if _is_unified_format(config_dict): + logger.info(f"Detected unified config format in {config_path}") + config_dict = _convert_unified_to_original(config_dict) + # Map nested config to flat ExperimentConfig structure config_dict = _map_nested_to_flat_config(config_dict) diff --git a/src/alignment/configs/unified_config.py b/src/alignment/configs/unified_config.py new file mode 100644 index 00000000..dc2832da --- /dev/null +++ b/src/alignment/configs/unified_config.py @@ -0,0 +1,683 @@ +""" +Unified configuration schema for the alignment framework. + +This module provides a single, validated configuration structure that works +for all experiment types (vision, LLM, custom). Configuration can be loaded +from YAML files and validated against the schema. + +Features: +- Unified structure for vision and LLM experiments +- Automatic validation with helpful error messages +- Default values with override capability +- Support for config inheritance/presets +- Registry-aware: validates that referenced components exist + +Usage: + from alignment.configs.unified_config import load_unified_config, UnifiedConfig + + # From YAML file + config = load_unified_config("configs/my_experiment.yaml") + + # Programmatically + config = UnifiedConfig( + experiment=ExperimentConfig(name="my_exp", type="cluster_analysis"), + model=ModelConfig(name="resnet18"), + ... + ) +""" + +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# CONFIGURATION DATACLASSES +# ============================================================================= + +@dataclass +class ExperimentConfig: + """Experiment-level configuration.""" + name: str = "experiment" + type: str = "alignment_analysis" # alignment_analysis, cluster_analysis, llm_alignment + seed: int = 42 + device: str = "cuda" + output_dir: str = "./results" + num_networks: int = 1 + + # Backward compatibility aliases + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "ExperimentConfig": + # Handle old-style flat configs + return cls( + name=d.get("name") or d.get("experiment_name", "experiment"), + type=d.get("type") or d.get("experiment_type", "alignment_analysis"), + seed=d.get("seed", 42), + device=d.get("device", "cuda"), + output_dir=d.get("output_dir", "./results"), + num_networks=d.get("num_networks", 1), + ) + + +@dataclass +class ModelConfig: + """Model configuration (works for both vision and LLM).""" + name: str = "resnet18" + pretrained: bool = True + num_classes: Optional[int] = None # Vision only + + # HuggingFace / LLM specific + model_id: Optional[str] = None # e.g., "meta-llama/Llama-3.1-8B" + dtype: str = "bfloat16" + device_map: Optional[str] = "auto" + trust_remote_code: bool = True + + # Layer tracking + tracked_layers: Optional[List[str]] = None + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "ModelConfig": + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class DatasetConfig: + """Dataset configuration.""" + name: str = "cifar10" + root: str = "./data" + batch_size: int = 128 + num_workers: int = 4 + + # LLM specific + subset: Optional[str] = None # e.g., "wikitext-2-raw-v1" + split: str = "train" + max_length: int = 2048 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "DatasetConfig": + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class CalibrationConfig: + """Calibration data configuration.""" + num_samples: int = 5000 # Vision: 5000, LLM: 64-128 + max_length: int = 2048 # LLM only + batch_size: int = 4 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "CalibrationConfig": + # Handle old-style nested configs + if "n_calibration_samples" in d: + d["num_samples"] = d.pop("n_calibration_samples") + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class MetricItemConfig: + """Configuration for a single metric.""" + enabled: bool = True + # Metric-specific parameters stored as dict + params: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MetricsConfig: + """Metrics configuration (unified naming).""" + # Core metrics available for both vision and LLM + rayleigh_quotient: MetricItemConfig = field(default_factory=lambda: MetricItemConfig(enabled=True, params={"relative": True})) + redundancy: MetricItemConfig = field(default_factory=lambda: MetricItemConfig(enabled=True, params={"sampling": "all"})) + synergy: MetricItemConfig = field(default_factory=lambda: MetricItemConfig(enabled=True, params={"target": "logit_margin", "num_pairs": 10})) + magnitude: MetricItemConfig = field(default_factory=lambda: MetricItemConfig(enabled=True, params={"type": "l2_norm"})) + + # Additional metrics (optional) + additional: Dict[str, MetricItemConfig] = field(default_factory=dict) + + # Composite weights for combined scoring + composite_weights: Dict[str, float] = field(default_factory=lambda: { + "rayleigh_quotient": 0.33, + "redundancy": -0.33, + "synergy": 0.33, + }) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "MetricsConfig": + config = cls() + + # Handle 'enabled' list format + if "enabled" in d and isinstance(d["enabled"], list): + for metric_name in d["enabled"]: + canonical = _normalize_metric_name(metric_name) + if hasattr(config, canonical): + getattr(config, canonical).enabled = True + + # Handle individual metric configs + for metric_name in ["rayleigh_quotient", "redundancy", "synergy", "magnitude"]: + if metric_name in d: + if isinstance(d[metric_name], dict): + getattr(config, metric_name).params.update(d[metric_name]) + elif isinstance(d[metric_name], bool): + getattr(config, metric_name).enabled = d[metric_name] + + # Handle old-style boolean flags + if d.get("compute_rq"): + config.rayleigh_quotient.enabled = True + if d.get("compute_redundancy"): + config.redundancy.enabled = True + if d.get("compute_synergy"): + config.synergy.enabled = True + + if "composite_weights" in d: + config.composite_weights.update(d["composite_weights"]) + + return config + + +@dataclass +class ClusteringConfig: + """Clustering configuration.""" + enabled: bool = True + n_clusters: int = 4 + type_names: List[str] = field(default_factory=lambda: ["critical", "redundant", "synergistic", "background"]) + normalize_features: bool = True + features: List[str] = field(default_factory=lambda: ["rayleigh_quotient", "redundancy", "synergy"]) + + # Stability analysis + stability_enabled: bool = True + n_bootstrap: int = 50 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "ClusteringConfig": + if "compute_stability" in d: + d["stability_enabled"] = d.pop("compute_stability") + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class SupernodeConfig: + """Supernode detection configuration.""" + enabled: bool = False + score_metric: str = "synergy" # Or: rayleigh_quotient, magnitude, scar_loss_proxy + core_fraction: float = 0.01 # Top 1% as supernodes + halo_fraction: float = 0.10 # 10% as halo + follower_fraction: float = 0.10 + protect_core: bool = True + + # Cross-layer analysis + cross_layer_analysis: bool = True + compare_by_connection: bool = True + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "SupernodeConfig": + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class HaloConfig: + """Halo/cross-layer analysis configuration.""" + enabled: bool = True + percentile: float = 90.0 + use_activation_weight: bool = True + compute_influence_matrix: bool = True + + # For LLM + max_refs: int = 512 + sample_pairs: int = 2000 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "HaloConfig": + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class CascadeConfig: + """Cascade/damage analysis configuration.""" + enabled: bool = True + n_remove_per_group: int = 5 + damage_sample_fraction: float = 0.2 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "CascadeConfig": + if "n_remove_per_cluster" in d: + d["n_remove_per_group"] = d.pop("n_remove_per_cluster") + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class PruningMethodConfig: + """Configuration for a single pruning method.""" + name: str + selection: str = "low" # low, high, random + weights: Optional[Dict[str, float]] = None # For composite methods + params: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PruningConfig: + """Pruning configuration.""" + enabled: bool = True + ratios: List[float] = field(default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) + + # Methods to compare + methods: List[PruningMethodConfig] = field(default_factory=list) + + # Selection modes + selection_modes: List[str] = field(default_factory=lambda: ["low", "high"]) + + # Fine-tuning + fine_tune_enabled: bool = True + fine_tune_epochs: int = 10 + fine_tune_lr: float = 0.0001 + + # Distribution strategy + distribution: str = "uniform" # uniform, global_threshold, adaptive + structured: bool = True + dependency_aware: bool = False + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "PruningConfig": + config = cls() + + if "ratios" in d: + config.ratios = d["ratios"] + elif "sparsity_levels" in d: + config.ratios = d["sparsity_levels"] + + if "methods" in d: + config.methods = [] + for method in d["methods"]: + if isinstance(method, str): + config.methods.append(PruningMethodConfig(name=method)) + elif isinstance(method, dict): + config.methods.append(PruningMethodConfig(**method)) + + if "fine_tune" in d or "fine_tuning" in d: + ft = d.get("fine_tune") or d.get("fine_tuning", {}) + config.fine_tune_enabled = ft.get("enabled", True) + config.fine_tune_epochs = ft.get("epochs", 10) + config.fine_tune_lr = ft.get("lr") or ft.get("learning_rate", 0.0001) + + for key in ["enabled", "distribution", "structured", "dependency_aware", "selection_modes"]: + if key in d: + setattr(config, key, d[key]) + + return config + + +@dataclass +class EvaluationConfig: + """Evaluation configuration.""" + enabled: bool = True + accuracy: bool = True + loss: bool = True + + # LLM specific + perplexity_enabled: bool = False + perplexity_datasets: List[str] = field(default_factory=lambda: ["wikitext"]) + benchmarks_enabled: bool = False + benchmark_tasks: List[str] = field(default_factory=list) + benchmark_fewshot: int = 0 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "EvaluationConfig": + config = cls() + + if "perplexity" in d: + ppl = d["perplexity"] + config.perplexity_enabled = ppl.get("enabled", True) + if "datasets" in ppl: + config.perplexity_datasets = [ + ds["name"] if isinstance(ds, dict) else ds + for ds in ppl["datasets"] + ] + + if "benchmarks" in d: + config.benchmarks_enabled = True + config.benchmark_tasks = [ + b["name"] if isinstance(b, dict) else b + for b in d["benchmarks"] + ] + + return config + + +@dataclass +class VisualizationConfig: + """Visualization configuration.""" + enabled: bool = True + format: str = "png" # png, pdf, svg + dpi: int = 300 + + # Plot types to generate + histograms: bool = True + violin_plots: bool = True + correlation_heatmap: bool = True + cluster_scatter: bool = True + cluster_evolution: bool = True + influence_matrix: bool = True + halo_properties: bool = True + pruning_comparison: bool = True + pruning_recovery: bool = True + cascade_test: bool = True + + # Scatter plot pairs + scatter_pairs: List[List[str]] = field(default_factory=lambda: [ + ["rayleigh_quotient", "redundancy"], + ["rayleigh_quotient", "synergy"], + ["redundancy", "synergy"], + ]) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "VisualizationConfig": + config = cls() + + for key, value in d.items(): + if hasattr(config, key): + setattr(config, key, value) + + # Handle nested 'plots' dict + if "plots" in d: + for key, value in d["plots"].items(): + if hasattr(config, key): + setattr(config, key, value) + + # Handle 'figures' list + if "figures" in d: + for fig_type in d["figures"]: + attr_name = fig_type.replace("-", "_") + if hasattr(config, attr_name): + setattr(config, attr_name, True) + + return config + + +@dataclass +class OutputConfig: + """Output configuration.""" + dir: str = "./results" + save_metrics: bool = True + save_clusters: bool = True + save_figures: bool = True + save_checkpoints: bool = False + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "OutputConfig": + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +# ============================================================================= +# UNIFIED CONFIG +# ============================================================================= + +@dataclass +class UnifiedConfig: + """ + Unified configuration for all experiment types. + + This config structure works for: + - Vision experiments (ResNet, VGG, etc. on CIFAR, ImageNet) + - LLM experiments (Llama, Qwen, etc.) + - Custom experiments + + All fields have sensible defaults, so you only need to specify + what you want to change. + """ + experiment: ExperimentConfig = field(default_factory=ExperimentConfig) + model: ModelConfig = field(default_factory=ModelConfig) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + calibration: CalibrationConfig = field(default_factory=CalibrationConfig) + metrics: MetricsConfig = field(default_factory=MetricsConfig) + clustering: ClusteringConfig = field(default_factory=ClusteringConfig) + supernode: SupernodeConfig = field(default_factory=SupernodeConfig) + halo_analysis: HaloConfig = field(default_factory=HaloConfig) + cascade_analysis: CascadeConfig = field(default_factory=CascadeConfig) + pruning: PruningConfig = field(default_factory=PruningConfig) + evaluation: EvaluationConfig = field(default_factory=EvaluationConfig) + visualization: VisualizationConfig = field(default_factory=VisualizationConfig) + output: OutputConfig = field(default_factory=OutputConfig) + + # Extra fields for custom extensions + extra: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "UnifiedConfig": + """Create config from dictionary (e.g., loaded from YAML).""" + config = cls() + + # Handle experiment config (with backward compat for flat style) + if "experiment" in d: + config.experiment = ExperimentConfig.from_dict(d["experiment"]) + else: + config.experiment = ExperimentConfig.from_dict(d) + + # Model + if "model" in d: + config.model = ModelConfig.from_dict(d["model"]) + + # Dataset + if "dataset" in d: + config.dataset = DatasetConfig.from_dict(d["dataset"]) + + # Calibration (can be in 'calibration' or 'metrics') + if "calibration" in d: + config.calibration = CalibrationConfig.from_dict(d["calibration"]) + elif "metrics" in d: + config.calibration = CalibrationConfig.from_dict(d["metrics"]) + + # Metrics + if "metrics" in d: + config.metrics = MetricsConfig.from_dict(d["metrics"]) + + # Clustering + if "clustering" in d: + config.clustering = ClusteringConfig.from_dict(d["clustering"]) + + # Supernode + if "supernode" in d: + config.supernode = SupernodeConfig.from_dict(d["supernode"]) + + # Halo + if "halo_analysis" in d: + config.halo_analysis = HaloConfig.from_dict(d["halo_analysis"]) + + # Cascade + if "cascade_analysis" in d: + config.cascade_analysis = CascadeConfig.from_dict(d["cascade_analysis"]) + + # Pruning + if "pruning" in d: + config.pruning = PruningConfig.from_dict(d["pruning"]) + + # Evaluation + if "evaluation" in d: + config.evaluation = EvaluationConfig.from_dict(d["evaluation"]) + + # Visualization + if "visualization" in d: + config.visualization = VisualizationConfig.from_dict(d["visualization"]) + + # Output + if "output" in d: + config.output = OutputConfig.from_dict(d["output"]) + + # Store any extra fields + known_keys = { + "experiment", "model", "dataset", "calibration", "metrics", + "clustering", "supernode", "halo_analysis", "cascade_analysis", + "pruning", "evaluation", "visualization", "output", + "experiment_name", "experiment_type", "name", "type", "seed", "device", + } + config.extra = {k: v for k, v in d.items() if k not in known_keys} + + return config + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary.""" + import dataclasses + + def convert(obj): + if dataclasses.is_dataclass(obj): + return {k: convert(v) for k, v in dataclasses.asdict(obj).items()} + elif isinstance(obj, list): + return [convert(v) for v in obj] + elif isinstance(obj, dict): + return {k: convert(v) for k, v in obj.items()} + return obj + + return convert(self) + + def save(self, path: Union[str, Path]) -> None: + """Save config to YAML file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) + logger.info(f"Saved config to {path}") + + def validate(self) -> List[str]: + """ + Validate the configuration. + + Returns: + List of validation warnings/errors (empty if valid) + """ + warnings = [] + + # Check experiment type + valid_types = ["alignment_analysis", "cluster_analysis", "llm_alignment", "general_alignment"] + if self.experiment.type not in valid_types: + warnings.append(f"Unknown experiment type: {self.experiment.type}") + + # Check LLM-specific requirements + if self.experiment.type == "llm_alignment": + if not self.model.model_id: + warnings.append("LLM experiments require model.model_id") + if self.calibration.num_samples > 256: + warnings.append("LLM calibration typically uses < 256 samples") + + # Check metric consistency + if self.clustering.enabled: + for feature in self.clustering.features: + normalized = _normalize_metric_name(feature) + if hasattr(self.metrics, normalized): + if not getattr(self.metrics, normalized).enabled: + warnings.append(f"Clustering uses '{feature}' but it's not enabled in metrics") + + return warnings + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + +def _normalize_metric_name(name: str) -> str: + """Normalize metric names to canonical form.""" + aliases = { + "rq": "rayleigh_quotient", + "rayleigh": "rayleigh_quotient", + "alignment": "rayleigh_quotient", + "gaussian_mi": "redundancy", + "gaussian_mi_analytic": "redundancy", + "average_redundancy": "redundancy", + "pairwise_redundancy": "redundancy", + "pairwise_redundancy_gaussian": "redundancy", + "synergy_gaussian_mmi": "synergy", + "synergy_continuous": "synergy", + "activation_l2_norm": "magnitude", + "l2_norm": "magnitude", + } + return aliases.get(name.lower(), name.lower()) + + +def load_unified_config(path: Union[str, Path]) -> UnifiedConfig: + """ + Load a unified config from a YAML file. + + Supports: + - New unified format + - Old LLM format (with 'experiment:' block) + - Old vision format (with flat 'experiment_name:') + - Config inheritance via '_inherit' key + + Args: + path: Path to YAML config file + + Returns: + UnifiedConfig instance + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + + with open(path) as f: + raw = yaml.safe_load(f) + + # Handle inheritance + if "_inherit" in raw: + base_path = path.parent / raw.pop("_inherit") + base_config = load_unified_config(base_path) + # Merge: raw overrides base + base_dict = base_config.to_dict() + _deep_merge(base_dict, raw) + raw = base_dict + + config = UnifiedConfig.from_dict(raw) + + # Validate + warnings = config.validate() + for warning in warnings: + logger.warning(f"Config validation: {warning}") + + return config + + +def _deep_merge(base: Dict, override: Dict) -> None: + """Deep merge override into base (in-place).""" + for key, value in override.items(): + if key in base and isinstance(base[key], dict) and isinstance(value, dict): + _deep_merge(base[key], value) + else: + base[key] = value + + +def create_config_template(experiment_type: str = "cluster_analysis") -> UnifiedConfig: + """ + Create a template config for a specific experiment type. + + Args: + experiment_type: Type of experiment + + Returns: + UnifiedConfig with appropriate defaults + """ + config = UnifiedConfig() + config.experiment.type = experiment_type + + if experiment_type == "llm_alignment": + # LLM defaults + config.model.name = "hf_causal_lm" + config.model.model_id = "meta-llama/Llama-3.1-8B" + config.dataset.name = "wikitext" + config.calibration.num_samples = 128 + config.calibration.max_length = 2048 + config.supernode.enabled = True + config.clustering.enabled = False + config.evaluation.perplexity_enabled = True + + elif experiment_type == "cluster_analysis": + # Vision clustering defaults + config.model.name = "resnet18" + config.model.num_classes = 10 + config.dataset.name = "cifar10" + config.calibration.num_samples = 5000 + config.clustering.enabled = True + config.supernode.enabled = False + + return config diff --git a/src/alignment/core/__init__.py b/src/alignment/core/__init__.py index 3a0e5edd..3f5eaabb 100644 --- a/src/alignment/core/__init__.py +++ b/src/alignment/core/__init__.py @@ -3,45 +3,187 @@ This module provides the foundational abstractions, protocols, and registries used throughout the framework. + +Key Components: +- **Protocols**: Interface definitions that components must implement +- **Registry**: Central registration system for discoverable components +- **Base Classes**: Optional abstract base classes for convenience + +Example - Registering a custom metric: + + from alignment.core import register_metric, BaseMetric + + @register_metric("my_metric", category="custom", tags=["experimental"]) + class MyMetric(BaseMetric): + name = "my_metric" + + def compute(self, outputs, **kwargs): + return per_neuron_scores + +Example - Using registered components: + + from alignment.core import get_metric, METRIC_REGISTRY + + # By name + metric = get_metric("rayleigh_quotient") + + # List available + print(METRIC_REGISTRY.list()) + + # Search by tag + print(METRIC_REGISTRY.search(tags=["alignment"])) """ from .base import BaseDataset, BaseExperiment, BaseMetric, BaseModel -from .protocols import AlignmentMetric, DatasetWrapper, Experiment, MetricAggregator + +# Protocols (interface definitions) +from .protocols import ( + AlignmentMetric, + DatasetWrapper, + Experiment, + MetricAggregator, + ResultReporter, + # New protocols for enhanced modularity + Analyzer, + Pruner, + Visualizer, + Evaluator, + Preprocessor, + # Base classes from protocols + BaseMetric as BaseMetricProtocol, + BaseAnalyzer, + BasePruner, + # Config dataclasses + MetricConfig, + AnalyzerConfig, + PrunerConfig, +) from .protocols import ModelWrapper as ModelWrapperProtocol -from .protocols import ResultReporter + +# Registry system from .registry import ( + # Core registry class Registry, - get_dataset, - get_experiment, - get_metric, - get_model, - register_dataset, - register_experiment, + ComponentInfo, + # Global registries + METRIC_REGISTRY, + MODEL_REGISTRY, + DATASET_REGISTRY, + EXPERIMENT_REGISTRY, + AGGREGATOR_REGISTRY, + REPORTER_REGISTRY, + ANALYZER_REGISTRY, + VISUALIZER_REGISTRY, + PRUNER_REGISTRY, + EVALUATOR_REGISTRY, + PREPROCESSOR_REGISTRY, + ALL_REGISTRIES, + # Registration decorators register_metric, register_model, + register_dataset, + register_experiment, + register_aggregator, + register_reporter, + register_analyzer, + register_visualizer, + register_pruner, + register_evaluator, + register_preprocessor, + # Getter functions + get_metric, + get_model, + get_dataset, + get_experiment, + get_aggregator, + get_reporter, + get_analyzer, + get_visualizer, + get_pruner, + get_evaluator, + get_preprocessor, + # Unified factory functions + create_component, + create_from_config, + list_all_components, + print_registry_summary, + # Discovery + discover_and_register, + discover_plugins, + initialize_registries, ) __all__ = [ - # Protocols + # Protocols (interfaces) "AlignmentMetric", "ModelWrapperProtocol", "DatasetWrapper", "Experiment", "MetricAggregator", "ResultReporter", + "Analyzer", + "Pruner", + "Visualizer", + "Evaluator", + "Preprocessor", + # Base classes + "BaseMetric", + "BaseMetricProtocol", + "BaseAnalyzer", + "BasePruner", + "BaseModel", + "BaseDataset", + "BaseExperiment", + # Config dataclasses + "MetricConfig", + "AnalyzerConfig", + "PrunerConfig", # Registry "Registry", + "ComponentInfo", + "METRIC_REGISTRY", + "MODEL_REGISTRY", + "DATASET_REGISTRY", + "EXPERIMENT_REGISTRY", + "AGGREGATOR_REGISTRY", + "REPORTER_REGISTRY", + "ANALYZER_REGISTRY", + "VISUALIZER_REGISTRY", + "PRUNER_REGISTRY", + "EVALUATOR_REGISTRY", + "PREPROCESSOR_REGISTRY", + "ALL_REGISTRIES", + # Registration decorators "register_metric", "register_model", "register_dataset", "register_experiment", + "register_aggregator", + "register_reporter", + "register_analyzer", + "register_visualizer", + "register_pruner", + "register_evaluator", + "register_preprocessor", + # Getter functions "get_metric", "get_model", "get_dataset", "get_experiment", - # Base classes - "BaseMetric", - "BaseModel", - "BaseDataset", - "BaseExperiment", + "get_aggregator", + "get_reporter", + "get_analyzer", + "get_visualizer", + "get_pruner", + "get_evaluator", + "get_preprocessor", + # Factory functions + "create_component", + "create_from_config", + "list_all_components", + "print_registry_summary", + # Discovery + "discover_and_register", + "discover_plugins", + "initialize_registries", ] diff --git a/src/alignment/core/protocols.py b/src/alignment/core/protocols.py index 5f386c74..ced0bfe1 100644 --- a/src/alignment/core/protocols.py +++ b/src/alignment/core/protocols.py @@ -3,9 +3,42 @@ These protocols define the interfaces that all implementations must follow, ensuring consistency and enabling easy extension of the framework. + +Protocols serve as contracts - any class implementing a protocol can be used +interchangeably, enabling plugin-based extensibility. + +Available Protocols: +- AlignmentMetric: Per-neuron/channel metrics (RQ, MI, redundancy, etc.) +- Analyzer: Analysis pipelines (clustering, halo analysis, etc.) +- Pruner: Pruning strategies +- Visualizer: Visualization components +- Evaluator: Model evaluation (accuracy, perplexity, etc.) +- ModelWrapper: Model wrappers for activation extraction +- DatasetWrapper: Dataset abstractions +- Experiment: Full experiment pipelines + +Example - Creating a custom metric: + + from alignment.core.protocols import AlignmentMetric + from alignment.core.registry import register_metric + + @register_metric("my_custom_metric", category="custom", tags=["experimental"]) + class MyCustomMetric: + '''My custom alignment metric.''' + + name = "my_custom_metric" + requires_inputs = False + requires_weights = True + requires_outputs = True + + def compute(self, outputs, weights, **kwargs): + # Your metric computation here + return per_neuron_scores """ -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runtime_checkable import torch import torch.nn as nn @@ -220,3 +253,470 @@ def report(self, results: Dict[str, Any], output_path: Optional[str] = None, **k def visualize(self, results: Dict[str, Any], plot_type: str, **kwargs: Any) -> Any: """Generate visualizations from results.""" ... + + +# ============================================================================= +# NEW PROTOCOLS FOR ENHANCED MODULARITY +# ============================================================================= + +@runtime_checkable +class Analyzer(Protocol): + """ + Protocol for analysis components (clustering, halo analysis, cross-layer, etc.). + + Analyzers take metric results and produce higher-level insights. + + Example implementations: + - KMeansClustering: Cluster neurons by metric values + - HaloAnalysis: Analyze cross-layer dependencies + - SupernodeDetection: Identify outlier neurons + """ + + @property + def name(self) -> str: + """Analyzer name.""" + ... + + @property + def requires(self) -> List[str]: + """List of required inputs (metric names, activations, etc.).""" + ... + + @property + def provides(self) -> List[str]: + """List of outputs this analyzer produces.""" + ... + + def analyze( + self, + metrics: Dict[str, Any], + model: Optional[nn.Module] = None, + activations: Optional[Dict[str, torch.Tensor]] = None, + **kwargs: Any + ) -> Dict[str, Any]: + """ + Perform analysis on metrics/activations. + + Args: + metrics: Dictionary of metric results (per layer, per neuron) + model: Optional model for weight-based analysis + activations: Optional pre-computed activations + **kwargs: Additional analyzer-specific parameters + + Returns: + Dictionary of analysis results + """ + ... + + def visualize( + self, + results: Dict[str, Any], + output_dir: Optional[str] = None, + **kwargs: Any + ) -> List[str]: + """ + Generate visualizations from analysis results. + + Args: + results: Analysis results from analyze() + output_dir: Directory to save figures + **kwargs: Visualization parameters + + Returns: + List of paths to generated figures + """ + ... + + +@runtime_checkable +class Pruner(Protocol): + """ + Protocol for pruning strategies. + + Pruners compute importance scores and apply pruning to models. + + Example implementations: + - MagnitudePruning: Prune by weight magnitude + - GradientPruning: Prune by gradient-based importance + - AlignmentPruning: Prune by alignment metrics (RQ, MI, etc.) + - ClusterAwarePruning: Prune while respecting cluster structure + """ + + @property + def name(self) -> str: + """Pruner name.""" + ... + + @property + def structured(self) -> bool: + """Whether this is structured (channel/neuron) or unstructured (weight) pruning.""" + ... + + def compute_importance( + self, + model: nn.Module, + layer_name: str, + activations: Optional[torch.Tensor] = None, + gradients: Optional[torch.Tensor] = None, + **kwargs: Any + ) -> torch.Tensor: + """ + Compute importance scores for neurons/channels in a layer. + + Args: + model: The model being pruned + layer_name: Name of the layer to compute importance for + activations: Optional activations for this layer + gradients: Optional gradients for this layer + **kwargs: Additional pruner-specific parameters + + Returns: + Tensor of importance scores [num_neurons] or [num_channels] + """ + ... + + def select_to_prune( + self, + scores: torch.Tensor, + amount: float, + **kwargs: Any + ) -> torch.Tensor: + """ + Select which neurons/channels to prune based on scores. + + Args: + scores: Importance scores from compute_importance() + amount: Fraction or number of neurons to prune + **kwargs: Additional selection parameters + + Returns: + Boolean mask or indices of neurons to prune + """ + ... + + def apply( + self, + model: nn.Module, + prune_mask: Dict[str, torch.Tensor], + **kwargs: Any + ) -> nn.Module: + """ + Apply pruning to the model. + + Args: + model: Model to prune + prune_mask: Dictionary mapping layer names to prune masks + **kwargs: Additional parameters + + Returns: + Pruned model + """ + ... + + +@runtime_checkable +class Visualizer(Protocol): + """ + Protocol for visualization components. + + Visualizers generate plots and figures from various data types. + + Example implementations: + - MetricHistogramVisualizer: Plot metric distributions + - PruningCurveVisualizer: Plot accuracy vs sparsity + - ClusterScatterVisualizer: Plot metric space clusters + """ + + @property + def name(self) -> str: + """Visualizer name.""" + ... + + @property + def plot_types(self) -> List[str]: + """List of plot types this visualizer can generate.""" + ... + + def plot( + self, + data: Any, + plot_type: str, + save_path: Optional[str] = None, + **kwargs: Any + ) -> Any: + """ + Generate a plot. + + Args: + data: Data to visualize (format depends on plot_type) + plot_type: Type of plot to generate + save_path: Optional path to save the figure + **kwargs: Plot-specific parameters (figsize, title, etc.) + + Returns: + Matplotlib figure or other visualization object + """ + ... + + def plot_batch( + self, + data: Dict[str, Any], + output_dir: str, + **kwargs: Any + ) -> List[str]: + """ + Generate multiple plots from a batch of data. + + Args: + data: Dictionary of data to visualize + output_dir: Directory to save figures + **kwargs: Common plot parameters + + Returns: + List of paths to generated figures + """ + ... + + +@runtime_checkable +class Evaluator(Protocol): + """ + Protocol for model evaluation. + + Evaluators compute performance metrics on models. + + Example implementations: + - AccuracyEvaluator: Classification accuracy + - PerplexityEvaluator: Language model perplexity + - BenchmarkEvaluator: Run standard benchmarks (MMLU, etc.) + """ + + @property + def name(self) -> str: + """Evaluator name.""" + ... + + @property + def metrics(self) -> List[str]: + """List of metrics this evaluator computes.""" + ... + + def evaluate( + self, + model: nn.Module, + dataloader: DataLoader, + device: str = "cuda", + **kwargs: Any + ) -> Dict[str, float]: + """ + Evaluate the model. + + Args: + model: Model to evaluate + dataloader: Data to evaluate on + device: Device to run evaluation on + **kwargs: Additional evaluation parameters + + Returns: + Dictionary of metric name -> value + """ + ... + + +@runtime_checkable +class Preprocessor(Protocol): + """ + Protocol for data/activation preprocessing. + + Preprocessors transform raw data or activations before metric computation. + + Example implementations: + - CNNUnfoldPreprocessor: Unfold conv activations for covariance computation + - NormalizePreprocessor: Normalize activations + - PatchPreprocessor: Extract patches from spatial activations + """ + + @property + def name(self) -> str: + """Preprocessor name.""" + ... + + def preprocess( + self, + data: torch.Tensor, + **kwargs: Any + ) -> torch.Tensor: + """ + Preprocess data. + + Args: + data: Input tensor to preprocess + **kwargs: Preprocessing parameters + + Returns: + Preprocessed tensor + """ + ... + + +# ============================================================================= +# BASE CLASSES (Optional abstract implementations) +# ============================================================================= + +class BaseMetric(ABC): + """ + Abstract base class for metrics with common functionality. + + Inherit from this for convenience, or just implement the AlignmentMetric protocol. + """ + + name: str = "base_metric" + requires_inputs: bool = False + requires_weights: bool = False + requires_outputs: bool = True + + @abstractmethod + def compute( + self, + inputs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + outputs: Optional[torch.Tensor] = None, + **kwargs: Any + ) -> torch.Tensor: + """Compute the metric.""" + pass + + def compute_distributed( + self, + inputs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + outputs: Optional[torch.Tensor] = None, + world_size: int = 1, + rank: int = 0, + **kwargs: Any, + ) -> torch.Tensor: + """Compute metric with distributed reduction.""" + result = self.compute(inputs, weights, outputs, **kwargs) + if world_size > 1: + import torch.distributed as dist + dist.all_reduce(result, op=dist.ReduceOp.SUM) + result = result / world_size + return result + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name='{self.name}')" + + +class BaseAnalyzer(ABC): + """Abstract base class for analyzers.""" + + name: str = "base_analyzer" + requires: List[str] = [] + provides: List[str] = [] + + @abstractmethod + def analyze( + self, + metrics: Dict[str, Any], + model: Optional[nn.Module] = None, + activations: Optional[Dict[str, torch.Tensor]] = None, + **kwargs: Any + ) -> Dict[str, Any]: + """Perform analysis.""" + pass + + def visualize( + self, + results: Dict[str, Any], + output_dir: Optional[str] = None, + **kwargs: Any + ) -> List[str]: + """Default visualization (override for custom plots).""" + return [] + + +class BasePruner(ABC): + """Abstract base class for pruning strategies.""" + + name: str = "base_pruner" + structured: bool = True + + @abstractmethod + def compute_importance( + self, + model: nn.Module, + layer_name: str, + activations: Optional[torch.Tensor] = None, + gradients: Optional[torch.Tensor] = None, + **kwargs: Any + ) -> torch.Tensor: + """Compute importance scores.""" + pass + + def select_to_prune( + self, + scores: torch.Tensor, + amount: float, + **kwargs: Any + ) -> torch.Tensor: + """Select neurons to prune (default: lowest scores).""" + n_prune = int(len(scores) * amount) if amount < 1 else int(amount) + _, indices = torch.sort(scores) + mask = torch.zeros_like(scores, dtype=torch.bool) + mask[indices[:n_prune]] = True + return mask + + def apply( + self, + model: nn.Module, + prune_mask: Dict[str, torch.Tensor], + **kwargs: Any + ) -> nn.Module: + """Apply pruning masks to model.""" + for name, mask in prune_mask.items(): + layer = dict(model.named_modules()).get(name) + if layer is not None and hasattr(layer, 'weight'): + with torch.no_grad(): + layer.weight.data[mask] = 0 + return model + + +# ============================================================================= +# CONFIGURATION DATACLASSES +# ============================================================================= + +@dataclass +class MetricConfig: + """Configuration for a metric.""" + name: str + enabled: bool = True + params: Dict[str, Any] = None + + def __post_init__(self): + if self.params is None: + self.params = {} + + +@dataclass +class AnalyzerConfig: + """Configuration for an analyzer.""" + name: str + enabled: bool = True + params: Dict[str, Any] = None + + def __post_init__(self): + if self.params is None: + self.params = {} + + +@dataclass +class PrunerConfig: + """Configuration for a pruner.""" + name: str + amount: float = 0.5 + selection: str = "low" # low, high, random + params: Dict[str, Any] = None + + def __post_init__(self): + if self.params is None: + self.params = {} diff --git a/src/alignment/core/registry.py b/src/alignment/core/registry.py index b9dc32ba..648f71b5 100644 --- a/src/alignment/core/registry.py +++ b/src/alignment/core/registry.py @@ -2,72 +2,195 @@ Central registry for managing all framework components. This module provides a unified registration system for metrics, models, -datasets, and experiments, making them easily discoverable and instantiable. +datasets, experiments, analyzers, visualizers, and pruning strategies, +making them easily discoverable and instantiable. + +The registry system enables: +1. Plugin-based architecture - add new components without modifying core code +2. Configuration-driven instantiation - create components from config files +3. Auto-discovery - automatically find and register components from packages +4. Metadata tracking - store component capabilities, requirements, and documentation + +Usage: + # Register a new metric + @register_metric("my_custom_metric", category="information", requires_pairs=True) + class MyCustomMetric(BaseMetric): + ... + + # Use the metric + metric = get_metric("my_custom_metric", param1=value1) + + # Or from config + metric = create_from_config({"name": "my_custom_metric", "param1": value1}) """ import logging -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union logger = logging.getLogger(__name__) T = TypeVar("T") +@dataclass +class ComponentInfo: + """Metadata about a registered component.""" + name: str + cls: Type[Any] + category: str = "default" + description: str = "" + requires: List[str] = field(default_factory=list) # Dependencies + provides: List[str] = field(default_factory=list) # What it provides + config_schema: Optional[Dict[str, Any]] = None # Expected config params + tags: Set[str] = field(default_factory=set) # Searchable tags + version: str = "1.0.0" + author: str = "" + extra: Dict[str, Any] = field(default_factory=dict) + + class Registry: - """Generic registry for framework components.""" + """ + Generic registry for framework components with enhanced metadata and discovery. + + Features: + - Category-based organization + - Tag-based search + - Dependency tracking + - Config schema validation + - Plugin discovery + """ - def __init__(self, name: str): + def __init__(self, name: str, base_protocol: Optional[Type] = None): """ Initialize a registry. Args: name: Name of the registry (e.g., "metrics", "models") + base_protocol: Optional protocol/base class that all registered items should follow """ self.name = name - self._registry: Dict[str, Type[Any]] = {} - self._metadata: Dict[str, Dict[str, Any]] = {} - - def register(self, name: str, cls: Optional[Type[T]] = None, **metadata: Any) -> Union[Callable[[Type[T]], Type[T]], Type[T]]: + self.base_protocol = base_protocol + self._registry: Dict[str, ComponentInfo] = {} + self._by_category: Dict[str, List[str]] = {} + self._by_tag: Dict[str, Set[str]] = {} + self._aliases: Dict[str, str] = {} # alias -> canonical name + + def register( + self, + name: str, + cls: Optional[Type[T]] = None, + category: str = "default", + description: str = "", + requires: Optional[List[str]] = None, + provides: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + aliases: Optional[List[str]] = None, + config_schema: Optional[Dict[str, Any]] = None, + **extra: Any + ) -> Union[Callable[[Type[T]], Type[T]], Type[T]]: """ - Register a class in the registry. + Register a class in the registry with comprehensive metadata. Can be used as a decorator or called directly. Args: - name: Name to register the class under + name: Canonical name to register the class under cls: Class to register (if not using as decorator) - **metadata: Additional metadata to store with the registration + category: Category for organization (e.g., "information", "gradient", "magnitude") + description: Human-readable description + requires: List of dependencies (other registered components) + provides: List of capabilities this component provides + tags: Searchable tags for discovery + aliases: Alternative names that map to this component + config_schema: Schema for expected configuration parameters + **extra: Additional metadata Returns: Registered class or decorator function + + Example: + @registry.register( + "rayleigh_quotient", + category="alignment", + description="Measures alignment between weights and activation covariance", + requires=["activations", "weights"], + provides=["per_neuron_score"], + tags=["alignment", "covariance", "rq"], + aliases=["rq", "rayleigh", "alignment_score"] + ) + class RayleighQuotient: + ... """ + requires = requires or [] + provides = provides or [] + tags = set(tags) if tags else set() + aliases = aliases or [] def decorator(cls_to_register: Type[T]) -> Type[T]: if name in self._registry: logger.warning(f"Overwriting existing registration '{name}' in {self.name} registry") - self._registry[name] = cls_to_register - self._metadata[name] = metadata + + # Create component info + info = ComponentInfo( + name=name, + cls=cls_to_register, + category=category, + description=description or cls_to_register.__doc__ or "", + requires=requires, + provides=provides, + config_schema=config_schema, + tags=tags, + extra=extra, + ) + + self._registry[name] = info + + # Index by category + if category not in self._by_category: + self._by_category[category] = [] + if name not in self._by_category[category]: + self._by_category[category].append(name) + + # Index by tags + for tag in tags: + if tag not in self._by_tag: + self._by_tag[tag] = set() + self._by_tag[tag].add(name) + + # Register aliases + for alias in aliases: + self._aliases[alias] = name + # Also add aliases to tag index + if alias not in self._by_tag: + self._by_tag[alias] = set() + self._by_tag[alias].add(name) # Add registry info to the class setattr(cls_to_register, "_registry_name", name) setattr(cls_to_register, "_registry", self.name) + setattr(cls_to_register, "_registry_info", info) - logger.debug(f"Registered '{name}' in {self.name} registry") + logger.debug(f"Registered '{name}' in {self.name} registry (category={category})") return cls_to_register if cls is None: - # Used as decorator return decorator else: - # Direct registration return decorator(cls) + + def resolve_name(self, name: str) -> str: + """Resolve an alias to its canonical name.""" + return self._aliases.get(name, name) def get(self, name: str) -> Type[Any]: """ - Get a registered class by name. + Get a registered class by name (supports aliases). Args: - name: Name of the registered class + name: Name or alias of the registered class Returns: The registered class @@ -75,25 +198,86 @@ def get(self, name: str) -> Type[Any]: Raises: KeyError: If name is not registered """ - if name not in self._registry: + canonical = self.resolve_name(name) + if canonical not in self._registry: available = list(self._registry.keys()) - raise KeyError(f"'{name}' not found in {self.name} registry. " f"Available: {available}") - return self._registry[name] + raise KeyError(f"'{name}' not found in {self.name} registry. Available: {available}") + return self._registry[canonical].cls + + def get_info(self, name: str) -> ComponentInfo: + """Get full component info for a registered class.""" + canonical = self.resolve_name(name) + if canonical not in self._registry: + raise KeyError(f"'{name}' not found in {self.name} registry") + return self._registry[canonical] def get_metadata(self, name: str) -> Dict[str, Any]: - """Get metadata for a registered class.""" - return self._metadata.get(name, {}) - - def list(self) -> List[str]: - """List all registered names.""" + """Get metadata for a registered class (legacy compatibility).""" + info = self.get_info(name) + return { + "category": info.category, + "description": info.description, + "requires": info.requires, + "provides": info.provides, + "tags": list(info.tags), + **info.extra + } + + def list(self, category: Optional[str] = None) -> List[str]: + """ + List registered names, optionally filtered by category. + + Args: + category: Optional category to filter by + + Returns: + List of registered names + """ + if category: + return self._by_category.get(category, []) return list(self._registry.keys()) + + def list_categories(self) -> List[str]: + """List all categories.""" + return list(self._by_category.keys()) + + def search(self, query: str = "", tags: Optional[List[str]] = None) -> List[str]: + """ + Search for components by name substring or tags. + + Args: + query: Substring to search in names and descriptions + tags: Tags to filter by (AND logic) + + Returns: + List of matching component names + """ + results = set(self._registry.keys()) + + # Filter by query + if query: + query_lower = query.lower() + results = { + name for name, info in self._registry.items() + if query_lower in name.lower() + or query_lower in info.description.lower() + or any(query_lower in tag for tag in info.tags) + } + + # Filter by tags (AND logic) + if tags: + for tag in tags: + tag_matches = self._by_tag.get(tag, set()) + results = results & tag_matches + + return list(results) def create(self, name: str, **kwargs: Any) -> Any: """ Create an instance of a registered class. Args: - name: Name of the registered class + name: Name or alias of the registered class **kwargs: Arguments to pass to the class constructor Returns: @@ -101,17 +285,61 @@ def create(self, name: str, **kwargs: Any) -> Any: """ cls = self.get(name) return cls(**kwargs) + + def create_from_config(self, config: Dict[str, Any]) -> Any: + """ + Create an instance from a configuration dictionary. + + The config should have a 'name' or 'type' key specifying the component, + and other keys are passed as constructor arguments. + + Args: + config: Configuration dictionary + + Returns: + Instance of the registered class + """ + config = dict(config) # Copy to avoid mutation + name = config.pop("name", None) or config.pop("type", None) + if not name: + raise ValueError("Config must have 'name' or 'type' key") + return self.create(name, **config) def __contains__(self, name: str) -> bool: - """Check if a name is registered.""" - return name in self._registry + """Check if a name or alias is registered.""" + canonical = self.resolve_name(name) + return canonical in self._registry def __len__(self) -> int: """Get number of registered items.""" return len(self._registry) + def __iter__(self): + """Iterate over registered names.""" + return iter(self._registry.keys()) + + def summary(self) -> str: + """Get a human-readable summary of the registry.""" + lines = [f"\n=== {self.name.upper()} REGISTRY ==="] + lines.append(f"Total: {len(self)} components in {len(self._by_category)} categories\n") + + for category, names in sorted(self._by_category.items()): + lines.append(f" {category}:") + for name in sorted(names): + info = self._registry[name] + desc = info.description[:50] + "..." if len(info.description) > 50 else info.description + lines.append(f" - {name}: {desc}") + + return "\n".join(lines) + + +# ============================================================================= +# GLOBAL REGISTRIES +# ============================================================================= +# These registries are the central point for all framework components. +# Components register themselves using decorators, and are instantiated +# via configuration or direct API calls. -# Create global registries METRIC_REGISTRY = Registry("metrics") MODEL_REGISTRY = Registry("models") DATASET_REGISTRY = Registry("datasets") @@ -119,15 +347,47 @@ def __len__(self) -> int: AGGREGATOR_REGISTRY = Registry("aggregators") REPORTER_REGISTRY = Registry("reporters") +# New registries for enhanced modularity +ANALYZER_REGISTRY = Registry("analyzers") # Analysis pipelines (clustering, halo, etc.) +VISUALIZER_REGISTRY = Registry("visualizers") # Visualization components +PRUNER_REGISTRY = Registry("pruners") # Pruning strategies +EVALUATOR_REGISTRY = Registry("evaluators") # Evaluation methods (accuracy, perplexity, etc.) +PREPROCESSOR_REGISTRY = Registry("preprocessors") # Data/activation preprocessors + +# Collect all registries for easy iteration +ALL_REGISTRIES = { + "metrics": METRIC_REGISTRY, + "models": MODEL_REGISTRY, + "datasets": DATASET_REGISTRY, + "experiments": EXPERIMENT_REGISTRY, + "aggregators": AGGREGATOR_REGISTRY, + "reporters": REPORTER_REGISTRY, + "analyzers": ANALYZER_REGISTRY, + "visualizers": VISUALIZER_REGISTRY, + "pruners": PRUNER_REGISTRY, + "evaluators": EVALUATOR_REGISTRY, + "preprocessors": PREPROCESSOR_REGISTRY, +} + + +# ============================================================================= +# DECORATOR FUNCTIONS FOR REGISTRATION +# ============================================================================= -# Decorator functions for registration def register_metric(name: str, **metadata: Any) -> Callable: - """Register a metric class.""" + """ + Register a metric class. + + Example: + @register_metric("my_metric", category="information", tags=["mi", "entropy"]) + class MyMetric(BaseMetric): + ... + """ return METRIC_REGISTRY.register(name, **metadata) def register_model(name: str, **metadata: Any) -> Callable: - """Register a model class.""" + """Register a model/architecture class.""" return MODEL_REGISTRY.register(name, **metadata) @@ -151,7 +411,49 @@ def register_reporter(name: str, **metadata: Any) -> Callable: return REPORTER_REGISTRY.register(name, **metadata) -# Getter functions +def register_analyzer(name: str, **metadata: Any) -> Callable: + """ + Register an analyzer class (clustering, halo analysis, etc.). + + Example: + @register_analyzer("kmeans_clustering", category="clustering") + class KMeansClustering(BaseAnalyzer): + ... + """ + return ANALYZER_REGISTRY.register(name, **metadata) + + +def register_visualizer(name: str, **metadata: Any) -> Callable: + """Register a visualizer class.""" + return VISUALIZER_REGISTRY.register(name, **metadata) + + +def register_pruner(name: str, **metadata: Any) -> Callable: + """ + Register a pruning strategy. + + Example: + @register_pruner("magnitude", category="baseline", tags=["simple", "weight-based"]) + class MagnitudePruning(BasePruner): + ... + """ + return PRUNER_REGISTRY.register(name, **metadata) + + +def register_evaluator(name: str, **metadata: Any) -> Callable: + """Register an evaluator class.""" + return EVALUATOR_REGISTRY.register(name, **metadata) + + +def register_preprocessor(name: str, **metadata: Any) -> Callable: + """Register a preprocessor class.""" + return PREPROCESSOR_REGISTRY.register(name, **metadata) + + +# ============================================================================= +# GETTER FUNCTIONS +# ============================================================================= + def get_metric(name: str, **kwargs: Any) -> Any: """Get a metric instance by name.""" return METRIC_REGISTRY.create(name, **kwargs) @@ -182,28 +484,240 @@ def get_reporter(name: str, **kwargs: Any) -> Any: return REPORTER_REGISTRY.create(name, **kwargs) -# Auto-discovery function -def discover_and_register(module_path: str, registry_type: str = "all") -> None: +def get_analyzer(name: str, **kwargs: Any) -> Any: + """Get an analyzer instance by name.""" + return ANALYZER_REGISTRY.create(name, **kwargs) + + +def get_visualizer(name: str, **kwargs: Any) -> Any: + """Get a visualizer instance by name.""" + return VISUALIZER_REGISTRY.create(name, **kwargs) + + +def get_pruner(name: str, **kwargs: Any) -> Any: + """Get a pruner instance by name.""" + return PRUNER_REGISTRY.create(name, **kwargs) + + +def get_evaluator(name: str, **kwargs: Any) -> Any: + """Get an evaluator instance by name.""" + return EVALUATOR_REGISTRY.create(name, **kwargs) + + +def get_preprocessor(name: str, **kwargs: Any) -> Any: + """Get a preprocessor instance by name.""" + return PREPROCESSOR_REGISTRY.create(name, **kwargs) + + +# ============================================================================= +# UNIFIED COMPONENT FACTORY +# ============================================================================= + +def create_component( + registry_name: str, + component_name: str, + **kwargs: Any +) -> Any: + """ + Create a component from any registry by name. + + Args: + registry_name: Name of the registry ("metrics", "pruners", etc.) + component_name: Name of the component within that registry + **kwargs: Arguments to pass to the component constructor + + Returns: + Instance of the component + + Example: + metric = create_component("metrics", "rayleigh_quotient", relative=True) + pruner = create_component("pruners", "magnitude", amount=0.5) + """ + if registry_name not in ALL_REGISTRIES: + raise KeyError(f"Unknown registry: {registry_name}. Available: {list(ALL_REGISTRIES.keys())}") + return ALL_REGISTRIES[registry_name].create(component_name, **kwargs) + + +def create_from_config(config: Dict[str, Any], registry_name: Optional[str] = None) -> Any: + """ + Create a component from a configuration dictionary. + + The config should have 'registry' and 'name' keys, or just 'name' if registry_name is provided. + + Args: + config: Configuration dictionary with at least 'name' key + registry_name: Optional registry name (if not in config) + + Returns: + Instance of the component + + Example: + config = {"registry": "metrics", "name": "rayleigh_quotient", "relative": True} + metric = create_from_config(config) + + # Or with explicit registry + config = {"name": "magnitude", "amount": 0.5} + pruner = create_from_config(config, registry_name="pruners") + """ + config = dict(config) + reg_name = config.pop("registry", None) or registry_name + if not reg_name: + raise ValueError("Config must have 'registry' key or registry_name must be provided") + + return ALL_REGISTRIES[reg_name].create_from_config(config) + + +def list_all_components() -> Dict[str, List[str]]: + """List all registered components across all registries.""" + return {name: reg.list() for name, reg in ALL_REGISTRIES.items()} + + +def print_registry_summary(registry_name: Optional[str] = None) -> None: + """Print a summary of registered components.""" + if registry_name: + if registry_name in ALL_REGISTRIES: + print(ALL_REGISTRIES[registry_name].summary()) + else: + print(f"Unknown registry: {registry_name}") + else: + for name, registry in ALL_REGISTRIES.items(): + if len(registry) > 0: + print(registry.summary()) + + +# ============================================================================= +# AUTO-DISCOVERY +# ============================================================================= + +def discover_and_register(module_path: str, registry_type: str = "all") -> int: """ Auto-discover and register components from a module. + Components are discovered by importing modules, which triggers + the @register_* decorators. + Args: - module_path: Python module path to scan + module_path: Python module path to scan (e.g., "alignment.metrics") registry_type: Type of components to register ("all", "metrics", etc.) + + Returns: + Number of modules imported """ import importlib import pkgutil + count = 0 try: module = importlib.import_module(module_path) # Recursively walk through submodules - for importer, modname, ispkg in pkgutil.walk_packages(path=module.__path__, prefix=module.__name__ + ".", onerror=lambda x: None): + for importer, modname, ispkg in pkgutil.walk_packages( + path=module.__path__, + prefix=module.__name__ + ".", + onerror=lambda x: None + ): try: importlib.import_module(modname) + count += 1 logger.debug(f"Imported module: {modname}") except Exception as e: logger.warning(f"Failed to import {modname}: {e}") except Exception as e: logger.error(f"Failed to discover components from {module_path}: {e}") + + return count + + +def discover_plugins(plugin_dirs: Optional[List[str]] = None) -> int: + """ + Discover and load plugins from specified directories. + + Plugins are Python files or packages that register components using + the @register_* decorators. This allows users to extend the framework + without modifying core code. + + Args: + plugin_dirs: List of directories to search for plugins. + Defaults to ["./plugins", "~/.alignment/plugins"] + + Returns: + Number of plugins loaded + """ + import importlib.util + import sys + + if plugin_dirs is None: + plugin_dirs = [ + "./plugins", + os.path.expanduser("~/.alignment/plugins"), + ] + + count = 0 + for plugin_dir in plugin_dirs: + plugin_path = Path(plugin_dir) + if not plugin_path.exists(): + continue + + logger.info(f"Scanning for plugins in: {plugin_path}") + + # Find all Python files + for py_file in plugin_path.glob("**/*.py"): + if py_file.name.startswith("_"): + continue + + module_name = f"alignment_plugin_{py_file.stem}" + + try: + spec = importlib.util.spec_from_file_location(module_name, py_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + count += 1 + logger.info(f"Loaded plugin: {py_file}") + except Exception as e: + logger.warning(f"Failed to load plugin {py_file}: {e}") + + return count + + +# ============================================================================= +# INITIALIZATION +# ============================================================================= + +_initialized = False + +def initialize_registries(discover_builtin: bool = True, discover_plugins_flag: bool = True) -> None: + """ + Initialize all registries by discovering and registering built-in components. + + This should be called once at application startup. + + Args: + discover_builtin: Whether to discover built-in components + discover_plugins_flag: Whether to discover user plugins + """ + global _initialized + if _initialized: + return + + if discover_builtin: + # Discover built-in components from alignment package + builtin_modules = [ + "alignment.metrics", + "alignment.pruning.strategies", + "alignment.analysis", + "alignment.models", + ] + for module in builtin_modules: + try: + discover_and_register(module) + except Exception as e: + logger.debug(f"Could not discover from {module}: {e}") + + if discover_plugins_flag: + discover_plugins() + + _initialized = True + logger.info(f"Registries initialized: {sum(len(r) for r in ALL_REGISTRIES.values())} components") From 34ca490efa45d92a28f626ca4252d1160820fa7a Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 8 Dec 2025 16:15:50 -0500 Subject: [PATCH 05/12] refactor config/ update vision configs --- configs/prune_llm/llama2_7b_unified.yaml | 23 ++- configs/prune_llm/llama3_8b_unified.yaml | 35 ++++- configs/prune_llm/mistral_7b_unified.yaml | 23 ++- configs/prune_llm/qwen2_7b_unified.yaml | 24 +++- src/alignment/configs/config_loader.py | 166 +++++++++++++++++++--- 5 files changed, 236 insertions(+), 35 deletions(-) diff --git a/configs/prune_llm/llama2_7b_unified.yaml b/configs/prune_llm/llama2_7b_unified.yaml index cee0950d..91d82814 100644 --- a/configs/prune_llm/llama2_7b_unified.yaml +++ b/configs/prune_llm/llama2_7b_unified.yaml @@ -17,6 +17,8 @@ experiment: seed: 42 device: "cuda" output_dir: "./results/paper/llama2_7b" + num_networks: 1 + save_activations: true # ----------------------------------------------------------------------------- # MODEL @@ -135,15 +137,21 @@ pruning: dependency_aware: true algorithms: - - "random" - - "magnitude" + # Alignment-based - "rayleigh_quotient" - "redundancy" + - "average_redundancy" + # SCAR-based - "scar_loss_proxy" + # Supernode-aware - "supernode_protection_score" - "supernode_connectivity_score" + # Generalized - "generalized_importance" - "cross_layer_importance" + # Magnitude baseline + - "magnitude" + # SOTA baselines - "wanda" - "sparsegpt" @@ -152,8 +160,11 @@ pruning: - "magnitude" - "rayleigh_quotient" - "redundancy" + - "average_redundancy" - "scar_loss_proxy" - "scar_taylor" + - "scar_activation_power" + - "scar_curvature" - "supernode_protection_score" - "supernode_connectivity_score" - "generalized_importance" @@ -174,8 +185,10 @@ evaluation: enabled: true accuracy: false loss: true + bits_per_byte: true perplexity_enabled: true + evaluation_num_samples: 100 perplexity_datasets: - name: "wikitext" subset: "wikitext-2-raw-v1" @@ -245,6 +258,12 @@ output: save_figures: true save_checkpoints: false +# ----------------------------------------------------------------------------- +# PERFORMANCE SETTINGS +# ----------------------------------------------------------------------------- +performance: + eval_batches: null # null = use all batches + # ----------------------------------------------------------------------------- # EXTRA # ----------------------------------------------------------------------------- diff --git a/configs/prune_llm/llama3_8b_unified.yaml b/configs/prune_llm/llama3_8b_unified.yaml index b421ae2e..29c8edeb 100644 --- a/configs/prune_llm/llama3_8b_unified.yaml +++ b/configs/prune_llm/llama3_8b_unified.yaml @@ -23,6 +23,8 @@ experiment: seed: 42 device: "cuda" output_dir: "./results/paper/llama3_8b" + num_networks: 1 + save_activations: true # ----------------------------------------------------------------------------- # MODEL @@ -152,30 +154,41 @@ pruning: # Algorithms to test (unified + LLM-specific) algorithms: - # Core (same as vision) - - "random" - - "magnitude" + # Alignment-based (core metrics) - "rayleigh_quotient" - - "redundancy" + - "redundancy" # maps to gaussian_mi_analytic + - "average_redundancy" # info-theoretic redundancy - # LLM-specific + # SCAR-based (gradient-informed) - "scar_loss_proxy" + + # Supernode-aware (novel contribution) - "supernode_protection_score" - "supernode_connectivity_score" + + # Generalized (no outlier assumption) - "generalized_importance" + + # Cross-layer importance - "cross_layer_importance" + # Magnitude baseline + - "magnitude" # maps to activation_l2_norm + # SOTA baselines - "wanda" - "sparsegpt" scoring_methods: - "random" - - "magnitude" + - "magnitude" # activation_l2_norm - "rayleigh_quotient" - - "redundancy" + - "redundancy" # gaussian_mi_analytic + - "average_redundancy" - "scar_loss_proxy" - "scar_taylor" + - "scar_activation_power" + - "scar_curvature" - "supernode_protection_score" - "supernode_connectivity_score" - "generalized_importance" @@ -196,9 +209,11 @@ evaluation: enabled: true accuracy: false # Not applicable to LLM loss: true + bits_per_byte: true # Compression efficiency metric # LLM-specific: Perplexity perplexity_enabled: true + evaluation_num_samples: 100 # Samples per benchmark perplexity_datasets: - name: "wikitext" subset: "wikitext-2-raw-v1" @@ -272,6 +287,12 @@ output: save_figures: true save_checkpoints: false +# ----------------------------------------------------------------------------- +# PERFORMANCE SETTINGS +# ----------------------------------------------------------------------------- +performance: + eval_batches: null # null = use all batches + # ----------------------------------------------------------------------------- # EXTRA (LLM-specific settings not in unified schema) # ----------------------------------------------------------------------------- diff --git a/configs/prune_llm/mistral_7b_unified.yaml b/configs/prune_llm/mistral_7b_unified.yaml index 735bbc6a..6eadae48 100644 --- a/configs/prune_llm/mistral_7b_unified.yaml +++ b/configs/prune_llm/mistral_7b_unified.yaml @@ -16,6 +16,8 @@ experiment: seed: 42 device: "cuda" output_dir: "./results/paper/mistral_7b" + num_networks: 1 + save_activations: true # ----------------------------------------------------------------------------- # MODEL @@ -134,15 +136,21 @@ pruning: dependency_aware: true algorithms: - - "random" - - "magnitude" + # Alignment-based - "rayleigh_quotient" - "redundancy" + - "average_redundancy" + # SCAR-based - "scar_loss_proxy" + # Supernode-aware - "supernode_protection_score" - "supernode_connectivity_score" + # Generalized - "generalized_importance" - "cross_layer_importance" + # Magnitude baseline + - "magnitude" + # SOTA baselines - "wanda" - "sparsegpt" @@ -151,8 +159,11 @@ pruning: - "magnitude" - "rayleigh_quotient" - "redundancy" + - "average_redundancy" - "scar_loss_proxy" - "scar_taylor" + - "scar_activation_power" + - "scar_curvature" - "supernode_protection_score" - "supernode_connectivity_score" - "generalized_importance" @@ -173,8 +184,10 @@ evaluation: enabled: true accuracy: false loss: true + bits_per_byte: true perplexity_enabled: true + evaluation_num_samples: 100 perplexity_datasets: - name: "wikitext" subset: "wikitext-2-raw-v1" @@ -244,6 +257,12 @@ output: save_figures: true save_checkpoints: false +# ----------------------------------------------------------------------------- +# PERFORMANCE SETTINGS +# ----------------------------------------------------------------------------- +performance: + eval_batches: null # null = use all batches + # ----------------------------------------------------------------------------- # EXTRA # ----------------------------------------------------------------------------- diff --git a/configs/prune_llm/qwen2_7b_unified.yaml b/configs/prune_llm/qwen2_7b_unified.yaml index eaa616e5..5d139da6 100644 --- a/configs/prune_llm/qwen2_7b_unified.yaml +++ b/configs/prune_llm/qwen2_7b_unified.yaml @@ -17,6 +17,8 @@ experiment: seed: 42 device: "cuda" output_dir: "./results/paper/qwen2_7b" + num_networks: 1 + save_activations: true # ----------------------------------------------------------------------------- # MODEL @@ -135,15 +137,21 @@ pruning: dependency_aware: true algorithms: - - "random" - - "magnitude" + # Alignment-based - "rayleigh_quotient" - "redundancy" + - "average_redundancy" + # SCAR-based - "scar_loss_proxy" + # Supernode-aware - "supernode_protection_score" - "supernode_connectivity_score" + # Generalized - "generalized_importance" - "cross_layer_importance" + # Magnitude baseline + - "magnitude" + # SOTA baselines - "wanda" - "sparsegpt" @@ -152,8 +160,11 @@ pruning: - "magnitude" - "rayleigh_quotient" - "redundancy" + - "average_redundancy" - "scar_loss_proxy" - "scar_taylor" + - "scar_activation_power" + - "scar_curvature" - "supernode_protection_score" - "supernode_connectivity_score" - "generalized_importance" @@ -174,6 +185,9 @@ evaluation: enabled: true accuracy: false loss: true + bits_per_byte: true + + evaluation_num_samples: 100 perplexity_enabled: true perplexity_datasets: @@ -245,6 +259,12 @@ output: save_figures: true save_checkpoints: false +# ----------------------------------------------------------------------------- +# PERFORMANCE SETTINGS +# ----------------------------------------------------------------------------- +performance: + eval_batches: null # null = use all batches + # ----------------------------------------------------------------------------- # EXTRA # ----------------------------------------------------------------------------- diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index 2d3e0350..91664a25 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -25,8 +25,10 @@ METRIC_UNIFIED_TO_ORIGINAL = { "rayleigh_quotient": "rayleigh_quotient", "redundancy": "gaussian_mi_analytic", + "average_redundancy": "average_redundancy", # Keep as-is "synergy": "synergy_gaussian_mmi", "magnitude": "activation_l2_norm", + "taylor": "taylor", # Vision taylor importance } # Reverse mapping: original -> unified @@ -94,6 +96,8 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: original["results_path"] = exp["output_dir"] if "num_networks" in exp: original["num_networks"] = exp["num_networks"] + if "save_activations" in exp: + original["save_activations"] = exp["save_activations"] # ------------------------------------------------------------------------- # MODEL @@ -295,12 +299,26 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: "datasets": ev["perplexity_datasets"], } + # bits_per_byte + if ev.get("bits_per_byte"): + original["evaluation"]["bits_per_byte"] = True + + # evaluation_num_samples + if "evaluation_num_samples" in ev: + original["evaluation_num_samples"] = ev["evaluation_num_samples"] + # Benchmarks (LLM) if ev.get("benchmarks_enabled"): if "benchmark_tasks" in ev: original["evaluation"]["benchmarks"] = ev["benchmark_tasks"] original["evaluation"]["batch_size"] = ev.get("benchmark_batch_size", 8) + # ------------------------------------------------------------------------- + # PERFORMANCE + # ------------------------------------------------------------------------- + if "performance" in unified: + original["performance"] = unified["performance"] + # ------------------------------------------------------------------------- # VISUALIZATION # ------------------------------------------------------------------------- @@ -326,32 +344,34 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: original["experiment"]["output_dir"] = out["dir"] # ------------------------------------------------------------------------- - # EXTRA - Expand LLM-specific settings from extra block + # EXTRA - Expand LLM-specific settings from extra block to top-level # ------------------------------------------------------------------------- if "extra" in unified: extra = unified["extra"] - # Analysis options (with all detailed settings) + # Analysis options (with all detailed settings) - TOP LEVEL if "analysis" in extra: original["analysis"] = extra["analysis"] - # Supernode robustness + # Supernode robustness - TOP LEVEL if "supernode_robustness" in extra: original["supernode_robustness"] = extra["supernode_robustness"] - # Supernode summary + # Supernode summary - TOP LEVEL if "supernode_summary" in extra: original["supernode_summary"] = extra["supernode_summary"] - # Multi-supernode + # Multi-supernode - TOP LEVEL if "multi_supernode" in extra: original["multi_supernode"] = extra["multi_supernode"] - # Cross-layer + # Cross-layer - TOP LEVEL if "cross_layer" in extra: original["cross_layer"] = extra["cross_layer"] + if extra["cross_layer"].get("enabled"): + original["do_connectivity_pruning"] = True - # Generalized importance + # Generalized importance - TOP LEVEL if "generalized_importance" in extra: original["generalized_importance"] = extra["generalized_importance"] if extra["generalized_importance"].get("enabled"): @@ -362,31 +382,133 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: if "halo_analysis" not in original: original["halo_analysis"] = {} original["halo_analysis"].update(extra["halo_analysis"]) + if extra["halo_analysis"].get("enabled"): + original["do_halo_analysis"] = True - # Visualization (detailed paper figure settings) + # Visualization (detailed paper figure settings) - MERGE with top-level if "visualization" in extra: if "visualization" not in original: original["visualization"] = {} - original["visualization"].update(extra["visualization"]) + # Merge extra.visualization into top-level visualization + extra_viz = extra["visualization"] + for key, value in extra_viz.items(): + original["visualization"][key] = value - # Top-level flags + # Top-level flags from extra for flag in ["do_scar_metrics", "do_directed_redundancy", "do_connectivity_pruning", "do_halo_analysis", "do_generalized_importance"]: if flag in extra: original[flag] = extra[flag] - # LLM block reconstruction - llm_block = {} - if original.get("do_scar_metrics"): - llm_block["scar_metrics"] = True - if "scar_num_samples" in original: - llm_block["scar_num_samples"] = original["scar_num_samples"] - if "scar_max_length" in original: - llm_block["scar_max_length"] = original["scar_max_length"] - if original.get("do_perplexity_computation"): - llm_block["evaluate_perplexity"] = True - if llm_block: - original["llm"] = llm_block + # Pretrain settings (for vision) + if "pretrain_epochs" in extra: + original["pretrain_epochs"] = extra["pretrain_epochs"] + if "pretrain_lr" in extra: + original["pretrain_lr"] = extra["pretrain_lr"] + + # Baselines (for vision) + if "baselines" in extra: + original["baselines"] = extra["baselines"] + + # Sensitivity analysis (for vision) + if "sensitivity_analysis" in extra: + original["sensitivity_analysis"] = extra["sensitivity_analysis"] + + # Structured pruning (for vision) + if "structured_pruning" in extra: + original["structured_pruning"] = extra["structured_pruning"] + + # Feature analysis (for vision) + if "feature_analysis" in extra: + original["feature_analysis"] = extra["feature_analysis"] + + # Efficiency tracking (for vision) + if "efficiency" in extra: + original["efficiency"] = extra["efficiency"] + + # ------------------------------------------------------------------------- + # BUILD LLM BLOCK - Reconstruct full llm: section as original expects + # ------------------------------------------------------------------------- + llm_block = {} + + # SCAR settings + if original.get("do_scar_metrics"): + llm_block["scar_metrics"] = True + if "scar_num_samples" in original: + llm_block["scar_num_samples"] = original["scar_num_samples"] + if "scar_max_length" in original: + llm_block["scar_max_length"] = original["scar_max_length"] + + # Perplexity settings + if original.get("do_perplexity_computation"): + llm_block["evaluate_perplexity"] = True + + # Build evaluation_metrics list from evaluation.benchmarks + evaluation_metrics = [] + if "evaluation" in original: + ev = original["evaluation"] + + # Perplexity metrics + if ev.get("perplexity", {}).get("enabled") or original.get("do_perplexity_computation"): + evaluation_metrics.extend(["perplexity", "loss", "bits_per_byte"]) + + # Build benchmark metrics from benchmark tasks + if "benchmarks" in ev: + for benchmark in ev["benchmarks"]: + if isinstance(benchmark, dict): + task_name = benchmark.get("name", "") + # Map task name to evaluation_metrics format + if task_name == "mmlu": + evaluation_metrics.append("accuracy_mmlu") + elif task_name == "hellaswag": + evaluation_metrics.append("accuracy_hellaswag") + elif task_name == "piqa": + evaluation_metrics.append("accuracy_piqa") + elif task_name == "boolq": + evaluation_metrics.append("accuracy_boolq") + elif task_name == "winogrande": + evaluation_metrics.append("accuracy_winogrande") + elif task_name == "arc_easy": + evaluation_metrics.append("accuracy_arc_easy") + elif task_name == "arc_challenge": + evaluation_metrics.append("accuracy_arc_challenge") + elif task_name == "openbookqa": + evaluation_metrics.append("accuracy_openbookqa") + elif task_name == "gsm8k": + evaluation_metrics.append("accuracy_gsm8k") + elif task_name == "truthfulqa": + evaluation_metrics.append("accuracy_truthfulqa") + + # Remove duplicates while preserving order + seen = set() + unique_metrics = [] + for m in evaluation_metrics: + if m not in seen: + seen.add(m) + unique_metrics.append(m) + + if unique_metrics: + llm_block["evaluation_metrics"] = unique_metrics + + if llm_block: + original["llm"] = llm_block + + # ------------------------------------------------------------------------- + # ENSURE TOP-LEVEL FLAGS ARE SET based on section enables + # ------------------------------------------------------------------------- + # Set do_scar_metrics if SCAR section is enabled + if unified.get("metrics", {}).get("scar", {}).get("enabled"): + original["do_scar_metrics"] = True + original["scar_num_samples"] = unified["metrics"]["scar"].get("num_samples", 64) + original["scar_max_length"] = unified["metrics"]["scar"].get("max_length", 512) + + # Set flags based on section enablement + if unified.get("supernode", {}).get("enabled"): + if unified["supernode"].get("cross_layer_analysis"): + original["do_connectivity_pruning"] = True + + if unified.get("halo_analysis", {}).get("enabled"): + original["do_halo_analysis"] = True logger.info("Converted unified config to original format") return original From c9a787017567d247540fd98493767fd8b0313401 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 12 Dec 2025 13:35:04 -0500 Subject: [PATCH 06/12] refactor the prunning and score computations --- configs/prune_llm/README.md | 78 ++- configs/prune_llm/llama2_7b_full.yaml | 2 + configs/prune_llm/llama2_7b_unified.yaml | 10 + configs/prune_llm/llama3_8b_full.yaml | 2 + configs/prune_llm/llama3_8b_unified.yaml | 13 + configs/prune_llm/mistral_7b_full.yaml | 2 + configs/prune_llm/mistral_7b_unified.yaml | 10 + configs/prune_llm/qwen2_7b_full.yaml | 2 + configs/prune_llm/qwen2_7b_unified.yaml | 10 + configs/template.yaml | 254 +++++++- configs/unified_template.yaml | 124 +++- .../mobilenetv2_cifar10_full.yaml | 40 +- .../mobilenetv2_cifar10_unified.yaml | 83 ++- .../vision_prune/resnet18_cifar10_full.yaml | 62 +- .../resnet18_cifar10_unified.yaml | 94 ++- .../vision_prune/resnet50_imagenet100.yaml | 41 +- .../resnet50_imagenet100_unified.yaml | 83 ++- configs/vision_prune/vgg16_cifar10_full.yaml | 39 +- .../vision_prune/vgg16_cifar10_unified.yaml | 83 ++- docs/source/api/experiments.rst | 195 ------ scripts/run_experiment.py | 422 +++++++++++-- scripts/verify_pruning.py | 96 +++ slurm_jobs/prune_llm/run_all_paper.sh | 32 +- slurm_jobs/prune_llm/run_llama2_7b.sh | 19 +- slurm_jobs/prune_llm/run_llama3_8b.sh | 20 +- slurm_jobs/prune_llm/run_mistral_7b.sh | 19 +- slurm_jobs/prune_llm/run_qwen2_7b.sh | 19 +- .../run_cluster_analysis_resnet18.sh | 22 +- .../analysis/visualization/__init__.py | 15 + .../analysis/visualization/cluster_plots.py | 212 +++++++ .../analysis/visualization/halo_plots.py | 2 +- .../analysis/visualization/pruning_plots.py | 289 +++++++++ .../visualization/unified_visualizer.py | 8 +- src/alignment/configs/config_loader.py | 53 +- src/alignment/core/base.py | 102 +++- src/alignment/data/datasets/__init__.py | 91 +-- .../data/datasets/unified_dataset.py | 345 ----------- src/alignment/experiments/__init__.py | 29 +- src/alignment/experiments/base.py | 9 +- .../experiments/config_components.py | 292 --------- .../experiments/general_alignment.py | 20 +- src/alignment/experiments/llm_experiments.py | 95 ++- .../parallel_pruning_experiment.py | 572 ------------------ src/alignment/experiments/training_utils.py | 161 ----- src/alignment/infrastructure/README.md | 164 ++++- src/alignment/infrastructure/__init__.py | 33 +- .../infrastructure/configuration/__init__.py | 16 +- .../infrastructure/storage/__init__.py | 15 + .../infrastructure/storage/job_directory.py | 357 +++++++++++ .../infrastructure/storage/logging.py | 1 + src/alignment/metrics/__init__.py | 207 +++++++ src/alignment/metrics/information/__init__.py | 4 +- .../metrics/information/gaussian_mi.py | 87 +++ .../metrics/information/synergy_mmi.py | 13 +- .../metrics/rayleigh/rayleigh_quotient.py | 355 +++++++++-- .../metrics/spectral/spectral_alignment.py | 48 ++ .../metrics/spectral/spectral_classic.py | 48 ++ src/alignment/models/README.md | 137 +++++ src/alignment/pruning/README.md | 150 ++++- src/alignment/pruning/__init__.py | 24 + src/alignment/pruning/baselines.py | 369 ----------- src/alignment/pruning/dependency_aware.py | 24 +- src/alignment/pruning/experiments/__init__.py | 19 - .../pruning/experiments/cascading_layer.py | 479 --------------- .../pruning/experiments/eigenvector_based.py | 417 ------------- .../pruning/experiments/global_pruning.py | 392 ------------ .../pruning/experiments/layer_wise.py | 410 ------------- .../pruning/experiments/progressive.py | 360 ----------- src/alignment/pruning/orchestrator.py | 321 ---------- src/alignment/pruning/parallel_optimizer.py | 306 ---------- src/alignment/pruning/pipeline.py | 140 +++++ src/alignment/pruning/strategies/__init__.py | 11 + src/alignment/pruning/strategies/adaptive.py | 256 +++++++- .../pruning/strategies/eigenvector.py | 288 +++++++++ src/alignment/pruning/strategies/ultimate.py | 281 --------- .../pruning/strategies/ultra_fast.py | 380 ------------ 76 files changed, 4569 insertions(+), 5714 deletions(-) create mode 100755 scripts/verify_pruning.py delete mode 100644 src/alignment/data/datasets/unified_dataset.py delete mode 100644 src/alignment/experiments/config_components.py delete mode 100644 src/alignment/experiments/parallel_pruning_experiment.py delete mode 100644 src/alignment/experiments/training_utils.py create mode 100644 src/alignment/infrastructure/storage/job_directory.py create mode 100644 src/alignment/models/README.md delete mode 100644 src/alignment/pruning/baselines.py delete mode 100644 src/alignment/pruning/experiments/__init__.py delete mode 100644 src/alignment/pruning/experiments/cascading_layer.py delete mode 100644 src/alignment/pruning/experiments/eigenvector_based.py delete mode 100644 src/alignment/pruning/experiments/global_pruning.py delete mode 100644 src/alignment/pruning/experiments/layer_wise.py delete mode 100644 src/alignment/pruning/experiments/progressive.py delete mode 100644 src/alignment/pruning/orchestrator.py delete mode 100644 src/alignment/pruning/parallel_optimizer.py create mode 100644 src/alignment/pruning/pipeline.py create mode 100644 src/alignment/pruning/strategies/eigenvector.py delete mode 100644 src/alignment/pruning/strategies/ultimate.py delete mode 100644 src/alignment/pruning/strategies/ultra_fast.py diff --git a/configs/prune_llm/README.md b/configs/prune_llm/README.md index 9a27f922..b8f85fc7 100644 --- a/configs/prune_llm/README.md +++ b/configs/prune_llm/README.md @@ -6,23 +6,58 @@ Configurations for generating results in the SCAR LLM pruning paper. | Config | Model | Layers | FFN Width | Runtime | |--------|-------|--------|-----------|---------| -| `llama3_8b_full.yaml` | LLaMA-3.1-8B | 32 | 14336 | 6-8h | -| `mistral_7b_full.yaml` | Mistral-7B | 32 | 14336 | 4-6h | -| `llama2_7b_full.yaml` | LLaMA-2-7B | 32 | 11008 | 4-6h | -| `qwen2_7b_full.yaml` | Qwen2-7B | 28 | 18944 | 4-6h | +| `llama3_8b_unified.yaml` | LLaMA-3.1-8B | 32 | 14336 | 6-8h | +| `mistral_7b_unified.yaml` | Mistral-7B | 32 | 14336 | 4-6h | +| `llama2_7b_unified.yaml` | LLaMA-2-7B | 32 | 11008 | 4-6h | +| `qwen2_7b_unified.yaml` | Qwen2-7B | 28 | 18944 | 4-6h | ## Quick Start Run all experiments: ```bash -sbatch slurm_jobs/paper/run_all_paper.sh +bash slurm_jobs/prune_llm/run_all_paper.sh ``` Run single model: ```bash -python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml +python scripts/run_experiment.py --config configs/prune_llm/llama3_8b_unified.yaml ``` +Override base output directory: +```bash +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_unified.yaml \ + --base-output-dir /path/to/your/output/dir +``` + +## Output Directory Structure + +Each job creates a unique directory based on timestamp and SLURM job ID: + +``` +/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +├── llama3_8b_paper_results_20241209_143052_12345678/ +│ ├── results/ # JSON results files +│ │ ├── results_20241209_143052.json +│ │ └── pruning_results.json +│ ├── logs/ # Experiment logs +│ │ └── experiment.log +│ ├── figures/ # All visualizations +│ │ ├── fig1_supernode_distribution.pdf +│ │ ├── fig2_halo_redundancy.pdf +│ │ └── fig3_pruning_curves.pdf +│ ├── checkpoints/ # Model checkpoints (if enabled) +│ ├── analysis/ # Post-analysis outputs +│ └── experiment_config.yaml +├── llama2_7b_paper_results_20241209_143100_12345679/ +│ └── ... +``` + +**Directory naming convention:** +- `{experiment_name}_{timestamp}_{job_id}` +- For SLURM jobs: `job_id` = `$SLURM_JOB_ID` +- For local runs: `job_id` = unique 8-character ID + ## Pruning Methods | Category | Methods | @@ -52,23 +87,24 @@ python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml **Few-shot**: HellaSwag (5-shot), PIQA (5-shot), ARC-Challenge (5-shot), MMLU (5-shot) -## Output Structure +## Configuration Options + +### Base Output Directory +The `output.base_dir` setting controls where job directories are created: + +```yaml +output: + # Creates: {base_dir}/{experiment_name}_{timestamp}_{job_id}/ + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" + + # Fallback if base_dir is not set (legacy) + dir: "./results/paper/llama3_8b" ``` -results/paper// -├── metrics/ -│ ├── layer_metrics.json -│ ├── supernode_analysis.json -│ └── halo_redundancy.json -├── evaluation/ -│ ├── perplexity_results.json -│ └── benchmark_results.json -├── pruning/ -│ └── sparsity_curves.json -└── figures/ - ├── fig1_supernode_distribution.pdf - ├── fig2_halo_redundancy.pdf - └── fig3_pruning_curves.pdf + +Can be overridden via CLI: +```bash +python scripts/run_experiment.py --config ... --base-output-dir /new/path ``` ## Features diff --git a/configs/prune_llm/llama2_7b_full.yaml b/configs/prune_llm/llama2_7b_full.yaml index 3bcaa0b7..a8282c59 100644 --- a/configs/prune_llm/llama2_7b_full.yaml +++ b/configs/prune_llm/llama2_7b_full.yaml @@ -221,6 +221,8 @@ pruning: enabled: true target: "ffn" distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true diff --git a/configs/prune_llm/llama2_7b_unified.yaml b/configs/prune_llm/llama2_7b_unified.yaml index 91d82814..4c14434e 100644 --- a/configs/prune_llm/llama2_7b_unified.yaml +++ b/configs/prune_llm/llama2_7b_unified.yaml @@ -59,6 +59,13 @@ calibration: # Note: supernode_protection_score, supernode_connectivity_score are computed # by the supernode analysis pipeline, not as standalone metrics metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -133,6 +140,8 @@ pruning: ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] selection_modes: ["low", "high"] distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true @@ -253,6 +262,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/paper/llama2_7b" save_metrics: true save_figures: true diff --git a/configs/prune_llm/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml index b72902a5..b68b47b3 100644 --- a/configs/prune_llm/llama3_8b_full.yaml +++ b/configs/prune_llm/llama3_8b_full.yaml @@ -272,6 +272,8 @@ pruning: enabled: true target: "ffn" distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true diff --git a/configs/prune_llm/llama3_8b_unified.yaml b/configs/prune_llm/llama3_8b_unified.yaml index 29c8edeb..da1f91a6 100644 --- a/configs/prune_llm/llama3_8b_unified.yaml +++ b/configs/prune_llm/llama3_8b_unified.yaml @@ -71,6 +71,13 @@ calibration: # by the supernode analysis pipeline, not as standalone metrics # ----------------------------------------------------------------------------- metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -149,6 +156,8 @@ pruning: ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] selection_modes: ["low", "high"] distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true @@ -281,7 +290,11 @@ visualization: # ----------------------------------------------------------------------------- # OUTPUT # ----------------------------------------------------------------------------- +# Uses job directory structure: creates unique folders for each run +# Directory format: {base_dir}/{experiment_name}_{timestamp}_{job_id}/ output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" + # dir is ignored when base_dir is set dir: "./results/paper/llama3_8b" save_metrics: true save_figures: true diff --git a/configs/prune_llm/mistral_7b_full.yaml b/configs/prune_llm/mistral_7b_full.yaml index 6777903e..32f75c5e 100644 --- a/configs/prune_llm/mistral_7b_full.yaml +++ b/configs/prune_llm/mistral_7b_full.yaml @@ -220,6 +220,8 @@ pruning: enabled: true target: "ffn" distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true diff --git a/configs/prune_llm/mistral_7b_unified.yaml b/configs/prune_llm/mistral_7b_unified.yaml index 6eadae48..9e87ee48 100644 --- a/configs/prune_llm/mistral_7b_unified.yaml +++ b/configs/prune_llm/mistral_7b_unified.yaml @@ -58,6 +58,13 @@ calibration: # Note: supernode_protection_score, supernode_connectivity_score are computed # by the supernode analysis pipeline, not as standalone metrics metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -132,6 +139,8 @@ pruning: ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] selection_modes: ["low", "high"] distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true @@ -252,6 +261,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/paper/mistral_7b" save_metrics: true save_figures: true diff --git a/configs/prune_llm/qwen2_7b_full.yaml b/configs/prune_llm/qwen2_7b_full.yaml index 899ddd87..7760b121 100644 --- a/configs/prune_llm/qwen2_7b_full.yaml +++ b/configs/prune_llm/qwen2_7b_full.yaml @@ -221,6 +221,8 @@ pruning: enabled: true target: "ffn" distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true diff --git a/configs/prune_llm/qwen2_7b_unified.yaml b/configs/prune_llm/qwen2_7b_unified.yaml index 5d139da6..e0a3b762 100644 --- a/configs/prune_llm/qwen2_7b_unified.yaml +++ b/configs/prune_llm/qwen2_7b_unified.yaml @@ -59,6 +59,13 @@ calibration: # Note: supernode_protection_score, supernode_connectivity_score are computed # by the supernode analysis pipeline, not as standalone metrics metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -133,6 +140,8 @@ pruning: ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] selection_modes: ["low", "high"] distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 structured: true dependency_aware: true @@ -254,6 +263,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/paper/qwen2_7b" save_metrics: true save_figures: true diff --git a/configs/template.yaml b/configs/template.yaml index f8a7057e..ae536b48 100644 --- a/configs/template.yaml +++ b/configs/template.yaml @@ -50,10 +50,47 @@ dataset: # ----------------------------------------------------------------------------- # METRICS # ----------------------------------------------------------------------------- -# Metrics to compute for alignment analysis -# Available: rayleigh_quotient, activation_l2_norm, activation_outlier_index, -# pairwise_redundancy_gaussian, synergy_gaussian_mmi, -# mutual_information_gaussian, weight_cosine_similarity +# Metrics to compute for alignment analysis and pruning. +# +# AVAILABLE METRICS (by category): +# +# ALIGNMENT (Rayleigh Quotient based): +# rayleigh_quotient (rq) - Alignment with input covariance +# rq_fast (rq_gap) - Fast RQ using GAP (10-100x faster for CNNs) +# rq_spatial - Spatial RQ treating locations as samples +# conditional_rayleigh_quotient - Class-conditioned RQ +# delta_rq - RQ_uncond - RQ_cond (class sensitivity) +# +# MUTUAL INFORMATION: +# gaussian_mi_analytic - Gaussian MI with Edgeworth corrections +# mi_about_class - MI between activations and class labels +# conditional_mi - Class-conditioned MI +# +# REDUNDANCY: +# pairwise_redundancy_gaussian - Pairwise Gaussian redundancy +# average_redundancy - Average redundancy per neuron +# halo_redundancy - Halo vs non-halo redundancy analysis +# cross_layer_redundancy - Downstream importance from next layer +# +# SYNERGY: +# synergy_gaussian_mmi - MMI-based synergy +# gaussian_pid_synergy_mmi - PID synergy +# +# ACTIVATION: +# activation_l2_norm - L2 norm of activations (NeMo/TensorRT style) +# activation_mean - Mean absolute activation +# activation_variance - Variance of activations +# activation_outlier_index - Outlier/supernode detection +# +# COMPOSITE: +# composite_importance - Weighted combination of metrics +# alignment_minus_redundancy - RQ - α*Redundancy +# cross_layer_importance - RQ + Downstream - Redundancy (SCAR-aligned) +# +# SPECTRAL: +# spectral_gap - Spectral gap of weight matrix +# eigenvalue_alignment - Wasserstein distance of eigenvalues +# metrics: enabled: - "rayleigh_quotient" @@ -63,25 +100,130 @@ metrics: # Number of data samples for metric computation num_samples: 128 - # Metric-specific settings + # ========================================================================= + # OPTIMIZATION OPTIONS + # ========================================================================= + # These options enable optimized implementations for faster computation. + # + # use_jit: Use JIT-compiled functions (20-50% faster) + # - Available for: rayleigh_quotient, mutual_information, node_correlation, + # cosine_similarity, spectral_norm, eigenvalue_entropy + # - Requires: PyTorch JIT compilation support + # + # use_gpu_acceleration: Use GPU-accelerated functions + # - Available for: histogram (1D/2D), mutual_information, entropy, + # covariance, correlation + # - Requires: CUDA-enabled GPU + # - Best for: Large batch sizes, histogram-based metrics + # + # force_cpu_for_large_ops: Move large tensor ops to CPU (default: true) + # - Prevents OOM for covariance matrices of high-dim layers + # - cpu_threshold: Element count threshold (default: 1e8 = 100M) + # + optimization: + use_jit: false # Enable JIT-compiled metric computations + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Move large ops to CPU to prevent OOM + cpu_threshold: 100000000 # 1e8 elements threshold for CPU fallback + + # ========================================================================= + # ALIGNMENT METRICS + # ========================================================================= rayleigh_quotient: + relative: true # Normalize by trace(Σ) for relative alignment + min_samples: 2 # Minimum samples for covariance estimation + scale_by_norm: false # Scale covariance by Frobenius norm + regularization: 1.0e-6 # Diagonal regularization for stability + + # Fast RQ for CNNs using Global Average Pooling + # 10-100x faster, good approximation for early conv layers + rq_fast: + relative: true + regularization: 1.0e-6 + + # Conditional RQ (class-conditioned) + conditional_rayleigh_quotient: relative: true + min_samples: 2 # Minimum samples per class regularization: 1.0e-6 + return_delta: false # If true, return ΔRQ = RQ_uncond - RQ_cond + # ========================================================================= + # MUTUAL INFORMATION METRICS + # ========================================================================= + gaussian_mi_analytic: + expansion_order: 2 # Edgeworth order (0=Gaussian, 1-3 for corrections) + noise_std: 0.1 # Assumed noise standard deviation + regularization: 1.0e-6 # Covariance regularization + per_neuron: true # Per-neuron MI (true) or joint MI (false) + + mi_about_class: + method: "gaussian" # "gaussian" (fast) or "binning" (discrete) + bins: 10 # Number of bins for binning method + min_samples_per_class: 5 # Minimum samples per class + + # ========================================================================= + # REDUNDANCY METRICS + # ========================================================================= pairwise_redundancy_gaussian: - mode: "output_based" # "output_based" (fast) or "covariance_based" - num_pairs: 10 + mode: "output_based" # "output_based" (fast) or "covariance_based" + num_pairs: 10 # Partners to sample per neuron + sampling_strategy: "random" # "random", "nearest", or "all" + regularization: 1.0e-6 # For covariance_based mode + + average_redundancy: + min_samples: 2 # Minimum samples for computation + use_correlation: true # Use correlation (true) or covariance (false) + + halo_redundancy: + supernode_fraction: 0.01 # Top fraction as supernodes + halo_fraction: 0.10 # Fraction of non-supernodes as halo + max_samples: 1000 # Max activation samples + max_pairs_per_group: 500 # Max pairs per group + + cross_layer_redundancy: + max_refs: 512 # Max reference neurons for efficiency + compute_within: true # Also compute within-layer redundancy + # ========================================================================= + # SYNERGY METRICS + # ========================================================================= synergy_gaussian_mmi: - num_pairs: 10 + num_pairs: 10 # Pairs to sample for synergy + regularization: 1.0e-6 + + # ========================================================================= + # ACTIVATION METRICS + # ========================================================================= + activation_l2_norm: + aggregate_method: "l2" # "l2", "mean", "max" + use_absolute: true # Take absolute value before aggregation + + activation_outlier_index: + quantile: 0.999 # High percentile for outlier detection + eps: 1.0e-6 # Numerical stability - # Composite score weights (for combining metrics into single importance score) - # Used when pruning.scoring: "composite" or supernode.score_metric: "composite" + # ========================================================================= + # COMPOSITE METRICS + # ========================================================================= + # Weights for combining metrics (negative = penalty) composite_weights: activation_l2_norm: 0.2 rayleigh_quotient: 0.3 pairwise_redundancy_gaussian: -0.2 # Negative = penalize redundancy synergy_gaussian_mmi: 0.2 + + composite_importance: + normalize_components: true # Normalize each component to [0, 1] + log_transform_rq: true # Apply log transform to RQ + + cross_layer_importance: + rq_weight: 0.25 # Weight for RQ (α) + downstream_weight: 0.35 # Weight for downstream importance (β) + within_redundancy_weight: 0.25 # Penalty for redundancy (γ) + activation_weight: 0.15 # Weight for activation magnitude + normalize: true # Normalize components + max_refs: 512 # Max reference neurons # ----------------------------------------------------------------------------- # TRAINING (optional) @@ -131,6 +273,49 @@ pruning: # "cascading" - sequential layer-by-layer pruning distribution: "uniform" + # Per-layer sparsity constraints + min_per_layer: 0.0 # Minimum sparsity per layer + max_per_layer: 0.95 # Maximum sparsity per layer (avoid removing everything) + + # Sensitivity configuration (used when distribution="adaptive_sensitivity") + # Measures how important each layer is to determine per-layer pruning amounts + sensitivity: + # Method options: + # "perturbation" - Add Gaussian noise, measure accuracy drop (slow, most accurate) + # "masking" - Random mask 30% weights, measure accuracy drop (slow) + # "activation_variance" - Use activation variance (FAST, single forward pass) [RECOMMENDED] + # "gradient" - Use gradient magnitude (FAST, single backward pass) + # "fisher" - Fisher information approximation (moderate speed) + # "weight_magnitude" - Use weight magnitude (FASTEST, no forward pass) + method: "activation_variance" + perturbation_scale: 0.1 # For "perturbation" method + num_trials: 3 # For "perturbation" and "masking" methods + + # Pruning strategies available via get_pruning_strategy(): + # BASELINES: + # "random" - Random pruning baseline + # "magnitude" - Weight magnitude pruning + # ALIGNMENT-BASED: + # "alignment" - Generic alignment (specify metric) + # "rayleigh_quotient" - RQ alignment score + # "hybrid" - Combine magnitude + alignment + # "global_alignment" - Global alignment-based + # "cascading_alignment" - Sequential layer pruning with score recomputation + # EIGENVECTOR/PCA: + # "eigenvector" - PCA-based pruning (prune low-variance neurons) + # MOVEMENT (Sanh et al. NeurIPS 2020): + # "movement" - Prune weights moving toward zero during training + # "adaptive_movement" - Auto-tuned movement pruning + # SENSITIVITY-ADAPTIVE: + # "adaptive_sensitivity" - Per-layer adaptive amounts based on sensitivity + # GRADIENT-BASED: + # "gradient" - Gradient magnitude pruning + # "fisher" - Fisher information pruning + # "momentum" - Momentum-based pruning + # LLM-SPECIFIC: + # "wanda" - Sun et al. 2023 + # "sparsegpt" - Frantar & Alistarh 2023 + # Structured pruning: remove entire neurons/channels (vs individual weights) # - true = prune whole neurons (hardware efficient, required for alignment metrics) # - false = prune individual weights (sparse matrices, less practical) @@ -164,14 +349,46 @@ llm: # ----------------------------------------------------------------------------- # How to preprocess convolutional layer activations for metric computation. # This affects how spatial dimensions (H, W) are handled when computing RQ, MI, etc. +# +# IMPORTANT: The choice of mode affects accuracy vs speed tradeoff significantly! +# +# COMPARISON TABLE: +# ┌─────────────────┬────────────┬─────────────────┬──────────────────────────────┐ +# │ Mode │ Speed │ Memory │ Best For │ +# ├─────────────────┼────────────┼─────────────────┼──────────────────────────────┤ +# │ unfold │ Slow │ O(B·P·C·K²) │ Accurate RQ/MI, later layers │ +# │ patchwise │ Moderate │ O(B·C·K²·P) │ Patch-level analysis │ +# │ spatial │ Fast │ O(B·H·W·C) │ Covariance metrics │ +# │ gap │ Fastest │ O(B·C) │ Quick experiments, early CNN │ +# │ channel_variance│ Fast │ O(C) │ Activation magnitude only │ +# └─────────────────┴────────────┴─────────────────┴──────────────────────────────┘ +# +# Where: B=batch, C=channels, H/W=spatial dims, K=kernel size, P=num patches +# cnn: # Preprocessing mode for CNN activations: # "unfold" (recommended): Unfolds spatial dims using kernel params, creates # [batch*patches, C*K*K] tensor. Best for RQ/covariance. + # Most accurate but slowest and highest memory. + # # "patchwise": Keeps patches separate [batch, features, patches]. Good for - # patch-level analysis. - # "channel_variance": Uses channel-wise variance summary. Faster but less accurate. + # patch-level analysis with weighted averaging. + # + # "spatial": Reshapes [B, C, H, W] -> [B*H*W, C]. Treats each spatial location + # as a sample. Fast and preserves spatial variance. RECOMMENDED for + # large feature maps where unfold is too slow. + # + # "gap": Global Average Pooling [B, C, H, W] -> [B, C]. Fastest but loses + # spatial information. Use for quick experiments or early layers. + # Alternatively, use rq_fast metric which internally does this. + # + # "channel_variance": Uses per-channel variance summary. For activation-based + # metrics only (not covariance-based like RQ). mode: "unfold" + + # Additional CNN options: + max_patches: 64 # Max patches to use (subsample if more) + weight_by_variance: true # Weight patches by their variance (for patchwise) # ----------------------------------------------------------------------------- # PERFORMANCE @@ -208,5 +425,18 @@ visualization: # ----------------------------------------------------------------------------- # OUTPUT PATHS # ----------------------------------------------------------------------------- +# Output directory structure: +# - If base_output_dir is set: Creates job-specific directories with structure: +# {base_output_dir}/{experiment_name}_{timestamp}_{job_id}/ +# results/, logs/, checkpoints/, figures/, analysis/ +# - If not set: Uses log_dir/checkpoint_dir directly (old behavior) +# +# For SLURM jobs, the job_id is automatically taken from SLURM_JOB_ID +# For local runs, a unique 8-character ID is generated +# +# Recommended for cluster runs: +# base_output_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +base_output_dir: null + log_dir: "./logs" checkpoint_dir: "./checkpoints" diff --git a/configs/unified_template.yaml b/configs/unified_template.yaml index b7116566..a349b088 100644 --- a/configs/unified_template.yaml +++ b/configs/unified_template.yaml @@ -57,7 +57,48 @@ calibration: # METRICS (unified naming with aliases) # ----------------------------------------------------------------------------- metrics: + # ----------------------------------------------------------------------------- + # OPTIMIZATION OPTIONS + # ----------------------------------------------------------------------------- + # Enable optimized implementations for faster metric computation. + # + # use_jit: Use JIT-compiled functions (20-50% faster) + # - Available for: rayleigh_quotient, mutual_information, node_correlation, + # cosine_similarity, spectral_norm, eigenvalue_entropy + # + # use_gpu_acceleration: Use GPU-accelerated functions + # - Available for: histogram (1D/2D), mutual_information, entropy, + # covariance, correlation + # - Best for: Large batch sizes, histogram-based metrics + # + optimization: + use_jit: false # Enable JIT-compiled metric computations + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Move large ops to CPU to prevent OOM + cpu_threshold: 100000000 # 1e8 elements threshold for CPU fallback + + # ----------------------------------------------------------------------------- + # CNN preprocessing mode (for convolutional layers only) + # ----------------------------------------------------------------------------- + # How to handle spatial dimensions in CNNs when computing metrics like RQ: + # - "unfold": Unfold patches, exact RQ (slow, high memory for large spatial dims) + # - "patchwise": Same as unfold but keeps patches separate (slow) + # - "gap": Global Average Pooling (FAST, approximate, loses spatial info) + # - "spatial": Treat spatial locations as samples (FAST, preserves variance) + # Recommendation: + # - Use "unfold" for research/accuracy (default) + # - Use "gap" or "spatial" for speed during development/large images + cnn_mode: "unfold" + max_patches: 64 # Subsample patches if more than this (speed optimization) + # Core metrics (available for both vision and LLM) + # ----------------------------------------------------------------------------- + # Rayleigh Quotient variants: + # - "rayleigh_quotient" (default): Standard RQ, uses unfold for CNNs + # - "rq_fast" / "rq_gap": Uses GAP for CNNs (10-100x faster) + # - "rq_spatial": Treats spatial locations as samples + # - "rq_patchwise": Explicit patchwise computation + # ----------------------------------------------------------------------------- rayleigh_quotient: enabled: true relative: true @@ -67,6 +108,7 @@ metrics: enabled: true sampling: "all" # all, random, top_k num_pairs: 10 # For pairwise computation + mode: "output_based" # output_based (FAST) or covariance_based (SLOW) synergy: # PID-based synergy with target enabled: true @@ -147,10 +189,66 @@ cascade_analysis: pruning: enabled: true - # Sparsity levels + # Distribution strategy: how to allocate pruning across layers + # Options: + # - "uniform": Same percentage per layer (default) + # - "global_threshold": Global score threshold across all layers + # - "adaptive_sensitivity": More pruning on robust layers, less on sensitive (RECOMMENDED) + # - "importance_weighted": Based on average importance scores + distribution: "uniform" + + # Per-layer constraints + min_per_layer: 0.0 # Minimum per-layer pruning fraction (0-1) + max_per_layer: 0.95 # Maximum per-layer pruning fraction (0-1) + + # Dependency-aware pruning (recommended for CNNs with BatchNorm, ResNets, etc.) + dependency_aware: false + + # ----------------------------------------------------------------------------- + # Sensitivity-based adaptive pruning (used when distribution="adaptive_sensitivity") + # ----------------------------------------------------------------------------- + sensitivity: + # Method to measure layer sensitivity. Options: + # - "perturbation": Add Gaussian noise, measure accuracy drop (slow, most accurate) + # - "masking": Random mask 30% weights, measure accuracy drop (slow) + # - "activation_variance": Use activation variance (FAST, single forward pass) + # - "gradient": Use gradient magnitude (FAST, single backward pass) + # - "fisher": Fisher information approximation (moderate speed) + # - "weight_magnitude": Use weight magnitude (FASTEST, no forward pass) + method: "activation_variance" + perturbation_scale: 0.1 # For "perturbation" method + num_trials: 3 # For "perturbation" and "masking" methods + + # Sparsity levels to evaluate ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - # Methods to compare (available for both vision and LLM) + # ----------------------------------------------------------------------------- + # Pruning methods/strategies to compare + # ----------------------------------------------------------------------------- + # Available strategies: + # BASELINES: + # - "random": Random pruning (baseline) + # - "magnitude": Weight magnitude pruning + # ALIGNMENT-BASED: + # - "rayleigh_quotient": RQ alignment score + # - "alignment": Generic alignment pruning (specify metric) + # - "hybrid": Combine magnitude + alignment + # - "global_alignment": Global alignment-based + # - "cascading_alignment": Sequential layer pruning with score recomputation + # EIGENVECTOR/PCA: + # - "eigenvector": PCA-based pruning (prune low-variance neurons) + # MOVEMENT (Sanh et al. NeurIPS 2020): + # - "movement": Prune weights moving toward zero during training + # - "adaptive_movement": Auto-tuned movement pruning + # GRADIENT-BASED: + # - "gradient": Gradient magnitude pruning + # - "fisher": Fisher information pruning + # - "momentum": Momentum-based pruning + # CLUSTER-AWARE: + # - "cluster_aware": Respects cluster boundaries + # LLM-SPECIFIC: + # - "wanda": Sun et al. 2023 + # - "sparsegpt": Frantar & Alistarh 2023 methods: # Baselines - name: "random" @@ -175,12 +273,18 @@ pruning: - name: "cluster_aware" # Respects cluster boundaries - name: "supernode_aware" # Protects supernodes + # Advanced (uncomment to enable) + # - name: "eigenvector" # PCA-based pruning + # - name: "movement" # Training-aware (requires training history) + # - name: "cascading_alignment" # Progressive with score recomputation + # LLM-specific (ignored for vision) - name: "wanda" - name: "sparsegpt" - name: "scar_loss_proxy" - # Selection modes + # Selection modes: which neurons to prune based on scores + # Options: "low" (prune low scores), "high" (prune high scores), "random" selection_modes: ["low", "high"] # Fine-tuning after pruning @@ -246,8 +350,22 @@ visualization: # ----------------------------------------------------------------------------- # OUTPUT # ----------------------------------------------------------------------------- +# Output directory structure: +# - If base_dir is set: Creates job-specific directories with structure: +# {base_dir}/{experiment_name}_{timestamp}_{job_id}/ +# results/, logs/, checkpoints/, figures/, analysis/ +# - If only dir is set: Uses that as the output directory (old behavior) +# +# For SLURM jobs, the job_id is automatically taken from SLURM_JOB_ID +# For local runs, a unique 8-character ID is generated output: + # Base directory for job-specific outputs (recommended for cluster runs) + # Each experiment creates a unique subdirectory with timestamp and job ID + base_dir: null # e.g., "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" + + # Direct output directory (legacy, used if base_dir is null) dir: "./results" + save_metrics: true save_clusters: true save_figures: true diff --git a/configs/vision_prune/mobilenetv2_cifar10_full.yaml b/configs/vision_prune/mobilenetv2_cifar10_full.yaml index c0460143..6c1b4ee0 100644 --- a/configs/vision_prune/mobilenetv2_cifar10_full.yaml +++ b/configs/vision_prune/mobilenetv2_cifar10_full.yaml @@ -36,9 +36,47 @@ cascade_analysis: enabled: true n_remove_per_cluster: 5 +# Pruning experiments - Comprehensive metric testing +# MobileNetV2 uses inverted residuals with depthwise separable convs +# More sensitive to pruning - interesting to see which metrics matter pruning: enabled: true - ratios: [0.1, 0.3, 0.5, 0.7] + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] # More conservative for MobileNet + + # Comprehensive algorithm list + algorithms: + # Baselines + - random + - magnitude + - taylor + + # Single metrics - prune LOW + - rq_low # Prune low Rayleigh Quotient + - redundancy_low # Prune low redundancy (MI) + - synergy_low # Prune low synergy + + # Single metrics - prune HIGH + - rq_high # Prune high RQ + - redundancy_high # Prune high redundancy + - synergy_high # Prune high synergy + - magnitude_high # Prune high magnitude + + # Composite combinations + - composite # score = RQ + syn - red (original) + - composite_pos_red # score = RQ + syn + red (flipped) + - rq_minus_red # score = RQ - redundancy + - rq_plus_red # score = RQ + redundancy + - magnitude_plus_rq # score = magnitude + RQ + - magnitude_minus_red # score = magnitude - redundancy + - magnitude_plus_red # score = magnitude + redundancy + + # Cluster-aware + - cluster_aware + - cluster_aware_protect_redundant + + fine_tuning: + epochs: 10 + lr: 0.0001 visualization: enabled: true diff --git a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml index 63015ff8..88199e64 100644 --- a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml +++ b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml @@ -53,6 +53,13 @@ calibration: # METRICS # ----------------------------------------------------------------------------- metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -117,33 +124,72 @@ cascade_analysis: # ----------------------------------------------------------------------------- # PRUNING # ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# PRUNING - Comprehensive metric testing +# ----------------------------------------------------------------------------- +# MobileNet uses inverted residuals with depthwise separable convs +# More sensitive to pruning - interesting to see which metrics matter pruning: enabled: true - ratios: [0.1, 0.2, 0.3, 0.4, 0.5] # MobileNet is already efficient - conservative pruning - selection_modes: ["low", "high"] + distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + dependency_aware: true # MobileNet has inverted residuals + min_per_layer: 0.0 + max_per_layer: 0.95 + ratios: [0.1, 0.2, 0.3, 0.4, 0.5] # Conservative for MobileNet + # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: - - "random" - - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" - - "composite" - - "cluster_aware" - + # ========================================================================= + # BASELINES + # ========================================================================= + - "random" # Random baseline + - "magnitude" # Standard magnitude pruning (prune low) + - "taylor" # Gradient-based importance + + # ========================================================================= + # SINGLE METRICS - Prune LOW (assumes low = unimportant) + # ========================================================================= + - "rq_low" # Prune low Rayleigh Quotient + - "redundancy_low" # Prune low redundancy (MI) + - "synergy_low" # Prune low synergy + + # ========================================================================= + # SINGLE METRICS - Prune HIGH (assumes high = unimportant) + # ========================================================================= + - "rq_high" # Prune high RQ (TEST: is high RQ bad?) + - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) + - "synergy_high" # Prune high synergy + - "magnitude_high" # Prune high magnitude (inverse of standard) + + # ========================================================================= + # COMPOSITE COMBINATIONS + # ========================================================================= + - "composite" # Original: score = RQ + syn - red (prune low) + - "composite_pos_red" # Flipped: score = RQ + syn + red (prune low) + - "rq_minus_red" # score = RQ - redundancy + - "rq_plus_red" # score = RQ + redundancy + - "magnitude_plus_rq" # score = magnitude + RQ + - "magnitude_minus_red" # score = magnitude - redundancy + - "magnitude_plus_red" # score = magnitude + redundancy + + # ========================================================================= + # CLUSTER-AWARE + # ========================================================================= + - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_protect_redundant" # Inverted: protect redundant + scoring_methods: - "random" - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" + - "rq_low" + - "rq_high" + - "redundancy_low" + - "redundancy_high" - "composite" - - "cluster_aware" - + - "composite_pos_red" + fine_tune: - enabled: true + enabled: false # Disabled to see pure pruning impact without recovery epochs: 15 # MobileNet may need more fine-tuning learning_rate: 0.0001 weight_decay: 0.00001 @@ -255,6 +301,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/vision/mobilenetv2_cifar10" save_metrics: true save_clusters: true diff --git a/configs/vision_prune/resnet18_cifar10_full.yaml b/configs/vision_prune/resnet18_cifar10_full.yaml index 3a91c9c6..b8020dfe 100644 --- a/configs/vision_prune/resnet18_cifar10_full.yaml +++ b/configs/vision_prune/resnet18_cifar10_full.yaml @@ -40,6 +40,10 @@ metrics: # Clustering clustering: enabled: true + distribution: "uniform" + dependency_aware: true + min_per_layer: 0.0 + max_per_layer: 0.95 n_clusters: 4 normalize_features: true compute_stability: true @@ -48,6 +52,10 @@ clustering: # Cross-layer halo analysis halo_analysis: enabled: true + distribution: "uniform" + dependency_aware: true + min_per_layer: 0.0 + max_per_layer: 0.95 percentile: 90.0 use_activation_weight: true compute_influence_matrix: true @@ -55,21 +63,53 @@ halo_analysis: # Cascade/damage analysis cascade_analysis: enabled: true + distribution: "uniform" + dependency_aware: true + min_per_layer: 0.0 + max_per_layer: 0.95 n_remove_per_cluster: 5 damage_sample_fraction: 0.2 -# Pruning experiments +# Pruning experiments - Comprehensive metric testing +# Tests individual metrics and combinations to validate CNN pruning assumptions pruning: enabled: true - ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + distribution: "uniform" + dependency_aware: true + min_per_layer: 0.0 + max_per_layer: 0.95 + ratios: [0.1, 0.3, 0.5, 0.7] - methods: - - name: random - - name: magnitude - - name: taylor - - name: rq_only - - name: composite # RQ + Red + Syn - - name: cluster_aware # Full cluster + halo aware + # Comprehensive algorithm list for exploration + algorithms: + # Baselines + - random + - magnitude + - taylor + + # Single metrics - prune LOW + - rq_low # Prune low Rayleigh Quotient + - redundancy_low # Prune low redundancy (MI) + - synergy_low # Prune low synergy + + # Single metrics - prune HIGH (test assumptions) + - rq_high # Prune high RQ + - redundancy_high # Prune high redundancy + - synergy_high # Prune high synergy + - magnitude_high # Prune high magnitude + + # Composite combinations + - composite # score = RQ + syn - red (original) + - composite_pos_red # score = RQ + syn + red (flipped redundancy) + - rq_minus_red # score = RQ - redundancy + - rq_plus_red # score = RQ + redundancy + - magnitude_plus_rq # score = magnitude + RQ + - magnitude_minus_red # score = magnitude - redundancy + - magnitude_plus_red # score = magnitude + redundancy + + # Cluster-aware + - cluster_aware + - cluster_aware_protect_redundant fine_tuning: epochs: 10 @@ -84,6 +124,10 @@ baselines: # Visualization visualization: enabled: true + distribution: "uniform" + dependency_aware: true + min_per_layer: 0.0 + max_per_layer: 0.95 figures: - cluster_scatter - cluster_evolution diff --git a/configs/vision_prune/resnet18_cifar10_unified.yaml b/configs/vision_prune/resnet18_cifar10_unified.yaml index a8a0e0be..c2e0418a 100644 --- a/configs/vision_prune/resnet18_cifar10_unified.yaml +++ b/configs/vision_prune/resnet18_cifar10_unified.yaml @@ -56,6 +56,13 @@ calibration: # magnitude (alias: activation_l2_norm) # ----------------------------------------------------------------------------- metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -119,37 +126,80 @@ cascade_analysis: damage_sample_fraction: 0.2 # ----------------------------------------------------------------------------- -# PRUNING +# PRUNING - Comprehensive testing of all metrics +# ----------------------------------------------------------------------------- +# This tests individual metrics and combinations to validate basic assumptions +# about what makes channels important in CNNs vs what works for LLMs. +# +# Key questions to answer: +# 1. Do low-RQ channels safely prune? (rq_low) +# 2. Is high redundancy bad (redundancy_high) or good (redundancy_low)? +# 3. Does synergy matter for CNNs? +# 4. What combinations work best? # ----------------------------------------------------------------------------- pruning: enabled: true - ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + distribution: "uniform" # uniform, global_threshold, size_proportional, importance_weighted + dependency_aware: true # Propagate masks through BN/skip connections + min_per_layer: 0.0 + max_per_layer: 0.95 + # Include high sparsity (80%, 90%) to clearly see degradation + ratios: [0.1, 0.3, 0.5, 0.7, 0.8, 0.9] - # Algorithms to test (unified naming) + # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: - - "random" - - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" - - "composite" - - "cluster_aware" + # ========================================================================= + # BASELINES + # ========================================================================= + - "random" # Random baseline + - "magnitude" # Standard magnitude pruning (prune low) + - "taylor" # Gradient-based importance + + # ========================================================================= + # SINGLE METRICS - Prune LOW (assumes low = unimportant) + # ========================================================================= + - "rq_low" # Prune low Rayleigh Quotient + - "redundancy_low" # Prune low redundancy (MI) + - "synergy_low" # Prune low synergy + + # ========================================================================= + # SINGLE METRICS - Prune HIGH (assumes high = unimportant) + # ========================================================================= + - "rq_high" # Prune high RQ (TEST: is high RQ bad?) + - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) + - "synergy_high" # Prune high synergy + - "magnitude_high" # Prune high magnitude (inverse of standard) + + # ========================================================================= + # COMPOSITE COMBINATIONS + # ========================================================================= + - "composite" # Original: score = RQ + syn - red (prune low) + - "composite_pos_red" # Flipped: score = RQ + syn + red (prune low) + - "rq_minus_red" # score = RQ - redundancy + - "rq_plus_red" # score = RQ + redundancy + - "magnitude_plus_rq" # score = magnitude + RQ + - "magnitude_minus_red" # score = magnitude - redundancy + - "magnitude_plus_red" # score = magnitude + redundancy + + # ========================================================================= + # CLUSTER-AWARE + # ========================================================================= + - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_protect_redundant" # Inverted: protect redundant - # Scoring methods (for comparison) scoring_methods: - "random" - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" + - "rq_low" + - "rq_high" + - "redundancy_low" + - "redundancy_high" + - "synergy_low" - "composite" - - "cluster_aware" + - "composite_pos_red" fine_tune: - enabled: true + enabled: false # Disabled to see pure pruning impact without recovery epochs: 10 learning_rate: 0.0001 weight_decay: 0.0001 @@ -261,7 +311,10 @@ visualization: # ----------------------------------------------------------------------------- # OUTPUT # ----------------------------------------------------------------------------- +# Uses job directory structure: creates unique folders for each run +# Directory format: {base_dir}/{experiment_name}_{timestamp}_{job_id}/ output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/vision/resnet18_cifar10" save_metrics: true save_clusters: true @@ -274,7 +327,8 @@ output: # ----------------------------------------------------------------------------- extra: # Pre-training for ImageNet pretrained models on CIFAR - pretrain_epochs: 20 + # Train until model achieves ~90% accuracy on CIFAR-10 + pretrain_epochs: 30 pretrain_lr: 0.001 # Baselines to compare against diff --git a/configs/vision_prune/resnet50_imagenet100.yaml b/configs/vision_prune/resnet50_imagenet100.yaml index cb630c39..69285c9e 100644 --- a/configs/vision_prune/resnet50_imagenet100.yaml +++ b/configs/vision_prune/resnet50_imagenet100.yaml @@ -65,17 +65,42 @@ cascade_analysis: n_remove_per_cluster: 5 damage_sample_fraction: 0.1 # Smaller for faster computation -# Pruning experiments +# Pruning experiments - Comprehensive metric testing +# ResNet-50 on ImageNet: larger model, more channels, tests scalability pruning: enabled: true - ratios: [0.2, 0.3, 0.4, 0.5, 0.6] + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] - methods: - - name: random - - name: magnitude - - name: taylor - - name: composite # RQ + Red + Syn (per-channel) - - name: cluster_aware # Full cluster + halo aware + # Comprehensive algorithm list + algorithms: + # Baselines + - random + - magnitude + - taylor + + # Single metrics - prune LOW + - rq_low # Prune low Rayleigh Quotient + - redundancy_low # Prune low redundancy (MI) + - synergy_low # Prune low synergy + + # Single metrics - prune HIGH + - rq_high # Prune high RQ + - redundancy_high # Prune high redundancy + - synergy_high # Prune high synergy + - magnitude_high # Prune high magnitude + + # Composite combinations + - composite # score = RQ + syn - red (original) + - composite_pos_red # score = RQ + syn + red (flipped) + - rq_minus_red # score = RQ - redundancy + - rq_plus_red # score = RQ + redundancy + - magnitude_plus_rq # score = magnitude + RQ + - magnitude_minus_red # score = magnitude - redundancy + - magnitude_plus_red # score = magnitude + redundancy + + # Cluster-aware + - cluster_aware + - cluster_aware_protect_redundant fine_tuning: epochs: 5 # Fewer epochs for ImageNet diff --git a/configs/vision_prune/resnet50_imagenet100_unified.yaml b/configs/vision_prune/resnet50_imagenet100_unified.yaml index fe55bac8..beb88e21 100644 --- a/configs/vision_prune/resnet50_imagenet100_unified.yaml +++ b/configs/vision_prune/resnet50_imagenet100_unified.yaml @@ -55,6 +55,13 @@ calibration: # METRICS # ----------------------------------------------------------------------------- metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -119,34 +126,71 @@ cascade_analysis: # ----------------------------------------------------------------------------- # PRUNING # ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# PRUNING - Comprehensive metric testing +# ----------------------------------------------------------------------------- +# ResNet-50 on ImageNet: larger model, more channels, tests scalability pruning: enabled: true - ratios: [0.2, 0.3, 0.4, 0.5, 0.6] - selection_modes: ["low", "high"] + distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + dependency_aware: true # ResNet has skip connections + min_per_layer: 0.0 + max_per_layer: 0.95 + ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: - - "random" - - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" - - "composite" - - "cluster_aware" - - "network_slimming" - + # ========================================================================= + # BASELINES + # ========================================================================= + - "random" # Random baseline + - "magnitude" # Standard magnitude pruning (prune low) + - "taylor" # Gradient-based importance + + # ========================================================================= + # SINGLE METRICS - Prune LOW (assumes low = unimportant) + # ========================================================================= + - "rq_low" # Prune low Rayleigh Quotient + - "redundancy_low" # Prune low redundancy (MI) + - "synergy_low" # Prune low synergy + + # ========================================================================= + # SINGLE METRICS - Prune HIGH (assumes high = unimportant) + # ========================================================================= + - "rq_high" # Prune high RQ (TEST: is high RQ bad?) + - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) + - "synergy_high" # Prune high synergy + - "magnitude_high" # Prune high magnitude (inverse of standard) + + # ========================================================================= + # COMPOSITE COMBINATIONS + # ========================================================================= + - "composite" # Original: score = RQ + syn - red (prune low) + - "composite_pos_red" # Flipped: score = RQ + syn + red (prune low) + - "rq_minus_red" # score = RQ - redundancy + - "rq_plus_red" # score = RQ + redundancy + - "magnitude_plus_rq" # score = magnitude + RQ + - "magnitude_minus_red" # score = magnitude - redundancy + - "magnitude_plus_red" # score = magnitude + redundancy + + # ========================================================================= + # CLUSTER-AWARE + # ========================================================================= + - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_protect_redundant" # Inverted: protect redundant + scoring_methods: - "random" - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" + - "rq_low" + - "rq_high" + - "redundancy_low" + - "redundancy_high" - "composite" - - "cluster_aware" - + - "composite_pos_red" + fine_tune: - enabled: true + enabled: false # Disabled to see pure pruning impact without recovery epochs: 5 # Fewer epochs for ImageNet learning_rate: 0.00001 weight_decay: 0.0001 @@ -258,6 +302,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/vision/resnet50_imagenet100" save_metrics: true save_clusters: true diff --git a/configs/vision_prune/vgg16_cifar10_full.yaml b/configs/vision_prune/vgg16_cifar10_full.yaml index 4e774f6e..6ff050c1 100644 --- a/configs/vision_prune/vgg16_cifar10_full.yaml +++ b/configs/vision_prune/vgg16_cifar10_full.yaml @@ -37,15 +37,46 @@ cascade_analysis: enabled: true n_remove_per_cluster: 5 +# Pruning experiments - Comprehensive metric testing +# VGG is highly pruneable (lots of redundant filters) - good for testing assumptions pruning: enabled: true - ratios: [0.1, 0.3, 0.5, 0.7] - methods: + ratios: [0.1, 0.3, 0.5, 0.7, 0.8] # VGG can handle higher sparsity + + # Comprehensive algorithm list + algorithms: + # Baselines + - random - magnitude - taylor - - network_slimming - - composite + + # Single metrics - prune LOW + - rq_low # Prune low Rayleigh Quotient + - redundancy_low # Prune low redundancy (MI) + - synergy_low # Prune low synergy + + # Single metrics - prune HIGH + - rq_high # Prune high RQ + - redundancy_high # Prune high redundancy + - synergy_high # Prune high synergy + - magnitude_high # Prune high magnitude + + # Composite combinations + - composite # score = RQ + syn - red (original) + - composite_pos_red # score = RQ + syn + red (flipped) + - rq_minus_red # score = RQ - redundancy + - rq_plus_red # score = RQ + redundancy + - magnitude_plus_rq # score = magnitude + RQ + - magnitude_minus_red # score = magnitude - redundancy + - magnitude_plus_red # score = magnitude + redundancy + + # Cluster-aware - cluster_aware + - cluster_aware_protect_redundant + + fine_tuning: + epochs: 10 + lr: 0.0001 visualization: enabled: true diff --git a/configs/vision_prune/vgg16_cifar10_unified.yaml b/configs/vision_prune/vgg16_cifar10_unified.yaml index 846377da..baa08658 100644 --- a/configs/vision_prune/vgg16_cifar10_unified.yaml +++ b/configs/vision_prune/vgg16_cifar10_unified.yaml @@ -51,6 +51,13 @@ calibration: # METRICS # ----------------------------------------------------------------------------- metrics: + # Optimization options for faster metric computation + optimization: + use_jit: false # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: false # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold + rayleigh_quotient: enabled: true relative: true @@ -115,34 +122,71 @@ cascade_analysis: # ----------------------------------------------------------------------------- # PRUNING # ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# PRUNING - Comprehensive metric testing +# ----------------------------------------------------------------------------- +# VGG is highly pruneable due to high redundancy - excellent for testing assumptions pruning: enabled: true - ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] # VGG can be pruned more aggressively - selection_modes: ["low", "high"] + distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + dependency_aware: false # VGG has no skip connections + min_per_layer: 0.0 + max_per_layer: 0.95 + ratios: [0.1, 0.3, 0.5, 0.7, 0.8] # VGG can be pruned aggressively + # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: - - "random" - - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" - - "composite" - - "cluster_aware" - - "network_slimming" - + # ========================================================================= + # BASELINES + # ========================================================================= + - "random" # Random baseline + - "magnitude" # Standard magnitude pruning (prune low) + - "taylor" # Gradient-based importance + + # ========================================================================= + # SINGLE METRICS - Prune LOW (assumes low = unimportant) + # ========================================================================= + - "rq_low" # Prune low Rayleigh Quotient + - "redundancy_low" # Prune low redundancy (MI) + - "synergy_low" # Prune low synergy + + # ========================================================================= + # SINGLE METRICS - Prune HIGH (assumes high = unimportant) + # ========================================================================= + - "rq_high" # Prune high RQ (TEST: is high RQ bad?) + - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) + - "synergy_high" # Prune high synergy + - "magnitude_high" # Prune high magnitude (inverse of standard) + + # ========================================================================= + # COMPOSITE COMBINATIONS + # ========================================================================= + - "composite" # Original: score = RQ + syn - red (prune low) + - "composite_pos_red" # Flipped: score = RQ + syn + red (prune low) + - "rq_minus_red" # score = RQ - redundancy + - "rq_plus_red" # score = RQ + redundancy + - "magnitude_plus_rq" # score = magnitude + RQ + - "magnitude_minus_red" # score = magnitude - redundancy + - "magnitude_plus_red" # score = magnitude + redundancy + + # ========================================================================= + # CLUSTER-AWARE + # ========================================================================= + - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_protect_redundant" # Inverted: protect redundant + scoring_methods: - "random" - "magnitude" - - "taylor" - - "rayleigh_quotient" - - "redundancy" - - "synergy" + - "rq_low" + - "rq_high" + - "redundancy_low" + - "redundancy_high" - "composite" - - "cluster_aware" - + - "composite_pos_red" + fine_tune: - enabled: true + enabled: false # Disabled to see pure pruning impact without recovery epochs: 10 learning_rate: 0.0001 weight_decay: 0.0001 @@ -250,6 +294,7 @@ visualization: # OUTPUT # ----------------------------------------------------------------------------- output: + base_dir: "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" dir: "./results/vision/vgg16_cifar10" save_metrics: true save_clusters: true diff --git a/docs/source/api/experiments.rst b/docs/source/api/experiments.rst index 59503f73..03d6bc81 100644 --- a/docs/source/api/experiments.rst +++ b/docs/source/api/experiments.rst @@ -160,201 +160,6 @@ Progressive Dropout Experiment # - Metric values for remaining neurons # - Layer-wise statistics -Layer-Isolated Pruning Experiment ---------------------------------- - -.. automodule:: alignment.experiments.layer_isolated - :members: - :undoc-members: - :show-inheritance: - -.. autoclass:: alignment.experiments.layer_isolated.LayerIsolatedPruningExperiment - :members: - :undoc-members: - - **Description:** - - This experiment prunes each layer in isolation to understand the contribution and - sensitivity of individual layers. It helps identify which layers are most critical - for network performance. - - **Key Configuration Options:** - - .. attribute:: target_layers - :type: Optional[List[str]] - :default: None - - Specific layers to analyze. If None, analyzes all layers. - - .. attribute:: isolation_mode - :type: str - :default: "sequential" - - How to isolate layers: - - - ``"sequential"``: Test one layer at a time - - ``"parallel"``: Test all layers independently - - ``"cumulative"``: Add layers progressively - - .. attribute:: restoration_mode - :type: str - :default: "full" - - How to restore between tests: - - - ``"full"``: Restore all weights - - ``"partial"``: Keep some pruning - - ``"none"``: No restoration - - **Example Usage:** - - .. code-block:: python - - config = ExperimentConfig( - name="layer_isolation_study", - model_name="resnet18", - dataset_name="cifar10", - metrics=["rayleigh_quotient"], - dropout_rates=[0.0, 0.3, 0.5, 0.7], - target_layers=["layer1.0.conv1", "layer2.0.conv1"], - isolation_mode="sequential" - ) - - experiment = LayerIsolatedPruningExperiment(config) - results = experiment.run() - -Cascading Layer Pruning Experiment ----------------------------------- - -.. automodule:: alignment.experiments.cascading - :members: - :undoc-members: - :show-inheritance: - -.. autoclass:: alignment.experiments.cascading.CascadingLayerPruningExperiment - :members: - :undoc-members: - - **Description:** - - This experiment implements cascading pruning where the effects of pruning early layers - propagate through the network. It studies how pruning decisions in one layer affect - subsequent layers. - - **Key Configuration Options:** - - .. attribute:: cascade_direction - :type: str - :default: "forward" - - Direction of cascading: - - - ``"forward"``: Input to output layers - - ``"backward"``: Output to input layers - - ``"middle_out"``: Start from middle layers - - .. attribute:: cascade_threshold - :type: float - :default: 0.01 - - Minimum activation threshold for cascading effects - - .. attribute:: recompute_scores - :type: bool - :default: True - - Whether to recompute importance scores after each layer is pruned - - .. attribute:: track_information_flow - :type: bool - :default: True - - Track how information flows through pruned network - - **Example Usage:** - - .. code-block:: python - - config = ExperimentConfig( - name="cascading_pruning_analysis", - model_name="resnet18", - dataset_name="cifar10", - metrics=["rayleigh_quotient", "mutual_information"], - dropout_rates=[0.0, 0.2, 0.4, 0.6], - cascade_direction="forward", - cascade_threshold=0.01, - recompute_scores=True - ) - - experiment = CascadingLayerPruningExperiment(config) - results = experiment.run() - -Eigenvector-based Dropout Experiment ------------------------------------- - -.. automodule:: alignment.experiments.eigenvector - :members: - :undoc-members: - :show-inheritance: - -.. autoclass:: alignment.experiments.eigenvector.EigenvectorDropoutExperiment - :members: - :undoc-members: - - **Description:** - - This experiment uses eigenvector analysis of weight matrices to identify important - directions in the network. Neurons are pruned based on their alignment with principal - components of the weight space. - - **Key Configuration Options:** - - .. attribute:: n_components_ratio - :type: float - :default: 0.95 - - Ratio of variance to preserve (0.95 = keep 95% of variance) - - .. attribute:: eigenvector_strategy - :type: str - :default: "low" - - Which components to prune: - - - ``"low"``: Remove low eigenvalue components - - ``"high"``: Remove high eigenvalue components - - ``"middle"``: Remove middle eigenvalue components - - .. attribute:: compute_layer_pca - :type: bool - :default: True - - Whether to compute PCA per layer or globally - - .. attribute:: weight_by_eigenvalue - :type: bool - :default: True - - Whether to weight importance by eigenvalue magnitude - - **Example Usage:** - - .. code-block:: python - - config = ExperimentConfig( - name="eigenvector_pruning", - model_name="resnet18", - dataset_name="cifar10", - metrics=["rayleigh_quotient"], - dropout_rates=[0.0, 0.3, 0.6, 0.9], - n_components_ratio=0.95, - eigenvector_strategy="low", - compute_layer_pca=True - ) - - experiment = EigenvectorDropoutExperiment(config) - results = experiment.run() - Experiment Runner ----------------- diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 6f7abe2e..70de33b5 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -8,6 +8,18 @@ python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml python scripts/run_experiment.py --config configs/examples/resnet_pruning.yaml --device cuda:0 python scripts/run_experiment.py --analysis-only --experiment-dir results/my_experiment_20240101 + +Job Directory Structure: + When base_output_dir is specified in config, experiments create unique job directories: + + {base_output_dir}/ + {experiment_name}_{timestamp}_{job_id}/ + results/ # JSON results files + logs/ # experiment.log + checkpoints/ # Model checkpoints + figures/ # All visualizations + analysis/ # Post-experiment analysis + experiment_config.yaml """ import argparse @@ -66,8 +78,6 @@ def patched_tqdm(*args, **kwargs): pass # tqdm not available, skip configuration from alignment.experiments.general_alignment import GeneralAlignmentExperiment -from alignment.pruning.experiments.cascading_layer import CascadingLayerPruningExperiment -from alignment.pruning.experiments.layer_wise import LayerIsolatedPruningExperiment from alignment.experiments.llm_experiments import LLMAlignmentExperiment from alignment.experiments.cluster_experiments import ( ClusterAnalysisExperiment, @@ -101,6 +111,26 @@ def _get_nested(obj, key, default): metrics_cfg = _get_nested(config, "metrics", {}) clustering_cfg = _get_nested(config, "clustering", {}) halo_cfg = _get_nested(config, "halo_analysis", {}) + pruning_cfg = _get_nested(config, "pruning", {}) + + # Get pruning settings + pruning_ratios = getattr(config, "pruning_amounts", None) or \ + (pruning_cfg.get("ratios") if isinstance(pruning_cfg, dict) else None) or \ + [0.1, 0.3, 0.5, 0.7] + + # Get fine-tuning settings + fine_tune_cfg = pruning_cfg.get("fine_tune", {}) if isinstance(pruning_cfg, dict) else {} + fine_tune_enabled = getattr(config, "fine_tune_after_pruning", + fine_tune_cfg.get("enabled", False) if isinstance(fine_tune_cfg, dict) else False) + fine_tune_epochs = getattr(config, "fine_tune_epochs", + fine_tune_cfg.get("epochs", 10) if isinstance(fine_tune_cfg, dict) else 10) + fine_tune_lr = getattr(config, "fine_tune_learning_rate", + fine_tune_cfg.get("learning_rate", 0.0001) if isinstance(fine_tune_cfg, dict) else 0.0001) + + # Get pruning algorithms/methods + pruning_methods = getattr(config, "pruning_strategies", None) or \ + (pruning_cfg.get("algorithms") if isinstance(pruning_cfg, dict) else None) or \ + ['random', 'magnitude', 'taylor', 'composite', 'cluster_aware'] # Build ClusterAnalysisConfig from the loaded config cluster_config = ClusterAnalysisConfig( @@ -111,6 +141,11 @@ def _get_nested(obj, key, default): synergy_target=getattr(config, "synergy_target", metrics_cfg.get("synergy_target", "logit_margin") if isinstance(metrics_cfg, dict) else "logit_margin"), synergy_pairs=getattr(config, "synergy_pairs", metrics_cfg.get("synergy_num_pairs", 10) if isinstance(metrics_cfg, dict) else 10), halo_percentile=getattr(config, "halo_percentile", halo_cfg.get("percentile", 90.0) if isinstance(halo_cfg, dict) else 90.0), + pruning_ratios=pruning_ratios, + pruning_methods=pruning_methods, + fine_tune_after_pruning=fine_tune_enabled, + fine_tune_epochs=fine_tune_epochs, + fine_tune_lr=fine_tune_lr, output_dir=getattr(config, "experiment_dir", "results/cluster_analysis"), device=getattr(config, "device", "cuda"), seed=getattr(config, "seed", 42), @@ -120,21 +155,38 @@ def _get_nested(obj, key, default): model_name = cluster_config.model_name.lower() num_classes = 10 if "cifar" in cluster_config.dataset_name.lower() else 1000 + # Check for pre-trained checkpoint + model_cfg = _get_nested(config, "model", {}) + checkpoint_path = model_cfg.get("checkpoint", None) if isinstance(model_cfg, dict) else None + checkpoint_path = checkpoint_path or getattr(config, "model_checkpoint", None) + if "resnet18" in model_name: - model = torchvision.models.resnet18(pretrained=True) + model = torchvision.models.resnet18(weights='IMAGENET1K_V1') model.fc = torch.nn.Linear(model.fc.in_features, num_classes) elif "resnet50" in model_name: - model = torchvision.models.resnet50(pretrained=True) + model = torchvision.models.resnet50(weights='IMAGENET1K_V1') model.fc = torch.nn.Linear(model.fc.in_features, num_classes) elif "vgg16" in model_name: - model = torchvision.models.vgg16_bn(pretrained=True) + model = torchvision.models.vgg16_bn(weights='IMAGENET1K_V1') model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) elif "mobilenet" in model_name: - model = torchvision.models.mobilenet_v2(pretrained=True) + model = torchvision.models.mobilenet_v2(weights='IMAGENET1K_V1') model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) else: raise ValueError(f"Unknown model: {model_name}") + # Load checkpoint if available, otherwise model needs to be trained + if checkpoint_path and os.path.exists(checkpoint_path): + logger.info(f"Loading model checkpoint from {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location='cpu') + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + model.load_state_dict(state_dict) + needs_training = False + else: + logger.warning(f"No checkpoint found - model needs to be trained on {cluster_config.dataset_name}") + needs_training = True + # Load dataset dataset_name = cluster_config.dataset_name.lower() if "cifar10" in dataset_name: @@ -160,13 +212,43 @@ def _get_nested(obj, key, default): # Fine-tune the model on target dataset before experiments # This is necessary because we replaced the classifier head with random weights + # Get training settings from config (check multiple locations) + training_cfg = _get_nested(config, "training", {}) + extra_cfg = _get_nested(config, "extra", {}) + + # Check in order: training.epochs, extra.pretrain_epochs, config.pretrain_epochs + pretrain_epochs = ( + training_cfg.get("epochs") if isinstance(training_cfg, dict) else None + ) or ( + extra_cfg.get("pretrain_epochs") if isinstance(extra_cfg, dict) else None + ) or getattr(config, "pretrain_epochs", 30) + + pretrain_lr = ( + training_cfg.get("learning_rate") if isinstance(training_cfg, dict) else None + ) or ( + extra_cfg.get("pretrain_lr") if isinstance(extra_cfg, dict) else None + ) or getattr(config, "pretrain_lr", 0.001) + model = _finetune_model_for_dataset( model, train_loader, test_loader, device=cluster_config.device, - epochs=getattr(config, "pretrain_epochs", 20), - lr=getattr(config, "pretrain_lr", 0.001), + epochs=pretrain_epochs, + lr=pretrain_lr, ) + # Save the trained model checkpoint + output_dir = Path(cluster_config.output_dir) + checkpoint_dir = output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True, parents=True) + trained_checkpoint = checkpoint_dir / "trained_model.pth" + torch.save({ + 'model_state_dict': model.state_dict(), + 'model_name': model_name, + 'dataset_name': dataset_name, + 'num_classes': num_classes, + }, trained_checkpoint) + logger.info(f"Saved trained model checkpoint to {trained_checkpoint}") + return ClusterAnalysisExperiment(cluster_config, model, train_loader, test_loader) @@ -210,9 +292,9 @@ def _finetune_model_for_dataset( total += y.size(0) initial_acc = correct / total - # If already trained (>50% accuracy), skip fine-tuning - if initial_acc > 0.5: - logger.info(f"Model already trained (accuracy: {initial_acc:.2%}), skipping fine-tuning") + # If already well-trained (>85% accuracy on CIFAR-10), skip fine-tuning + if initial_acc > 0.85: + logger.info(f"Model already well-trained (accuracy: {initial_acc:.2%}), skipping fine-tuning") return model logger.info(f"Fine-tuning model on target dataset (initial accuracy: {initial_acc:.2%})...") @@ -318,13 +400,230 @@ def run_post_analysis(config, results_file: Path, output_dir: Path): logger.error(f"Post-analysis failed: {e}") +def _regenerate_llm_visualizations(experiment, results: dict, output_dir: Path): + """ + Regenerate visualizations for LLM experiments from saved results. + + Args: + experiment: LLMAlignmentExperiment instance + results: Loaded results dictionary from JSON + output_dir: Output directory for plots + """ + import numpy as np + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + import matplotlib.pyplot as plt + + from alignment.analysis.visualization import UnifiedVisualizer + + # Determine plots directory + if (output_dir / "figures").exists(): + plots_dir = output_dir / "figures" + elif (output_dir / "plots").exists(): + plots_dir = output_dir / "plots" + else: + plots_dir = output_dir / "figures" + plots_dir.mkdir(parents=True, exist_ok=True) + + viz = UnifiedVisualizer() + + # Extract data from results + importance_scores = results.get("importance_scores", {}) + pruning_results = results.get("pruning_results", {}) + scar_results = results.get("scar_metrics", results.get("scar_scores", {})) + supernode_results = results.get("supernode_analysis", {}) + halo_results = results.get("halo_analysis", {}) + + logger.info(f"Regenerating plots to: {plots_dir}") + logger.info(f" - Found importance_scores: {bool(importance_scores)}") + logger.info(f" - Found pruning_results: {bool(pruning_results)}") + logger.info(f" - Found scar_results: {bool(scar_results)}") + logger.info(f" - Found supernode_results: {bool(supernode_results)}") + logger.info(f" - Found halo_results: {bool(halo_results)}") + + # Convert list values back to numpy arrays for plotting + def to_numpy(data): + if isinstance(data, list): + return np.array(data) + elif isinstance(data, dict): + return {k: to_numpy(v) for k, v in data.items()} + return data + + importance_scores = to_numpy(importance_scores) + scar_results = to_numpy(scar_results) + + # 1. SCAR metrics plots + if scar_results: + scar_plots_dir = plots_dir / "scar" + scar_plots_dir.mkdir(parents=True, exist_ok=True) + + try: + # Layer-wise SCAR distributions + for metric_name in ["scar_loss_proxy", "scar_activation_power", "scar_curvature", "scar_taylor"]: + # Check if this metric exists in any layer + has_metric = any(metric_name in layer_data for layer_data in scar_results.values() if isinstance(layer_data, dict)) + if has_metric: + try: + fig = viz.plot_scar_layer_scores( + scar_results, + metric_name=metric_name, + plot_type="violin", + save_path=scar_plots_dir / f"{metric_name}_layers.png", + ) + plt.close(fig) + logger.info(f" Generated: {metric_name}_layers.png") + except Exception as e: + logger.debug(f"Could not generate {metric_name} plot: {e}") + except Exception as e: + logger.warning(f"Error generating SCAR plots: {e}") + + # 2. Pruning comparison plots + if pruning_results: + pruning_plots_dir = plots_dir / "pruning" + pruning_plots_dir.mkdir(parents=True, exist_ok=True) + + try: + # Extract pruning data for visualization + all_pruning_data = [] + for method_name, method_results in pruning_results.items(): + if isinstance(method_results, dict): + for sparsity, data in method_results.items(): + if isinstance(data, dict): + entry = { + "method": method_name, + "sparsity": float(sparsity) if isinstance(sparsity, str) else sparsity, + "perplexity": data.get("perplexity", data.get("test_perplexity")), + "accuracy": data.get("accuracy"), + } + if entry["perplexity"] is not None or entry["accuracy"] is not None: + all_pruning_data.append(entry) + + if all_pruning_data: + # Generate pruning comparison plots + fig = viz.plot_pruning_comparison_curves( + all_pruning_data, + metric="perplexity", + title="Pruning Methods Comparison (Perplexity)", + save_path=pruning_plots_dir / "pruning_comparison_perplexity.png", + ) + if fig: + plt.close(fig) + logger.info(" Generated: pruning_comparison_perplexity.png") + except Exception as e: + logger.warning(f"Error generating pruning plots: {e}") + + # 3. Redundancy plots (if halo analysis results exist) + if halo_results: + redundancy_plots_dir = plots_dir / "redundancy" + redundancy_plots_dir.mkdir(parents=True, exist_ok=True) + + try: + for layer_name, layer_data in halo_results.items(): + if isinstance(layer_data, dict): + high_vals = layer_data.get("high_connected_redundancy", []) + low_vals = layer_data.get("low_connected_redundancy", []) + + if high_vals and low_vals: + high_vals = np.array(high_vals) if isinstance(high_vals, list) else high_vals + low_vals = np.array(low_vals) if isinstance(low_vals, list) else low_vals + + figs = viz.plot_redundancy_comparison( + high_vals, low_vals, layer_name, + save_dir=redundancy_plots_dir, + ) + for fig in figs: + plt.close(fig) + logger.info(f" Generated redundancy plots for: {layer_name}") + except Exception as e: + logger.warning(f"Error generating redundancy plots: {e}") + + # 4. Importance score histograms + if importance_scores: + importance_plots_dir = plots_dir / "importance" + importance_plots_dir.mkdir(parents=True, exist_ok=True) + + try: + for layer_name, layer_metrics in importance_scores.items(): + if isinstance(layer_metrics, dict): + for metric_name, values in layer_metrics.items(): + if values is not None and len(values) > 0: + try: + values_np = np.array(values) if isinstance(values, list) else values + fig, ax = plt.subplots(figsize=(10, 6)) + ax.hist(values_np.flatten(), bins=50, alpha=0.7, edgecolor='black') + ax.set_xlabel(metric_name) + ax.set_ylabel("Count") + ax.set_title(f"{metric_name} Distribution\n{layer_name}") + + safe_layer = layer_name.replace(".", "_").replace("/", "_") + save_path = importance_plots_dir / f"{metric_name}_{safe_layer}.png" + fig.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close(fig) + except Exception as e: + logger.debug(f"Could not plot {metric_name} for {layer_name}: {e}") + except Exception as e: + logger.warning(f"Error generating importance plots: {e}") + + logger.info(f"Visualization regeneration complete. Plots saved to: {plots_dir}") + + +def _create_job_directory(config, args, timestamp: str) -> Path: + """ + Create job directory using the new unified structure. + + Priority for base_output_dir: + 1. --output-dir CLI argument (full path) + 2. config.base_output_dir (new field) + 3. config.experiment.output_dir or config.output.dir + 4. Default: ./results + + Returns: + Path to the created job directory. + """ + from alignment.infrastructure.storage import create_job_directory, get_slurm_job_id + + experiment_name = getattr(config, "name", "experiment") + + # If --output-dir is given as full path, use it directly + if args.output_dir: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + # Check for base_output_dir in config (new unified approach) + base_output_dir = getattr(config, "base_output_dir", None) + + # If base_output_dir is set, use new job directory structure + if base_output_dir: + job_id = get_slurm_job_id() # Will be None for non-SLURM runs + output_dir = create_job_directory( + base_output_dir=base_output_dir, + experiment_name=experiment_name, + timestamp=timestamp, + job_id=job_id, + create_subdirs=True, + ) + return output_dir + + # Fallback to old behavior: results/{name}_{timestamp} + output_dir = Path(f"results/{experiment_name}_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + + # Create subdirectories for consistency + for subdir in ["results", "logs", "checkpoints", "figures", "analysis"]: + (output_dir / subdir).mkdir(exist_ok=True) + + return output_dir + + def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="Unified Alignment Experiment Runner") parser.add_argument("--config", type=str, required=True, help="Configuration file") parser.add_argument("--device", type=str, help="Override device") parser.add_argument("--seed", type=int, help="Override seed") - parser.add_argument("--output-dir", type=str, help="Override output directory") + parser.add_argument("--output-dir", type=str, help="Override output directory (full path)") + parser.add_argument("--base-output-dir", type=str, help="Override base output directory (creates job subdir)") parser.add_argument( "--analysis-only", action="store_true", @@ -353,6 +652,10 @@ def main(): for key, value in overrides.items(): if hasattr(config, key): setattr(config, key, value) + + # Override base_output_dir if provided via CLI + if args.base_output_dir: + config.base_output_dir = args.base_output_dir is_analysis_only = bool(args.analysis_only) @@ -366,40 +669,48 @@ def main(): config.experiment_dir = str(output_dir) config.checkpoint_dir = str(output_dir / "checkpoints") config.log_dir = str(output_dir / "logs") - plots_dir = output_dir / "plots" + + # Support both old 'plots' and new 'figures' directory names + if (output_dir / "figures").exists(): + plots_dir = output_dir / "figures" + else: + plots_dir = output_dir / "plots" config.plots_dir = str(plots_dir) plots_dir.mkdir(parents=True, exist_ok=True) config_save_path = output_dir / "experiment_config.yaml" timestamp = None else: - # Create timestamped output directory + # Create job directory with new structure timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - experiment_name = getattr(config, "name", "experiment") - - if args.output_dir: - output_dir = Path(args.output_dir) - else: - output_dir = Path(f"results/{experiment_name}_{timestamp}") - - output_dir.mkdir(parents=True, exist_ok=True) + output_dir = _create_job_directory(config, args, timestamp) config_save_path = output_dir / "experiment_config.yaml" config.save(config_save_path) + # Set paths - use new subdirectory structure config.checkpoint_dir = str(output_dir / "checkpoints") config.log_dir = str(output_dir / "logs") config.experiment_dir = str(output_dir) - plots_dir = output_dir / "plots" + # Use 'figures' subdirectory for new structure, 'plots' for compatibility + if (output_dir / "figures").exists(): + plots_dir = output_dir / "figures" + else: + plots_dir = output_dir / "plots" + plots_dir.mkdir(parents=True, exist_ok=True) config.plots_dir = str(plots_dir) Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(config.log_dir).mkdir(parents=True, exist_ok=True) - plots_dir.mkdir(parents=True, exist_ok=True) - # Setup logging - log_file = output_dir / "experiment.log" + # Setup logging - use logs subdirectory if it exists (new structure) + logs_subdir = output_dir / "logs" + if logs_subdir.exists(): + log_file = logs_subdir / "experiment.log" + else: + log_file = output_dir / "experiment.log" + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -426,23 +737,29 @@ def main(): elif experiment_type in {"cluster_analysis", "vision_cluster_analysis", "metric_cluster_analysis"}: # Cluster-based analysis experiment (works for any architecture) experiment = _create_cluster_experiment(config) - elif experiment_type == "layer_isolated_pruning": - experiment = LayerIsolatedPruningExperiment(config) - elif experiment_type == "cascading_layer_pruning": - experiment = CascadingLayerPruningExperiment(config) else: - raise ValueError(f"Unknown experiment type: {experiment_type}") + raise ValueError(f"Unknown experiment type: {experiment_type}. " + f"Supported types: llm_alignment, llm_supernode, llm, " + f"alignment_analysis, vision_synergy, general_alignment, " + f"cluster_analysis, vision_cluster_analysis, metric_cluster_analysis") # Analysis-only mode if is_analysis_only: - if isinstance(experiment, GeneralAlignmentExperiment): + # Find results file - check both root and results subdirectory + results_subdir = output_dir / "results" + if results_subdir.exists(): + result_files = sorted(results_subdir.glob("results_*.json")) + else: result_files = sorted(output_dir.glob("results_*.json")) - if not result_files: - raise FileNotFoundError(f"No results_*.json found in {output_dir}") - results_path = result_files[-1] - with results_path.open("r") as f: - results = json.load(f) + + if not result_files: + raise FileNotFoundError(f"No results_*.json found in {output_dir} or {results_subdir}") + results_path = result_files[-1] + + with results_path.open("r") as f: + results = json.load(f) + if isinstance(experiment, GeneralAlignmentExperiment): experiment.train_results = results.get("train_results", {}) experiment.test_results = results.get("test_results", {}) experiment.dropout_results = results.get("dropout_results", {}) @@ -455,6 +772,25 @@ def main(): # Run post-analysis if configured run_post_analysis(config, results_path, output_dir) + + elif isinstance(experiment, LLMAlignmentExperiment): + # For LLM experiments, regenerate visualizations from saved results + logger.info("Regenerating LLM experiment visualizations from saved results...") + + # Setup experiment (loads model/tokenizer for any model-dependent plots) + try: + experiment.setup() + except Exception as e: + logger.warning(f"Could not setup model (may not be needed for plots): {e}") + + # Regenerate visualizations using the unified visualizer + if getattr(config, "generate_plots", True): + _regenerate_llm_visualizations(experiment, results, output_dir) + logger.info("Regenerated LLM visualizations from existing results") + + # Run post-analysis if configured + run_post_analysis(config, results_path, output_dir) + else: logger.warning(f"Analysis-only mode not supported for {experiment_type}") @@ -466,15 +802,21 @@ def main(): # Full experiment run results = experiment.run() - # Save results - results_file = output_dir / f"results_{timestamp}.json" + # Save results - use results subdirectory if it exists (new structure) + results_subdir = output_dir / "results" + if results_subdir.exists(): + results_file = results_subdir / f"results_{timestamp}.json" + else: + results_file = output_dir / f"results_{timestamp}.json" def convert_to_serializable(obj): if hasattr(obj, "tolist"): return obj.tolist() + elif hasattr(obj, "item"): + return obj.item() elif isinstance(obj, dict): return {k: convert_to_serializable(v) for k, v in obj.items()} - elif isinstance(obj, list): + elif isinstance(obj, (list, tuple)): return [convert_to_serializable(i) for i in obj] return obj diff --git a/scripts/verify_pruning.py b/scripts/verify_pruning.py new file mode 100755 index 00000000..ec77fad8 --- /dev/null +++ b/scripts/verify_pruning.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +"""Quick verification script to test that CNN pruning is working correctly.""" +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +import numpy as np +import copy + +def load_cifar10(batch_size=128): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + return (torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2), + torch.utils.data.DataLoader(test, batch_size=batch_size*2, shuffle=False, num_workers=2)) + +def evaluate(model, loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + correct += (model(x).argmax(1) == y).sum().item() + total += y.size(0) + return correct / total + +def train_model(model, loader, device, epochs=10): + model = model.to(device) + model.train() + opt = torch.optim.Adam(model.parameters(), lr=0.001) + for epoch in range(epochs): + for x, y in loader: + x, y = x.to(device), y.to(device) + opt.zero_grad() + loss = nn.CrossEntropyLoss()(model(x), y) + loss.backward() + opt.step() + if (epoch+1) % 5 == 0: + print(f" Epoch {epoch+1}/{epochs}") + return model + +def prune_layer(model, layer_name, layer, indices): + with torch.no_grad(): + layer.weight.data[indices] = 0 + if layer.bias is not None: + layer.bias.data[indices] = 0 + # Zero BatchNorm + for name, m in model.named_modules(): + if isinstance(m, nn.BatchNorm2d): + if layer_name.replace('conv','bn') in name or layer_name.replace('.conv','.bn') in name: + with torch.no_grad(): + m.weight.data[indices] = 0 + m.bias.data[indices] = 0 + m.running_mean.data[indices] = 0 + m.running_var.data[indices] = 1 + break + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + train_loader, test_loader = load_cifar10() + + # Load and train + print("\nTraining ResNet18 on CIFAR-10...") + model = torchvision.models.resnet18(weights='IMAGENET1K_V1') + model.fc = nn.Linear(model.fc.in_features, 10) + model = train_model(model, train_loader, device, epochs=15) + baseline = evaluate(model, test_loader, device) + print(f"\nBaseline accuracy: {baseline:.2%}") + + # Get conv layers + convs = [(n,m) for n,m in model.named_modules() if isinstance(m, nn.Conv2d) and m.weight.shape[0]>1] + print(f"\nTesting pruning on {len(convs)} conv layers...") + + # Test: accuracy vs sparsity + print("\nAccuracy vs Sparsity (random pruning, all layers):") + for ratio in [0.1, 0.3, 0.5, 0.7, 0.8, 0.9]: + m = copy.deepcopy(model) + for name, layer in convs: + l = dict(m.named_modules())[name] + n_ch = layer.weight.shape[0] + n_prune = min(int(n_ch * ratio), n_ch - 1) + idx = np.random.choice(n_ch, n_prune, replace=False).tolist() + prune_layer(m, name, l, idx) + acc = evaluate(m, test_loader, device) + print(f" {ratio:.0%}: {acc:.2%} (drop: {baseline-acc:+.2%})") + + print("\nIf accuracy drops with higher sparsity, pruning is working!") + print("If random matches magnitude-based, model is over-parameterized.") + +if __name__ == "__main__": + main() diff --git a/slurm_jobs/prune_llm/run_all_paper.sh b/slurm_jobs/prune_llm/run_all_paper.sh index e238f884..36edaa53 100755 --- a/slurm_jobs/prune_llm/run_all_paper.sh +++ b/slurm_jobs/prune_llm/run_all_paper.sh @@ -5,33 +5,46 @@ # This script submits all 4 paper experiments as separate SLURM jobs # They will run in parallel if resources are available # +# Output Directory Structure: +# All results go to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# Each job creates a unique directory: {model}_paper_results_{timestamp}_{job_id}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# # Usage: # cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -# bash slurm_jobs/paper/run_all_paper.sh +# bash slurm_jobs/prune_llm/run_all_paper.sh # ============================================================================ +OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" + echo "==============================================" echo "Submitting SCAR Paper Experiments" echo "==============================================" echo "" +echo "Output directory: $OUTPUT_BASE" +echo "" cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment # Submit all jobs echo "Submitting LLaMA-3.1-8B (main results)..." -JOB1=$(sbatch slurm_jobs/paper/run_llama3_8b.sh | awk '{print $4}') +JOB1=$(sbatch slurm_jobs/prune_llm/run_llama3_8b.sh | awk '{print $4}') echo " Job ID: $JOB1" echo "Submitting Mistral-7B (generalization)..." -JOB2=$(sbatch slurm_jobs/paper/run_mistral_7b.sh | awk '{print $4}') +JOB2=$(sbatch slurm_jobs/prune_llm/run_mistral_7b.sh | awk '{print $4}') echo " Job ID: $JOB2" echo "Submitting LLaMA-2-7B (generalization)..." -JOB3=$(sbatch slurm_jobs/paper/run_llama2_7b.sh | awk '{print $4}') +JOB3=$(sbatch slurm_jobs/prune_llm/run_llama2_7b.sh | awk '{print $4}') echo " Job ID: $JOB3" echo "Submitting Qwen2-7B (generalization)..." -JOB4=$(sbatch slurm_jobs/paper/run_qwen2_7b.sh | awk '{print $4}') +JOB4=$(sbatch slurm_jobs/prune_llm/run_qwen2_7b.sh | awk '{print $4}') echo " Job ID: $JOB4" echo "" @@ -44,11 +57,16 @@ echo "" echo "Monitor with:" echo " squeue -u \$USER" echo "" -echo "View logs:" +echo "View SLURM logs:" echo " tail -f logs/paper_llama3_8b_${JOB1}.out" echo " tail -f logs/paper_mistral_7b_${JOB2}.out" echo " tail -f logs/paper_llama2_7b_${JOB3}.out" echo " tail -f logs/paper_qwen2_7b_${JOB4}.out" echo "" echo "Expected runtime: ~6-8 hours per job" -echo "Results will be in: results/paper//" +echo "" +echo "Results will be in:" +echo " $OUTPUT_BASE/llama3_8b_paper_results_*_${JOB1}/" +echo " $OUTPUT_BASE/mistral_7b_paper_results_*_${JOB2}/" +echo " $OUTPUT_BASE/llama2_7b_paper_results_*_${JOB3}/" +echo " $OUTPUT_BASE/qwen2_7b_paper_results_*_${JOB4}/" diff --git a/slurm_jobs/prune_llm/run_llama2_7b.sh b/slurm_jobs/prune_llm/run_llama2_7b.sh index 9f335693..5e28a8f1 100755 --- a/slurm_jobs/prune_llm/run_llama2_7b.sh +++ b/slurm_jobs/prune_llm/run_llama2_7b.sh @@ -16,6 +16,15 @@ # ============================================================================ # Cross-model generalization experiment # Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# llama2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs # ============================================================================ echo "============================================================================" @@ -25,6 +34,7 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" echo "" # Environment setup @@ -35,8 +45,8 @@ conda activate networkAlignmentAnalysis cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# Create local logs directory for SLURM output files mkdir -p logs -mkdir -p results/paper/llama2_7b export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false @@ -48,8 +58,10 @@ echo "" echo "Running LLaMA-2-7B full paper analysis..." echo "" +# The config has base_dir set, so outputs go to: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/paper/llama2_7b_full.yaml \ + --config configs/prune_llm/llama2_7b_unified.yaml \ --device cuda echo "" @@ -57,4 +69,5 @@ echo "========================================================================== echo "LLaMA-2-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: results/paper/llama2_7b/" +echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Look for directory: llama2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_llama3_8b.sh b/slurm_jobs/prune_llm/run_llama3_8b.sh index 35f9bd17..a4a83c96 100755 --- a/slurm_jobs/prune_llm/run_llama3_8b.sh +++ b/slurm_jobs/prune_llm/run_llama3_8b.sh @@ -23,6 +23,15 @@ # - Full benchmark evaluation # # Expected runtime: ~6-8 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# llama3_8b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs # ============================================================================ echo "============================================================================" @@ -32,6 +41,7 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" echo "" # Environment setup @@ -42,8 +52,8 @@ conda activate networkAlignmentAnalysis cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# Create local logs directory for SLURM output files mkdir -p logs -mkdir -p results/paper/llama3_8b export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false @@ -55,9 +65,8 @@ echo "" echo "Running LLaMA-3.1-8B full paper analysis..." echo "" -# python scripts/run_experiment.py \ -# --config configs/paper/llama3_8b_full.yaml \ -# --device cuda +# The config has base_dir set, so outputs go to: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ --config configs/prune_llm/llama3_8b_unified.yaml \ --device cuda @@ -67,4 +76,5 @@ echo "========================================================================== echo "LLaMA-3.1-8B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: results/paper/llama3_8b/" +echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Look for directory: llama3_8b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_mistral_7b.sh b/slurm_jobs/prune_llm/run_mistral_7b.sh index 70429016..91efd866 100755 --- a/slurm_jobs/prune_llm/run_mistral_7b.sh +++ b/slurm_jobs/prune_llm/run_mistral_7b.sh @@ -16,6 +16,15 @@ # ============================================================================ # Cross-model generalization experiment # Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# mistral_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs # ============================================================================ echo "============================================================================" @@ -25,6 +34,7 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" echo "" # Environment setup @@ -35,8 +45,8 @@ conda activate networkAlignmentAnalysis cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# Create local logs directory for SLURM output files mkdir -p logs -mkdir -p results/paper/mistral_7b export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false @@ -48,8 +58,10 @@ echo "" echo "Running Mistral-7B full paper analysis..." echo "" +# The config has base_dir set, so outputs go to: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/paper/mistral_7b_full.yaml \ + --config configs/prune_llm/mistral_7b_unified.yaml \ --device cuda echo "" @@ -57,4 +69,5 @@ echo "========================================================================== echo "Mistral-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: results/paper/mistral_7b/" +echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Look for directory: mistral_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_qwen2_7b.sh b/slurm_jobs/prune_llm/run_qwen2_7b.sh index 18bd9116..31780718 100755 --- a/slurm_jobs/prune_llm/run_qwen2_7b.sh +++ b/slurm_jobs/prune_llm/run_qwen2_7b.sh @@ -17,6 +17,15 @@ # Cross-model generalization experiment # Qwen2 has different FFN architecture (28 layers, larger intermediate) # Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# qwen2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs # ============================================================================ echo "============================================================================" @@ -26,6 +35,7 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" echo "" # Environment setup @@ -36,8 +46,8 @@ conda activate networkAlignmentAnalysis cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# Create local logs directory for SLURM output files mkdir -p logs -mkdir -p results/paper/qwen2_7b export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false @@ -49,8 +59,10 @@ echo "" echo "Running Qwen2-7B full paper analysis..." echo "" +# The config has base_dir set, so outputs go to: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/paper/qwen2_7b_full.yaml \ + --config configs/prune_llm/qwen2_7b_unified.yaml \ --device cuda echo "" @@ -58,4 +70,5 @@ echo "========================================================================== echo "Qwen2-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: results/paper/qwen2_7b/" +echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Look for directory: qwen2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh b/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh index 4afda69c..f3cde664 100644 --- a/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh +++ b/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh @@ -19,11 +19,22 @@ # - K-means clustering into functional types # - Cross-layer halo analysis # - Cascade damage testing -# - Visualization generation +# - Pruning experiments (without fine-tuning to see raw impact) +# - Organized visualization output +# +# Figure Organization: +# figures/01_distributions/ - Per-layer metric histograms +# figures/02_summary/ - Layer-wise violin plots, trends +# figures/03_clustering/ - Cluster scatter plots, evolution +# figures/04_cascade/ - Cascade damage test results +# figures/05_halo/ - Halo analysis plots +# figures/06_pruning/ - Pruning comparison charts # # Expected runtime: ~1-2 hours on single GPU # ============================================================================ +OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" + echo "============================================================================" echo "Cluster-Based Analysis: ResNet-18 on CIFAR-10" echo "============================================================================" @@ -31,6 +42,7 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" +echo "Output Base: $OUTPUT_BASE" echo "" # Environment setup @@ -41,14 +53,14 @@ conda activate networkAlignmentAnalysis cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -# Create directories +# Create local logs directory for SLURM output files mkdir -p logs -mkdir -p results/cluster_analysis/resnet18_cifar10 export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK echo "" echo "Running ResNet-18 cluster analysis..." +echo "Fine-tuning after pruning: DISABLED (seeing raw pruning impact)" echo "" python scripts/run_experiment.py \ @@ -63,7 +75,7 @@ echo "ResNet-18 cluster analysis completed at $(date)" echo "Exit code: $EXIT_CODE" echo "============================================================================" echo "" -echo "Results saved to: results/cluster_analysis/resnet18_cifar10/" +echo "Results saved to: $OUTPUT_BASE/" +echo "Look for directory starting with: resnet18_cifar10_cluster_analysis_" exit $EXIT_CODE - diff --git a/src/alignment/analysis/visualization/__init__.py b/src/alignment/analysis/visualization/__init__.py index 3136f996..4ccb65d3 100644 --- a/src/alignment/analysis/visualization/__init__.py +++ b/src/alignment/analysis/visualization/__init__.py @@ -87,11 +87,20 @@ plot_pruning_comparison, plot_metric_distributions_for_layer, plot_layer_metric_summary, + plot_layer_metric_trends, + plot_metric_statistics_table, plot_centroid_evolution, plot_centroid_depth_profiles, CLUSTER_COLORS, ) +# New bar chart functions for pruning +from .pruning_plots import ( + plot_pruning_bar_comparison, + plot_pruning_heatmap, + plot_pruning_ranking, +) + __all__ = [ # Primary "UnifiedVisualizer", @@ -128,7 +137,13 @@ "plot_pruning_comparison", "plot_metric_distributions_for_layer", "plot_layer_metric_summary", + "plot_layer_metric_trends", + "plot_metric_statistics_table", "plot_centroid_evolution", "plot_centroid_depth_profiles", "CLUSTER_COLORS", + # New bar chart functions + "plot_pruning_bar_comparison", + "plot_pruning_heatmap", + "plot_pruning_ranking", ] diff --git a/src/alignment/analysis/visualization/cluster_plots.py b/src/alignment/analysis/visualization/cluster_plots.py index b76950f4..f32e3859 100644 --- a/src/alignment/analysis/visualization/cluster_plots.py +++ b/src/alignment/analysis/visualization/cluster_plots.py @@ -702,3 +702,215 @@ def plot_vision_pruning_comparison( ) -> Optional["plt.Figure"]: """Alias for plot_pruning_comparison for backward compatibility.""" return plot_pruning_comparison(results, baseline_acc, save_path, figsize) + + +def plot_layer_metric_trends( + layer_metrics: Dict[str, Dict[str, np.ndarray]], + metrics_to_plot: List[str] = ['rq', 'redundancy', 'synergy'], + smooth_window: int = 3, + show_ci: bool = True, + ci_percentile: float = 95, + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (14, 8), +) -> Optional["plt.Figure"]: + """ + Plot layer-wise metric trends with smoothing and confidence intervals. + + This produces cleaner, more interpretable plots than raw layer values. + + Args: + layer_metrics: Dict mapping layer_name -> {metric_name -> values} + metrics_to_plot: Which metrics to include + smooth_window: Window size for moving average smoothing (1 = no smoothing) + show_ci: Whether to show confidence intervals + ci_percentile: Percentile for confidence interval + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL or not layer_metrics: + return None + + layer_names = list(layer_metrics.keys()) + n_layers = len(layer_names) + + if n_layers == 0: + return None + + fig, axes = plt.subplots(len(metrics_to_plot), 1, figsize=figsize, sharex=True) + if len(metrics_to_plot) == 1: + axes = [axes] + + x = np.arange(n_layers) + + for ax, metric_name in zip(axes, metrics_to_plot): + # Extract layer-wise statistics + means = [] + stds = [] + ci_low = [] + ci_high = [] + + for layer in layer_names: + values = layer_metrics[layer].get(metric_name, np.array([])) + if len(values) > 0: + means.append(np.mean(values)) + stds.append(np.std(values)) + ci_low.append(np.percentile(values, (100 - ci_percentile) / 2)) + ci_high.append(np.percentile(values, 100 - (100 - ci_percentile) / 2)) + else: + means.append(0) + stds.append(0) + ci_low.append(0) + ci_high.append(0) + + means = np.array(means) + stds = np.array(stds) + ci_low = np.array(ci_low) + ci_high = np.array(ci_high) + + # Apply smoothing (moving average) + if smooth_window > 1 and n_layers >= smooth_window: + kernel = np.ones(smooth_window) / smooth_window + means_smooth = np.convolve(means, kernel, mode='same') + # Fix edges + for i in range(smooth_window // 2): + means_smooth[i] = np.mean(means[:i+smooth_window//2+1]) + means_smooth[-(i+1)] = np.mean(means[-(i+smooth_window//2+1):]) + else: + means_smooth = means + + # Get color for metric + color = METRIC_COLORS.get(metric_name.lower(), '#333333') + + # Plot smoothed line + ax.plot(x, means_smooth, 'o-', color=color, linewidth=2.5, markersize=6, + label=f'{metric_name.upper()} (smoothed)') + + # Show confidence interval + if show_ci: + ax.fill_between(x, ci_low, ci_high, alpha=0.2, color=color, + label=f'{int(ci_percentile)}% CI') + + # Also show raw means as smaller points + ax.scatter(x, means, s=20, alpha=0.5, color=color, zorder=5) + + ax.set_ylabel(f'{metric_name.upper()}', fontsize=11) + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=9) + + # Add layer depth indicator + ax.axhline(y=np.mean(means), color='gray', linestyle='--', alpha=0.5) + + # X-axis labels on bottom plot only + axes[-1].set_xlabel('Layer Index', fontsize=11) + + # Simplify x-tick labels if too many layers + if n_layers > 20: + tick_positions = np.linspace(0, n_layers-1, min(10, n_layers), dtype=int) + axes[-1].set_xticks(tick_positions) + axes[-1].set_xticklabels([layer_names[i].split('.')[-1] for i in tick_positions], + rotation=45, ha='right', fontsize=9) + else: + axes[-1].set_xticks(x) + axes[-1].set_xticklabels([n.split('.')[-1] for n in layer_names], + rotation=45, ha='right', fontsize=9) + + fig.suptitle('Layer-wise Metric Trends (with smoothing)', fontsize=14, fontweight='bold') + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved layer metric trends to {save_path}") + + return fig + + +def plot_metric_statistics_table( + layer_metrics: Dict[str, Dict[str, np.ndarray]], + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (12, 8), +) -> Optional["plt.Figure"]: + """ + Create a summary table showing global statistics for each metric. + + Args: + layer_metrics: Dict mapping layer_name -> {metric_name -> values} + save_path: Optional path to save figure + figsize: Figure size + + Returns: + Matplotlib Figure or None + """ + if not HAS_MPL or not layer_metrics: + return None + + # Aggregate all values per metric + metric_stats = {} + for layer_name, metrics in layer_metrics.items(): + for metric_name, values in metrics.items(): + if metric_name.startswith('_'): # Skip internal + continue + if metric_name not in metric_stats: + metric_stats[metric_name] = [] + if len(values) > 0: + metric_stats[metric_name].extend(values.tolist() if hasattr(values, 'tolist') else list(values)) + + if not metric_stats: + return None + + # Compute statistics + table_data = [] + for metric_name, all_values in metric_stats.items(): + if len(all_values) == 0: + continue + arr = np.array(all_values) + table_data.append([ + metric_name.upper(), + f'{np.mean(arr):.4f}', + f'{np.std(arr):.4f}', + f'{np.median(arr):.4f}', + f'{np.min(arr):.4f}', + f'{np.max(arr):.4f}', + f'{np.percentile(arr, 5):.4f}', + f'{np.percentile(arr, 95):.4f}', + ]) + + # Create figure with table + fig, ax = plt.subplots(figsize=figsize) + ax.axis('off') + + headers = ['Metric', 'Mean', 'Std', 'Median', 'Min', 'Max', 'P5', 'P95'] + + table = ax.table( + cellText=table_data, + colLabels=headers, + cellLoc='center', + loc='center', + ) + + table.auto_set_font_size(False) + table.set_fontsize(11) + table.scale(1.2, 1.8) + + # Style header row + for i in range(len(headers)): + table[(0, i)].set_facecolor('#4CAF50') + table[(0, i)].set_text_props(weight='bold', color='white') + + # Alternate row colors + for i in range(1, len(table_data) + 1): + for j in range(len(headers)): + if i % 2 == 0: + table[(i, j)].set_facecolor('#f0f0f0') + + ax.set_title('Metric Statistics Summary (All Layers)', fontsize=14, fontweight='bold', pad=20) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved metric statistics table to {save_path}") + + return fig diff --git a/src/alignment/analysis/visualization/halo_plots.py b/src/alignment/analysis/visualization/halo_plots.py index 982c0124..04908b1f 100644 --- a/src/alignment/analysis/visualization/halo_plots.py +++ b/src/alignment/analysis/visualization/halo_plots.py @@ -237,7 +237,7 @@ def plot_halo_redundancy_comprehensive( colors.append('#2ecc71') if data_to_plot: - bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True) + bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True, showfliers=False) for patch, color in zip(bp['boxes'], colors): patch.set_facecolor(color) patch.set_alpha(0.6) diff --git a/src/alignment/analysis/visualization/pruning_plots.py b/src/alignment/analysis/visualization/pruning_plots.py index a1296854..b4effa3f 100644 --- a/src/alignment/analysis/visualization/pruning_plots.py +++ b/src/alignment/analysis/visualization/pruning_plots.py @@ -743,6 +743,27 @@ def _plot_summary_stats(self, ax, results): # Network slimming "network_slimming": "#8e44ad", "chip": "#16a085", + # Activation-based metrics + "activation_l2_norm": "#e74c3c", # Same as magnitude (they're aliases) + "activation_mean": "#c0392b", + "activation_variance": "#a93226", + # Generalized importance (no outlier assumption) + "generalized_importance": "#27ae60", # Green - distinct from supernode methods + "neighborhood_redundancy": "#2ecc71", + # Supernode/connectivity methods + "supernode_protection_score": "#3498db", # Blue + "supernode_connectivity_score": "#2980b9", + "directed_redundancy": "#1abc9c", + "cross_layer_importance": "#9b59b6", # Purple + "within_layer_importance": "#8e44ad", + # SCAR metrics + "scar_loss_proxy": "#e67e22", # Orange + "scar_activation_power": "#d35400", + "scar_taylor": "#f39c12", + "scar_curvature": "#f1c40f", + # Rayleigh quotient + "rayleigh_quotient": "#3498db", + "gaussian_mi_analytic": "#2ecc71", } PRUNING_METHOD_MARKERS = { @@ -1071,3 +1092,271 @@ def plot_pruning_recovery_chart( logger.info(f"Saved pruning recovery chart to {save_path}") return fig + + +# ============================================================================== +# NEW: Bar Charts for Clear Method Comparison +# ============================================================================== + +def plot_pruning_bar_comparison( + results: Dict[str, Dict[float, Dict[str, Any]]], + baseline_value: Optional[float] = None, + target_sparsity: float = 0.5, + metric: str = "accuracy", + show_before_ft: bool = True, + title: Optional[str] = None, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (14, 7), + sort_by_performance: bool = True, +) -> "plt.Figure": + """ + Create a clear bar chart comparing pruning methods at a specific sparsity level. + + Args: + results: Dict mapping method -> {sparsity -> {metric: value}} + baseline_value: Baseline (unpruned) value + target_sparsity: Sparsity level to compare at (e.g., 0.5 for 50%) + metric: Which metric to plot + show_before_ft: Whether to also show before-fine-tuning values + title: Plot title + save_path: Path to save figure + figsize: Figure size + sort_by_performance: Whether to sort bars by performance + + Returns: + Matplotlib Figure + """ + methods = [] + after_ft_values = [] + before_ft_values = [] + + for method, method_results in results.items(): + sparsities = [s for s in method_results.keys() if isinstance(s, (int, float))] + if not sparsities: + continue + + closest = min(sparsities, key=lambda x: abs(x - target_sparsity)) + if abs(closest - target_sparsity) > 0.15: + continue + + data = method_results[closest] + if isinstance(data, dict) and 'error' not in data: + after_val = data.get('accuracy_after_ft') or data.get('accuracy') + before_val = data.get('accuracy_before_ft') + + if after_val is not None: + methods.append(method) + if metric == 'accuracy' and after_val <= 1.0: + after_val *= 100 + if before_val is not None: + before_val *= 100 + after_ft_values.append(after_val) + before_ft_values.append(before_val) + + if not methods: + logger.warning(f"No valid data found for sparsity {target_sparsity}") + return plt.figure() + + if sort_by_performance: + sorted_data = sorted(zip(methods, after_ft_values, before_ft_values), + key=lambda x: x[1], reverse=True) + methods, after_ft_values, before_ft_values = zip(*sorted_data) + methods, after_ft_values, before_ft_values = list(methods), list(after_ft_values), list(before_ft_values) + + fig, ax = plt.subplots(figsize=figsize) + x = np.arange(len(methods)) + width = 0.35 if show_before_ft and any(v is not None for v in before_ft_values) else 0.6 + + # Colors based on method category + colors = [] + for m in methods: + if 'random' in m.lower(): + colors.append('#95a5a6') + elif 'magnitude' in m.lower() and 'plus' not in m.lower() and 'minus' not in m.lower(): + colors.append('#e74c3c') + elif 'composite' in m.lower() or 'cluster' in m.lower(): + colors.append('#2ecc71') + elif 'rq' in m.lower(): + colors.append('#3498db') + elif 'redundancy' in m.lower(): + colors.append('#9b59b6') + elif 'synergy' in m.lower(): + colors.append('#f39c12') + else: + colors.append('#1abc9c') + + if show_before_ft and any(v is not None for v in before_ft_values): + bars_before = ax.bar(x - width/2, + [v if v is not None else 0 for v in before_ft_values], + width, label='Before Fine-tune', alpha=0.5, + color=colors, edgecolor='black', linewidth=0.5) + bars_after = ax.bar(x + width/2, after_ft_values, width, + label='After Fine-tune', color=colors, + edgecolor='black', linewidth=1) + else: + bars_after = ax.bar(x, after_ft_values, width, color=colors, + edgecolor='black', linewidth=1) + + if baseline_value is not None: + baseline_display = baseline_value * 100 if baseline_value <= 1.0 else baseline_value + ax.axhline(y=baseline_display, color='red', linestyle='--', linewidth=2, + label=f'Unpruned: {baseline_display:.1f}%') + + for i, val in enumerate(after_ft_values): + bar = bars_after[i] + height = bar.get_height() + ax.annotate(f'{val:.1f}%', + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom', fontsize=9, fontweight='bold') + + ax.set_xlabel('Pruning Method', fontsize=12) + ax.set_ylabel(f'{metric.title()} (%)', fontsize=12) + title = title or f'Pruning Method Comparison at {int(target_sparsity*100)}% Sparsity' + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels([m.replace('_', '\n') for m in methods], rotation=45, ha='right', fontsize=10) + ax.legend(loc='upper right', fontsize=10) + ax.grid(True, alpha=0.3, axis='y') + y_min = max(0, min(after_ft_values) - 10) + ax.set_ylim([y_min, 105]) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved pruning bar comparison to {save_path}") + + return fig + + +def plot_pruning_heatmap( + results: Dict[str, Dict[float, Dict[str, Any]]], + metric: str = "accuracy", + title: Optional[str] = None, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (14, 10), + annotate: bool = True, +) -> "plt.Figure": + """ + Create a heatmap showing accuracy for all method-sparsity combinations. + """ + all_sparsities = set() + for method_results in results.values(): + all_sparsities.update(s for s in method_results.keys() if isinstance(s, (int, float))) + sparsities = sorted(all_sparsities) + + methods = list(results.keys()) + data = np.full((len(methods), len(sparsities)), np.nan) + + for i, method in enumerate(methods): + method_results = results[method] + for j, s in enumerate(sparsities): + if s in method_results: + d = method_results[s] + if isinstance(d, dict) and 'error' not in d: + val = d.get('accuracy_after_ft') or d.get('accuracy') + if val is not None: + data[i, j] = val * 100 if val <= 1.0 else val + + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(data, aspect='auto', cmap='RdYlGn', vmin=0, vmax=100) + cbar = plt.colorbar(im, ax=ax) + cbar.set_label(f'{metric.title()} (%)', fontsize=12) + + ax.set_xticks(np.arange(len(sparsities))) + ax.set_yticks(np.arange(len(methods))) + ax.set_xticklabels([f'{int(s*100)}%' for s in sparsities], fontsize=10) + ax.set_yticklabels([m.replace('_', ' ') for m in methods], fontsize=10) + + if annotate: + for i in range(len(methods)): + for j in range(len(sparsities)): + if not np.isnan(data[i, j]): + text_color = 'white' if data[i, j] < 50 else 'black' + ax.text(j, i, f'{data[i, j]:.1f}', ha='center', va='center', + color=text_color, fontsize=9, fontweight='bold') + + ax.set_xlabel('Sparsity Level', fontsize=12) + ax.set_ylabel('Pruning Method', fontsize=12) + title = title or 'Pruning Performance Heatmap' + ax.set_title(title, fontsize=14, fontweight='bold') + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved pruning heatmap to {save_path}") + + return fig + + +def plot_pruning_ranking( + results: Dict[str, Dict[float, Dict[str, Any]]], + metric: str = "accuracy", + title: Optional[str] = None, + save_path: Optional[Union[str, Path]] = None, + figsize: Tuple[int, int] = (12, 8), +) -> "plt.Figure": + """ + Create a ranking plot showing methods ordered by average performance. + """ + method_scores = {} + + for method, method_results in results.items(): + values = [] + for s, data in method_results.items(): + if isinstance(s, (int, float)) and isinstance(data, dict): + val = data.get('accuracy_after_ft') or data.get('accuracy') + if val is not None: + values.append(val * 100 if val <= 1.0 else val) + + if values: + method_scores[method] = { + 'mean': np.mean(values), + 'std': np.std(values), + } + + if not method_scores: + return plt.figure() + + sorted_methods = sorted(method_scores.items(), key=lambda x: x[1]['mean'], reverse=True) + methods = [m for m, _ in sorted_methods] + means = [d['mean'] for _, d in sorted_methods] + stds = [d['std'] for _, d in sorted_methods] + + fig, ax = plt.subplots(figsize=figsize) + y_pos = np.arange(len(methods)) + colors = plt.cm.RdYlGn(np.linspace(0.2, 0.8, len(methods)))[::-1] + + bars = ax.barh(y_pos, means, xerr=stds, color=colors, + edgecolor='black', linewidth=0.5, capsize=3) + + for bar, mean, std in zip(bars, means, stds): + width = bar.get_width() + ax.text(width + std + 1, bar.get_y() + bar.get_height()/2, + f'{mean:.1f}±{std:.1f}%', va='center', fontsize=10) + + ax.set_yticks(y_pos) + ax.set_yticklabels([m.replace('_', ' ') for m in methods], fontsize=11) + ax.set_xlabel(f'Average {metric.title()} (%)', fontsize=12) + title = title or 'Pruning Method Ranking' + ax.set_title(title, fontsize=14, fontweight='bold') + ax.grid(True, alpha=0.3, axis='x') + + for i in range(len(sorted_methods)): + ax.text(-2, i, f'#{i+1}', va='center', ha='right', fontsize=11, fontweight='bold') + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"Saved ranking plot to {save_path}") + + return fig diff --git a/src/alignment/analysis/visualization/unified_visualizer.py b/src/alignment/analysis/visualization/unified_visualizer.py index 08718d13..3d378b3f 100644 --- a/src/alignment/analysis/visualization/unified_visualizer.py +++ b/src/alignment/analysis/visualization/unified_visualizer.py @@ -190,7 +190,7 @@ def plot_layer_scores( pc.set_facecolor("lightblue") pc.set_alpha(0.7) elif plot_type == "box": - ax.boxplot(data, positions=positions, labels=layer_names) + ax.boxplot(data, positions=positions, labels=layer_names, showfliers=False) elif plot_type == "bar": means = [np.mean(d) for d in data] stds = [np.std(d) for d in data] @@ -530,7 +530,7 @@ def _to_array(z): data = [supernode_vals, non_supernode_vals] labels = [f"Supernodes\n(n={len(supernode_vals)})", f"Non-supernodes\n(n={len(non_supernode_vals)})"] - bp = ax.boxplot(data, labels=labels, patch_artist=True, notch=True) + bp = ax.boxplot(data, labels=labels, patch_artist=True, notch=True, showfliers=False) bp['boxes'][0].set_facecolor('coral') bp['boxes'][1].set_facecolor('steelblue') ax.set_ylabel(metric_name) @@ -2543,10 +2543,10 @@ def plot_redundancy_comparison( dpi=self.dpi, bbox_inches='tight') figures.append(fig) - # Plot 3: Box plot comparison + # Plot 3: Box plot comparison (without outlier dots for cleaner visualization) fig, ax = plt.subplots(figsize=(8, 6)) bp = ax.boxplot([high_vals, low_vals], labels=['High-connected', 'Low-connected'], - patch_artist=True, notch=True) + patch_artist=True, notch=True, showfliers=False) bp['boxes'][0].set_facecolor('coral') bp['boxes'][1].set_facecolor('steelblue') ax.set_ylabel("Absolute Pairwise Correlation (Redundancy)") diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index 91664a25..cdf724ec 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -159,6 +159,15 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: enabled_metrics = [] metric_configs = {} + # Extract optimization options (apply to all metrics) + optimization = metrics.get("optimization", {}) + global_optimization_opts = { + "use_jit": optimization.get("use_jit", False), + "use_gpu_acceleration": optimization.get("use_gpu_acceleration", False), + "force_cpu_for_large_ops": optimization.get("force_cpu_for_large_ops", True), + "cpu_threshold": optimization.get("cpu_threshold", 100000000), + } + # Check each unified metric for unified_name, original_name in METRIC_UNIFIED_TO_ORIGINAL.items(): if unified_name in metrics: @@ -166,12 +175,16 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: if isinstance(metric_cfg, dict): if metric_cfg.get("enabled", True): enabled_metrics.append(original_name) - # Copy metric-specific params + # Copy metric-specific params + optimization options params = {k: v for k, v in metric_cfg.items() if k != "enabled"} + # Apply global optimization options + params.update(global_optimization_opts) if params: metric_configs[original_name] = params elif metric_cfg is True: enabled_metrics.append(original_name) + # Apply optimization options even for simple enabled metrics + metric_configs[original_name] = global_optimization_opts.copy() # Handle SCAR metrics (LLM-specific) if "scar" in metrics: @@ -342,6 +355,9 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: # Also set experiment.output_dir for compatibility if "experiment" in original: original["experiment"]["output_dir"] = out["dir"] + # Handle base_output_dir for job directory structure + if "base_dir" in out: + original["base_output_dir"] = out["base_dir"] # ------------------------------------------------------------------------- # EXTRA - Expand LLM-specific settings from extra block to top-level @@ -668,12 +684,24 @@ def _map_nested_to_flat_config(nested_config: Dict[str, Any]) -> Dict[str, Any]: if enabled_metrics is not None: flat_config["metrics"] = enabled_metrics + # Extract optimization options (apply to all metrics) + optimization = metric_block.get("optimization", {}) + global_optimization_opts = { + "use_jit": optimization.get("use_jit", False), + "use_gpu_acceleration": optimization.get("use_gpu_acceleration", False), + "force_cpu_for_large_ops": optimization.get("force_cpu_for_large_ops", True), + "cpu_threshold": optimization.get("cpu_threshold", 100000000), + } + flat_config["metric_optimization"] = global_optimization_opts + metric_configs = flat_config.get("metric_configs", {}).copy() for metric_name, metric_cfg in metric_block.items(): - if metric_name == "enabled" or metric_cfg is None: + if metric_name in ("enabled", "optimization") or metric_cfg is None: continue if isinstance(metric_cfg, dict): - metric_configs[metric_name] = metric_cfg + # Merge optimization options into each metric config + merged_cfg = {**global_optimization_opts, **metric_cfg} + metric_configs[metric_name] = merged_cfg if metric_configs: flat_config["metric_configs"] = metric_configs @@ -939,6 +967,15 @@ def _map_nested_to_flat_config(nested_config: Dict[str, Any]) -> Dict[str, Any]: flat_config["pruning_amounts"] = pruning_block.get("sparsity_levels", nested_config.get("pruning_amounts", [0.1, 0.3, 0.5, 0.7, 0.9])) selection_modes = pruning_block.get("selection_modes", nested_config.get("pruning_selection_mode", "low")) flat_config["pruning_selection_mode"] = selection_modes + flat_config["pruning_distribution"] = pruning_block.get( + "distribution", nested_config.get("pruning_distribution", "uniform") + ) + flat_config["pruning_min_per_layer"] = pruning_block.get( + "min_per_layer", nested_config.get("pruning_min_per_layer", 0.0) + ) + flat_config["pruning_max_per_layer"] = pruning_block.get( + "max_per_layer", nested_config.get("pruning_max_per_layer", 0.95) + ) # Only set fine_tune defaults if not already set from fine_tune block above if "fine_tune_after_pruning" not in flat_config: flat_config["fine_tune_after_pruning"] = pruning_block.get("fine_tune_after_pruning", nested_config.get("fine_tune_after_pruning", True)) @@ -1067,6 +1104,16 @@ def _map_nested_to_flat_config(nested_config: Dict[str, Any]) -> Dict[str, Any]: # Map paths flat_config["log_dir"] = nested_config.get("results_path", "./logs") flat_config["checkpoint_dir"] = os.path.join(flat_config["log_dir"], "checkpoints") + + # Handle base_output_dir for job directory structure + # Priority: output.base_dir > experiment.base_output_dir > top-level base_output_dir + output_block = nested_config.get("output", {}) + if isinstance(output_block, dict) and "base_dir" in output_block: + flat_config["base_output_dir"] = output_block["base_dir"] + elif "base_output_dir" in experiment_block: + flat_config["base_output_dir"] = experiment_block["base_output_dir"] + elif "base_output_dir" in nested_config: + flat_config["base_output_dir"] = nested_config["base_output_dir"] return flat_config diff --git a/src/alignment/core/base.py b/src/alignment/core/base.py index 5d246d5f..c89f6183 100644 --- a/src/alignment/core/base.py +++ b/src/alignment/core/base.py @@ -20,7 +20,19 @@ class BaseMetric(ABC): - """Base class for all alignment metrics.""" + """Base class for all alignment metrics. + + Common Configuration Options: + name: Optional custom name for the metric + force_cpu_for_large_ops: bool = True - Move to CPU for large tensors + cpu_threshold: int = 1e8 - Element count threshold for CPU fallback + use_jit: bool = False - Use JIT-compiled implementations when available + use_gpu_acceleration: bool = False - Use GPU-accelerated implementations + + Note: JIT and GPU acceleration provide 20-50% speedup but require: + - use_jit: PyTorch JIT compilation support + - use_gpu_acceleration: CUDA-enabled GPU + """ def __init__(self, name: Optional[str] = None, **config: Any): """ @@ -34,6 +46,16 @@ def __init__(self, name: Optional[str] = None, **config: Any): self.config = config self._force_cpu_for_large_ops = config.get("force_cpu_for_large_ops", True) self._cpu_threshold = config.get("cpu_threshold", 1e8) # 100M elements + + # Optimization options + self._use_jit = config.get("use_jit", False) + self._use_gpu_acceleration = config.get("use_gpu_acceleration", False) + + # Initialize JIT/GPU accelerated functions if requested + self._jit_compute = None + self._gpu_functions = None + if self._use_jit or self._use_gpu_acceleration: + self._setup_optimizations() @property def name(self) -> str: @@ -117,8 +139,84 @@ def _should_use_cpu(self, *tensors: torch.Tensor) -> bool: total_elements = sum(t.numel() for t in tensors if t is not None) return total_elements > self._cpu_threshold + def _setup_optimizations(self) -> None: + """ + Setup JIT and GPU optimized functions if available. + + Override this in subclasses to enable metric-specific optimizations. + """ + # Import optimization modules lazily + if self._use_jit: + try: + from ..infrastructure.computing.optimized.jit import ( + compute_rayleigh_quotient_jit, + compute_mutual_information_gaussian_jit, + compute_node_correlation_jit, + compute_cosine_similarity_matrix_jit, + ) + self._jit_functions = { + "rayleigh_quotient": compute_rayleigh_quotient_jit, + "mutual_information": compute_mutual_information_gaussian_jit, + "node_correlation": compute_node_correlation_jit, + "cosine_similarity": compute_cosine_similarity_matrix_jit, + } + logger.debug(f"{self._name}: JIT functions loaded") + except ImportError as e: + logger.warning(f"Could not load JIT functions: {e}") + self._jit_functions = {} + else: + self._jit_functions = {} + + if self._use_gpu_acceleration: + try: + from ..infrastructure.computing.optimized.gpu import ( + gpu_histogram1d, + gpu_histogram2d, + gpu_mutual_information, + gpu_entropy, + GPUAcceleratedMetrics, + ) + self._gpu_functions = { + "histogram1d": gpu_histogram1d, + "histogram2d": gpu_histogram2d, + "mutual_information": gpu_mutual_information, + "entropy": gpu_entropy, + "fast_covariance": GPUAcceleratedMetrics.fast_covariance, + "fast_correlation": GPUAcceleratedMetrics.fast_correlation, + } + logger.debug(f"{self._name}: GPU functions loaded") + except ImportError as e: + logger.warning(f"Could not load GPU functions: {e}") + self._gpu_functions = {} + else: + self._gpu_functions = {} + + @property + def use_jit(self) -> bool: + """Whether JIT-compiled functions are enabled.""" + return self._use_jit and bool(getattr(self, "_jit_functions", {})) + + @property + def use_gpu_acceleration(self) -> bool: + """Whether GPU-accelerated functions are enabled.""" + return self._use_gpu_acceleration and bool(getattr(self, "_gpu_functions", {})) + + def _get_jit_function(self, name: str): + """Get a JIT-compiled function by name, or None if not available.""" + return getattr(self, "_jit_functions", {}).get(name) + + def _get_gpu_function(self, name: str): + """Get a GPU-accelerated function by name, or None if not available.""" + return getattr(self, "_gpu_functions", {}).get(name) + def __repr__(self) -> str: - return f"{self.__class__.__name__}(name='{self.name}')" + opt_flags = [] + if self._use_jit: + opt_flags.append("jit") + if self._use_gpu_acceleration: + opt_flags.append("gpu") + opt_str = f", optimizations=[{', '.join(opt_flags)}]" if opt_flags else "" + return f"{self.__class__.__name__}(name='{self.name}'{opt_str})" class BaseModel(ABC): diff --git a/src/alignment/data/datasets/__init__.py b/src/alignment/data/datasets/__init__.py index fb119881..a536917f 100644 --- a/src/alignment/data/datasets/__init__.py +++ b/src/alignment/data/datasets/__init__.py @@ -1,23 +1,31 @@ """ Dataset implementations for alignment analysis. -This module provides a unified dataset interface that can handle -various dataset types without code duplication. +NOTE: This module re-exports from alignment.dataops.datasets for backward compatibility. +The canonical location is alignment.dataops.datasets. """ -from typing import Optional, Tuple - -import torch.utils.data - -# Import for backward compatibility - these are now created dynamically -# but we import them to make them available at module level -from alignment.core.registry import DATASET_REGISTRY -from alignment.data.datasets.unified_dataset import DATASET_CONFIGS, UnifiedDataset - -# Try to import text datasets (optional - requires additional dependencies) +# Re-export everything from the canonical location +from alignment.dataops.datasets import ( + DATASET_CONFIGS, + UnifiedDataset, + get_dataset, + MNISTDataset, + FashionMNISTDataset, + CIFAR10Dataset, + CIFAR100Dataset, + ImageNetDataset, + SVHNDataset, +) + +# Try to import text datasets (optional) try: - from alignment.data.datasets.text_datasets import C4Dataset, TextDataset, WikiTextDataset, load_text_dataset - + from alignment.dataops.datasets import ( + TextDataset, + WikiTextDataset, + C4Dataset, + load_text_dataset, + ) HAS_TEXT_DATASETS = True except ImportError: HAS_TEXT_DATASETS = False @@ -26,47 +34,6 @@ C4Dataset = None load_text_dataset = None -# Get dynamically created dataset classes -MNISTDataset = DATASET_REGISTRY.get("mnist") -FashionMNISTDataset = DATASET_REGISTRY.get("fashion_mnist") -CIFAR10Dataset = DATASET_REGISTRY.get("cifar10") -CIFAR100Dataset = DATASET_REGISTRY.get("cifar100") -ImageNetDataset = DATASET_REGISTRY.get("imagenet") -SVHNDataset = DATASET_REGISTRY.get("svhn") - - -def get_dataset( - dataset_name: str, batch_size: int = 128, num_workers: int = 4, **kwargs -) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: - """ - Get train and validation data loaders for a dataset. - - Args: - dataset_name: Name of the dataset (mnist, cifar10, etc.) - batch_size: Batch size for data loaders - num_workers: Number of worker processes - **kwargs: Additional arguments for dataset - - Returns: - Tuple of (train_loader, val_loader) - """ - # Get dataset class from registry - dataset_class = DATASET_REGISTRY.get(dataset_name) - if dataset_class is None: - available = list(DATASET_REGISTRY._registry.keys()) - raise ValueError(f"Unknown dataset: {dataset_name}. Available: {available}") - - # Create train and val datasets - train_dataset = dataset_class(train=True, **kwargs) - val_dataset = dataset_class(train=False, **kwargs) - - # Create data loaders - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) - - val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) - - return train_loader, val_loader - __all__ = [ "UnifiedDataset", @@ -82,11 +49,9 @@ def get_dataset( # Add text datasets to exports if available if HAS_TEXT_DATASETS: - __all__.extend( - [ - "TextDataset", - "WikiTextDataset", - "C4Dataset", - "load_text_dataset", - ] - ) + __all__.extend([ + "TextDataset", + "WikiTextDataset", + "C4Dataset", + "load_text_dataset", + ]) diff --git a/src/alignment/data/datasets/unified_dataset.py b/src/alignment/data/datasets/unified_dataset.py deleted file mode 100644 index 3329816d..00000000 --- a/src/alignment/data/datasets/unified_dataset.py +++ /dev/null @@ -1,345 +0,0 @@ -""" -Unified dataset wrapper for alignment analysis. - -This module provides a single, flexible dataset class that can handle -various dataset types without code duplication. -""" - -import logging -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch.utils.data import Dataset -from torchvision import datasets, transforms - -from alignment.core.registry import register_dataset -from alignment.dataops.base import BaseDataset - -logger = logging.getLogger(__name__) - - -# Dataset configurations -DATASET_CONFIGS = { - "mnist": { - "dataset_class": datasets.MNIST, - "mean": 0.1307, - "std": 0.3081, - "num_classes": 10, - "input_shape": (1, 28, 28), - "class_names": [str(i) for i in range(10)], - "augmentation": {"rotation": 10, "translate": (0.1, 0.1), "scale": (0.9, 1.1)}, - }, - "fashion_mnist": { - "dataset_class": datasets.FashionMNIST, - "mean": 0.2860, - "std": 0.3530, - "num_classes": 10, - "input_shape": (1, 28, 28), - "class_names": ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"], - "augmentation": {"rotation": 10, "translate": (0.1, 0.1), "scale": (0.9, 1.1)}, - }, - "cifar10": { - "dataset_class": datasets.CIFAR10, - "mean": [0.4914, 0.4822, 0.4465], - "std": [0.2470, 0.2435, 0.2616], - "num_classes": 10, - "input_shape": (3, 32, 32), - "class_names": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"], - "augmentation": { - "crop": 32, - "padding": 4, - "horizontal_flip": True, - "color_jitter": {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.1}, - }, - }, - "cifar100": { - "dataset_class": datasets.CIFAR100, - "mean": [0.5071, 0.4865, 0.4409], - "std": [0.2673, 0.2564, 0.2762], - "num_classes": 100, - "input_shape": (3, 32, 32), - "augmentation": { - "crop": 32, - "padding": 4, - "horizontal_flip": True, - "color_jitter": {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.1}, - "rotation": 15, - }, - }, - "imagenet": { - "dataset_class": datasets.ImageNet, - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - "input_shape": (3, 224, 224), - "augmentation": { - "random_resized_crop": 224, - "horizontal_flip": True, - "color_jitter": {"brightness": 0.4, "contrast": 0.4, "saturation": 0.4, "hue": 0.1}, - }, - "val_transforms": {"resize": 256, "center_crop": 224}, - }, - "svhn": { - "dataset_class": datasets.SVHN, - "mean": [0.4377, 0.4438, 0.4728], - "std": [0.1980, 0.2010, 0.1970], - "num_classes": 10, - "input_shape": (3, 32, 32), - "class_names": [str(i) for i in range(10)], - "split_arg": "split", # SVHN uses 'split' instead of 'train' - "augmentation": { - "crop": 32, - "padding": 4, - "horizontal_flip": False, # Numbers shouldn't be flipped - "color_jitter": {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.1}, - }, - }, -} - - -@register_dataset("unified") -class UnifiedDataset(BaseDataset): - """ - Unified dataset wrapper that can handle multiple dataset types. - - This class provides a single interface for various datasets, - eliminating code duplication while maintaining flexibility. - """ - - def __init__( - self, - dataset_type: str, - data_path: Optional[str] = None, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True, - normalize: bool = True, - augment: bool = False, - custom_config: Optional[Dict[str, Any]] = None, - **kwargs, - ): - """ - Initialize unified dataset. - - Args: - dataset_type: Type of dataset ('mnist', 'cifar10', 'imagenet', etc.) - data_path: Path to store/load data - train: Whether to load training set - transform: Additional transforms - target_transform: Transform for targets - download: Whether to download if not found - normalize: Whether to normalize - augment: Whether to apply augmentation - custom_config: Custom configuration to override defaults - **kwargs: Additional dataset-specific arguments - """ - # Get dataset configuration - if dataset_type not in DATASET_CONFIGS: - raise ValueError(f"Unknown dataset type: {dataset_type}. " f"Available: {list(DATASET_CONFIGS.keys())}") - - self.dataset_config = DATASET_CONFIGS[dataset_type].copy() - - # Apply custom config if provided - if custom_config: - self.dataset_config.update(custom_config) - - # Initialize base class - super().__init__( - name=dataset_type.upper(), - data_path=data_path or "./data", - train=train, - transform=transform, - target_transform=target_transform, - download=download, - normalize=normalize, - augment=augment, - **kwargs, - ) - - self.dataset_type = dataset_type - self._initialize_dataset(**kwargs) - - def _initialize_dataset(self, **kwargs): - """Initialize the underlying dataset.""" - dataset_class = self.dataset_config["dataset_class"] - - # Prepare dataset arguments - dataset_args = { - "root": self._data_path, - "transform": self.get_transform(), - "target_transform": self.target_transform, - "download": self.download, - } - - # Handle dataset-specific arguments - if self.dataset_type == "svhn": - # SVHN uses 'split' instead of 'train' - dataset_args["split"] = "train" if self.train else "test" - elif self.dataset_type == "imagenet": - # ImageNet uses 'split' parameter - dataset_args["split"] = "train" if self.train else "val" - else: - # Most datasets use 'train' parameter - dataset_args["train"] = self.train - - # Add any additional kwargs - dataset_args.update(kwargs) - - # Create dataset - try: - self._dataset = dataset_class(**dataset_args) - except Exception as e: - logger.error(f"Failed to initialize {self.dataset_type} dataset: {e}") - raise - - @property - def mean(self) -> Union[float, List[float]]: - """Dataset mean for normalization.""" - return self.dataset_config["mean"] - - @property - def std(self) -> Union[float, List[float]]: - """Dataset standard deviation for normalization.""" - return self.dataset_config["std"] - - @property - def num_classes(self) -> int: - """Number of classes in the dataset.""" - return self.dataset_config["num_classes"] - - @property - def input_shape(self) -> Tuple[int, ...]: - """Shape of a single input sample.""" - return self.dataset_config["input_shape"] - - @property - def class_names(self) -> List[str]: - """List of class names.""" - if "class_names" in self.dataset_config: - return self.dataset_config["class_names"] - else: - # Generate generic class names - return [f"class_{i}" for i in range(self.num_classes)] - - def _get_basic_transforms(self) -> List[Callable]: - """Get basic transforms based on dataset type.""" - transforms_list = [] - - # Add dataset-specific basic transforms - if self.dataset_type == "imagenet" and not self.train: - # ImageNet validation needs resize and center crop - val_config = self.dataset_config.get("val_transforms", {}) - if "resize" in val_config: - transforms_list.append(transforms.Resize(val_config["resize"])) - if "center_crop" in val_config: - transforms_list.append(transforms.CenterCrop(val_config["center_crop"])) - - # Always convert to tensor - transforms_list.append(transforms.ToTensor()) - - return transforms_list - - def _get_augmentation_transforms(self) -> List[Callable]: - """Get augmentation transforms based on dataset configuration.""" - if "augmentation" not in self.dataset_config: - return [] - - aug_config = self.dataset_config["augmentation"] - transforms_list = [] - - # Random crop with padding - if "crop" in aug_config and "padding" in aug_config: - transforms_list.append(transforms.RandomCrop(aug_config["crop"], padding=aug_config["padding"])) - - # Random resized crop (for ImageNet) - if "random_resized_crop" in aug_config: - transforms_list.append(transforms.RandomResizedCrop(aug_config["random_resized_crop"])) - - # Random horizontal flip - if aug_config.get("horizontal_flip", False): - transforms_list.append(transforms.RandomHorizontalFlip()) - - # Random rotation - if "rotation" in aug_config: - transforms_list.append(transforms.RandomRotation(aug_config["rotation"])) - - # Random affine - if "translate" in aug_config or "scale" in aug_config: - transforms_list.append(transforms.RandomAffine(degrees=0, translate=aug_config.get("translate"), scale=aug_config.get("scale"))) - - # Color jitter - if "color_jitter" in aug_config: - jitter_params = aug_config["color_jitter"] - transforms_list.append( - transforms.ColorJitter( - brightness=jitter_params.get("brightness", 0), - contrast=jitter_params.get("contrast", 0), - saturation=jitter_params.get("saturation", 0), - hue=jitter_params.get("hue", 0), - ) - ) - - return transforms_list - - def __len__(self) -> int: - """Get dataset length.""" - return len(self._dataset) - - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Any]: - """Get a sample from the dataset.""" - return self._dataset[idx] - - def get_targets(self) -> torch.Tensor: - """Get all targets as a tensor.""" - if hasattr(self._dataset, "targets"): - return torch.tensor(self._dataset.targets) - elif hasattr(self._dataset, "labels"): - return torch.tensor(self._dataset.labels) - else: - # Fallback: iterate through dataset - targets = [] - for _, target in self._dataset: - targets.append(target) - return torch.tensor(targets) - - def add_dataset_type(self, name: str, dataset_class: type, config: Dict[str, Any]) -> None: - """ - Add a new dataset type to the registry. - - Args: - name: Name for the dataset type - dataset_class: Dataset class (e.g., from torchvision) - config: Configuration dictionary with required fields - """ - required_fields = ["mean", "std", "num_classes", "input_shape"] - for field in required_fields: - if field not in config: - raise ValueError(f"Config must include '{field}'") - - config["dataset_class"] = dataset_class - DATASET_CONFIGS[name] = config - logger.info(f"Added new dataset type: {name}") - - -# Register specific dataset types for backward compatibility -def create_dataset_class(dataset_type): - """Create a dataset class for a specific type.""" - - @register_dataset(dataset_type) - class SpecificDataset(UnifiedDataset): - """Dataset wrapper for specific dataset type.""" - - def __init__(self, **kwargs): - super().__init__(dataset_type=dataset_type, **kwargs) - - # Set proper class name - SpecificDataset.__name__ = f"{dataset_type.upper()}Dataset" - SpecificDataset.__qualname__ = f"{dataset_type.upper()}Dataset" - return SpecificDataset - - -# Create and register all dataset types -for dataset_type in DATASET_CONFIGS.keys(): - create_dataset_class(dataset_type) diff --git a/src/alignment/experiments/__init__.py b/src/alignment/experiments/__init__.py index 2eeddc4c..01ab0185 100644 --- a/src/alignment/experiments/__init__.py +++ b/src/alignment/experiments/__init__.py @@ -2,21 +2,10 @@ Experiments module for alignment analysis. This module provides various experiments for analyzing neural network alignment, -including general alignment analysis, multi-network experiments, and utilities. +including general alignment analysis, LLM alignment, and cluster-based analysis. """ from .base import BaseExperiment, ExperimentConfig - -# Configuration components -from .config_components import ( - CNNConfig, - EvaluationConfig, - MultiNetworkConfig, - PruningConfig, - TrainingConfig, - create_backward_compatible_config, - create_config_from_dict, -) from .general_alignment import GeneralAlignmentConfig, GeneralAlignmentExperiment from .llm_experiments import LLMAlignmentExperiment from .cluster_experiments import ( @@ -26,9 +15,6 @@ VisionExperimentConfig, # backward compat ) -# Training utilities -from .training_utils import convert_training_history, create_experiment_trainer, evaluate_with_metrics, train_with_metrics - __all__ = [ # Base classes "BaseExperiment", @@ -41,17 +27,4 @@ "ClusterAnalysisConfig", "VisionExperiment", # backward compat alias "VisionExperimentConfig", # backward compat alias - # Configuration components - "TrainingConfig", - "PruningConfig", - "EvaluationConfig", - "CNNConfig", - "MultiNetworkConfig", - "create_config_from_dict", - "create_backward_compatible_config", - # Training utilities - "create_experiment_trainer", - "train_with_metrics", - "evaluate_with_metrics", - "convert_training_history", ] diff --git a/src/alignment/experiments/base.py b/src/alignment/experiments/base.py index 00d6a2f5..de9a6d3e 100644 --- a/src/alignment/experiments/base.py +++ b/src/alignment/experiments/base.py @@ -68,6 +68,8 @@ class ExperimentConfig: # Metrics configuration metrics: List[str] = field(default_factory=lambda: ["rayleigh_quotient"]) metric_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict) + # Optimization options for metric computation (use_jit, use_gpu_acceleration, etc.) + metric_optimization: Dict[str, Any] = field(default_factory=dict) tracked_layers: Optional[List[str]] = None scale_by_norm: bool = False # Whether to scale alignment scores by weight norm force_cpu_for_large_metric_ops: bool = True # Move large operations to CPU @@ -113,6 +115,9 @@ class ExperimentConfig: pruning_alignment_metric: str = "rayleigh_quotient" pruning_hybrid_alpha: float = 0.5 pruning_scope: str = "layer" # "global" or "layer" + pruning_distribution: str = "uniform" + pruning_min_per_layer: float = 0.0 + pruning_max_per_layer: float = 0.95 fine_tune_learning_rate: Optional[float] = None # Will default to learning_rate * 0.1 alignment_structured_pruning: bool = False # Use structured pruning for alignment cascading_direction: str = "forward" # Direction for cascading pruning @@ -136,10 +141,12 @@ class ExperimentConfig: checkpoint_interval: int = 1000 save_best: bool = True - # Logging + # Logging and Output Directories log_dir: str = "./logs" log_interval: int = 100 plots_dir: str = "./plots" # Directory for saving plots + experiment_dir: Optional[str] = None # Root experiment directory (set by runner) + base_output_dir: Optional[str] = None # Base directory for creating job-specific output folders wandb_project: Optional[str] = None wandb_entity: Optional[str] = None diff --git a/src/alignment/experiments/config_components.py b/src/alignment/experiments/config_components.py deleted file mode 100644 index 4a509b08..00000000 --- a/src/alignment/experiments/config_components.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Composable configuration components for experiments. - -This module provides reusable configuration building blocks that can be -composed to create experiment configurations with less duplication. -""" - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - - -@dataclass -class TrainingConfig: - """Training-related configuration.""" - - train_before_dropout: bool = True - training_epochs: int = 10 - learning_rate: float = 0.001 - optimizer: str = "adam" - scheduler: Optional[str] = None - scheduler_config: Dict[str, Any] = field(default_factory=dict) - batch_size: int = 32 - gradient_clip_val: Optional[float] = None - early_stopping_patience: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary.""" - return { - "train_before_dropout": self.train_before_dropout, - "training_epochs": self.training_epochs, - "learning_rate": self.learning_rate, - "optimizer": self.optimizer, - "scheduler": self.scheduler, - "scheduler_config": self.scheduler_config, - "batch_size": self.batch_size, - "gradient_clip_val": self.gradient_clip_val, - "early_stopping_patience": self.early_stopping_patience, - } - - -@dataclass -class PruningConfig: - """Pruning/dropout configuration.""" - - dropout_rates: List[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]) - dropout_mode: str = "scaled" - num_random_trials: int = 3 - pruning_metric: str = "rayleigh_quotient" - pruning_strategy: str = "low" # "low", "high", "random" - exclude_classification_layer: bool = True - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary.""" - return { - "dropout_rates": self.dropout_rates, - "dropout_mode": self.dropout_mode, - "num_random_trials": self.num_random_trials, - "pruning_metric": self.pruning_metric, - "pruning_strategy": self.pruning_strategy, - "exclude_classification_layer": self.exclude_classification_layer, - } - - -@dataclass -class EvaluationConfig: - """Evaluation configuration.""" - - eval_batches: Optional[int] = None - eval_frequency: int = 1 - compute_alignment_during_eval: bool = True - save_predictions: bool = False - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary.""" - return { - "eval_batches": self.eval_batches, - "eval_frequency": self.eval_frequency, - "compute_alignment_during_eval": self.compute_alignment_during_eval, - "save_predictions": self.save_predictions, - } - - -@dataclass -class CNNConfig: - """CNN-specific configuration.""" - - cnn_mode: str = "unfold" # "unfold", "patchwise", "batch_patch_combined" - kernel_size: Optional[int] = None - stride: Optional[int] = None - padding: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary.""" - return {"cnn_mode": self.cnn_mode, "kernel_size": self.kernel_size, "stride": self.stride, "padding": self.padding} - - -@dataclass -class MultiNetworkConfig: - """Multi-network training configuration.""" - - num_networks: int = 1 - parallel_batch_size: Optional[int] = None - use_tensorized_training: bool = True - aggregate_metrics: bool = True - save_individual_networks: bool = False - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary.""" - return { - "num_networks": self.num_networks, - "parallel_batch_size": self.parallel_batch_size, - "use_tensorized_training": self.use_tensorized_training, - "aggregate_metrics": self.aggregate_metrics, - "save_individual_networks": self.save_individual_networks, - } - - -# Factory functions for common configurations - - -def create_standard_training_config(epochs: int = 10, learning_rate: float = 0.001, optimizer: str = "adam", **kwargs) -> TrainingConfig: - """Create a standard training configuration.""" - return TrainingConfig(training_epochs=epochs, learning_rate=learning_rate, optimizer=optimizer, **kwargs) - - -def create_standard_pruning_config( - dropout_rates: Optional[List[float]] = None, metric: str = "rayleigh_quotient", strategy: str = "low", **kwargs -) -> PruningConfig: - """Create a standard pruning configuration.""" - return PruningConfig(dropout_rates=dropout_rates or [0.0, 0.1, 0.3, 0.5, 0.7, 0.9], pruning_metric=metric, pruning_strategy=strategy, **kwargs) - - -def create_quick_test_config(epochs: int = 2, dropout_rates: Optional[List[float]] = None) -> Dict[str, Any]: - """Create a configuration for quick testing.""" - return { - "training": create_standard_training_config(epochs=epochs), - "pruning": create_standard_pruning_config(dropout_rates=dropout_rates or [0.0, 0.5], num_random_trials=1), - "evaluation": EvaluationConfig(eval_batches=5), - } - - -# Backward compatibility helper - - -def flatten_config_dict(config_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Flatten a nested config dictionary for backward compatibility. - - Converts: - {"training": {"epochs": 10}, "pruning": {"dropout_rates": [0.1]}} - To: - {"epochs": 10, "dropout_rates": [0.1]} - """ - flat = {} - - for key, value in config_dict.items(): - if isinstance(value, dict): - flat.update(value) - elif hasattr(value, "to_dict"): - flat.update(value.to_dict()) - else: - flat[key] = value - - return flat - - -def unflatten_config_dict( - flat_dict: Dict[str, Any], - training_keys: Optional[List[str]] = None, - pruning_keys: Optional[List[str]] = None, - evaluation_keys: Optional[List[str]] = None, - cnn_keys: Optional[List[str]] = None, -) -> Dict[str, Any]: - """ - Unflatten a config dictionary into components. - - Converts flat dictionary into nested structure with components. - """ - # Default key mappings - if training_keys is None: - training_keys = [ - "train_before_dropout", - "training_epochs", - "learning_rate", - "optimizer", - "scheduler", - "scheduler_config", - "batch_size", - "gradient_clip_val", - "early_stopping_patience", - ] - - if pruning_keys is None: - pruning_keys = ["dropout_rates", "dropout_mode", "num_random_trials", "pruning_metric", "pruning_strategy", "exclude_classification_layer"] - - if evaluation_keys is None: - evaluation_keys = ["eval_batches", "eval_frequency", "compute_alignment_during_eval", "save_predictions"] - - if cnn_keys is None: - cnn_keys = ["cnn_mode", "kernel_size", "stride", "padding"] - - # Extract component configs - components = {} - remaining = flat_dict.copy() - - # Extract training config - training_dict = {} - for key in training_keys: - if key in remaining: - training_dict[key] = remaining.pop(key) - if training_dict: - components["training"] = TrainingConfig(**training_dict) - - # Extract pruning config - pruning_dict = {} - for key in pruning_keys: - if key in remaining: - pruning_dict[key] = remaining.pop(key) - if pruning_dict: - components["pruning"] = PruningConfig(**pruning_dict) - - # Extract evaluation config - eval_dict = {} - for key in evaluation_keys: - if key in remaining: - eval_dict[key] = remaining.pop(key) - if eval_dict: - components["evaluation"] = EvaluationConfig(**eval_dict) - - # Extract CNN config - cnn_dict = {} - for key in cnn_keys: - if key in remaining: - cnn_dict[key] = remaining.pop(key) - if cnn_dict: - components["cnn"] = CNNConfig(**cnn_dict) - - # Add remaining keys - components.update(remaining) - - return components - - -def create_config_from_dict(config_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Create configuration components from a dictionary. - - This is a convenience function that takes a flat or nested dictionary - and returns properly structured configuration components. - """ - # If already structured with components, return as-is - if any(key in config_dict for key in ["training", "pruning", "evaluation", "cnn"]): - return config_dict - - # Otherwise, unflatten into components - return unflatten_config_dict(config_dict) - - -def create_backward_compatible_config( - base_config: Any, - training: Optional[TrainingConfig] = None, - pruning: Optional[PruningConfig] = None, - evaluation: Optional[EvaluationConfig] = None, - **kwargs, -) -> Any: - """ - Create a backward-compatible configuration by merging components into a base config. - - This allows using the new component system with existing experiment classes. - """ - # Start with base config - if hasattr(base_config, "__dict__"): - config_dict = vars(base_config).copy() - else: - config_dict = base_config.copy() if isinstance(base_config, dict) else {} - - # Merge in component configs - if training: - config_dict.update(training.to_dict()) - if pruning: - config_dict.update(pruning.to_dict()) - if evaluation: - config_dict.update(evaluation.to_dict()) - - # Add any additional kwargs - config_dict.update(kwargs) - - # If base_config was a class instance, create new instance with merged values - if hasattr(base_config, "__class__") and hasattr(base_config.__class__, "__init__"): - return base_config.__class__(**config_dict) - - return config_dict diff --git a/src/alignment/experiments/general_alignment.py b/src/alignment/experiments/general_alignment.py index bf4e5c4f..7080a4a3 100644 --- a/src/alignment/experiments/general_alignment.py +++ b/src/alignment/experiments/general_alignment.py @@ -30,6 +30,7 @@ from alignment.models import ModelWrapper from alignment.pruning.base import PruningConfig from alignment.pruning.dependency_aware import DependencyAwarePruning +from alignment.pruning.pipeline import PruningPipelineOptions, run_pruning_pipeline from alignment.pruning.strategies import MagnitudePruning, RandomPruning from alignment.services import ActivationCaptureService, MaskOperations @@ -2445,6 +2446,8 @@ def _run_dependency_aware_pruning( """ Apply dependency-aware pruning by converting per-layer scores into masks that respect downstream dependencies (e.g., Conv blocks, residual connections). + + Uses the shared pruning pipeline with configurable distribution options. """ layer_scores: Dict[str, torch.Tensor] = {} layer_outputs = layer_outputs or {} @@ -2479,9 +2482,22 @@ def _run_dependency_aware_pruning( logger.warning("Dependency-aware pruning requested but no valid layer scores were computed.") return None - dep_pruner = DependencyAwarePruning(self.model) + # Use shared pruning pipeline with config-driven options + pipeline_options = PruningPipelineOptions( + distribution=getattr(self.config, "pruning_distribution", "uniform"), + dependency_aware=True, # Always true for this method + min_amount=getattr(self.config, "pruning_min_per_layer", 0.0), + max_amount=getattr(self.config, "pruning_max_per_layer", 0.95), + ) + try: - result = dep_pruner.prune(layer_scores, amount=amount, mode=selection_mode) + result = run_pruning_pipeline( + model=self.model, + layer_scores=layer_scores, + target_sparsity=amount, + selection_mode=selection_mode, + options=pipeline_options, + ) except ValueError as exc: logger.error(f"Dependency-aware pruning failed validation: {exc}") return None diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 22eb77cb..ab2ab57a 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -11,6 +11,7 @@ from alignment.metrics import get_metric from alignment.models.transformers import TransformerWrapperEnhanced as TransformerWrapper from alignment.pruning import AlignmentPruning, PruningConfig +from alignment.pruning.pipeline import PruningPipelineOptions from alignment.pruning.strategies.llm_baselines import WandaPruning, SparseGPTPruning from alignment.services import MaskOperations from alignment.training.base import BaseTrainer # kept for compatibility if used elsewhere @@ -3122,13 +3123,19 @@ def _compute_rq_for_layer( ) -> Optional[torch.Tensor]: """Compute Rayleigh Quotient for a specific layer. - RQ measures how well each neuron's weight vector aligns with input covariance. + Standard RQ formula: RQ(w) = (w^T Σ w) / (w^T w) + where Σ is input covariance and w is a weight vector. - For different layer types: - - up_proj/gate_proj: weight [intermediate, hidden], input [hidden] - → RQ per intermediate neuron (output dim) - - down_proj: weight [hidden, intermediate], input [intermediate] - → Use weighted variance proxy per intermediate neuron + For down_proj layers: + - weight W: [hidden_dim, intermediate_dim] + - input X: [batch, intermediate_dim] + - Σ = Cov(X): [intermediate_dim, intermediate_dim] + + We compute RQ per ROW of W (each row w_i is [intermediate_dim]): + RQ_i = (w_i @ Σ @ w_i^T) / (w_i @ w_i^T) + + Then aggregate to get per-intermediate-neuron scores by looking at + how much each intermediate neuron j contributes across all output RQs. """ device = next(self.model.parameters()).device @@ -3141,7 +3148,7 @@ def _compute_rq_for_layer( for name, module in hf_model.named_modules(): if name == layer_name or name.endswith(layer_name): if hasattr(module, "weight"): - weight = module.weight.data.float() # [hidden_dim, intermediate_dim] + weight = module.weight.data.float() break if weight is None: @@ -3185,7 +3192,7 @@ def hook_fn(module, input, output): all_acts = all_acts.to(device) input_dim = all_acts.shape[1] - # Compute covariance of inputs + # Compute covariance of inputs: Σ = (X - μ)^T (X - μ) / (n-1) mean = all_acts.mean(dim=0, keepdim=True) centered = all_acts - mean cov = (centered.T @ centered) / (all_acts.shape[0] - 1) @@ -3194,17 +3201,52 @@ def hook_fn(module, input, output): weight = weight.to(device) out_dim, in_dim = weight.shape # weight: [out_dim, in_dim] - # For down_proj: FFN "neurons" are the intermediate dim (inputs to this layer) - # We want one score per intermediate neuron for consistency with up_proj/gate_proj if "down_proj" in layer_name: - # down_proj: weight is [hidden, intermediate], input is [intermediate] - # Return one score per intermediate neuron (input column) - # Use weighted input variance as RQ proxy - input_var = torch.var(all_acts, dim=0) # [intermediate_dim] - # Weight by outgoing connection strength (how much this neuron contributes) - col_norms = torch.norm(weight, dim=0) # [intermediate_dim] - rq_proxy = input_var * col_norms # [intermediate_dim] - return rq_proxy.cpu() + # weight W: [hidden_dim, intermediate_dim] + # cov Σ: [intermediate_dim, intermediate_dim] + # + # Standard RQ per OUTPUT neuron (row i of W): + # w_i = W[i, :] has shape [intermediate_dim] + # RQ_i = (w_i @ Σ @ w_i^T) / ||w_i||^2 + # + # Vectorized: W @ Σ @ W^T gives [hidden_dim, hidden_dim] + # Diagonal gives per-output-neuron RQ + + # Compute W @ Σ: [hidden_dim, intermediate_dim] + w_cov = weight @ cov # [hidden_dim, intermediate_dim] + + # Compute (W @ Σ) * W and sum over intermediate dim → w^T Σ w per row + w_cov_w = (w_cov * weight).sum(dim=1) # [hidden_dim] + + # Compute ||w||^2 per row + w_norm_sq = (weight ** 2).sum(dim=1) # [hidden_dim] + + # RQ per output neuron + rq_per_output = w_cov_w / (w_norm_sq + 1e-10) # [hidden_dim] + + # Now we need per-INTERMEDIATE-neuron scores for pruning. + # Contribution of intermediate neuron j to all output RQs: + # The term W[:, j] * Σ[j, :] @ W^T contributes to each output's RQ. + # + # Per-intermediate importance = how much does neuron j contribute to + # the total output variance? This is captured by: + # Σ[j, j] * ||W[:, j]||^2 (diagonal contribution) + # Plus weighted covariance contribution from correlations. + # + # Alternatively, use activation variance weighted by weight magnitude: + # This captures supernodes (high variance + high weight = high impact) + + # Diagonal of covariance = per-neuron variance + var_j = torch.diag(cov) # [intermediate_dim] + + # Column norms squared = weight contribution + col_norm_sq = (weight ** 2).sum(dim=0) # [intermediate_dim] + + # Per-intermediate RQ proxy: Var(j) * ||W[:, j]||^2 + # This is the diagonal contribution to output variance from neuron j + rq_per_intermediate = var_j * col_norm_sq + + return rq_per_intermediate.cpu() # For up_proj/gate_proj: weight [intermediate, hidden], input [hidden] # Check if weight columns align with input covariance @@ -5526,6 +5568,23 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm if not self.importance_scores: raise ValueError("Must compute importance scores before pruning") + # Get pruning pipeline options from config + pruning_distribution = getattr(self.config, "pruning_distribution", "uniform") + pruning_min = getattr(self.config, "pruning_min_per_layer", 0.0) + pruning_max = getattr(self.config, "pruning_max_per_layer", 0.95) + + # Store options for reference (can be used by downstream methods) + self._pruning_options = PruningPipelineOptions( + distribution=pruning_distribution, + dependency_aware=getattr(self.config, "dependency_aware_pruning", False), + min_amount=pruning_min, + max_amount=pruning_max, + ) + + # Log pruning configuration + if pruning_distribution != "uniform": + logger.info(f"Using {pruning_distribution} distribution (min={pruning_min}, max={pruning_max})") + config = PruningConfig(amount=sparsity, structured=True, pruning_mode=mode) # For SCAR metrics, baseline methods (wanda, sparsegpt), and other pre-computed scores, diff --git a/src/alignment/experiments/parallel_pruning_experiment.py b/src/alignment/experiments/parallel_pruning_experiment.py deleted file mode 100644 index bf301add..00000000 --- a/src/alignment/experiments/parallel_pruning_experiment.py +++ /dev/null @@ -1,572 +0,0 @@ -""" -Parallel experiment runner for multi-network pruning analysis. - -This module provides efficient parallel training and analysis of multiple networks -with different seeds, computing metrics and performing pruning experiments. -""" - -import json -import logging -import multiprocessing as mp -import time -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from ..analysis.visualization.pruning_plots import PruningVisualizer -from ..metrics import get_metric -from ..models import ModelWrapper -from ..pruning import PruningConfig, get_pruning_strategy -from ..training.multi_network import train_networks_fully_tensorized - -logger = logging.getLogger(__name__) - - -@dataclass -class ParallelExperimentConfig: - """Configuration for parallel pruning experiments.""" - - num_networks: int = 5 - seeds: Optional[List[int]] = None - model_class: type = None - model_kwargs: Dict[str, Any] = None - dataset_name: str = "mnist" - batch_size: int = 128 - epochs: int = 10 - learning_rate: float = 0.001 - device: str = "cuda" - - # Pruning settings - pruning_strategies: List[str] = None - pruning_modes: List[str] = None - sparsity_levels: List[float] = None - fine_tune_epochs: int = 5 - - # Metrics settings - metrics_to_compute: List[str] = None - compute_rayleigh: bool = True - - # Output settings - output_dir: str = "results/parallel_pruning" - save_checkpoints: bool = True - create_visualizations: bool = True - - def __post_init__(self): - if self.seeds is None: - self.seeds = list(range(42, 42 + self.num_networks)) - if self.pruning_strategies is None: - self.pruning_strategies = ["magnitude", "gradient", "random"] - if self.pruning_modes is None: - self.pruning_modes = ["low", "high"] - if self.sparsity_levels is None: - self.sparsity_levels = [0.1, 0.3, 0.5, 0.7, 0.9] - if self.metrics_to_compute is None: - self.metrics_to_compute = ["rayleigh_quotient", "mutual_information"] - - -class ParallelPruningExperiment: - """ - Run parallel pruning experiments on multiple networks. - - This class handles: - 1. Training multiple networks with different seeds - 2. Computing alignment metrics (RQ, MI, etc.) - 3. Applying various pruning strategies - 4. Generating comprehensive visualizations - """ - - def __init__(self, config: ParallelExperimentConfig): - """ - Initialize parallel experiment. - - Args: - config: Experiment configuration - """ - self.config = config - self.device = torch.device(config.device) - self.output_dir = Path(config.output_dir) - self.output_dir.mkdir(parents=True, exist_ok=True) - - # Initialize visualizer - self.visualizer = PruningVisualizer() - - # Setup multiprocessing - self.num_workers = min(config.num_networks, mp.cpu_count()) - - def run(self) -> Dict[str, Any]: - """ - Run the complete parallel experiment. - - Returns: - Dictionary containing all results - """ - logger.info(f"Starting parallel pruning experiment with {self.config.num_networks} networks") - - # Phase 1: Train networks in parallel - logger.info("Phase 1: Training networks...") - networks, training_history = self._train_networks_parallel() - - # Phase 2: Compute initial metrics - logger.info("Phase 2: Computing initial metrics...") - initial_metrics = self._compute_metrics_parallel(networks) - - # Phase 3: Run pruning experiments - logger.info("Phase 3: Running pruning experiments...") - pruning_results = self._run_pruning_experiments_parallel(networks) - - # Phase 4: Generate visualizations - logger.info("Phase 4: Generating visualizations...") - if self.config.create_visualizations: - self._create_visualizations(pruning_results) - - # Phase 5: Save results - logger.info("Phase 5: Saving results...") - results = { - "config": self.config.__dict__, - "training_history": training_history, - "initial_metrics": initial_metrics, - "pruning_results": pruning_results, - "timestamp": time.strftime("%Y-%m-%d_%H-%M-%S"), - } - self._save_results(results) - - logger.info("Experiment complete!") - return results - - def _train_networks_parallel(self) -> Tuple[List[nn.Module], Dict[str, Any]]: - """Train multiple networks in parallel.""" - # Get data loaders - train_loader, val_loader = self._get_data_loaders() - - # Create networks with different seeds - networks = [] - for seed in self.config.seeds[: self.config.num_networks]: - torch.manual_seed(seed) - np.random.seed(seed) - - model = self.config.model_class(**self.config.model_kwargs) - networks.append(model) - - # Train using tensorized approach if possible - if len(networks) > 1 and self._can_use_tensorized_training(networks): - logger.info(f"Using tensorized training for {len(networks)} networks") - networks, history = train_networks_fully_tensorized( - networks=networks, - train_loader=train_loader, - val_loader=val_loader, - epochs=self.config.epochs, - optimizer_kwargs={"lr": self.config.learning_rate}, - device=self.config.device, - checkpoint_dir=self.output_dir / "checkpoints" if self.config.save_checkpoints else None, - ) - else: - # Fallback to parallel training with multiprocessing - logger.info(f"Using parallel training with {self.num_workers} workers") - with ProcessPoolExecutor(max_workers=self.num_workers) as executor: - futures = [] - for i, (network, seed) in enumerate(zip(networks, self.config.seeds)): - future = executor.submit(self._train_single_network, network, train_loader, val_loader, seed, i) - futures.append(future) - - trained_networks = [] - histories = [] - for future in futures: - net, hist = future.result() - trained_networks.append(net) - histories.append(hist) - - networks = trained_networks - history = self._aggregate_histories(histories) - - return networks, history - - def _compute_metrics_parallel(self, networks: List[nn.Module]) -> Dict[str, Any]: - """Compute metrics for all networks in parallel.""" - results = {} - - # Use thread pool for metric computation - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - futures = {} - - for metric_name in self.config.metrics_to_compute: - for i, network in enumerate(networks): - key = (metric_name, i) - future = executor.submit(self._compute_single_metric, network, metric_name) - futures[key] = future - - # Collect results - for (metric_name, net_idx), future in futures.items(): - if metric_name not in results: - results[metric_name] = [] - results[metric_name].append(future.result()) - - # Compute statistics - stats = {} - for metric_name, values in results.items(): - stats[metric_name] = {"mean": np.mean(values), "std": np.std(values), "values": values} - - return stats - - def _run_pruning_experiments_parallel(self, networks: List[nn.Module]) -> Dict[str, Any]: - """Run pruning experiments on all networks.""" - results = {} - - # Get data loaders for evaluation - train_loader, val_loader = self._get_data_loaders() - - # Iterate through strategies and modes - for strategy_name in self.config.pruning_strategies: - for mode in self.config.pruning_modes: - strategy_key = f"{strategy_name}_{mode}" - logger.info(f"Running {strategy_key} pruning...") - - # Results for this strategy across all networks - strategy_results = [] - - # Process each network - for net_idx, network in enumerate(networks): - seed_results = {} - - # Test different sparsity levels - for sparsity in self.config.sparsity_levels: - # Clone network for this experiment - net_copy = self._clone_network(network) - - # Apply pruning - config = PruningConfig(amount=sparsity, pruning_mode=mode) - strategy = get_pruning_strategy(strategy_name, config=config) - - # Prune all layers - for name, module in net_copy.named_modules(): - if isinstance(module, (nn.Linear, nn.Conv2d)): - strategy.prune(module) - - # Fine-tune if requested - if self.config.fine_tune_epochs > 0: - self._fine_tune_network(net_copy, train_loader, val_loader, epochs=self.config.fine_tune_epochs) - - # Evaluate - accuracy, loss = self._evaluate_network(net_copy, val_loader) - - seed_results[sparsity] = {"accuracy": accuracy, "loss": loss} - - strategy_results.append(seed_results) - - # Aggregate results across seeds - results[strategy_key] = self._aggregate_pruning_results(strategy_results) - - return results - - def _aggregate_pruning_results(self, seed_results: List[Dict[float, Dict[str, float]]]) -> Dict[float, Dict[str, Any]]: - """Aggregate pruning results across multiple seeds.""" - aggregated = {} - - # Get all sparsity levels - sparsities = sorted(seed_results[0].keys()) - - for sparsity in sparsities: - metrics = {} - - # Collect values for each metric - for metric in ["accuracy", "loss"]: - values = [sr[sparsity][metric] for sr in seed_results] - - metrics[metric] = values - metrics[f"{metric}_mean"] = np.mean(values) - metrics[f"{metric}_std"] = np.std(values) - - aggregated[sparsity] = { - "mean": {"accuracy": metrics["accuracy_mean"], "loss": metrics["loss_mean"]}, - "std": {"accuracy": metrics["accuracy_std"], "loss": metrics["loss_std"]}, - "raw_values": {"accuracy": metrics["accuracy"], "loss": metrics["loss"]}, - } - - return aggregated - - def _create_visualizations(self, pruning_results: Dict[str, Any]): - """Create comprehensive visualizations.""" - vis_dir = self.output_dir / "visualizations" - vis_dir.mkdir(exist_ok=True) - - # 1. Main performance comparison - self.visualizer.plot_pruning_performance( - pruning_results, - metrics=["accuracy", "loss"], - save_path=vis_dir / "performance_comparison.png", - title="Pruning Strategy Performance Comparison", - show_confidence=True, - ) - - # 2. Comprehensive comparison grid - # Convert format for grid plot - grid_results = {} - for strategy_key, strategy_data in pruning_results.items(): - grid_results[strategy_key] = {} - for sparsity, data in strategy_data.items(): - grid_results[strategy_key][sparsity] = {"accuracy": data["mean"]["accuracy"], "loss": data["mean"]["loss"]} - - self.visualizer.plot_pruning_comparison_grid(grid_results, save_path=vis_dir / "comparison_grid.png") - - # 3. Multi-seed analysis - # Reorganize data by strategy - seed_results = {} - for strategy_key in pruning_results: - seed_results[strategy_key] = [] - - # Extract raw values for each seed - for seed_idx in range(self.config.num_networks): - seed_data = {} - for sparsity, data in pruning_results[strategy_key].items(): - seed_data[sparsity] = {"accuracy": data["raw_values"]["accuracy"][seed_idx], "loss": data["raw_values"]["loss"][seed_idx]} - seed_results[strategy_key].append(seed_data) - - self.visualizer.plot_multi_seed_results(seed_results, metric="accuracy", save_path=vis_dir / "multi_seed_accuracy.png") - - self.visualizer.plot_multi_seed_results(seed_results, metric="loss", save_path=vis_dir / "multi_seed_loss.png") - - def _save_results(self, results: Dict[str, Any]): - """Save all results to disk.""" - - # Convert numpy arrays to lists for JSON serialization - def convert_to_serializable(obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, (np.float32, np.float64)): - return float(obj) - elif isinstance(obj, (np.int32, np.int64)): - return int(obj) - elif isinstance(obj, dict): - return {k: convert_to_serializable(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_to_serializable(v) for v in obj] - else: - return obj - - serializable_results = convert_to_serializable(results) - - # Save as JSON - results_path = self.output_dir / "results.json" - with open(results_path, "w") as f: - json.dump(serializable_results, f, indent=2) - - # Save summary - summary_path = self.output_dir / "summary.txt" - with open(summary_path, "w") as f: - f.write(f"Parallel Pruning Experiment Summary\n") - f.write(f"{'='*50}\n\n") - f.write(f"Configuration:\n") - f.write(f" - Networks: {self.config.num_networks}\n") - f.write(f" - Strategies: {self.config.pruning_strategies}\n") - f.write(f" - Modes: {self.config.pruning_modes}\n") - f.write(f" - Sparsity levels: {self.config.sparsity_levels}\n") - f.write(f"\nBest performing strategies:\n") - - # Find best strategies - for sparsity in [0.5, 0.7, 0.9]: - f.write(f"\nAt {sparsity*100:.0f}% sparsity:\n") - best_acc = 0 - best_strategy = None - - for strategy, data in results["pruning_results"].items(): - if sparsity in data: - acc = data[sparsity]["mean"]["accuracy"] - if acc > best_acc: - best_acc = acc - best_strategy = strategy - - if best_strategy: - f.write(f" Best: {best_strategy} ({best_acc:.2f}%)\n") - - logger.info(f"Results saved to {self.output_dir}") - - # Helper methods - def _get_data_loaders(self): - """Get data loaders for the specified dataset.""" - # This is a placeholder - implement based on your data module - from ..data.datasets import get_dataset - - return get_dataset(self.config.dataset_name, batch_size=self.config.batch_size) - - def _can_use_tensorized_training(self, networks: List[nn.Module]) -> bool: - """Check if networks can be trained using tensorized approach.""" - # Simple check - all networks should have same architecture - if len(networks) < 2: - return False - - base_arch = str(networks[0]) - for net in networks[1:]: - if str(net) != base_arch: - return False - return True - - def _train_single_network(self, network: nn.Module, train_loader, val_loader, seed: int, idx: int) -> Tuple[nn.Module, Dict]: - """Train a single network (for multiprocessing).""" - # Set seeds - torch.manual_seed(seed) - np.random.seed(seed) - - # Move to device - device = torch.device(self.config.device if torch.cuda.is_available() else "cpu") - network = network.to(device) - - # Training loop - optimizer = torch.optim.Adam(network.parameters(), lr=self.config.learning_rate) - criterion = nn.CrossEntropyLoss() - - history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} - - for epoch in range(self.config.epochs): - # Train - network.train() - train_loss = 0 - correct = 0 - total = 0 - - for inputs, targets in train_loader: - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - outputs = network(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - train_loss += loss.item() - _, predicted = outputs.max(1) - correct += predicted.eq(targets).sum().item() - total += targets.size(0) - - # Validate - val_loss, val_acc = self._evaluate_network(network, val_loader) - - history["train_loss"].append(train_loss / len(train_loader)) - history["train_acc"].append(100.0 * correct / total) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - - return network.cpu(), history - - def _compute_single_metric(self, network: nn.Module, metric_name: str) -> float: - """Compute a single metric for a network.""" - metric = get_metric(metric_name)() - - # Get a sample batch for metric computation - train_loader, _ = self._get_data_loaders() - inputs, _ = next(iter(train_loader)) - inputs = inputs.to(self.device) - - network = network.to(self.device) - - # Compute metric (simplified - implement based on your metrics) - with torch.no_grad(): - if metric_name == "rayleigh_quotient": - # Example for first layer - if hasattr(network, "fc1"): - scores = metric.compute(inputs=inputs, weights=network.fc1.weight) - else: - # Find first linear layer - for module in network.modules(): - if isinstance(module, nn.Linear): - scores = metric.compute(inputs=inputs, weights=module.weight) - break - return scores.mean().item() - else: - # Implement other metrics - return 0.0 - - def _clone_network(self, network: nn.Module) -> nn.Module: - """Create a deep copy of a network.""" - import copy - - return copy.deepcopy(network) - - def _fine_tune_network(self, network: nn.Module, train_loader, val_loader, epochs: int): - """Fine-tune a pruned network.""" - device = torch.device(self.config.device if torch.cuda.is_available() else "cpu") - network = network.to(device) - - optimizer = torch.optim.Adam(network.parameters(), lr=self.config.learning_rate * 0.1) - criterion = nn.CrossEntropyLoss() - - for epoch in range(epochs): - network.train() - for inputs, targets in train_loader: - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - outputs = network(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - def _evaluate_network(self, network: nn.Module, val_loader) -> Tuple[float, float]: - """Evaluate a network.""" - device = torch.device(self.config.device if torch.cuda.is_available() else "cpu") - network = network.to(device) - network.eval() - - total_loss = 0 - correct = 0 - total = 0 - criterion = nn.CrossEntropyLoss() - - with torch.no_grad(): - for inputs, targets in val_loader: - inputs, targets = inputs.to(device), targets.to(device) - outputs = network(inputs) - loss = criterion(outputs, targets) - - total_loss += loss.item() - _, predicted = outputs.max(1) - correct += predicted.eq(targets).sum().item() - total += targets.size(0) - - accuracy = 100.0 * correct / total - avg_loss = total_loss / len(val_loader) - - return accuracy, avg_loss - - def _aggregate_histories(self, histories: List[Dict]) -> Dict: - """Aggregate training histories from multiple networks.""" - aggregated = {} - - for key in histories[0].keys(): - values = [h[key] for h in histories] - aggregated[key] = {"mean": np.mean(values, axis=0).tolist(), "std": np.std(values, axis=0).tolist(), "all": values} - - return aggregated - - -def run_parallel_pruning_experiment( - model_class: type, - model_kwargs: Dict[str, Any], - num_networks: int = 5, - dataset_name: str = "mnist", - output_dir: str = "results/parallel_pruning", - **kwargs, -) -> Dict[str, Any]: - """ - Convenience function to run a parallel pruning experiment. - - Args: - model_class: Model class to instantiate - model_kwargs: Arguments for model construction - num_networks: Number of networks to train - dataset_name: Name of dataset to use - output_dir: Directory to save results - **kwargs: Additional configuration options - - Returns: - Dictionary containing all results - """ - config = ParallelExperimentConfig( - num_networks=num_networks, model_class=model_class, model_kwargs=model_kwargs, dataset_name=dataset_name, output_dir=output_dir, **kwargs - ) - - experiment = ParallelPruningExperiment(config) - return experiment.run() diff --git a/src/alignment/experiments/training_utils.py b/src/alignment/experiments/training_utils.py deleted file mode 100644 index a011fdba..00000000 --- a/src/alignment/experiments/training_utils.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Utility functions for integrating ExperimentTrainer into experiments. -""" - -from dataclasses import asdict -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.nn as nn - -from alignment.training import ExperimentTrainer, ExperimentTrainingConfig - - -def create_experiment_trainer(model: Union[nn.Module, List[nn.Module]], config: Dict[str, Any], device: str = "cuda") -> ExperimentTrainer: - """ - Create an ExperimentTrainer from experiment config. - - Args: - model: Model or list of models to train - config: Experiment configuration dictionary - device: Device to train on - - Returns: - Configured ExperimentTrainer instance - """ - # Extract training-related config - training_config = ExperimentTrainingConfig( - epochs=config.get("training_epochs", config.get("epochs", 10)), - learning_rate=config.get("learning_rate", 0.001), - batch_size=config.get("batch_size", 32), - optimizer=config.get("optimizer", "adam"), - optimizer_kwargs=config.get("optimizer_kwargs", {}), - scheduler=config.get("scheduler", None), - scheduler_kwargs=config.get("scheduler_kwargs", {}), - device=device, - log_interval=config.get("log_interval", 100), - eval_interval=config.get("eval_interval", 1), - checkpoint_dir=config.get("checkpoint_dir", None), - early_stopping_patience=config.get("early_stopping_patience", None), - gradient_clip_val=config.get("gradient_clip_val", None), - # Multi-network specific - num_networks=config.get("num_networks", 1), - tensorized=config.get("tensorized_training", True), - save_all_networks=config.get("save_all_networks", False), - metric_aggregation=config.get("metric_aggregation", "mean"), - ) - - # Create trainer - return ExperimentTrainer(model=model, config=training_config, loss_fn=nn.CrossEntropyLoss(), callbacks=[]) - - -def train_with_metrics( - trainer: ExperimentTrainer, - train_loader: torch.utils.data.DataLoader, - val_loader: Optional[torch.utils.data.DataLoader] = None, - compute_accuracy: bool = True, -) -> Dict[str, Any]: - """ - Train using ExperimentTrainer with standard metrics. - - Args: - trainer: ExperimentTrainer instance - train_loader: Training data loader - val_loader: Optional validation data loader - compute_accuracy: Whether to compute accuracy metric - - Returns: - Training history with metrics - """ - - # Define metric function - def metric_fn(outputs: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]: - metrics = {} - if compute_accuracy: - _, predicted = outputs.max(1) - correct = predicted.eq(targets).sum().item() - total = targets.size(0) - metrics["accuracy"] = 100.0 * correct / total - return metrics - - # Train - history = trainer.train(train_loader=train_loader, val_loader=val_loader, metric_fn=metric_fn if compute_accuracy else None) - - return history - - -def convert_training_history(history: Dict[str, Any], num_networks: int = 1) -> Dict[str, Any]: - """ - Convert ExperimentTrainer history to experiment result format. - - Args: - history: Training history from ExperimentTrainer - num_networks: Number of networks trained - - Returns: - Converted results dictionary - """ - results = { - "training_epochs": len(history["train_loss"]), - "final_train_loss": history["train_loss"][-1] if history["train_loss"] else 0.0, - "final_train_accuracy": ( - history["train_metrics"][-1].get("accuracy", 0.0) if history["train_metrics"] and history["train_metrics"][-1] else 0.0 - ), - "training_history": history, - } - - if history["val_loss"]: - results["final_val_loss"] = history["val_loss"][-1] - results["final_val_accuracy"] = ( - history["val_metrics"][-1].get("accuracy", 0.0) if history["val_metrics"] and history["val_metrics"][-1] else 0.0 - ) - - # Add per-network results if multi-network - if num_networks > 1 and "per_network" in history: - results["per_network_results"] = {} - for i in range(num_networks): - network_history = history["per_network"][i] - results["per_network_results"][i] = { - "final_train_loss": network_history["train_loss"][-1] if network_history["train_loss"] else 0.0, - "final_train_accuracy": ( - network_history["train_metrics"][-1].get("accuracy", 0.0) - if network_history["train_metrics"] and network_history["train_metrics"][-1] - else 0.0 - ), - } - if network_history["val_loss"]: - results["per_network_results"][i]["final_val_loss"] = network_history["val_loss"][-1] - results["per_network_results"][i]["final_val_accuracy"] = ( - network_history["val_metrics"][-1].get("accuracy", 0.0) - if network_history["val_metrics"] and network_history["val_metrics"][-1] - else 0.0 - ) - - return results - - -def evaluate_with_metrics( - trainer: ExperimentTrainer, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: str = "cuda", compute_alignment: bool = True -) -> Dict[str, Any]: - """ - Evaluate model and compute metrics. - - Args: - trainer: The trainer instance - model: Model to evaluate - data_loader: Data loader for evaluation - device: Device to use - compute_alignment: Whether to compute alignment metrics - - Returns: - Dictionary of evaluation metrics - """ - # Basic evaluation - metrics = trainer.evaluate(model, data_loader, device=device) - - # Add alignment metrics if requested - if compute_alignment and hasattr(trainer, "compute_alignment_metrics"): - alignment_metrics = trainer.compute_alignment_metrics(model, data_loader, device=device) - metrics.update(alignment_metrics) - - return metrics diff --git a/src/alignment/infrastructure/README.md b/src/alignment/infrastructure/README.md index cad08151..34790c97 100644 --- a/src/alignment/infrastructure/README.md +++ b/src/alignment/infrastructure/README.md @@ -2,8 +2,166 @@ System utilities for computing, storage, and configuration. +## Usage Status + +| Component | Status | Description | +|-----------|--------|-------------| +| `storage/checkpoint.py` | ✅ ACTIVE | Model checkpoint save/load | +| `storage/logging.py` | ✅ ACTIVE | Logging setup and MetricLogger | +| `storage/job_directory.py` | ✅ ACTIVE | SLURM job directory management | +| `configuration/config.py` | ⚠️ AVAILABLE | Basic config utilities (use `alignment.configs` for main config) | +| `computing/distributed.py` | 🔧 AVAILABLE | Multi-GPU distributed computing (not currently integrated) | +| `computing/optimized/gpu.py` | ✅ INTEGRATED | GPU-accelerated histogram/MI (enable via config) | +| `computing/optimized/jit.py` | ✅ INTEGRATED | JIT-compiled metrics (enable via config) | + ## Components -- `computing/` - GPU utilities, distributed computing -- `storage/` - Checkpoint and result management -- `configuration/` - Configuration handling +### storage/ - Storage Infrastructure ✅ ACTIVE + +**checkpoint.py** - Model checkpoint utilities +```python +from alignment.infrastructure import save_checkpoint, load_checkpoint + +# Save model with optimizer state +save_checkpoint(model, optimizer, epoch=10, filepath="checkpoint.pt") + +# Load checkpoint +checkpoint = load_checkpoint("checkpoint.pt", model=model, optimizer=optimizer) +``` + +**logging.py** - Logging utilities +```python +from alignment.infrastructure import setup_logging, get_logger, MetricLogger + +# Setup logging +setup_logging(log_level="INFO", log_file="experiment.log") + +# Get a logger +logger = get_logger(__name__) + +# Track metrics over time +metric_logger = MetricLogger(log_dir="./logs", experiment_name="my_exp") +metric_logger.log({"loss": 0.5, "accuracy": 0.95}, step=100) +metric_logger.write_summary() +``` + +**job_directory.py** - SLURM job directory management +```python +from alignment.infrastructure.storage import create_job_directory, JobDirectory + +# Create unique job directory (auto-detects SLURM_JOB_ID) +job_dir = create_job_directory( + base_output_dir="/path/to/outputs", + experiment_name="llama3_pruning" +) +# Creates: /path/to/outputs/llama3_pruning_20241209_143052_12345/ +# ├── results/ +# ├── logs/ +# ├── checkpoints/ +# ├── figures/ +# └── analysis/ + +# Or use context manager +with JobDirectory("/path/to/outputs", "my_experiment") as job: + job.save_config(config) + job.save_results(results) +``` + +### computing/ - Computing Infrastructure 🔧 AVAILABLE + +**distributed.py** - Distributed training utilities +```python +from alignment.infrastructure import ( + setup_distributed, cleanup_distributed, + is_distributed, is_main_process, + get_rank, get_world_size +) + +# Setup distributed training +if setup_distributed(backend="nccl"): + print(f"Rank {get_rank()} of {get_world_size()}") + +# Check if main process (for logging) +if is_main_process(): + print("Only printed on rank 0") +``` + +**optimized/gpu.py** - GPU-accelerated operations +```python +from alignment.infrastructure.computing.optimized import ( + gpu_histogram1d, gpu_histogram2d, + gpu_mutual_information, gpu_entropy, + GPUAcceleratedMetrics +) + +# Fast GPU histogram +hist, edges = gpu_histogram1d(data, bins=100) + +# GPU mutual information +mi = gpu_mutual_information(x, y, bins=50) + +# JIT-compiled covariance +cov = GPUAcceleratedMetrics.fast_covariance(X) +``` + +**optimized/jit.py** - JIT-compiled metrics +```python +from alignment.infrastructure.computing.optimized import ( + JITRayleighQuotient, JITMutualInformation, JITNodeCorrelation +) + +# Create JIT-optimized metric +jit_rq = JITRayleighQuotient(epsilon=1e-8) +scores = jit_rq(inputs, weights) # Faster than regular RQ +``` + +### configuration/ - Configuration Utilities ⚠️ AVAILABLE + +Basic configuration utilities. For the main experiment configuration system, +use `alignment.configs` instead. + +```python +from alignment.infrastructure.configuration import load_config, save_config + +# Load/save config files +config = load_config("config.yaml") +save_config(config, "output.yaml") +``` + +## Enabling Optimizations via Config + +JIT and GPU acceleration are now integrated into the metric system. Enable them +via YAML config: + +```yaml +metrics: + optimization: + use_jit: true # Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: true # Enable GPU-accelerated functions + force_cpu_for_large_ops: true # Prevent OOM for large covariance matrices + cpu_threshold: 100000000 # 1e8 elements threshold +``` + +Or programmatically: + +```python +from alignment.metrics import get_optimization_status, get_metric_with_optimizations + +# Check what's available +status = get_optimization_status() +print(f"JIT available: {status['jit_available']}") +print(f"GPU available: {status['gpu_available']}") + +# Create a metric with optimizations +metric = get_metric_with_optimizations( + "rayleigh_quotient", + use_jit=True, + use_gpu_acceleration=True, + relative=True +) +``` + +## Future Integration Plans + +The `computing/distributed.py` component is ready for integration when +multi-GPU metric computation becomes a priority. diff --git a/src/alignment/infrastructure/__init__.py b/src/alignment/infrastructure/__init__.py index 4b49df61..5407c48b 100644 --- a/src/alignment/infrastructure/__init__.py +++ b/src/alignment/infrastructure/__init__.py @@ -1,19 +1,22 @@ """ -Infrastructure module for the alignment fra__all__ = [ - # Distributed computing - 'setup_distributed', - 'cleanup_distributed', - 'is_distributed', - 'is_main_process', - 'get_world_size', - 'get_rank', - 'GPUAcceleratedMetrics', - 'JITRayleighQuotient', - 'JITMutualInformation', - 'JITNodeCorrelation', - 'create_jit_metric', - # Storagedule provides utilities for distributed computing, storage, -configuration management, and optimization. +Infrastructure module for the alignment framework. + +This module provides utilities for: +- Distributed computing (multi-GPU training) +- Storage (checkpointing, logging, job directories) +- GPU optimization (accelerated metric computations) +- JIT compilation (optimized metric functions) + +USAGE STATUS: +- Storage (checkpoint, logging, job_directory): ACTIVELY USED +- Configuration: See alignment.configs for the main config system +- Computing (distributed, GPU, JIT): AVAILABLE but not currently integrated + These are optimized implementations ready for future performance improvements. + +Example: + >>> from alignment.infrastructure import save_checkpoint, load_checkpoint + >>> from alignment.infrastructure import setup_logging, get_logger + >>> from alignment.infrastructure.storage import create_job_directory """ # Computing infrastructure diff --git a/src/alignment/infrastructure/configuration/__init__.py b/src/alignment/infrastructure/configuration/__init__.py index 80659a75..eb5a7d2d 100644 --- a/src/alignment/infrastructure/configuration/__init__.py +++ b/src/alignment/infrastructure/configuration/__init__.py @@ -1,15 +1,19 @@ -"""Configuration infrastructure for the alignment framework.""" +""" +Configuration infrastructure for the alignment framework. -from .config import Config, DataConfig, ExperimentConfig, MetricConfig, ModelConfig, load_config, merge_configs, save_config, validate_config +NOTE: This module provides basic configuration utilities. +For the main experiment configuration system, use alignment.configs instead: + from alignment.configs import ExperimentConfig, load_config + +This module contains simpler utilities that can be used standalone. +""" + +from .config import Config, ExperimentConfig, load_config, merge_configs, save_config __all__ = [ "Config", "ExperimentConfig", - "MetricConfig", - "ModelConfig", - "DataConfig", "load_config", "save_config", "merge_configs", - "validate_config", ] diff --git a/src/alignment/infrastructure/storage/__init__.py b/src/alignment/infrastructure/storage/__init__.py index 4ad0c11a..4d25360b 100644 --- a/src/alignment/infrastructure/storage/__init__.py +++ b/src/alignment/infrastructure/storage/__init__.py @@ -2,6 +2,14 @@ from .checkpoint import load_checkpoint, save_checkpoint, save_model_for_inference from .logging import MetricLogger, get_logger, log_metrics, setup_logging +from .job_directory import ( + create_job_directory, + get_job_directory_paths, + setup_job_logging, + get_slurm_job_id, + get_slurm_array_task_id, + JobDirectory, +) __all__ = [ # Checkpointing @@ -13,4 +21,11 @@ "get_logger", "log_metrics", "MetricLogger", + # Job Directory + "create_job_directory", + "get_job_directory_paths", + "setup_job_logging", + "get_slurm_job_id", + "get_slurm_array_task_id", + "JobDirectory", ] diff --git a/src/alignment/infrastructure/storage/job_directory.py b/src/alignment/infrastructure/storage/job_directory.py new file mode 100644 index 00000000..9e81a226 --- /dev/null +++ b/src/alignment/infrastructure/storage/job_directory.py @@ -0,0 +1,357 @@ +""" +Job directory management for organizing experiment outputs. + +This module provides utilities for creating unique, timestamped directories +for each experiment job. All results, logs, checkpoints, and visualizations +are stored within a single job directory for cleaner organization. + +Directory Structure: + base_output_dir/ + {experiment_name}_{timestamp}_{job_id}/ + results/ + results.json + pruning_results.json + logs/ + experiment.log + checkpoints/ + figures/ + analysis/ +""" + +import logging +import os +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +logger = logging.getLogger(__name__) + + +def get_slurm_job_id() -> Optional[str]: + """ + Get the SLURM job ID if running under SLURM. + + Returns: + SLURM job ID as string, or None if not running under SLURM. + """ + return os.environ.get("SLURM_JOB_ID") + + +def get_slurm_array_task_id() -> Optional[str]: + """ + Get the SLURM array task ID if running as part of a job array. + + Returns: + SLURM array task ID as string, or None if not a job array. + """ + return os.environ.get("SLURM_ARRAY_TASK_ID") + + +def generate_unique_id() -> str: + """ + Generate a short unique identifier. + + Returns: + 8-character unique ID. + """ + return uuid.uuid4().hex[:8] + + +def create_job_directory( + base_output_dir: Union[str, Path], + experiment_name: str, + timestamp: Optional[str] = None, + job_id: Optional[str] = None, + create_subdirs: bool = True, +) -> Path: + """ + Create a unique job directory for experiment outputs. + + The directory name is formatted as: + {experiment_name}_{timestamp}_{job_id} + + Where: + - experiment_name: Name of the experiment from config + - timestamp: ISO format timestamp (defaults to current time) + - job_id: SLURM job ID if available, otherwise a unique ID + + Args: + base_output_dir: Base directory for all experiment outputs. + experiment_name: Name of the experiment. + timestamp: Optional timestamp string. If None, uses current time. + job_id: Optional job ID. If None, uses SLURM_JOB_ID or generates unique ID. + create_subdirs: Whether to create standard subdirectories. + + Returns: + Path to the created job directory. + + Example: + >>> job_dir = create_job_directory( + ... "/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM", + ... "llama3_8b_pruning" + ... ) + >>> print(job_dir) + /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/llama3_8b_pruning_20241209_143052_12345678 + """ + base_output_dir = Path(base_output_dir) + + # Generate timestamp if not provided + if timestamp is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Get job ID + if job_id is None: + # Try SLURM job ID first + slurm_job_id = get_slurm_job_id() + slurm_array_id = get_slurm_array_task_id() + + if slurm_job_id: + if slurm_array_id: + job_id = f"{slurm_job_id}_{slurm_array_id}" + else: + job_id = slurm_job_id + else: + # Generate a unique ID for non-SLURM runs + job_id = generate_unique_id() + + # Sanitize experiment name (remove special characters) + safe_name = "".join(c if c.isalnum() or c in "_-" else "_" for c in experiment_name) + + # Create directory name + dir_name = f"{safe_name}_{timestamp}_{job_id}" + job_dir = base_output_dir / dir_name + + # Create the directory + job_dir.mkdir(parents=True, exist_ok=True) + + # Create standard subdirectories + if create_subdirs: + subdirs = ["results", "logs", "checkpoints", "figures", "analysis"] + for subdir in subdirs: + (job_dir / subdir).mkdir(exist_ok=True) + + logger.info(f"Created job directory: {job_dir}") + + return job_dir + + +def get_job_directory_paths(job_dir: Union[str, Path]) -> dict: + """ + Get standard paths within a job directory. + + Args: + job_dir: Path to the job directory. + + Returns: + Dictionary with paths to standard subdirectories and files. + """ + job_dir = Path(job_dir) + + return { + "root": job_dir, + "results": job_dir / "results", + "logs": job_dir / "logs", + "checkpoints": job_dir / "checkpoints", + "figures": job_dir / "figures", + "analysis": job_dir / "analysis", + # Common file paths + "experiment_log": job_dir / "logs" / "experiment.log", + "config_file": job_dir / "experiment_config.yaml", + "results_file": job_dir / "results" / "results.json", + } + + +def setup_job_logging( + job_dir: Union[str, Path], + log_level: int = logging.INFO, +) -> logging.Logger: + """ + Setup logging to write to the job directory. + + Args: + job_dir: Path to the job directory. + log_level: Logging level. + + Returns: + Configured root logger. + """ + paths = get_job_directory_paths(job_dir) + log_file = paths["experiment_log"] + + # Ensure log directory exists + log_file.parent.mkdir(parents=True, exist_ok=True) + + # Configure logging + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + root_logger = logging.getLogger() + root_logger.info(f"Logging to: {log_file}") + + return root_logger + + +class JobDirectory: + """ + Context manager for job directories. + + Provides convenient access to job directory paths and handles + cleanup on errors. + + Example: + >>> with JobDirectory(base_dir, "my_experiment") as job: + ... # Access paths + ... print(job.results_dir) + ... print(job.figures_dir) + ... + ... # Save files + ... job.save_config(config) + ... job.save_results(results) + """ + + def __init__( + self, + base_output_dir: Union[str, Path], + experiment_name: str, + timestamp: Optional[str] = None, + job_id: Optional[str] = None, + setup_logging: bool = True, + log_level: int = logging.INFO, + ): + """ + Initialize job directory. + + Args: + base_output_dir: Base directory for experiment outputs. + experiment_name: Name of the experiment. + timestamp: Optional timestamp string. + job_id: Optional job ID. + setup_logging: Whether to configure logging. + log_level: Logging level. + """ + self.base_output_dir = Path(base_output_dir) + self.experiment_name = experiment_name + self._timestamp = timestamp + self._job_id = job_id + self._setup_logging = setup_logging + self._log_level = log_level + + self._job_dir: Optional[Path] = None + self._paths: Optional[dict] = None + + @property + def job_dir(self) -> Path: + """Get the job directory path.""" + if self._job_dir is None: + raise RuntimeError("JobDirectory not initialized. Use as context manager.") + return self._job_dir + + @property + def results_dir(self) -> Path: + """Get the results subdirectory.""" + return self._paths["results"] + + @property + def logs_dir(self) -> Path: + """Get the logs subdirectory.""" + return self._paths["logs"] + + @property + def checkpoints_dir(self) -> Path: + """Get the checkpoints subdirectory.""" + return self._paths["checkpoints"] + + @property + def figures_dir(self) -> Path: + """Get the figures subdirectory.""" + return self._paths["figures"] + + @property + def analysis_dir(self) -> Path: + """Get the analysis subdirectory.""" + return self._paths["analysis"] + + def __enter__(self) -> "JobDirectory": + """Create job directory and setup logging.""" + self._job_dir = create_job_directory( + self.base_output_dir, + self.experiment_name, + timestamp=self._timestamp, + job_id=self._job_id, + ) + self._paths = get_job_directory_paths(self._job_dir) + + if self._setup_logging: + setup_job_logging(self._job_dir, self._log_level) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Handle cleanup on exit.""" + if exc_type is not None: + logger.error(f"Job failed with error: {exc_val}") + return False # Don't suppress exceptions + + def save_config(self, config, filename: str = "experiment_config.yaml"): + """ + Save experiment configuration to the job directory. + + Args: + config: Configuration object with save() method or dict. + filename: Name of the config file. + """ + import json + import yaml + + config_path = self._job_dir / filename + + if hasattr(config, "save"): + config.save(config_path) + elif hasattr(config, "to_dict"): + with open(config_path, "w") as f: + yaml.dump(config.to_dict(), f, default_flow_style=False) + elif isinstance(config, dict): + with open(config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + else: + raise ValueError(f"Cannot save config of type {type(config)}") + + logger.info(f"Saved config to: {config_path}") + + def save_results(self, results: dict, filename: str = "results.json"): + """ + Save results to the results subdirectory. + + Args: + results: Results dictionary. + filename: Name of the results file. + """ + import json + + results_path = self.results_dir / filename + + def convert_to_serializable(obj): + if hasattr(obj, "tolist"): + return obj.tolist() + elif hasattr(obj, "item"): + return obj.item() + elif isinstance(obj, dict): + return {k: convert_to_serializable(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [convert_to_serializable(i) for i in obj] + return obj + + serializable = convert_to_serializable(results) + + with open(results_path, "w") as f: + json.dump(serializable, f, indent=2) + + logger.info(f"Saved results to: {results_path}") + return results_path diff --git a/src/alignment/infrastructure/storage/logging.py b/src/alignment/infrastructure/storage/logging.py index 9cd439f6..02deeb2c 100644 --- a/src/alignment/infrastructure/storage/logging.py +++ b/src/alignment/infrastructure/storage/logging.py @@ -4,6 +4,7 @@ import json import logging +import logging.config import sys from datetime import datetime from pathlib import Path diff --git a/src/alignment/metrics/__init__.py b/src/alignment/metrics/__init__.py index 3ec2fb82..7c0903a7 100644 --- a/src/alignment/metrics/__init__.py +++ b/src/alignment/metrics/__init__.py @@ -63,6 +63,112 @@ - cross_layer_importance: RQ + Downstream_Importance - Redundancy (SCAR logic) - layer_transition_efficiency: Downstream_Importance / (1 + Redundancy) +============================================================================= +METRIC CONFIGURATION OPTIONS +============================================================================= + +All metrics share these common options (via BaseMetric): + - force_cpu_for_large_ops: bool (default: True) - Move to CPU for large tensors + - cpu_threshold: int (default: 1e8) - Element count threshold for CPU fallback + +CNN Input Handling Modes (for convolutional layers): + Metrics handle 4D inputs [B, C, H, W] differently. The recommended approach + depends on the metric and use case: + + 1. "unfold" - Unfold spatial dims into patches [B*P, C*K*K] + - Most accurate for RQ/covariance metrics + - Memory: O(B * P * C * K^2) where P = H*W patches + - Use for: RQ, MI (when weights have kernel shape) + + 2. "spatial" - Reshape [B, C, H, W] -> [B*H*W, C] + - Treats each spatial location as a sample + - Preserves spatial variance information + - Memory: O(B * H * W * C) + - Use for: Fast covariance estimation + + 3. "gap" (Global Average Pooling) - [B, C, H, W] -> [B, C] + - Fastest, lowest memory + - Loses spatial information + - Memory: O(B * C) + - Use for: Quick experiments, early conv layers + + 4. "channel" - Per-channel statistics only + - Aggregates over spatial dims + - Memory: O(C) + - Use for: Activation magnitude metrics + +INDIVIDUAL METRIC OPTIONS: +-------------------------- + +rayleigh_quotient: + relative: bool = True # Normalize by trace(Σ) for relative alignment + min_samples: int = 2 # Minimum samples for covariance + scale_by_norm: bool = False # Scale covariance by Frobenius norm + regularization: float = 1e-6 # Diagonal regularization for stability + +rq_fast (FastRayleighQuotient): + relative: bool = True # Same as RQ + regularization: float = 1e-6 # Same as RQ + # Uses GAP for CNN inputs - 10-100x faster than patchwise + +rq_spatial (SpatialRayleighQuotient): + # Same options as RQ + # Uses spatial reshaping for CNN inputs + +gaussian_mi_analytic: + expansion_order: int = 2 # Edgeworth expansion order (0=pure Gaussian) + noise_std: float = 0.1 # Assumed noise standard deviation + regularization: float = 1e-6 # Covariance regularization + per_neuron: bool = True # Per-neuron MI (True) or joint MI (False) + +pairwise_redundancy_gaussian: + num_pairs: int = 10 # Partners to sample per neuron + sampling_strategy: str = "random" # "random", "nearest", "all" + mode: str = "output_based" # "output_based" (fast) or "covariance_based" + regularization: float = 1e-6 # For covariance_based mode + +average_redundancy: + min_samples: int = 2 # Minimum samples + use_correlation: bool = True # Use correlation (True) or covariance (False) + +mi_about_class: + method: str = "gaussian" # "gaussian" or "binning" + bins: int = 10 # For binning method + min_samples_per_class: int = 5 # Minimum samples per class + +conditional_rayleigh_quotient: + relative: bool = True # Same as RQ + min_samples: int = 2 # Minimum samples per class + regularization: float = 1e-6 # Same as RQ + return_delta: bool = False # Return ΔRQ = RQ_uncond - RQ_cond + +activation_l2_norm: + aggregate_method: str = "l2" # "l2", "mean", "max" + use_absolute: bool = True # Take absolute value before aggregation + +activation_outlier_index: + quantile: float = 0.999 # High percentile for outlier detection + eps: float = 1e-6 # Numerical stability + +composite_importance: + weights: dict # {metric_name: weight} - negative for penalties + normalize_components: bool = True # Normalize to [0, 1] before combining + log_transform_rq: bool = True # Apply log to RQ + +cross_layer_importance: + rq_weight: float = 0.25 # Weight for RQ (α) + downstream_weight: float = 0.35 # Weight for downstream importance (β) + within_redundancy_weight: float = 0.25 # Penalty for redundancy (γ) + activation_weight: float = 0.15 # Weight for activation magnitude + normalize: bool = True # Normalize components to [0, 1] + max_refs: int = 512 # Max reference neurons for efficiency + +halo_redundancy: + supernode_fraction: float = 0.01 # Top fraction as supernodes + halo_fraction: float = 0.10 # Fraction of non-supernodes as halo + max_samples: int = 1000 # Max activation samples + max_pairs_per_group: int = 500 # Max pairs for efficiency + ============================================================================= """ @@ -189,6 +295,105 @@ def get_metric_category(name: str) -> str: return 'other' +def get_optimization_status() -> dict: + """ + Check availability of metric optimizations (JIT, GPU). + + Returns: + Dictionary with optimization availability status: + { + 'jit_available': bool, + 'gpu_available': bool, + 'cuda_available': bool, + 'available_jit_functions': list, + 'available_gpu_functions': list, + } + """ + import torch + + status = { + 'jit_available': False, + 'gpu_available': False, + 'cuda_available': torch.cuda.is_available(), + 'available_jit_functions': [], + 'available_gpu_functions': [], + } + + # Check JIT functions + try: + from ..infrastructure.computing.optimized.jit import ( + compute_rayleigh_quotient_jit, + compute_mutual_information_gaussian_jit, + compute_node_correlation_jit, + ) + status['jit_available'] = True + status['available_jit_functions'] = [ + 'rayleigh_quotient', + 'mutual_information', + 'node_correlation', + 'cosine_similarity', + 'eigenvalue_entropy', + 'spectral_norm', + ] + except ImportError: + pass + + # Check GPU functions + try: + from ..infrastructure.computing.optimized.gpu import ( + gpu_histogram1d, + gpu_mutual_information, + GPUAcceleratedMetrics, + ) + status['gpu_available'] = True + status['available_gpu_functions'] = [ + 'histogram1d', + 'histogram2d', + 'mutual_information', + 'entropy', + 'conditional_entropy', + 'fast_covariance', + 'fast_correlation', + ] + except ImportError: + pass + + return status + + +def get_metric_with_optimizations( + name: str, + use_jit: bool = False, + use_gpu_acceleration: bool = False, + **kwargs +): + """ + Get a metric instance with optimization options. + + Args: + name: Name of the metric + use_jit: Enable JIT-compiled computations (20-50% faster) + use_gpu_acceleration: Enable GPU-accelerated functions + **kwargs: Additional metric parameters + + Returns: + Instantiated metric object with optimizations enabled + + Example: + >>> metric = get_metric_with_optimizations( + ... "rayleigh_quotient", + ... use_jit=True, + ... relative=True + ... ) + """ + return METRIC_REGISTRY.create( + name, + use_jit=use_jit, + use_gpu_acceleration=use_gpu_acceleration, + **kwargs + ) + + # For convenience, expose the registry and functions __all__ = [ "METRIC_REGISTRY", @@ -197,4 +402,6 @@ def get_metric_category(name: str) -> str: "get_recommended_metrics", "get_extended_metrics", "get_metric_category", + "get_optimization_status", + "get_metric_with_optimizations", ] diff --git a/src/alignment/metrics/information/__init__.py b/src/alignment/metrics/information/__init__.py index 35bf063e..89e54cfb 100644 --- a/src/alignment/metrics/information/__init__.py +++ b/src/alignment/metrics/information/__init__.py @@ -3,7 +3,7 @@ """ from .conditional_mutual_information import ConditionalMutualInformation -from .gaussian_mi import GaussianMIAnalytic +from .gaussian_mi import GaussianMIAnalytic, FastGaussianMI from .mi_projection import MIProjectionVsMeanInput from .mutual_information import MutualInformationBinning, MutualInformationGaussian from .pairwise_gaussian import PairwiseRedundancyGaussian @@ -27,11 +27,13 @@ "MutualInformationGaussian", "MutualInformationBinning", "GaussianMIAnalytic", + "FastGaussianMI", # Fast MI variant using GAP for CNNs # Redundancy "AverageRedundancy", "PairwiseRedundancyGaussian", # Synergy "SynergyGaussianMMI", + "SynergyContinuousTarget", # PID "SharedInformation", "UniqueInformationX", diff --git a/src/alignment/metrics/information/gaussian_mi.py b/src/alignment/metrics/information/gaussian_mi.py index d92385e8..b13d89b7 100644 --- a/src/alignment/metrics/information/gaussian_mi.py +++ b/src/alignment/metrics/information/gaussian_mi.py @@ -340,3 +340,90 @@ def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional # Return same value for all neurons return torch.full((output_dim,), total_mi.item(), device=original_device) + + +@register_metric("mi_fast", aliases=["gaussian_mi_fast", "mi_gap"]) +class FastGaussianMI(GaussianMIAnalytic): + """ + Fast Gaussian MI approximation for CNNs using Global Average Pooling. + + Instead of unfolding patches or spatial reshaping, uses GAP to reduce + spatial dimensions to get a channel-wise computation. + + This is an APPROXIMATION but is: + - 10-100x faster than spatial MI + - Uses O(C^2) memory instead of O((C*H*W)^2) + - Better for early conv layers with large spatial dimensions + + For inputs [B, C_in, H, W] and weights [C_out, C_in, k, k]: + 1. Apply GAP: inputs -> [B, C_in] + 2. Sum weights over kernel: weights -> [C_out, C_in] + 3. Compute standard MI on channel dimension + + Example: + >>> mi_fast = FastGaussianMI() + >>> scores = mi_fast.compute(inputs=inputs, weights=conv.weight) + """ + + def __init__( + self, + noise_std: float = 0.1, + regularization: float = 1e-6, + **config, + ): + """ + Initialize fast MI metric. + + Args: + noise_std: Assumed noise standard deviation + regularization: Regularization for covariance + **config: Additional configuration + """ + super().__init__( + expansion_order=0, # Skip Edgeworth for speed + noise_std=noise_std, + regularization=regularization, + per_neuron=True, + **config, + ) + + def compute( + self, + inputs: torch.Tensor, + weights: torch.Tensor, + outputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute fast MI using GAP for CNN inputs. + + Args: + inputs: Input activations [B, C_in, H, W] or [B, C_in] + weights: Conv weights [C_out, C_in, k_H, k_W] or [C_out, C_in] + outputs: Not used + + Returns: + MI values for each output channel [C_out] + """ + if weights is None: + raise ValueError("FastMI requires weights") + if inputs is None: + raise ValueError("FastMI requires inputs") + + # Handle CNN inputs: [B, C_in, H, W] -> GAP -> [B, C_in] + if inputs.ndim == 4: + # Global Average Pooling over spatial dimensions + inputs = inputs.mean(dim=(2, 3)) # [B, C_in] + elif inputs.ndim == 3: + # Patchwise input [B, F, P] -> average over patches + inputs = inputs.mean(dim=2) # [B, F] + + # Handle Conv weights: [C_out, C_in, k_H, k_W] -> sum over kernel -> [C_out, C_in] + if weights.ndim == 4: + # Sum over kernel dimensions to get channel-wise weights + weights = weights.sum(dim=(2, 3)) # [C_out, C_in] + elif weights.ndim > 2: + weights = weights.reshape(weights.shape[0], -1) + + # Now use standard 2D MI computation + return super().compute(inputs=inputs, weights=weights, **kwargs) diff --git a/src/alignment/metrics/information/synergy_mmi.py b/src/alignment/metrics/information/synergy_mmi.py index d714436e..b9030995 100644 --- a/src/alignment/metrics/information/synergy_mmi.py +++ b/src/alignment/metrics/information/synergy_mmi.py @@ -40,18 +40,29 @@ class SynergyGaussianMMI(BaseMetric): >>> print(synergy.shape) # [num_neurons] """ - def __init__(self, num_pairs: int = 10, sampling_strategy: str = "random", **config: Any): + def __init__( + self, + num_pairs: int = 10, + sampling_strategy: str = "random", + regularization: float = 1e-6, + min_samples: int = 2, + **config: Any, + ): """ Initialize synergy metric. Args: num_pairs: Number of partner neurons to sample per neuron sampling_strategy: How to sample pairs ('random', 'nearest', 'all') + regularization: Regularization for numerical stability + min_samples: Minimum samples per class for reliable estimates **config: Additional configuration """ super().__init__(**config) self.num_pairs = num_pairs self.sampling_strategy = sampling_strategy + self.regularization = regularization + self.min_samples = min_samples @property def requires_inputs(self) -> bool: diff --git a/src/alignment/metrics/rayleigh/rayleigh_quotient.py b/src/alignment/metrics/rayleigh/rayleigh_quotient.py index 06af2048..1f6493a1 100644 --- a/src/alignment/metrics/rayleigh/rayleigh_quotient.py +++ b/src/alignment/metrics/rayleigh/rayleigh_quotient.py @@ -30,6 +30,10 @@ class RayleighQuotient(BaseMetric): When relative=True (default), normalizes by trace(C) to get a proportion of total variance. + + Optimization Options: + use_jit: bool = False - Use JIT-compiled RQ computation (20-50% faster) + use_gpu_acceleration: bool = False - Use GPU-accelerated covariance """ # Class-level set to track warned dimension pairs (avoid spam) @@ -52,6 +56,8 @@ def __init__( min_samples: Minimum samples required for covariance computation scale_by_norm: Whether to scale covariance by its Frobenius norm regularization: Small value added to diagonal for numerical stability (default: 1e-6) + use_jit: bool = False - Use JIT-compiled RQ computation + use_gpu_acceleration: bool = False - Use GPU-accelerated covariance **config: Additional configuration parameters """ super().__init__(**config) @@ -61,6 +67,14 @@ def __init__( self.regularization = regularization # Optional: class-conditioned covariance support (targets provided at compute time preferred) self._cc_targets = class_conditioned_targets + + # JIT-specific function reference + self._jit_rq_compute = None + if self._use_jit: + jit_fn = self._get_jit_function("rayleigh_quotient") + if jit_fn is not None: + self._jit_rq_compute = jit_fn + logger.info("RayleighQuotient: Using JIT-compiled computation") @property def requires_inputs(self) -> bool: @@ -190,9 +204,33 @@ def compute( use_class_cond = False # Fall back to unconditional if not use_class_cond: - # Compute unconditional covariance matrix - inputs_centered = inputs - inputs.mean(dim=0, keepdim=True) - cov = torch.matmul(inputs_centered.T, inputs_centered) / (batch_size - 1) + # Use JIT-compiled version if available for simple unconditional case + if self._jit_rq_compute is not None and not self.scale_by_norm: + try: + rq_values = self._jit_rq_compute(inputs, weights, self.regularization) + # JIT version doesn't normalize by trace, do it here if relative + if self.relative: + inputs_centered = inputs - inputs.mean(dim=0, keepdim=True) + cov = torch.matmul(inputs_centered.T, inputs_centered) / (batch_size - 1) + trace_cov = torch.trace(cov.float()) + if trace_cov > 1e-12: + rq_values = rq_values / trace_cov + return rq_values + except Exception as e: + logger.warning(f"JIT RQ failed, falling back to standard: {e}") + + # Use GPU-accelerated covariance if available + if self._use_gpu_acceleration and self._get_gpu_function("fast_covariance") is not None: + try: + cov = self._get_gpu_function("fast_covariance")(inputs) + except Exception as e: + logger.warning(f"GPU covariance failed, falling back to standard: {e}") + inputs_centered = inputs - inputs.mean(dim=0, keepdim=True) + cov = torch.matmul(inputs_centered.T, inputs_centered) / (batch_size - 1) + else: + # Compute unconditional covariance matrix + inputs_centered = inputs - inputs.mean(dim=0, keepdim=True) + cov = torch.matmul(inputs_centered.T, inputs_centered) / (batch_size - 1) return self._compute_from_covariance(cov, weights) @@ -350,7 +388,15 @@ def compute_class_conditioned( return {"rq_uncond": rq_uncond, "rq_cond": rq_cond, "delta_rq": delta_rq} - def _compute_patchwise(self, inputs: torch.Tensor, weights: torch.Tensor, weight_by_variance: bool = True, **kwargs: Any) -> torch.Tensor: + def _compute_patchwise( + self, + inputs: torch.Tensor, + weights: torch.Tensor, + weight_by_variance: bool = True, + max_patches: int = 64, + use_vectorized: bool = True, + **kwargs: Any + ) -> torch.Tensor: """ Compute patch-wise RQ for CNN layers. @@ -358,6 +404,8 @@ def _compute_patchwise(self, inputs: torch.Tensor, weights: torch.Tensor, weight inputs: Input patches [batch_size, features, num_patches] weights: Flattened weights [output_features, features] weight_by_variance: Whether to weight patches by their variance + max_patches: Maximum patches to use (subsample if more) + use_vectorized: Use vectorized computation (faster, more memory) Returns: RQ values [output_features] @@ -373,65 +421,144 @@ def _compute_patchwise(self, inputs: torch.Tensor, weights: torch.Tensor, weight if weights.ndim > 2: weights = weights.reshape(weights.shape[0], -1) - # Compute variance for each patch - patch_var = torch.var(inputs, dim=0, keepdim=False) # [features, num_patches] - patch_total_var = patch_var.sum(dim=0) # [num_patches] + # Subsample patches if too many (memory/speed optimization) + if num_patches > max_patches: + logger.debug(f"Subsampling patches: {num_patches} -> {max_patches}") + indices = torch.randperm(num_patches, device=inputs.device)[:max_patches] + inputs = inputs[:, :, indices] + num_patches = max_patches + + # Handle dimension mismatch + min_dim = min(features, weights.shape[1]) + if features != weights.shape[1]: + inputs = inputs[:, :min_dim, :] + weights = weights[:, :min_dim] + features = min_dim + + # Compute patch variances for weighting + patch_var = torch.var(inputs, dim=0) # [features, num_patches] + patch_weights = patch_var.sum(dim=0) if weight_by_variance else torch.ones(num_patches, device=inputs.device) + + if use_vectorized and num_patches <= 256: + # VECTORIZED: Compute all patches at once (faster but more memory) + return self._compute_patchwise_vectorized(inputs, weights, patch_weights) + else: + # LOOP: Compute patches one by one (slower but less memory) + return self._compute_patchwise_loop(inputs, weights, patch_weights) + + def _compute_patchwise_vectorized( + self, + inputs: torch.Tensor, + weights: torch.Tensor, + patch_weights: torch.Tensor + ) -> torch.Tensor: + """ + Vectorized patchwise RQ computation - no explicit loop. + + Uses einsum for efficient batch computation over patches. + Memory: O(num_patches * features^2) for covariances + """ + batch_size, features, num_patches = inputs.shape + output_features = weights.shape[0] + device = weights.device + eps = 1e-12 + + # Center inputs: [B, F, P] + inputs_centered = inputs - inputs.mean(dim=0, keepdim=True) + + # Compute covariance for each patch using einsum + # cov[p] = X_centered[:, :, p].T @ X_centered[:, :, p] / (B-1) + # = einsum('bi,bj->ij' for each patch p) + # Batched: einsum('bfp,bgp->fgp') gives [F, F, P] + all_covs = torch.einsum('bfp,bgp->fgp', inputs_centered, inputs_centered) / (batch_size - 1) + + # Add regularization + if self.regularization > 0: + eye = torch.eye(features, device=device, dtype=all_covs.dtype) + all_covs = all_covs + self.regularization * eye.unsqueeze(-1) + + # Compute w^T C w for all patches and neurons at once + # weights: [N, F], all_covs: [F, G, P] where F=G + # numerator[n, p] = sum_f sum_g w[n,f] * cov[f,g,p] * w[n,g] + # = einsum('nf,fgp,ng->np') + wc = torch.einsum('nf,fgp->ngp', weights, all_covs) # [N, G, P] + numerator = torch.einsum('ngp,ng->np', wc, weights) # [N, P] + + # Compute w^T w (same for all patches) + denominator = (weights ** 2).sum(dim=1) # [N] + + # Compute RQ per patch + patch_rq = torch.zeros(output_features, num_patches, device=device) + valid_mask = denominator > eps + patch_rq[valid_mask, :] = numerator[valid_mask, :] / denominator[valid_mask].unsqueeze(1) + + # Normalize by trace if relative + if self.relative: + # trace of each patch covariance + traces = torch.diagonal(all_covs, dim1=0, dim2=1).sum(dim=0) # [P] + traces = torch.clamp(traces, min=eps) + patch_rq = patch_rq / traces.unsqueeze(0) + + # Weighted average across patches + patch_weights = torch.clamp(patch_weights, min=0) + total_weight = patch_weights.sum() + if total_weight > eps: + final_rq = (patch_rq * patch_weights.unsqueeze(0)).sum(dim=1) / total_weight + else: + final_rq = patch_rq.mean(dim=1) + + return torch.nan_to_num(final_rq, nan=0.0, posinf=0.0, neginf=0.0) - # Initialize accumulators - weighted_rq_sum = torch.zeros(output_features, device=weights.device) + def _compute_patchwise_loop( + self, + inputs: torch.Tensor, + weights: torch.Tensor, + patch_weights: torch.Tensor + ) -> torch.Tensor: + """ + Loop-based patchwise RQ computation - lower memory usage. + + Memory: O(features^2) for one covariance at a time + """ + batch_size, features, num_patches = inputs.shape + output_features = weights.shape[0] + device = weights.device + eps = 1e-12 + + weighted_rq_sum = torch.zeros(output_features, device=device) total_weight = 0.0 - # Compute RQ for each patch for p in range(num_patches): patch_data = inputs[:, :, p] # [batch_size, features] - - # Center the data patch_data_centered = patch_data - patch_data.mean(dim=0, keepdim=True) - - # Compute covariance for this patch patch_cov = torch.matmul(patch_data_centered.T, patch_data_centered) / (batch_size - 1) - # Scale by norm if requested + if self.regularization > 0: + patch_cov = patch_cov + self.regularization * torch.eye(features, device=device, dtype=patch_cov.dtype) + if self.scale_by_norm: cov_norm = torch.norm(patch_cov, p="fro") if cov_norm > 0: patch_cov = patch_cov / cov_norm - # Handle dimension mismatch - min_dim = min(features, weights.shape[1]) - if features != weights.shape[1]: - patch_cov = patch_cov[:min_dim, :min_dim] - weights_adj = weights[:, :min_dim] - else: - weights_adj = weights + # Compute RQ + wc = torch.matmul(weights, patch_cov) + numerator = torch.sum(wc * weights, dim=1) + denominator = torch.sum(weights * weights, dim=1) - # Compute RQ for this patch - wc = torch.matmul(weights_adj, patch_cov) - numerator = torch.sum(wc * weights_adj, dim=1) - denominator = torch.sum(weights_adj * weights_adj, dim=1) - - eps = 1e-12 patch_rq = torch.zeros_like(numerator) valid_mask = denominator > eps patch_rq[valid_mask] = numerator[valid_mask] / denominator[valid_mask] - # Normalize by trace if relative if self.relative: - # Convert to float32 for trace (bfloat16 not supported) trace = torch.trace(patch_cov.float()) if trace > eps: patch_rq = patch_rq / trace - # Weight by patch variance if requested - if weight_by_variance: - patch_weight = patch_total_var[p].item() - else: - patch_weight = 1.0 - - weighted_rq_sum += patch_rq * patch_weight - total_weight += patch_weight + pw = patch_weights[p].item() + weighted_rq_sum += patch_rq * pw + total_weight += pw - # Average across patches if total_weight > 0: final_rq = weighted_rq_sum / total_weight else: @@ -526,3 +653,153 @@ def compute( final_rq = torch.zeros(output_channels, device=weights.device, dtype=weights.dtype) return final_rq + + +@register_metric("rq_fast", aliases=["rq_gap", "rayleigh_quotient_fast"]) +class FastRayleighQuotient(RayleighQuotient): + """ + Fast Rayleigh Quotient approximation for CNNs using Global Average Pooling. + + Instead of unfolding patches, this version uses GAP to reduce spatial dimensions + to get a channel-wise covariance, then computes RQ on the channel dimension. + + This is an APPROXIMATION but is: + - 10-100x faster than patchwise RQ + - Uses O(C^2) memory instead of O((C*k*k)^2) + - Better for early conv layers with large spatial dimensions + + For inputs [B, C_in, H, W] and weights [C_out, C_in, k, k]: + 1. Apply GAP: inputs -> [B, C_in] + 2. Sum weights over kernel: weights -> [C_out, C_in] + 3. Compute standard RQ on channel dimension + + Use cases: + - Quick experiments + - Early conv layers (layer1, layer2) where exact RQ is too slow + - When channel-level importance is sufficient + + Example: + >>> rq_fast = FastRayleighQuotient() + >>> # inputs: [batch, channels, height, width] + >>> scores = rq_fast.compute(inputs=inputs, weights=conv.weight) + """ + + def __init__(self, relative: bool = True, regularization: float = 1e-6, **config: Any): + """ + Initialize fast RQ metric. + + Args: + relative: Whether to normalize by trace(C) + regularization: Regularization for covariance + **config: Additional configuration + """ + super().__init__(relative=relative, regularization=regularization, **config) + + def compute( + self, + inputs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + outputs: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Compute fast RQ using GAP for CNN inputs. + + Args: + inputs: Input activations [B, C_in, H, W] or [B, C_in] + weights: Conv weights [C_out, C_in, k_H, k_W] or [C_out, C_in] + outputs: Not used + + Returns: + RQ values for each output channel [C_out] + """ + if weights is None: + raise ValueError("FastRQ requires weights") + if inputs is None: + raise ValueError("FastRQ requires inputs") + + # Handle CNN inputs: [B, C_in, H, W] -> GAP -> [B, C_in] + if inputs.ndim == 4: + # Global Average Pooling over spatial dimensions + inputs = inputs.mean(dim=(2, 3)) # [B, C_in] + elif inputs.ndim == 3: + # Patchwise input [B, F, P] -> average over patches + inputs = inputs.mean(dim=2) # [B, F] + + # Handle Conv weights: [C_out, C_in, k_H, k_W] -> sum over kernel -> [C_out, C_in] + if weights.ndim == 4: + # Sum over kernel dimensions to get channel-wise weights + weights = weights.sum(dim=(2, 3)) # [C_out, C_in] + elif weights.ndim > 2: + weights = weights.reshape(weights.shape[0], -1) + + # Now use standard 2D RQ computation + return super().compute(inputs=inputs, weights=weights, **kwargs) + + +@register_metric("rq_spatial", aliases=["rayleigh_quotient_spatial"]) +class SpatialRayleighQuotient(RayleighQuotient): + """ + Spatial Rayleigh Quotient that treats spatial locations as additional samples. + + Instead of unfolding into patches, this version: + 1. Reshapes [B, C, H, W] -> [B*H*W, C] + 2. Computes channel-wise covariance over all spatial locations + 3. Computes RQ using kernel-summed weights + + This is FASTER than patchwise and captures spatial variation as sample diversity. + + Trade-off vs GAP: + - GAP averages spatially, losing spatial variance information + - Spatial treats each location as a sample, preserving variance + + Trade-off vs Patchwise: + - Patchwise captures kernel-local correlations + - Spatial captures global channel correlations only + + Example: + >>> rq_spatial = SpatialRayleighQuotient() + >>> scores = rq_spatial.compute(inputs=inputs, weights=conv.weight) + """ + + def compute( + self, + inputs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + outputs: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Compute spatial RQ. + + Args: + inputs: [B, C_in, H, W] or [B, C_in] + weights: [C_out, C_in, k_H, k_W] or [C_out, C_in] + outputs: Not used + + Returns: + RQ values [C_out] + """ + if weights is None: + raise ValueError("SpatialRQ requires weights") + if inputs is None: + raise ValueError("SpatialRQ requires inputs") + + # Handle CNN inputs: [B, C, H, W] -> [B*H*W, C] + if inputs.ndim == 4: + B, C, H, W = inputs.shape + # Permute to [B, H, W, C] then reshape to [B*H*W, C] + inputs = inputs.permute(0, 2, 3, 1).reshape(-1, C) + elif inputs.ndim == 3: + # Patchwise input [B, F, P] -> transpose and reshape [B*P, F] + B, F, P = inputs.shape + inputs = inputs.permute(0, 2, 1).reshape(-1, F) + + # Handle Conv weights: sum over kernel + if weights.ndim == 4: + weights = weights.sum(dim=(2, 3)) + elif weights.ndim > 2: + weights = weights.reshape(weights.shape[0], -1) + + # Standard 2D RQ on [samples, channels] + return super().compute(inputs=inputs, weights=weights, **kwargs) diff --git a/src/alignment/metrics/spectral/spectral_alignment.py b/src/alignment/metrics/spectral/spectral_alignment.py index fb209c86..7f8ae9ee 100644 --- a/src/alignment/metrics/spectral/spectral_alignment.py +++ b/src/alignment/metrics/spectral/spectral_alignment.py @@ -28,6 +28,18 @@ def __init__(self, normalize: bool = True): super().__init__() self.normalize = normalize + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def compute( self, inputs: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: @@ -88,6 +100,18 @@ def __init__(self, p: float = 2.0, top_k: Optional[int] = None): self.top_k = top_k self._reference_eigenvalues = None + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def set_reference(self, weights: torch.Tensor): """Set reference weight matrix for comparison.""" self._reference_eigenvalues = self._compute_eigenvalues(weights) @@ -167,6 +191,18 @@ def __init__(self, n_components: int = 5, n_clusters: int = 10): self.n_components = n_components self.n_clusters = n_clusters + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return True + def compute( self, inputs: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: @@ -232,6 +268,18 @@ def __init__(self, max_iterations: int = 100, tolerance: float = 1e-6): self.max_iterations = max_iterations self.tolerance = tolerance + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def compute( self, inputs: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: diff --git a/src/alignment/metrics/spectral/spectral_classic.py b/src/alignment/metrics/spectral/spectral_classic.py index 68c4c9ce..b66fa761 100644 --- a/src/alignment/metrics/spectral/spectral_classic.py +++ b/src/alignment/metrics/spectral/spectral_classic.py @@ -37,6 +37,18 @@ def __init__(self, n_components: Optional[int] = None, normalize: bool = True, e self.normalize = normalize self.epsilon = epsilon + @property + def requires_inputs(self) -> bool: + return True + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Compute spectral alignment scores. @@ -112,6 +124,18 @@ def __init__(self, epsilon: float = 1e-8): super().__init__() self.epsilon = epsilon + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Compute spectral norm ratio for each layer. @@ -160,6 +184,18 @@ def __init__(self, temperature: float = 1.0, epsilon: float = 1e-8): self.temperature = temperature self.epsilon = epsilon + @property + def requires_inputs(self) -> bool: + return True + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False # Computed from inputs and weights if not provided + def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Compute eigenvalue entropy for neuron activations. @@ -230,6 +266,18 @@ def __init__(self, n_clusters: int = 5, similarity_type: str = "correlation", ep self.similarity_type = similarity_type self.epsilon = epsilon + @property + def requires_inputs(self) -> bool: + return False + + @property + def requires_weights(self) -> bool: + return True + + @property + def requires_outputs(self) -> bool: + return False + def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Compute spectral clustering scores. diff --git a/src/alignment/models/README.md b/src/alignment/models/README.md new file mode 100644 index 00000000..ae7cb763 --- /dev/null +++ b/src/alignment/models/README.md @@ -0,0 +1,137 @@ +# Models Module + +Model wrappers and loaders for the alignment metrics framework. + +## Quick Start + +```python +# Load a pretrained vision model +from alignment.models import ModelWrapper +import torchvision.models as tvm + +model = tvm.resnet18(pretrained=True) +wrapper = ModelWrapper(model) + +# Forward pass with activation capture +outputs, activations = wrapper.forward_with_activations(input_batch) +# activations: {'layer1_input': ..., 'layer1_output': ..., ...} + +# Get weights for alignment computation +weights = wrapper.get_layer_weights() +# weights: {'layer1': tensor[out_features, in_features], ...} +``` + +## Components + +### Model Wrappers + +| Class | Purpose | When to Use | +|-------|---------|-------------| +| `ModelWrapper` | General-purpose wrapper | Default for most experiments | +| `AlignmentNetwork` | Backward-compatible wrapper | Legacy code compatibility | +| `TransformerWrapperEnhanced` | Transformer with Q/K/V tracking | Per-head attention analysis | +| `LLaMAWrapper` | LLaMA-specific wrapper | FFN/attention analysis for LLaMA models | + +### Model Loaders (hub.py) + +Use these in YAML configs via the model registry: + +```yaml +# Vision model from torchvision +model: + name: "resnet18" # Shorthand - auto-loads via torchvision + +# LLM from Hugging Face +model: + name: "hf_causal_lm" + model_id: "meta-llama/Llama-3.1-8B" + dtype: "bfloat16" + device_map: "auto" +``` + +| Loader | Description | +|--------|-------------| +| `TorchvisionModel` | ResNet, VGG, MobileNet, etc. from torchvision | +| `TIMMModel` | Any model from timm library | +| `HFVisionModel` | Vision transformers from Hugging Face | +| `HFCausalLM` | Causal LMs (LLaMA, Mistral, GPT) from Hugging Face | + +### Custom Architectures (architectures/) + +Simple models for experiments: + +```python +from alignment.models import MLP, CNN2P2, create_model + +# Create MLP for MNIST +model = create_model('mlp', 'mnist', hidden_dims=[300, 200]) + +# Create CNN for CIFAR-10 +model = create_model('cnn2p2', 'cifar10') +``` + +### Hook Management (hooks.py) + +Automatic hook lifecycle management: + +```python +from alignment.models.hooks import HookManager + +hook_mgr = HookManager() + +# Temporary hooks with automatic cleanup +with hook_mgr.temporary_hooks(model, ['layer1', 'layer2']) as cache: + output = model(input_batch) + layer1_output = cache['layer1_output'] + layer2_input = cache['layer2_input'] +# Hooks automatically removed after context + +# Or use PersistentHookManager for long-running tracking +from alignment.models.hooks import PersistentHookManager + +persistent_mgr = PersistentHookManager() +persistent_mgr.register_persistent_hooks(model, ['layer1']) +# ... multiple forward passes ... +persistent_mgr.cleanup() # Manual cleanup required +``` + +## Integration with Experiments + +The experiment runner automatically wraps models: + +```python +# In experiments/base.py +self.wrapped_model = ModelWrapper(self.model, **wrapper_kwargs) + +# Activation capture via service layer +from alignment.services.activation_capture import ActivationCaptureService +service = ActivationCaptureService(wrapped_model) +data = service.capture(input_batch) +``` + +## Layer Auto-Discovery + +Wrappers automatically discover trackable layers: + +```python +wrapper = ModelWrapper(model) +print(wrapper.tracked_layers) +# ['layer1.0.conv1', 'layer1.0.conv2', ..., 'fc'] + +# Or specify explicitly +wrapper = ModelWrapper(model, tracked_layers=['layer4', 'fc']) +``` + +## CNN Preprocessing Modes + +For convolutional layers, choose preprocessing: + +| Mode | Description | Memory | Best For | +|------|-------------|--------|----------| +| `unfold` | Unfold with kernel params | High | Exact RQ computation | +| `patchwise` | Keep patch structure | Medium | Patch-level analysis | +| `flatten` | Simple flatten | Low | Quick experiments | + +```python +wrapper = ModelWrapper(model, preprocessing_mode='unfold') +``` diff --git a/src/alignment/pruning/README.md b/src/alignment/pruning/README.md index d260183d..e05389cf 100644 --- a/src/alignment/pruning/README.md +++ b/src/alignment/pruning/README.md @@ -1,6 +1,6 @@ # Pruning Module -Neural network pruning strategies. +Neural network pruning strategies and infrastructure. ## Strategies @@ -12,9 +12,42 @@ Neural network pruning strategies. ### Alignment-Based - `AlignmentPruning` - Prune by alignment score - `HybridPruning` - Combine magnitude and alignment +- `GlobalAlignmentPruning` - Global alignment-based pruning +- `CascadingAlignmentPruning` - Sequential layer pruning with score recomputation + +### Gradient-Based +- `GradientPruning` - Prune by gradient magnitude +- `FisherPruning` - Fisher information-based pruning +- `MomentumPruning` - Momentum-based pruning + +### Eigenvector-Based +- `EigenvectorPruning` - PCA-based pruning (prune low-variance neurons) + +### Movement-Based (Sanh et al. NeurIPS 2020) +- `MovementPruning` - Prune weights moving toward zero during training +- `AdaptiveMovementPruning` - Adaptive movement pruning with auto-tuned amounts + +### Adaptive Sensitivity-Based +- `AdaptiveSensitivityPruning` - Layer-adaptive pruning based on sensitivity analysis ### Random - `RandomPruning` - Random baseline +- `LayerwiseRandomPruning` - Per-layer random pruning +- `BernoulliPruning` - Bernoulli mask pruning + +### Cluster-Aware +- `ClusterAwarePruning` - Cluster-based structured pruning +- `CompositePruning` - Composite pruning strategies + +### LLM Baselines +- `WandaPruning` - Sun et al. 2023 +- `SparseGPTPruning` - Frantar & Alistarh 2023 + +### Parallel/Advanced +- `ParallelModePruning` - Multiple modes simultaneously +- `TensorizedPruning` - Tensorized pruning operations +- `AsyncParallelPruning` - Async parallel pruning +- `ParallelBatchPruning` - Batch parallel pruning ## Usage @@ -26,6 +59,112 @@ strategy = MagnitudePruning(config) mask = strategy.prune(layer, amount=0.5) ``` +### Eigenvector Pruning + +```python +from alignment.pruning import EigenvectorPruning, PruningConfig + +config = PruningConfig(amount=0.5, structured=True, pruning_mode='low') +strategy = EigenvectorPruning(config=config) + +# Prune neurons with low eigenvalue contribution (low variance) +mask = strategy.prune(layer, inputs=activations) +``` + +### Movement Pruning + +```python +from alignment.pruning import MovementPruning + +strategy = MovementPruning() + +# During training, track weight movement +for batch in train_loader: + loss.backward() + strategy.update_movement_history(model) # Call before optimizer.step() + optimizer.step() + +# After training, prune weights moving toward zero +mask = strategy.prune(layer, amount=0.5) +``` + +### Adaptive Sensitivity Pruning + +```python +from alignment.pruning import AdaptiveSensitivityPruning + +strategy = AdaptiveSensitivityPruning( + target_sparsity=0.7, + metric='rayleigh_quotient', + sensitivity_method='activation_variance', # FAST - single forward pass + min_amount=0.1, + max_amount=0.9 +) + +# Compute layer sensitivities and prune adaptively +sensitivities = strategy.compute_all_sensitivities( + model, + layer_names, + data_loader=val_loader # For fast methods +) + +# Print report showing per-layer sensitivity and pruning amounts +strategy.print_sensitivity_report() + +# Apply adaptive pruning +masks = strategy.prune_adaptive(model, layer_names, eval_fn=None, inputs_per_layer=None) +``` + +#### Sensitivity Methods + +| Method | Speed | Accuracy | Requirements | +|--------|-------|----------|--------------| +| `perturbation` | Slow | High | `eval_fn` | +| `masking` | Slow | High | `eval_fn` | +| `activation_variance` | **Fast** | Medium | `data_loader` | +| `gradient` | **Fast** | Medium | `data_loader` | +| `fisher` | Medium | Medium-High | `data_loader` | +| `weight_magnitude` | **Fastest** | Low | None | + +**Recommendation**: Use `activation_variance` for a good speed/accuracy tradeoff. + +### Cascading Pruning (Progressive) + +```python +from alignment.pruning import CascadingAlignmentPruning, PruningConfig + +config = PruningConfig(amount=0.5, structured=True) +strategy = CascadingAlignmentPruning( + metric='rayleigh_quotient', + direction='forward', # or 'backward' + config=config +) + +# Prune layer by layer, recomputing scores after each +masks = strategy.prune_model(model, get_layer_inputs_fn) +``` + +## Using the Pipeline + +```python +from alignment.pruning import run_pruning_pipeline, PruningPipelineOptions + +options = PruningPipelineOptions( + distribution="uniform", # or "global_threshold" + dependency_aware=True, # Handle dependent layers + min_amount=0.0, + max_amount=0.95 +) + +result = run_pruning_pipeline( + model=model, + layer_scores=scores_dict, + target_sparsity=0.5, + selection_mode="low", + options=options +) +``` + ## Configuration ```python @@ -34,6 +173,7 @@ config = PruningConfig( structured=False, # Structured vs unstructured iterative=False, # Single shot vs iterative global_pruning=False, # Global vs layer-wise + pruning_mode='low', # 'low', 'high', or 'random' ) ``` @@ -41,3 +181,11 @@ config = PruningConfig( - **Unstructured**: Remove individual weights (sparse matrices) - **Structured**: Remove entire neurons/channels (dense matrices) + +## Module Organization + +- `base.py` - Base classes (`BasePruningStrategy`, `PruningConfig`) +- `pipeline.py` - Shared pruning pipeline (`run_pruning_pipeline`) +- `dependency_aware.py` - Handle dependent layers (BatchNorm, etc.) +- `distribution.py` - Layer sparsity distribution strategies +- `strategies/` - All pruning strategy implementations diff --git a/src/alignment/pruning/__init__.py b/src/alignment/pruning/__init__.py index 249d41ab..c0d354b2 100644 --- a/src/alignment/pruning/__init__.py +++ b/src/alignment/pruning/__init__.py @@ -44,20 +44,26 @@ from typing import Optional, Type, Union from .base import BasePruningStrategy, IterativePruningStrategy, PruningConfig +from .pipeline import PruningPipelineOptions, run_pruning_pipeline from .strategies import ( + AdaptiveMovementPruning, + AdaptiveSensitivityPruning, AlignmentPruning, AsyncParallelPruning, BernoulliPruning, CascadingAlignmentPruning, + EigenvectorPruning, FisherPruning, GlobalAlignmentPruning, GlobalMagnitudePruning, GradientPruning, HybridPruning, IterativeMagnitudePruning, + LayerSensitivity, LayerwiseRandomPruning, MagnitudePruning, MomentumPruning, + MovementPruning, ParallelModePruning, RandomPruning, SparseGPTPruning, @@ -82,6 +88,13 @@ "hybrid": HybridPruning, "global_alignment": GlobalAlignmentPruning, "cascading_alignment": CascadingAlignmentPruning, + # Eigenvector-based (PCA pruning) + "eigenvector": EigenvectorPruning, + # Movement-based (Sanh et al. NeurIPS 2020) + "movement": MovementPruning, + "adaptive_movement": AdaptiveMovementPruning, + # Adaptive sensitivity-based + "adaptive_sensitivity": AdaptiveSensitivityPruning, # Random strategies (kept for backward compatibility) # Note: Consider using selection_mode='random' instead "random": RandomPruning, @@ -158,6 +171,14 @@ def list_pruning_strategies() -> list: "HybridPruning", "GlobalAlignmentPruning", "CascadingAlignmentPruning", + # Eigenvector (PCA) strategy + "EigenvectorPruning", + # Movement-based (Sanh et al. 2020) + "MovementPruning", + "AdaptiveMovementPruning", + # Adaptive sensitivity-based + "AdaptiveSensitivityPruning", + "LayerSensitivity", # Random strategies "RandomPruning", "LayerwiseRandomPruning", @@ -172,4 +193,7 @@ def list_pruning_strategies() -> list: # Functions "get_pruning_strategy", "list_pruning_strategies", + # Pipeline helpers + "PruningPipelineOptions", + "run_pruning_pipeline", ] diff --git a/src/alignment/pruning/baselines.py b/src/alignment/pruning/baselines.py deleted file mode 100644 index 7d1fb132..00000000 --- a/src/alignment/pruning/baselines.py +++ /dev/null @@ -1,369 +0,0 @@ -""" -Pruning Baselines for Comparison with SCAR - -Implements: -- Wanda: Weight × Activation pruning (Sun et al., 2023) -- SparseGPT-style: Second-order one-shot pruning (Frantar & Alistarh, 2023) -- Magnitude: Simple weight magnitude pruning -""" - -import torch -import torch.nn as nn -from typing import Dict, List, Optional, Tuple, Any -import logging - -logger = logging.getLogger(__name__) - - -class WandaPruning: - """ - Wanda: Pruning by Weights and Activations - - Reference: Sun et al., "A Simple and Effective Pruning Approach for Large Language Models" (2023) - - Key idea: Score = |W| * ||X||_2 (weight magnitude × activation norm) - """ - - def __init__( - self, - sparsity: float = 0.5, - structured: bool = True, - prune_dim: int = 0, # 0 = prune rows (output neurons), 1 = prune columns (input features) - ): - self.sparsity = sparsity - self.structured = structured - self.prune_dim = prune_dim - - def compute_scores( - self, - weight: torch.Tensor, - activations: torch.Tensor, - ) -> torch.Tensor: - """ - Compute Wanda importance scores. - - Args: - weight: Weight matrix [out_features, in_features] - activations: Input activations [batch, seq_len, in_features] or [batch, in_features] - - Returns: - Importance scores per neuron (structured) or per weight (unstructured) - """ - # Flatten activations if needed - if activations.ndim == 3: - activations = activations.reshape(-1, activations.shape[-1]) - - # Compute activation norms (L2 norm across samples) - activation_norms = torch.norm(activations, p=2, dim=0) # [in_features] - - if self.structured: - # Structured: score per output neuron - # Score_i = sum_j |W_ij| * ||X_j||_2 - weight_abs = torch.abs(weight) # [out_features, in_features] - scores = torch.matmul(weight_abs, activation_norms) # [out_features] - else: - # Unstructured: score per weight - # Score_ij = |W_ij| * ||X_j||_2 - scores = torch.abs(weight) * activation_norms.unsqueeze(0) # [out_features, in_features] - - return scores - - def get_pruning_mask( - self, - scores: torch.Tensor, - sparsity: Optional[float] = None, - ) -> torch.Tensor: - """ - Create pruning mask based on scores. - - Returns: - Boolean mask where True = keep, False = prune - """ - if sparsity is None: - sparsity = self.sparsity - - if self.structured: - # Prune neurons with lowest scores - num_prune = int(sparsity * scores.numel()) - threshold = torch.kthvalue(scores, num_prune).values - mask = scores > threshold - else: - # Prune weights with lowest scores - flat_scores = scores.flatten() - num_prune = int(sparsity * flat_scores.numel()) - threshold = torch.kthvalue(flat_scores, num_prune).values - mask = scores > threshold - - return mask - - -class SparseGPTStylePruning: - """ - Simplified SparseGPT-style pruning using second-order information. - - Reference: Frantar & Alistarh, "SparseGPT: Massive Language Models Can Be - Accurately Pruned in One-Shot" (2023) - - Key idea: Use Hessian approximation to minimize reconstruction error - """ - - def __init__( - self, - sparsity: float = 0.5, - structured: bool = True, - block_size: int = 128, - percdamp: float = 0.01, - ): - self.sparsity = sparsity - self.structured = structured - self.block_size = block_size - self.percdamp = percdamp - - def compute_hessian_inverse( - self, - activations: torch.Tensor, - ) -> torch.Tensor: - """ - Compute inverse Hessian approximation from activations. - - H ≈ X^T X / n (Fisher approximation) - """ - # Flatten activations - if activations.ndim == 3: - activations = activations.reshape(-1, activations.shape[-1]) - - n_samples = activations.shape[0] - - # Compute H = X^T X / n - H = torch.matmul(activations.T, activations) / n_samples - - # Add damping for numerical stability - damp = self.percdamp * torch.diag(H).mean() - H = H + damp * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) - - # Compute inverse (or pseudo-inverse for stability) - try: - H_inv = torch.linalg.inv(H) - except: - H_inv = torch.linalg.pinv(H) - - return H_inv - - def compute_scores( - self, - weight: torch.Tensor, - activations: torch.Tensor, - ) -> torch.Tensor: - """ - Compute SparseGPT-style importance scores. - - Score_i = W_i^2 / [H^{-1}]_{ii} (optimal brain surgeon criterion) - """ - H_inv = self.compute_hessian_inverse(activations) - - if self.structured: - # For structured pruning, aggregate across input dimension - # Score for row i = sum_j W_ij^2 / [H^{-1}]_{jj} - diag_H_inv = torch.diag(H_inv) # [in_features] - weight_sq = weight ** 2 # [out_features, in_features] - scores = torch.sum(weight_sq / diag_H_inv.unsqueeze(0), dim=1) # [out_features] - else: - # Unstructured: per-weight scores - diag_H_inv = torch.diag(H_inv) - scores = (weight ** 2) / diag_H_inv.unsqueeze(0) - - return scores - - def get_pruning_mask( - self, - scores: torch.Tensor, - sparsity: Optional[float] = None, - ) -> torch.Tensor: - """Create pruning mask (same as Wanda).""" - if sparsity is None: - sparsity = self.sparsity - - if self.structured: - num_prune = int(sparsity * scores.numel()) - threshold = torch.kthvalue(scores, max(1, num_prune)).values - mask = scores > threshold - else: - flat_scores = scores.flatten() - num_prune = int(sparsity * flat_scores.numel()) - threshold = torch.kthvalue(flat_scores, max(1, num_prune)).values - mask = scores > threshold - - return mask - - -class MagnitudePruning: - """ - Simple magnitude-based pruning baseline. - - Score = ||W_i||_p (L1 or L2 norm of weight row/column) - """ - - def __init__( - self, - sparsity: float = 0.5, - structured: bool = True, - norm_type: int = 2, # L1 or L2 - prune_dim: int = 0, - ): - self.sparsity = sparsity - self.structured = structured - self.norm_type = norm_type - self.prune_dim = prune_dim - - def compute_scores( - self, - weight: torch.Tensor, - activations: Optional[torch.Tensor] = None, # Not used, but kept for API consistency - ) -> torch.Tensor: - """Compute magnitude-based importance scores.""" - if self.structured: - # Norm per row (output neuron) - scores = torch.norm(weight, p=self.norm_type, dim=1) - else: - # Per-weight magnitude - scores = torch.abs(weight) - - return scores - - def get_pruning_mask( - self, - scores: torch.Tensor, - sparsity: Optional[float] = None, - ) -> torch.Tensor: - """Create pruning mask.""" - if sparsity is None: - sparsity = self.sparsity - - if self.structured: - num_prune = int(sparsity * scores.numel()) - threshold = torch.kthvalue(scores, max(1, num_prune)).values - mask = scores > threshold - else: - flat_scores = scores.flatten() - num_prune = int(sparsity * flat_scores.numel()) - threshold = torch.kthvalue(flat_scores, max(1, num_prune)).values - mask = scores > threshold - - return mask - - -def apply_structured_pruning( - module: nn.Linear, - mask: torch.Tensor, - prune_dim: int = 0, -) -> nn.Linear: - """ - Apply structured pruning to a Linear layer. - - Args: - module: Linear layer to prune - mask: Boolean mask (True = keep) - prune_dim: 0 = prune output neurons, 1 = prune input features - - Returns: - New Linear layer with reduced dimensions - """ - keep_indices = torch.where(mask)[0] - - if prune_dim == 0: - # Prune output neurons (rows) - new_out_features = keep_indices.numel() - new_weight = module.weight.data[keep_indices, :] - new_bias = module.bias.data[keep_indices] if module.bias is not None else None - - new_module = nn.Linear(module.in_features, new_out_features, bias=module.bias is not None) - new_module.weight.data = new_weight - if new_bias is not None: - new_module.bias.data = new_bias - else: - # Prune input features (columns) - new_in_features = keep_indices.numel() - new_weight = module.weight.data[:, keep_indices] - - new_module = nn.Linear(new_in_features, module.out_features, bias=module.bias is not None) - new_module.weight.data = new_weight - if module.bias is not None: - new_module.bias.data = module.bias.data - - return new_module - - -def compare_pruning_methods( - model: nn.Module, - calibration_data: torch.Tensor, - sparsity_levels: List[float] = [0.3, 0.5, 0.7], - methods: List[str] = ["magnitude", "wanda", "sparsegpt"], -) -> Dict[str, Dict[float, Dict[str, Any]]]: - """ - Compare different pruning methods on a model. - - Args: - model: Model to prune - calibration_data: Data for computing activation statistics - sparsity_levels: List of sparsity levels to test - methods: List of pruning methods to compare - - Returns: - Dictionary with results per method and sparsity level - """ - results = {method: {} for method in methods} - - # Initialize pruning methods - pruners = { - "magnitude": MagnitudePruning(structured=True), - "wanda": WandaPruning(structured=True), - "sparsegpt": SparseGPTStylePruning(structured=True), - } - - for method in methods: - if method not in pruners: - logger.warning(f"Unknown pruning method: {method}") - continue - - pruner = pruners[method] - - for sparsity in sparsity_levels: - pruner.sparsity = sparsity - - # Collect scores for all layers - layer_scores = {} - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - # Get activations for this layer (would need hooks in practice) - # This is a simplified version - scores = pruner.compute_scores( - module.weight.data, - calibration_data, - ) - layer_scores[name] = { - "scores": scores, - "mask": pruner.get_pruning_mask(scores, sparsity), - } - - results[method][sparsity] = { - "layer_scores": layer_scores, - "sparsity": sparsity, - } - - return results - - -# Registry for easy access -PRUNING_METHODS = { - "magnitude": MagnitudePruning, - "wanda": WandaPruning, - "sparsegpt": SparseGPTStylePruning, -} - - -def get_pruning_method(name: str, **kwargs) -> Any: - """Get a pruning method by name.""" - if name not in PRUNING_METHODS: - raise ValueError(f"Unknown pruning method: {name}. Available: {list(PRUNING_METHODS.keys())}") - return PRUNING_METHODS[name](**kwargs) - diff --git a/src/alignment/pruning/dependency_aware.py b/src/alignment/pruning/dependency_aware.py index e6eace68..fca4005c 100644 --- a/src/alignment/pruning/dependency_aware.py +++ b/src/alignment/pruning/dependency_aware.py @@ -131,7 +131,14 @@ def __init__(self, model: nn.Module): self.model = model self.dependency_graph = DependencyGraph(model) - def prune(self, layer_scores: Dict[str, torch.Tensor], amount: float, mode: str = "low", dry_run: bool = False) -> Dict[str, Any]: + def prune( + self, + layer_scores: Dict[str, torch.Tensor], + amount: float, + mode: str = "low", + dry_run: bool = False, + per_layer_amounts: Optional[Dict[str, float]] = None, + ) -> Dict[str, Any]: """ Apply structured pruning with dependency awareness. @@ -148,7 +155,7 @@ def prune(self, layer_scores: Dict[str, torch.Tensor], amount: float, mode: str - 'validation': Shape validation results """ # 1. Create initial masks from scores - initial_masks = self._create_initial_masks(layer_scores, amount, mode) + initial_masks = self._create_initial_masks(layer_scores, amount, mode, per_layer_amounts) # 2. Propagate masks to handle dependencies propagated_masks = self._propagate_masks(initial_masks) @@ -170,7 +177,13 @@ def prune(self, layer_scores: Dict[str, torch.Tensor], amount: float, mode: str return {"masks": propagated_masks, "stats": stats, "validation": validation} - def _create_initial_masks(self, layer_scores: Dict[str, torch.Tensor], amount: float, mode: str) -> Dict[str, torch.Tensor]: + def _create_initial_masks( + self, + layer_scores: Dict[str, torch.Tensor], + default_amount: float, + mode: str, + per_layer_amounts: Optional[Dict[str, float]] = None, + ) -> Dict[str, torch.Tensor]: """Create initial output masks from importance scores.""" from ..services.mask_ops import MaskOperations @@ -181,8 +194,11 @@ def _create_initial_masks(self, layer_scores: Dict[str, torch.Tensor], amount: f logger.warning(f"Layer {layer_name} not in dependency graph, skipping") continue + layer_amount = default_amount + if per_layer_amounts and layer_name in per_layer_amounts: + layer_amount = per_layer_amounts[layer_name] # Create structured mask (output neurons/channels) - mask = MaskOperations.create_structured_mask(scores, amount=amount, mode=mode) + mask = MaskOperations.create_structured_mask(scores, amount=layer_amount, mode=mode) initial_masks[layer_name] = mask diff --git a/src/alignment/pruning/experiments/__init__.py b/src/alignment/pruning/experiments/__init__.py deleted file mode 100644 index 2f6d7a09..00000000 --- a/src/alignment/pruning/experiments/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Pruning experiments for alignment analysis. -""" - -from .cascading_layer import CascadingConfig, CascadingLayerPruningExperiment -from .eigenvector_based import EigenvectorConfig, EigenvectorDropoutExperiment -from .global_pruning import GlobalDropoutConfig, GlobalDropoutExperiment -from .layer_wise import LayerIsolatedConfig, LayerIsolatedPruningExperiment - -__all__ = [ - "EigenvectorDropoutExperiment", - "EigenvectorConfig", - "CascadingLayerPruningExperiment", - "CascadingConfig", - "LayerIsolatedPruningExperiment", - "LayerIsolatedConfig", - "GlobalDropoutExperiment", - "GlobalDropoutConfig", -] diff --git a/src/alignment/pruning/experiments/cascading_layer.py b/src/alignment/pruning/experiments/cascading_layer.py deleted file mode 100644 index ff30e0df..00000000 --- a/src/alignment/pruning/experiments/cascading_layer.py +++ /dev/null @@ -1,479 +0,0 @@ -""" -Cascading layer pruning experiment. - -This module implements progressive pruning that cascades through layers, -where pruning in earlier layers affects later layers. -""" - -import logging -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import torch - -from alignment.core.registry import register_experiment -from alignment.experiments.base import BaseExperiment, ExperimentConfig -from alignment.experiments.training_utils import convert_training_history, create_experiment_trainer, train_with_metrics -from alignment.models import ModelWrapper - -logger = logging.getLogger(__name__) - - -@dataclass -class CascadingConfig(ExperimentConfig): - """Configuration for cascading layer pruning experiment.""" - - # Dropout configuration - dropout_rates: List[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]) - dropout_mode: str = "scaled" # "scaled" or "unscaled" - cascade_direction: str = "forward" # "forward" or "backward" - - # Pruning configuration - pruning_metric: str = "rayleigh_quotient" - pruning_strategy: str = "low" # "low", "high", "random" - exclude_classification_layer: bool = True - recompute_scores: bool = True # Whether to recompute scores after each layer pruning - - # CNN preprocessing mode - cnn_mode: str = "unfold" # "unfold", "patchwise", "batch_patch_combined" - - # Training configuration - train_before_dropout: bool = True - training_epochs: int = 10 - learning_rate: float = 0.001 - optimizer: str = "adam" - - # Evaluation - eval_batches: Optional[int] = None - num_random_trials: int = 3 - - -@register_experiment("cascading_layer_pruning") -class CascadingLayerPruningExperiment(BaseExperiment): - """ - Experiment for cascading layer pruning. - - This experiment: - 1. Trains a model (if configured) - 2. Progressively prunes layers in order (forward or backward) - 3. Optionally recomputes alignment scores after each layer - 4. Evaluates performance at different pruning levels - """ - - def __init__(self, config: CascadingConfig): - """Initialize cascading layer pruning experiment.""" - super().__init__(config) - self.original_weights = {} - self.layer_order = [] - - def _get_layer_order(self) -> List[str]: - """ - Get the order of layers for cascading pruning. - - Returns: - List of layer names in pruning order - """ - layers = [ - name for name in self.wrapped_model.tracked_layers if not (self.config.exclude_classification_layer and "classifier" in name.lower()) - ] - - if self.config.cascade_direction == "backward": - layers = layers[::-1] - - logger.info(f"Layer pruning order ({self.config.cascade_direction}): {layers}") - return layers - - def _get_layer_type(self, layer_name: str) -> str: - """Get the type of a layer (linear, conv, etc.).""" - layer_info = self.wrapped_model.get_layer_info(layer_name) - return layer_info.get("type", "unknown").lower() - - def _get_appropriate_metric(self, layer_name: str): - """Get the appropriate metric for a layer based on its type.""" - layer_type = self._get_layer_type(layer_name) - - # Use patchwise RQ for conv layers if using RQ metric - if self.config.pruning_metric == "rayleigh_quotient" and "conv" in layer_type: - # Check if patchwise variant exists - if "rq_patchwise" in self.metrics: - logger.debug(f"Using patchwise RQ for conv layer {layer_name}") - return self.metrics["rq_patchwise"] - - # Default to configured metric - return self.metrics[self.config.pruning_metric] - - def _compute_alignment_scores(self, layer_name: str, active_masks: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: - """ - Compute alignment scores for a specific layer. - - Args: - layer_name: Name of the layer to compute scores for - active_masks: Current active masks for all layers - - Returns: - Alignment scores for the layer - """ - metric = self._get_appropriate_metric(layer_name) - scores_list = [] - - # Apply current masks if provided - if active_masks: - self._apply_masks(active_masks) - - # Compute scores - eval_batches = self.config.eval_batches or len(self.data_loader) - - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if batch_idx >= eval_batches: - break - - inputs = inputs.to(self.config.device) - - # Forward pass - _, activations = self.wrapped_model.forward_with_activations(inputs) - - # Get layer data - layer_inputs = activations.get(f"{layer_name}_input") - layer_weights = self.wrapped_model.get_layer_weights()[layer_name] - - if layer_inputs is None or layer_weights is None: - continue - - # Preprocess activations based on CNN mode - from alignment.dataops.processing import preprocess_layer_activations - - layer_modules = dict(self.wrapped_model._model.named_modules()) - preprocessed = preprocess_layer_activations( - {f"{layer_name}_input": layer_inputs}, layer_modules, mode=self.config.cnn_mode if hasattr(self.config, "cnn_mode") else None - ) - layer_inputs = preprocessed.get(f"{layer_name}_input", layer_inputs) - - # Compute metric - if hasattr(metric, "requires_outputs") and metric.requires_outputs: - layer_outputs = activations.get(f"{layer_name}_output") - if layer_outputs is not None: - preprocessed_out = preprocess_layer_activations( - {f"{layer_name}_output": layer_outputs}, - layer_modules, - mode=self.config.cnn_mode if hasattr(self.config, "cnn_mode") else None, - ) - layer_outputs = preprocessed_out.get(f"{layer_name}_output", layer_outputs) - scores = metric.compute(inputs=layer_inputs, weights=layer_weights, outputs=layer_outputs) - else: - scores = metric.compute(inputs=layer_inputs, weights=layer_weights) - - scores_list.append(scores.cpu()) - - # Restore original weights - if active_masks: - self._restore_original_weights() - - # Aggregate scores - if scores_list: - # Stack scores from different batches and average across batches - return torch.stack(scores_list, dim=0).mean(dim=0) - else: - logger.warning(f"No scores computed for layer {layer_name}") - return torch.zeros(1) - - def _create_layer_mask(self, scores: torch.Tensor, dropout_rate: float, strategy: str = "low", layer_name: Optional[str] = None) -> torch.Tensor: - """ - Create a dropout mask for a single layer. - - Args: - scores: Alignment scores for the layer - dropout_rate: Fraction to drop - strategy: Pruning strategy - layer_name: Optional layer name to get actual layer size - - Returns: - Boolean mask (True = keep, False = drop) - """ - # Get actual layer size if layer name provided - if layer_name is not None: - layer = self.wrapped_model.get_layer(layer_name) - if layer is not None and hasattr(layer, "weight"): - # Use actual output dimension size - if len(layer.weight.shape) >= 1: - actual_neurons = layer.weight.shape[0] - if scores.numel() != actual_neurons: - logger.warning(f"Score size ({scores.numel()}) doesn't match layer size ({actual_neurons}). " f"Using layer size for mask.") - # Create mask based on actual layer size - if scores.numel() == 1: - # Single score - apply uniformly - return torch.ones(actual_neurons, dtype=torch.bool) - else: - # Resize scores if possible - scores = scores[:actual_neurons] if scores.numel() > actual_neurons else scores - - # Handle scalar scores (0-d tensor) - if scores.dim() == 0: - logger.warning("Scores is a scalar, creating single-neuron mask") - return torch.ones(1, dtype=torch.bool) - - # Get number of neurons - num_neurons = scores.numel() - num_drop = int(num_neurons * dropout_rate) - - if num_drop == 0: - return torch.ones(num_neurons, dtype=torch.bool) - - # Ensure scores is 1D - scores = scores.flatten() - - if strategy == "low": - # Drop lowest scoring neurons - sorted_indices = torch.argsort(scores) - mask = torch.ones(num_neurons, dtype=torch.bool) - mask[sorted_indices[:num_drop]] = False - elif strategy == "high": - # Drop highest scoring neurons - sorted_indices = torch.argsort(scores) - mask = torch.ones(num_neurons, dtype=torch.bool) - mask[sorted_indices[-num_drop:]] = False - else: # random - mask = torch.ones(num_neurons, dtype=torch.bool) - random_indices = torch.randperm(num_neurons)[:num_drop] - mask[random_indices] = False - - return mask - - def _apply_masks(self, masks: Dict[str, torch.Tensor]): - """Apply dropout masks to layers.""" - for layer_name, mask in masks.items(): - layer = self.wrapped_model.get_layer(layer_name) - if layer is None: - continue - - # Store original weights - if layer_name not in self.original_weights: - self.original_weights[layer_name] = layer.weight.data.clone() - if hasattr(layer, "bias") and layer.bias is not None: - self.original_weights[layer_name + "_bias"] = layer.bias.data.clone() - - # Apply mask - if hasattr(layer, "weight"): - layer.weight.data = self.original_weights[layer_name].clone() - if len(layer.weight.shape) == 2: # Linear layer - # Mask output neurons (rows) - layer.weight.data[~mask] = 0 - elif len(layer.weight.shape) == 4: # Conv layer - # Mask output channels - # Expand mask to match weight dimensions - expanded_mask = mask.view(-1, 1, 1, 1).expand_as(layer.weight) - layer.weight.data[~expanded_mask] = 0 - - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[layer_name + "_bias"].clone() - layer.bias.data[~mask] = 0 - - def _restore_original_weights(self): - """Restore all layers to original weights.""" - for layer_name, original_weight in self.original_weights.items(): - if "_bias" in layer_name: - continue - - layer = self.wrapped_model.get_layer(layer_name) - if layer is not None and hasattr(layer, "weight"): - layer.weight.data = original_weight.clone() - - bias_key = layer_name + "_bias" - if bias_key in self.original_weights and hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[bias_key].clone() - - def _evaluate_model(self) -> Tuple[float, float]: - """Evaluate model performance.""" - self.model.eval() - total_loss = 0.0 - correct = 0 - total = 0 - - criterion = torch.nn.CrossEntropyLoss() - - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if self.config.eval_batches and batch_idx >= self.config.eval_batches: - break - - inputs, targets = inputs.to(self.config.device), targets.to(self.config.device) - outputs = self.model(inputs) - - loss = criterion(outputs, targets) - total_loss += loss.item() - - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - avg_loss = total_loss / (batch_idx + 1) - accuracy = 100.0 * correct / total - - return avg_loss, accuracy - - def _train_model(self) -> Dict[str, Any]: - """Train the model if configured.""" - if not self.config.train_before_dropout: - logger.info("Skipping initial training") - return {} - - logger.info(f"Training model for {self.config.training_epochs} epochs") - - # Create trainer using the unified interface - trainer = create_experiment_trainer(self.model, asdict(self.config), device=self.config.device) - - # Train with metrics - history = train_with_metrics(trainer, self.data_loader, val_loader=None, compute_accuracy=True) # No validation in original implementation - - # Log final metrics (trainer already logs per-epoch) - if history["train_loss"]: - final_metrics = {"train_loss": history["train_loss"][-1], "train_accuracy": history["train_metrics"][-1].get("accuracy", 0.0)} - self.log_metrics(len(history["train_loss"]) - 1, final_metrics) - - # Return training results - return convert_training_history(history) - - def _cascading_prune(self, dropout_rate: float, strategy: str = "low") -> Tuple[Dict[str, torch.Tensor], Dict[str, List[float]]]: - """ - Perform cascading pruning at a specific dropout rate. - - Args: - dropout_rate: Target dropout rate - strategy: Pruning strategy - - Returns: - Tuple of (final masks, layer-wise scores history) - """ - masks = {} - scores_history = {} - - # Process layers in order - for layer_idx, layer_name in enumerate(self.layer_order): - logger.debug(f"Processing layer {layer_idx+1}/{len(self.layer_order)}: {layer_name}") - - # Compute scores (with previously pruned layers masked) - if self.config.recompute_scores and masks: - scores = self._compute_alignment_scores(layer_name, masks) - else: - # Use initial scores - scores = self._compute_alignment_scores(layer_name) - - scores_history[layer_name] = scores.flatten().tolist() if scores.dim() > 0 else [scores.item()] - - # Create mask for this layer - mask = self._create_layer_mask(scores, dropout_rate, strategy, layer_name) - masks[layer_name] = mask - - # Log pruning info - active_neurons = int(mask.sum().item()) - total_neurons = len(mask) - logger.debug(f" Layer {layer_name}: {active_neurons}/{total_neurons} neurons active") - - return masks, scores_history - - def run(self) -> Dict[str, Any]: - """ - Run the cascading layer pruning experiment. - - Returns: - Dictionary containing experiment results - """ - logger.info("Starting cascading layer pruning experiment") - - # Train model - training_results = self._train_model() - - # Get layer order - self.layer_order = self._get_layer_order() - - # Initialize results - results = { - "config": self.config.to_dict(), - "layer_order": self.layer_order, - "dropout_rates": self.config.dropout_rates, - "accuracies": {"low": [], "high": [], "random": []}, - "losses": {"low": [], "high": [], "random": []}, - "layer_scores": {}, - "cascade_masks": {}, - "training_results": training_results, # Include training results - } - - # Evaluate at each dropout rate - for dropout_idx, dropout_rate in enumerate(self.config.dropout_rates): - logger.info(f"Evaluating dropout rate: {dropout_rate}") - - # Store masks for this dropout rate - cascade_info = {} - - # Evaluate each strategy - for strategy in ["low", "high", "random"]: - if strategy == "random": - # Average over multiple trials - trial_losses = [] - trial_accs = [] - - for trial in range(self.config.num_random_trials): - # Perform cascading pruning - masks, _ = self._cascading_prune(dropout_rate, strategy) - - # Apply masks - self._apply_masks(masks) - - # Evaluate - loss, acc = self._evaluate_model() - trial_losses.append(loss) - trial_accs.append(acc) - - # Restore weights - self._restore_original_weights() - - avg_loss = np.mean(trial_losses) - avg_acc = np.mean(trial_accs) - - else: - # Perform cascading pruning - masks, scores_history = self._cascading_prune(dropout_rate, strategy) - - # Store cascade info - if strategy == "low": # Store detailed info only for one strategy - cascade_info = { - "masks": {k: v.tolist() for k, v in masks.items()}, - "scores": scores_history, - "active_neurons": {k: int(v.sum().item()) for k, v in masks.items()}, - } - - # Apply masks - self._apply_masks(masks) - - # Evaluate - avg_loss, avg_acc = self._evaluate_model() - - # Restore weights - self._restore_original_weights() - - # Store results - results["losses"][strategy].append(avg_loss) - results["accuracies"][strategy].append(avg_acc) - - logger.info(f" {strategy}: Loss={avg_loss:.4f}, Accuracy={avg_acc:.2f}%") - - # Log metrics - self.log_metrics( - dropout_idx * 3 + ["low", "high", "random"].index(strategy), - {f"{strategy}_loss": avg_loss, f"{strategy}_accuracy": avg_acc, "dropout_rate": dropout_rate}, - ) - - # Store cascade info - results["cascade_masks"][f"dropout_{dropout_rate}"] = cascade_info - - # Save results - self.results.update(results) - self.save_results() - - # Save final checkpoint - self.save_checkpoint(step=len(self.config.dropout_rates), metrics={"final_results": results}) - - logger.info("Cascading layer pruning experiment completed") - - return results diff --git a/src/alignment/pruning/experiments/eigenvector_based.py b/src/alignment/pruning/experiments/eigenvector_based.py deleted file mode 100644 index d134d00d..00000000 --- a/src/alignment/pruning/experiments/eigenvector_based.py +++ /dev/null @@ -1,417 +0,0 @@ -""" -Eigenvector dropout experiment. - -This module implements pruning based on PCA/eigendecomposition, -dropping neurons based on their contribution to principal components. -""" - -import logging -from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import torch - -from alignment.core.registry import register_experiment -from alignment.experiments.base import BaseExperiment, ExperimentConfig -from alignment.experiments.training_utils import convert_training_history, create_experiment_trainer, train_with_metrics - -logger = logging.getLogger(__name__) - - -@dataclass -class EigenvectorConfig(ExperimentConfig): - """Configuration for eigenvector dropout experiment.""" - - # Dropout configuration - dropout_rates: List[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]) - dropout_mode: str = "scaled" - - # Eigenvector configuration - compute_layer_pca: bool = True # Whether to compute PCA per layer - n_components_ratio: float = 0.99 # Variance ratio to keep - eigenvector_strategy: str = "low" # "low" = drop low eigenvalue components, "high" = drop high - - # Training configuration - train_before_dropout: bool = True - training_epochs: int = 10 - learning_rate: float = 0.001 - optimizer: str = "adam" - - # Evaluation - eval_batches: Optional[int] = None - exclude_classification_layer: bool = True - num_random_trials: int = 3 - - -@register_experiment("eigenvector_dropout") -class EigenvectorDropoutExperiment(BaseExperiment): - """ - Experiment for eigenvector-based dropout. - - This experiment: - 1. Trains a model (if configured) - 2. Computes PCA/eigendecomposition of layer activations - 3. Prunes neurons based on their eigenvalue rankings - 4. Evaluates performance at different pruning levels - """ - - def __init__(self, config: EigenvectorConfig): - """Initialize eigenvector dropout experiment.""" - super().__init__(config) - self.config = config - self.eigendecomposition = {} - self.original_weights = {} - - def _compute_layer_eigendecomposition(self) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: - """ - Compute eigendecomposition for each layer's activations. - - Returns: - Dictionary mapping layer names to (eigenvalues, eigenvectors) - """ - logger.info("Computing eigendecomposition for each layer") - - layer_eigen = {} - eval_batches = self.config.eval_batches or len(self.data_loader) - - # Collect activations for each layer - layer_activations = { - name: [] for name in self.wrapped_model.tracked_layers if not (self.config.exclude_classification_layer and "classifier" in name.lower()) - } - - # Gather activations - self.model.eval() - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if batch_idx >= eval_batches: - break - - inputs = inputs.to(self.config.device) - - # Forward pass - _, activations = self.wrapped_model.forward_with_activations(inputs) - - # Collect layer outputs - for layer_name in layer_activations: - layer_output = activations.get(f"{layer_name}_output") - if layer_output is not None: - # Flatten spatial dimensions for conv layers - if len(layer_output.shape) > 2: - layer_output = layer_output.flatten(2).mean(dim=2) - layer_activations[layer_name].append(layer_output.cpu()) - - # Compute eigendecomposition for each layer - for layer_name, act_list in layer_activations.items(): - if not act_list: - logger.warning(f"No activations collected for layer {layer_name}") - continue - - # Concatenate all activations - all_activations = torch.cat(act_list, dim=0) # [total_samples, features] - - # Compute covariance matrix - activations_centered = all_activations - all_activations.mean(dim=0, keepdim=True) - cov_matrix = torch.mm(activations_centered.t(), activations_centered) / (all_activations.size(0) - 1) - - # Eigendecomposition - eigenvalues, eigenvectors = torch.linalg.eigh(cov_matrix) - - # Sort by eigenvalue (descending) - idx = eigenvalues.argsort(descending=True) - eigenvalues = eigenvalues[idx] - eigenvectors = eigenvectors[:, idx] - - layer_eigen[layer_name] = (eigenvalues, eigenvectors) - - # Log variance explained - total_var = eigenvalues.sum() - if total_var > 0: - var_explained = eigenvalues.cumsum(0) / total_var - n_components = (var_explained <= self.config.n_components_ratio).sum() + 1 - logger.info( - f"Layer {layer_name}: {n_components}/{len(eigenvalues)} components " f"explain {self.config.n_components_ratio*100:.1f}% variance" - ) - - return layer_eigen - - def _create_eigenvector_masks( - self, eigendecomposition: Dict[str, Tuple[torch.Tensor, torch.Tensor]], dropout_rate: float, strategy: str = "low" - ) -> Dict[str, torch.Tensor]: - """ - Create dropout masks based on eigenvalue rankings. - - Args: - eigendecomposition: Layer eigenvalues and eigenvectors - dropout_rate: Fraction of neurons to drop - strategy: "low" drops low eigenvalue neurons, "high" drops high - - Returns: - Dictionary of masks for each layer - """ - masks = {} - - for layer_name, (eigenvalues, eigenvectors) in eigendecomposition.items(): - num_neurons = len(eigenvalues) - num_drop = int(num_neurons * dropout_rate) - - if num_drop == 0: - masks[layer_name] = torch.ones(num_neurons, dtype=torch.bool) - continue - - # Create mask based on eigenvalue ranking - mask = torch.ones(num_neurons, dtype=torch.bool) - - if strategy == "low": - # Drop neurons with lowest eigenvalues - mask[-num_drop:] = False - elif strategy == "high": - # Drop neurons with highest eigenvalues - mask[:num_drop] = False - else: # random - random_indices = torch.randperm(num_neurons)[:num_drop] - mask[random_indices] = False - - masks[layer_name] = mask - - logger.debug(f"Layer {layer_name}: keeping {mask.sum().item()}/{num_neurons} neurons") - - return masks - - def _project_weights_to_eigenspace(self, layer_name: str, eigenvectors: torch.Tensor): - """ - Project layer weights to eigenspace. - - Args: - layer_name: Name of the layer - eigenvectors: Eigenvector matrix for the layer - """ - layer = self.wrapped_model.get_layer(layer_name) - if layer is None or not hasattr(layer, "weight"): - return - - # Store original weights - if layer_name not in self.original_weights: - self.original_weights[layer_name] = layer.weight.data.clone() - if hasattr(layer, "bias") and layer.bias is not None: - self.original_weights[layer_name + "_bias"] = layer.bias.data.clone() - - # Project weights to eigenspace - weight = layer.weight.data - if len(weight.shape) == 2: # Linear layer - # Weight shape: [out_features, in_features] - # Eigenvectors shape: [in_features, n_components] - # We need to project the input dimension - projected_weight = torch.mm(weight, eigenvectors) - layer.weight.data = torch.mm(projected_weight, eigenvectors.t()) - elif len(weight.shape) == 4: # Conv layer - # For conv layers, we'll just apply the mask directly - # A more sophisticated approach would reshape and project properly - pass - - def _apply_eigenvector_masks(self, masks: Dict[str, torch.Tensor], eigendecomposition: Dict[str, Tuple[torch.Tensor, torch.Tensor]]): - """Apply eigenvector-based dropout masks.""" - for layer_name, mask in masks.items(): - layer = self.wrapped_model.get_layer(layer_name) - if layer is None: - continue - - eigenvalues, eigenvectors = eigendecomposition[layer_name] - - # Store original weights - if layer_name not in self.original_weights: - self.original_weights[layer_name] = layer.weight.data.clone() - if hasattr(layer, "bias") and layer.bias is not None: - self.original_weights[layer_name + "_bias"] = layer.bias.data.clone() - - # Apply mask in eigenspace - if hasattr(layer, "weight"): - weight = self.original_weights[layer_name].clone() - - if len(weight.shape) == 2: # Linear layer - # Keep only selected eigenvectors - selected_eigenvectors = eigenvectors[:, mask] - - # Project weights to reduced eigenspace and back - projected = torch.mm(weight, selected_eigenvectors) - layer.weight.data = torch.mm(projected, selected_eigenvectors.t()) - - # Also mask the output neurons directly - layer.weight.data[~mask] = 0 - - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[layer_name + "_bias"].clone() - layer.bias.data[~mask] = 0 - - elif len(weight.shape) == 4: # Conv layer - # For conv layers, directly mask output channels - layer.weight.data = weight - layer.weight.data[~mask] = 0 - - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[layer_name + "_bias"].clone() - layer.bias.data[~mask] = 0 - - def _restore_original_weights(self): - """Restore original weights.""" - for layer_name, original_weight in self.original_weights.items(): - if "_bias" in layer_name: - continue - - layer = self.wrapped_model.get_layer(layer_name) - if layer is not None and hasattr(layer, "weight"): - layer.weight.data = original_weight.clone() - - bias_key = layer_name + "_bias" - if bias_key in self.original_weights and hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[bias_key].clone() - - def _evaluate_model(self) -> Tuple[float, float]: - """Evaluate model performance.""" - self.model.eval() - total_loss = 0.0 - correct = 0 - total = 0 - - criterion = torch.nn.CrossEntropyLoss() - - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if self.config.eval_batches and batch_idx >= self.config.eval_batches: - break - - inputs, targets = inputs.to(self.config.device), targets.to(self.config.device) - outputs = self.model(inputs) - - loss = criterion(outputs, targets) - total_loss += loss.item() - - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - avg_loss = total_loss / (batch_idx + 1) - accuracy = 100.0 * correct / total - - return avg_loss, accuracy - - def _train_model(self) -> Dict[str, Any]: - """Train the model if configured.""" - if not self.config.train_before_dropout: - logger.info("Skipping initial training") - return {} - - logger.info(f"Training model for {self.config.training_epochs} epochs") - - # Create trainer using the unified interface - trainer = create_experiment_trainer(self.model, asdict(self.config), device=self.config.device) - - # Train with metrics - history = train_with_metrics(trainer, self.data_loader, val_loader=None, compute_accuracy=True) # No validation in original implementation - - # Log final metrics (trainer already logs per-epoch) - if history["train_loss"]: - final_metrics = {"train_loss": history["train_loss"][-1], "train_accuracy": history["train_metrics"][-1].get("accuracy", 0.0)} - self.log_metrics(len(history["train_loss"]) - 1, final_metrics) - - # Return training results - return convert_training_history(history) - - def run(self) -> Dict[str, Any]: - """ - Run the eigenvector dropout experiment. - - Returns: - Dictionary containing experiment results - """ - logger.info("Starting eigenvector dropout experiment") - - # Train model - training_results = self._train_model() - - # Compute eigendecomposition - self.eigendecomposition = self._compute_layer_eigendecomposition() - - # Initialize results - results = { - "dropout_rates": self.config.dropout_rates, - "eigenvalues": { - layer: eigenvalues.tolist()[:20] for layer, (eigenvalues, _) in self.eigendecomposition.items() # Store top 20 eigenvalues - }, - "variance_explained": {}, - "accuracies": {"low": [], "high": [], "random": []}, - "losses": {"low": [], "high": [], "random": []}, - "training_results": training_results, # Include training results - } - - # Compute variance explained - for layer, (eigenvalues, _) in self.eigendecomposition.items(): - total_var = eigenvalues.sum() - if total_var > 0: - var_explained = (eigenvalues.cumsum(0) / total_var).tolist()[:20] - results["variance_explained"][layer] = var_explained - - # Evaluate at each dropout rate - for dropout_rate in self.config.dropout_rates: - logger.info(f"Evaluating dropout rate: {dropout_rate}") - - # Evaluate each strategy - for strategy in ["low", "high", "random"]: - if strategy == "random": - # Average over multiple trials - trial_losses = [] - trial_accs = [] - - for trial in range(self.config.num_random_trials): - # Create masks - masks = self._create_eigenvector_masks(self.eigendecomposition, dropout_rate, strategy) - - # Apply masks - self._apply_eigenvector_masks(masks, self.eigendecomposition) - - # Evaluate - loss, acc = self._evaluate_model() - trial_losses.append(loss) - trial_accs.append(acc) - - # Restore weights - self._restore_original_weights() - - avg_loss = np.mean(trial_losses) - avg_acc = np.mean(trial_accs) - - else: - # Create masks - masks = self._create_eigenvector_masks(self.eigendecomposition, dropout_rate, strategy) - - # Apply masks - self._apply_eigenvector_masks(masks, self.eigendecomposition) - - # Evaluate - avg_loss, avg_acc = self._evaluate_model() - - # Restore weights - self._restore_original_weights() - - # Store results - results["losses"][strategy].append(avg_loss) - results["accuracies"][strategy].append(avg_acc) - - logger.info(f" {strategy}: Loss={avg_loss:.4f}, Accuracy={avg_acc:.2f}%") - - # Log metrics - self.log_metrics( - len(results["losses"][strategy]) - 1, - {f"{strategy}_loss": avg_loss, f"{strategy}_accuracy": avg_acc, "dropout_rate": dropout_rate}, - ) - - # Save results - self.results.update(results) - self.save_results() - - # Save checkpoint - self.save_checkpoint(step=len(self.config.dropout_rates), metrics={"final_results": results}) - - logger.info("Eigenvector dropout experiment completed") - - return results diff --git a/src/alignment/pruning/experiments/global_pruning.py b/src/alignment/pruning/experiments/global_pruning.py deleted file mode 100644 index 5a572738..00000000 --- a/src/alignment/pruning/experiments/global_pruning.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Global dropout experiment for analyzing model alignment under dropout. - -This module implements experiments that apply the same dropout rate globally -across all layers and track changes in alignment metrics. -""" - -import logging -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional - -import numpy as np -import torch -import torch.nn as nn - -from alignment.core.registry import register_experiment -from alignment.experiments.base import BaseExperiment, ExperimentConfig -from alignment.experiments.config_components import PruningConfig -from alignment.experiments.training_utils import convert_training_history, create_experiment_trainer, train_with_metrics - -logger = logging.getLogger(__name__) - - -@dataclass -class GlobalDropoutConfig(ExperimentConfig): - """Configuration for global dropout experiment.""" - - # Dropout configuration - dropout_rates: List[float] = field(default_factory=lambda: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) - dropout_structure: str = "random" # 'random', 'magnitude', 'gradient' - dropout_mode: str = "scaled" # 'scaled' or 'unscaled' - - # Pruning configuration (when using structured dropout) - pruning_mode: str = "global_joint" # 'global_joint', 'layer_wise', etc. - pruning_strategy: str = "low" # 'low', 'high', 'random' - exclude_classification_layer: bool = True - - # Training configuration - train_before_dropout: bool = True - training_epochs: int = 10 - learning_rate: float = 0.001 - optimizer: str = "adam" - - # Evaluation - num_samples: int = 1000 - apply_to_layers: Optional[List[str]] = None - eval_batches: Optional[int] = None - - -@register_experiment("global_dropout") -@register_experiment("progressive_dropout") # Backward compatibility -class GlobalDropoutExperiment(BaseExperiment): - """ - Experiment for applying global dropout to analyze alignment changes. - - This experiment: - 1. Trains a model (if configured) - 2. Applies the same dropout rate globally across all layers - 3. Tracks alignment metrics at each dropout level - 4. Analyzes how dropout affects model structure - """ - - def __init__(self, config: GlobalDropoutConfig): - """ - Initialize global dropout experiment. - """ - super().__init__(config) - - # Results storage - self.dropout_results = {"dropout_rates": self.config.dropout_rates, "metrics_by_rate": {}, "layer_statistics": {}} - - def _train_model(self) -> Dict[str, Any]: - """Train the model if configured.""" - if not self.config.train_before_dropout: - logger.info("Skipping initial training") - return {} - - logger.info(f"Training model for {self.config.training_epochs} epochs") - - # Create trainer using the unified interface - trainer = create_experiment_trainer(self.model, asdict(self.config), device=self.config.device) - - # Train with metrics - history = train_with_metrics(trainer, self.data_loader, val_loader=None, compute_accuracy=True) - - # Log final metrics - if history["train_loss"]: - final_metrics = {"train_loss": history["train_loss"][-1], "train_accuracy": history["train_metrics"][-1].get("accuracy", 0.0)} - self.log_metrics(len(history["train_loss"]) - 1, final_metrics) - - return convert_training_history(history) - - def run(self, models=None, dataset=None, **kwargs) -> Dict[str, Any]: - """ - Run the global dropout experiment. - - Returns: - Experiment results including metrics at each dropout rate - """ - logger.info("Starting global dropout experiment") - - # Collect initial model statistics - self._collect_initial_statistics() - - # Get data samples for evaluation - eval_data = self._get_evaluation_data() - - # Test each dropout rate - for dropout_rate in self.config.dropout_rates: - logger.info(f"Testing dropout rate: {dropout_rate}") - - # Create dropout masks for this rate - dropout_masks = self._create_dropout_masks(dropout_rate) - - # Apply dropout to model - if dropout_masks: - self.wrapped_model.apply_structured_dropout(dropout_masks, mode="multiplicative", permanent=False) - - # Compute metrics with dropout - metrics = self._evaluate_with_dropout(eval_data, dropout_rate) - - # Store results - self.dropout_results["metrics_by_rate"][dropout_rate] = metrics - - # Log progress - self.log_metrics(step=int(dropout_rate * 100), metrics=self._flatten_metrics(metrics, dropout_rate)) # Use dropout % as step - - # Restore original weights - if dropout_masks: - self.wrapped_model.restore_weights() - - # Optional: Save checkpoint - if self.config.checkpoint_interval > 0: - self.save_checkpoint(step=int(dropout_rate * 100), metrics=metrics) - - # Analyze results - self._analyze_dropout_effects() - - # Save final results - self.results.update(self.dropout_results) - self.save_results() - - return self.results - - def _collect_initial_statistics(self): - """Collect statistics about the initial model.""" - logger.info("Collecting initial model statistics") - - weights = self.wrapped_model.get_layer_weights() - - for layer_name, weight in weights.items(): - if weight is None: - continue - - # Compute weight statistics - stats = { - "shape": list(weight.shape), - "num_parameters": weight.numel(), - "mean": weight.mean().item(), - "std": weight.std().item(), - "min": weight.min().item(), - "max": weight.max().item(), - "sparsity": (weight == 0).float().mean().item(), - } - - # Compute norms - stats["l1_norm"] = weight.abs().sum().item() - stats["l2_norm"] = weight.pow(2).sum().sqrt().item() - - self.dropout_results["layer_statistics"][layer_name] = stats - - def _create_dropout_masks(self, dropout_rate: float) -> Dict[str, torch.Tensor]: - """ - Create dropout masks for specified layers. - - Args: - dropout_rate: Fraction of units to drop - - Returns: - Dictionary mapping layer names to binary masks - """ - if dropout_rate == 0.0: - return {} - - dropout_masks = {} - layers_to_apply = self.config.apply_to_layers or self.wrapped_model.tracked_layers - - for layer_name in layers_to_apply: - layer_info = self.wrapped_model.get_layer_info(layer_name) - - if "weight_shape" not in layer_info: - continue - - # Get number of units based on layer type - if layer_info["type"] == "Linear": - num_units = layer_info["out_features"] - elif layer_info["type"] in ["Conv2d", "Conv1d"]: - num_units = layer_info["out_channels"] - else: - continue - - # Create mask based on dropout structure - if self.config.dropout_structure == "random": - # Random dropout - mask = torch.rand(num_units) > dropout_rate - elif self.config.dropout_structure == "magnitude": - # Magnitude-based dropout (keep high magnitude units) - weights = self.wrapped_model.get_layer_weights([layer_name])[layer_name] - magnitudes = weights.abs().sum(dim=tuple(range(1, weights.ndim))) - threshold = torch.quantile(magnitudes, dropout_rate) - mask = magnitudes > threshold - elif self.config.dropout_structure == "gradient": - # Gradient-based dropout (requires gradients) - # For now, fallback to random - mask = torch.rand(num_units) > dropout_rate - else: - raise ValueError(f"Unknown dropout structure: {self.config.dropout_structure}") - - dropout_masks[layer_name] = mask.float().to(self.config.device) - - return dropout_masks - - def _get_evaluation_data(self) -> List[torch.Tensor]: - """Get data samples for evaluation.""" - eval_data = [] - total_samples = 0 - - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - inputs = inputs.to(self.config.device) - eval_data.append(inputs) - - total_samples += inputs.size(0) - if total_samples >= self.config.num_samples: - break - - # Concatenate and trim to exact number - eval_data = torch.cat(eval_data, dim=0)[: self.config.num_samples] - - logger.info(f"Collected {eval_data.size(0)} samples for evaluation") - return eval_data - - def _evaluate_with_dropout(self, eval_data: torch.Tensor, dropout_rate: float) -> Dict[str, Any]: - """ - Evaluate metrics with current dropout settings. - - Args: - eval_data: Data to evaluate on - dropout_rate: Current dropout rate - - Returns: - Dictionary of metrics - """ - # Set model to eval mode (but keep dropout active via context manager) - self.model.eval() - - all_metrics = {} - batch_size = min(self.config.batch_size, eval_data.size(0)) - - # Process in batches - for i in range(0, eval_data.size(0), batch_size): - batch = eval_data[i : i + batch_size] - - # Compute metrics for batch - with torch.no_grad(): - batch_metrics = self.compute_metrics(batch) - - # Accumulate metrics - for metric_name, layer_results in batch_metrics.items(): - if metric_name not in all_metrics: - all_metrics[metric_name] = {} - - for layer_name, value in layer_results.items(): - if layer_name not in all_metrics[metric_name]: - all_metrics[metric_name][layer_name] = [] - all_metrics[metric_name][layer_name].append(value) - - # Average metrics across batches - averaged_metrics = {} - for metric_name, layer_results in all_metrics.items(): - averaged_metrics[metric_name] = {} - for layer_name, values in layer_results.items(): - averaged_metrics[metric_name][layer_name] = np.mean(values) - - # Add additional statistics - averaged_metrics["_statistics"] = { - "dropout_rate": dropout_rate, - "num_samples": eval_data.size(0), - "effective_sparsity": self._compute_effective_sparsity(), - } - - return averaged_metrics - - def _compute_effective_sparsity(self) -> Dict[str, float]: - """Compute effective sparsity after dropout.""" - sparsity = {} - - for name, module in self.model.named_modules(): - if hasattr(module, "weight") and module.weight is not None: - weight = module.weight - # Count zeros (including dropout-induced zeros) - num_zeros = (weight == 0).float().sum().item() - total_params = weight.numel() - sparsity[name] = num_zeros / total_params - - return sparsity - - def _flatten_metrics(self, metrics: Dict[str, Any], dropout_rate: float) -> Dict[str, float]: - """Flatten metrics dictionary for logging.""" - flat_metrics = {f"dropout_rate": dropout_rate} - - for metric_name, layer_results in metrics.items(): - if metric_name.startswith("_"): - continue # Skip internal metrics - - for layer_name, value in layer_results.items(): - key = f"{metric_name}/{layer_name}" - flat_metrics[key] = value - - return flat_metrics - - def _analyze_dropout_effects(self): - """Analyze how dropout affects alignment metrics.""" - logger.info("Analyzing dropout effects on alignment") - - analysis = {"metric_trends": {}, "layer_sensitivity": {}, "critical_dropout_rates": {}} - - # Analyze trends for each metric and layer - for metric_name in self.metrics.keys(): - analysis["metric_trends"][metric_name] = {} - - # Get all layers that have this metric - all_layers = set() - for rate_results in self.dropout_results["metrics_by_rate"].values(): - if metric_name in rate_results: - all_layers.update(rate_results[metric_name].keys()) - - for layer_name in all_layers: - # Collect values across dropout rates - rates = [] - values = [] - - for rate, results in self.dropout_results["metrics_by_rate"].items(): - if metric_name in results and layer_name in results[metric_name]: - rates.append(rate) - values.append(results[metric_name][layer_name]) - - if len(values) < 2: - continue - - # Compute trend statistics - values = np.array(values) - rates = np.array(rates) - - # Linear regression to find trend - coeffs = np.polyfit(rates, values, 1) - slope = coeffs[0] - - # Find critical points (large changes) - if len(values) > 2: - diffs = np.diff(values) - max_change_idx = np.argmax(np.abs(diffs)) - critical_rate = rates[max_change_idx + 1] - else: - critical_rate = None - - analysis["metric_trends"][metric_name][layer_name] = { - "slope": float(slope), - "initial_value": float(values[0]), - "final_value": float(values[-1]), - "percent_change": float((values[-1] - values[0]) / (values[0] + 1e-8) * 100), - "critical_dropout_rate": float(critical_rate) if critical_rate else None, - } - - # Compute layer sensitivity (how much each layer is affected by dropout) - for layer_name in self.wrapped_model.tracked_layers: - sensitivities = [] - - for metric_name in self.metrics.keys(): - if metric_name in analysis["metric_trends"] and layer_name in analysis["metric_trends"][metric_name]: - trend = analysis["metric_trends"][metric_name][layer_name] - sensitivities.append(abs(trend["percent_change"])) - - if sensitivities: - analysis["layer_sensitivity"][layer_name] = { - "mean_sensitivity": float(np.mean(sensitivities)), - "max_sensitivity": float(np.max(sensitivities)), - "sensitivity_scores": sensitivities, - } - - self.dropout_results["analysis"] = analysis - logger.info("Dropout analysis complete") diff --git a/src/alignment/pruning/experiments/layer_wise.py b/src/alignment/pruning/experiments/layer_wise.py deleted file mode 100644 index 9024daa2..00000000 --- a/src/alignment/pruning/experiments/layer_wise.py +++ /dev/null @@ -1,410 +0,0 @@ -""" -Layer-isolated pruning experiment. - -This module implements pruning where each layer is pruned independently -based on its alignment scores, without considering other layers. -""" - -import json -import logging -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import torch - -from alignment.core.registry import register_experiment -from alignment.experiments.base import BaseExperiment, ExperimentConfig -from alignment.experiments.training_utils import convert_training_history, create_experiment_trainer, train_with_metrics -from alignment.models import ModelWrapper - -logger = logging.getLogger(__name__) - - -@dataclass -class LayerIsolatedConfig(ExperimentConfig): - """Configuration for layer-isolated pruning experiment.""" - - # Dropout configuration - dropout_rates: List[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]) - dropout_mode: str = "scaled" # "scaled" or "unscaled" - - # Pruning configuration - pruning_metric: str = "rayleigh_quotient" # Metric to use for pruning decisions - pruning_strategy: str = "low" # "low", "high", "random" - exclude_classification_layer: bool = True - - # CNN preprocessing mode - cnn_mode: str = "unfold" # "unfold", "patchwise", "batch_patch_combined" - - # Training configuration - train_before_dropout: bool = True - training_epochs: int = 10 - learning_rate: float = 0.001 - optimizer: str = "adam" - - # Evaluation - eval_batches: Optional[int] = None # None for full dataset - num_random_trials: int = 3 # Number of random pruning trials - - -@register_experiment("layer_isolated_pruning") -class LayerIsolatedPruningExperiment(BaseExperiment): - """ - Experiment for layer-isolated pruning. - - This experiment: - 1. Trains a model (if configured) - 2. Computes alignment scores for each layer independently - 3. Prunes neurons in each layer based on their scores - 4. Evaluates performance at different pruning levels - """ - - def __init__(self, config: LayerIsolatedConfig): - """Initialize layer-isolated pruning experiment.""" - super().__init__(config) - # config is already handled by parent class - self.pruning_scores = {} - self.original_weights = {} - - def _compute_layer_scores(self) -> Dict[str, torch.Tensor]: - """ - Compute pruning scores for each layer independently. - - Returns: - Dictionary mapping layer names to score tensors - """ - logger.info(f"Computing {self.config.pruning_metric} scores for each layer") - - layer_scores = {} - metric = self.metrics[self.config.pruning_metric] - - # Evaluate on subset of data if configured - eval_batches = self.config.eval_batches or len(self.data_loader) - - for layer_name in self.wrapped_model.tracked_layers: - if self.config.exclude_classification_layer and "classifier" in layer_name.lower(): - logger.info(f"Skipping classification layer: {layer_name}") - continue - - scores_list = [] - - # Compute scores over batches - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if batch_idx >= eval_batches: - break - - inputs = inputs.to(self.config.device) - - # Forward pass with activation tracking - _, activations = self.wrapped_model.forward_with_activations(inputs) - - # Get layer-specific data - layer_inputs = activations.get(f"{layer_name}_input") - layer_weights = self.wrapped_model.get_layer_weights()[layer_name] - - if layer_inputs is None or layer_weights is None: - continue - - # Preprocess activations based on CNN mode - from alignment.dataops.processing import preprocess_layer_activations - - layer_modules = dict(self.wrapped_model._model.named_modules()) - preprocessed = preprocess_layer_activations( - {f"{layer_name}_input": layer_inputs}, layer_modules, mode=self.config.cnn_mode if hasattr(self.config, "cnn_mode") else None - ) - layer_inputs = preprocessed.get(f"{layer_name}_input", layer_inputs) - - # Compute metric scores - if hasattr(metric, "requires_outputs") and metric.requires_outputs: - layer_outputs = activations.get(f"{layer_name}_output") - if layer_outputs is not None: - preprocessed_out = preprocess_layer_activations( - {f"{layer_name}_output": layer_outputs}, - layer_modules, - mode=self.config.cnn_mode if hasattr(self.config, "cnn_mode") else None, - ) - layer_outputs = preprocessed_out.get(f"{layer_name}_output", layer_outputs) - scores = metric.compute(inputs=layer_inputs, weights=layer_weights, outputs=layer_outputs) - else: - scores = metric.compute(inputs=layer_inputs, weights=layer_weights) - - scores_list.append(scores.cpu()) - - # Aggregate scores across batches - if scores_list: - # Stack scores from different batches and average across batches - layer_scores[layer_name] = torch.stack(scores_list, dim=0).mean(dim=0) - logger.info(f"Layer {layer_name}: computed {len(layer_scores[layer_name])} scores") - else: - logger.warning(f"No scores computed for layer {layer_name}") - - return layer_scores - - def _create_dropout_masks(self, layer_scores: Dict[str, torch.Tensor], dropout_rate: float) -> Dict[str, Dict[str, torch.Tensor]]: - """ - Create dropout masks for each layer based on scores. - - Args: - layer_scores: Scores for each layer - dropout_rate: Fraction of neurons to drop - - Returns: - Dictionary with masks for each strategy - """ - masks = {"low": {}, "high": {}, "random": {}} - - for layer_name, scores in layer_scores.items(): - # Handle scalar scores (0-d tensor) - if scores.dim() == 0: - logger.warning(f"Scores for layer {layer_name} is a scalar; using probabilistic dropout when rate > 0") - keep_mask_scalar = True - if dropout_rate > 0: - # Drop the single neuron with probability equal to dropout_rate - keep_mask_scalar = torch.rand(1).item() > dropout_rate - keep_mask_tensor = torch.tensor([keep_mask_scalar], dtype=torch.bool) - # low/high identical for single neuron; random follows same probabilistic rule - masks["low"][layer_name] = keep_mask_tensor.clone() - masks["high"][layer_name] = keep_mask_tensor.clone() - masks["random"][layer_name] = keep_mask_tensor.clone() - continue - - # Get number of neurons and ensure scores is 1D - scores = scores.flatten() - num_neurons = scores.numel() - num_drop = int(num_neurons * dropout_rate) - - if num_drop == 0: - # No dropout for this layer - for strategy in masks: - masks[strategy][layer_name] = torch.ones(num_neurons, dtype=torch.bool) - continue - - # Sort scores to get indices - sorted_indices = torch.argsort(scores) - - # Low scores mask (drop lowest scoring neurons) - low_mask = torch.ones(num_neurons, dtype=torch.bool) - low_mask[sorted_indices[:num_drop]] = False - masks["low"][layer_name] = low_mask - - # High scores mask (drop highest scoring neurons) - high_mask = torch.ones(num_neurons, dtype=torch.bool) - high_mask[sorted_indices[-num_drop:]] = False - masks["high"][layer_name] = high_mask - - # Random mask - random_mask = torch.ones(num_neurons, dtype=torch.bool) - random_indices = torch.randperm(num_neurons)[:num_drop] - random_mask[random_indices] = False - masks["random"][layer_name] = random_mask - - logger.debug(f"Layer {layer_name}: dropping {num_drop}/{num_neurons} neurons") - - return masks - - def _apply_layer_masks(self, masks: Dict[str, torch.Tensor]): - """Apply dropout masks to each layer independently.""" - for layer_name, mask in masks.items(): - layer = self.wrapped_model.get_layer(layer_name) - if layer is None: - continue - - # Store original weights if not already stored - if layer_name not in self.original_weights: - self.original_weights[layer_name] = layer.weight.data.clone() - - # Apply mask based on layer type - if hasattr(layer, "weight"): - if len(layer.weight.shape) == 2: # Linear layer - # Mask output neurons - layer.weight.data = self.original_weights[layer_name].clone() - layer.weight.data[~mask] = 0 - - if hasattr(layer, "bias") and layer.bias is not None: - if layer_name + "_bias" not in self.original_weights: - self.original_weights[layer_name + "_bias"] = layer.bias.data.clone() - layer.bias.data = self.original_weights[layer_name + "_bias"].clone() - layer.bias.data[~mask] = 0 - - elif len(layer.weight.shape) == 4: # Conv layer - # Mask output channels - layer.weight.data = self.original_weights[layer_name].clone() - # Expand mask to match weight dimensions - expanded_mask = mask.view(-1, 1, 1, 1).expand_as(layer.weight) - layer.weight.data[~expanded_mask] = 0 - - if hasattr(layer, "bias") and layer.bias is not None: - if layer_name + "_bias" not in self.original_weights: - self.original_weights[layer_name + "_bias"] = layer.bias.data.clone() - layer.bias.data = self.original_weights[layer_name + "_bias"].clone() - layer.bias.data[~mask] = 0 - - def _restore_original_weights(self): - """Restore original weights to all layers.""" - for layer_name, original_weight in self.original_weights.items(): - if "_bias" in layer_name: - continue - - layer = self.wrapped_model.get_layer(layer_name) - if layer is not None and hasattr(layer, "weight"): - layer.weight.data = original_weight.clone() - - # Restore bias if exists - bias_key = layer_name + "_bias" - if bias_key in self.original_weights and hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = self.original_weights[bias_key].clone() - - def _evaluate_model(self) -> Tuple[float, float]: - """ - Evaluate model performance. - - Returns: - Tuple of (loss, accuracy) - """ - self.model.eval() - total_loss = 0.0 - correct = 0 - total = 0 - - criterion = torch.nn.CrossEntropyLoss() - - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - if self.config.eval_batches and batch_idx >= self.config.eval_batches: - break - - inputs, targets = inputs.to(self.config.device), targets.to(self.config.device) - outputs = self.model(inputs) - - loss = criterion(outputs, targets) - total_loss += loss.item() - - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - avg_loss = total_loss / (batch_idx + 1) - accuracy = 100.0 * correct / total - - return avg_loss, accuracy - - def _train_model(self) -> Dict[str, Any]: - """Train the model if configured.""" - if not self.config.train_before_dropout: - logger.info("Skipping initial training (train_before_dropout=False)") - return {} - - logger.info(f"Training model for {self.config.training_epochs} epochs") - - # Create trainer using the unified interface - trainer = create_experiment_trainer(self.model, asdict(self.config), device=self.config.device) # Convert dataclass to dict - - # Train with metrics - history = train_with_metrics(trainer, self.data_loader, val_loader=None, compute_accuracy=True) # No validation in original implementation - - # Log final metrics (trainer already logs per-epoch) - if history["train_loss"]: - final_metrics = {"train_loss": history["train_loss"][-1], "train_accuracy": history["train_metrics"][-1].get("accuracy", 0.0)} - self.log_metrics(len(history["train_loss"]) - 1, final_metrics) - - # Return training results - return convert_training_history(history) - - def run(self) -> Dict[str, Any]: - """ - Run the layer-isolated pruning experiment. - - Returns: - Dictionary containing experiment results - """ - logger.info("Starting layer-isolated pruning experiment") - - # Train model if configured - training_results = self._train_model() - - # Compute pruning scores for each layer - layer_scores = self._compute_layer_scores() - self.pruning_scores = layer_scores - - # Initialize results - results = { - "dropout_rates": self.config.dropout_rates, - "layer_scores": {k: v.tolist() for k, v in layer_scores.items()}, - "accuracies": {"low": [], "high": [], "random": []}, - "losses": {"low": [], "high": [], "random": []}, - "per_layer_masks": {}, - "training_results": training_results, # Include training results - } - - # Evaluate at each dropout rate - for dropout_rate in self.config.dropout_rates: - logger.info(f"Evaluating dropout rate: {dropout_rate}") - - # Create masks for this dropout rate - masks = self._create_dropout_masks(layer_scores, dropout_rate) - - # Store mask info - results["per_layer_masks"][f"dropout_{dropout_rate}"] = { - layer: {"total_neurons": len(mask), "active_neurons": int(mask.sum().item())} for layer, mask in masks["low"].items() - } - - # Evaluate each strategy - for strategy in ["low", "high", "random"]: - if strategy == "random": - # Average over multiple random trials - trial_losses = [] - trial_accs = [] - - for trial in range(self.config.num_random_trials): - # Create new random masks - random_masks = self._create_dropout_masks(layer_scores, dropout_rate)["random"] - - # Apply masks - self._apply_layer_masks(random_masks) - - # Evaluate - loss, acc = self._evaluate_model() - trial_losses.append(loss) - trial_accs.append(acc) - - # Restore weights - self._restore_original_weights() - - # Average results - avg_loss = np.mean(trial_losses) - avg_acc = np.mean(trial_accs) - - else: - # Apply masks for this strategy - self._apply_layer_masks(masks[strategy]) - - # Evaluate - avg_loss, avg_acc = self._evaluate_model() - - # Restore original weights - self._restore_original_weights() - - # Store results - results["losses"][strategy].append(avg_loss) - results["accuracies"][strategy].append(avg_acc) - - logger.info(f" {strategy}: Loss={avg_loss:.4f}, Accuracy={avg_acc:.2f}%") - - # Log metrics - self.log_metrics( - len(results["losses"][strategy]) - 1, - {f"{strategy}_loss": avg_loss, f"{strategy}_accuracy": avg_acc, "dropout_rate": dropout_rate}, - ) - - # Save final results - self.results.update(results) - self.save_results() - - # Save checkpoint with final state - self.save_checkpoint(step=len(self.config.dropout_rates), metrics={"final_results": results}) - - logger.info("Layer-isolated pruning experiment completed") - - return results diff --git a/src/alignment/pruning/experiments/progressive.py b/src/alignment/pruning/experiments/progressive.py deleted file mode 100644 index 7687900b..00000000 --- a/src/alignment/pruning/experiments/progressive.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -Progressive dropout experiment for alignment analysis. - -This experiment progressively applies dropout to model layers -and tracks how alignment metrics change. -""" - -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional - -import numpy as np -import torch -import torch.nn as nn - -from alignment.core.registry import register_experiment -from alignment.experiments.base import BaseExperiment, ExperimentConfig - -logger = logging.getLogger(__name__) - - -@register_experiment("progressive_dropout") -class ProgressiveDropoutExperiment(BaseExperiment): - """ - Experiment that progressively applies dropout to model layers. - - This experiment: - 1. Starts with a trained model - 2. Progressively increases dropout rates - 3. Tracks alignment metrics at each dropout level - 4. Analyzes how alignment changes with dropout - """ - - def __init__(self, config: ExperimentConfig): - """ - Initialize progressive dropout experiment. - - Additional config parameters: - dropout_rates: List of dropout rates to test - dropout_structure: Type of structured dropout ('random', 'magnitude', 'gradient') - num_samples: Number of data samples to use for metrics - apply_to_layers: Specific layers to apply dropout to (None = all) - """ - super().__init__(config) - - # Experiment-specific config - self.dropout_rates = getattr(config, "dropout_rates", [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) - self.dropout_structure = getattr(config, "dropout_structure", "random") - self.num_samples = getattr(config, "num_samples", 1000) - self.apply_to_layers = getattr(config, "apply_to_layers", None) - - # Dropout configuration - self.dropout_mode = getattr(config, "dropout_mode", "scaled") - self.pruning_mode = getattr(config, "pruning_mode", "global_joint") - self.pruning_strategy = getattr(config, "pruning_strategy", "low") - self.exclude_classification_layer = getattr(config, "exclude_classification_layer", True) - - # Results storage - self.dropout_results = {"dropout_rates": self.dropout_rates, "metrics_by_rate": {}, "layer_statistics": {}} - - def run(self, models=None, dataset=None, **kwargs) -> Dict[str, Any]: - """ - Run the progressive dropout experiment. - - Returns: - Experiment results including metrics at each dropout rate - """ - logger.info("Starting progressive dropout experiment") - - # Collect initial model statistics - self._collect_initial_statistics() - - # Get data samples for evaluation - eval_data = self._get_evaluation_data() - - # Test each dropout rate - for dropout_rate in self.dropout_rates: - logger.info(f"Testing dropout rate: {dropout_rate}") - - # Create dropout masks for this rate - dropout_masks = self._create_dropout_masks(dropout_rate) - - # Apply dropout to model - if dropout_masks: - self.wrapped_model.apply_structured_dropout(dropout_masks, mode="multiplicative", permanent=False) - - # Compute metrics with dropout - metrics = self._evaluate_with_dropout(eval_data, dropout_rate) - - # Store results - self.dropout_results["metrics_by_rate"][dropout_rate] = metrics - - # Log progress - self.log_metrics(step=int(dropout_rate * 100), metrics=self._flatten_metrics(metrics, dropout_rate)) # Use dropout % as step - - # Restore original weights - if dropout_masks: - self.wrapped_model.restore_weights() - - # Optional: Save checkpoint - if self.config.checkpoint_interval > 0: - self.save_checkpoint(step=int(dropout_rate * 100), metrics=metrics) - - # Analyze results - self._analyze_dropout_effects() - - # Save final results - self.results.update(self.dropout_results) - self.save_results() - - return self.results - - def _collect_initial_statistics(self): - """Collect statistics about the initial model.""" - logger.info("Collecting initial model statistics") - - weights = self.wrapped_model.get_layer_weights() - - for layer_name, weight in weights.items(): - if weight is None: - continue - - # Compute weight statistics - stats = { - "shape": list(weight.shape), - "num_parameters": weight.numel(), - "mean": weight.mean().item(), - "std": weight.std().item(), - "min": weight.min().item(), - "max": weight.max().item(), - "sparsity": (weight == 0).float().mean().item(), - } - - # Compute norms - stats["l1_norm"] = weight.abs().sum().item() - stats["l2_norm"] = weight.pow(2).sum().sqrt().item() - - self.dropout_results["layer_statistics"][layer_name] = stats - - def _create_dropout_masks(self, dropout_rate: float) -> Dict[str, torch.Tensor]: - """ - Create dropout masks for specified layers. - - Args: - dropout_rate: Fraction of units to drop - - Returns: - Dictionary mapping layer names to binary masks - """ - if dropout_rate == 0.0: - return {} - - dropout_masks = {} - layers_to_apply = self.apply_to_layers or self.wrapped_model.tracked_layers - - for layer_name in layers_to_apply: - layer_info = self.wrapped_model.get_layer_info(layer_name) - - if "weight_shape" not in layer_info: - continue - - # Get number of units based on layer type - if layer_info["type"] == "Linear": - num_units = layer_info["out_features"] - elif layer_info["type"] in ["Conv2d", "Conv1d"]: - num_units = layer_info["out_channels"] - else: - continue - - # Create mask based on dropout structure - if self.dropout_structure == "random": - # Random dropout - mask = torch.rand(num_units) > dropout_rate - elif self.dropout_structure == "magnitude": - # Magnitude-based dropout (keep high magnitude units) - weights = self.wrapped_model.get_layer_weights([layer_name])[layer_name] - magnitudes = weights.abs().sum(dim=tuple(range(1, weights.ndim))) - threshold = torch.quantile(magnitudes, dropout_rate) - mask = magnitudes > threshold - elif self.dropout_structure == "gradient": - # Gradient-based dropout (requires gradients) - # For now, fallback to random - mask = torch.rand(num_units) > dropout_rate - else: - raise ValueError(f"Unknown dropout structure: {self.dropout_structure}") - - dropout_masks[layer_name] = mask.float().to(self.config.device) - - return dropout_masks - - def _get_evaluation_data(self) -> List[torch.Tensor]: - """Get data samples for evaluation.""" - eval_data = [] - total_samples = 0 - - for batch_idx, (inputs, targets) in enumerate(self.data_loader): - inputs = inputs.to(self.config.device) - eval_data.append(inputs) - - total_samples += inputs.size(0) - if total_samples >= self.num_samples: - break - - # Concatenate and trim to exact number - eval_data = torch.cat(eval_data, dim=0)[: self.num_samples] - - logger.info(f"Collected {eval_data.size(0)} samples for evaluation") - return eval_data - - def _evaluate_with_dropout(self, eval_data: torch.Tensor, dropout_rate: float) -> Dict[str, Any]: - """ - Evaluate metrics with current dropout settings. - - Args: - eval_data: Data to evaluate on - dropout_rate: Current dropout rate - - Returns: - Dictionary of metrics - """ - # Set model to eval mode (but keep dropout active via context manager) - self.model.eval() - - all_metrics = {} - batch_size = min(self.config.batch_size, eval_data.size(0)) - - # Process in batches - for i in range(0, eval_data.size(0), batch_size): - batch = eval_data[i : i + batch_size] - - # Compute metrics for batch - with torch.no_grad(): - batch_metrics = self.compute_metrics(batch) - - # Accumulate metrics - for metric_name, layer_results in batch_metrics.items(): - if metric_name not in all_metrics: - all_metrics[metric_name] = {} - - for layer_name, value in layer_results.items(): - if layer_name not in all_metrics[metric_name]: - all_metrics[metric_name][layer_name] = [] - all_metrics[metric_name][layer_name].append(value) - - # Average metrics across batches - averaged_metrics = {} - for metric_name, layer_results in all_metrics.items(): - averaged_metrics[metric_name] = {} - for layer_name, values in layer_results.items(): - averaged_metrics[metric_name][layer_name] = np.mean(values) - - # Add additional statistics - averaged_metrics["_statistics"] = { - "dropout_rate": dropout_rate, - "num_samples": eval_data.size(0), - "effective_sparsity": self._compute_effective_sparsity(), - } - - return averaged_metrics - - def _compute_effective_sparsity(self) -> Dict[str, float]: - """Compute effective sparsity after dropout.""" - sparsity = {} - - for name, module in self.model.named_modules(): - if hasattr(module, "weight") and module.weight is not None: - weight = module.weight - # Count zeros (including dropout-induced zeros) - num_zeros = (weight == 0).float().sum().item() - total_params = weight.numel() - sparsity[name] = num_zeros / total_params - - return sparsity - - def _flatten_metrics(self, metrics: Dict[str, Any], dropout_rate: float) -> Dict[str, float]: - """Flatten metrics dictionary for logging.""" - flat_metrics = {f"dropout_rate": dropout_rate} - - for metric_name, layer_results in metrics.items(): - if metric_name.startswith("_"): - continue # Skip internal metrics - - for layer_name, value in layer_results.items(): - key = f"{metric_name}/{layer_name}" - flat_metrics[key] = value - - return flat_metrics - - def _analyze_dropout_effects(self): - """Analyze how dropout affects alignment metrics.""" - logger.info("Analyzing dropout effects on alignment") - - analysis = {"metric_trends": {}, "layer_sensitivity": {}, "critical_dropout_rates": {}} - - # Analyze trends for each metric and layer - for metric_name in self.metrics.keys(): - analysis["metric_trends"][metric_name] = {} - - # Get all layers that have this metric - all_layers = set() - for rate_results in self.dropout_results["metrics_by_rate"].values(): - if metric_name in rate_results: - all_layers.update(rate_results[metric_name].keys()) - - for layer_name in all_layers: - # Collect values across dropout rates - rates = [] - values = [] - - for rate, results in self.dropout_results["metrics_by_rate"].items(): - if metric_name in results and layer_name in results[metric_name]: - rates.append(rate) - values.append(results[metric_name][layer_name]) - - if len(values) < 2: - continue - - # Compute trend statistics - values = np.array(values) - rates = np.array(rates) - - # Linear regression to find trend - coeffs = np.polyfit(rates, values, 1) - slope = coeffs[0] - - # Find critical points (large changes) - if len(values) > 2: - diffs = np.diff(values) - max_change_idx = np.argmax(np.abs(diffs)) - critical_rate = rates[max_change_idx + 1] - else: - critical_rate = None - - analysis["metric_trends"][metric_name][layer_name] = { - "slope": float(slope), - "initial_value": float(values[0]), - "final_value": float(values[-1]), - "percent_change": float((values[-1] - values[0]) / (values[0] + 1e-8) * 100), - "critical_dropout_rate": float(critical_rate) if critical_rate else None, - } - - # Compute layer sensitivity (how much each layer is affected by dropout) - for layer_name in self.wrapped_model.tracked_layers: - sensitivities = [] - - for metric_name in self.metrics.keys(): - if metric_name in analysis["metric_trends"] and layer_name in analysis["metric_trends"][metric_name]: - - trend = analysis["metric_trends"][metric_name][layer_name] - sensitivities.append(abs(trend["percent_change"])) - - if sensitivities: - analysis["layer_sensitivity"][layer_name] = { - "mean_sensitivity": float(np.mean(sensitivities)), - "max_sensitivity": float(np.max(sensitivities)), - "sensitivity_scores": sensitivities, - } - - self.dropout_results["analysis"] = analysis - logger.info("Dropout analysis complete") diff --git a/src/alignment/pruning/orchestrator.py b/src/alignment/pruning/orchestrator.py deleted file mode 100644 index cc3e787a..00000000 --- a/src/alignment/pruning/orchestrator.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -Master Pruning Orchestrator - Complete pruning pipeline. - -Coordinates all aspects of pruning: -- Distribution strategy (how to allocate across layers) -- Scoring method (single metric or composite) -- Dynamic vs static scoring -- Parallel optimization -- Dependency handling - -Provides simple high-level API for comprehensive pruning experiments. -""" - -import logging -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional - -import torch -import torch.nn as nn - -logger = logging.getLogger(__name__) - - -@dataclass -class PruningPlan: - """Complete pruning plan.""" - - distribution_strategy: str - scoring_method: str - per_layer_amounts: Dict[str, float] - per_layer_scores: Dict[str, torch.Tensor] - expected_sparsity: float - use_dynamic_scores: bool = False - - -class MasterPruningOrchestrator: - """ - High-level orchestrator for complete pruning workflows. - - Handles everything: - - Distribution across layers (uniform, adaptive, global, etc.) - - Scoring (single metric, composite, dynamic) - - Direction (low, high, random) - - Dependencies (conv, attention) - - Parallelization (multiple strategies/networks) - - One-liner API for comprehensive experiments! - - Example: - >>> orchestrator = MasterPruningOrchestrator() - >>> result = orchestrator.prune_complete( - ... model, - ... target_sparsity=0.7, - ... distribution='adaptive_sensitivity', - ... scoring='composite', - ... use_dynamic=True, - ... train_loader=train_loader, - ... val_loader=val_loader - ... ) - >>> print(f"Accuracy: {result['baseline']}% → {result['final']}%") - """ - - def __init__(self, verbose: bool = True, parallel: bool = False, num_workers: int = 4): - """ - Initialize orchestrator. - - Args: - verbose: Print detailed progress - parallel: Use parallel optimization when possible - num_workers: Number of parallel workers - """ - self.verbose = verbose - self.parallel = parallel - self.num_workers = num_workers - - def prune_complete( - self, - model: nn.Module, - target_sparsity: float, - distribution: str = "adaptive_sensitivity", - scoring: str = "composite", - direction: str = "low", - use_dynamic: bool = False, - train_loader=None, - val_loader=None, - trainer_fn: Optional[Callable] = None, - eval_fn: Optional[Callable] = None, - layers: Optional[List[str]] = None, - fine_tune_epochs: int = 20, - ) -> Dict[str, Any]: - """ - Complete pruning workflow with all options. - - Args: - model: Model to prune - target_sparsity: Overall target (e.g., 0.7 for 70%) - distribution: How to distribute across layers - - 'uniform': Same % per layer - - 'global_threshold': Global score threshold - - 'adaptive_sensitivity': Based on layer sensitivity (RECOMMENDED) - - 'importance_weighted': Based on avg scores - - 'cascading': Sequential - scoring: How to score neurons - - 'magnitude': L1/L2 norm - - 'rayleigh_quotient': RQ alignment - - 'composite': Multi-criteria (RECOMMENDED) - - 'movement': Training-aware - direction: Which neurons to prune - - 'low': Prune low scores (default) - - 'high': Prune high scores (ablation) - - 'random': Random (baseline) - use_dynamic: Use training evolution (requires train_loader) - train_loader: For dynamic scoring & fine-tuning - val_loader: For evaluation - trainer_fn: Function(model, train_loader, epochs) for fine-tuning - eval_fn: Function(model, val_loader) -> accuracy - layers: Specific layers to prune (None = auto-detect) - fine_tune_epochs: Epochs to fine-tune after pruning - - Returns: - Complete results dictionary - """ - if self.verbose: - print("\n" + "=" * 80) - print("Master Pruning Orchestrator") - print("=" * 80) - print(f"Target sparsity: {target_sparsity:.0%}") - print(f"Distribution: {distribution}") - print(f"Scoring: {scoring}") - print(f"Direction: {direction}") - print(f"Dynamic scores: {use_dynamic}") - print("=" * 80 + "\n") - - # Auto-detect layers if needed - if layers is None: - from ..core.layer_detector import detect_trackable_layers - - layers = detect_trackable_layers(model) - if self.verbose: - print(f"Auto-detected {len(layers)} trackable layers\n") - - # Baseline evaluation - baseline_acc = None - if eval_fn and val_loader: - baseline_acc = eval_fn(model, val_loader) - if self.verbose: - print(f"Baseline accuracy: {baseline_acc:.2f}%\n") - - # Step 1: Compute scores - if self.verbose: - print("Step 1: Computing importance scores...") - - if use_dynamic and train_loader: - layer_scores = self._compute_dynamic_scores(model, train_loader, layers, scoring) - else: - layer_scores = self._compute_static_scores(model, val_loader or train_loader, layers, scoring) - - if self.verbose: - print(f"Computed scores for {len(layer_scores)} layers\n") - - # Step 2: Compute distribution - if self.verbose: - print(f"Step 2: Computing {distribution} distribution...") - - from .distribution import PruningDistributionManager - - dist_manager = PruningDistributionManager(strategy=distribution, target_sparsity=target_sparsity) - - per_layer_amounts = dist_manager.compute_distribution( - model, layers, layer_scores=layer_scores, eval_fn=lambda m: eval_fn(m, val_loader) if eval_fn and val_loader else None - ) - - if self.verbose: - dist_manager.print_distribution(per_layer_amounts, model, layer_scores) - - # Step 3: Create masks - if self.verbose: - print("Step 3: Creating pruning masks...") - - from ..services import MaskOperations - - masks = {} - for layer_name in layers: - if layer_name not in layer_scores or layer_name not in per_layer_amounts: - continue - - mask = MaskOperations.create_structured_mask(layer_scores[layer_name], amount=per_layer_amounts[layer_name], mode=direction) - masks[layer_name] = mask - - if self.verbose: - print(f"Created masks for {len(masks)} layers\n") - - # Step 4: Apply with dependency awareness - if self.verbose: - print("Step 4: Applying pruning (dependency-aware)...") - - from .dependency_aware import DependencyAwarePruning - - dep_pruner = DependencyAwarePruning(model) - - # Convert masks to scores for dependency pruner interface - pruning_result = dep_pruner.prune(layer_scores, amount=target_sparsity, dry_run=False) # Overall target - - if self.verbose: - print("Applied pruning\n") - - # Step 5: Fine-tune - if trainer_fn and train_loader and fine_tune_epochs > 0: - if self.verbose: - print(f"Step 5: Fine-tuning for {fine_tune_epochs} epochs...") - - trainer_fn(model, train_loader, epochs=fine_tune_epochs) - - if self.verbose: - print("Fine-tuning complete\n") - - # Step 6: Final evaluation - final_acc = None - if eval_fn and val_loader: - final_acc = eval_fn(model, val_loader) - if self.verbose: - print(f"Final accuracy: {final_acc:.2f}%") - if baseline_acc: - drop = baseline_acc - final_acc - print(f"Accuracy drop: {drop:.2f}%\n") - - # Return complete results - return { - "baseline_accuracy": baseline_acc, - "final_accuracy": final_acc, - "accuracy_drop": baseline_acc - final_acc if baseline_acc and final_acc else None, - "target_sparsity": target_sparsity, - "distribution_strategy": distribution, - "scoring_method": scoring, - "per_layer_amounts": per_layer_amounts, - "masks": masks, - "pruning_stats": pruning_result["stats"], - } - - def _compute_static_scores(self, model: nn.Module, data_loader, layers: List[str], scoring: str) -> Dict[str, torch.Tensor]: - """Compute scores on current (trained) model.""" - from ..metrics import get_metric - from ..models import BaseModelWrapper - from ..services import ActivationCaptureService, NodeScoringService - - # Wrap model - wrapper = BaseModelWrapper(model, tracked_layers=layers) - capture = ActivationCaptureService(wrapper) - - # Get batch - inputs, targets = next(iter(data_loader)) - if torch.cuda.is_available(): - inputs = inputs.cuda() - targets = targets.cuda() - - # Capture activations - data = capture.capture(inputs, include_weights=True) - - # Compute scores based on method - if scoring == "magnitude": - scores = {} - for layer in layers: - if layer in data.weights: - weights = data.weights[layer] - scores[layer] = weights.abs().mean(dim=list(range(1, weights.ndim))) - - elif scoring == "rayleigh_quotient": - rq = get_metric("rayleigh_quotient") - scores = {} - for layer in layers: - if layer in data.inputs and layer in data.weights: - scores[layer] = rq.compute(data.inputs[layer], data.weights[layer]) - - elif scoring == "composite": - scorer = NodeScoringService( - metrics={ - "rq": get_metric("rayleigh_quotient"), - "redundancy": get_metric("pairwise_redundancy_gaussian", mode="output_based", num_pairs=10), - "synergy": get_metric("synergy_gaussian_mmi", num_pairs=10), - } - ) - - layer_scores_obj = scorer.compute_layerwise_scores(data, targets) - scores = {name: ls.composite for name, ls in layer_scores_obj.items()} - - else: - raise ValueError(f"Unknown scoring method: {scoring}") - - return scores - - def _compute_dynamic_scores(self, model: nn.Module, train_loader, layers: List[str], scoring: str) -> Dict[str, torch.Tensor]: - """ - Compute scores using training dynamics. - - Note: Requires training with callback - placeholder for now. - Future: Integrate with training history. - """ - logger.warning( - "Dynamic scoring requires training with AlignmentMetricsCallback. " - "Falling back to static scores. " - "See dynamic_scoring.py for full implementation." - ) - - # Fallback to static - return self._compute_static_scores(model, train_loader, layers, scoring) - - -def prune_with_all_options(model: nn.Module, target_sparsity: float = 0.7, **kwargs) -> Dict: - """ - One-liner for complete pruning with all options. - - Args: - model: Model to prune - target_sparsity: Target overall sparsity - **kwargs: All options (distribution, scoring, etc.) - - Returns: - Complete results - """ - orchestrator = MasterPruningOrchestrator() - return orchestrator.prune_complete(model, target_sparsity, **kwargs) diff --git a/src/alignment/pruning/parallel_optimizer.py b/src/alignment/pruning/parallel_optimizer.py deleted file mode 100644 index 459fe60a..00000000 --- a/src/alignment/pruning/parallel_optimizer.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Parallel pruning optimizer for maximum efficiency. - -Speeds up pruning by parallelizing across: -1. Multiple networks (ensemble analysis) -2. Multiple strategies (compare approaches) -3. Multiple layers (concurrent processing) -""" - -import copy -import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, List, Optional, Tuple - -import torch -import torch.nn as nn - -logger = logging.getLogger(__name__) - - -class ParallelPruningOptimizer: - """ - Optimize pruning by parallelizing computation. - - Key optimizations: - 1. Shared activation capture (one forward pass for all metrics) - 2. Batched metric computation (vectorized across neurons) - 3. Parallel strategy comparison (test multiple approaches) - 4. Multi-network ensemble pruning - - Performance: N strategies × M networks in ~1.5x time of single case - - Example: - >>> optimizer = ParallelPruningOptimizer() - >>> results = optimizer.compare_strategies_parallel( - ... model, - ... strategies=['magnitude', 'alignment', 'composite'], - ... amounts=[0.3, 0.5, 0.7], - ... data_loader=val_loader - ... ) - >>> # Results for all strategy×amount combinations in parallel! - """ - - def __init__(self, num_workers: int = 4, use_gpu: bool = True, shared_computation: bool = True): - """ - Initialize parallel optimizer. - - Args: - num_workers: Number of parallel workers - use_gpu: Whether to use GPU for computation - shared_computation: Share activations/covariances across tasks - """ - self.num_workers = num_workers - self.use_gpu = use_gpu - self.shared_computation = shared_computation - - def compare_strategies_parallel( - self, base_model: nn.Module, strategies: List[str], amounts: List[float], data_loader, eval_fn: Callable, layers: Optional[List[str]] = None - ) -> Dict[Tuple[str, float], Dict]: - """ - Compare multiple pruning strategies in parallel. - - Args: - base_model: Base model (will be copied for each strategy) - strategies: List of strategy names to compare - amounts: List of pruning amounts to try - data_loader: Data for metric computation - eval_fn: Evaluation function - layers: Layers to prune (None = auto-detect) - - Returns: - Dict[(strategy, amount)] -> {'accuracy': X, 'mask': M, ...} - """ - # Create all strategy×amount combinations - experiments = [(strategy, amount) for strategy in strategies for amount in amounts] - - logger.info(f"Running {len(experiments)} experiments in parallel...") - - # Shared computation: capture activations once - if self.shared_computation: - shared_data = self._capture_shared_data(base_model, data_loader, layers) - else: - shared_data = None - - # Parallel execution - results = {} - - # For GPU, sequential is better (avoid memory issues) - # For CPU, can parallelize - if self.use_gpu or self.num_workers == 1: - # Sequential on GPU - for strategy, amount in experiments: - result = self._run_single_experiment(base_model, strategy, amount, eval_fn, shared_data, layers) - results[(strategy, amount)] = result - else: - # Parallel on CPU - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - futures = { - executor.submit(self._run_single_experiment, base_model, strategy, amount, eval_fn, shared_data, layers): (strategy, amount) - for strategy, amount in experiments - } - - for future in futures: - strategy, amount = futures[future] - results[(strategy, amount)] = future.result() - - # Print comparison - self._print_comparison(results, strategies, amounts) - - return results - - def _capture_shared_data(self, model: nn.Module, data_loader, layers: Optional[List[str]]) -> Dict: - """Capture activations and weights once for all strategies.""" - from ..models import BaseModelWrapper - from ..services import ActivationCaptureService - - wrapper = BaseModelWrapper(model, tracked_layers=layers) - capture = ActivationCaptureService(wrapper) - - # Capture on a batch - inputs, targets = next(iter(data_loader)) - if self.use_gpu and torch.cuda.is_available(): - inputs = inputs.cuda() - targets = targets.cuda() - - data = capture.capture(inputs, include_weights=True) - - return {"activation_data": data, "targets": targets} - - def _run_single_experiment( - self, base_model: nn.Module, strategy: str, amount: float, eval_fn: Callable, shared_data: Optional[Dict], layers: Optional[List[str]] - ) -> Dict: - """Run a single pruning experiment.""" - from ..metrics import get_metric - from ..services import MaskOperations, NodeScoringService - - # Clone model - model = copy.deepcopy(base_model) - - # Compute scores using shared data - if shared_data and strategy in ["alignment", "composite"]: - data = shared_data["activation_data"] - targets = shared_data["targets"] - - if strategy == "alignment": - scorer = NodeScoringService(metrics={"rq": get_metric("rayleigh_quotient")}) - else: # composite - scorer = NodeScoringService( - metrics={"rq": get_metric("rayleigh_quotient"), "redundancy": get_metric("pairwise_redundancy_gaussian", mode="output_based")} - ) - - layer_scores = scorer.compute_layerwise_scores(data, targets) - scores_dict = {name: layer_scores[name].composite for name in layer_scores} - - elif strategy == "magnitude": - # Magnitude scores - scores_dict = {} - for name, module in model.named_modules(): - if layers is None or name in layers: - if hasattr(module, "weight"): - scores_dict[name] = module.weight.abs().mean(dim=list(range(1, module.weight.ndim))) - - elif strategy == "random": - # Random scores - scores_dict = {} - for name, module in model.named_modules(): - if layers is None or name in layers: - if hasattr(module, "weight"): - out_dim = module.weight.shape[0] - scores_dict[name] = torch.rand(out_dim) - - else: - raise ValueError(f"Unknown strategy: {strategy}") - - # Create masks - masks = {} - for layer_name, scores in scores_dict.items(): - mask = MaskOperations.create_structured_mask(scores, amount, mode="low") - masks[layer_name] = mask - - # Apply pruning - for layer_name, mask in masks.items(): - module = dict(model.named_modules())[layer_name] - if hasattr(module, "weight"): - module.weight.data *= mask.unsqueeze(1).float() - - # Evaluate - accuracy = eval_fn(model) - - return {"strategy": strategy, "amount": amount, "accuracy": accuracy, "masks": masks} - - def _print_comparison(self, results: Dict, strategies: List[str], amounts: List[float]): - """Print comparison table.""" - print("\n" + "=" * 80) - print("Parallel Strategy Comparison") - print("=" * 80) - - # Create table - print(f"\n{'Strategy':<20} ", end="") - for amount in amounts: - print(f"{amount:>8.0%} ", end="") - print() - print("-" * 80) - - for strategy in strategies: - print(f"{strategy:<20} ", end="") - for amount in amounts: - key = (strategy, amount) - if key in results: - acc = results[key]["accuracy"] - print(f"{acc:>8.2f}% ", end="") - else: - print(f"{'N/A':>9} ", end="") - print() - - print("=" * 80 + "\n") - - def prune_ensemble_parallel(self, networks: List[nn.Module], strategy: str, amount: float, shared_inputs: torch.Tensor) -> List[Dict]: - """ - Prune multiple networks in parallel with shared computation. - - Args: - networks: List of networks (same architecture) - strategy: Pruning strategy - amount: Pruning amount - shared_inputs: Input batch (same for all networks) - - Returns: - List of results per network - """ - # Shared: Compute covariance once - if self.shared_computation: - shared_cov = torch.cov(shared_inputs.T) - else: - shared_cov = None - - # Process each network - results = [] - - for net_idx, network in enumerate(networks): - # Compute scores (using shared covariance if available) - scores = self._compute_scores_with_shared_cov(network, shared_inputs, shared_cov, strategy) - - # Prune - masks = self._create_and_apply_masks(network, scores, amount) - - results.append({"network_idx": net_idx, "masks": masks, "strategy": strategy, "amount": amount}) - - logger.info(f"Pruned {len(networks)} networks in parallel") - - return results - - def _compute_scores_with_shared_cov( - self, network: nn.Module, inputs: torch.Tensor, shared_cov: Optional[torch.Tensor], strategy: str - ) -> Dict[str, torch.Tensor]: - """Compute scores using shared covariance.""" - from ..metrics import get_metric - - scores = {} - - if strategy == "alignment" or strategy == "composite": - rq = get_metric("rayleigh_quotient") - - for name, module in network.named_modules(): - if hasattr(module, "weight"): - if shared_cov is not None: - # Use shared covariance (FAST!) - weights = module.weight - if weights.ndim > 2: - weights = weights.reshape(weights.shape[0], -1) - - # RQ = (w @ cov @ w.T).diag() / (w @ w.T).diag() / tr(cov) - wc = weights @ shared_cov - numerator = (wc * weights).sum(dim=1) - denominator = (weights**2).sum(dim=1) - rq_scores = numerator / (denominator + 1e-12) - rq_scores = rq_scores / (shared_cov.trace() + 1e-12) - - scores[name] = rq_scores - else: - # Compute normally - scores[name] = rq.compute(inputs, module.weight) - - else: # magnitude or other - for name, module in network.named_modules(): - if hasattr(module, "weight"): - scores[name] = module.weight.abs().mean(dim=list(range(1, module.weight.ndim))) - - return scores - - def _create_and_apply_masks(self, network: nn.Module, scores: Dict[str, torch.Tensor], amount: float) -> Dict[str, torch.Tensor]: - """Create and apply masks.""" - from ..services import MaskOperations - - masks = {} - - for layer_name, layer_scores in scores.items(): - mask = MaskOperations.create_structured_mask(layer_scores, amount, mode="low") - masks[layer_name] = mask - - # Apply - module = dict(network.named_modules())[layer_name] - if hasattr(module, "weight"): - module.weight.data *= mask.unsqueeze(1).float() - - return masks diff --git a/src/alignment/pruning/pipeline.py b/src/alignment/pruning/pipeline.py new file mode 100644 index 00000000..c52342bc --- /dev/null +++ b/src/alignment/pruning/pipeline.py @@ -0,0 +1,140 @@ +""" +Shared structured pruning pipeline. + +Provides a thin wrapper around existing pruning utilities (distribution +manager, dependency-aware pruning, mask ops) so experiments can invoke +pruning consistently through configuration instead of bespoke loops. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from .dependency_aware import DependencyAwarePruning +from .distribution import PruningDistributionManager +from ..services.mask_ops import MaskOperations + +logger = logging.getLogger(__name__) + + +@dataclass +class PruningPipelineOptions: + """Options controlling the pruning pipeline behaviour.""" + + distribution: str = "uniform" + dependency_aware: bool = False + min_amount: float = 0.0 + max_amount: float = 0.95 + + +def _ensure_tensor(scores) -> torch.Tensor: + if isinstance(scores, torch.Tensor): + return scores + return torch.as_tensor(scores, dtype=torch.float32) + + +def _apply_masks_to_modules(layer_modules: Dict[str, nn.Module], masks: Dict[str, torch.Tensor]) -> None: + """Apply channel masks directly to module weights/biases.""" + for layer_name, mask in masks.items(): + module = layer_modules.get(layer_name) + if module is None or not hasattr(module, "weight"): + continue + + mask = mask.to(module.weight.device) + weight_mask = MaskOperations.expand_neuron_mask_to_weights(mask, module.weight.data.shape, dim=0) + masked_weights = MaskOperations.apply_mask_to_weights(module.weight.data, weight_mask, mode="zero") + + with torch.no_grad(): + module.weight.data.copy_(masked_weights) + if getattr(module, "bias", None) is not None and mask.numel() == module.bias.data.numel(): + bias_mask = mask.bool() + module.bias.data[~bias_mask] = 0.0 + + +def run_pruning_pipeline( + model: nn.Module, + layer_scores: Dict[str, torch.Tensor], + *, + layer_modules: Optional[Dict[str, nn.Module]] = None, + target_sparsity: float, + selection_mode: str = "low", + options: Optional[PruningPipelineOptions] = None, +) -> Dict[str, Any]: + """ + Execute pruning using shared infrastructure. + + Args: + model: Model to prune. + layer_scores: Dict mapping layer_name -> channel scores tensor. + layer_modules: Optional mapping layer_name -> module (defaults to model.named_modules()). + target_sparsity: Overall sparsity target (0-1). + selection_mode: 'low', 'high', or 'random'. + options: Additional pipeline options. + + Returns: + Dict with 'masks' and optional stats from dependency-aware pruning. + """ + if not layer_scores: + logger.warning("No layer scores provided; skipping pruning") + return {"masks": {}} + + options = options or PruningPipelineOptions() + layer_modules = layer_modules or dict(model.named_modules()) + + tensor_scores = {name: _ensure_tensor(score) for name, score in layer_scores.items()} + layer_names = [name for name in tensor_scores.keys() if name in layer_modules] + + if not layer_names: + logger.warning("Layer score names do not match model modules; skipping pruning") + return {"masks": {}} + + distribution = options.distribution or "uniform" + + if options.dependency_aware: + manager = PruningDistributionManager( + strategy=distribution, + target_sparsity=target_sparsity, + min_amount=options.min_amount, + max_amount=options.max_amount, + ) + per_layer_amounts = manager.compute_distribution(model, layer_names, layer_scores=tensor_scores) + + dep_pruner = DependencyAwarePruning(model) + result = dep_pruner.prune( + tensor_scores, + amount=target_sparsity, + mode=selection_mode, + per_layer_amounts=per_layer_amounts, + ) + flat_masks = {} + for name, data in result["masks"].items(): + output_mask = data.get("output_mask") if isinstance(data, dict) else None + if output_mask is not None: + flat_masks[name] = output_mask + result["masks"] = flat_masks + return result + + if distribution in {"global_threshold", "global"}: + masks = MaskOperations.global_threshold_mask(tensor_scores, global_amount=target_sparsity, mode=selection_mode) + else: + manager = PruningDistributionManager( + strategy=distribution, + target_sparsity=target_sparsity, + min_amount=options.min_amount, + max_amount=options.max_amount, + ) + per_layer_amounts = manager.compute_distribution(model, layer_names, layer_scores=tensor_scores) + masks = {} + for name in layer_names: + amount = per_layer_amounts.get(name, target_sparsity) + masks[name] = MaskOperations.create_structured_mask(tensor_scores[name], amount=amount, mode=selection_mode) + + _apply_masks_to_modules(layer_modules, masks) + + stats = {name: MaskOperations.get_mask_statistics(mask) for name, mask in masks.items()} + return {"masks": masks, "stats": stats} diff --git a/src/alignment/pruning/strategies/__init__.py b/src/alignment/pruning/strategies/__init__.py index d15d3e53..42682dfb 100644 --- a/src/alignment/pruning/strategies/__init__.py +++ b/src/alignment/pruning/strategies/__init__.py @@ -2,10 +2,13 @@ Pruning strategies for the alignment framework. """ +from .adaptive import AdaptiveSensitivityPruning, LayerSensitivity from .alignment_based import AlignmentPruning, GlobalAlignmentPruning, HybridPruning from .cascading import CascadingAlignmentPruning from .cluster_aware import ClusterAwarePruning, ClusterAwarePruningConfig, CompositePruning +from .eigenvector import EigenvectorPruning from .gradient import FisherPruning, GradientPruning, MomentumPruning +from .movement import AdaptiveMovementPruning, MovementPruning from .llm_baselines import WandaPruning, SparseGPTPruning from .magnitude import GlobalMagnitudePruning, IterativeMagnitudePruning, MagnitudePruning from .parallel import AsyncParallelPruning, ParallelModePruning, TensorizedPruning @@ -35,6 +38,14 @@ "HybridPruning", "GlobalAlignmentPruning", "CascadingAlignmentPruning", + # Eigenvector-based (PCA pruning) + "EigenvectorPruning", + # Movement-based (Sanh et al. 2020) + "MovementPruning", + "AdaptiveMovementPruning", + # Adaptive sensitivity-based + "AdaptiveSensitivityPruning", + "LayerSensitivity", # Cluster-aware (vision paper) "ClusterAwarePruning", "ClusterAwarePruningConfig", diff --git a/src/alignment/pruning/strategies/adaptive.py b/src/alignment/pruning/strategies/adaptive.py index 3169a724..73a99d33 100644 --- a/src/alignment/pruning/strategies/adaptive.py +++ b/src/alignment/pruning/strategies/adaptive.py @@ -54,13 +54,25 @@ class AdaptiveSensitivityPruning(BasePruningStrategy): >>> # Overall: 70% average """ + # Available sensitivity methods + SENSITIVITY_METHODS = [ + "perturbation", # Add Gaussian noise, measure accuracy drop (slow but accurate) + "masking", # Random mask 30% weights, measure accuracy drop (slow) + "activation_variance", # Use activation variance as proxy (fast, single forward pass) + "gradient", # Use gradient magnitude (fast, single backward pass) + "fisher", # Fisher information approximation (moderate speed) + "weight_magnitude", # Use weight magnitude as proxy (fastest, no forward pass) + ] + def __init__( self, target_sparsity: float = 0.5, metric: str = "rayleigh_quotient", - sensitivity_method: str = "perturbation", # 'perturbation', 'gradient', 'hessian' + sensitivity_method: str = "perturbation", min_amount: float = 0.1, max_amount: float = 0.9, + perturbation_scale: float = 0.1, + num_trials: int = 3, config: Optional[PruningConfig] = None, **metric_kwargs, ): @@ -69,10 +81,18 @@ def __init__( Args: target_sparsity: Target overall sparsity (0-1) - metric: Metric to use for importance scores - sensitivity_method: How to measure sensitivity - min_amount: Minimum pruning per layer - max_amount: Maximum pruning per layer + metric: Metric to use for importance scores within each layer + sensitivity_method: How to measure layer sensitivity. Options: + - 'perturbation': Add Gaussian noise, measure accuracy drop (slow but accurate) + - 'masking': Random mask weights, measure accuracy drop (slow) + - 'activation_variance': Use activation variance (fast, single forward pass) + - 'gradient': Use gradient magnitude (fast, single backward pass) + - 'fisher': Fisher information approximation (moderate speed) + - 'weight_magnitude': Use weight magnitude as proxy (fastest) + min_amount: Minimum pruning per layer (0-1) + max_amount: Maximum pruning per layer (0-1) + perturbation_scale: Scale of perturbation for 'perturbation' method + num_trials: Number of trials for perturbation/masking methods config: Pruning configuration **metric_kwargs: Arguments for metric initialization """ @@ -83,6 +103,14 @@ def __init__( self.sensitivity_method = sensitivity_method self.min_amount = min_amount self.max_amount = max_amount + self.perturbation_scale = perturbation_scale + self.num_trials = num_trials + + if sensitivity_method not in self.SENSITIVITY_METHODS: + raise ValueError( + f"Unknown sensitivity_method: {sensitivity_method}. " + f"Available: {self.SENSITIVITY_METHODS}" + ) # Will be populated during sensitivity analysis self.layer_sensitivities: Dict[str, LayerSensitivity] = {} @@ -115,7 +143,12 @@ def compute_importance_scores(self, module: nn.Module, inputs: Optional[torch.Te return scores def measure_layer_sensitivity( - self, model: nn.Module, layer_name: str, eval_fn: Callable, perturbation_scale: float = 0.1, num_trials: int = 3 + self, + model: nn.Module, + layer_name: str, + eval_fn: Optional[Callable] = None, + data_loader=None, + cached_activations: Optional[Dict[str, torch.Tensor]] = None, ) -> float: """ Measure sensitivity of a single layer. @@ -124,31 +157,66 @@ def measure_layer_sensitivity( model: Full model layer_name: Name of layer to test eval_fn: Function that evaluates model and returns accuracy - perturbation_scale: Scale of perturbation - num_trials: Number of trials to average + (required for perturbation/masking methods) + data_loader: Data loader for gradient/fisher methods + cached_activations: Pre-captured activations for activation_variance method Returns: - Sensitivity (accuracy drop when layer perturbed) + Sensitivity score (higher = more sensitive = prune less) """ - # Get baseline accuracy - baseline_acc = eval_fn(model) - - # Get layer layer = dict(model.named_modules())[layer_name] if not hasattr(layer, "weight"): return 0.0 + # Fast methods that don't require eval_fn + if self.sensitivity_method == "weight_magnitude": + # Use average weight magnitude as sensitivity proxy + # Higher magnitude = more important = more sensitive + return layer.weight.data.abs().mean().item() + + elif self.sensitivity_method == "activation_variance": + # Use activation variance as sensitivity proxy + # Higher variance = more information = more sensitive + if cached_activations is not None and layer_name in cached_activations: + activations = cached_activations[layer_name] + # Variance across batch dimension + return activations.var(dim=0).mean().item() + else: + logger.warning(f"No cached activations for {layer_name}, using weight magnitude") + return layer.weight.data.abs().mean().item() + + elif self.sensitivity_method == "gradient": + # Use gradient magnitude as sensitivity proxy + # Requires a backward pass with data + if data_loader is None: + logger.warning("gradient method requires data_loader, using weight magnitude") + return layer.weight.data.abs().mean().item() + return self._compute_gradient_sensitivity(model, layer_name, data_loader) + + elif self.sensitivity_method == "fisher": + # Fisher information approximation + if data_loader is None: + logger.warning("fisher method requires data_loader, using weight magnitude") + return layer.weight.data.abs().mean().item() + return self._compute_fisher_sensitivity(model, layer_name, data_loader) + + # Slow methods that require eval_fn + if eval_fn is None: + raise ValueError(f"sensitivity_method '{self.sensitivity_method}' requires eval_fn") + + # Get baseline accuracy + baseline_acc = eval_fn(model) + # Store original weight original_weight = layer.weight.data.clone() - # Measure sensitivity via perturbation + # Measure sensitivity via perturbation/masking sensitivities = [] - for _ in range(num_trials): - # Perturb layer + for _ in range(self.num_trials): if self.sensitivity_method == "perturbation": - perturbation = perturbation_scale * torch.randn_like(layer.weight) + perturbation = self.perturbation_scale * torch.randn_like(layer.weight) layer.weight.data = original_weight + perturbation elif self.sensitivity_method == "masking": @@ -166,26 +234,129 @@ def measure_layer_sensitivity( # Restore layer.weight.data = original_weight - # Average sensitivity avg_sensitivity = sum(sensitivities) / len(sensitivities) - logger.debug(f"{layer_name}: sensitivity = {avg_sensitivity:.4f}") return avg_sensitivity - def compute_all_sensitivities(self, model: nn.Module, layer_names: List[str], eval_fn: Callable) -> Dict[str, LayerSensitivity]: + def _compute_gradient_sensitivity( + self, model: nn.Module, layer_name: str, data_loader + ) -> float: + """Compute sensitivity using gradient magnitude.""" + layer = dict(model.named_modules())[layer_name] + + model.train() + total_grad = None + num_batches = 0 + + # Get a few batches + for i, (inputs, targets) in enumerate(data_loader): + if i >= 3: # Only use 3 batches for speed + break + + if torch.cuda.is_available(): + inputs = inputs.cuda() + targets = targets.cuda() + + model.zero_grad() + outputs = model(inputs) + loss = torch.nn.functional.cross_entropy(outputs, targets) + loss.backward() + + if layer.weight.grad is not None: + grad = layer.weight.grad.abs() + if total_grad is None: + total_grad = grad.clone() + else: + total_grad += grad + num_batches += 1 + + model.eval() + + if total_grad is None or num_batches == 0: + return 0.0 + + # Average gradient magnitude = sensitivity + return (total_grad / num_batches).mean().item() + + def _compute_fisher_sensitivity( + self, model: nn.Module, layer_name: str, data_loader + ) -> float: + """Compute sensitivity using Fisher information approximation.""" + layer = dict(model.named_modules())[layer_name] + + model.train() + fisher_diag = None + num_batches = 0 + + for i, (inputs, targets) in enumerate(data_loader): + if i >= 3: # Only use 3 batches for speed + break + + if torch.cuda.is_available(): + inputs = inputs.cuda() + targets = targets.cuda() + + model.zero_grad() + outputs = model(inputs) + + # For Fisher, we use log-likelihood gradient squared + log_probs = torch.nn.functional.log_softmax(outputs, dim=1) + # Sample from predicted distribution + with torch.no_grad(): + sampled_labels = torch.multinomial(torch.exp(log_probs), 1).squeeze() + + loss = torch.nn.functional.nll_loss(log_probs, sampled_labels) + loss.backward() + + if layer.weight.grad is not None: + grad_sq = layer.weight.grad ** 2 + if fisher_diag is None: + fisher_diag = grad_sq.clone() + else: + fisher_diag += grad_sq + num_batches += 1 + + model.eval() + + if fisher_diag is None or num_batches == 0: + return 0.0 + + # Fisher information = sensitivity + return (fisher_diag / num_batches).mean().item() + + def compute_all_sensitivities( + self, + model: nn.Module, + layer_names: List[str], + eval_fn: Optional[Callable] = None, + data_loader=None, + cached_activations: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, LayerSensitivity]: """ Compute sensitivities for all specified layers. Args: model: Model to analyze layer_names: Layers to analyze - eval_fn: Evaluation function + eval_fn: Evaluation function (required for perturbation/masking methods) + data_loader: Data loader (required for gradient/fisher methods) + cached_activations: Pre-captured activations (for activation_variance method) Returns: Dict mapping layer names to LayerSensitivity objects """ - logger.info(f"Computing sensitivities for {len(layer_names)} layers...") + logger.info(f"Computing sensitivities for {len(layer_names)} layers using '{self.sensitivity_method}' method...") + + # Validate requirements + if self.sensitivity_method in ["perturbation", "masking"] and eval_fn is None: + raise ValueError(f"sensitivity_method '{self.sensitivity_method}' requires eval_fn") + if self.sensitivity_method in ["gradient", "fisher"] and data_loader is None: + raise ValueError(f"sensitivity_method '{self.sensitivity_method}' requires data_loader") + + # For activation_variance, capture activations if not provided + if self.sensitivity_method == "activation_variance" and cached_activations is None and data_loader is not None: + cached_activations = self._capture_activations(model, layer_names, data_loader) sensitivities = {} @@ -193,7 +364,13 @@ def compute_all_sensitivities(self, model: nn.Module, layer_names: List[str], ev layer = dict(model.named_modules())[layer_name] # Measure sensitivity - sens = self.measure_layer_sensitivity(model, layer_name, eval_fn) + sens = self.measure_layer_sensitivity( + model, + layer_name, + eval_fn=eval_fn, + data_loader=data_loader, + cached_activations=cached_activations, + ) # Get layer size size = layer.weight.numel() @@ -210,6 +387,39 @@ def compute_all_sensitivities(self, model: nn.Module, layer_names: List[str], ev return sensitivities + def _capture_activations( + self, model: nn.Module, layer_names: List[str], data_loader + ) -> Dict[str, torch.Tensor]: + """Capture activations for all layers in a single forward pass.""" + activations = {} + hooks = [] + + def make_hook(name): + def hook(module, input, output): + if isinstance(output, tuple): + output = output[0] + activations[name] = output.detach() + return hook + + # Register hooks + for name, module in model.named_modules(): + if name in layer_names: + hooks.append(module.register_forward_hook(make_hook(name))) + + # Forward pass with one batch + model.eval() + with torch.no_grad(): + inputs, _ = next(iter(data_loader)) + if torch.cuda.is_available(): + inputs = inputs.cuda() + model(inputs) + + # Remove hooks + for hook in hooks: + hook.remove() + + return activations + def _compute_adaptive_amounts(self, sensitivities: Dict[str, LayerSensitivity]) -> Dict[str, LayerSensitivity]: """ Compute adaptive pruning amounts based on sensitivities. diff --git a/src/alignment/pruning/strategies/eigenvector.py b/src/alignment/pruning/strategies/eigenvector.py new file mode 100644 index 00000000..a4bceec2 --- /dev/null +++ b/src/alignment/pruning/strategies/eigenvector.py @@ -0,0 +1,288 @@ +""" +Eigenvector-based pruning strategy. + +This module implements pruning based on PCA/eigendecomposition, +dropping neurons based on their contribution to principal components. +Neurons aligned with low-variance directions are pruned first. +""" + +import logging +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ..base import BasePruningStrategy, PruningConfig + +logger = logging.getLogger(__name__) + + +class EigenvectorPruning(BasePruningStrategy): + """ + Eigenvector-based pruning strategy. + + This strategy computes eigendecomposition of the activation covariance + matrix and prunes neurons based on their contribution to the principal + components. Neurons with low eigenvalue contributions are considered + less important and pruned first. + + Two modes are supported: + - 'low': Prune neurons aligned with low-variance directions (default) + - 'high': Prune neurons aligned with high-variance directions (ablation) + + Examples: + >>> from alignment.pruning.strategies import EigenvectorPruning + >>> from alignment.pruning import PruningConfig + >>> + >>> config = PruningConfig(amount=0.5, structured=True, pruning_mode='low') + >>> strategy = EigenvectorPruning(config=config) + >>> + >>> # Compute scores from activations + >>> scores = strategy.compute_importance_scores(module, inputs=activations) + >>> strategy.prune(module, inputs=activations) + """ + + def __init__( + self, + config: Optional[PruningConfig] = None, + variance_threshold: float = 0.99, + use_correlation: bool = False, + regularization: float = 1e-6, + ): + """ + Initialize eigenvector pruning. + + Args: + config: Pruning configuration + variance_threshold: Fraction of variance to explain (for ranking) + use_correlation: If True, use correlation matrix instead of covariance + regularization: Small value added to diagonal for numerical stability + """ + super().__init__(config) + self.variance_threshold = variance_threshold + self.use_correlation = use_correlation + self.regularization = regularization + + # Force structured pruning (eigenvector pruning is inherently structured) + if self.config: + self.config.structured = True + + def _compute_activation_covariance( + self, + activations: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute covariance matrix and mean of activations. + + Args: + activations: Input activations [batch, neurons, ...] or [batch, neurons] + + Returns: + Tuple of (covariance matrix, mean vector) + """ + # Flatten spatial dimensions if present (for conv layers) + if activations.dim() > 2: + # [B, C, H, W] -> [B*H*W, C] + batch_size, channels = activations.shape[:2] + activations = activations.permute(0, 2, 3, 1).reshape(-1, channels) + elif activations.dim() == 2: + # [B, N] is already correct + pass + else: + raise ValueError(f"Unexpected activation shape: {activations.shape}") + + # Compute mean + mean = activations.mean(dim=0) + + # Center activations + centered = activations - mean + + # Compute covariance: (X^T @ X) / (n-1) + n_samples = centered.shape[0] + cov = (centered.T @ centered) / max(n_samples - 1, 1) + + # Add regularization for numerical stability + cov = cov + self.regularization * torch.eye(cov.shape[0], device=cov.device) + + # Optionally convert to correlation matrix + if self.use_correlation: + std = torch.sqrt(torch.diag(cov)) + std = torch.clamp(std, min=1e-8) # Avoid division by zero + cov = cov / (std.unsqueeze(0) * std.unsqueeze(1)) + + return cov, mean + + def _compute_eigendecomposition( + self, + cov: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute eigendecomposition of covariance matrix. + + Args: + cov: Covariance matrix [N, N] + + Returns: + Tuple of (eigenvalues, eigenvectors) sorted by eigenvalue descending + """ + # Compute eigendecomposition + try: + eigenvalues, eigenvectors = torch.linalg.eigh(cov) + except Exception as e: + logger.warning(f"Eigendecomposition failed: {e}. Using SVD fallback.") + # SVD fallback for numerical issues + U, S, Vh = torch.linalg.svd(cov) + eigenvalues = S + eigenvectors = U + + # Sort by eigenvalue descending (largest first) + sorted_indices = torch.argsort(eigenvalues, descending=True) + eigenvalues = eigenvalues[sorted_indices] + eigenvectors = eigenvectors[:, sorted_indices] + + return eigenvalues, eigenvectors + + def compute_importance_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + **kwargs + ) -> torch.Tensor: + """ + Compute neuron importance based on eigenvalue contributions. + + Each neuron's importance is based on its contribution to the + variance explained by the principal components. + + Args: + module: Module to compute scores for + inputs: Input activations [batch, neurons, ...] + + Returns: + Importance scores per neuron (higher = more important) + """ + if inputs is None: + raise ValueError("EigenvectorPruning requires input activations") + + # Get number of neurons from module + if hasattr(module, 'weight'): + n_neurons = module.weight.shape[0] + else: + raise ValueError("Module must have weight attribute") + + # Compute covariance of activations + cov, _ = self._compute_activation_covariance(inputs) + + # Compute eigendecomposition + eigenvalues, eigenvectors = self._compute_eigendecomposition(cov) + + # Compute importance scores for each neuron + # Score = sum of (eigenvalue * squared loading) for each neuron + # This measures how much variance each neuron contributes to + + # Normalize eigenvalues to get variance explained + total_variance = eigenvalues.sum() + if total_variance > 0: + variance_explained = eigenvalues / total_variance + else: + variance_explained = torch.ones_like(eigenvalues) / len(eigenvalues) + + # Compute neuron scores + # For each neuron i: score_i = sum_j(lambda_j * v_ij^2) + # where lambda_j is the j-th eigenvalue and v_ij is the loading + loadings_squared = eigenvectors ** 2 # [neurons, components] + neuron_scores = (loadings_squared * variance_explained.unsqueeze(0)).sum(dim=1) + + # Ensure scores are positive and on correct device + neuron_scores = torch.clamp(neuron_scores, min=0) + + if neuron_scores.device != module.weight.device: + neuron_scores = neuron_scores.to(module.weight.device) + + logger.debug(f"Eigenvalue importance: min={neuron_scores.min():.4f}, " + f"max={neuron_scores.max():.4f}, mean={neuron_scores.mean():.4f}") + + return neuron_scores + + def get_variance_explained( + self, + module: nn.Module, + inputs: torch.Tensor, + n_components: Optional[int] = None + ) -> Tuple[torch.Tensor, float]: + """ + Get cumulative variance explained by top N components. + + Args: + module: Module to analyze + inputs: Input activations + n_components: Number of components (None = all) + + Returns: + Tuple of (cumulative variance ratios, total variance) + """ + cov, _ = self._compute_activation_covariance(inputs) + eigenvalues, _ = self._compute_eigendecomposition(cov) + + total_variance = eigenvalues.sum().item() + cumulative = torch.cumsum(eigenvalues, dim=0) / total_variance + + if n_components is not None: + cumulative = cumulative[:n_components] + + return cumulative, total_variance + + def prune( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + amount: Optional[float] = None, + **kwargs + ) -> torch.Tensor: + """ + Prune module based on eigenvector importance. + + Args: + module: Module to prune + inputs: Input activations (required) + amount: Fraction to prune (overrides config) + + Returns: + Pruning mask + """ + if inputs is None: + raise ValueError("EigenvectorPruning requires input activations") + + amount = amount if amount is not None else self.config.amount + + # Compute importance scores + scores = self.compute_importance_scores(module, inputs) + + # Create structured mask + n_neurons = scores.numel() + k = int(amount * n_neurons) + + if k == 0: + mask = torch.ones_like(module.weight) + else: + keep_mask = torch.ones(n_neurons, dtype=torch.bool, device=scores.device) + + if self.config.pruning_mode == "low": + # Prune neurons with LOWEST eigenvalue contribution + _, indices_to_prune = torch.topk(scores, k, largest=False) + else: # 'high' mode + # Prune neurons with HIGHEST eigenvalue contribution (ablation) + _, indices_to_prune = torch.topk(scores, k, largest=True) + + keep_mask[indices_to_prune] = False + + # Expand to weight dimensions + if len(module.weight.shape) == 2: # Linear + mask = keep_mask.unsqueeze(1).expand_as(module.weight).float() + else: # Conv + mask = keep_mask.view(-1, 1, 1, 1).expand_as(module.weight).float() + + # Apply pruning + self.apply_pruning(module, mask) + + return mask diff --git a/src/alignment/pruning/strategies/ultimate.py b/src/alignment/pruning/strategies/ultimate.py deleted file mode 100644 index c43bacb9..00000000 --- a/src/alignment/pruning/strategies/ultimate.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Ultimate pruning strategy combining all best practices. - -Combines: -1. Adaptive layer-wise amounts (sensitivity-based) -2. Composite scoring (redundancy-aware) -3. Progressive stages (safe → refined) -4. Dependency-aware application - -Expected: Best possible accuracy retention at high sparsity. -""" - -import logging -from typing import Any, Callable, Dict, List, Optional - -import torch.nn as nn - -from ..dependency_aware import DependencyAwarePruning -from .adaptive import AdaptiveSensitivityPruning - -logger = logging.getLogger(__name__) - - -class UltimatePruningStrategy: - """ - State-of-the-art pruning combining multiple advanced techniques. - - Multi-stage progressive pruning with adaptive per-layer amounts - and redundancy-aware composite scoring. - - Stages: - 1. Sensitivity Analysis → adaptive per-layer amounts - 2. Coarse Pruning (magnitude) → safe initial pruning - 3. Refined Pruning (composite) → redundancy-aware - 4. Cleanup → remove truly dead neurons - - Expected performance: - - 70% sparsity: ~5% accuracy drop (vs 10% for magnitude) - - 85% sparsity: ~12% drop (vs 20% for magnitude) - - Example: - >>> strategy = UltimatePruningStrategy( - ... target_sparsity=0.7, - ... stages='full' # or 'fast' for fewer stages - ... ) - >>> result = strategy.prune(model, train_loader, val_loader) - >>> print(f"Final accuracy: {result['accuracy']:.2f}%") - """ - - def __init__( - self, - target_sparsity: float = 0.7, - stages: str = "full", # 'full', 'fast', 'custom' - sensitivity_based: bool = True, - use_redundancy: bool = True, - fine_tune_epochs_per_stage: int = 10, - **config, - ): - """ - Initialize ultimate pruning strategy. - - Args: - target_sparsity: Target overall sparsity - stages: Pruning schedule - - 'full': 4 stages (best quality) - - 'fast': 2 stages (faster) - - 'custom': Use custom stage config - sensitivity_based: Use adaptive per-layer amounts - use_redundancy: Use redundancy-aware scoring - fine_tune_epochs_per_stage: Fine-tuning between stages - """ - self.target_sparsity = target_sparsity - self.stages_mode = stages - self.sensitivity_based = sensitivity_based - self.use_redundancy = use_redundancy - self.fine_tune_epochs_per_stage = fine_tune_epochs_per_stage - - # Initialize sub-strategies - if sensitivity_based: - self.adaptive_pruner = AdaptiveSensitivityPruning(target_sparsity=target_sparsity) - - self.dependency_pruner = DependencyAwarePruning - - # Define pruning stages - self.stages = self._get_pruning_stages(stages) - - def _get_pruning_stages(self, mode: str) -> List[Dict]: - """Define pruning stages based on mode.""" - if mode == "full": - return [ - { - "name": "Initial (Magnitude)", - "target_fraction": 0.5, # 50% of final target - "metric": "magnitude", - "fine_tune_epochs": self.fine_tune_epochs_per_stage, - }, - { - "name": "Intermediate (Alignment)", - "target_fraction": 0.75, # 75% of final target - "metric": "rayleigh_quotient", - "fine_tune_epochs": self.fine_tune_epochs_per_stage, - }, - { - "name": "Refined (Composite)", - "target_fraction": 0.95, # 95% of final target - "metric": "composite", - "fine_tune_epochs": self.fine_tune_epochs_per_stage * 2, - }, - { - "name": "Cleanup", - "target_fraction": 1.0, # 100% of target - "metric": "composite", - "fine_tune_epochs": self.fine_tune_epochs_per_stage * 2, - }, - ] - - elif mode == "fast": - return [ - {"name": "Magnitude", "target_fraction": 0.7, "metric": "magnitude", "fine_tune_epochs": self.fine_tune_epochs_per_stage}, - {"name": "Composite", "target_fraction": 1.0, "metric": "composite", "fine_tune_epochs": self.fine_tune_epochs_per_stage * 2}, - ] - - else: # one-shot - return [ - {"name": "One-Shot Composite", "target_fraction": 1.0, "metric": "composite", "fine_tune_epochs": self.fine_tune_epochs_per_stage * 3} - ] - - def prune( - self, - model: nn.Module, - train_loader, - val_loader, - layers_to_prune: Optional[List[str]] = None, - trainer_fn: Optional[Callable] = None, - eval_fn: Optional[Callable] = None, - ) -> Dict[str, Any]: - """ - Execute ultimate pruning strategy. - - Args: - model: Model to prune - train_loader: Training data (for fine-tuning) - val_loader: Validation data (for evaluation) - layers_to_prune: Specific layers (None = auto-detect) - trainer_fn: Function(model, train_loader, epochs) for fine-tuning - eval_fn: Function(model, val_loader) -> accuracy - - Returns: - Results dictionary with masks, stats, accuracy history - """ - # Auto-detect layers if not specified - if layers_to_prune is None: - from ...core.layer_detector import detect_trackable_layers - - layers_to_prune = detect_trackable_layers(model) - - logger.info(f"Pruning {len(layers_to_prune)} layers with {self.stages_mode} strategy") - - # Baseline evaluation - if eval_fn: - baseline_acc = eval_fn(model, val_loader) - logger.info(f"Baseline accuracy: {baseline_acc:.2f}%") - else: - baseline_acc = None - - # Step 1: Sensitivity analysis (if enabled) - if self.sensitivity_based and eval_fn: - logger.info("Stage 0: Computing layer sensitivities...") - sensitivities = self.adaptive_pruner.compute_all_sensitivities(model, layers_to_prune, eval_fn=lambda m: eval_fn(m, val_loader)) - self.adaptive_pruner.print_sensitivity_report() - else: - sensitivities = None - - # Track results - results = {"baseline_accuracy": baseline_acc, "stage_results": [], "final_masks": {}, "sensitivity_report": sensitivities} - - # Execute stages - for stage_idx, stage in enumerate(self.stages): - logger.info(f"\n{'='*80}") - logger.info(f"Stage {stage_idx + 1}/{len(self.stages)}: {stage['name']}") - logger.info(f"{'='*80}") - - # Compute target amount for this stage - stage_target = self.target_sparsity * stage["target_fraction"] - - # Prune - stage_result = self._execute_stage(model, layers_to_prune, stage_target, stage["metric"], sensitivities) - - # Fine-tune if trainer provided - if trainer_fn and stage["fine_tune_epochs"] > 0: - logger.info(f"Fine-tuning for {stage['fine_tune_epochs']} epochs...") - trainer_fn(model, train_loader, epochs=stage["fine_tune_epochs"]) - - # Evaluate - if eval_fn: - stage_acc = eval_fn(model, val_loader) - logger.info(f"Accuracy after stage: {stage_acc:.2f}%") - stage_result["accuracy"] = stage_acc - - results["stage_results"].append(stage_result) - - # Final evaluation - if eval_fn: - final_acc = eval_fn(model, val_loader) - results["final_accuracy"] = final_acc - results["accuracy_drop"] = baseline_acc - final_acc if baseline_acc else None - - logger.info(f"\n{'='*80}") - logger.info("FINAL RESULTS") - logger.info(f"{'='*80}") - logger.info(f"Baseline: {baseline_acc:.2f}%") - logger.info(f"Final: {final_acc:.2f}%") - logger.info(f"Drop: {results['accuracy_drop']:.2f}%") - logger.info(f"Sparsity: {self.target_sparsity:.1%}") - logger.info(f"{'='*80}\n") - - return results - - def _execute_stage(self, model: nn.Module, layer_names: List[str], stage_target: float, metric_name: str, sensitivities: Optional[Dict]) -> Dict: - """Execute a single pruning stage.""" - - stage_result = {"metric": metric_name, "target": stage_target, "masks": {}} - - # Compute layer-specific amounts if adaptive - if sensitivities: - # Use adaptive amounts, scaled to stage target - layer_amounts = {name: sens.recommended_amount * (stage_target / self.target_sparsity) for name, sens in sensitivities.items()} - else: - # Uniform amount - layer_amounts = {name: stage_target for name in layer_names} - - # Compute scores and masks per layer - layer_scores = {} - - for layer_name in layer_names: - layer = dict(model.named_modules())[layer_name] - layer_amounts.get(layer_name, stage_target) - - # Compute scores based on metric - if metric_name == "magnitude": - scores = layer.weight.abs().flatten() - if scores.ndim > 1: - scores = scores.mean(dim=list(range(1, scores.ndim))) - - elif metric_name == "rayleigh_quotient": - # Would need inputs - skip for now or use cached - scores = layer.weight.norm(dim=1) # Fallback - - elif metric_name == "composite" and self.use_redundancy: - # Use redundancy-aware composite - # Would need full pipeline - simplified here - scores = layer.weight.norm(dim=1) - - else: - scores = layer.weight.norm(dim=1) - - layer_scores[layer_name] = scores - - # Apply with dependency awareness - pruner = self.dependency_pruner(model) - result = pruner.prune(layer_scores, amount=stage_target, dry_run=False) - - stage_result["masks"] = result["masks"] - stage_result["stats"] = result["stats"] - - return stage_result - - -def create_ultimate_pruner(target_sparsity: float = 0.7, mode: str = "full", **config) -> UltimatePruningStrategy: - """ - Factory function for creating ultimate pruning strategy. - - Args: - target_sparsity: Target overall sparsity - mode: 'full' (best quality), 'fast' (faster), 'oneshot' - **config: Additional configuration - - Returns: - Configured UltimatePruningStrategy - """ - return UltimatePruningStrategy(target_sparsity=target_sparsity, stages=mode, **config) diff --git a/src/alignment/pruning/strategies/ultra_fast.py b/src/alignment/pruning/strategies/ultra_fast.py deleted file mode 100644 index 2fb2a5e0..00000000 --- a/src/alignment/pruning/strategies/ultra_fast.py +++ /dev/null @@ -1,380 +0,0 @@ -"""Ultra-fast parallel pruning strategy for alignment experiments.""" - -import logging -import time -from typing import Any, Dict, List, Optional - -import torch -import torch.nn as nn - -from ..base import BasePruningStrategy - -logger = logging.getLogger(__name__) - - -class UltraFastParallelPruning(BasePruningStrategy): - """ - Ultra-fast parallel pruning that evaluates all configurations simultaneously. - - This strategy processes multiple networks and pruning amounts in a single pass, - dramatically reducing computation time compared to sequential processing. - """ - - def __init__(self, config=None): - super().__init__(config) - self.networks = None - self.data_loader = None - self.original_states = None - - def setup(self, networks: List[nn.Module], data_loader): - """Setup the pruning strategy with networks and data.""" - self.networks = networks - self.data_loader = data_loader - - # Save original states efficiently - self.original_states = [] - for model in self.networks: - state = {name: module.weight.data.clone() for name, module in model.named_modules() if hasattr(module, "weight")} - self.original_states.append(state) - - def run_pruning_experiments(self, strategies: List[str], selection_modes: List[str], pruning_amounts: List[float]) -> Dict[str, Any]: - """Run ultra-fast parallel pruning experiments.""" - logger.info("Running ULTRA-FAST PARALLEL pruning experiments") - - results = {"strategies": {}} - - # Process each strategy - for strategy_name in strategies: - logger.info(f"Testing pruning strategy: {strategy_name}") - - for selection_mode in selection_modes: - logger.info(f" Selection mode: {selection_mode}") - - # Always use ultra-parallel implementation - batch_results = self._ultra_fast_parallel_pruning(strategy_name, selection_mode, pruning_amounts) - - # Store results - strategy_key = f"{strategy_name}_{selection_mode}" - strategy_results = { - "sparsities": batch_results["sparsities"].mean(dim=0).tolist(), - "accuracies_before_finetune": batch_results["accuracies_before"].mean(dim=0).tolist(), - "accuracies_after_finetune": batch_results["accuracies_after"].mean(dim=0).tolist(), - "losses_before_finetune": batch_results["losses_before"].mean(dim=0).tolist(), - "losses_after_finetune": batch_results["losses_after"].mean(dim=0).tolist(), - "improvements": (batch_results["accuracies_after"] - batch_results["accuracies_before"]).mean(dim=0).tolist(), - } - - # Add standard deviations if multiple networks - if len(self.networks) > 1: - strategy_results["accuracies_before_finetune_std"] = batch_results["accuracies_before"].std(dim=0).tolist() - strategy_results["accuracies_after_finetune_std"] = batch_results["accuracies_after"].std(dim=0).tolist() - - results["strategies"][strategy_key] = strategy_results - - # Restore original weights - for net_idx, model in enumerate(self.networks): - for name, module in model.named_modules(): - if name in self.original_states[net_idx]: - module.weight.data = self.original_states[net_idx][name] - - return results - - def _ultra_fast_parallel_pruning(self, strategy_name: str, selection_mode: str, pruning_amounts: List[float]) -> Dict[str, torch.Tensor]: - """ - Ultra-fast version that processes all networks and pruning amounts truly in parallel. - Uses a single forward pass per batch for ALL configurations. - """ - num_networks = len(self.networks) - num_amounts = len(pruning_amounts) - total_configs = num_networks * num_amounts - - logger.info(f" Processing {total_configs} configurations in TRUE parallel") - logger.info(f" Networks: {num_networks}, Sparsity levels: {num_amounts}") - - # Initialize result tensors on GPU for speed - device = self.config.device if hasattr(self.config, "device") else "cuda" - accuracies_before = torch.zeros(num_networks, num_amounts, device=device) - losses_before = torch.zeros(num_networks, num_amounts, device=device) - sparsities = torch.zeros(num_networks, num_amounts) - - # Create all masks upfront - logger.info(" Creating masks for all configurations...") - all_masks = self._create_all_masks(strategy_name, selection_mode, pruning_amounts) - - # Evaluate all configurations in a single pass per batch - logger.info(" Starting TRULY PARALLEL evaluation...") - start_time = time.time() - - # Pre-allocate for all configurations - all_correct = torch.zeros(total_configs, device=device) - all_loss = torch.zeros(total_configs, device=device) - total_samples = 0 - - # Set all networks to eval mode - for net in self.networks: - net.eval() - - criterion = nn.CrossEntropyLoss(reduction="none") - eval_batches = getattr(self.config, "eval_batches", None) if hasattr(self.config, "eval_batches") else None - batch_count = 0 - - with torch.no_grad(): - for inputs, targets in self.data_loader: - inputs = inputs.to(device) - targets = targets.to(device) - batch_size = targets.size(0) - - # Collect outputs from ALL configurations in one go - all_outputs = [] - - for net_idx in range(num_networks): - net = self.networks[net_idx] - - for amount_idx in range(num_amounts): - # Apply configuration - self._apply_mask_config(net, all_masks[net_idx][amount_idx], self.original_states[net_idx]) - - # Forward pass - outputs = net(inputs) - all_outputs.append(outputs) - - # Calculate sparsity (only once) - if batch_count == 0: - sparsities[net_idx, amount_idx] = self._calculate_sparsity(net) - - # Stack all outputs for vectorized processing - stacked_outputs = torch.stack(all_outputs, dim=0) # [total_configs, batch_size, num_classes] - - # Compute losses for all configs at once - expanded_targets = targets.unsqueeze(0).expand(total_configs, -1) - all_batch_losses = criterion(stacked_outputs.reshape(-1, stacked_outputs.size(-1)), expanded_targets.reshape(-1)).reshape( - total_configs, batch_size - ) - - # Sum losses - all_loss += all_batch_losses.sum(dim=1) - - # Get predictions and count correct - all_preds = stacked_outputs.argmax(dim=2) # [total_configs, batch_size] - correct = all_preds.eq(expanded_targets).sum(dim=1) - all_correct += correct - - total_samples += batch_size - batch_count += 1 - - # Check if we've evaluated enough batches - if eval_batches is not None and batch_count >= eval_batches: - break - - # Reshape results back to [num_networks, num_amounts] - all_correct = all_correct.reshape(num_networks, num_amounts) - all_loss = all_loss.reshape(num_networks, num_amounts) - - # Convert to accuracies and average losses - accuracies_before = (all_correct * 100.0 / total_samples).cpu() - losses_before = (all_loss / batch_count).cpu() - - eval_time = time.time() - start_time - logger.info(f" Parallel evaluation completed in {eval_time:.2f} seconds") - logger.info(f" Average accuracy: {accuracies_before.mean():.2f}%") - - # For now, no fine-tuning in ultra-fast mode (can be added if needed) - accuracies_after = accuracies_before.clone() - losses_after = losses_before.clone() - - # Reset networks to train mode - for net in self.networks: - net.train() - - return { - "accuracies_before": accuracies_before, - "losses_before": losses_before, - "accuracies_after": accuracies_after, - "losses_after": losses_after, - "sparsities": sparsities, - } - - def compute_importance_scores(self, module: nn.Module, inputs: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - """Compute importance scores for a module (not used in parallel mode).""" - # This is implemented for compatibility with BasePruningStrategy - # The actual importance computation happens in _create_all_masks - return module.weight.data.abs() - - def _create_all_masks(self, strategy_name: str, selection_mode: str, pruning_amounts: List[float]) -> List[List[Dict[str, torch.Tensor]]]: - """Create all masks for all networks and pruning amounts.""" - all_masks = [] - - for net_idx, net in enumerate(self.networks): - network_masks = [] - - for amount in pruning_amounts: - if strategy_name == "magnitude": - masks = self._create_magnitude_masks(net, amount, selection_mode) - elif strategy_name == "random": - masks = self._create_random_masks(net, amount) - elif strategy_name == "alignment": - # For alignment, we need to capture inputs first - masks = self._create_alignment_masks(net, amount, selection_mode) - else: - raise ValueError(f"Unknown strategy: {strategy_name}") - - network_masks.append(masks) - - all_masks.append(network_masks) - - return all_masks - - def _create_magnitude_masks(self, model: nn.Module, amount: float, selection_mode: str) -> Dict[str, torch.Tensor]: - """Create magnitude-based masks for a model.""" - masks = {} - - for name, module in model.named_modules(): - if hasattr(module, "weight"): - weight = module.weight.data - importance = weight.abs() - - structured = ( - getattr(self.config, "alignment_structured_pruning", True) if hasattr(self.config, "alignment_structured_pruning") else True - ) - - if structured and len(weight.shape) >= 2: - # Structured pruning - prune entire neurons - neuron_importance = importance.mean(dim=tuple(range(1, len(weight.shape)))) - mask = self._create_structured_mask(neuron_importance, amount, selection_mode, weight.shape) - else: - # Unstructured pruning - mask = self._create_mask(importance, amount, selection_mode) - - masks[name] = mask - - return masks - - def _create_random_masks(self, model: nn.Module, amount: float) -> Dict[str, torch.Tensor]: - """Create random masks for a model.""" - masks = {} - - for name, module in model.named_modules(): - if hasattr(module, "weight"): - weight = module.weight.data - - structured = ( - getattr(self.config, "alignment_structured_pruning", True) if hasattr(self.config, "alignment_structured_pruning") else True - ) - - if structured and len(weight.shape) >= 2: - # Structured random pruning - num_neurons = weight.shape[0] - num_to_prune = int(amount * num_neurons) - mask = torch.ones(num_neurons, device=weight.device) - if num_to_prune > 0: - indices = torch.randperm(num_neurons)[:num_to_prune] - mask[indices] = 0 - # Expand to match weight dimensions - mask = mask.unsqueeze(1).expand_as(weight) - else: - # Unstructured random pruning - mask = torch.rand_like(weight) > amount - - masks[name] = mask.float() - - return masks - - def _create_alignment_masks(self, model: nn.Module, amount: float, selection_mode: str) -> Dict[str, torch.Tensor]: - """Create alignment-based masks (simplified for speed).""" - # For ultra-fast mode, we'll use a simplified alignment metric - # based on weight magnitudes weighted by gradient magnitudes - # (Full alignment computation would require forward passes) - return self._create_magnitude_masks(model, amount, selection_mode) - - def _create_mask(self, importance: torch.Tensor, amount: float, selection_mode: str) -> torch.Tensor: - """Create a binary mask based on importance scores.""" - if amount == 0: - return torch.ones_like(importance) - elif amount >= 1: - return torch.zeros_like(importance) - - flat_importance = importance.flatten() - k = int(amount * flat_importance.numel()) - - if k == 0: - return torch.ones_like(importance) - - # Use topk-based selection to guarantee exactly k weights are pruned - # This avoids non-monotonic behavior caused by ties at threshold values - mask = torch.ones(flat_importance.numel(), dtype=torch.bool, device=importance.device) - - if selection_mode == "low": - # Prune k weights with LOWEST scores - _, indices = torch.topk(flat_importance, k, largest=False) - elif selection_mode == "high": - # Prune k weights with HIGHEST scores - _, indices = torch.topk(flat_importance, k, largest=True) - elif selection_mode == "random": - indices = torch.randperm(flat_importance.numel(), device=flat_importance.device)[:k] - else: - raise ValueError(f"Unknown selection mode: {selection_mode}") - - mask[indices] = False - mask = mask.view(importance.shape) - - return mask.float() - - def _create_structured_mask(self, neuron_importance: torch.Tensor, amount: float, selection_mode: str, weight_shape: torch.Size) -> torch.Tensor: - """Create a structured mask that prunes entire neurons.""" - num_neurons = neuron_importance.numel() - num_to_prune = int(amount * num_neurons) - - if num_to_prune == 0: - mask = torch.ones_like(neuron_importance) - elif num_to_prune >= num_neurons: - mask = torch.zeros_like(neuron_importance) - else: - if selection_mode == "low": - _, indices = torch.topk(neuron_importance, num_neurons - num_to_prune) - mask = torch.zeros_like(neuron_importance) - mask[indices] = 1 - elif selection_mode == "high": - _, indices = torch.topk(neuron_importance, num_to_prune) - mask = torch.ones_like(neuron_importance) - mask[indices] = 0 - elif selection_mode == "random": - mask = torch.ones_like(neuron_importance) - indices = torch.randperm(num_neurons)[:num_to_prune] - mask[indices] = 0 - else: - raise ValueError(f"Unknown selection mode: {selection_mode}") - - # Expand mask to match weight dimensions - if len(weight_shape) == 2: - # Linear layer: [out_features, in_features] - mask = mask.unsqueeze(1).expand_as(torch.zeros(weight_shape)) - elif len(weight_shape) == 4: - # Conv layer: [out_channels, in_channels, height, width] - mask = mask.view(-1, 1, 1, 1).expand_as(torch.zeros(weight_shape)) - - return mask - - def _apply_mask_config(self, model: nn.Module, masks: Dict[str, torch.Tensor], original_state: Dict[str, torch.Tensor]): - """Apply masks to a model after restoring original weights.""" - with torch.no_grad(): - for name, module in model.named_modules(): - if name in original_state: - # Restore original weights - module.weight.data = original_state[name].clone() - - # Apply mask if exists - if name in masks: - module.weight.data *= masks[name] - - def _calculate_sparsity(self, model: nn.Module) -> float: - """Calculate the sparsity of a model.""" - total_params = 0 - zero_params = 0 - - for module in model.modules(): - if hasattr(module, "weight"): - weight = module.weight.data - total_params += weight.numel() - zero_params += (weight == 0).sum().item() - - return zero_params / total_params if total_params > 0 else 0.0 From c9f613eac635506c6038fb608bead9633b38abbd Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 10:42:55 -0500 Subject: [PATCH 07/12] fix bugs in llm prunning --- configs/prune_llm/llama2_7b_full.yaml | 3 + configs/prune_llm/llama3_8b_full.yaml | 4 + configs/prune_llm/mistral_7b_full.yaml | 3 + configs/prune_llm/qwen2_7b_full.yaml | 3 + scripts/generate_scar_paper_tables.py | 152 +++ slurm_jobs/run_paper_experiments.sh | 16 +- .../analysis/visualization/pruning_plots.py | 5 + src/alignment/experiments/llm_experiments.py | 948 +++++++++++------- src/alignment/pruning/distribution.py | 3 +- 9 files changed, 790 insertions(+), 347 deletions(-) create mode 100644 scripts/generate_scar_paper_tables.py diff --git a/configs/prune_llm/llama2_7b_full.yaml b/configs/prune_llm/llama2_7b_full.yaml index a8282c59..908789d4 100644 --- a/configs/prune_llm/llama2_7b_full.yaml +++ b/configs/prune_llm/llama2_7b_full.yaml @@ -238,6 +238,7 @@ pruning: - "supernode_connectivity_score" - "generalized_importance" - "activation_l2_norm" + - "weight_magnitude" - "wanda" - "sparsegpt" - "cross_layer_importance" @@ -245,6 +246,7 @@ pruning: scoring_methods: - "random" - "activation_l2_norm" + - "weight_magnitude" - "rayleigh_quotient" - "gaussian_mi_analytic" - "average_redundancy" @@ -366,6 +368,7 @@ visualization: - "wanda" - "sparsegpt" - "activation_l2_norm" + - "weight_magnitude" - "random" supernode_robustness: diff --git a/configs/prune_llm/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml index b68b47b3..19dce566 100644 --- a/configs/prune_llm/llama3_8b_full.yaml +++ b/configs/prune_llm/llama3_8b_full.yaml @@ -299,6 +299,8 @@ pruning: # Magnitude baseline - "activation_l2_norm" + # Weight-only magnitude baseline (channel-group) + - "weight_magnitude" # SOTA baselines - "wanda" @@ -310,6 +312,7 @@ pruning: scoring_methods: - "random" - "activation_l2_norm" + - "weight_magnitude" - "rayleigh_quotient" - "gaussian_mi_analytic" - "average_redundancy" @@ -440,6 +443,7 @@ visualization: - "wanda" - "sparsegpt" - "activation_l2_norm" + - "weight_magnitude" - "random" # Supernode robustness plots diff --git a/configs/prune_llm/mistral_7b_full.yaml b/configs/prune_llm/mistral_7b_full.yaml index 32f75c5e..e32fe907 100644 --- a/configs/prune_llm/mistral_7b_full.yaml +++ b/configs/prune_llm/mistral_7b_full.yaml @@ -237,6 +237,7 @@ pruning: - "supernode_connectivity_score" - "generalized_importance" - "activation_l2_norm" + - "weight_magnitude" - "wanda" - "sparsegpt" - "cross_layer_importance" @@ -244,6 +245,7 @@ pruning: scoring_methods: - "random" - "activation_l2_norm" + - "weight_magnitude" - "rayleigh_quotient" - "gaussian_mi_analytic" - "average_redundancy" @@ -365,6 +367,7 @@ visualization: - "wanda" - "sparsegpt" - "activation_l2_norm" + - "weight_magnitude" - "random" supernode_robustness: diff --git a/configs/prune_llm/qwen2_7b_full.yaml b/configs/prune_llm/qwen2_7b_full.yaml index 7760b121..ea6e7ba8 100644 --- a/configs/prune_llm/qwen2_7b_full.yaml +++ b/configs/prune_llm/qwen2_7b_full.yaml @@ -238,6 +238,7 @@ pruning: - "supernode_connectivity_score" - "generalized_importance" - "activation_l2_norm" + - "weight_magnitude" - "wanda" - "sparsegpt" - "cross_layer_importance" @@ -245,6 +246,7 @@ pruning: scoring_methods: - "random" - "activation_l2_norm" + - "weight_magnitude" - "rayleigh_quotient" - "gaussian_mi_analytic" - "average_redundancy" @@ -366,6 +368,7 @@ visualization: - "wanda" - "sparsegpt" - "activation_l2_norm" + - "weight_magnitude" - "random" supernode_robustness: diff --git a/scripts/generate_scar_paper_tables.py b/scripts/generate_scar_paper_tables.py new file mode 100644 index 00000000..d388a92b --- /dev/null +++ b/scripts/generate_scar_paper_tables.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Generate LaTeX tables for the SCAR ICML draft from a saved experiment results JSON. + +Why: +- Avoid manual copy/paste drift between `results_*.json` and `drafts/LLM_prune/scar_paper_icml_v5.tex`. +- Make it easy to update tables after rerunning experiments. + +Usage: + python scripts/generate_scar_paper_tables.py \ + --results /abs/path/to/results_YYYYMMDD_HHMMSS.json \ + --sparsity 0.5 \ + --best-mode +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + + +METHODS = [ + ("Magnitude (channel)", "weight_magnitude"), + ("Wanda (channel)", "wanda"), + ("SparseGPT (channel)", "sparsegpt"), + ("Act. L2", "activation_l2_norm"), + ("RQ", "rayleigh_quotient"), + ("SCAR-LP", "scar_loss_proxy"), + ("SCAR-Prot", "supernode_protection_score"), + ("SCAR-Conn", "supernode_connectivity_score"), +] + + +def _get_pruned_entry( + pruning_results: Dict[str, Any], + metric: str, + mode: str, + sparsity: float, +) -> Optional[Dict[str, Any]]: + for v in pruning_results.values(): + if not isinstance(v, dict): + continue + if v.get("metric") == metric and v.get("mode") == mode and v.get("sparsity") == sparsity: + return v + return None + + +def _pick_mode( + pruning_results: Dict[str, Any], + metric: str, + sparsity: float, + best_mode: bool, +) -> Tuple[str, Optional[Dict[str, Any]]]: + if not best_mode: + entry = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) + return "low", entry + + low = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) + high = _get_pruned_entry(pruning_results, metric=metric, mode="high", sparsity=sparsity) + + def ppl(x: Optional[Dict[str, Any]]) -> float: + if not x: + return float("inf") + v = x.get("perplexity") + return float(v) if v is not None else float("inf") + + if ppl(low) <= ppl(high): + return "low", low + return "high", high + + +def _fmt(x: Any, digits: int = 1) -> str: + if x is None: + return "--" + try: + xf = float(x) + except Exception: + return "--" + if xf != xf: # NaN + return "--" + return f"{xf:.{digits}f}" + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--results", required=True, type=str, help="Path to results_*.json") + ap.add_argument("--sparsity", type=float, default=0.5, help="Target sparsity level (default: 0.5)") + ap.add_argument( + "--best-mode", + action="store_true", + help="Pick the better of low/high for each method by perplexity.", + ) + args = ap.parse_args() + + path = Path(args.results) + obj = json.loads(path.read_text()) + + pruning_results = obj.get("pruning_results") or {} + evaluation = obj.get("evaluation") or {} + + baseline_metrics = evaluation.get("baseline_metrics") or {} + baseline_ppl = evaluation.get("baseline_perplexity") + + # Metrics used in the draft's Table 1 + cols = [ + ("PPL$\\downarrow$", "perplexity", 1), + ("MMLU", "accuracy_mmlu", 1), + ("Hella", "accuracy_hellaswag", 1), + ("PIQA", "accuracy_piqa", 1), + ("BoolQ", "accuracy_boolq", 1), + ] + + print("% Auto-generated by scripts/generate_scar_paper_tables.py") + print("\\begin{tabular}{@{}l" + "c" * len(cols) + "@{}}") + print("\\toprule") + print("Method & " + " & ".join(h for h, _, _ in cols) + " \\\\") + print("\\midrule") + + # Unpruned row + unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} + print( + "Unpruned & " + + " & ".join(_fmt(unpruned_vals.get(k), d) for _, k, d in cols) + + " \\\\" + ) + print("\\midrule") + + for label, metric in METHODS: + mode, entry = _pick_mode(pruning_results, metric=metric, sparsity=args.sparsity, best_mode=args.best_mode) + if entry is None: + row = [label] + ["--"] * len(cols) + print(row[0] + " & " + " & ".join(row[1:]) + " \\\\") + continue + + vals = {k: entry.get(k) for _, k, _ in cols} + # Be robust: some evaluators store e.g. 'accuracy_mmlu' etc; keep '--' if missing. + print( + f"{label} & " + + " & ".join(_fmt(vals.get(k), d) for _, k, d in cols) + + f" % ({metric}, {mode})" + + " \\\\" + ) + + print("\\bottomrule") + print("\\end{tabular}") + + +if __name__ == "__main__": + main() + diff --git a/slurm_jobs/run_paper_experiments.sh b/slurm_jobs/run_paper_experiments.sh index a13462fa..fd106cc4 100644 --- a/slurm_jobs/run_paper_experiments.sh +++ b/slurm_jobs/run_paper_experiments.sh @@ -51,8 +51,8 @@ echo "==============================================" echo "Experiment 1: LLaMA-3.1-8B (Main Results)" echo "==============================================" -python -m alignment.experiments.llm_alignment \ - --config configs/paper/llama3_8b_full.yaml \ +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ 2>&1 | tee logs/llama3_8b_paper.log echo "LLaMA-3.1-8B completed at $(date)" @@ -65,8 +65,8 @@ echo "==============================================" echo "Experiment 2: Mistral-7B (Generalization)" echo "==============================================" -python -m alignment.experiments.llm_alignment \ - --config configs/paper/mistral_7b_full.yaml \ +python scripts/run_experiment.py \ + --config configs/prune_llm/mistral_7b_full.yaml \ 2>&1 | tee logs/mistral_7b_paper.log echo "Mistral-7B completed at $(date)" @@ -79,8 +79,8 @@ echo "==============================================" echo "Experiment 3: LLaMA-2-7B (Generalization)" echo "==============================================" -python -m alignment.experiments.llm_alignment \ - --config configs/paper/llama2_7b_full.yaml \ +python scripts/run_experiment.py \ + --config configs/prune_llm/llama2_7b_full.yaml \ 2>&1 | tee logs/llama2_7b_paper.log echo "LLaMA-2-7B completed at $(date)" @@ -93,8 +93,8 @@ echo "==============================================" echo "Experiment 4: Qwen2-7B (Generalization)" echo "==============================================" -python -m alignment.experiments.llm_alignment \ - --config configs/paper/qwen2_7b_full.yaml \ +python scripts/run_experiment.py \ + --config configs/prune_llm/qwen2_7b_full.yaml \ 2>&1 | tee logs/qwen2_7b_paper.log echo "Qwen2-7B completed at $(date)" diff --git a/src/alignment/analysis/visualization/pruning_plots.py b/src/alignment/analysis/visualization/pruning_plots.py index b4effa3f..e6d804a6 100644 --- a/src/alignment/analysis/visualization/pruning_plots.py +++ b/src/alignment/analysis/visualization/pruning_plots.py @@ -745,6 +745,10 @@ def _plot_summary_stats(self, ax, results): "chip": "#16a085", # Activation-based metrics "activation_l2_norm": "#e74c3c", # Same as magnitude (they're aliases) + # Weight-only magnitude (channel-group) + "weight_magnitude": "#e74c3c", + "weight_magnitude_low": "#e74c3c", + "weight_magnitude_high": "#c0392b", "activation_mean": "#c0392b", "activation_variance": "#a93226", # Generalized importance (no outlier assumption) @@ -769,6 +773,7 @@ def _plot_summary_stats(self, ax, results): PRUNING_METHOD_MARKERS = { "random": "o", "magnitude": "s", + "weight_magnitude": "s", "taylor": "^", "gradient": "d", "composite": "p", diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index ab2ab57a..7e2d10ed 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -1924,7 +1924,10 @@ def compute_scar_supernode_metrics( activation_power_i = E[u_i^2] taylor_i = E[ | (g_u_i * u_i) | ] (first-order saliency) curvature_i = E[ (v_i^T g_y)^2 ] (Rayleigh-style curvature along v_i) - loss_proxy_i = 0.5 * activation_power_i * curvature_i + loss_proxy_i = 0.5 * E[(u_i * (v_i^T g_y))^2] (joint second moment; matches paper Eq. loss-proxy) + + Notes: + - We also compute a factored approximation (0.5 * E[u_i^2] * E[(v_i^T g_y)^2]) for diagnostics. """ if not getattr(self.config, "do_scar_metrics", False): logger.info("SCAR metrics disabled in config; skipping compute_scar_supernode_metrics.") @@ -1993,9 +1996,10 @@ def compute_scar_supernode_metrics( continue scar_state[layer_name] = { - "u_sqr_sum": None, # sum over tokens of u^2 - "R_sum": None, # sum over tokens of (v_i^T g_y)^2 - "T_sum": None, # sum over tokens of |g_u_i * u_i| + "u_sqr_sum": None, # sum over tokens of u^2 + "R_sum": None, # sum over tokens of (v_i^T g_y)^2 + "T_sum": None, # sum over tokens of |g_u_i * u_i| + "loss_proxy_sum": None, # sum over tokens of (u_i * (v_i^T g_y))^2 "count": 0, # number of tokens seen } @@ -2015,11 +2019,14 @@ def fwd_hook(mod: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Ten state = scar_state[name] m = u_flat.shape[-1] if state["u_sqr_sum"] is None: - state["u_sqr_sum"] = torch.zeros(m, device=u_flat.device, dtype=u_flat.dtype) + # Accumulate in float32 for numerical stability (bfloat16 accumulation is too lossy) + state["u_sqr_sum"] = torch.zeros(m, device=u_flat.device, dtype=torch.float32) state["R_sum"] = torch.zeros_like(state["u_sqr_sum"]) state["T_sum"] = torch.zeros_like(state["u_sqr_sum"]) + state["loss_proxy_sum"] = torch.zeros_like(state["u_sqr_sum"]) - state["u_sqr_sum"] += (u_flat * u_flat).sum(dim=0) + u_flat_f = u_flat.float() + state["u_sqr_sum"] += (u_flat_f * u_flat_f).sum(dim=0) state["count"] += u_flat.shape[0] # Store u for first-order saliency computation in backward @@ -2076,17 +2083,24 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: # Curvature: R_i = E[ (v_i^T g_y)^2 ] # s = g_y * W_down => [N_tokens, m] try: - s_flat = torch.matmul(g_y_flat, weight) # [N_tokens, m] + s_flat = torch.matmul(g_y_flat.float(), weight.float()) # [N_tokens, m] except Exception as e: logger.error(f"SCAR metrics: failed to compute W_down^T g_y for layer {name}: {e}") return - state["R_sum"] += (s_flat * s_flat).sum(dim=0) + s2 = (s_flat * s_flat).sum(dim=0) + state["R_sum"] += s2 # First-order Taylor saliency: E[ |g_u_i * u_i| ] - t_contrib = torch.abs(g_u_flat * u_flat).sum(dim=0) + u_flat_f = u_flat.float() + g_u_flat_f = g_u_flat.float() + t_contrib = torch.abs(g_u_flat_f * u_flat_f).sum(dim=0) state["T_sum"] += t_contrib + # Loss proxy: 0.5 * E[(u_i * (v_i^T g_y))^2] (joint moment) + q = u_flat_f * s_flat + state["loss_proxy_sum"] += (q * q).sum(dim=0) + return fwd_hook, bwd_hook fwd_hook, bwd_hook = make_hooks(layer_name) @@ -2146,13 +2160,17 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: u2_mean = state["u_sqr_sum"] / float(count) R_vals = state["R_sum"] / float(count) T_vals = state["T_sum"] / float(count) - loss_proxy = 0.5 * u2_mean * R_vals + # Exact joint estimator used by the paper definition + loss_proxy_joint = 0.5 * (state["loss_proxy_sum"] / float(count)) + # Diagnostic: separable approximation (can diverge if u^2 and (v^T g)^2 correlate) + loss_proxy_factored = 0.5 * u2_mean * R_vals scar_scores[layer_name] = { "scar_activation_power": u2_mean, "scar_taylor": T_vals, "scar_curvature": R_vals, - "scar_loss_proxy": loss_proxy, + "scar_loss_proxy": loss_proxy_joint, + "scar_loss_proxy_factored": loss_proxy_factored, } # Also attach these scores into importance_scores for later use in pruning @@ -2160,7 +2178,11 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: layer_scores["scar_activation_power"] = u2_mean layer_scores["scar_taylor"] = T_vals layer_scores["scar_curvature"] = R_vals - layer_scores["scar_loss_proxy"] = loss_proxy + layer_scores["scar_loss_proxy"] = loss_proxy_joint + layer_scores["scar_loss_proxy_factored"] = loss_proxy_factored + # Now that scar_loss_proxy exists, we can compute the configured supernode mask on this layer. + # This ensures 'protect_core' works during pruning even when score_metric='scar_loss_proxy'. + self._apply_supernode_selection(layer_scores, composite=None) self.importance_scores[layer_name] = layer_scores logger.info(f"SCAR metrics: computed metrics for {len(scar_scores)} FFN layers.") @@ -2230,18 +2252,44 @@ def compute_baseline_pruning_scores( model = self.wrapped_model._model device = next(model.parameters()).device - # Track which layers to compute scores for (MLP layers from tracked_layers) - target_layers = [] - for layer_name in self.importance_scores.keys(): - # Check if it's an MLP layer - if any(mlp_pattern in layer_name for mlp_pattern in ["up_proj", "gate_proj", "down_proj", "fc", "mlp"]): - target_layers.append(layer_name) - - if not target_layers: - logger.warning("No target layers found for baseline pruning scores") + # --------------------------------------------------------------------- + # IMPORTANT: Channel/group adaptation (matches paper + structured FFN pruning) + # + # A "channel" corresponds to: + # - row i of gate_proj and up_proj (out_features = intermediate_dim) + # - column i of down_proj (in_features = intermediate_dim) + # + # So for baseline methods (Wanda/SparseGPT) we compute a *group score* per channel: + # score_i = score_gate_row_i + score_up_row_i + score_down_col_i + # and store that 1D score (length = intermediate_dim) for pruning. + # --------------------------------------------------------------------- + import re + + layer_indices = set() + for k in self.importance_scores.keys(): + m = re.search(r"layers\.(\d+)\.mlp", k) + if m: + layer_indices.add(int(m.group(1))) + + if not layer_indices: + logger.warning("No MLP layers found in importance_scores; cannot compute baseline channel scores.") return {} - - logger.info(f"Computing baseline scores for {len(target_layers)} layers") + + underlying_model = self._get_underlying_model() + module_dict = dict(underlying_model.named_modules()) + + def _resolve_mlp_path(layer_idx: int) -> Optional[str]: + candidates = [ + f"model.model.layers.{layer_idx}.mlp", + f"model.layers.{layer_idx}.mlp", + f"layers.{layer_idx}.mlp", + ] + for p in candidates: + if p in module_dict: + return p + return None + + logger.info(f"Computing baseline channel scores for {len(layer_indices)} MLP layers") # Compute Wanda scores if "wanda" in strategies: @@ -2250,35 +2298,54 @@ def compute_baseline_pruning_scores( wanda = WandaPruning(num_calibration_samples=num_calibration_samples) wanda.calibrate(model, calib_dataloader, device=str(device)) - # Compute scores for each tracked layer - for layer_name in target_layers: - # Find the module - module = None - for name, mod in model.named_modules(): - if name == layer_name or layer_name.endswith(name) or name.endswith(layer_name.split('.')[-1]): - if isinstance(mod, nn.Linear): - module = mod - break - - if module is not None: - try: - # Get structured scores (per output neuron) - scores = wanda.get_structured_scores(module, layer_name=layer_name, dim=0) - - # Store in importance_scores - if layer_name not in self.importance_scores: - self.importance_scores[layer_name] = {} - self.importance_scores[layer_name]["wanda"] = scores - - if layer_name not in results: - results[layer_name] = {} - results[layer_name]["wanda"] = scores - - logger.debug(f"Wanda scores for {layer_name}: shape {scores.shape}, mean {scores.mean():.4f}") - except Exception as e: - logger.warning(f"Failed to compute Wanda scores for {layer_name}: {e}") + for layer_idx in sorted(layer_indices): + mlp_path = _resolve_mlp_path(layer_idx) + if mlp_path is None: + logger.warning(f"Wanda: could not resolve MLP path for layer {layer_idx}") + continue + + gate_name = f"{mlp_path}.gate_proj" + up_name = f"{mlp_path}.up_proj" + down_name = f"{mlp_path}.down_proj" + + if gate_name not in module_dict or up_name not in module_dict or down_name not in module_dict: + logger.warning(f"Wanda: missing projections for {mlp_path}") + continue + + gate = module_dict[gate_name] + up = module_dict[up_name] + down = module_dict[down_name] + + if not all(isinstance(m, nn.Linear) for m in (gate, up, down)): + logger.warning(f"Wanda: projections for {mlp_path} are not all nn.Linear; skipping") + continue + + try: + # gate/up: per output channel (rows) => dim=0 + gate_scores = wanda.get_structured_scores(gate, layer_name=gate_name, dim=0) + up_scores = wanda.get_structured_scores(up, layer_name=up_name, dim=0) + # down: per input channel (columns) => dim=1 + down_scores = wanda.get_structured_scores(down, layer_name=down_name, dim=1) + + channel_scores = (gate_scores + up_scores + down_scores).detach() + + # Store the channel-group score for pruning under all three projection names + for store_name in (gate_name, up_name, down_name): + if store_name not in self.importance_scores: + self.importance_scores[store_name] = {} + self.importance_scores[store_name]["wanda"] = channel_scores + + if store_name not in results: + results[store_name] = {} + results[store_name]["wanda"] = channel_scores + + logger.debug( + f"Wanda channel scores for {mlp_path}: shape={tuple(channel_scores.shape)}, mean={channel_scores.mean().item():.4f}" + ) + except Exception as e: + logger.warning(f"Failed to compute Wanda channel scores for {mlp_path}: {e}") - logger.info(f"Wanda: computed scores for {len([k for k in results if 'wanda' in results.get(k, {})])} layers") + logger.info(f"Wanda: computed channel scores for {len(layer_indices)} MLP layers") except Exception as e: logger.error(f"Wanda calibration failed: {e}") import traceback @@ -2291,41 +2358,140 @@ def compute_baseline_pruning_scores( sparsegpt = SparseGPTPruning(num_calibration_samples=num_calibration_samples) sparsegpt.calibrate(model, calib_dataloader, device=str(device)) - # Compute scores for each tracked layer - for layer_name in target_layers: - # Find the module - module = None - for name, mod in model.named_modules(): - if name == layer_name or layer_name.endswith(name) or name.endswith(layer_name.split('.')[-1]): - if isinstance(mod, nn.Linear): - module = mod - break - - if module is not None: - try: - # Get structured scores (per output neuron) - scores = sparsegpt.get_structured_scores(module, layer_name=layer_name, dim=0) - - # Store in importance_scores - if layer_name not in self.importance_scores: - self.importance_scores[layer_name] = {} - self.importance_scores[layer_name]["sparsegpt"] = scores - - if layer_name not in results: - results[layer_name] = {} - results[layer_name]["sparsegpt"] = scores - - logger.debug(f"SparseGPT scores for {layer_name}: shape {scores.shape}, mean {scores.mean():.4f}") - except Exception as e: - logger.warning(f"Failed to compute SparseGPT scores for {layer_name}: {e}") - - logger.info(f"SparseGPT: computed scores for {len([k for k in results if 'sparsegpt' in results.get(k, {})])} layers") + for layer_idx in sorted(layer_indices): + mlp_path = _resolve_mlp_path(layer_idx) + if mlp_path is None: + logger.warning(f"SparseGPT: could not resolve MLP path for layer {layer_idx}") + continue + + gate_name = f"{mlp_path}.gate_proj" + up_name = f"{mlp_path}.up_proj" + down_name = f"{mlp_path}.down_proj" + + if gate_name not in module_dict or up_name not in module_dict or down_name not in module_dict: + logger.warning(f"SparseGPT: missing projections for {mlp_path}") + continue + + gate = module_dict[gate_name] + up = module_dict[up_name] + down = module_dict[down_name] + + if not all(isinstance(m, nn.Linear) for m in (gate, up, down)): + logger.warning(f"SparseGPT: projections for {mlp_path} are not all nn.Linear; skipping") + continue + + try: + gate_scores = sparsegpt.get_structured_scores(gate, layer_name=gate_name, dim=0) + up_scores = sparsegpt.get_structured_scores(up, layer_name=up_name, dim=0) + down_scores = sparsegpt.get_structured_scores(down, layer_name=down_name, dim=1) + + channel_scores = (gate_scores + up_scores + down_scores).detach() + + for store_name in (gate_name, up_name, down_name): + if store_name not in self.importance_scores: + self.importance_scores[store_name] = {} + self.importance_scores[store_name]["sparsegpt"] = channel_scores + + if store_name not in results: + results[store_name] = {} + results[store_name]["sparsegpt"] = channel_scores + + logger.debug( + f"SparseGPT channel scores for {mlp_path}: shape={tuple(channel_scores.shape)}, mean={channel_scores.mean().item():.4f}" + ) + except Exception as e: + logger.warning(f"Failed to compute SparseGPT channel scores for {mlp_path}: {e}") + + logger.info(f"SparseGPT: computed channel scores for {len(layer_indices)} MLP layers") except Exception as e: logger.error(f"SparseGPT calibration failed: {e}") import traceback logger.error(traceback.format_exc()) return results + + def compute_weight_magnitude_channel_scores(self) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Compute a fast, calibration-free structured *channel* baseline using weight magnitudes. + + For each MLP layer and intermediate channel i: + score_i = ||W_gate[i,:]||_2 + ||W_up[i,:]||_2 + ||W_down[:,i]||_2 + + This matches the "Magnitude (channel)" baseline described in the paper. + + Returns: + Dict mapping module_name -> {"weight_magnitude": score_tensor} + """ + import re + + underlying_model = self._get_underlying_model() + module_dict = dict(underlying_model.named_modules()) + + # Identify MLP layer indices based on already-tracked layer names + layer_indices = set() + for k in self.importance_scores.keys(): + m = re.search(r"layers\.(\d+)\.mlp", k) + if m: + layer_indices.add(int(m.group(1))) + + if not layer_indices: + logger.warning("weight_magnitude: no MLP layers found in importance_scores; skipping") + return {} + + def _resolve_mlp_path(layer_idx: int) -> Optional[str]: + candidates = [ + f"model.model.layers.{layer_idx}.mlp", + f"model.layers.{layer_idx}.mlp", + f"layers.{layer_idx}.mlp", + ] + for p in candidates: + if p in module_dict: + return p + return None + + results: Dict[str, Dict[str, torch.Tensor]] = {} + + for layer_idx in sorted(layer_indices): + mlp_path = _resolve_mlp_path(layer_idx) + if mlp_path is None: + logger.warning(f"weight_magnitude: could not resolve MLP path for layer {layer_idx}") + continue + + gate_name = f"{mlp_path}.gate_proj" + up_name = f"{mlp_path}.up_proj" + down_name = f"{mlp_path}.down_proj" + + if gate_name not in module_dict or up_name not in module_dict or down_name not in module_dict: + logger.warning(f"weight_magnitude: missing projections for {mlp_path}") + continue + + gate = module_dict[gate_name] + up = module_dict[up_name] + down = module_dict[down_name] + + if not all(isinstance(m, nn.Linear) for m in (gate, up, down)): + logger.warning(f"weight_magnitude: projections for {mlp_path} are not all nn.Linear; skipping") + continue + + # gate/up: row norms (out_features = intermediate_dim) + gate_score = torch.norm(gate.weight.detach().float(), p=2, dim=1) + up_score = torch.norm(up.weight.detach().float(), p=2, dim=1) + # down: column norms (in_features = intermediate_dim) + down_score = torch.norm(down.weight.detach().float(), p=2, dim=0) + + channel_scores = (gate_score + up_score + down_score).detach() + + for store_name in (gate_name, up_name, down_name): + if store_name not in self.importance_scores: + self.importance_scores[store_name] = {} + self.importance_scores[store_name]["weight_magnitude"] = channel_scores + + if store_name not in results: + results[store_name] = {} + results[store_name]["weight_magnitude"] = channel_scores + + logger.info(f"Computed weight_magnitude channel scores for {len(layer_indices)} MLP layers") + return results @staticmethod def _normalize_scores_tensor(scores: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: @@ -3831,7 +3997,7 @@ def compute_group_redundancy(acts: torch.Tensor) -> Tuple[torch.Tensor, float, f return results - def compute_directed_redundancy( + def compute_halo_redundancy_within_hidden_outputs( self, scar_scores: Dict[str, Dict[str, torch.Tensor]], supernode_fraction: float = 0.01, @@ -3839,14 +4005,17 @@ def compute_directed_redundancy( num_samples: int = 8, ) -> Dict[str, Dict[str, Any]]: """ - Compute DIRECTED REDUNDANCY: redundancy among neurons in the "halo" connected to supernodes. + (Legacy/diagnostic) Compute redundancy among *hidden-dimension* output neurons that are strongly + influenced by supernodes. + + Note: This is NOT the SCAR paper's "directed redundancy" (which is defined on loss-relevant + per-channel contribution signals). This helper is kept for exploratory plots and is not used + for pruning decisions. - This is the key novel contribution for SCAR pruning: - 1. Identify supernodes (top neurons by activation power/loss proxy) - 2. Find "halo" neurons: those with large weight connections TO supernodes - 3. Compute pairwise redundancy WITHIN the halo - 4. Neurons in the halo that are highly redundant with each other are safe to prune - 5. Neurons with low redundancy (unique information) should be protected + It: + 1. Identifies intermediate-dim supernodes (top by loss proxy / activation power) + 2. Defines a "halo" in the *hidden* output space: hidden neurons receiving large total weight from supernodes + 3. Computes within-halo redundancy using activation correlation Args: scar_scores: SCAR metrics per layer from compute_scar_supernode_metrics @@ -4299,9 +4468,9 @@ def capture_hook(module, inputs, outputs): full_directed_redundancy[supernode_indices] = supernode_total_influence layer_scores["directed_redundancy"] = full_directed_redundancy - # Create "supernode protection score" - higher = more important, harder to prune - # All neurons get a base score based on their activation (L2 norm) - # Then supernodes get a large boost + # Diagnostic score (NOT used as SCAR-Prot): downstream influence of supernodes. + # This measures which *supernodes* strongly explain downstream hidden activations. + # We keep it under a separate key to avoid clobbering the SCAR-Prot pruning score. base_protection = torch.zeros(intermediate_dim) # Get base importance from L2 norm, RQ, or scar_loss_proxy if available @@ -4323,9 +4492,9 @@ def capture_hook(module, inputs, outputs): # Supernodes get very high protection (10x boost above max base) max_base = base_protection.max().item() if base_protection.max() > 0 else 1.0 - protection_scores = base_protection.clone() - protection_scores[supernode_indices] = max_base * 10 + supernode_total_influence - layer_scores["supernode_protection_score"] = protection_scores + downstream_influence_scores = base_protection.clone() + downstream_influence_scores[supernode_indices] = max_base * 10 + supernode_total_influence + layer_scores["supernode_downstream_influence_score"] = downstream_influence_scores self.importance_scores[layer_name] = layer_scores @@ -4362,275 +4531,348 @@ def compute_supernode_connectivity_pruning_score( plots_dir: Optional[Union[str, Path]] = None, ) -> Dict[str, Dict[str, Any]]: """ - Compute supernode-connectivity based pruning score. - - Algorithm: - 1. For each layer, identify supernodes (top neurons by activation power) - 2. Compute weight connections from each neuron to supernodes in next layer - 3. Partition neurons into: - - High-connectivity: Strong weights to supernodes → compute redundancy, prune redundant - - Low-connectivity: Weak weights to supernodes → low importance, can prune - 4. Create composite pruning score: - - Low connectivity neurons get low score (safe to prune) - - High connectivity neurons: score = base_importance - redundancy_penalty - - This captures the insight that neurons weakly connected to important supernodes - are likely less important, while among strongly connected neurons, the redundant - ones are safer to prune. + Compute SCAR-style halo-aware pruning scores (paper-aligned). + + This routine computes, per FFN channel i in each layer: + - **Supernodes**: top `supernode_fraction` by `scar_loss_proxy` + - **Connectivity** Conn_i: overlap of downstream write pattern |v_i| with the aggregated + supernode write pattern a = Σ_{s in supernodes} |v_s| + - **Halo**: top `high_connectivity_fraction` of non-supernodes by Conn_i + - **Loss-relevant redundancy to core** (halo only): using the scalar contribution + q_i = u_i * (v_i^T g_y), compute Gaussian MI to each supernode and take the max + - **Protection** Protect_i in [0, 1] (halo only): 1 - normalized(redundancy_to_core) + + It then produces two **importance scores** (high = keep; prune with mode="low"): + - `supernode_protection_score` (SCAR-Prot): LP_i * Protect_i (non-halo Protect=1) + - `supernode_connectivity_score` (SCAR-Conn): LP_i * ((1-Conn_i) + Conn_i * Protect_i) + + Notes: + - `redundancy_weight` is retained for backward compatibility but not used in the + paper-aligned estimator (MI already yields a redundancy scale). Args: scar_scores: SCAR scores dictionary with supernode metrics supernode_fraction: Fraction of neurons considered supernodes - high_connectivity_fraction: Fraction of neurons considered "high connectivity" - redundancy_weight: Weight for redundancy penalty in high-connectivity group - num_samples: Calibration samples for redundancy computation + high_connectivity_fraction: Halo fraction (fraction of non-supernodes placed in halo) + redundancy_weight: (unused) kept for backward compatibility + num_samples: Calibration samples for redundancy / protection computation plots_dir: Directory to save analysis plots Returns: Dictionary with pruning scores and analysis per layer """ - logger.info("Computing supernode-connectivity based pruning score...") - logger.info(f" Supernode fraction: {supernode_fraction*100:.1f}%") - logger.info(f" High-connectivity fraction: {high_connectivity_fraction*100:.1f}%") - + logger.info("Computing SCAR halo connectivity + protection pruning scores...") + logger.info(f" Supernode fraction (rho): {supernode_fraction*100:.1f}%") + logger.info(f" Halo fraction (eta): {high_connectivity_fraction*100:.1f}%") + + eps = 1e-8 results: Dict[str, Dict[str, Any]] = {} - - # Get underlying HF model + + # Underlying HF model for module lookup / hook registration hf_model = self.model if hasattr(hf_model, "model"): hf_model = hf_model.model - - # Get calibration texts + + module_dict = dict(hf_model.named_modules()) + + # Calibration texts calibration_texts: List[str] = [] if hasattr(self, "dataset") and hasattr(self.dataset, "texts"): calibration_texts = list(self.dataset.texts)[:num_samples] - if not calibration_texts: - logger.warning("No calibration texts available") + logger.warning("No calibration texts available for SCAR protection/connectivity computation") return {} - - # Setup plots directory - if plots_dir: - plots_dir = Path(plots_dir) - scatter_dir = plots_dir / "scatter" - scatter_dir.mkdir(parents=True, exist_ok=True) - - # Get all layer names with SCAR scores + + # Determine which layers to process (down_proj layers only) layer_names = [ln for ln in scar_scores.keys() if "mlp.down_proj" in ln] - - for idx, layer_name in enumerate(layer_names): - layer_metrics = scar_scores[layer_name] - - # Get supernode identification metric - # Note: Use explicit None checks to avoid tensor boolean ambiguity - supernode_metric = layer_metrics.get("scar_activation_power") - if supernode_metric is None: - supernode_metric = layer_metrics.get("scar_loss_proxy") - if supernode_metric is None: + if not layer_names: + logger.warning("No down_proj layers found in scar_scores; skipping SCAR connectivity/pruning score") + return {} + + # ------------------------------------------------------------------ + # Phase 1: Per-layer supernodes + connectivity + halo indices (weights-only) + # ------------------------------------------------------------------ + plan: Dict[str, Dict[str, Any]] = {} + for layer_name in layer_names: + layer_metrics = scar_scores.get(layer_name, {}) or {} + lp = layer_metrics.get("scar_loss_proxy") + if lp is None: + # fallbacks for older runs + lp = layer_metrics.get("scar_activation_power") + if lp is None: continue - - supernode_metric = supernode_metric.float().cpu() - num_neurons = supernode_metric.numel() # intermediate_dim - - # Identify supernodes in THIS layer - num_supernodes = max(1, int(supernode_fraction * num_neurons)) - _, sorted_indices = torch.sort(supernode_metric, descending=True) - supernode_indices = sorted_indices[:num_supernodes] - supernode_mask = torch.zeros(num_neurons, dtype=torch.bool) - supernode_mask[supernode_indices] = True - - # Get weights from this layer's neurons to NEXT layer - # down_proj: [hidden_dim, intermediate_dim] - current layer output - # next layer's up_proj/gate_proj: [intermediate_dim', hidden_dim] - receives from hidden - - # For simplicity, we use the current layer's down_proj weights - # to estimate connectivity importance (how much each neuron contributes to output) - down_proj_weight = None - for name, module in hf_model.named_modules(): - if name == layer_name or (name.endswith("mlp.down_proj") and name in layer_name): - if hasattr(module, "weight"): - down_proj_weight = module.weight.detach().float().cpu() - break - - if down_proj_weight is None: + + lp_cpu = lp.detach().float().cpu() + m = lp_cpu.numel() + if m == 0: continue - - hidden_dim, intermediate_dim = down_proj_weight.shape - - # Compute connectivity score: how much each neuron (column) contributes - # to outputs that go to supernodes in the NEXT layer - # For within-layer analysis: use total weight magnitude per neuron - neuron_output_magnitude = down_proj_weight.abs().sum(dim=0) # [intermediate_dim] - - # If we have next layer info, use supernode indices from next layer - # For now, use supernode influence from THIS layer - # Supernodes have high activation, so neurons with large weights TO those outputs matter - supernode_influence = down_proj_weight[:, supernode_indices].abs().sum(dim=1) # [hidden_dim] - - # For each neuron in intermediate_dim, compute its "supernode connectivity" - # = sum of |weights| to hidden dimensions that have high supernode_influence - num_high_influence = max(1, int(high_connectivity_fraction * hidden_dim)) - _, high_influence_hidden = torch.topk(supernode_influence, num_high_influence) - - # Connectivity score: how much does each intermediate neuron contribute to - # hidden dimensions that strongly connect to supernodes - connectivity_score = down_proj_weight[high_influence_hidden, :].abs().sum(dim=0) # [intermediate_dim] - - # Partition into high and low connectivity - num_high_conn = max(1, int(high_connectivity_fraction * intermediate_dim)) - _, conn_sorted = torch.sort(connectivity_score, descending=True) - high_conn_indices = conn_sorted[:num_high_conn] - low_conn_indices = conn_sorted[num_high_conn:] - - high_conn_mask = torch.zeros(intermediate_dim, dtype=torch.bool) - high_conn_mask[high_conn_indices] = True - - logger.info(f" {layer_name}:") - logger.info(f" {num_supernodes} supernodes, {len(high_conn_indices)} high-connectivity neurons") - - # Compute redundancy among high-connectivity neurons - # Capture activations for redundancy computation - activations: List[torch.Tensor] = [] - - def capture_hook(module, inputs, outputs): - if inputs and inputs[0] is not None: - inp = inputs[0].detach().float() - if inp.ndim == 3: - inp = inp.reshape(-1, inp.shape[-1]) - activations.append(inp.cpu()) - - hook_handle = None - for name, module in hf_model.named_modules(): - if name == layer_name or (name.endswith("mlp.down_proj") and name in layer_name): - hook_handle = module.register_forward_hook(capture_hook) - break - - if hook_handle: - self.model.eval() - with torch.no_grad(): - for text in calibration_texts: - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=256) - inputs = {k: v.to(self.config.device) for k, v in inputs.items()} - try: - self.model(**inputs) - except Exception: - pass - hook_handle.remove() - - # Compute redundancy for high-connectivity neurons - redundancy_scores = torch.zeros(intermediate_dim) - - if activations: - all_acts = torch.cat(activations, dim=0) # [N, intermediate_dim] - high_acts = all_acts[:, high_conn_indices] # [N, num_high_conn] - - # Compute pairwise correlation among high-connectivity neurons - high_centered = high_acts - high_acts.mean(dim=0, keepdim=True) - high_cov = (high_centered.T @ high_centered) / (high_acts.shape[0] - 1) - high_std = torch.sqrt(torch.diag(high_cov) + 1e-8) - high_corr = high_cov / (high_std.unsqueeze(0) * high_std.unsqueeze(1) + 1e-8) - high_corr = torch.clamp(high_corr, -1, 1) - high_corr.fill_diagonal_(0) - - # Per-neuron redundancy = mean |correlation| with others - per_neuron_redundancy = high_corr.abs().mean(dim=1) - redundancy_scores[high_conn_indices] = per_neuron_redundancy - - logger.info(f" High-connectivity redundancy: mean={per_neuron_redundancy.mean():.4f}, max={per_neuron_redundancy.max():.4f}") - - # Build composite pruning score - # Low connectivity → low score (safe to prune) - # High connectivity with high redundancy → lower score (redundant, can prune) - # High connectivity with low redundancy → high score (important, protect) - - # Normalize connectivity to [0, 1] - conn_normalized = connectivity_score.clone() - if conn_normalized.max() > conn_normalized.min(): - conn_normalized = (conn_normalized - conn_normalized.min()) / (conn_normalized.max() - conn_normalized.min()) - - # Base importance from activation power - base_importance = supernode_metric.clone() - if base_importance.max() > base_importance.min(): - base_importance = (base_importance - base_importance.min()) / (base_importance.max() - base_importance.min()) - - # Composite score: - # - Start with base importance (activation power) - # - Multiply by connectivity (low connectivity = low score) - # - Subtract redundancy penalty for high-connectivity neurons - pruning_score = base_importance * (0.5 + 0.5 * conn_normalized) # [0, 1] - pruning_score[high_conn_mask] = pruning_score[high_conn_mask] - redundancy_weight * redundancy_scores[high_conn_mask] - - # Supernodes always get high protection - pruning_score[supernode_mask] = pruning_score.max() + 1.0 - - # Store in importance_scores + + module = module_dict.get(layer_name) + if module is None or not hasattr(module, "weight"): + logger.warning(f"SCAR connectivity: could not resolve module/weight for {layer_name}") + continue + + # Identify supernodes by LP + num_supernodes = max(1, int(supernode_fraction * m)) + _, super_idx = torch.topk(lp_cpu, k=num_supernodes, largest=True) + super_idx = super_idx.long() + super_mask = torch.zeros(m, dtype=torch.bool) + super_mask[super_idx] = True + + # Compute Conn_i from down_proj weights (write-pattern overlap) + W = module.weight.detach().float().cpu() # [hidden_dim, m] + abs_W = W.abs() + a = abs_W[:, super_idx].sum(dim=1) # [hidden_dim] + a_norm = a.sum() + eps + v_norm = abs_W.sum(dim=0) + eps # [m] + conn_num = (abs_W * a.unsqueeze(1)).sum(dim=0) # [m] + conn = (conn_num / (v_norm * a_norm + eps)).clamp(0.0, 1.0) + + # Halo: top eta among non-supernodes by Conn + non_super_idx = (~super_mask).nonzero(as_tuple=True)[0] + if non_super_idx.numel() == 0: + continue + num_halo = max(1, int(high_connectivity_fraction * non_super_idx.numel())) + halo_scores = conn[non_super_idx] + _, halo_rel = torch.topk(halo_scores, k=num_halo, largest=True) + halo_idx = non_super_idx[halo_rel].long() + + plan[layer_name] = { + "lp_cpu": lp_cpu, + "conn_cpu": conn, + "super_idx_cpu": super_idx, + "halo_idx_cpu": halo_idx, + "m": m, + # device-side indices + streaming sums (initialized lazily in hooks) + "super_idx": None, + "halo_idx": None, + "sum_q_super": None, + "sum_q2_super": None, + "sum_q_halo": None, + "sum_q2_halo": None, + "sum_q_halo_super": None, + "count": 0, + } + + if not plan: + logger.warning("SCAR connectivity: no layers eligible after filtering; skipping") + return {} + + # ------------------------------------------------------------------ + # Phase 2: Calibration passes to estimate redundancy-to-core via q=u*(v^T g_y) + # ------------------------------------------------------------------ + hooks: List[Any] = [] + + def make_hooks(name: str): + def fwd_hook(mod: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor): + if not inputs or inputs[0] is None: + return + mod._scar_conn_last_u = inputs[0].detach() + + def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]): + st = plan.get(name) + if st is None: + return + if not grad_output or grad_output[0] is None: + return + if not hasattr(mod, "weight"): + return + if not hasattr(mod, "_scar_conn_last_u"): + return + + u = mod._scar_conn_last_u + delattr(mod, "_scar_conn_last_u") + + g_y = grad_output[0] + weight = mod.weight + + # Flatten to [N_tokens, dim] + if u.ndim > 2: + u_flat = u.reshape(-1, u.shape[-1]) + else: + u_flat = u.reshape(-1, u.shape[-1]) + if g_y.ndim > 2: + g_y_flat = g_y.reshape(-1, g_y.shape[-1]) + else: + g_y_flat = g_y.reshape(-1, g_y.shape[-1]) + + if u_flat.numel() == 0: + return + + # Move indices to the correct device once + if st["super_idx"] is None or st["super_idx"].device != u_flat.device: + st["super_idx"] = st["super_idx_cpu"].to(device=u_flat.device) + if st["halo_idx"] is None or st["halo_idx"].device != u_flat.device: + st["halo_idx"] = st["halo_idx_cpu"].to(device=u_flat.device) + + super_idx_dev = st["super_idx"] + halo_idx_dev = st["halo_idx"] + + # Compute q = u * (W_down^T g_y) ONLY for channels we need (supernodes + halo). + # This avoids a full [N, m] GEMM per layer. + idx_union = torch.cat([super_idx_dev, halo_idx_dev], dim=0) # [|M|+|H|] + try: + W_sel = weight.index_select(1, idx_union).float() # [hidden_dim, |M|+|H|] + s_sel = torch.matmul(g_y_flat.float(), W_sel) # [N, |M|+|H|] + u_sel = u_flat.index_select(1, idx_union).float() # [N, |M|+|H|] + except Exception: + return + + q_sel = u_sel * s_sel # [N, |M|+|H|] + n_super = super_idx_dev.numel() + q_super = q_sel[:, :n_super] # [N, |M|] + q_halo = q_sel[:, n_super:] # [N, |H|] + + N = q_flat.shape[0] + + # Initialize streaming sums on first batch + if st["sum_q_super"] is None: + st["sum_q_super"] = torch.zeros(q_super.shape[1], device=q_super.device, dtype=torch.float32) + st["sum_q2_super"] = torch.zeros_like(st["sum_q_super"]) + st["sum_q_halo"] = torch.zeros(q_halo.shape[1], device=q_halo.device, dtype=torch.float32) + st["sum_q2_halo"] = torch.zeros_like(st["sum_q_halo"]) + st["sum_q_halo_super"] = torch.zeros( + (q_halo.shape[1], q_super.shape[1]), device=q_halo.device, dtype=torch.float32 + ) + + st["sum_q_super"] += q_super.sum(dim=0) + st["sum_q2_super"] += (q_super * q_super).sum(dim=0) + st["sum_q_halo"] += q_halo.sum(dim=0) + st["sum_q2_halo"] += (q_halo * q_halo).sum(dim=0) + st["sum_q_halo_super"] += q_halo.transpose(0, 1) @ q_super # [|H|,|M|] + st["count"] += N + + return fwd_hook, bwd_hook + + for layer_name, module in module_dict.items(): + if layer_name not in plan: + continue + fwd, bwd = make_hooks(layer_name) + hooks.append(module.register_forward_hook(fwd)) + hooks.append(module.register_full_backward_hook(bwd)) + + # Run calibration (full forward+backward) for q-statistics + self.model.eval() + device = torch.device(self.config.device) + + # Try to use halo_analysis.max_length if present + halo_cfg = getattr(self.config, "halo_analysis", {}) or {} + if hasattr(halo_cfg, "__dict__"): + halo_cfg = vars(halo_cfg) + max_length = int(halo_cfg.get("max_length", 256)) + + try: + for idx, text in enumerate(calibration_texts): + inputs = self.tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=max_length, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + labels = inputs["input_ids"].clone() + pad_token_id = getattr(self.tokenizer, "pad_token_id", None) or getattr(self.tokenizer, "eos_token_id", None) + labels[labels == pad_token_id] = -100 + inputs["labels"] = labels + + self.model.zero_grad(set_to_none=True) + out = self.model(**inputs) + loss = out.loss + loss.backward() + + if (idx + 1) % 1 == 0: + logger.info(f" SCAR q-stats: processed {idx+1}/{len(calibration_texts)} samples, loss={loss.item():.4f}") + finally: + for h in hooks: + try: + h.remove() + except Exception: + pass + + # ------------------------------------------------------------------ + # Phase 3: Compute Protect + final importance scores; store into importance_scores + # ------------------------------------------------------------------ + for layer_name, st in plan.items(): + N = int(st.get("count", 0)) + if N <= 1 or st["sum_q_halo_super"] is None: + logger.warning(f"SCAR connectivity: insufficient q-stats for {layer_name} (N={N}); skipping layer") + continue + + sum_q_super = st["sum_q_super"].detach().cpu() + sum_q2_super = st["sum_q2_super"].detach().cpu() + sum_q_halo = st["sum_q_halo"].detach().cpu() + sum_q2_halo = st["sum_q2_halo"].detach().cpu() + sum_q_halo_super = st["sum_q_halo_super"].detach().cpu() + + mean_super = sum_q_super / float(N) + mean_halo = sum_q_halo / float(N) + + cov = (sum_q_halo_super / float(N)) - (mean_halo.unsqueeze(1) * mean_super.unsqueeze(0)) + var_halo = (sum_q2_halo / float(N)) - (mean_halo * mean_halo) + var_super = (sum_q2_super / float(N)) - (mean_super * mean_super) + + denom = torch.sqrt(var_halo.clamp_min(0).unsqueeze(1) * var_super.clamp_min(0).unsqueeze(0) + eps) + corr = torch.where(denom > 0, cov / denom, torch.zeros_like(cov)) + corr = corr.clamp(-0.9999, 0.9999) + + rho_sq = (corr * corr).clamp(0.0, 0.9999) + mi = -0.5 * torch.log(1 - rho_sq) + + redundancy_to_core = mi.max(dim=1).values # [|H|] + red_min = redundancy_to_core.min() + red_max = redundancy_to_core.max() + if red_max > red_min: + red_norm = (redundancy_to_core - red_min) / (red_max - red_min + eps) + else: + red_norm = torch.zeros_like(redundancy_to_core) + protect_halo = (1.0 - red_norm).clamp(0.0, 1.0) + + m = st["m"] + lp = st["lp_cpu"].float() + conn = st["conn_cpu"].float() + super_idx = st["super_idx_cpu"] + halo_idx = st["halo_idx_cpu"] + + protect_full = torch.ones(m, dtype=torch.float32) + protect_full[halo_idx] = protect_halo + protect_full[super_idx] = 1.0 + + # SCAR-Prot and SCAR-Conn importance scores (high=keep) + prot_score = (lp * protect_full).float() + conn_score = (lp * ((1.0 - conn) + conn * protect_full)).float() + + # Explicitly protect supernodes (also enforced later by apply_pruning via supernode_mask) + prot_boost = float(prot_score.max().item()) + 1.0 + conn_boost = float(conn_score.max().item()) + 1.0 + prot_score[super_idx] = prot_boost + conn_score[super_idx] = conn_boost + + halo_mask = torch.zeros(m, dtype=torch.bool) + halo_mask[halo_idx] = True + + super_mask = torch.zeros(m, dtype=torch.bool) + super_mask[super_idx] = True + layer_scores = self.importance_scores.get(layer_name, {}) - layer_scores["supernode_connectivity_score"] = pruning_score - layer_scores["connectivity_score"] = connectivity_score - layer_scores["redundancy_in_high_conn"] = redundancy_scores - layer_scores["high_connectivity_mask"] = high_conn_mask + layer_scores["supernode_protection_score"] = prot_score + layer_scores["supernode_connectivity_score"] = conn_score + layer_scores["connectivity_score"] = conn + layer_scores["protection_score"] = protect_full + layer_scores["halo_mask"] = halo_mask + layer_scores["supernode_mask"] = super_mask self.importance_scores[layer_name] = layer_scores - - # Generate scatter plots - if plots_dir and activations: - import matplotlib.pyplot as plt - viz = UnifiedVisualizer() - - try: - # Scatter 1: Connectivity vs Base Importance - fig, ax = plt.subplots(figsize=(10, 8)) - colors = ['#e74c3c' if supernode_mask[i] else '#3498db' if high_conn_mask[i] else '#95a5a6' - for i in range(intermediate_dim)] - ax.scatter(conn_normalized.numpy(), base_importance.numpy(), c=colors, alpha=0.5, s=10) - ax.set_xlabel("Connectivity Score (normalized)") - ax.set_ylabel("Base Importance (activation power)") - ax.set_title(f"Connectivity vs Importance\n{layer_name}") - ax.legend(handles=[ - plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#e74c3c', label='Supernode'), - plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#3498db', label='High-Connectivity'), - plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#95a5a6', label='Low-Connectivity'), - ]) - ax.grid(True, alpha=0.3) - fig.savefig(scatter_dir / f"connectivity_vs_importance_{layer_name.replace('.', '_')}.png", dpi=150, bbox_inches="tight") - plt.close(fig) - - # Scatter 2: Redundancy vs Connectivity (for high-connectivity only) - if high_conn_mask.sum() > 0: - fig, ax = plt.subplots(figsize=(10, 8)) - hc_conn = conn_normalized[high_conn_mask].numpy() - hc_red = redundancy_scores[high_conn_mask].numpy() - ax.scatter(hc_conn, hc_red, alpha=0.5, s=20, c='#3498db') - ax.set_xlabel("Connectivity Score") - ax.set_ylabel("Redundancy Score") - ax.set_title(f"Redundancy vs Connectivity (High-Conn Neurons)\n{layer_name}") - ax.grid(True, alpha=0.3) - fig.savefig(scatter_dir / f"redundancy_vs_connectivity_{layer_name.replace('.', '_')}.png", dpi=150, bbox_inches="tight") - plt.close(fig) - - # Scatter 3: Final Pruning Score vs Base Importance - fig, ax = plt.subplots(figsize=(10, 8)) - ax.scatter(base_importance.numpy(), pruning_score.numpy(), c=colors, alpha=0.5, s=10) - ax.set_xlabel("Base Importance (activation power)") - ax.set_ylabel("Final Pruning Score") - ax.set_title(f"Pruning Score vs Base Importance\n{layer_name}") - ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x') - ax.legend() - ax.grid(True, alpha=0.3) - fig.savefig(scatter_dir / f"pruning_score_vs_importance_{layer_name.replace('.', '_')}.png", dpi=150, bbox_inches="tight") - plt.close(fig) - - except Exception as e: - logger.warning(f" Failed to generate scatter plots: {e}") - + results[layer_name] = { - "num_supernodes": num_supernodes, - "num_high_connectivity": len(high_conn_indices), - "num_low_connectivity": len(low_conn_indices), - "mean_redundancy_high_conn": float(redundancy_scores[high_conn_mask].mean().item()) if high_conn_mask.sum() > 0 else 0, - "pruning_score_range": [float(pruning_score.min().item()), float(pruning_score.max().item())], + "num_supernodes": int(super_idx.numel()), + "num_halo": int(halo_idx.numel()), + "q_samples": N, + "conn_mean": float(conn.mean().item()), + "protect_halo_mean": float(protect_halo.mean().item()) if protect_halo.numel() else 0.0, + "redundancy_to_core_mean": float(redundancy_to_core.mean().item()) if redundancy_to_core.numel() else 0.0, } - - logger.info(f"Computed supernode-connectivity pruning score for {len(results)} layers") + + logger.info(f"Computed SCAR protection/connectivity scores for {len(results)} layers") return results def analyze_halo_vs_nonhalo_redundancy( @@ -5594,6 +5836,8 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm "scar_loss_proxy", "scar_activation_power", "scar_taylor", "scar_curvature", # Supernode/connectivity metrics "directed_redundancy", "supernode_protection_score", "supernode_connectivity_score", + # Weight-only structured baseline (channel-group weight magnitude) + "weight_magnitude", # Generalized importance (no outlier assumption) "generalized_importance", "neighborhood_redundancy", # LLM baseline methods (computed by compute_baseline_pruning_scores) @@ -6357,6 +6601,15 @@ def run(self) -> Dict[str, Any]: import traceback logger.error(traceback.format_exc()) + # Fast, calibration-free channel magnitude baseline (paper: "Magnitude (channel)") + if "weight_magnitude" in pruning_strategies: + try: + self.compute_weight_magnitude_channel_scores() + except Exception as mag_err: + logger.error(f"Failed weight_magnitude score computation: {mag_err}") + import traceback + logger.error(traceback.format_exc()) + # Example: per-layer histogram with top-5 annotations # self.plot_layer_importance_histogram( # layer_name="model.layers.1.mlp.up_proj", @@ -6404,6 +6657,25 @@ def run(self) -> Dict[str, Any]: baseline_ppl = self.evaluate_perplexity(dataset=self.config.evaluation_dataset, num_samples=self.config.evaluation_num_samples) results["evaluation"]["baseline_perplexity"] = baseline_ppl + # For paper tables/plots: evaluate the unpruned model once on the full configured benchmark suite. + # (This avoids hard-coding "Unpruned" numbers in the manuscript.) + try: + llm_cfg = getattr(self.config, "llm", {}) or {} + eval_metrics = llm_cfg.get("evaluation_metrics") or getattr(self.config, "evaluation_metrics", ["perplexity"]) + if isinstance(eval_metrics, str): + eval_metrics = [eval_metrics] + if eval_metrics: + baseline_eval = self.evaluate_multiple_metrics( + metrics=eval_metrics, + num_samples=self.config.evaluation_num_samples, + ) + results["evaluation"]["baseline_metrics"] = baseline_eval + # Keep baseline_perplexity in sync if evaluate_multiple_metrics produced it + if results["evaluation"].get("baseline_perplexity") is None and baseline_eval.get("perplexity") is not None: + results["evaluation"]["baseline_perplexity"] = baseline_eval.get("perplexity") + except Exception as e: + logger.warning(f"Failed baseline full-metric evaluation: {e}") + if self.config.do_pruning_experiments: sparsity_levels = self.config.pruning_amounts diff --git a/src/alignment/pruning/distribution.py b/src/alignment/pruning/distribution.py index 803d01d5..24f26aad 100644 --- a/src/alignment/pruning/distribution.py +++ b/src/alignment/pruning/distribution.py @@ -111,7 +111,8 @@ def compute_distribution( def _uniform_distribution(self, layer_names: List[str]) -> Dict[str, float]: """Same amount for all layers.""" - return {name: self.target_sparsity for name in layer_names} + amount = max(self.min_amount, min(self.max_amount, self.target_sparsity)) + return {name: amount for name in layer_names} def _global_threshold_distribution(self, layer_scores: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, float]: """ From 8851b59296f6ff03627d20032034c193ce76b072 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 12:52:03 -0500 Subject: [PATCH 08/12] fix some issues in cnn q-compute --- .../examples/llama3_extended_analysis.yaml | 2 +- configs/examples/vision_pruning_test.yaml | 2 +- configs/prune_llm/llama2_7b_full.yaml | 5 + configs/prune_llm/llama2_7b_unified.yaml | 3 + configs/prune_llm/llama3_8b_full.yaml | 8 + configs/prune_llm/llama3_8b_unified.yaml | 5 + configs/prune_llm/mistral_7b_full.yaml | 5 + configs/prune_llm/mistral_7b_unified.yaml | 3 + configs/prune_llm/qwen2_7b_full.yaml | 5 + configs/prune_llm/qwen2_7b_unified.yaml | 3 + configs/vision_prune/README.md | 2 +- docs/METRIC_CONSISTENCY.md | 2 +- scripts/collect_paper_artifacts.py | 604 +++++++++++++ scripts/run_experiment.py | 14 +- slurm_jobs/prune_llm/paper/README.md | 43 + .../paper/run_llama3_8b_calibration_array.sh | 91 ++ .../paper/run_llama3_8b_noprotect.sh | 71 ++ ...run_llama3_8b_positive_redundancy_array.sh | 80 ++ .../paper/run_llama3_8b_protect_baselines.sh | 72 ++ slurm_jobs/prune_llm/paper/submit_suite.sh | 50 ++ slurm_jobs/prune_llm/run_all_paper.sh | 10 +- slurm_jobs/prune_llm/run_llama2_7b.sh | 12 +- slurm_jobs/prune_llm/run_llama3_8b.sh | 12 +- slurm_jobs/prune_llm/run_mistral_7b.sh | 12 +- slurm_jobs/prune_llm/run_qwen2_7b.sh | 12 +- src/alignment/configs/config_loader.py | 8 + .../dataops/datasets/text_datasets.py | 88 +- src/alignment/dataops/processing/layers.py | 168 +++- src/alignment/experiments/llm_experiments.py | 796 ++++++++++++------ src/alignment/metrics/__init__.py | 2 +- src/alignment/metrics/halo_redundancy.py | 2 +- .../metrics/information/gaussian_mi.py | 2 +- .../metrics/information/synergy_continuous.py | 17 +- .../preprocessing/layer_preprocessing.py | 168 +++- src/alignment/services/activation_capture.py | 101 ++- .../metrics/test_synergy_continuous_target.py | 24 + .../unit/services/test_activation_capture.py | 123 +++ 37 files changed, 2217 insertions(+), 410 deletions(-) create mode 100644 scripts/collect_paper_artifacts.py create mode 100644 slurm_jobs/prune_llm/paper/README.md create mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh create mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh create mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh create mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh create mode 100644 slurm_jobs/prune_llm/paper/submit_suite.sh create mode 100644 tests/unit/metrics/test_synergy_continuous_target.py create mode 100644 tests/unit/services/test_activation_capture.py diff --git a/configs/examples/llama3_extended_analysis.yaml b/configs/examples/llama3_extended_analysis.yaml index ba3fa662..dfeb64d2 100644 --- a/configs/examples/llama3_extended_analysis.yaml +++ b/configs/examples/llama3_extended_analysis.yaml @@ -8,7 +8,7 @@ # 3. Cross-layer redundancy (redundancy with previous layer) # 4. Layer transition efficiency (new information per layer) # -# Based on theoretical framework in drafts/alignment_notes/new.tex +# Based on theoretical framework in drafts/alignment_notes/alignment_red.tex # ============================================================================ experiment: diff --git a/configs/examples/vision_pruning_test.yaml b/configs/examples/vision_pruning_test.yaml index 24dfca76..4a82510c 100644 --- a/configs/examples/vision_pruning_test.yaml +++ b/configs/examples/vision_pruning_test.yaml @@ -1,7 +1,7 @@ # Vision Pruning Test (AlexNet on ImageNet) # Comprehensive metrics with pruning strategies # -# Based on alignment_notes/main.tex and vision_synergy_icml.tex: +# Based on drafts/alignment_notes/alignment_red.tex: # - RQ measures alignment with input covariance # - Gaussian MI is directly related to RQ for linear-Gaussian models # - Redundancy I(Y_i; Y_j) = -0.5 * log(1 - ρ²) measures overlap between neurons diff --git a/configs/prune_llm/llama2_7b_full.yaml b/configs/prune_llm/llama2_7b_full.yaml index 908789d4..2978ddf3 100644 --- a/configs/prune_llm/llama2_7b_full.yaml +++ b/configs/prune_llm/llama2_7b_full.yaml @@ -138,6 +138,11 @@ supernode: follower_fraction: 0.10 halo_fraction: 0.10 protect_core: true + protect_core_metrics: + - "scar_loss_proxy" # SCAR-LP + - "supernode_protection_score" # SCAR-Prot + - "supernode_connectivity_score" # SCAR-Conn + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true diff --git a/configs/prune_llm/llama2_7b_unified.yaml b/configs/prune_llm/llama2_7b_unified.yaml index 4c14434e..9fc234f6 100644 --- a/configs/prune_llm/llama2_7b_unified.yaml +++ b/configs/prune_llm/llama2_7b_unified.yaml @@ -98,6 +98,7 @@ supernode: halo_fraction: 0.10 follower_fraction: 0.10 protect_core: true + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true compute_metrics: @@ -160,6 +161,7 @@ pruning: - "cross_layer_importance" # Magnitude baseline - "magnitude" + - "weight_magnitude" # SOTA baselines - "wanda" - "sparsegpt" @@ -167,6 +169,7 @@ pruning: scoring_methods: - "random" - "magnitude" + - "weight_magnitude" - "rayleigh_quotient" - "redundancy" - "average_redundancy" diff --git a/configs/prune_llm/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml index 19dce566..b817d2c6 100644 --- a/configs/prune_llm/llama3_8b_full.yaml +++ b/configs/prune_llm/llama3_8b_full.yaml @@ -175,6 +175,14 @@ supernode: follower_fraction: 0.10 halo_fraction: 0.10 protect_core: true + # Apply hard supernode protection only for the listed pruning metrics. + # If omitted, legacy behavior is to protect for *all* pruning metrics. + protect_core_metrics: + - "scar_loss_proxy" # SCAR-LP + - "supernode_protection_score" # SCAR-Prot + - "supernode_connectivity_score" # SCAR-Conn + # If true, treat anti-correlated q-signals as NON-redundant (recommended ablation) + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true diff --git a/configs/prune_llm/llama3_8b_unified.yaml b/configs/prune_llm/llama3_8b_unified.yaml index da1f91a6..acacb074 100644 --- a/configs/prune_llm/llama3_8b_unified.yaml +++ b/configs/prune_llm/llama3_8b_unified.yaml @@ -114,6 +114,8 @@ supernode: halo_fraction: 0.10 follower_fraction: 0.10 protect_core: true + # If true, treat anti-correlated q-signals as NON-redundant (recommended ablation) + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true compute_metrics: @@ -183,6 +185,8 @@ pruning: # Magnitude baseline - "magnitude" # maps to activation_l2_norm + # Weight-only magnitude baseline (channel-group) + - "weight_magnitude" # SOTA baselines - "wanda" @@ -191,6 +195,7 @@ pruning: scoring_methods: - "random" - "magnitude" # activation_l2_norm + - "weight_magnitude" - "rayleigh_quotient" - "redundancy" # gaussian_mi_analytic - "average_redundancy" diff --git a/configs/prune_llm/mistral_7b_full.yaml b/configs/prune_llm/mistral_7b_full.yaml index e32fe907..06c621a6 100644 --- a/configs/prune_llm/mistral_7b_full.yaml +++ b/configs/prune_llm/mistral_7b_full.yaml @@ -137,6 +137,11 @@ supernode: follower_fraction: 0.10 halo_fraction: 0.10 protect_core: true + protect_core_metrics: + - "scar_loss_proxy" # SCAR-LP + - "supernode_protection_score" # SCAR-Prot + - "supernode_connectivity_score" # SCAR-Conn + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true diff --git a/configs/prune_llm/mistral_7b_unified.yaml b/configs/prune_llm/mistral_7b_unified.yaml index 9e87ee48..d6c4ac13 100644 --- a/configs/prune_llm/mistral_7b_unified.yaml +++ b/configs/prune_llm/mistral_7b_unified.yaml @@ -97,6 +97,7 @@ supernode: halo_fraction: 0.10 follower_fraction: 0.10 protect_core: true + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true compute_metrics: @@ -159,6 +160,7 @@ pruning: - "cross_layer_importance" # Magnitude baseline - "magnitude" + - "weight_magnitude" # SOTA baselines - "wanda" - "sparsegpt" @@ -166,6 +168,7 @@ pruning: scoring_methods: - "random" - "magnitude" + - "weight_magnitude" - "rayleigh_quotient" - "redundancy" - "average_redundancy" diff --git a/configs/prune_llm/qwen2_7b_full.yaml b/configs/prune_llm/qwen2_7b_full.yaml index ea6e7ba8..646b29ac 100644 --- a/configs/prune_llm/qwen2_7b_full.yaml +++ b/configs/prune_llm/qwen2_7b_full.yaml @@ -138,6 +138,11 @@ supernode: follower_fraction: 0.10 halo_fraction: 0.10 protect_core: true + protect_core_metrics: + - "scar_loss_proxy" # SCAR-LP + - "supernode_protection_score" # SCAR-Prot + - "supernode_connectivity_score" # SCAR-Conn + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true diff --git a/configs/prune_llm/qwen2_7b_unified.yaml b/configs/prune_llm/qwen2_7b_unified.yaml index e0a3b762..430f62e5 100644 --- a/configs/prune_llm/qwen2_7b_unified.yaml +++ b/configs/prune_llm/qwen2_7b_unified.yaml @@ -98,6 +98,7 @@ supernode: halo_fraction: 0.10 follower_fraction: 0.10 protect_core: true + positive_redundancy: false cross_layer_analysis: true compare_by_connection: true compute_metrics: @@ -160,6 +161,7 @@ pruning: - "cross_layer_importance" # Magnitude baseline - "magnitude" + - "weight_magnitude" # SOTA baselines - "wanda" - "sparsegpt" @@ -167,6 +169,7 @@ pruning: scoring_methods: - "random" - "magnitude" + - "weight_magnitude" - "rayleigh_quotient" - "redundancy" - "average_redundancy" diff --git a/configs/vision_prune/README.md b/configs/vision_prune/README.md index f7159d80..cb009eae 100644 --- a/configs/vision_prune/README.md +++ b/configs/vision_prune/README.md @@ -111,5 +111,5 @@ The 4-cluster structure identifies: ## Related Papers -- Vision paper: `drafts/alignment_notes/vision_synergy_icml_v3.tex` +- Vision paper: `drafts/alignment_notes/alignment_red.tex` - LLM paper: `drafts/LLM_prune/scar_paper_icml_v4.tex` diff --git a/docs/METRIC_CONSISTENCY.md b/docs/METRIC_CONSISTENCY.md index 53b0e266..1e0bb1a3 100644 --- a/docs/METRIC_CONSISTENCY.md +++ b/docs/METRIC_CONSISTENCY.md @@ -1,7 +1,7 @@ # Metric Consistency with Theoretical Definitions This document verifies that the implemented metrics are consistent with the theoretical -definitions in `drafts/alignment_notes/main.tex` and `drafts/alignment_notes/new.tex`. +definitions in `drafts/alignment_notes/alignment_red.tex`. ## Summary diff --git a/scripts/collect_paper_artifacts.py b/scripts/collect_paper_artifacts.py new file mode 100644 index 00000000..a96bb643 --- /dev/null +++ b/scripts/collect_paper_artifacts.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python3 +""" +Collect SCAR paper artifacts (tables + key figures) from experiment job directories. + +This script is meant to be run AFTER the SLURM batch suite finishes. + +What it does: +- Finds the latest job directory for each expected experiment name under --results-base +- Loads the corresponding results_*.json +- Generates LaTeX table snippets into drafts/LLM_prune/paper_artifacts/tables/ +- Copies a small set of "paper figure" images into drafts/LLM_prune/ as the placeholder_*.png files + used by `drafts/LLM_prune/scar_paper_icml_v5.tex` (so the paper auto-fills without manual edits). + +Example: + python scripts/collect_paper_artifacts.py \ + --results-base /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM \ + --draft-dir /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/drafts/LLM_prune +""" + +from __future__ import annotations + +import argparse +import json +import shutil +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +# ----------------------------- +# Helpers +# ----------------------------- + + +def _latest_results_json(job_dir: Path) -> Optional[Path]: + candidates: List[Path] = [] + if (job_dir / "results").exists(): + candidates.extend(sorted((job_dir / "results").glob("results_*.json"))) + candidates.extend(sorted(job_dir.glob("results_*.json"))) + if not candidates: + return None + candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) + return candidates[0] + + +def find_latest_job_dir(results_base: Path, experiment_name: str) -> Optional[Path]: + """ + Job directories are created as: + {experiment_name}_{timestamp}_{job_id}/ + """ + if not results_base.exists(): + return None + candidates = [p for p in results_base.iterdir() if p.is_dir() and p.name.startswith(f"{experiment_name}_")] + candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) + for job_dir in candidates: + if _latest_results_json(job_dir) is not None: + return job_dir + return None + + +def load_results(job_dir: Path) -> Dict[str, Any]: + path = _latest_results_json(job_dir) + if path is None: + raise FileNotFoundError(f"No results_*.json found in {job_dir} or {job_dir/'results'}") + return json.loads(path.read_text()) + + +def _get_pruned_entry( + pruning_results: Dict[str, Any], + metric: str, + mode: str, + sparsity: float, +) -> Optional[Dict[str, Any]]: + for v in pruning_results.values(): + if not isinstance(v, dict): + continue + if v.get("metric") == metric and v.get("mode") == mode and float(v.get("sparsity", -1)) == float(sparsity): + return v + return None + + +def _pick_mode( + pruning_results: Dict[str, Any], + metric: str, + sparsity: float, + mode: str, +) -> Tuple[str, Optional[Dict[str, Any]]]: + """ + mode: + - "low" or "high": choose that mode + - "best": choose the better of low/high by perplexity + """ + if mode in {"low", "high"}: + return mode, _get_pruned_entry(pruning_results, metric=metric, mode=mode, sparsity=sparsity) + + low = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) + high = _get_pruned_entry(pruning_results, metric=metric, mode="high", sparsity=sparsity) + + def ppl(x: Optional[Dict[str, Any]]) -> float: + if not x: + return float("inf") + v = x.get("perplexity") + return float(v) if v is not None else float("inf") + + if ppl(low) <= ppl(high): + return "low", low + return "high", high + + +def _fmt(x: Any, digits: int = 1) -> str: + if x is None: + return "--" + try: + xf = float(x) + except Exception: + return "--" + if xf != xf: # NaN + return "--" + return f"{xf:.{digits}f}" + + +def write_text(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +def latex_tabular(rows: List[List[str]], align_spec: str) -> str: + lines = [] + lines.append("% Auto-generated by scripts/collect_paper_artifacts.py") + lines.append(f"\\begin{{tabular}}{{{align_spec}}}") + lines.append("\\toprule") + for i, r in enumerate(rows): + lines.append(" & ".join(r) + " \\\\") + if i == 0: + lines.append("\\midrule") + lines.append("\\bottomrule") + lines.append("\\end{tabular}") + lines.append("") + return "\n".join(lines) + + +def safe_copy(src: Path, dst: Path) -> bool: + if not src.exists(): + return False + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + return True + + +# ----------------------------- +# Table generators +# ----------------------------- + + +def make_table_main_50( + llama3_main: Dict[str, Any], + llama3_noprotect: Optional[Dict[str, Any]], + llama3_protect_baselines: Optional[Dict[str, Any]], + out_path: Path, + sparsity: float = 0.5, + mode: str = "best", +) -> None: + pruning_main = llama3_main.get("pruning_results") or {} + evaluation_main = llama3_main.get("evaluation") or {} + + # Columns used in the draft Table 1 + cols: List[Tuple[str, str, int]] = [ + ("PPL$\\downarrow$", "perplexity", 1), + ("MMLU", "accuracy_mmlu", 1), + ("Hella", "accuracy_hellaswag", 1), + ("PIQA", "accuracy_piqa", 1), + ("BoolQ", "accuracy_boolq", 1), + ] + + rows: List[List[str]] = [] + rows.append(["Method"] + [h for h, _, _ in cols]) + + # Unpruned row + baseline_ppl = evaluation_main.get("baseline_perplexity") + baseline_metrics = evaluation_main.get("baseline_metrics") or {} + unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} + rows.append(["Unpruned"] + [_fmt(unpruned_vals.get(k), d) for _, k, d in cols]) + + # Main rows from the "main" run + methods_main: List[Tuple[str, str]] = [ + ("Magnitude (channel)", "weight_magnitude"), + ("Wanda (channel)", "wanda"), + ("SparseGPT (channel)", "sparsegpt"), + ("Act. L2", "activation_l2_norm"), + ("RQ", "rayleigh_quotient"), + ("SCAR-LP", "scar_loss_proxy"), + ("SCAR-Prot", "supernode_protection_score"), + ("SCAR-Conn", "supernode_connectivity_score"), + ] + + for label, metric in methods_main: + picked_mode, entry = _pick_mode(pruning_main, metric=metric, sparsity=sparsity, mode=mode) + if entry is None: + rows.append([label] + ["--"] * len(cols)) + continue + rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) + + # Optional extra rows (if those experiments are present) + if llama3_noprotect is not None: + pr = llama3_noprotect.get("pruning_results") or {} + picked_mode, entry = _pick_mode(pr, metric="scar_loss_proxy", sparsity=sparsity, mode="low") + if entry is not None: + rows.append(["LP-no-protect"] + [_fmt(entry.get(k), d) for _, k, d in cols]) + + if llama3_protect_baselines is not None: + pr = llama3_protect_baselines.get("pruning_results") or {} + for label, metric in [("Protect+Magnitude", "weight_magnitude"), ("Protect+Wanda", "wanda")]: + picked_mode, entry = _pick_mode(pr, metric=metric, sparsity=sparsity, mode="low") + if entry is not None: + rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) + + content = latex_tabular(rows, align_spec="@{}l" + "c" * len(cols) + "@{}") + write_text(out_path, content) + + +def make_table_sparsity_tradeoff( + llama3_main: Dict[str, Any], + out_path: Path, + sparsities: List[float], + mode: str = "low", +) -> None: + pruning = llama3_main.get("pruning_results") or {} + baseline_ppl = (llama3_main.get("evaluation") or {}).get("baseline_perplexity") + + methods: List[Tuple[str, str]] = [ + ("Wanda (channel)", "wanda"), + ("SparseGPT (channel)", "sparsegpt"), + ("SCAR-LP", "scar_loss_proxy"), + ("SCAR-Prot", "supernode_protection_score"), + ("SCAR-Conn", "supernode_connectivity_score"), + ] + + header = ["Method"] + [f"{int(100*s)}\\%" for s in sparsities] + rows: List[List[str]] = [header] + if baseline_ppl is not None: + rows.append(["Baseline"] + [_fmt(baseline_ppl, 1)] * len(sparsities)) + + for label, metric in methods: + row = [label] + for s in sparsities: + picked_mode, entry = _pick_mode(pruning, metric=metric, sparsity=s, mode=mode) + row.append(_fmt(entry.get("perplexity") if entry else None, 1)) + rows.append(row) + + content = latex_tabular(rows, align_spec="@{}l" + "c" * len(sparsities) + "@{}") + write_text(out_path, content) + + +def make_table_generalization( + model_results: List[Tuple[str, Dict[str, Any]]], + out_path: Path, + sparsity: float = 0.5, + mode: str = "best", +) -> None: + rows: List[List[str]] = [] + rows.append(["Model", "Method", "PPL$\\downarrow$", "MMLU", "Avg.$\\uparrow$"]) + + for model_label, res in model_results: + pr = res.get("pruning_results") or {} + # Wanda baseline + _, wanda = _pick_mode(pr, metric="wanda", sparsity=sparsity, mode=mode) + _, scar = _pick_mode(pr, metric="supernode_connectivity_score", sparsity=sparsity, mode=mode) + + def avg_acc(entry: Optional[Dict[str, Any]]) -> Optional[float]: + if not entry: + return None + accs = [float(v) for k, v in entry.items() if isinstance(k, str) and k.startswith("accuracy_") and v is not None] + return sum(accs) / len(accs) if accs else None + + for method_label, entry in [("Wanda (channel)", wanda), ("SCAR-Conn", scar)]: + rows.append( + [ + model_label, + method_label, + _fmt(entry.get("perplexity") if entry else None, 1), + _fmt(entry.get("accuracy_mmlu") if entry else None, 1), + _fmt(avg_acc(entry), 1), + ] + ) + + content = latex_tabular(rows, align_spec="@{}llccc@{}") + write_text(out_path, content) + + +def make_table_halo_redundancy( + llama3_main: Dict[str, Any], + out_path: Path, +) -> None: + halo = llama3_main.get("halo_analysis") or {} + agg = halo.get("aggregate") or {} + + # The current analysis uses |corr|; we expose it directly for the draft table. + rows: List[List[str]] = [] + rows.append(["Group Pair", "Mean", "Std"]) + + def get(group: str, key: str) -> Any: + return (agg.get(group) or {}).get(key) + + rows.append(["Within-Halo", _fmt(get("halo_halo", "mean"), 3), _fmt(get("halo_halo", "std"), 3)]) + rows.append(["Within-Non-Halo", _fmt(get("non_halo", "mean"), 3), _fmt(get("non_halo", "std"), 3)]) + rows.append(["Cross (Halo $\\leftrightarrow$ Non-Halo)", _fmt(get("cross", "mean"), 3), _fmt(get("cross", "std"), 3)]) + + content = latex_tabular(rows, align_spec="@{}lcc@{}") + write_text(out_path, content) + + +def make_table_full_benchmarks_50( + llama3_main: Dict[str, Any], + out_path: Path, + sparsity: float = 0.5, + mode: str = "best", +) -> None: + """ + Appendix table: a wider benchmark set at a single sparsity. + """ + pruning = llama3_main.get("pruning_results") or {} + evaluation = llama3_main.get("evaluation") or {} + + cols: List[Tuple[str, str, int]] = [ + ("PPL$\\downarrow$", "perplexity", 1), + ("MMLU", "accuracy_mmlu", 1), + ("Hella", "accuracy_hellaswag", 1), + ("PIQA", "accuracy_piqa", 1), + ("BoolQ", "accuracy_boolq", 1), + ("WinoG", "accuracy_winogrande", 1), + ("ARC-E", "accuracy_arc_easy", 1), + ("ARC-C", "accuracy_arc_challenge", 1), + ("OBQA", "accuracy_openbookqa", 1), + ] + + rows: List[List[str]] = [] + rows.append(["Method"] + [h for h, _, _ in cols]) + + baseline_ppl = evaluation.get("baseline_perplexity") + baseline_metrics = evaluation.get("baseline_metrics") or {} + unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} + rows.append(["Unpruned"] + [_fmt(unpruned_vals.get(k), d) for _, k, d in cols]) + + methods: List[Tuple[str, str]] = [ + ("Random", "random"), + ("Magnitude (channel)", "weight_magnitude"), + ("Wanda (channel)", "wanda"), + ("SparseGPT (channel)", "sparsegpt"), + ("Act. L2", "activation_l2_norm"), + ("RQ", "rayleigh_quotient"), + ("Gaussian MI (analytic)", "gaussian_mi_analytic"), + ("SCAR-LP", "scar_loss_proxy"), + ("SCAR-Prot", "supernode_protection_score"), + ("SCAR-Conn", "supernode_connectivity_score"), + ] + + for label, metric in methods: + picked_mode, entry = _pick_mode(pruning, metric=metric, sparsity=sparsity, mode=mode if metric != "random" else "low") + if entry is None: + rows.append([label] + ["--"] * len(cols)) + continue + rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) + + content = latex_tabular(rows, align_spec="@{}l" + "c" * len(cols) + "@{}") + write_text(out_path, content) + + +def make_table_supernode_control( + llama3_main: Dict[str, Any], + llama3_noprotect: Optional[Dict[str, Any]], + out_path: Path, + sparsity: float = 0.5, +) -> None: + rows: List[List[str]] = [] + rows.append(["Strategy", "PPL", "Relative"]) + + pr_main = llama3_main.get("pruning_results") or {} + _, scar = _pick_mode(pr_main, metric="supernode_connectivity_score", sparsity=sparsity, mode="low") + _, wanda = _pick_mode(pr_main, metric="wanda", sparsity=sparsity, mode="low") + + remove_core_early_ppl = None + if llama3_noprotect is not None: + pr = llama3_noprotect.get("pruning_results") or {} + high = _get_pruned_entry(pr, metric="scar_loss_proxy", mode="high", sparsity=sparsity) + if high is not None: + remove_core_early_ppl = high.get("perplexity") + + rows.append(["Remove supernodes early", _fmt(remove_core_early_ppl, 1), "$\\gg$ worse"]) + rows.append(["Wanda (channel)", _fmt(wanda.get("perplexity") if wanda else None, 1), "worse"]) + rows.append(["SCAR-Conn (protect supernodes)", _fmt(scar.get("perplexity") if scar else None, 1), "best"]) + + content = latex_tabular(rows, align_spec="@{}lcc@{}") + write_text(out_path, content) + + +def make_table_calibration_sensitivity( + results_base: Path, + prefix: str, + out_path: Path, + sparsity: float = 0.5, +) -> None: + """ + Collect all runs whose experiment name starts with `prefix` and build a calibration sensitivity table. + Intended to work with the job-array naming convention from: + slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh + """ + if not results_base.exists(): + return + + job_dirs = [p for p in results_base.iterdir() if p.is_dir() and p.name.startswith(f"{prefix}_")] + if not job_dirs: + return + + # Keep the most recent run per (dataset_name, n_samples) + best: Dict[Tuple[str, int], Tuple[float, Dict[str, Any]]] = {} + + for job_dir in sorted(job_dirs, key=lambda p: p.stat().st_mtime, reverse=True): + try: + res = load_results(job_dir) + except Exception: + continue + + cfg = res.get("config") or {} + dataset_name = str(cfg.get("dataset_name", "unknown")) + n = cfg.get("alignment_data_num_samples") + try: + n_int = int(n) + except Exception: + continue + + pr = res.get("pruning_results") or {} + entry = _get_pruned_entry(pr, metric="supernode_connectivity_score", mode="low", sparsity=sparsity) + ppl = None if entry is None else entry.get("perplexity") + try: + ppl_f = float(ppl) if ppl is not None else float("inf") + except Exception: + ppl_f = float("inf") + + key = (dataset_name, n_int) + if key not in best: + best[key] = (job_dir.stat().st_mtime, {"dataset": dataset_name, "n": n_int, "ppl": ppl_f}) + + if not best: + return + + pretty_name = { + "wikitext": "WikiText-2", + "c4": "C4", + "mixed_wikitext_c4": "Mixed (Wiki + C4)", + "mixed_wiki_c4": "Mixed (Wiki + C4)", + "mixed": "Mixed (Wiki + C4)", + } + + rows: List[List[str]] = [] + rows.append(["Dataset", "\\# seqs", "PPL"]) + + # Sort rows for readability: wikitext first, then c4, then others; within by n desc. + def sort_key(item: Tuple[Tuple[str, int], Tuple[float, Dict[str, Any]]]) -> Tuple[int, str, int]: + (ds, n), _ = item + group = 0 if ds == "wikitext" else 1 if ds == "c4" else 2 + return (group, ds, -n) + + for (_, _), (_, rec) in sorted(best.items(), key=sort_key): + ds = str(rec["dataset"]) + rows.append([pretty_name.get(ds, ds), str(rec["n"]), _fmt(rec["ppl"], 1)]) + + content = latex_tabular(rows, align_spec="@{}llc@{}") + write_text(out_path, content) + + +# ----------------------------- +# Main +# ----------------------------- + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--results-base", required=True, type=str, help="Base directory containing job dirs (OUTPUT_BASE)") + ap.add_argument( + "--draft-dir", + required=True, + type=str, + help="Path to drafts/LLM_prune (where placeholder_*.png live)", + ) + ap.add_argument( + "--artifacts-dir", + type=str, + default=None, + help="Where to write paper_artifacts/ (default: {draft-dir}/paper_artifacts)", + ) + ap.add_argument("--sparsity", type=float, default=0.5, help="Primary sparsity level for tables (default: 0.5)") + ap.add_argument("--mode", type=str, default="best", choices=["low", "high", "best"], help="Mode selection for tables") + + # Canonical experiment names (these are the ExperimentConfig.name values) + ap.add_argument("--llama3-main-name", type=str, default="llama3_8b_paper_results") + ap.add_argument("--mistral-main-name", type=str, default="mistral_7b_paper_results") + ap.add_argument("--llama2-main-name", type=str, default="llama2_7b_paper_results") + ap.add_argument("--qwen2-main-name", type=str, default="qwen2_7b_paper_results") + ap.add_argument("--llama3-noprotect-name", type=str, default="llama3_8b_paper_results_noprotect") + ap.add_argument("--llama3-protect-baselines-name", type=str, default="llama3_8b_paper_results_protect_baselines") + ap.add_argument("--llama3-calib-prefix", type=str, default="llama3_8b_paper_results_calib") + + args = ap.parse_args() + + results_base = Path(args.results_base) + draft_dir = Path(args.draft_dir) + artifacts_dir = Path(args.artifacts_dir) if args.artifacts_dir else (draft_dir / "paper_artifacts") + tables_dir = artifacts_dir / "tables" + + # Locate + load main runs + llama3_main_dir = find_latest_job_dir(results_base, args.llama3_main_name) + if llama3_main_dir is None: + raise FileNotFoundError(f"Could not find a run for '{args.llama3_main_name}' under {results_base}") + llama3_main = load_results(llama3_main_dir) + + mistral_dir = find_latest_job_dir(results_base, args.mistral_main_name) + llama2_dir = find_latest_job_dir(results_base, args.llama2_main_name) + qwen2_dir = find_latest_job_dir(results_base, args.qwen2_main_name) + + mistral = load_results(mistral_dir) if mistral_dir else None + llama2 = load_results(llama2_dir) if llama2_dir else None + qwen2 = load_results(qwen2_dir) if qwen2_dir else None + + # Optional control runs + llama3_noprotect_dir = find_latest_job_dir(results_base, args.llama3_noprotect_name) + llama3_protect_baselines_dir = find_latest_job_dir(results_base, args.llama3_protect_baselines_name) + llama3_noprotect = load_results(llama3_noprotect_dir) if llama3_noprotect_dir else None + llama3_protect_baselines = load_results(llama3_protect_baselines_dir) if llama3_protect_baselines_dir else None + + # Generate tables + make_table_main_50( + llama3_main=llama3_main, + llama3_noprotect=llama3_noprotect, + llama3_protect_baselines=llama3_protect_baselines, + out_path=tables_dir / "table_main_50.tex", + sparsity=float(args.sparsity), + mode=args.mode, + ) + make_table_sparsity_tradeoff( + llama3_main=llama3_main, + out_path=tables_dir / "table_sparsity_ppl.tex", + sparsities=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + mode="low", + ) + make_table_supernode_control( + llama3_main=llama3_main, + llama3_noprotect=llama3_noprotect, + out_path=tables_dir / "table_supernode_control.tex", + sparsity=float(args.sparsity), + ) + make_table_halo_redundancy( + llama3_main=llama3_main, + out_path=tables_dir / "table_halo_redundancy.tex", + ) + make_table_full_benchmarks_50( + llama3_main=llama3_main, + out_path=tables_dir / "table_full_benchmarks_50.tex", + sparsity=float(args.sparsity), + mode=args.mode, + ) + make_table_calibration_sensitivity( + results_base=results_base, + prefix=args.llama3_calib_prefix, + out_path=tables_dir / "table_calibration_sensitivity.tex", + sparsity=float(args.sparsity), + ) + + model_rows: List[Tuple[str, Dict[str, Any]]] = [] + if mistral is not None: + model_rows.append(("Mistral-7B", mistral)) + if llama2 is not None: + model_rows.append(("Llama-2-7B", llama2)) + if qwen2 is not None: + model_rows.append(("Qwen2-7B", qwen2)) + if model_rows: + make_table_generalization(model_rows, out_path=tables_dir / "table_generalization_50.tex", sparsity=float(args.sparsity)) + + # Copy key figures into the draft placeholders (no LaTeX edits needed) + figs = Path(llama3_main_dir) / "figures" + mapping = [ + (figs / "pruning" / "pruning_comparison.png", draft_dir / "placeholder_sparsity_curves.png"), + (figs / "histograms" / "histogram_scar_loss_proxy.png", draft_dir / "placeholder_supernode_distribution.png"), + (figs / "supernode_summary" / "halo_nonhalo_metrics_by_layer.png", draft_dir / "placeholder_halo_redundancy.png"), + (figs / "supernode_summary" / "supernode_outlier_zscores.png", draft_dir / "placeholder_supernode_analysis.png"), + ] + + copied = 0 + for src, dst in mapping: + if safe_copy(src, dst): + copied += 1 + + print("\n=== Paper artifacts collected ===") + print(f"Main run: {llama3_main_dir}") + print(f"Artifacts dir: {artifacts_dir}") + print(f"Tables written: {tables_dir}") + print(f"Placeholder figures copied into draft: {copied}/{len(mapping)}") + if llama3_noprotect_dir: + print(f"Found noprotect control: {llama3_noprotect_dir}") + if llama3_protect_baselines_dir: + print(f"Found protect-baselines: {llama3_protect_baselines_dir}") + + +if __name__ == "__main__": + main() + diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 70de33b5..8b2afe5e 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -644,14 +644,12 @@ def main(): if args.seed: overrides["seed"] = args.seed - # Load config - from alignment.configs.config_loader import load_config as proper_load_config - config = proper_load_config(args.config) - - # Apply overrides - for key, value in overrides.items(): - if hasattr(config, key): - setattr(config, key, value) + # Load config (support key=value overrides passed after args) + # Example: + # python scripts/run_experiment.py --config ... name="llama3_8b_paper_main" supernode.protect_core=false + from alignment.configs.config_loader import load_config_with_overrides as proper_load_config + cli_overrides = [x for x in (unknown or []) if isinstance(x, str) and "=" in x] + config = proper_load_config(args.config, overrides=overrides or None, cli_args=cli_overrides or None) # Override base_output_dir if provided via CLI if args.base_output_dir: diff --git a/slurm_jobs/prune_llm/paper/README.md b/slurm_jobs/prune_llm/paper/README.md new file mode 100644 index 00000000..e9a9d89e --- /dev/null +++ b/slurm_jobs/prune_llm/paper/README.md @@ -0,0 +1,43 @@ +### SCAR paper experiment suite (batch + collection) + +This folder contains **SLURM batch scripts** that run a complete ICML-style paper suite: + +- **Main results + generalization** (4 models) +- **Key controls / ablations** on Llama-3.1-8B: + - **LP-no-protect** + **remove-supernodes-early** (mode=high) control + - **Protect+Wanda** and **Protect+Magnitude** (baseline + supernode protection) + - **Positive-only redundancy** ablation (anti-correlation does NOT count as redundancy) + - **Calibration sensitivity** sweep (dataset + sample-count) + +All jobs write to a single `OUTPUT_BASE` using the unified job directory structure: + +`{OUTPUT_BASE}/{experiment_name}_{timestamp}_{job_id}/` + +### How to run + +- **Set output base** (or let scripts use the default in each file): + +```bash +export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +``` + +- **Submit the full suite**: + +```bash +bash slurm_jobs/prune_llm/paper/submit_suite.sh +``` + +### How to collect artifacts (tables + placeholder figures) + +After jobs finish: + +```bash +python scripts/collect_paper_artifacts.py \ + --results-base "$OUTPUT_BASE" \ + --draft-dir /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/drafts/LLM_prune +``` + +This will: +- write LaTeX snippets to `drafts/LLM_prune/paper_artifacts/tables/*.tex` +- copy key plots into `drafts/LLM_prune/placeholder_*.png` so the draft auto-fills. + diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh new file mode 100644 index 00000000..d2807c35 --- /dev/null +++ b/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh @@ -0,0 +1,91 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_calib +#SBATCH --output=logs/paper_llama3_calib_%A_%a.out +#SBATCH --error=logs/paper_llama3_calib_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev +#SBATCH --array=0-4 + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B SWEEP: calibration sensitivity for SCAR-Conn @ 50% sparsity +# +# Task mapping: +# 0: wikitext, n=128 +# 1: wikitext, n=64 +# 2: wikitext, n=32 +# 3: c4, n=128 +# 4: mixed_wikitext_c4, n=128 +# +# Notes: +# - We restrict pruning to SCAR-Conn at 50% and evaluate perplexity only (fast). +# ---------------------------------------------------------------------------- + +set -euo pipefail + +DATASETS=("wikitext" "wikitext" "wikitext" "c4" "mixed_wikitext_c4") +NSAMPLES=(128 64 32 128 128) +TAGS=("wikitext_128" "wikitext_64" "wikitext_32" "c4_128" "mixed_128") + +IDX="${SLURM_ARRAY_TASK_ID}" +DATASET="${DATASETS[$IDX]}" +N="${NSAMPLES[$IDX]}" +TAG="${TAGS[$IDX]}" + +echo "============================================================================" +echo "SCAR Paper Sweep: LLaMA-3.1-8B calibration sensitivity (${TAG})" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "Calibration dataset: ${DATASET}" +echo "Calibration samples: ${N}" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export HF_HOME=/n/home13/hsafaai/.cache/huggingface +export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_calib_${TAG}" \ + generate_plots=false \ + dataset_name="${DATASET}" \ + alignment_data_num_samples="${N}" \ + scar_num_samples="${N}" \ + pruning_strategies="['supernode_connectivity_score']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + "llm.evaluation_metrics=['perplexity']" \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B calibration sweep (${TAG}) completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh new file mode 100644 index 00000000..3263a5ab --- /dev/null +++ b/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh @@ -0,0 +1,71 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_noprotect +#SBATCH --output=logs/paper_llama3_noprotect_%j.out +#SBATCH --error=logs/paper_llama3_noprotect_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B CONTROL: LP-no-protect + "remove supernodes early" (mode=high) +# +# Produces (at 50%): +# - LP-no-protect: metric=scar_loss_proxy, mode=low, protect_core=false +# - Remove-core-early metric=scar_loss_proxy, mode=high, protect_core=false +# ---------------------------------------------------------------------------- + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Control: LLaMA-3.1-8B (no-protect LP control)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export HF_HOME=/n/home13/hsafaai/.cache/huggingface +export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_noprotect" \ + generate_plots=false \ + supernode.protect_core=false \ + pruning_strategies="['scar_loss_proxy']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low','high']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B no-protect control completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh new file mode 100644 index 00000000..240edc85 --- /dev/null +++ b/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh @@ -0,0 +1,80 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_posred +#SBATCH --output=logs/paper_llama3_posred_%A_%a.out +#SBATCH --error=logs/paper_llama3_posred_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev +#SBATCH --array=0-1 + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B ABLATION: positive-only redundancy vs rho^2 redundancy +# +# Task 0: positive_redundancy=false (rho^2 counts anti-correlation as redundancy) +# Task 1: positive_redundancy=true (rho^+ only; anti-correlation NOT redundant) +# ---------------------------------------------------------------------------- + +set -euo pipefail + +if [ "${SLURM_ARRAY_TASK_ID}" -eq 0 ]; then + POS_RED="false" + TAG="rho2" +else + POS_RED="true" + TAG="posonly" +fi + +echo "============================================================================" +echo "SCAR Paper Ablation: LLaMA-3.1-8B (positive redundancy = ${POS_RED})" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export HF_HOME=/n/home13/hsafaai/.cache/huggingface +export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_posred_${TAG}" \ + generate_plots=false \ + supernode.positive_redundancy="${POS_RED}" \ + supernode.protect_core=true \ + "supernode.protect_core_metrics=['supernode_connectivity_score']" \ + pruning_strategies="['supernode_connectivity_score']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B pos-redundancy ablation (${TAG}) completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh new file mode 100644 index 00000000..eafcd13f --- /dev/null +++ b/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh @@ -0,0 +1,72 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_protect_base +#SBATCH --output=logs/paper_llama3_protect_base_%j.out +#SBATCH --error=logs/paper_llama3_protect_base_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=08:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B CONTROL: Protect+Baseline variants +# +# Produces (at 50%): +# - Protect+Wanda: metric=wanda, protect_core_metrics includes wanda +# - Protect+Magnitude: metric=weight_magnitude, protect_core_metrics includes weight_magnitude +# ---------------------------------------------------------------------------- + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Control: LLaMA-3.1-8B (protect baselines)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export HF_HOME=/n/home13/hsafaai/.cache/huggingface +export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_protect_baselines" \ + generate_plots=false \ + supernode.protect_core=true \ + "supernode.protect_core_metrics=['wanda','weight_magnitude']" \ + pruning_strategies="['wanda','weight_magnitude']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B protect-baselines completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/paper/submit_suite.sh b/slurm_jobs/prune_llm/paper/submit_suite.sh new file mode 100644 index 00000000..bc6c29aa --- /dev/null +++ b/slurm_jobs/prune_llm/paper/submit_suite.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT FULL SCAR PAPER SUITE (main + controls/ablations) +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# bash slurm_jobs/prune_llm/paper/submit_suite.sh +# +# Output: +# Uses OUTPUT_BASE (exported or defaulted below). +# ============================================================================ + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" + +echo "==============================================" +echo "Submitting SCAR Paper Suite" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +echo "---- Main results + generalization (4 models) ----" +export OUTPUT_BASE +bash slurm_jobs/prune_llm/run_all_paper.sh +echo "" + +echo "---- Controls / ablations (Llama-3.1-8B) ----" +JOB_NP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh | awk '{print $4}') +echo " noprotect/control: $JOB_NP" + +JOB_PB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh | awk '{print $4}') +echo " protect-baselines: $JOB_PB" + +JOB_POSRED=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh | awk '{print $4}') +echo " pos-redundancy array: $JOB_POSRED" + +JOB_CALIB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh | awk '{print $4}') +echo " calibration array: $JOB_CALIB" + +echo "" +echo "==============================================" +echo "All suite jobs submitted" +echo "==============================================" +echo "Monitor with: squeue -u \$USER" +echo "" + diff --git a/slurm_jobs/prune_llm/run_all_paper.sh b/slurm_jobs/prune_llm/run_all_paper.sh index 36edaa53..e6e8f443 100755 --- a/slurm_jobs/prune_llm/run_all_paper.sh +++ b/slurm_jobs/prune_llm/run_all_paper.sh @@ -19,7 +19,7 @@ # bash slurm_jobs/prune_llm/run_all_paper.sh # ============================================================================ -OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" echo "==============================================" echo "Submitting SCAR Paper Experiments" @@ -32,19 +32,19 @@ cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment # Submit all jobs echo "Submitting LLaMA-3.1-8B (main results)..." -JOB1=$(sbatch slurm_jobs/prune_llm/run_llama3_8b.sh | awk '{print $4}') +JOB1=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_llama3_8b.sh | awk '{print $4}') echo " Job ID: $JOB1" echo "Submitting Mistral-7B (generalization)..." -JOB2=$(sbatch slurm_jobs/prune_llm/run_mistral_7b.sh | awk '{print $4}') +JOB2=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_mistral_7b.sh | awk '{print $4}') echo " Job ID: $JOB2" echo "Submitting LLaMA-2-7B (generalization)..." -JOB3=$(sbatch slurm_jobs/prune_llm/run_llama2_7b.sh | awk '{print $4}') +JOB3=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_llama2_7b.sh | awk '{print $4}') echo " Job ID: $JOB3" echo "Submitting Qwen2-7B (generalization)..." -JOB4=$(sbatch slurm_jobs/prune_llm/run_qwen2_7b.sh | awk '{print $4}') +JOB4=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_qwen2_7b.sh | awk '{print $4}') echo " Job ID: $JOB4" echo "" diff --git a/slurm_jobs/prune_llm/run_llama2_7b.sh b/slurm_jobs/prune_llm/run_llama2_7b.sh index 5e28a8f1..c9d36de8 100755 --- a/slurm_jobs/prune_llm/run_llama2_7b.sh +++ b/slurm_jobs/prune_llm/run_llama2_7b.sh @@ -34,7 +34,8 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" echo "" # Environment setup @@ -58,16 +59,15 @@ echo "" echo "Running LLaMA-2-7B full paper analysis..." echo "" -# The config has base_dir set, so outputs go to: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/prune_llm/llama2_7b_unified.yaml \ - --device cuda + --config configs/prune_llm/llama2_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" echo "" echo "============================================================================" echo "LLaMA-2-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Results saved to: $OUTPUT_BASE/" echo "Look for directory: llama2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_llama3_8b.sh b/slurm_jobs/prune_llm/run_llama3_8b.sh index a4a83c96..b7fa3c22 100755 --- a/slurm_jobs/prune_llm/run_llama3_8b.sh +++ b/slurm_jobs/prune_llm/run_llama3_8b.sh @@ -41,7 +41,8 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" echo "" # Environment setup @@ -65,16 +66,15 @@ echo "" echo "Running LLaMA-3.1-8B full paper analysis..." echo "" -# The config has base_dir set, so outputs go to: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_unified.yaml \ - --device cuda + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" echo "" echo "============================================================================" echo "LLaMA-3.1-8B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Results saved to: $OUTPUT_BASE/" echo "Look for directory: llama3_8b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_mistral_7b.sh b/slurm_jobs/prune_llm/run_mistral_7b.sh index 91efd866..c4f04b7e 100755 --- a/slurm_jobs/prune_llm/run_mistral_7b.sh +++ b/slurm_jobs/prune_llm/run_mistral_7b.sh @@ -34,7 +34,8 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" echo "" # Environment setup @@ -58,16 +59,15 @@ echo "" echo "Running Mistral-7B full paper analysis..." echo "" -# The config has base_dir set, so outputs go to: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/prune_llm/mistral_7b_unified.yaml \ - --device cuda + --config configs/prune_llm/mistral_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" echo "" echo "============================================================================" echo "Mistral-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Results saved to: $OUTPUT_BASE/" echo "Look for directory: mistral_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_qwen2_7b.sh b/slurm_jobs/prune_llm/run_qwen2_7b.sh index 31780718..a81d62b9 100755 --- a/slurm_jobs/prune_llm/run_qwen2_7b.sh +++ b/slurm_jobs/prune_llm/run_qwen2_7b.sh @@ -35,7 +35,8 @@ echo "Job ID: $SLURM_JOB_ID" echo "Node: $(hostname)" echo "Start time: $(date)" echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "Output Base: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" echo "" # Environment setup @@ -59,16 +60,15 @@ echo "" echo "Running Qwen2-7B full paper analysis..." echo "" -# The config has base_dir set, so outputs go to: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/{name}_{timestamp}_{job_id}/ python scripts/run_experiment.py \ - --config configs/prune_llm/qwen2_7b_unified.yaml \ - --device cuda + --config configs/prune_llm/qwen2_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" echo "" echo "============================================================================" echo "Qwen2-7B completed at $(date)" echo "============================================================================" echo "" -echo "Results saved to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/" +echo "Results saved to: $OUTPUT_BASE/" echo "Look for directory: qwen2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index cdf724ec..4590a908 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -1206,6 +1206,14 @@ def load_config_with_overrides( key, value = arg.split("=", 1) # Convert value to appropriate type try: + # Common CLI convenience: YAML-style booleans/nulls + raw = value.strip() + low = raw.lower() + if low in {"true", "false"}: + value = (low == "true") + elif low in {"none", "null"}: + value = None + else: value = eval(value) except Exception: pass # Keep as string diff --git a/src/alignment/dataops/datasets/text_datasets.py b/src/alignment/dataops/datasets/text_datasets.py index feed16fb..02e2adf7 100644 --- a/src/alignment/dataops/datasets/text_datasets.py +++ b/src/alignment/dataops/datasets/text_datasets.py @@ -116,7 +116,21 @@ class C4Dataset(IterableDataset): """ def __init__(self, tokenizer: Any, split: str = "validation", max_length: int = 512, max_samples: Optional[int] = None): - self.tokenizer = tokenizer + from transformers import AutoTokenizer, PreTrainedTokenizerBase + + # Accept either a tokenizer object or a model ID string + if isinstance(tokenizer, PreTrainedTokenizerBase): + hf_tokenizer = tokenizer + elif isinstance(tokenizer, str): + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + raise TypeError(f"tokenizer must be a string or PreTrainedTokenizerBase, got {type(tokenizer)}") + + # If no pad token exists, set it to the eos token (common for causal LM) + if hf_tokenizer.pad_token is None: + hf_tokenizer.pad_token = hf_tokenizer.eos_token + + self.tokenizer = hf_tokenizer self.max_length = max_length self.max_samples = max_samples @@ -125,14 +139,33 @@ def __init__(self, tokenizer: Any, split: str = "validation", max_length: int = logger.info(f"Loading C4 dataset ({split}, streaming)") self.dataset = load_dataset("allenai/c4", "en", split=split, streaming=True) + # For LLM calibration/analysis we often need raw texts (e.g., to build a reusable + # calibration set). If max_samples is specified, materialize that many texts so + # downstream code can rely on a `.texts` attribute (like WikiText). + self.texts: Optional[List[str]] = None + if self.max_samples is not None: + texts: List[str] = [] + for item in self.dataset: + text = item.get("text") + if not text or len(text.strip()) == 0: + continue + texts.append(text) + if len(texts) >= self.max_samples: + break + self.texts = texts + logger.info(f"Materialized {len(self.texts)} C4 texts for split='{split}'") + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: count = 0 - for item in self.dataset: + # If we materialized texts, iterate over them for deterministic reuse. + iterable = self.texts if self.texts is not None else self.dataset + + for item in iterable: if self.max_samples and count >= self.max_samples: break - text = item["text"] + text = item if isinstance(item, str) else item.get("text") if not text or len(text.strip()) == 0: continue @@ -170,6 +203,51 @@ def load_text_dataset( elif dataset_name == "c4": return C4Dataset(tokenizer, split, max_length, max_samples) + elif dataset_name in {"mixed_wikitext_c4", "mixed_wiki_c4", "mixed"}: + # Lightweight "mixed" calibration set: combine WikiText + C4 raw texts. + # This is useful for robustness/sensitivity experiments. + # + # Supported kwargs: + # - wikitext_name: which WikiText subset to use (default: wikitext-2-raw-v1) + # - wikitext_fraction: fraction of samples drawn from WikiText when max_samples is set (default: 0.5) + from transformers import AutoTokenizer, PreTrainedTokenizerBase + + if isinstance(tokenizer, PreTrainedTokenizerBase): + hf_tokenizer = tokenizer + elif isinstance(tokenizer, str): + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + raise TypeError(f"tokenizer must be a string or PreTrainedTokenizerBase, got {type(tokenizer)}") + + if hf_tokenizer.pad_token is None: + hf_tokenizer.pad_token = hf_tokenizer.eos_token + + wikitext_name = kwargs.get("wikitext_name", "wikitext-2-raw-v1") + wikitext_fraction = float(kwargs.get("wikitext_fraction", 0.5)) + + if max_samples is None: + # Default to a small mixed set if caller didn't specify a budget. + max_samples = 512 + + n_wiki = int(round(max_samples * wikitext_fraction)) + n_c4 = max_samples - n_wiki + + wiki_ds = WikiTextDataset(hf_tokenizer, split=split, max_length=max_length, dataset_name=wikitext_name) + wiki_texts = list(getattr(wiki_ds, "texts", []))[:n_wiki] + + c4_ds = C4Dataset(hf_tokenizer, split=split, max_length=max_length, max_samples=n_c4) + c4_texts = list(getattr(c4_ds, "texts", []))[:n_c4] + + # Interleave for better mixing. + mixed_texts: List[str] = [] + for i in range(max(len(wiki_texts), len(c4_texts))): + if i < len(wiki_texts): + mixed_texts.append(wiki_texts[i]) + if i < len(c4_texts): + mixed_texts.append(c4_texts[i]) + + return TextDataset(mixed_texts, hf_tokenizer, max_length=max_length) + elif dataset_name == "ptb": from datasets import load_dataset @@ -181,7 +259,9 @@ def load_text_dataset( return TextDataset(texts, tokenizer, max_length) else: - raise ValueError(f"Unknown dataset: {dataset_name}. " f"Supported: wikitext, c4, ptb") + raise ValueError( + f"Unknown dataset: {dataset_name}. Supported: wikitext, c4, ptb, mixed_wikitext_c4" + ) # Register datasets in alignment registry if needed diff --git a/src/alignment/dataops/processing/layers.py b/src/alignment/dataops/processing/layers.py index a1ed207d..925e1643 100644 --- a/src/alignment/dataops/processing/layers.py +++ b/src/alignment/dataops/processing/layers.py @@ -115,18 +115,51 @@ def _unfold_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: boo Returns: Tensor of shape [batch_size * num_patches, features] """ - b, c, h, w = activation.shape - - if is_input and isinstance(layer, nn.Conv2d): - # For inputs, unfold based on the layer's kernel parameters - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) - # [b, features*kernel_size, num_patches] -> [b*num_patches, features] - unfolded = unfolded.transpose(1, 2).contiguous() - return unfolded.view(-1, unfolded.size(2)) - else: - # For outputs or non-conv layers, just flatten spatial dims - return activation.reshape(b, c, -1).permute(0, 2, 1).reshape(-1, c) + if isinstance(layer, nn.Conv2d): + if activation.ndim != 4: + raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") + + b, c, h, w = activation.shape + + if is_input: + # Unfold based on the layer's kernel parameters so feature dimension matches weight flattening + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + # [b, features, num_patches] -> [b*num_patches, features] + unfolded = unfolded.transpose(1, 2).contiguous() + return unfolded.view(-1, unfolded.size(2)) + + # Output: treat each spatial location as a sample (node = output channel) + # [b, c, h, w] -> [b*h*w, c] + return activation.permute(0, 2, 3, 1).reshape(-1, c) + + if isinstance(layer, nn.Conv1d): + if activation.ndim != 3: + raise ValueError(f"Expected 3D tensor for Conv1d, got {activation.ndim}D") + + b, c, l = activation.shape + + if is_input: + # Use 2D unfold trick on [b, c, 1, l] to respect stride/padding/dilation + x4 = activation.unsqueeze(2) # [b, c, 1, l] + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + unfolded = torch.nn.functional.unfold( + x4, + kernel_size=(1, k), + dilation=(1, d), + padding=(0, p), + stride=(1, s), + ) # [b, c*k, num_patches] + unfolded = unfolded.transpose(1, 2).contiguous() + return unfolded.view(-1, unfolded.size(2)) # [b*num_patches, c*k] + + # Output: [b, c, l] -> [b*l, c] + return activation.permute(0, 2, 1).reshape(-1, c) + + raise ValueError(f"Expected Conv layer, got {type(layer)}") def _patchwise_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: bool) -> torch.Tensor: """ @@ -135,16 +168,47 @@ def _patchwise_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: Returns: Tensor of shape [batch_size, features, num_patches] """ - b, c, h, w = activation.shape + if isinstance(layer, nn.Conv2d): + if activation.ndim != 4: + raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") + + b, c, h, w = activation.shape + + if is_input: + # Unfold to get kernel patches + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + return unfolded # [b, features, patches] + + # Output: reshape spatial dims to patches (node = output channel) + return activation.reshape(b, c, h * w) # [b, c, patches] + + if isinstance(layer, nn.Conv1d): + if activation.ndim != 3: + raise ValueError(f"Expected 3D tensor for Conv1d, got {activation.ndim}D") + + b, c, l = activation.shape + + if is_input: + # Unfold 1D input into kernel patches: [b, c*k, patches] + x4 = activation.unsqueeze(2) # [b, c, 1, l] + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + unfolded = torch.nn.functional.unfold( + x4, + kernel_size=(1, k), + dilation=(1, d), + padding=(0, p), + stride=(1, s), + ) + return unfolded # [b, c*k, patches] + + # Output: already [b, c, l] = [b, c, patches] + return activation - if is_input and isinstance(layer, nn.Conv2d): - # Unfold to get patches - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) - return unfolded # [b, features, patches] - else: - # For outputs, reshape spatial dims to patches - return activation.reshape(b, c, h * w) + raise ValueError(f"Expected Conv layer, got {type(layer)}") def _batch_patch_combined_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: bool) -> torch.Tensor: """ @@ -166,30 +230,48 @@ def _get_unfold_params(self, layer: nn.Module) -> Dict[str, Any]: def get_output_shape(self, input_shape: Tuple[int, ...], layer: nn.Module) -> Tuple[int, ...]: """Get expected output shape after preprocessing.""" - if len(input_shape) != 4: - raise ValueError(f"Expected 4D input shape, got {len(input_shape)}D") - - b, c, h, w = input_shape - - if self.mode == "unfold" or self.mode == "batch_patch_combined": - if isinstance(layer, nn.Conv2d): - # Calculate number of patches - out_h = (h + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 - out_w = (w + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 - num_patches = out_h * out_w - features = c * layer.kernel_size[0] * layer.kernel_size[1] + if isinstance(layer, nn.Conv2d): + if len(input_shape) != 4: + raise ValueError(f"Expected 4D input shape for Conv2d, got {len(input_shape)}D") + b, c, h, w = input_shape + + # Output spatial size (PyTorch conv2d formula; floor division) + k_h, k_w = layer.kernel_size + s_h, s_w = layer.stride + p_h, p_w = layer.padding + d_h, d_w = layer.dilation + out_h = (h + 2 * p_h - d_h * (k_h - 1) - 1) // s_h + 1 + out_w = (w + 2 * p_w - d_w * (k_w - 1) - 1) // s_w + 1 + num_patches = max(0, out_h) * max(0, out_w) + features = c * k_h * k_w + + if self.mode in {"unfold", "batch_patch_combined"}: return (b * num_patches, features) - else: - return (b * h * w, c) - elif self.mode == "patchwise": - if isinstance(layer, nn.Conv2d): - out_h = (h + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 - out_w = (w + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 - num_patches = out_h * out_w - features = c * layer.kernel_size[0] * layer.kernel_size[1] + if self.mode == "patchwise": return (b, features, num_patches) - else: - return (b, c, h * w) + + raise ValueError(f"Unknown mode: {self.mode}") + + if isinstance(layer, nn.Conv1d): + if len(input_shape) != 3: + raise ValueError(f"Expected 3D input shape for Conv1d, got {len(input_shape)}D") + b, c, l = input_shape + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + out_l = (l + 2 * p - d * (k - 1) - 1) // s + 1 + num_patches = max(0, out_l) + features = c * k + + if self.mode in {"unfold", "batch_patch_combined"}: + return (b * num_patches, features) + if self.mode == "patchwise": + return (b, features, num_patches) + + raise ValueError(f"Unknown mode: {self.mode}") + + raise ValueError(f"Expected Conv layer, got {type(layer)}") class AttentionPreprocessor(LayerPreprocessor): diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 7e2d10ed..d8fff89d 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -2038,16 +2038,8 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: # Gradient w.r.t. module input (u) if not grad_input or grad_input[0] is None: return - if not grad_output or grad_output[0] is None: - return g_u = grad_input[0] - g_y = grad_output[0] - - if not hasattr(mod, "weight"): - return - - weight = mod.weight # [hidden_dim, m] # Retrieve stored u from forward hook (if available) if not hasattr(mod, "_scar_last_u"): @@ -2068,11 +2060,6 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: else: g_u_flat = g_u.reshape(-1, g_u.shape[-1]) - if g_y.ndim > 2: - g_y_flat = g_y.reshape(-1, g_y.shape[-1]) - else: - g_y_flat = g_y.reshape(-1, g_y.shape[-1]) - # Ensure shapes are consistent if u_flat.shape != g_u_flat.shape: logger.warning( @@ -2080,13 +2067,10 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: ) return - # Curvature: R_i = E[ (v_i^T g_y)^2 ] - # s = g_y * W_down => [N_tokens, m] - try: - s_flat = torch.matmul(g_y_flat.float(), weight.float()) # [N_tokens, m] - except Exception as e: - logger.error(f"SCAR metrics: failed to compute W_down^T g_y for layer {name}: {e}") - return + # NOTE: In backprop through y=W_down u, PyTorch already computes: + # g_u = dL/du = W_down^T * dL/dy + # So s_i := (v_i^T g_y) is exactly g_u_i. No extra GEMM needed. + s_flat = g_u_flat.float() s2 = (s_flat * s_flat).sum(dim=0) state["R_sum"] += s2 @@ -2183,6 +2167,24 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: # Now that scar_loss_proxy exists, we can compute the configured supernode mask on this layer. # This ensures 'protect_core' works during pruning even when score_metric='scar_loss_proxy'. self._apply_supernode_selection(layer_scores, composite=None) + # Propagate the supernode mask to sibling MLP projections so that channel-level protection + # works regardless of which projection holds the pruning scores (e.g., Wanda stores channel + # scores on gate/up/down; alignment metrics often live on gate/up). + try: + mask = layer_scores.get("supernode_mask") + if mask is not None and isinstance(layer_name, str) and "down_proj" in layer_name: + for sibling_proj in ("gate_proj", "up_proj"): + sibling_name = layer_name.replace("down_proj", sibling_proj) + if sibling_name in self.importance_scores: + sib_scores = self.importance_scores.get(sibling_name, {}) + sib_scores["supernode_mask"] = mask + if "supernode_core_size" in layer_scores: + sib_scores["supernode_core_size"] = layer_scores["supernode_core_size"] + if "supernode_threshold" in layer_scores: + sib_scores["supernode_threshold"] = layer_scores["supernode_threshold"] + self.importance_scores[sibling_name] = sib_scores + except Exception as _prop_err: + logger.debug(f"Failed to propagate supernode mask for {layer_name}: {_prop_err}") self.importance_scores[layer_name] = layer_scores logger.info(f"SCAR metrics: computed metrics for {len(scar_scores)} FFN layers.") @@ -2587,6 +2589,41 @@ def _apply_supernode_selection(self, layer_scores: Dict[str, torch.Tensor], comp layer_scores["supernode_core_size"] = num_core layer_scores["supernode_threshold"] = sorted_scores[min(num_core - 1, sorted_scores.shape[0] - 1)].item() + def _should_protect_supernodes_for_metric(self, metric: str) -> bool: + """ + Decide whether supernode protection (i.e., forcing core channels to be kept) should be applied + for a given pruning metric. + + Backward-compatible behavior: + - If `supernode.protect_core_metrics` is NOT set, protection applies to *all* metrics + (matching the legacy behavior when `protect_core: true`). + """ + cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} + if not cfg.get("enabled", False): + return False + if not cfg.get("protect_core", True): + return False + + protect_metrics = cfg.get("protect_core_metrics", None) + if protect_metrics is None: + return True + + # Accept a few convenient string shorthands. + if isinstance(protect_metrics, str): + token = protect_metrics.strip().lower() + if token in {"all", "true", "yes", "1"}: + return True + if token in {"none", "false", "no", "0", ""}: + return False + # comma-separated list + protect_metrics = [m.strip() for m in protect_metrics.split(",") if m.strip()] + + try: + return metric in set(protect_metrics) + except TypeError: + # If the config value is malformed, fall back to "protect everything" (safer). + return True + def analyze_supernode_connections( self, scar_scores: Dict[str, Dict[str, torch.Tensor]], @@ -4567,6 +4604,10 @@ def compute_supernode_connectivity_pruning_score( eps = 1e-8 results: Dict[str, Dict[str, Any]] = {} + supernode_cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} + positive_redundancy = bool(supernode_cfg.get("positive_redundancy", False)) + if positive_redundancy: + logger.info(" Redundancy: using positive-only correlation (anti-correlation does NOT count as redundancy)") # Underlying HF model for module lookup / hook registration hf_model = self.model @@ -4673,9 +4714,7 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: st = plan.get(name) if st is None: return - if not grad_output or grad_output[0] is None: - return - if not hasattr(mod, "weight"): + if not grad_input or grad_input[0] is None: return if not hasattr(mod, "_scar_conn_last_u"): return @@ -4683,18 +4722,17 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: u = mod._scar_conn_last_u delattr(mod, "_scar_conn_last_u") - g_y = grad_output[0] - weight = mod.weight + g_u = grad_input[0] # Flatten to [N_tokens, dim] if u.ndim > 2: u_flat = u.reshape(-1, u.shape[-1]) else: u_flat = u.reshape(-1, u.shape[-1]) - if g_y.ndim > 2: - g_y_flat = g_y.reshape(-1, g_y.shape[-1]) + if g_u.ndim > 2: + g_u_flat = g_u.reshape(-1, g_u.shape[-1]) else: - g_y_flat = g_y.reshape(-1, g_y.shape[-1]) + g_u_flat = g_u.reshape(-1, g_u.shape[-1]) if u_flat.numel() == 0: return @@ -4708,13 +4746,12 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: super_idx_dev = st["super_idx"] halo_idx_dev = st["halo_idx"] - # Compute q = u * (W_down^T g_y) ONLY for channels we need (supernodes + halo). - # This avoids a full [N, m] GEMM per layer. + # Compute q = u * s where s := dL/du is already computed by backprop. + # We only materialize the supernode+halo indices. idx_union = torch.cat([super_idx_dev, halo_idx_dev], dim=0) # [|M|+|H|] try: - W_sel = weight.index_select(1, idx_union).float() # [hidden_dim, |M|+|H|] - s_sel = torch.matmul(g_y_flat.float(), W_sel) # [N, |M|+|H|] u_sel = u_flat.index_select(1, idx_union).float() # [N, |M|+|H|] + s_sel = g_u_flat.index_select(1, idx_union).float() # [N, |M|+|H|] except Exception: return @@ -4723,7 +4760,7 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: q_super = q_sel[:, :n_super] # [N, |M|] q_halo = q_sel[:, n_super:] # [N, |H|] - N = q_flat.shape[0] + N = q_sel.shape[0] # Initialize streaming sums on first batch if st["sum_q_super"] is None: @@ -4816,7 +4853,8 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: corr = torch.where(denom > 0, cov / denom, torch.zeros_like(cov)) corr = corr.clamp(-0.9999, 0.9999) - rho_sq = (corr * corr).clamp(0.0, 0.9999) + corr_eff = torch.clamp(corr, min=0.0) if positive_redundancy else corr + rho_sq = (corr_eff * corr_eff).clamp(0.0, 0.9999) mi = -0.5 * torch.log(1 - rho_sq) redundancy_to_core = mi.max(dim=1).values # [|H|] @@ -4885,230 +4923,470 @@ def analyze_halo_vs_nonhalo_redundancy( sample_pairs: int = 2000, ) -> Dict[str, Dict[str, Any]]: """ - Analyze redundancy patterns between halo and non-halo neurons. - - Computes pairwise |correlation| for three groups: - 1. Halo-Halo: Both neurons in halo (high connectivity to supernodes) - 2. Non-halo-Non-halo: Both neurons NOT in halo (low connectivity) - 3. Cross-group: One in halo, one not - - This helps validate whether halo membership correlates with redundancy. - - Args: - scar_scores: SCAR metrics from compute_scar_metrics - supernode_fraction: Fraction of neurons to consider as supernodes - halo_fraction: Fraction of non-supernodes to consider as halo - num_samples: Number of calibration samples - max_length: Max sequence length - sample_pairs: Number of pairs to sample per group (for efficiency) - + Paper-aligned halo redundancy analysis using the loss-relevant contribution signal. + + We compare redundancy between three groups (per layer), then aggregate across layers: + 1) **Halo-Halo**: both channels in the halo (high Conn to supernode write pattern) + 2) **Non-halo**: both channels outside halo and outside supernodes + 3) **Cross**: one halo channel and one non-halo channel + + Signal: + \(q_i = u_i s_i\) where \(u\) is the FFN post-gate activation (down_proj input) and + \(s=\nabla_u \mathcal{L}\) (down_proj grad_input[0]). + + Redundancy proxy: + - \(\rho_{ij}=\mathrm{corr}(q_i,q_j)\) over calibration tokens + - Optional **positive-only** redundancy: \(\rho^+_{ij}=\max(0,\rho_{ij})\) + - \(\mathrm{Red}(i,j) = -\tfrac12 \log(1-(\rho_{ij})^2)\) + + Notes: + - Supernodes are identified by `scar_loss_proxy` when available (paper definition). + - Halo membership is identified by Conn overlap with the aggregated supernode write pattern + (same as `compute_supernode_connectivity_pruning_score`). + Returns: - Dictionary with per-layer and aggregate redundancy statistics + Dict with: + - per_layer: per-layer group stats + - aggregate: aggregated stats across layers """ - logger.info("="*60) - logger.info("ANALYZING HALO vs NON-HALO REDUNDANCY") - logger.info("="*60) - - # Get HF model - hf_model = self.wrapped_model._model if hasattr(self.wrapped_model, '_model') else self.model - - # Calibration texts - calibration_texts = [ - "The quick brown fox jumps over the lazy dog.", - "Machine learning models require careful tuning.", - "In the beginning, there was darkness, then light.", - "The stock market experienced significant volatility.", - "Scientists discovered a new species of deep-sea fish.", - "The conference will be held in San Francisco next month.", - "Programming languages continue to evolve.", - "Climate change poses challenges for future generations.", - ][:num_samples] - - results = {} - aggregate = {"halo_halo": [], "non_halo": [], "cross": []} - - for layer_name, layer_metrics in scar_scores.items(): - if "scar_activation_power" not in layer_metrics: + logger.info("=" * 60) + logger.info("ANALYZING HALO vs NON-HALO REDUNDANCY (q-signal, paper-aligned)") + logger.info("=" * 60) + + eps = 1e-8 + halo_cfg = getattr(self.config, "halo_analysis", {}) or {} + if hasattr(halo_cfg, "__dict__"): + halo_cfg = vars(halo_cfg) + + # Use positive-only redundancy when configured (matches SCAR ablation) + supernode_cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} + positive_redundancy = bool(supernode_cfg.get("positive_redundancy", False)) + if positive_redundancy: + logger.info(" Redundancy: using positive-only correlation (anti-correlation does NOT count as redundancy)") + + # Respect optional config bounds + max_pairs_per_group = int(halo_cfg.get("max_pairs_per_group", sample_pairs)) + pairs_per_group = max(1, min(int(sample_pairs), max_pairs_per_group)) + max_group_channels = int(halo_cfg.get("max_group_channels", 512)) + + # Prefer the same calibration texts used in SCAR / importance computation + calibration_texts: List[str] = [] + if getattr(self.config, "importance_computation_texts", None): + calibration_texts = list(self.config.importance_computation_texts) + elif getattr(self, "dataset", None) is not None and hasattr(self.dataset, "texts"): + calibration_texts = list(self.dataset.texts) + if not calibration_texts: + # Last-resort fallback (keeps the analysis runnable in isolation) + calibration_texts = [ + "The quick brown fox jumps over the lazy dog.", + "Machine learning models require careful tuning.", + "In the beginning, there was darkness, then light.", + "The stock market experienced significant volatility.", + "Scientists discovered a new species of deep-sea fish.", + "The conference will be held in San Francisco next month.", + "Programming languages continue to evolve.", + "Climate change poses challenges for future generations.", + ] + + if not calibration_texts: + logger.warning("No calibration texts available for halo redundancy analysis") + return {} + + num_samples = max(1, int(num_samples)) + calibration_texts = calibration_texts[: min(num_samples, len(calibration_texts))] + + # Underlying HF model for module lookup / hook registration + hf_model: nn.Module = self.model + if hasattr(hf_model, "model"): + hf_model = getattr(hf_model, "model") + module_dict = dict(hf_model.named_modules()) + + # Only analyze FFN down_proj layers (intermediate channels) + layer_names = [ln for ln in scar_scores.keys() if "mlp.down_proj" in ln] + if not layer_names: + logger.warning("No down_proj layers found in scar_scores for halo redundancy analysis") + return {} + + # Helper: sample pair positions (indices into a group of size n) + def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: + if n < 2 or p <= 0: + return torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long) + i = torch.randint(low=0, high=n, size=(p,), dtype=torch.long) + j = torch.randint(low=0, high=n, size=(p,), dtype=torch.long) + # ensure i != j (resample j where equal) + same = i == j + tries = 0 + while same.any() and tries < 10: + j[same] = torch.randint(low=0, high=n, size=(int(same.sum().item()),), dtype=torch.long) + same = i == j + tries += 1 + # if still equal, shift deterministically + if same.any(): + j[same] = (j[same] + 1) % n + return i, j + + # ------------------------------------------------------------------ + # Phase 1: Per-layer supernodes + connectivity halo (weights-only) and pair plans + # ------------------------------------------------------------------ + plan: Dict[str, Dict[str, Any]] = {} + for layer_name in layer_names: + layer_metrics = scar_scores.get(layer_name, {}) or {} + lp = layer_metrics.get("scar_loss_proxy") + if lp is None: + lp = layer_metrics.get("scar_activation_power") + if lp is None: continue - - # Get down_proj weights - down_proj_weight = None - for name, module in hf_model.named_modules(): - if name == layer_name and hasattr(module, 'weight'): - down_proj_weight = module.weight.data.float().cpu() - break - - if down_proj_weight is None: + + lp_cpu = lp.detach().float().cpu() + m = int(lp_cpu.numel()) + if m <= 0: continue - - hidden_dim, intermediate_dim = down_proj_weight.shape - - # Step 1: Identify supernodes - supernode_metric = layer_metrics["scar_activation_power"].float().cpu() - num_supernodes = max(1, int(supernode_fraction * intermediate_dim)) - _, supernode_indices = torch.topk(supernode_metric, num_supernodes) - supernode_mask = torch.zeros(intermediate_dim, dtype=torch.bool) - supernode_mask[supernode_indices] = True - - # Step 2: Compute connection strength to supernodes - connection_strength = down_proj_weight.abs().sum(dim=0) - - # Step 3: Define halo (excluding supernodes) - non_supernode_mask = ~supernode_mask - non_supernode_indices = non_supernode_mask.nonzero(as_tuple=True)[0] - non_supernode_connection = connection_strength[non_supernode_indices] - - num_halo = max(1, int(halo_fraction * len(non_supernode_indices))) - _, halo_relative_indices = torch.topk(non_supernode_connection, num_halo) - halo_indices = non_supernode_indices[halo_relative_indices] - - halo_mask = torch.zeros(intermediate_dim, dtype=torch.bool) - halo_mask[halo_indices] = True - - # Non-halo = not supernode and not halo - non_halo_mask = non_supernode_mask & ~halo_mask - non_halo_indices = non_halo_mask.nonzero(as_tuple=True)[0] - - logger.info(f" {layer_name}: {num_supernodes} supernodes, {len(halo_indices)} halo, {len(non_halo_indices)} non-halo") - - # Step 4: Capture activations - activations: List[torch.Tensor] = [] - - def capture_hook(module, inputs, outputs): - if inputs and inputs[0] is not None: - inp = inputs[0].detach().float() - if inp.ndim == 3: - inp = inp.reshape(-1, inp.shape[-1]) - activations.append(inp.cpu()) - - hook_handle = None - for name, module in hf_model.named_modules(): - if name == layer_name: - hook_handle = module.register_forward_hook(capture_hook) - break - - if hook_handle is None: + + module = module_dict.get(layer_name) + if module is None or not hasattr(module, "weight"): + logger.warning(f"Halo redundancy: could not resolve module/weight for {layer_name}") continue - - hf_model.eval() - with torch.no_grad(): - for text in calibration_texts: - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) - inputs = {k: v.to(self.config.device) for k, v in inputs.items()} - try: - hf_model(**inputs) - except Exception: - pass - - hook_handle.remove() - - if not activations: + + # Identify supernodes by LP (paper definition) + num_supernodes = max(1, int(supernode_fraction * m)) + _, super_idx = torch.topk(lp_cpu, k=num_supernodes, largest=True) + super_idx = super_idx.long() + super_mask = torch.zeros(m, dtype=torch.bool) + super_mask[super_idx] = True + + # Compute Conn_i from down_proj weights (write-pattern overlap) + W = module.weight.detach().float().cpu() # [hidden_dim, m] + abs_W = W.abs() + a = abs_W[:, super_idx].sum(dim=1) # [hidden_dim] + a_norm = a.sum() + eps + v_norm = abs_W.sum(dim=0) + eps # [m] + conn_num = (abs_W * a.unsqueeze(1)).sum(dim=0) # [m] + conn = (conn_num / (v_norm * a_norm + eps)).clamp(0.0, 1.0) + + non_super_idx = (~super_mask).nonzero(as_tuple=True)[0] + if non_super_idx.numel() < 2: continue - - all_acts = torch.cat(activations, dim=0) - - # Step 5: Compute correlation matrix - centered = all_acts - all_acts.mean(dim=0, keepdim=True) - std = centered.std(dim=0, keepdim=True) - std = torch.where(std > 1e-8, std, torch.ones_like(std)) - normalized = centered / std - corr = (normalized.T @ normalized) / (all_acts.shape[0] - 1) - corr = torch.clamp(corr, -1, 1) - abs_corr = torch.abs(corr) - abs_corr.fill_diagonal_(0) - - # Step 6: Sample correlations for each group - halo_idx = halo_indices.numpy() - non_halo_idx = non_halo_indices.numpy() - - # Halo-Halo - if len(halo_idx) > 1: - hh_corr = abs_corr[np.ix_(halo_idx, halo_idx)] - hh_vals = hh_corr[torch.triu(torch.ones_like(hh_corr), diagonal=1).bool()].numpy() - if len(hh_vals) > sample_pairs: - hh_vals = np.random.choice(hh_vals, sample_pairs, replace=False) - else: - hh_vals = np.array([]) - - # Non-halo - Non-halo - if len(non_halo_idx) > 1: - nh_corr = abs_corr[np.ix_(non_halo_idx, non_halo_idx)] - nh_vals = nh_corr[torch.triu(torch.ones_like(nh_corr), diagonal=1).bool()].numpy() - if len(nh_vals) > sample_pairs: - nh_vals = np.random.choice(nh_vals, sample_pairs, replace=False) - else: - nh_vals = np.array([]) - - # Cross-group - if len(halo_idx) > 0 and len(non_halo_idx) > 0: - cross_corr = abs_corr[np.ix_(halo_idx, non_halo_idx)] - cross_vals = cross_corr.flatten().numpy() - if len(cross_vals) > sample_pairs: - cross_vals = np.random.choice(cross_vals, sample_pairs, replace=False) - else: - cross_vals = np.array([]) - - # Store results - results[layer_name] = { - "num_supernodes": num_supernodes, - "num_halo": len(halo_idx), - "num_non_halo": len(non_halo_idx), - "halo_halo": { - "mean": float(np.mean(hh_vals)) if len(hh_vals) > 0 else 0, - "std": float(np.std(hh_vals)) if len(hh_vals) > 0 else 0, - "median": float(np.median(hh_vals)) if len(hh_vals) > 0 else 0, - }, - "non_halo": { - "mean": float(np.mean(nh_vals)) if len(nh_vals) > 0 else 0, - "std": float(np.std(nh_vals)) if len(nh_vals) > 0 else 0, - "median": float(np.median(nh_vals)) if len(nh_vals) > 0 else 0, - }, - "cross": { - "mean": float(np.mean(cross_vals)) if len(cross_vals) > 0 else 0, - "std": float(np.std(cross_vals)) if len(cross_vals) > 0 else 0, - "median": float(np.median(cross_vals)) if len(cross_vals) > 0 else 0, - }, + num_halo = max(1, int(halo_fraction * non_super_idx.numel())) + _, halo_rel = torch.topk(conn[non_super_idx], k=num_halo, largest=True) + halo_idx = non_super_idx[halo_rel].long() + + halo_mask = torch.zeros(m, dtype=torch.bool) + halo_mask[halo_idx] = True + non_halo_idx = ((~super_mask) & (~halo_mask)).nonzero(as_tuple=True)[0].long() + if halo_idx.numel() < 2 or non_halo_idx.numel() < 2: + continue + + # Subsample channels to keep the analysis lightweight and comparable across layers. + halo_sel = halo_idx + if halo_sel.numel() > max_group_channels: + perm = torch.randperm(halo_sel.numel()) + halo_sel = halo_sel[perm[:max_group_channels]] + + non_halo_target = min(int(halo_sel.numel()), int(non_halo_idx.numel()), max_group_channels) + if non_halo_target < 2: + continue + perm = torch.randperm(non_halo_idx.numel()) + non_halo_sel = non_halo_idx[perm[:non_halo_target]] + + # If the halo selection was larger, trim to match (keeps pair sampling symmetric). + if halo_sel.numel() > non_halo_sel.numel(): + halo_sel = halo_sel[: non_halo_sel.numel()] + + H = int(halo_sel.numel()) + NH = int(non_halo_sel.numel()) + if H < 2 or NH < 2: + continue + + P = int(min(pairs_per_group, H * (H - 1) // 2, NH * (NH - 1) // 2)) + if P <= 0: + continue + + hh_i_cpu, hh_j_cpu = sample_pairs_pos(H, P) + nn_i_cpu, nn_j_cpu = sample_pairs_pos(NH, P) + cross_h_cpu = torch.randint(low=0, high=H, size=(P,), dtype=torch.long) + cross_n_cpu = torch.randint(low=0, high=NH, size=(P,), dtype=torch.long) + + plan[layer_name] = { + "num_supernodes": int(num_supernodes), + "m": int(m), + "halo_idx_cpu": halo_sel, + "nonhalo_idx_cpu": non_halo_sel, + "hh_i_cpu": hh_i_cpu, + "hh_j_cpu": hh_j_cpu, + "nn_i_cpu": nn_i_cpu, + "nn_j_cpu": nn_j_cpu, + "cross_h_cpu": cross_h_cpu, + "cross_n_cpu": cross_n_cpu, + # device-side cached tensors + "halo_idx": None, + "nonhalo_idx": None, + "hh_i": None, + "hh_j": None, + "nn_i": None, + "nn_j": None, + "cross_h": None, + "cross_n": None, + # streaming sums + "sum_q_halo": None, + "sum_q2_halo": None, + "sum_q_nonhalo": None, + "sum_q2_nonhalo": None, + "sum_qij_hh": None, + "sum_qij_nn": None, + "sum_qij_cross": None, + "count": 0, } - - # Aggregate - aggregate["halo_halo"].extend(hh_vals.tolist()) - aggregate["non_halo"].extend(nh_vals.tolist()) - aggregate["cross"].extend(cross_vals.tolist()) - - # Compute aggregate statistics - aggregate_stats = {} - for group in ["halo_halo", "non_halo", "cross"]: - vals = aggregate[group] + + if not plan: + logger.warning("Halo redundancy: no eligible layers after filtering; skipping") + return {} + + # ------------------------------------------------------------------ + # Phase 2: Calibration passes (forward+backward) to accumulate q correlations + # ------------------------------------------------------------------ + hooks: List[Any] = [] + + def make_hooks(name: str): + def fwd_hook(mod: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor): + if not inputs or inputs[0] is None: + return + mod._halo_last_u = inputs[0].detach() + + def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]): + st = plan.get(name) + if st is None: + return + if not grad_input or grad_input[0] is None: + return + if not hasattr(mod, "_halo_last_u"): + return + + u = mod._halo_last_u + delattr(mod, "_halo_last_u") + g_u = grad_input[0] + + # Flatten to [N_tokens, dim] + u_flat = u.reshape(-1, u.shape[-1]) if u.ndim > 2 else u.reshape(-1, u.shape[-1]) + g_u_flat = g_u.reshape(-1, g_u.shape[-1]) if g_u.ndim > 2 else g_u.reshape(-1, g_u.shape[-1]) + if u_flat.shape != g_u_flat.shape or u_flat.numel() == 0: + return + + # Move cached indices/pairs to the correct device once + dev = u_flat.device + if st["halo_idx"] is None or st["halo_idx"].device != dev: + st["halo_idx"] = st["halo_idx_cpu"].to(device=dev) + if st["nonhalo_idx"] is None or st["nonhalo_idx"].device != dev: + st["nonhalo_idx"] = st["nonhalo_idx_cpu"].to(device=dev) + if st["hh_i"] is None or st["hh_i"].device != dev: + st["hh_i"] = st["hh_i_cpu"].to(device=dev) + st["hh_j"] = st["hh_j_cpu"].to(device=dev) + st["nn_i"] = st["nn_i_cpu"].to(device=dev) + st["nn_j"] = st["nn_j_cpu"].to(device=dev) + st["cross_h"] = st["cross_h_cpu"].to(device=dev) + st["cross_n"] = st["cross_n_cpu"].to(device=dev) + + halo_idx = st["halo_idx"] + nonhalo_idx = st["nonhalo_idx"] + idx_union = torch.cat([halo_idx, nonhalo_idx], dim=0) # [H + NH] + + try: + u_sel = u_flat.index_select(1, idx_union).float() + s_sel = g_u_flat.index_select(1, idx_union).float() + except Exception: + return + + q_sel = u_sel * s_sel + H = int(halo_idx.numel()) + q_h = q_sel[:, :H] + q_n = q_sel[:, H:] + N = int(q_sel.shape[0]) + if N <= 0: + return + + # Initialize sums on first batch + if st["sum_q_halo"] is None: + st["sum_q_halo"] = torch.zeros(H, device=dev, dtype=torch.float32) + st["sum_q2_halo"] = torch.zeros_like(st["sum_q_halo"]) + st["sum_q_nonhalo"] = torch.zeros(q_n.shape[1], device=dev, dtype=torch.float32) + st["sum_q2_nonhalo"] = torch.zeros_like(st["sum_q_nonhalo"]) + P = int(st["hh_i"].numel()) + st["sum_qij_hh"] = torch.zeros(P, device=dev, dtype=torch.float32) + st["sum_qij_nn"] = torch.zeros(P, device=dev, dtype=torch.float32) + st["sum_qij_cross"] = torch.zeros(P, device=dev, dtype=torch.float32) + + st["sum_q_halo"] += q_h.sum(dim=0) + st["sum_q2_halo"] += (q_h * q_h).sum(dim=0) + st["sum_q_nonhalo"] += q_n.sum(dim=0) + st["sum_q2_nonhalo"] += (q_n * q_n).sum(dim=0) + + # Pair cross-products (vectorized) + hh_i = st["hh_i"] + hh_j = st["hh_j"] + nn_i = st["nn_i"] + nn_j = st["nn_j"] + ch = st["cross_h"] + cn = st["cross_n"] + + if hh_i.numel() > 0: + qi = q_h.index_select(1, hh_i) + qj = q_h.index_select(1, hh_j) + st["sum_qij_hh"] += (qi * qj).sum(dim=0) + + qi = q_n.index_select(1, nn_i) + qj = q_n.index_select(1, nn_j) + st["sum_qij_nn"] += (qi * qj).sum(dim=0) + + qi = q_h.index_select(1, ch) + qj = q_n.index_select(1, cn) + st["sum_qij_cross"] += (qi * qj).sum(dim=0) + + st["count"] += N + + return fwd_hook, bwd_hook + + for layer_name, module in module_dict.items(): + if layer_name not in plan: + continue + fwd, bwd = make_hooks(layer_name) + hooks.append(module.register_forward_hook(fwd)) + hooks.append(module.register_full_backward_hook(bwd)) + + self.model.eval() + device = torch.device(self.config.device) + + try: + for idx, text in enumerate(calibration_texts): + inputs = self.tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=int(max_length), + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + labels = inputs["input_ids"].clone() + pad_token_id = getattr(self.tokenizer, "pad_token_id", None) or getattr(self.tokenizer, "eos_token_id", None) + labels[labels == pad_token_id] = -100 + inputs["labels"] = labels + + self.model.zero_grad(set_to_none=True) + out = self.model(**inputs) + loss = out.loss + loss.backward() + + if (idx + 1) % 1 == 0: + logger.info(f" Halo q-stats: processed {idx+1}/{len(calibration_texts)} samples, loss={loss.item():.4f}") + finally: + for h in hooks: + try: + h.remove() + except Exception: + pass + + # ------------------------------------------------------------------ + # Phase 3: Compute redundancy distributions and aggregate across layers + # ------------------------------------------------------------------ + per_layer: Dict[str, Dict[str, Any]] = {} + agg_vals: Dict[str, List[float]] = {"halo_halo": [], "non_halo": [], "cross": []} + + def corr_to_red(corr: torch.Tensor) -> torch.Tensor: + corr = corr.clamp(-0.9999, 0.9999) + if positive_redundancy: + corr = torch.clamp(corr, min=0.0) + rho_sq = (corr * corr).clamp(0.0, 0.9999) + return (-0.5 * torch.log(1.0 - rho_sq + eps)).float() + + for layer_name, st in plan.items(): + N = int(st.get("count", 0)) + if N <= 1 or st["sum_qij_hh"] is None: + continue + + sum_q_h = st["sum_q_halo"].detach().cpu() + sum_q2_h = st["sum_q2_halo"].detach().cpu() + sum_q_n = st["sum_q_nonhalo"].detach().cpu() + sum_q2_n = st["sum_q2_nonhalo"].detach().cpu() + + mean_h = sum_q_h / float(N) + mean_n = sum_q_n / float(N) + var_h = (sum_q2_h / float(N)) - (mean_h * mean_h) + var_n = (sum_q2_n / float(N)) - (mean_n * mean_n) + std_h = torch.sqrt(torch.clamp(var_h, min=eps)) + std_n = torch.sqrt(torch.clamp(var_n, min=eps)) + + hh_i = st["hh_i_cpu"] + hh_j = st["hh_j_cpu"] + nn_i = st["nn_i_cpu"] + nn_j = st["nn_j_cpu"] + ch = st["cross_h_cpu"] + cn = st["cross_n_cpu"] + + # E[q_i q_j] + e_hh = st["sum_qij_hh"].detach().cpu() / float(N) + e_nn = st["sum_qij_nn"].detach().cpu() / float(N) + e_cn = st["sum_qij_cross"].detach().cpu() / float(N) + + # corr (halo-halo) + cov = e_hh - (mean_h[hh_i] * mean_h[hh_j]) + corr_hh = cov / (std_h[hh_i] * std_h[hh_j] + eps) + red_hh = corr_to_red(corr_hh) + + # corr (non-halo, non-halo) + cov = e_nn - (mean_n[nn_i] * mean_n[nn_j]) + corr_nn = cov / (std_n[nn_i] * std_n[nn_j] + eps) + red_nn = corr_to_red(corr_nn) + + # corr (cross) + cov = e_cn - (mean_h[ch] * mean_n[cn]) + corr_cn = cov / (std_h[ch] * std_n[cn] + eps) + red_cn = corr_to_red(corr_cn) + + def stats(x: torch.Tensor) -> Dict[str, float]: + if x.numel() == 0: + return {"mean": 0.0, "std": 0.0, "median": 0.0, "count": 0} + return { + "mean": float(x.mean().item()), + "std": float(x.std(unbiased=False).item()), + "median": float(x.median().item()), + "count": int(x.numel()), + } + + per_layer[layer_name] = { + "num_supernodes": int(st.get("num_supernodes", 0)), + "num_halo": int(st["halo_idx_cpu"].numel()), + "num_non_halo": int(st["nonhalo_idx_cpu"].numel()), + "halo_halo": stats(red_hh), + "non_halo": stats(red_nn), + "cross": stats(red_cn), + } + + agg_vals["halo_halo"].extend(red_hh.tolist()) + agg_vals["non_halo"].extend(red_nn.tolist()) + agg_vals["cross"].extend(red_cn.tolist()) + + aggregate_stats: Dict[str, Dict[str, Any]] = {} + for group, vals in agg_vals.items(): + if not vals: + aggregate_stats[group] = {"mean": 0.0, "std": 0.0, "median": 0.0, "count": 0} + continue + arr = np.asarray(vals, dtype=np.float64) aggregate_stats[group] = { - "mean": float(np.mean(vals)) if vals else 0, - "std": float(np.std(vals)) if vals else 0, - "median": float(np.median(vals)) if vals else 0, - "count": len(vals), + "mean": float(arr.mean()), + "std": float(arr.std()), + "median": float(np.median(arr)), + "count": int(arr.size), } - - # Log summary - logger.info("\nHALO vs NON-HALO REDUNDANCY SUMMARY:") + + logger.info("\nHALO vs NON-HALO REDUNDANCY SUMMARY (q-signal):") logger.info(f" Halo-Halo: mean={aggregate_stats['halo_halo']['mean']:.4f}") logger.info(f" Non-halo: mean={aggregate_stats['non_halo']['mean']:.4f}") logger.info(f" Cross-group: mean={aggregate_stats['cross']['mean']:.4f}") - - # Interpretation - hh_mean = aggregate_stats['halo_halo']['mean'] - nh_mean = aggregate_stats['non_halo']['mean'] - cross_mean = aggregate_stats['cross']['mean'] - - if hh_mean > nh_mean * 1.2: - logger.info(" → Halo neurons MORE redundant than non-halo ✓") - elif nh_mean > hh_mean * 1.2: - logger.info(" → Non-halo neurons MORE redundant (consider revising halo definition)") - else: - logger.info(" → Similar redundancy in both groups") - - if cross_mean < min(hh_mean, nh_mean) * 0.8: - logger.info(" → Cross-group correlation LOW (groups carry different info) ✓") - + return { - "per_layer": results, + "signal": "q", + "positive_redundancy": positive_redundancy, + "pairs_per_group": pairs_per_group, + "max_group_channels": max_group_channels, + "per_layer": per_layer, "aggregate": aggregate_stats, } @@ -5220,17 +5498,24 @@ def get_metric(metric_name, fallback_size): if mi.sum() == 0: mi = taylor # Taylor score relates to information content - # Identify supernodes (top by activation power) + # Identify supernodes (paper-aligned: top by loss proxy when available) + supernode_metric = loss_proxy if loss_proxy is not None and loss_proxy.numel() == intermediate_dim else activation_power num_supernodes = max(1, int(supernode_fraction * intermediate_dim)) - _, supernode_indices = torch.topk(activation_power, num_supernodes) + _, supernode_indices = torch.topk(supernode_metric, num_supernodes) supernode_mask = torch.zeros(intermediate_dim, dtype=torch.bool) supernode_mask[supernode_indices] = True # Identify halo (high connectivity to supernodes among non-supernodes) non_supernode_mask = ~supernode_mask non_supernode_indices = non_supernode_mask.nonzero(as_tuple=True)[0] - connection_strength = down_proj_weight.abs().sum(dim=0) - non_supernode_connection = connection_strength[non_supernode_indices] + # Paper-aligned Conn using overlap with aggregated supernode write pattern + abs_W = down_proj_weight.abs() + a = abs_W[:, supernode_indices].sum(dim=1) + a_norm = a.sum() + 1e-8 + v_norm = abs_W.sum(dim=0) + 1e-8 + conn_num = (abs_W * a.unsqueeze(1)).sum(dim=0) + conn = (conn_num / (v_norm * a_norm + 1e-8)).clamp(0.0, 1.0) + non_supernode_connection = conn[non_supernode_indices] num_halo = max(1, int(halo_fraction * len(non_supernode_indices))) _, halo_relative_indices = torch.topk(non_supernode_connection, num_halo) @@ -5836,6 +6121,8 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm "scar_loss_proxy", "scar_activation_power", "scar_taylor", "scar_curvature", # Supernode/connectivity metrics "directed_redundancy", "supernode_protection_score", "supernode_connectivity_score", + # Random baseline (scores are generated and stored in importance_scores) + "random", # Weight-only structured baseline (channel-group weight magnitude) "weight_magnitude", # Generalized importance (no outlier assumption) @@ -5871,9 +6158,8 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm # Get importance scores scores = self.importance_scores[layer_name][metric].clone() - supernode_cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} core_mask = self.importance_scores[layer_name].get("supernode_mask") - if supernode_cfg.get("enabled") and supernode_cfg.get("protect_core", True) and core_mask is not None: + if core_mask is not None and self._should_protect_supernodes_for_metric(metric): margin = torch.abs(scores).max().detach().item() + 1.0 if mode == "low": scores[core_mask] = scores.max() + margin @@ -6007,6 +6293,7 @@ def _prune_attention_layers( mode=mode, sparsity=sparsity, layer_key=ref_layer, + metric=metric, ) if neuron_mask is None: continue @@ -6062,6 +6349,7 @@ def _create_attention_neuron_mask( mode: str, sparsity: float, layer_key: str, + metric: str, ) -> Tuple[Optional[torch.Tensor], Optional[int], Optional[int]]: """ Convert per-neuron attention scores into a shared mask aligned with heads. @@ -6070,9 +6358,9 @@ def _create_attention_neuron_mask( scores = scores.flatten() device = scores.device - supernode_cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} core_mask = self.importance_scores.get(layer_key, {}).get("supernode_mask") - if supernode_cfg.get("enabled") and supernode_cfg.get("protect_core", True) and core_mask is not None: + do_protect = core_mask is not None and self._should_protect_supernodes_for_metric(metric) + if do_protect: margin = torch.abs(scores).max().detach().item() + 1.0 if mode == "low": scores[core_mask] = scores.max() + margin @@ -6110,7 +6398,7 @@ def _create_attention_neuron_mask( head_keep = MaskOperations.create_structured_mask(head_scores, amount=sparsity, mode=mode) # Ensure that any head containing a protected core neuron is always kept. - if core_mask is not None and core_mask.numel() == scores.numel(): + if do_protect and core_mask is not None and core_mask.numel() == scores.numel(): core_heads = core_mask.view(num_heads, head_dim).any(dim=1) if core_heads.any(): head_keep = head_keep | core_heads.to(head_keep.device) diff --git a/src/alignment/metrics/__init__.py b/src/alignment/metrics/__init__.py index 7c0903a7..79948a7d 100644 --- a/src/alignment/metrics/__init__.py +++ b/src/alignment/metrics/__init__.py @@ -2,7 +2,7 @@ Metrics for measuring neural network alignment, redundancy, and synergy. ============================================================================= -METRIC TAXONOMY (from alignment_notes/main.tex and new.tex) +METRIC TAXONOMY (from drafts/alignment_notes/alignment_red.tex) ============================================================================= 1. ALIGNMENT METRICS (Rayleigh Quotient based) diff --git a/src/alignment/metrics/halo_redundancy.py b/src/alignment/metrics/halo_redundancy.py index 9ce998c9..f9240781 100644 --- a/src/alignment/metrics/halo_redundancy.py +++ b/src/alignment/metrics/halo_redundancy.py @@ -147,7 +147,7 @@ def correlation_to_redundancy(corr: torch.Tensor) -> torch.Tensor: """ Convert correlation to redundancy using Gaussian MI formula. - Theory (from alignment_notes/new.tex, Eq. 5.1): + Theory (from drafts/alignment_notes/alignment_red.tex): I(Y_i; Y_j) = -0.5 * log(1 - ρ²) This is the mutual information between jointly Gaussian variables. diff --git a/src/alignment/metrics/information/gaussian_mi.py b/src/alignment/metrics/information/gaussian_mi.py index b13d89b7..1300067b 100644 --- a/src/alignment/metrics/information/gaussian_mi.py +++ b/src/alignment/metrics/information/gaussian_mi.py @@ -289,7 +289,7 @@ def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional # - RQ = (w^T Σ_x w) / (w^T w) -- normalizes by weight norm (scale-invariant) # - MI = 0.5 * log(1 + (w^T Σ_x w) / σ_n²) -- uses raw signal variance! # - # From the theory (see alignment_notes/main.tex Section 3): + # From the theory (see drafts/alignment_notes/alignment_red.tex): # For noisy linear neuron y = w^T X + n where n ~ N(0, σ_n²): # I(X; y) = 0.5 * log(1 + (w^T Σ_X w) / σ_n²) # diff --git a/src/alignment/metrics/information/synergy_continuous.py b/src/alignment/metrics/information/synergy_continuous.py index ec93383a..97798359 100644 --- a/src/alignment/metrics/information/synergy_continuous.py +++ b/src/alignment/metrics/information/synergy_continuous.py @@ -104,7 +104,22 @@ def compute( if outputs.ndim > 2: # Conv layer: [B, C, H, W] -> [B, C] via GAP outputs = outputs.mean(dim=(2, 3)) if outputs.ndim == 4 else outputs.reshape(outputs.shape[0], -1) - + + # Handle batch mismatch (common when upstream preprocessing unfolds CNN outputs) + # If outputs has more samples than logits/labels, aggregate back to per-example activations. + # This makes synergy w.r.t. per-example target T well-defined. + if outputs.shape[0] != logits.shape[0]: + if outputs.shape[0] % logits.shape[0] == 0: + num_patches = outputs.shape[0] // logits.shape[0] + n_neurons = outputs.shape[1] + outputs = outputs.view(logits.shape[0], num_patches, n_neurons).mean(dim=1) + else: + logger.warning( + f"SynergyContinuousTarget: Batch mismatch outputs={outputs.shape[0]} vs logits={logits.shape[0]}; " + "cannot safely aggregate. Returning zeros." + ) + return torch.zeros(outputs.shape[-1], device=device, dtype=dtype) + batch_size, n_neurons = outputs.shape # Compute continuous target T diff --git a/src/alignment/preprocessing/layer_preprocessing.py b/src/alignment/preprocessing/layer_preprocessing.py index 73cf7e49..cfc0a65f 100644 --- a/src/alignment/preprocessing/layer_preprocessing.py +++ b/src/alignment/preprocessing/layer_preprocessing.py @@ -115,18 +115,51 @@ def _unfold_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: boo Returns: Tensor of shape [batch_size * num_patches, features] """ - b, c, h, w = activation.shape - - if is_input and isinstance(layer, nn.Conv2d): - # For inputs, unfold based on the layer's kernel parameters - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) - # [b, features*kernel_size, num_patches] -> [b*num_patches, features] - unfolded = unfolded.transpose(1, 2).contiguous() - return unfolded.view(-1, unfolded.size(2)) - else: - # For outputs or non-conv layers, just flatten spatial dims - return activation.reshape(b, c, -1).permute(0, 2, 1).reshape(-1, c) + if isinstance(layer, nn.Conv2d): + if activation.ndim != 4: + raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") + + b, c, h, w = activation.shape + + if is_input: + # Unfold based on the layer's kernel parameters so feature dimension matches weight flattening + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + # [b, features, num_patches] -> [b*num_patches, features] + unfolded = unfolded.transpose(1, 2).contiguous() + return unfolded.view(-1, unfolded.size(2)) + + # Output: treat each spatial location as a sample (node = output channel) + # [b, c, h, w] -> [b*h*w, c] + return activation.permute(0, 2, 3, 1).reshape(-1, c) + + if isinstance(layer, nn.Conv1d): + if activation.ndim != 3: + raise ValueError(f"Expected 3D tensor for Conv1d, got {activation.ndim}D") + + b, c, l = activation.shape + + if is_input: + # Use 2D unfold trick on [b, c, 1, l] to respect stride/padding/dilation + x4 = activation.unsqueeze(2) # [b, c, 1, l] + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + unfolded = torch.nn.functional.unfold( + x4, + kernel_size=(1, k), + dilation=(1, d), + padding=(0, p), + stride=(1, s), + ) # [b, c*k, num_patches] + unfolded = unfolded.transpose(1, 2).contiguous() + return unfolded.view(-1, unfolded.size(2)) # [b*num_patches, c*k] + + # Output: [b, c, l] -> [b*l, c] + return activation.permute(0, 2, 1).reshape(-1, c) + + raise ValueError(f"Expected Conv layer, got {type(layer)}") def _patchwise_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: bool) -> torch.Tensor: """ @@ -135,16 +168,47 @@ def _patchwise_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: Returns: Tensor of shape [batch_size, features, num_patches] """ - b, c, h, w = activation.shape + if isinstance(layer, nn.Conv2d): + if activation.ndim != 4: + raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") + + b, c, h, w = activation.shape + + if is_input: + # Unfold to get kernel patches + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + return unfolded # [b, features, patches] + + # Output: reshape spatial dims to patches (node = output channel) + return activation.reshape(b, c, h * w) # [b, c, patches] + + if isinstance(layer, nn.Conv1d): + if activation.ndim != 3: + raise ValueError(f"Expected 3D tensor for Conv1d, got {activation.ndim}D") + + b, c, l = activation.shape + + if is_input: + # Unfold 1D input into kernel patches: [b, c*k, patches] + x4 = activation.unsqueeze(2) # [b, c, 1, l] + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + unfolded = torch.nn.functional.unfold( + x4, + kernel_size=(1, k), + dilation=(1, d), + padding=(0, p), + stride=(1, s), + ) + return unfolded # [b, c*k, patches] + + # Output: already [b, c, l] = [b, c, patches] + return activation - if is_input and isinstance(layer, nn.Conv2d): - # Unfold to get patches - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) - return unfolded # [b, features, patches] - else: - # For outputs, reshape spatial dims to patches - return activation.reshape(b, c, h * w) + raise ValueError(f"Expected Conv layer, got {type(layer)}") def _batch_patch_combined_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: bool) -> torch.Tensor: """ @@ -166,30 +230,48 @@ def _get_unfold_params(self, layer: nn.Module) -> Dict[str, Any]: def get_output_shape(self, input_shape: Tuple[int, ...], layer: nn.Module) -> Tuple[int, ...]: """Get expected output shape after preprocessing.""" - if len(input_shape) != 4: - raise ValueError(f"Expected 4D input shape, got {len(input_shape)}D") - - b, c, h, w = input_shape - - if self.mode == "unfold" or self.mode == "batch_patch_combined": - if isinstance(layer, nn.Conv2d): - # Calculate number of patches - out_h = (h + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 - out_w = (w + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 - num_patches = out_h * out_w - features = c * layer.kernel_size[0] * layer.kernel_size[1] + if isinstance(layer, nn.Conv2d): + if len(input_shape) != 4: + raise ValueError(f"Expected 4D input shape for Conv2d, got {len(input_shape)}D") + b, c, h, w = input_shape + + # Output spatial size (PyTorch conv2d formula; floor division) + k_h, k_w = layer.kernel_size + s_h, s_w = layer.stride + p_h, p_w = layer.padding + d_h, d_w = layer.dilation + out_h = (h + 2 * p_h - d_h * (k_h - 1) - 1) // s_h + 1 + out_w = (w + 2 * p_w - d_w * (k_w - 1) - 1) // s_w + 1 + num_patches = max(0, out_h) * max(0, out_w) + features = c * k_h * k_w + + if self.mode in {"unfold", "batch_patch_combined"}: return (b * num_patches, features) - else: - return (b * h * w, c) - elif self.mode == "patchwise": - if isinstance(layer, nn.Conv2d): - out_h = (h + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 - out_w = (w + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 - num_patches = out_h * out_w - features = c * layer.kernel_size[0] * layer.kernel_size[1] + if self.mode == "patchwise": return (b, features, num_patches) - else: - return (b, c, h * w) + + raise ValueError(f"Unknown mode: {self.mode}") + + if isinstance(layer, nn.Conv1d): + if len(input_shape) != 3: + raise ValueError(f"Expected 3D input shape for Conv1d, got {len(input_shape)}D") + b, c, l = input_shape + k = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size + s = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + p = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + d = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + out_l = (l + 2 * p - d * (k - 1) - 1) // s + 1 + num_patches = max(0, out_l) + features = c * k + + if self.mode in {"unfold", "batch_patch_combined"}: + return (b * num_patches, features) + if self.mode == "patchwise": + return (b, features, num_patches) + + raise ValueError(f"Unknown mode: {self.mode}") + + raise ValueError(f"Expected Conv layer, got {type(layer)}") class AttentionPreprocessor(LayerPreprocessor): diff --git a/src/alignment/services/activation_capture.py b/src/alignment/services/activation_capture.py index 86c06cc1..7784d07e 100644 --- a/src/alignment/services/activation_capture.py +++ b/src/alignment/services/activation_capture.py @@ -80,39 +80,98 @@ def capture( # Capture activations using model wrapper try: + # Prefer capturing RAW activations (we do preprocessing here for consistency) # Prefer passing explicit layers where supported - output, activations = self.model_wrapper.forward_with_activations(input_batch, layers=layers) + output, activations = self.model_wrapper.forward_with_activations( + input_batch, layers=layers, preprocess=False + ) except TypeError as e: - # Backwards compatibility: some wrappers don't accept 'layers' kwarg - if "unexpected keyword argument 'layers'" in str(e): - logger.warning( - "Model wrapper.forward_with_activations does not accept 'layers'; " - "capturing activations for all tracked layers instead." - ) - output, activations = self.model_wrapper.forward_with_activations(input_batch) - else: + # Backwards compatibility: wrappers may not accept 'layers' and/or 'preprocess' + msg = str(e) + try: + if "unexpected keyword argument 'preprocess'" in msg: + # Retry without preprocess kwarg + output, activations = self.model_wrapper.forward_with_activations(input_batch, layers=layers) + elif "unexpected keyword argument 'layers'" in msg: + logger.warning( + "Model wrapper.forward_with_activations does not accept 'layers'; " + "capturing activations for all tracked layers instead." + ) + # Retry without layers (and without preprocess, which may also be unsupported) + try: + output, activations = self.model_wrapper.forward_with_activations(input_batch, preprocess=False) + except TypeError: + output, activations = self.model_wrapper.forward_with_activations(input_batch) + else: + raise + except Exception: logger.error(f"Failed to capture activations: {e}") raise except Exception as e: logger.error(f"Failed to capture activations: {e}") raise - # Separate inputs and outputs - inputs = {} - outputs = {} + # Preprocess if requested (unified, layer-aware preprocessing) + processed_activations = activations + if preprocess and mode not in {"none", "preserve_spatial"}: + if mode in {"unfold", "patchwise", "batch_patch_combined"}: + try: + # Use the canonical preprocessing utilities that distinguish _input vs _output + from alignment.dataops.processing import preprocess_layer_activations + + model = getattr(self.model_wrapper, "model", None) or getattr(self.model_wrapper, "_model", None) + if model is None: + raise AttributeError("Model wrapper has no .model or ._model attribute") + + layer_modules = dict(model.named_modules()) + + # Only preprocess the activations we will actually consume + to_process = {} + for layer in layers: + in_key = f"{layer}_input" + out_key = f"{layer}_output" + if in_key in activations: + to_process[in_key] = activations[in_key] + if out_key in activations: + to_process[out_key] = activations[out_key] + # Legacy compatibility: some wrappers store outputs under the bare layer name + if out_key not in activations and layer in activations: + to_process[layer] = activations[layer] + + processed_activations = preprocess_layer_activations(to_process, layer_modules, mode=mode) + except Exception as e: + logger.warning(f"Layer-aware preprocessing failed ({e}); falling back to simple flatten") + processed_activations = activations + elif mode == "flatten": + # Simple flattening: [B, ...] -> [B, -1] + processed_activations = {} + for name, tensor in activations.items(): + if isinstance(tensor, torch.Tensor) and tensor.ndim > 2: + processed_activations[name] = tensor.reshape(tensor.shape[0], -1) + else: + processed_activations[name] = tensor + else: + logger.warning(f"Unknown preprocessing mode '{mode}', using 'flatten'") + processed_activations = {} + for name, tensor in activations.items(): + if isinstance(tensor, torch.Tensor) and tensor.ndim > 2: + processed_activations[name] = tensor.reshape(tensor.shape[0], -1) + else: + processed_activations[name] = tensor + + # Separate inputs and outputs (keys normalized to layer names) + inputs: Dict[str, torch.Tensor] = {} + outputs: Dict[str, torch.Tensor] = {} for layer in layers: input_key = f"{layer}_input" output_key = f"{layer}_output" - if input_key in activations: - inputs[layer] = activations[input_key] - if output_key in activations: - outputs[layer] = activations[output_key] - - # Preprocess if requested - if preprocess: - inputs = self._preprocess_activations(inputs, mode) - outputs = self._preprocess_activations(outputs, mode) + if input_key in processed_activations: + inputs[layer] = processed_activations[input_key] + if output_key in processed_activations: + outputs[layer] = processed_activations[output_key] + elif layer in processed_activations: + outputs[layer] = processed_activations[layer] # Capture weights if requested weights = {} diff --git a/tests/unit/metrics/test_synergy_continuous_target.py b/tests/unit/metrics/test_synergy_continuous_target.py new file mode 100644 index 00000000..978c12f3 --- /dev/null +++ b/tests/unit/metrics/test_synergy_continuous_target.py @@ -0,0 +1,24 @@ +import torch + +from alignment.metrics.information.synergy_continuous import SynergyContinuousTarget + + +def test_synergy_continuous_target_aggregates_unfolded_outputs(device): + """ + When CNN preprocessing expands outputs to [B*P, C], synergy w.r.t. a per-example + target must aggregate back to [B, C]. We verify that repeating each example's + activations P times yields identical synergy scores after aggregation. + """ + B, P, C, n_classes = 8, 7, 5, 10 + logits = torch.randn(B, n_classes, device=device) + labels = torch.randint(0, n_classes, (B,), device=device) + + outputs_base = torch.randn(B, C, device=device) + outputs_unfolded = outputs_base.repeat_interleave(P, dim=0) # [B*P, C] + + metric = SynergyContinuousTarget(target_type="logit_margin", num_pairs=2, sampling_strategy="top_k") + s_base = metric.compute(outputs=outputs_base, logits=logits, labels=labels) + s_unfolded = metric.compute(outputs=outputs_unfolded, logits=logits, labels=labels) + + torch.testing.assert_close(s_base, s_unfolded, rtol=1e-5, atol=1e-6) + diff --git a/tests/unit/services/test_activation_capture.py b/tests/unit/services/test_activation_capture.py new file mode 100644 index 00000000..feaad965 --- /dev/null +++ b/tests/unit/services/test_activation_capture.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + +from alignment.models import ModelWrapper +from alignment.services.activation_capture import ActivationCaptureService + + +def test_activation_capture_conv2d_unfold_matches_conv(device): + """ + For Conv2d with bias=False, verify that: + - inputs are unfolded into patches [B*P, C_in*kH*kW] + - outputs are spatial-flattened [B*P, C_out] + - outputs == inputs @ W^T (exact conv equivalence via unfold) + """ + + class SimpleConv(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, x): + return self.conv(x) + + model = SimpleConv().to(device) + wrapper = ModelWrapper(model, tracked_layers=["conv"], preprocessing_mode="auto") + service = ActivationCaptureService(wrapper, default_mode="unfold") + + B, C, H, W = 2, 3, 8, 8 + x = torch.randn(B, C, H, W, device=device) + + data = service.capture(x, layers=["conv"], include_weights=True, preprocess=True) + + # Shapes + assert "conv" in data.inputs + assert "conv" in data.outputs + assert "conv" in data.weights + + # For stride=1,pad=1,k=3 => H_out=W_out=H=W, P=H*W + P = H * W + F = C * 3 * 3 + assert tuple(data.inputs["conv"].shape) == (B * P, F) + assert tuple(data.outputs["conv"].shape) == (B * P, 4) + assert tuple(data.weights["conv"].shape) == (4, F) + + # Numerical equivalence: conv(x) == unfold(x) @ W^T + pred = data.inputs["conv"] @ data.weights["conv"].T + torch.testing.assert_close(pred, data.outputs["conv"], rtol=1e-4, atol=1e-5) + + +def test_activation_capture_conv2d_patchwise_matches_conv(device): + """ + Patchwise mode keeps patches separate: + inputs: [B, F, P] + outputs: [B, C_out, P] + and should still satisfy patchwise linear equivalence for bias=False. + """ + + class SimpleConv(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 5, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, x): + return self.conv(x) + + model = SimpleConv().to(device) + wrapper = ModelWrapper(model, tracked_layers=["conv"], preprocessing_mode="auto") + service = ActivationCaptureService(wrapper, default_mode="patchwise") + + B, C, H, W = 2, 3, 6, 7 + x = torch.randn(B, C, H, W, device=device) + + data = service.capture(x, layers=["conv"], include_weights=True, preprocess=True) + + P = H * W + F = C * 3 * 3 + assert tuple(data.inputs["conv"].shape) == (B, F, P) + assert tuple(data.outputs["conv"].shape) == (B, 5, P) + assert tuple(data.weights["conv"].shape) == (5, F) + + # Compare in [B, P, C_out] form + x_patches = data.inputs["conv"].permute(0, 2, 1) # [B, P, F] + y_pred = x_patches @ data.weights["conv"].T # [B, P, C_out] + y_true = data.outputs["conv"].permute(0, 2, 1) # [B, P, C_out] + torch.testing.assert_close(y_pred, y_true, rtol=1e-4, atol=1e-5) + + +def test_activation_capture_conv1d_unfold_matches_conv(device): + """ + Conv1d support: unfold mode should produce inputs [B*P, C_in*k] and + outputs [B*P, C_out], matching the Conv1d forward when bias=False. + """ + + class SimpleConv1d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(2, 3, kernel_size=5, stride=2, padding=2, dilation=1, bias=False) + + def forward(self, x): + return self.conv(x) + + model = SimpleConv1d().to(device) + wrapper = ModelWrapper(model, tracked_layers=["conv"], preprocessing_mode="auto") + service = ActivationCaptureService(wrapper, default_mode="unfold") + + B, C, L = 2, 2, 17 + x = torch.randn(B, C, L, device=device) + + data = service.capture(x, layers=["conv"], include_weights=True, preprocess=True) + + # Output length: floor((L + 2p - d*(k-1) - 1)/s + 1) + k, s, p, d = 5, 2, 2, 1 + L_out = (L + 2 * p - d * (k - 1) - 1) // s + 1 + P = L_out + F = C * k + + assert tuple(data.inputs["conv"].shape) == (B * P, F) + assert tuple(data.outputs["conv"].shape) == (B * P, 3) + assert tuple(data.weights["conv"].shape) == (3, F) + + pred = data.inputs["conv"] @ data.weights["conv"].T + torch.testing.assert_close(pred, data.outputs["conv"], rtol=1e-4, atol=1e-5) + From 6b7415263cb747d21f1a3045e5379c8a8f3fb691 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 15:32:02 -0500 Subject: [PATCH 09/12] fix nan/inf --- .gitignore | 3 + configs/prune_llm/README.md | 2 +- scripts/README.md | 11 +- scripts/analyze_halo_redundancy.py | 410 ------------ scripts/collect_paper_artifacts.py | 604 ------------------ scripts/generate_scar_paper_tables.py | 152 ----- scripts/run_experiment.py | 96 ++- slurm-54745862.out | 37 ++ slurm_jobs/prune_llm/paper/README.md | 43 -- .../paper/run_llama3_8b_calibration_array.sh | 91 --- .../paper/run_llama3_8b_noprotect.sh | 71 -- ...run_llama3_8b_positive_redundancy_array.sh | 80 --- .../paper/run_llama3_8b_protect_baselines.sh | 72 --- slurm_jobs/prune_llm/paper/submit_suite.sh | 50 -- slurm_jobs/prune_llm/run_all_paper.sh | 72 --- slurm_jobs/prune_llm/run_llama2_7b.sh | 73 --- slurm_jobs/prune_llm/run_llama3_8b.sh | 80 --- slurm_jobs/prune_llm/run_mistral_7b.sh | 73 --- slurm_jobs/prune_llm/run_qwen2_7b.sh | 74 --- .../run_cluster_analysis_resnet18.sh | 81 --- .../run_cluster_analysis_resnet50.sh | 76 --- slurm_jobs/run_comprehensive_pruning.sh | 102 --- slurm_jobs/run_llama3_full_benchmark.sh | 73 --- slurm_jobs/run_llama3_scar_pruning.sh | 53 -- slurm_jobs/run_minitron_comparison.sh | 92 --- slurm_jobs/run_multimodel_pruning.sh | 88 --- slurm_jobs/run_paper_experiments.sh | 118 ---- slurm_jobs/run_supernode_robustness.sh | 91 --- .../analysis/clustering/cross_layer_halo.py | 7 +- .../analysis/visualization/__init__.py | 13 + .../analysis/visualization/cluster_plots.py | 120 ++++ .../analysis/visualization/paper_plots.py | 312 +++++++++ src/alignment/configs/config_loader.py | 4 +- src/alignment/experiments/llm_experiments.py | 219 ++++++- 34 files changed, 764 insertions(+), 2779 deletions(-) delete mode 100644 scripts/analyze_halo_redundancy.py delete mode 100644 scripts/collect_paper_artifacts.py delete mode 100644 scripts/generate_scar_paper_tables.py create mode 100644 slurm-54745862.out delete mode 100644 slurm_jobs/prune_llm/paper/README.md delete mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh delete mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh delete mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh delete mode 100644 slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh delete mode 100644 slurm_jobs/prune_llm/paper/submit_suite.sh delete mode 100755 slurm_jobs/prune_llm/run_all_paper.sh delete mode 100755 slurm_jobs/prune_llm/run_llama2_7b.sh delete mode 100755 slurm_jobs/prune_llm/run_llama3_8b.sh delete mode 100755 slurm_jobs/prune_llm/run_mistral_7b.sh delete mode 100755 slurm_jobs/prune_llm/run_qwen2_7b.sh delete mode 100644 slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh delete mode 100644 slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh delete mode 100644 slurm_jobs/run_comprehensive_pruning.sh delete mode 100644 slurm_jobs/run_llama3_full_benchmark.sh delete mode 100755 slurm_jobs/run_llama3_scar_pruning.sh delete mode 100755 slurm_jobs/run_minitron_comparison.sh delete mode 100644 slurm_jobs/run_multimodel_pruning.sh delete mode 100644 slurm_jobs/run_paper_experiments.sh delete mode 100644 slurm_jobs/run_supernode_robustness.sh create mode 100644 src/alignment/analysis/visualization/paper_plots.py diff --git a/.gitignore b/.gitignore index 49a061d8..29d16233 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Archive and draft directories _arxiv/ +# Ignore drafts by default, but keep the SCAR paper folder tracked (it has its own .gitignore). drafts/ +!drafts/LLM_prune/ +!drafts/LLM_prune/** checkpoints/ results/ logs/ diff --git a/configs/prune_llm/README.md b/configs/prune_llm/README.md index b8f85fc7..849efcd5 100644 --- a/configs/prune_llm/README.md +++ b/configs/prune_llm/README.md @@ -15,7 +15,7 @@ Configurations for generating results in the SCAR LLM pruning paper. Run all experiments: ```bash -bash slurm_jobs/prune_llm/run_all_paper.sh +bash drafts/LLM_prune/paper/slurm/run_all_paper.sh ``` Run single model: diff --git a/scripts/README.md b/scripts/README.md index 37fba2ad..b1808bc7 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -11,7 +11,7 @@ Run experiments from YAML configuration: python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml # LLM analysis -python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml +python scripts/run_experiment.py --config configs/prune_llm/llama3_8b_full.yaml # Cluster-based analysis python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml @@ -41,10 +41,7 @@ Options: - `--analyses LIST` - Specific analyses to run - `--quick` - Run all analyses with defaults -## analyze_halo_redundancy.py +## Paper-specific helpers -Specialized script for halo redundancy analysis: - -```bash -python scripts/analyze_halo_redundancy.py --results-dir ./results --output-dir ./plots -``` +The SCAR/LLM-pruning paper batch scripts and artifact collectors live under: +- `drafts/LLM_prune/paper/` diff --git a/scripts/analyze_halo_redundancy.py b/scripts/analyze_halo_redundancy.py deleted file mode 100644 index c6830334..00000000 --- a/scripts/analyze_halo_redundancy.py +++ /dev/null @@ -1,410 +0,0 @@ -#!/usr/bin/env python3 -""" -Analyze redundancy patterns between halo and non-halo neurons. - -This script computes pairwise correlations/redundancy between: -1. Both neurons in halo (halo-halo) -2. Both neurons NOT in halo (non-halo-non-halo) -3. One neuron in halo, one not (halo-non-halo / cross-group) - -This helps understand whether: -- Halo neurons are indeed redundant with each other -- Non-halo neurons are also redundant (suggesting they too could be pruned safely) -- Halo vs non-halo neurons carry independent information -""" - -import argparse -import json -import os -import sys -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -from tqdm import tqdm - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - - -def compute_redundancy_matrix(activations: torch.Tensor) -> torch.Tensor: - """Compute pairwise |correlation| matrix.""" - # Center - centered = activations - activations.mean(dim=0, keepdim=True) - # Normalize - std = centered.std(dim=0, keepdim=True) - std = torch.where(std > 1e-8, std, torch.ones_like(std)) - normalized = centered / std - # Correlation - corr = (normalized.T @ normalized) / (activations.shape[0] - 1) - corr = torch.clamp(corr, -1, 1) - return torch.abs(corr) - - -def analyze_layer_redundancy( - model, - tokenizer, - layer_idx: int, - calibration_texts: list, - supernode_fraction: float = 0.01, - halo_fraction: float = 0.1, - device: str = "cuda", - sample_size: int = 1000, # Sample for efficiency -): - """Analyze redundancy patterns for a single layer.""" - - # Find the layer - hf_model = model - layer_name = f"model.layers.{layer_idx}.mlp.down_proj" - - # Get down_proj weight - down_proj = None - for name, module in hf_model.named_modules(): - if f"layers.{layer_idx}.mlp.down_proj" in name: - down_proj = module.weight.data.float() # [hidden_dim, intermediate_dim] - break - - if down_proj is None: - print(f"Could not find down_proj for layer {layer_idx}") - return None - - hidden_dim, intermediate_dim = down_proj.shape - - # Step 1: Identify supernodes (by weight magnitude as proxy for activation power) - neuron_magnitude = down_proj.abs().sum(dim=0) # [intermediate_dim] - num_supernodes = max(1, int(supernode_fraction * intermediate_dim)) - _, supernode_indices = torch.topk(neuron_magnitude, num_supernodes) - supernode_mask = torch.zeros(intermediate_dim, dtype=torch.bool) - supernode_mask[supernode_indices] = True - - # Step 2: Compute connection strength to supernodes - supernode_weights = down_proj[:, supernode_indices] # [hidden_dim, num_supernodes] - - # For intermediate neurons: connection = how much their outputs influence hidden dims that supernodes also influence - # Simpler: use sum of absolute weights in down_proj row as proxy - connection_strength = down_proj.abs().sum(dim=0) # [intermediate_dim] - - # Step 3: Define halo (excluding supernodes) - non_supernode_indices = (~supernode_mask).nonzero(as_tuple=True)[0] - non_supernode_connection = connection_strength[non_supernode_indices] - - num_halo = max(1, int(halo_fraction * len(non_supernode_indices))) - _, halo_relative_indices = torch.topk(non_supernode_connection, num_halo) - halo_indices = non_supernode_indices[halo_relative_indices] - - halo_mask = torch.zeros(intermediate_dim, dtype=torch.bool) - halo_mask[halo_indices] = True - - # Non-halo = not supernode and not halo - non_halo_mask = ~supernode_mask & ~halo_mask - non_halo_indices = non_halo_mask.nonzero(as_tuple=True)[0] - - print(f"Layer {layer_idx}: {num_supernodes} supernodes, {len(halo_indices)} halo, {len(non_halo_indices)} non-halo") - - # Step 4: Capture activations - activations = [] - - def capture_hook(module, inputs, outputs): - if inputs and inputs[0] is not None: - inp = inputs[0].detach().float() - if inp.ndim == 3: - inp = inp.reshape(-1, inp.shape[-1]) - activations.append(inp.cpu()) - - hook_handle = None - for name, module in hf_model.named_modules(): - if f"layers.{layer_idx}.mlp.down_proj" in name: - hook_handle = module.register_forward_hook(capture_hook) - break - - if hook_handle is None: - print(f"Could not hook layer {layer_idx}") - return None - - hf_model.eval() - with torch.no_grad(): - for text in tqdm(calibration_texts[:20], desc=f"Layer {layer_idx} activations"): - inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) - inputs = {k: v.to(device) for k, v in inputs.items()} - try: - hf_model(**inputs) - except Exception as e: - pass - - hook_handle.remove() - - if not activations: - print(f"No activations captured for layer {layer_idx}") - return None - - all_acts = torch.cat(activations, dim=0) # [N, intermediate_dim] - print(f" Captured {all_acts.shape[0]} activation samples") - - # Step 5: Sample for efficiency - if all_acts.shape[0] > sample_size: - indices = torch.randperm(all_acts.shape[0])[:sample_size] - all_acts = all_acts[indices] - - # Step 6: Compute full correlation matrix - print(" Computing correlation matrix...") - corr_matrix = compute_redundancy_matrix(all_acts) - corr_matrix.fill_diagonal_(0) # Exclude self-correlations - - # Step 7: Extract correlations for different groups - halo_idx_np = halo_indices.numpy() - non_halo_idx_np = non_halo_indices.numpy() - - # Sample indices for efficiency (if too many) - max_per_group = 500 - if len(halo_idx_np) > max_per_group: - halo_idx_np = np.random.choice(halo_idx_np, max_per_group, replace=False) - if len(non_halo_idx_np) > max_per_group: - non_halo_idx_np = np.random.choice(non_halo_idx_np, max_per_group, replace=False) - - # Halo-Halo correlations - halo_halo_corr = corr_matrix[np.ix_(halo_idx_np, halo_idx_np)] - halo_halo_values = halo_halo_corr[torch.triu(torch.ones_like(halo_halo_corr), diagonal=1).bool()].numpy() - - # Non-halo - Non-halo correlations - non_halo_non_halo_corr = corr_matrix[np.ix_(non_halo_idx_np, non_halo_idx_np)] - non_halo_values = non_halo_non_halo_corr[torch.triu(torch.ones_like(non_halo_non_halo_corr), diagonal=1).bool()].numpy() - - # Cross-group: halo - non-halo correlations - cross_corr = corr_matrix[np.ix_(halo_idx_np, non_halo_idx_np)] - cross_values = cross_corr.flatten().numpy() - - results = { - "layer_idx": layer_idx, - "num_supernodes": num_supernodes, - "num_halo": len(halo_idx_np), - "num_non_halo": len(non_halo_idx_np), - "halo_halo": { - "mean": float(np.mean(halo_halo_values)), - "std": float(np.std(halo_halo_values)), - "median": float(np.median(halo_halo_values)), - "values": halo_halo_values.tolist()[:5000], # Limit for storage - }, - "non_halo_non_halo": { - "mean": float(np.mean(non_halo_values)), - "std": float(np.std(non_halo_values)), - "median": float(np.median(non_halo_values)), - "values": non_halo_values.tolist()[:5000], - }, - "cross_group": { - "mean": float(np.mean(cross_values)), - "std": float(np.std(cross_values)), - "median": float(np.median(cross_values)), - "values": cross_values.tolist()[:5000], - }, - } - - return results - - -def plot_redundancy_comparison(results_list: list, output_path: str): - """Create comparison plots for redundancy patterns.""" - - fig, axes = plt.subplots(2, 2, figsize=(14, 12)) - - # Plot 1: Histogram comparison for all layers combined - ax = axes[0, 0] - all_halo_halo = [] - all_non_halo = [] - all_cross = [] - - for r in results_list: - all_halo_halo.extend(r["halo_halo"]["values"]) - all_non_halo.extend(r["non_halo_non_halo"]["values"]) - all_cross.extend(r["cross_group"]["values"]) - - bins = np.linspace(0, 1, 50) - ax.hist(all_halo_halo, bins=bins, alpha=0.5, label=f'Halo-Halo (μ={np.mean(all_halo_halo):.3f})', density=True, color='#e74c3c') - ax.hist(all_non_halo, bins=bins, alpha=0.5, label=f'Non-halo (μ={np.mean(all_non_halo):.3f})', density=True, color='#3498db') - ax.hist(all_cross, bins=bins, alpha=0.5, label=f'Cross-group (μ={np.mean(all_cross):.3f})', density=True, color='#2ecc71') - ax.set_xlabel("|Correlation|") - ax.set_ylabel("Density") - ax.set_title("Pairwise |Correlation| Distribution (All Layers)") - ax.legend() - ax.grid(True, alpha=0.3) - - # Plot 2: Mean redundancy by layer - ax = axes[0, 1] - layers = [r["layer_idx"] for r in results_list] - halo_means = [r["halo_halo"]["mean"] for r in results_list] - non_halo_means = [r["non_halo_non_halo"]["mean"] for r in results_list] - cross_means = [r["cross_group"]["mean"] for r in results_list] - - x = np.arange(len(layers)) - width = 0.25 - ax.bar(x - width, halo_means, width, label='Halo-Halo', color='#e74c3c', alpha=0.8) - ax.bar(x, non_halo_means, width, label='Non-halo', color='#3498db', alpha=0.8) - ax.bar(x + width, cross_means, width, label='Cross-group', color='#2ecc71', alpha=0.8) - ax.set_xlabel("Layer") - ax.set_ylabel("Mean |Correlation|") - ax.set_title("Mean Redundancy by Layer") - ax.set_xticks(x) - ax.set_xticklabels([f"L{l}" for l in layers]) - ax.legend() - ax.grid(True, alpha=0.3, axis='y') - - # Plot 3: Box plots for each group - ax = axes[1, 0] - data_to_plot = [ - all_halo_halo[:2000], # Sample for visibility - all_non_halo[:2000], - all_cross[:2000], - ] - bp = ax.boxplot(data_to_plot, labels=['Halo-Halo', 'Non-halo', 'Cross-group'], patch_artist=True) - colors = ['#e74c3c', '#3498db', '#2ecc71'] - for patch, color in zip(bp['boxes'], colors): - patch.set_facecolor(color) - patch.set_alpha(0.6) - ax.set_ylabel("|Correlation|") - ax.set_title("Redundancy Distribution by Group") - ax.grid(True, alpha=0.3, axis='y') - - # Plot 4: Summary statistics table - ax = axes[1, 1] - ax.axis('off') - - summary_text = """ -REDUNDANCY ANALYSIS SUMMARY -========================== - -Group | Mean | Std | Median -------------------|---------|---------|-------- -Halo-Halo | {:.4f} | {:.4f} | {:.4f} -Non-halo-Non-halo | {:.4f} | {:.4f} | {:.4f} -Cross (Halo-NonH) | {:.4f} | {:.4f} | {:.4f} - -KEY INSIGHTS: -------------- -• If Halo-Halo >> Non-halo: Halo neurons are more redundant - → Current approach is correct - -• If Non-halo ~ Halo: Non-halo neurons also redundant - → Could prune more aggressively - -• If Cross << within-group: Groups carry different info - → Important to distinguish halo vs non-halo - -• If Halo-Halo high, Cross low: Halo is an "echo chamber" - → Safe to prune redundant halo members -""".format( - np.mean(all_halo_halo), np.std(all_halo_halo), np.median(all_halo_halo), - np.mean(all_non_halo), np.std(all_non_halo), np.median(all_non_halo), - np.mean(all_cross), np.std(all_cross), np.median(all_cross), - ) - - ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=10, - verticalalignment='top', fontfamily='monospace', - bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) - - plt.tight_layout() - plt.savefig(output_path, dpi=150, bbox_inches='tight') - plt.close() - - print(f"Saved plot to {output_path}") - - -def main(): - parser = argparse.ArgumentParser(description="Analyze halo vs non-halo redundancy patterns") - parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B", - help="Model name or path") - parser.add_argument("--layers", type=str, default="8,12,16,20,24", - help="Comma-separated layer indices to analyze") - parser.add_argument("--supernode-fraction", type=float, default=0.01, - help="Fraction of neurons to consider as supernodes") - parser.add_argument("--halo-fraction", type=float, default=0.1, - help="Fraction of non-supernodes to consider as halo") - parser.add_argument("--output-dir", type=str, default="./halo_analysis", - help="Output directory for plots and results") - parser.add_argument("--device", type=str, default="cuda", - help="Device to use") - args = parser.parse_args() - - # Create output directory - os.makedirs(args.output_dir, exist_ok=True) - - # Load model - print(f"Loading model: {args.model}") - from transformers import AutoModelForCausalLM, AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - args.model, - torch_dtype=torch.float16, - device_map=args.device, - trust_remote_code=True, - ) - - # Calibration texts - calibration_texts = [ - "The quick brown fox jumps over the lazy dog.", - "Machine learning models require careful tuning of hyperparameters.", - "In the beginning, there was darkness, and then there was light.", - "The stock market experienced significant volatility today.", - "Scientists have discovered a new species of deep-sea fish.", - "The conference will be held in San Francisco next month.", - "Programming languages continue to evolve with new features.", - "Climate change poses significant challenges for future generations.", - "The recipe calls for two cups of flour and one cup of sugar.", - "Historical documents reveal interesting details about past civilizations.", - ] * 5 # Repeat for more samples - - # Analyze layers - layers = [int(l.strip()) for l in args.layers.split(",")] - results_list = [] - - for layer_idx in layers: - print(f"\nAnalyzing layer {layer_idx}...") - result = analyze_layer_redundancy( - model, tokenizer, layer_idx, calibration_texts, - supernode_fraction=args.supernode_fraction, - halo_fraction=args.halo_fraction, - device=args.device, - ) - if result: - results_list.append(result) - - # Save results - results_path = os.path.join(args.output_dir, "halo_redundancy_results.json") - with open(results_path, "w") as f: - json.dump(results_list, f, indent=2) - print(f"\nSaved results to {results_path}") - - # Create plots - plot_path = os.path.join(args.output_dir, "halo_redundancy_comparison.png") - plot_redundancy_comparison(results_list, plot_path) - - # Print summary - print("\n" + "="*60) - print("SUMMARY") - print("="*60) - - all_halo = [v for r in results_list for v in r["halo_halo"]["values"]] - all_non_halo = [v for r in results_list for v in r["non_halo_non_halo"]["values"]] - all_cross = [v for r in results_list for v in r["cross_group"]["values"]] - - print(f"Halo-Halo redundancy: mean={np.mean(all_halo):.4f}, std={np.std(all_halo):.4f}") - print(f"Non-halo redundancy: mean={np.mean(all_non_halo):.4f}, std={np.std(all_non_halo):.4f}") - print(f"Cross-group redundancy: mean={np.mean(all_cross):.4f}, std={np.std(all_cross):.4f}") - - # Interpretation - print("\nINTERPRETATION:") - if np.mean(all_halo) > np.mean(all_non_halo) * 1.2: - print("✓ Halo neurons are MORE redundant than non-halo → Current approach valid") - elif np.mean(all_halo) < np.mean(all_non_halo) * 0.8: - print("✗ Non-halo neurons are MORE redundant → Consider revising halo definition") - else: - print("≈ Similar redundancy in both groups → May need different selection criteria") - - if np.mean(all_cross) < np.mean(all_halo) * 0.8: - print("✓ Cross-group correlation LOW → Groups carry different information") - else: - print("≈ Cross-group correlation similar → Information not well separated by halo") - - -if __name__ == "__main__": - main() diff --git a/scripts/collect_paper_artifacts.py b/scripts/collect_paper_artifacts.py deleted file mode 100644 index a96bb643..00000000 --- a/scripts/collect_paper_artifacts.py +++ /dev/null @@ -1,604 +0,0 @@ -#!/usr/bin/env python3 -""" -Collect SCAR paper artifacts (tables + key figures) from experiment job directories. - -This script is meant to be run AFTER the SLURM batch suite finishes. - -What it does: -- Finds the latest job directory for each expected experiment name under --results-base -- Loads the corresponding results_*.json -- Generates LaTeX table snippets into drafts/LLM_prune/paper_artifacts/tables/ -- Copies a small set of "paper figure" images into drafts/LLM_prune/ as the placeholder_*.png files - used by `drafts/LLM_prune/scar_paper_icml_v5.tex` (so the paper auto-fills without manual edits). - -Example: - python scripts/collect_paper_artifacts.py \ - --results-base /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM \ - --draft-dir /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/drafts/LLM_prune -""" - -from __future__ import annotations - -import argparse -import json -import shutil -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple - - -# ----------------------------- -# Helpers -# ----------------------------- - - -def _latest_results_json(job_dir: Path) -> Optional[Path]: - candidates: List[Path] = [] - if (job_dir / "results").exists(): - candidates.extend(sorted((job_dir / "results").glob("results_*.json"))) - candidates.extend(sorted(job_dir.glob("results_*.json"))) - if not candidates: - return None - candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) - return candidates[0] - - -def find_latest_job_dir(results_base: Path, experiment_name: str) -> Optional[Path]: - """ - Job directories are created as: - {experiment_name}_{timestamp}_{job_id}/ - """ - if not results_base.exists(): - return None - candidates = [p for p in results_base.iterdir() if p.is_dir() and p.name.startswith(f"{experiment_name}_")] - candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) - for job_dir in candidates: - if _latest_results_json(job_dir) is not None: - return job_dir - return None - - -def load_results(job_dir: Path) -> Dict[str, Any]: - path = _latest_results_json(job_dir) - if path is None: - raise FileNotFoundError(f"No results_*.json found in {job_dir} or {job_dir/'results'}") - return json.loads(path.read_text()) - - -def _get_pruned_entry( - pruning_results: Dict[str, Any], - metric: str, - mode: str, - sparsity: float, -) -> Optional[Dict[str, Any]]: - for v in pruning_results.values(): - if not isinstance(v, dict): - continue - if v.get("metric") == metric and v.get("mode") == mode and float(v.get("sparsity", -1)) == float(sparsity): - return v - return None - - -def _pick_mode( - pruning_results: Dict[str, Any], - metric: str, - sparsity: float, - mode: str, -) -> Tuple[str, Optional[Dict[str, Any]]]: - """ - mode: - - "low" or "high": choose that mode - - "best": choose the better of low/high by perplexity - """ - if mode in {"low", "high"}: - return mode, _get_pruned_entry(pruning_results, metric=metric, mode=mode, sparsity=sparsity) - - low = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) - high = _get_pruned_entry(pruning_results, metric=metric, mode="high", sparsity=sparsity) - - def ppl(x: Optional[Dict[str, Any]]) -> float: - if not x: - return float("inf") - v = x.get("perplexity") - return float(v) if v is not None else float("inf") - - if ppl(low) <= ppl(high): - return "low", low - return "high", high - - -def _fmt(x: Any, digits: int = 1) -> str: - if x is None: - return "--" - try: - xf = float(x) - except Exception: - return "--" - if xf != xf: # NaN - return "--" - return f"{xf:.{digits}f}" - - -def write_text(path: Path, content: str) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content) - - -def latex_tabular(rows: List[List[str]], align_spec: str) -> str: - lines = [] - lines.append("% Auto-generated by scripts/collect_paper_artifacts.py") - lines.append(f"\\begin{{tabular}}{{{align_spec}}}") - lines.append("\\toprule") - for i, r in enumerate(rows): - lines.append(" & ".join(r) + " \\\\") - if i == 0: - lines.append("\\midrule") - lines.append("\\bottomrule") - lines.append("\\end{tabular}") - lines.append("") - return "\n".join(lines) - - -def safe_copy(src: Path, dst: Path) -> bool: - if not src.exists(): - return False - dst.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(src, dst) - return True - - -# ----------------------------- -# Table generators -# ----------------------------- - - -def make_table_main_50( - llama3_main: Dict[str, Any], - llama3_noprotect: Optional[Dict[str, Any]], - llama3_protect_baselines: Optional[Dict[str, Any]], - out_path: Path, - sparsity: float = 0.5, - mode: str = "best", -) -> None: - pruning_main = llama3_main.get("pruning_results") or {} - evaluation_main = llama3_main.get("evaluation") or {} - - # Columns used in the draft Table 1 - cols: List[Tuple[str, str, int]] = [ - ("PPL$\\downarrow$", "perplexity", 1), - ("MMLU", "accuracy_mmlu", 1), - ("Hella", "accuracy_hellaswag", 1), - ("PIQA", "accuracy_piqa", 1), - ("BoolQ", "accuracy_boolq", 1), - ] - - rows: List[List[str]] = [] - rows.append(["Method"] + [h for h, _, _ in cols]) - - # Unpruned row - baseline_ppl = evaluation_main.get("baseline_perplexity") - baseline_metrics = evaluation_main.get("baseline_metrics") or {} - unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} - rows.append(["Unpruned"] + [_fmt(unpruned_vals.get(k), d) for _, k, d in cols]) - - # Main rows from the "main" run - methods_main: List[Tuple[str, str]] = [ - ("Magnitude (channel)", "weight_magnitude"), - ("Wanda (channel)", "wanda"), - ("SparseGPT (channel)", "sparsegpt"), - ("Act. L2", "activation_l2_norm"), - ("RQ", "rayleigh_quotient"), - ("SCAR-LP", "scar_loss_proxy"), - ("SCAR-Prot", "supernode_protection_score"), - ("SCAR-Conn", "supernode_connectivity_score"), - ] - - for label, metric in methods_main: - picked_mode, entry = _pick_mode(pruning_main, metric=metric, sparsity=sparsity, mode=mode) - if entry is None: - rows.append([label] + ["--"] * len(cols)) - continue - rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) - - # Optional extra rows (if those experiments are present) - if llama3_noprotect is not None: - pr = llama3_noprotect.get("pruning_results") or {} - picked_mode, entry = _pick_mode(pr, metric="scar_loss_proxy", sparsity=sparsity, mode="low") - if entry is not None: - rows.append(["LP-no-protect"] + [_fmt(entry.get(k), d) for _, k, d in cols]) - - if llama3_protect_baselines is not None: - pr = llama3_protect_baselines.get("pruning_results") or {} - for label, metric in [("Protect+Magnitude", "weight_magnitude"), ("Protect+Wanda", "wanda")]: - picked_mode, entry = _pick_mode(pr, metric=metric, sparsity=sparsity, mode="low") - if entry is not None: - rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) - - content = latex_tabular(rows, align_spec="@{}l" + "c" * len(cols) + "@{}") - write_text(out_path, content) - - -def make_table_sparsity_tradeoff( - llama3_main: Dict[str, Any], - out_path: Path, - sparsities: List[float], - mode: str = "low", -) -> None: - pruning = llama3_main.get("pruning_results") or {} - baseline_ppl = (llama3_main.get("evaluation") or {}).get("baseline_perplexity") - - methods: List[Tuple[str, str]] = [ - ("Wanda (channel)", "wanda"), - ("SparseGPT (channel)", "sparsegpt"), - ("SCAR-LP", "scar_loss_proxy"), - ("SCAR-Prot", "supernode_protection_score"), - ("SCAR-Conn", "supernode_connectivity_score"), - ] - - header = ["Method"] + [f"{int(100*s)}\\%" for s in sparsities] - rows: List[List[str]] = [header] - if baseline_ppl is not None: - rows.append(["Baseline"] + [_fmt(baseline_ppl, 1)] * len(sparsities)) - - for label, metric in methods: - row = [label] - for s in sparsities: - picked_mode, entry = _pick_mode(pruning, metric=metric, sparsity=s, mode=mode) - row.append(_fmt(entry.get("perplexity") if entry else None, 1)) - rows.append(row) - - content = latex_tabular(rows, align_spec="@{}l" + "c" * len(sparsities) + "@{}") - write_text(out_path, content) - - -def make_table_generalization( - model_results: List[Tuple[str, Dict[str, Any]]], - out_path: Path, - sparsity: float = 0.5, - mode: str = "best", -) -> None: - rows: List[List[str]] = [] - rows.append(["Model", "Method", "PPL$\\downarrow$", "MMLU", "Avg.$\\uparrow$"]) - - for model_label, res in model_results: - pr = res.get("pruning_results") or {} - # Wanda baseline - _, wanda = _pick_mode(pr, metric="wanda", sparsity=sparsity, mode=mode) - _, scar = _pick_mode(pr, metric="supernode_connectivity_score", sparsity=sparsity, mode=mode) - - def avg_acc(entry: Optional[Dict[str, Any]]) -> Optional[float]: - if not entry: - return None - accs = [float(v) for k, v in entry.items() if isinstance(k, str) and k.startswith("accuracy_") and v is not None] - return sum(accs) / len(accs) if accs else None - - for method_label, entry in [("Wanda (channel)", wanda), ("SCAR-Conn", scar)]: - rows.append( - [ - model_label, - method_label, - _fmt(entry.get("perplexity") if entry else None, 1), - _fmt(entry.get("accuracy_mmlu") if entry else None, 1), - _fmt(avg_acc(entry), 1), - ] - ) - - content = latex_tabular(rows, align_spec="@{}llccc@{}") - write_text(out_path, content) - - -def make_table_halo_redundancy( - llama3_main: Dict[str, Any], - out_path: Path, -) -> None: - halo = llama3_main.get("halo_analysis") or {} - agg = halo.get("aggregate") or {} - - # The current analysis uses |corr|; we expose it directly for the draft table. - rows: List[List[str]] = [] - rows.append(["Group Pair", "Mean", "Std"]) - - def get(group: str, key: str) -> Any: - return (agg.get(group) or {}).get(key) - - rows.append(["Within-Halo", _fmt(get("halo_halo", "mean"), 3), _fmt(get("halo_halo", "std"), 3)]) - rows.append(["Within-Non-Halo", _fmt(get("non_halo", "mean"), 3), _fmt(get("non_halo", "std"), 3)]) - rows.append(["Cross (Halo $\\leftrightarrow$ Non-Halo)", _fmt(get("cross", "mean"), 3), _fmt(get("cross", "std"), 3)]) - - content = latex_tabular(rows, align_spec="@{}lcc@{}") - write_text(out_path, content) - - -def make_table_full_benchmarks_50( - llama3_main: Dict[str, Any], - out_path: Path, - sparsity: float = 0.5, - mode: str = "best", -) -> None: - """ - Appendix table: a wider benchmark set at a single sparsity. - """ - pruning = llama3_main.get("pruning_results") or {} - evaluation = llama3_main.get("evaluation") or {} - - cols: List[Tuple[str, str, int]] = [ - ("PPL$\\downarrow$", "perplexity", 1), - ("MMLU", "accuracy_mmlu", 1), - ("Hella", "accuracy_hellaswag", 1), - ("PIQA", "accuracy_piqa", 1), - ("BoolQ", "accuracy_boolq", 1), - ("WinoG", "accuracy_winogrande", 1), - ("ARC-E", "accuracy_arc_easy", 1), - ("ARC-C", "accuracy_arc_challenge", 1), - ("OBQA", "accuracy_openbookqa", 1), - ] - - rows: List[List[str]] = [] - rows.append(["Method"] + [h for h, _, _ in cols]) - - baseline_ppl = evaluation.get("baseline_perplexity") - baseline_metrics = evaluation.get("baseline_metrics") or {} - unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} - rows.append(["Unpruned"] + [_fmt(unpruned_vals.get(k), d) for _, k, d in cols]) - - methods: List[Tuple[str, str]] = [ - ("Random", "random"), - ("Magnitude (channel)", "weight_magnitude"), - ("Wanda (channel)", "wanda"), - ("SparseGPT (channel)", "sparsegpt"), - ("Act. L2", "activation_l2_norm"), - ("RQ", "rayleigh_quotient"), - ("Gaussian MI (analytic)", "gaussian_mi_analytic"), - ("SCAR-LP", "scar_loss_proxy"), - ("SCAR-Prot", "supernode_protection_score"), - ("SCAR-Conn", "supernode_connectivity_score"), - ] - - for label, metric in methods: - picked_mode, entry = _pick_mode(pruning, metric=metric, sparsity=sparsity, mode=mode if metric != "random" else "low") - if entry is None: - rows.append([label] + ["--"] * len(cols)) - continue - rows.append([label] + [_fmt(entry.get(k), d) for _, k, d in cols]) - - content = latex_tabular(rows, align_spec="@{}l" + "c" * len(cols) + "@{}") - write_text(out_path, content) - - -def make_table_supernode_control( - llama3_main: Dict[str, Any], - llama3_noprotect: Optional[Dict[str, Any]], - out_path: Path, - sparsity: float = 0.5, -) -> None: - rows: List[List[str]] = [] - rows.append(["Strategy", "PPL", "Relative"]) - - pr_main = llama3_main.get("pruning_results") or {} - _, scar = _pick_mode(pr_main, metric="supernode_connectivity_score", sparsity=sparsity, mode="low") - _, wanda = _pick_mode(pr_main, metric="wanda", sparsity=sparsity, mode="low") - - remove_core_early_ppl = None - if llama3_noprotect is not None: - pr = llama3_noprotect.get("pruning_results") or {} - high = _get_pruned_entry(pr, metric="scar_loss_proxy", mode="high", sparsity=sparsity) - if high is not None: - remove_core_early_ppl = high.get("perplexity") - - rows.append(["Remove supernodes early", _fmt(remove_core_early_ppl, 1), "$\\gg$ worse"]) - rows.append(["Wanda (channel)", _fmt(wanda.get("perplexity") if wanda else None, 1), "worse"]) - rows.append(["SCAR-Conn (protect supernodes)", _fmt(scar.get("perplexity") if scar else None, 1), "best"]) - - content = latex_tabular(rows, align_spec="@{}lcc@{}") - write_text(out_path, content) - - -def make_table_calibration_sensitivity( - results_base: Path, - prefix: str, - out_path: Path, - sparsity: float = 0.5, -) -> None: - """ - Collect all runs whose experiment name starts with `prefix` and build a calibration sensitivity table. - Intended to work with the job-array naming convention from: - slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh - """ - if not results_base.exists(): - return - - job_dirs = [p for p in results_base.iterdir() if p.is_dir() and p.name.startswith(f"{prefix}_")] - if not job_dirs: - return - - # Keep the most recent run per (dataset_name, n_samples) - best: Dict[Tuple[str, int], Tuple[float, Dict[str, Any]]] = {} - - for job_dir in sorted(job_dirs, key=lambda p: p.stat().st_mtime, reverse=True): - try: - res = load_results(job_dir) - except Exception: - continue - - cfg = res.get("config") or {} - dataset_name = str(cfg.get("dataset_name", "unknown")) - n = cfg.get("alignment_data_num_samples") - try: - n_int = int(n) - except Exception: - continue - - pr = res.get("pruning_results") or {} - entry = _get_pruned_entry(pr, metric="supernode_connectivity_score", mode="low", sparsity=sparsity) - ppl = None if entry is None else entry.get("perplexity") - try: - ppl_f = float(ppl) if ppl is not None else float("inf") - except Exception: - ppl_f = float("inf") - - key = (dataset_name, n_int) - if key not in best: - best[key] = (job_dir.stat().st_mtime, {"dataset": dataset_name, "n": n_int, "ppl": ppl_f}) - - if not best: - return - - pretty_name = { - "wikitext": "WikiText-2", - "c4": "C4", - "mixed_wikitext_c4": "Mixed (Wiki + C4)", - "mixed_wiki_c4": "Mixed (Wiki + C4)", - "mixed": "Mixed (Wiki + C4)", - } - - rows: List[List[str]] = [] - rows.append(["Dataset", "\\# seqs", "PPL"]) - - # Sort rows for readability: wikitext first, then c4, then others; within by n desc. - def sort_key(item: Tuple[Tuple[str, int], Tuple[float, Dict[str, Any]]]) -> Tuple[int, str, int]: - (ds, n), _ = item - group = 0 if ds == "wikitext" else 1 if ds == "c4" else 2 - return (group, ds, -n) - - for (_, _), (_, rec) in sorted(best.items(), key=sort_key): - ds = str(rec["dataset"]) - rows.append([pretty_name.get(ds, ds), str(rec["n"]), _fmt(rec["ppl"], 1)]) - - content = latex_tabular(rows, align_spec="@{}llc@{}") - write_text(out_path, content) - - -# ----------------------------- -# Main -# ----------------------------- - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("--results-base", required=True, type=str, help="Base directory containing job dirs (OUTPUT_BASE)") - ap.add_argument( - "--draft-dir", - required=True, - type=str, - help="Path to drafts/LLM_prune (where placeholder_*.png live)", - ) - ap.add_argument( - "--artifacts-dir", - type=str, - default=None, - help="Where to write paper_artifacts/ (default: {draft-dir}/paper_artifacts)", - ) - ap.add_argument("--sparsity", type=float, default=0.5, help="Primary sparsity level for tables (default: 0.5)") - ap.add_argument("--mode", type=str, default="best", choices=["low", "high", "best"], help="Mode selection for tables") - - # Canonical experiment names (these are the ExperimentConfig.name values) - ap.add_argument("--llama3-main-name", type=str, default="llama3_8b_paper_results") - ap.add_argument("--mistral-main-name", type=str, default="mistral_7b_paper_results") - ap.add_argument("--llama2-main-name", type=str, default="llama2_7b_paper_results") - ap.add_argument("--qwen2-main-name", type=str, default="qwen2_7b_paper_results") - ap.add_argument("--llama3-noprotect-name", type=str, default="llama3_8b_paper_results_noprotect") - ap.add_argument("--llama3-protect-baselines-name", type=str, default="llama3_8b_paper_results_protect_baselines") - ap.add_argument("--llama3-calib-prefix", type=str, default="llama3_8b_paper_results_calib") - - args = ap.parse_args() - - results_base = Path(args.results_base) - draft_dir = Path(args.draft_dir) - artifacts_dir = Path(args.artifacts_dir) if args.artifacts_dir else (draft_dir / "paper_artifacts") - tables_dir = artifacts_dir / "tables" - - # Locate + load main runs - llama3_main_dir = find_latest_job_dir(results_base, args.llama3_main_name) - if llama3_main_dir is None: - raise FileNotFoundError(f"Could not find a run for '{args.llama3_main_name}' under {results_base}") - llama3_main = load_results(llama3_main_dir) - - mistral_dir = find_latest_job_dir(results_base, args.mistral_main_name) - llama2_dir = find_latest_job_dir(results_base, args.llama2_main_name) - qwen2_dir = find_latest_job_dir(results_base, args.qwen2_main_name) - - mistral = load_results(mistral_dir) if mistral_dir else None - llama2 = load_results(llama2_dir) if llama2_dir else None - qwen2 = load_results(qwen2_dir) if qwen2_dir else None - - # Optional control runs - llama3_noprotect_dir = find_latest_job_dir(results_base, args.llama3_noprotect_name) - llama3_protect_baselines_dir = find_latest_job_dir(results_base, args.llama3_protect_baselines_name) - llama3_noprotect = load_results(llama3_noprotect_dir) if llama3_noprotect_dir else None - llama3_protect_baselines = load_results(llama3_protect_baselines_dir) if llama3_protect_baselines_dir else None - - # Generate tables - make_table_main_50( - llama3_main=llama3_main, - llama3_noprotect=llama3_noprotect, - llama3_protect_baselines=llama3_protect_baselines, - out_path=tables_dir / "table_main_50.tex", - sparsity=float(args.sparsity), - mode=args.mode, - ) - make_table_sparsity_tradeoff( - llama3_main=llama3_main, - out_path=tables_dir / "table_sparsity_ppl.tex", - sparsities=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - mode="low", - ) - make_table_supernode_control( - llama3_main=llama3_main, - llama3_noprotect=llama3_noprotect, - out_path=tables_dir / "table_supernode_control.tex", - sparsity=float(args.sparsity), - ) - make_table_halo_redundancy( - llama3_main=llama3_main, - out_path=tables_dir / "table_halo_redundancy.tex", - ) - make_table_full_benchmarks_50( - llama3_main=llama3_main, - out_path=tables_dir / "table_full_benchmarks_50.tex", - sparsity=float(args.sparsity), - mode=args.mode, - ) - make_table_calibration_sensitivity( - results_base=results_base, - prefix=args.llama3_calib_prefix, - out_path=tables_dir / "table_calibration_sensitivity.tex", - sparsity=float(args.sparsity), - ) - - model_rows: List[Tuple[str, Dict[str, Any]]] = [] - if mistral is not None: - model_rows.append(("Mistral-7B", mistral)) - if llama2 is not None: - model_rows.append(("Llama-2-7B", llama2)) - if qwen2 is not None: - model_rows.append(("Qwen2-7B", qwen2)) - if model_rows: - make_table_generalization(model_rows, out_path=tables_dir / "table_generalization_50.tex", sparsity=float(args.sparsity)) - - # Copy key figures into the draft placeholders (no LaTeX edits needed) - figs = Path(llama3_main_dir) / "figures" - mapping = [ - (figs / "pruning" / "pruning_comparison.png", draft_dir / "placeholder_sparsity_curves.png"), - (figs / "histograms" / "histogram_scar_loss_proxy.png", draft_dir / "placeholder_supernode_distribution.png"), - (figs / "supernode_summary" / "halo_nonhalo_metrics_by_layer.png", draft_dir / "placeholder_halo_redundancy.png"), - (figs / "supernode_summary" / "supernode_outlier_zscores.png", draft_dir / "placeholder_supernode_analysis.png"), - ] - - copied = 0 - for src, dst in mapping: - if safe_copy(src, dst): - copied += 1 - - print("\n=== Paper artifacts collected ===") - print(f"Main run: {llama3_main_dir}") - print(f"Artifacts dir: {artifacts_dir}") - print(f"Tables written: {tables_dir}") - print(f"Placeholder figures copied into draft: {copied}/{len(mapping)}") - if llama3_noprotect_dir: - print(f"Found noprotect control: {llama3_noprotect_dir}") - if llama3_protect_baselines_dir: - print(f"Found protect-baselines: {llama3_protect_baselines_dir}") - - -if __name__ == "__main__": - main() - diff --git a/scripts/generate_scar_paper_tables.py b/scripts/generate_scar_paper_tables.py deleted file mode 100644 index d388a92b..00000000 --- a/scripts/generate_scar_paper_tables.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -""" -Generate LaTeX tables for the SCAR ICML draft from a saved experiment results JSON. - -Why: -- Avoid manual copy/paste drift between `results_*.json` and `drafts/LLM_prune/scar_paper_icml_v5.tex`. -- Make it easy to update tables after rerunning experiments. - -Usage: - python scripts/generate_scar_paper_tables.py \ - --results /abs/path/to/results_YYYYMMDD_HHMMSS.json \ - --sparsity 0.5 \ - --best-mode -""" - -from __future__ import annotations - -import argparse -import json -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - - -METHODS = [ - ("Magnitude (channel)", "weight_magnitude"), - ("Wanda (channel)", "wanda"), - ("SparseGPT (channel)", "sparsegpt"), - ("Act. L2", "activation_l2_norm"), - ("RQ", "rayleigh_quotient"), - ("SCAR-LP", "scar_loss_proxy"), - ("SCAR-Prot", "supernode_protection_score"), - ("SCAR-Conn", "supernode_connectivity_score"), -] - - -def _get_pruned_entry( - pruning_results: Dict[str, Any], - metric: str, - mode: str, - sparsity: float, -) -> Optional[Dict[str, Any]]: - for v in pruning_results.values(): - if not isinstance(v, dict): - continue - if v.get("metric") == metric and v.get("mode") == mode and v.get("sparsity") == sparsity: - return v - return None - - -def _pick_mode( - pruning_results: Dict[str, Any], - metric: str, - sparsity: float, - best_mode: bool, -) -> Tuple[str, Optional[Dict[str, Any]]]: - if not best_mode: - entry = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) - return "low", entry - - low = _get_pruned_entry(pruning_results, metric=metric, mode="low", sparsity=sparsity) - high = _get_pruned_entry(pruning_results, metric=metric, mode="high", sparsity=sparsity) - - def ppl(x: Optional[Dict[str, Any]]) -> float: - if not x: - return float("inf") - v = x.get("perplexity") - return float(v) if v is not None else float("inf") - - if ppl(low) <= ppl(high): - return "low", low - return "high", high - - -def _fmt(x: Any, digits: int = 1) -> str: - if x is None: - return "--" - try: - xf = float(x) - except Exception: - return "--" - if xf != xf: # NaN - return "--" - return f"{xf:.{digits}f}" - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("--results", required=True, type=str, help="Path to results_*.json") - ap.add_argument("--sparsity", type=float, default=0.5, help="Target sparsity level (default: 0.5)") - ap.add_argument( - "--best-mode", - action="store_true", - help="Pick the better of low/high for each method by perplexity.", - ) - args = ap.parse_args() - - path = Path(args.results) - obj = json.loads(path.read_text()) - - pruning_results = obj.get("pruning_results") or {} - evaluation = obj.get("evaluation") or {} - - baseline_metrics = evaluation.get("baseline_metrics") or {} - baseline_ppl = evaluation.get("baseline_perplexity") - - # Metrics used in the draft's Table 1 - cols = [ - ("PPL$\\downarrow$", "perplexity", 1), - ("MMLU", "accuracy_mmlu", 1), - ("Hella", "accuracy_hellaswag", 1), - ("PIQA", "accuracy_piqa", 1), - ("BoolQ", "accuracy_boolq", 1), - ] - - print("% Auto-generated by scripts/generate_scar_paper_tables.py") - print("\\begin{tabular}{@{}l" + "c" * len(cols) + "@{}}") - print("\\toprule") - print("Method & " + " & ".join(h for h, _, _ in cols) + " \\\\") - print("\\midrule") - - # Unpruned row - unpruned_vals = {"perplexity": baseline_ppl, **(baseline_metrics if isinstance(baseline_metrics, dict) else {})} - print( - "Unpruned & " - + " & ".join(_fmt(unpruned_vals.get(k), d) for _, k, d in cols) - + " \\\\" - ) - print("\\midrule") - - for label, metric in METHODS: - mode, entry = _pick_mode(pruning_results, metric=metric, sparsity=args.sparsity, best_mode=args.best_mode) - if entry is None: - row = [label] + ["--"] * len(cols) - print(row[0] + " & " + " & ".join(row[1:]) + " \\\\") - continue - - vals = {k: entry.get(k) for _, k, _ in cols} - # Be robust: some evaluators store e.g. 'accuracy_mmlu' etc; keep '--' if missing. - print( - f"{label} & " - + " & ".join(_fmt(vals.get(k), d) for _, k, d in cols) - + f" % ({metric}, {mode})" - + " \\\\" - ) - - print("\\bottomrule") - print("\\end{tabular}") - - -if __name__ == "__main__": - main() - diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 8b2afe5e..5fa8a8ea 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -138,7 +138,26 @@ def _get_nested(obj, key, default): dataset_name=getattr(config, "dataset_name", dataset_cfg.get("name", "cifar10") if isinstance(dataset_cfg, dict) else "cifar10"), n_calibration=getattr(config, "n_calibration", metrics_cfg.get("n_calibration_samples", 5000) if isinstance(metrics_cfg, dict) else 5000), n_clusters=getattr(config, "n_clusters", clustering_cfg.get("n_clusters", 4) if isinstance(clustering_cfg, dict) else 4), + activation_samples=getattr( + config, + "activation_samples", + metrics_cfg.get("activation_samples", "flatten_spatial") if isinstance(metrics_cfg, dict) else "flatten_spatial", + ), + spatial_samples_per_image=int( + getattr( + config, + "spatial_samples_per_image", + metrics_cfg.get("spatial_samples_per_image", 16) if isinstance(metrics_cfg, dict) else 16, + ) + ), synergy_target=getattr(config, "synergy_target", metrics_cfg.get("synergy_target", "logit_margin") if isinstance(metrics_cfg, dict) else "logit_margin"), + synergy_candidate_pool=int( + getattr( + config, + "synergy_candidate_pool", + metrics_cfg.get("synergy_candidate_pool", 50) if isinstance(metrics_cfg, dict) else 50, + ) + ), synergy_pairs=getattr(config, "synergy_pairs", metrics_cfg.get("synergy_num_pairs", 10) if isinstance(metrics_cfg, dict) else 10), halo_percentile=getattr(config, "halo_percentile", halo_cfg.get("percentile", 90.0) if isinstance(halo_cfg, dict) else 90.0), pruning_ratios=pruning_ratios, @@ -153,24 +172,33 @@ def _get_nested(obj, key, default): # Load model model_name = cluster_config.model_name.lower() - num_classes = 10 if "cifar" in cluster_config.dataset_name.lower() else 1000 + dataset_name = cluster_config.dataset_name.lower() + # Prefer explicit num_classes from config.model.num_classes when present + num_classes = ( + int(model_cfg.get("num_classes")) if isinstance(model_cfg, dict) and model_cfg.get("num_classes") is not None + else (10 if "cifar10" in dataset_name else 100 if "cifar100" in dataset_name else 100 if "imagenet100" in dataset_name else 1000) + ) # Check for pre-trained checkpoint model_cfg = _get_nested(config, "model", {}) checkpoint_path = model_cfg.get("checkpoint", None) if isinstance(model_cfg, dict) else None checkpoint_path = checkpoint_path or getattr(config, "model_checkpoint", None) + pretrained = bool(model_cfg.get("pretrained", True)) if isinstance(model_cfg, dict) else True + weights_name = model_cfg.get("weights", None) if isinstance(model_cfg, dict) else None + weights_arg = weights_name if pretrained else None + if "resnet18" in model_name: - model = torchvision.models.resnet18(weights='IMAGENET1K_V1') + model = torchvision.models.resnet18(weights=weights_arg or 'IMAGENET1K_V1') model.fc = torch.nn.Linear(model.fc.in_features, num_classes) elif "resnet50" in model_name: - model = torchvision.models.resnet50(weights='IMAGENET1K_V1') + model = torchvision.models.resnet50(weights=weights_arg or 'IMAGENET1K_V1') model.fc = torch.nn.Linear(model.fc.in_features, num_classes) elif "vgg16" in model_name: - model = torchvision.models.vgg16_bn(weights='IMAGENET1K_V1') + model = torchvision.models.vgg16_bn(weights=weights_arg or 'IMAGENET1K_V1') model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) elif "mobilenet" in model_name: - model = torchvision.models.mobilenet_v2(weights='IMAGENET1K_V1') + model = torchvision.models.mobilenet_v2(weights=weights_arg or 'IMAGENET1K_V1') model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) else: raise ValueError(f"Unknown model: {model_name}") @@ -188,27 +216,56 @@ def _get_nested(obj, key, default): needs_training = True # Load dataset - dataset_name = cluster_config.dataset_name.lower() if "cifar10" in dataset_name: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) - train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) - test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + root = dataset_cfg.get("root", "./data") if isinstance(dataset_cfg, dict) else "./data" + train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform) + test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform) elif "cifar100" in dataset_name: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) - train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) - test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) + root = dataset_cfg.get("root", "./data") if isinstance(dataset_cfg, dict) else "./data" + train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform) + test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform) + elif "imagenet100" in dataset_name: + # Expected folder structure: {root}/train/* and {root}/val/* (ImageFolder) + root = dataset_cfg.get("root", "./data/imagenet100") if isinstance(dataset_cfg, dict) else "./data/imagenet100" + train_dir = Path(root) / "train" + val_dir = Path(root) / "val" + if not train_dir.exists() or not val_dir.exists(): + raise FileNotFoundError( + f"ImageNet-100 not found. Expected ImageFolder dirs at: {train_dir} and {val_dir}" + ) + + imagenet_mean = (0.485, 0.456, 0.406) + imagenet_std = (0.229, 0.224, 0.225) + image_size = int(dataset_cfg.get("image_size", 224)) if isinstance(dataset_cfg, dict) else 224 + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(imagenet_mean, imagenet_std), + ]) + val_transform = transforms.Compose([ + transforms.Resize(int(image_size * 256 / 224)), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize(imagenet_mean, imagenet_std), + ]) + train_dataset = torchvision.datasets.ImageFolder(root=str(train_dir), transform=train_transform) + test_dataset = torchvision.datasets.ImageFolder(root=str(val_dir), transform=val_transform) else: raise ValueError(f"Unknown dataset: {dataset_name}") - batch_size = getattr(config, "batch_size", 128) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) - test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4) + batch_size = int(dataset_cfg.get("batch_size", getattr(config, "batch_size", 128))) if isinstance(dataset_cfg, dict) else int(getattr(config, "batch_size", 128)) + num_workers = int(dataset_cfg.get("num_workers", 4)) if isinstance(dataset_cfg, dict) else 4 + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers) # Fine-tune the model on target dataset before experiments # This is necessary because we replaced the classifier head with random weights @@ -229,12 +286,13 @@ def _get_nested(obj, key, default): extra_cfg.get("pretrain_lr") if isinstance(extra_cfg, dict) else None ) or getattr(config, "pretrain_lr", 0.001) - model = _finetune_model_for_dataset( - model, train_loader, test_loader, - device=cluster_config.device, - epochs=pretrain_epochs, - lr=pretrain_lr, - ) + if needs_training: + model = _finetune_model_for_dataset( + model, train_loader, test_loader, + device=cluster_config.device, + epochs=pretrain_epochs, + lr=pretrain_lr, + ) # Save the trained model checkpoint output_dir = Path(cluster_config.output_dir) diff --git a/slurm-54745862.out b/slurm-54745862.out new file mode 100644 index 00000000..b8070148 --- /dev/null +++ b/slurm-54745862.out @@ -0,0 +1,37 @@ +============================================== +Submitting SCAR Paper Experiments +============================================== + +Output directory: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM + +Submitting LLaMA-3.1-8B (main results)... + Job ID: 54745874 +Submitting Mistral-7B (generalization)... + Job ID: 54745875 +Submitting LLaMA-2-7B (generalization)... + Job ID: 54745878 +Submitting Qwen2-7B (generalization)... + Job ID: 54745879 + +============================================== +All jobs submitted! +============================================== + +Job IDs: 54745874, 54745875, 54745878, 54745879 + +Monitor with: + squeue -u $USER + +View SLURM logs: + tail -f logs/paper_llama3_8b_54745874.out + tail -f logs/paper_mistral_7b_54745875.out + tail -f logs/paper_llama2_7b_54745878.out + tail -f logs/paper_qwen2_7b_54745879.out + +Expected runtime: ~6-8 hours per job + +Results will be in: + /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/llama3_8b_paper_results_*_54745874/ + /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/mistral_7b_paper_results_*_54745875/ + /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/llama2_7b_paper_results_*_54745878/ + /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/qwen2_7b_paper_results_*_54745879/ diff --git a/slurm_jobs/prune_llm/paper/README.md b/slurm_jobs/prune_llm/paper/README.md deleted file mode 100644 index e9a9d89e..00000000 --- a/slurm_jobs/prune_llm/paper/README.md +++ /dev/null @@ -1,43 +0,0 @@ -### SCAR paper experiment suite (batch + collection) - -This folder contains **SLURM batch scripts** that run a complete ICML-style paper suite: - -- **Main results + generalization** (4 models) -- **Key controls / ablations** on Llama-3.1-8B: - - **LP-no-protect** + **remove-supernodes-early** (mode=high) control - - **Protect+Wanda** and **Protect+Magnitude** (baseline + supernode protection) - - **Positive-only redundancy** ablation (anti-correlation does NOT count as redundancy) - - **Calibration sensitivity** sweep (dataset + sample-count) - -All jobs write to a single `OUTPUT_BASE` using the unified job directory structure: - -`{OUTPUT_BASE}/{experiment_name}_{timestamp}_{job_id}/` - -### How to run - -- **Set output base** (or let scripts use the default in each file): - -```bash -export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" -``` - -- **Submit the full suite**: - -```bash -bash slurm_jobs/prune_llm/paper/submit_suite.sh -``` - -### How to collect artifacts (tables + placeholder figures) - -After jobs finish: - -```bash -python scripts/collect_paper_artifacts.py \ - --results-base "$OUTPUT_BASE" \ - --draft-dir /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/drafts/LLM_prune -``` - -This will: -- write LaTeX snippets to `drafts/LLM_prune/paper_artifacts/tables/*.tex` -- copy key plots into `drafts/LLM_prune/placeholder_*.png` so the draft auto-fills. - diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh deleted file mode 100644 index d2807c35..00000000 --- a/slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama3_calib -#SBATCH --output=logs/paper_llama3_calib_%A_%a.out -#SBATCH --error=logs/paper_llama3_calib_%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=06:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev -#SBATCH --array=0-4 - -# ---------------------------------------------------------------------------- -# LLaMA-3.1-8B SWEEP: calibration sensitivity for SCAR-Conn @ 50% sparsity -# -# Task mapping: -# 0: wikitext, n=128 -# 1: wikitext, n=64 -# 2: wikitext, n=32 -# 3: c4, n=128 -# 4: mixed_wikitext_c4, n=128 -# -# Notes: -# - We restrict pruning to SCAR-Conn at 50% and evaluate perplexity only (fast). -# ---------------------------------------------------------------------------- - -set -euo pipefail - -DATASETS=("wikitext" "wikitext" "wikitext" "c4" "mixed_wikitext_c4") -NSAMPLES=(128 64 32 128 128) -TAGS=("wikitext_128" "wikitext_64" "wikitext_32" "c4_128" "mixed_128") - -IDX="${SLURM_ARRAY_TASK_ID}" -DATASET="${DATASETS[$IDX]}" -N="${NSAMPLES[$IDX]}" -TAG="${TAGS[$IDX]}" - -echo "============================================================================" -echo "SCAR Paper Sweep: LLaMA-3.1-8B calibration sensitivity (${TAG})" -echo "============================================================================" -echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "Calibration dataset: ${DATASET}" -echo "Calibration samples: ${N}" -echo "" - -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" \ - name="llama3_8b_paper_results_calib_${TAG}" \ - generate_plots=false \ - dataset_name="${DATASET}" \ - alignment_data_num_samples="${N}" \ - scar_num_samples="${N}" \ - pruning_strategies="['supernode_connectivity_score']" \ - pruning_amounts="[0.5]" \ - pruning_selection_mode="['low']" \ - "llm.evaluation_metrics=['perplexity']" \ - do_directed_redundancy=false \ - do_halo_analysis=false \ - do_generalized_importance=false \ - supernode_robustness.enabled=false \ - supernode_summary.enabled=false - -echo "" -echo "============================================================================" -echo "LLaMA-3.1-8B calibration sweep (${TAG}) completed at $(date)" -echo "============================================================================" - diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh deleted file mode 100644 index 3263a5ab..00000000 --- a/slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama3_noprotect -#SBATCH --output=logs/paper_llama3_noprotect_%j.out -#SBATCH --error=logs/paper_llama3_noprotect_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=06:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ---------------------------------------------------------------------------- -# LLaMA-3.1-8B CONTROL: LP-no-protect + "remove supernodes early" (mode=high) -# -# Produces (at 50%): -# - LP-no-protect: metric=scar_loss_proxy, mode=low, protect_core=false -# - Remove-core-early metric=scar_loss_proxy, mode=high, protect_core=false -# ---------------------------------------------------------------------------- - -set -euo pipefail - -echo "============================================================================" -echo "SCAR Paper Control: LLaMA-3.1-8B (no-protect LP control)" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" \ - name="llama3_8b_paper_results_noprotect" \ - generate_plots=false \ - supernode.protect_core=false \ - pruning_strategies="['scar_loss_proxy']" \ - pruning_amounts="[0.5]" \ - pruning_selection_mode="['low','high']" \ - do_connectivity_pruning=false \ - do_directed_redundancy=false \ - do_halo_analysis=false \ - do_generalized_importance=false \ - supernode_robustness.enabled=false \ - supernode_summary.enabled=false - -echo "" -echo "============================================================================" -echo "LLaMA-3.1-8B no-protect control completed at $(date)" -echo "============================================================================" - diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh deleted file mode 100644 index 240edc85..00000000 --- a/slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama3_posred -#SBATCH --output=logs/paper_llama3_posred_%A_%a.out -#SBATCH --error=logs/paper_llama3_posred_%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=06:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev -#SBATCH --array=0-1 - -# ---------------------------------------------------------------------------- -# LLaMA-3.1-8B ABLATION: positive-only redundancy vs rho^2 redundancy -# -# Task 0: positive_redundancy=false (rho^2 counts anti-correlation as redundancy) -# Task 1: positive_redundancy=true (rho^+ only; anti-correlation NOT redundant) -# ---------------------------------------------------------------------------- - -set -euo pipefail - -if [ "${SLURM_ARRAY_TASK_ID}" -eq 0 ]; then - POS_RED="false" - TAG="rho2" -else - POS_RED="true" - TAG="posonly" -fi - -echo "============================================================================" -echo "SCAR Paper Ablation: LLaMA-3.1-8B (positive redundancy = ${POS_RED})" -echo "============================================================================" -echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" \ - name="llama3_8b_paper_results_posred_${TAG}" \ - generate_plots=false \ - supernode.positive_redundancy="${POS_RED}" \ - supernode.protect_core=true \ - "supernode.protect_core_metrics=['supernode_connectivity_score']" \ - pruning_strategies="['supernode_connectivity_score']" \ - pruning_amounts="[0.5]" \ - pruning_selection_mode="['low']" \ - do_directed_redundancy=false \ - do_halo_analysis=false \ - do_generalized_importance=false \ - supernode_robustness.enabled=false \ - supernode_summary.enabled=false - -echo "" -echo "============================================================================" -echo "LLaMA-3.1-8B pos-redundancy ablation (${TAG}) completed at $(date)" -echo "============================================================================" - diff --git a/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh b/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh deleted file mode 100644 index eafcd13f..00000000 --- a/slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama3_protect_base -#SBATCH --output=logs/paper_llama3_protect_base_%j.out -#SBATCH --error=logs/paper_llama3_protect_base_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=08:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ---------------------------------------------------------------------------- -# LLaMA-3.1-8B CONTROL: Protect+Baseline variants -# -# Produces (at 50%): -# - Protect+Wanda: metric=wanda, protect_core_metrics includes wanda -# - Protect+Magnitude: metric=weight_magnitude, protect_core_metrics includes weight_magnitude -# ---------------------------------------------------------------------------- - -set -euo pipefail - -echo "============================================================================" -echo "SCAR Paper Control: LLaMA-3.1-8B (protect baselines)" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" \ - name="llama3_8b_paper_results_protect_baselines" \ - generate_plots=false \ - supernode.protect_core=true \ - "supernode.protect_core_metrics=['wanda','weight_magnitude']" \ - pruning_strategies="['wanda','weight_magnitude']" \ - pruning_amounts="[0.5]" \ - pruning_selection_mode="['low']" \ - do_connectivity_pruning=false \ - do_directed_redundancy=false \ - do_halo_analysis=false \ - do_generalized_importance=false \ - supernode_robustness.enabled=false \ - supernode_summary.enabled=false - -echo "" -echo "============================================================================" -echo "LLaMA-3.1-8B protect-baselines completed at $(date)" -echo "============================================================================" - diff --git a/slurm_jobs/prune_llm/paper/submit_suite.sh b/slurm_jobs/prune_llm/paper/submit_suite.sh deleted file mode 100644 index bc6c29aa..00000000 --- a/slurm_jobs/prune_llm/paper/submit_suite.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash -# ============================================================================ -# SUBMIT FULL SCAR PAPER SUITE (main + controls/ablations) -# ============================================================================ -# Usage: -# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -# bash slurm_jobs/prune_llm/paper/submit_suite.sh -# -# Output: -# Uses OUTPUT_BASE (exported or defaulted below). -# ============================================================================ - -set -euo pipefail - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" - -echo "==============================================" -echo "Submitting SCAR Paper Suite" -echo "==============================================" -echo "OUTPUT_BASE: $OUTPUT_BASE" -echo "" - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -mkdir -p logs - -echo "---- Main results + generalization (4 models) ----" -export OUTPUT_BASE -bash slurm_jobs/prune_llm/run_all_paper.sh -echo "" - -echo "---- Controls / ablations (Llama-3.1-8B) ----" -JOB_NP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_noprotect.sh | awk '{print $4}') -echo " noprotect/control: $JOB_NP" - -JOB_PB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_protect_baselines.sh | awk '{print $4}') -echo " protect-baselines: $JOB_PB" - -JOB_POSRED=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_positive_redundancy_array.sh | awk '{print $4}') -echo " pos-redundancy array: $JOB_POSRED" - -JOB_CALIB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/paper/run_llama3_8b_calibration_array.sh | awk '{print $4}') -echo " calibration array: $JOB_CALIB" - -echo "" -echo "==============================================" -echo "All suite jobs submitted" -echo "==============================================" -echo "Monitor with: squeue -u \$USER" -echo "" - diff --git a/slurm_jobs/prune_llm/run_all_paper.sh b/slurm_jobs/prune_llm/run_all_paper.sh deleted file mode 100755 index e6e8f443..00000000 --- a/slurm_jobs/prune_llm/run_all_paper.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash -# ============================================================================ -# SUBMIT ALL PAPER EXPERIMENTS -# ============================================================================ -# This script submits all 4 paper experiments as separate SLURM jobs -# They will run in parallel if resources are available -# -# Output Directory Structure: -# All results go to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ -# Each job creates a unique directory: {model}_paper_results_{timestamp}_{job_id}/ -# results/ - JSON results files -# logs/ - experiment.log -# figures/ - All visualizations -# checkpoints/ - Model checkpoints -# analysis/ - Post-analysis outputs -# -# Usage: -# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -# bash slurm_jobs/prune_llm/run_all_paper.sh -# ============================================================================ - -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" - -echo "==============================================" -echo "Submitting SCAR Paper Experiments" -echo "==============================================" -echo "" -echo "Output directory: $OUTPUT_BASE" -echo "" - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Submit all jobs -echo "Submitting LLaMA-3.1-8B (main results)..." -JOB1=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_llama3_8b.sh | awk '{print $4}') -echo " Job ID: $JOB1" - -echo "Submitting Mistral-7B (generalization)..." -JOB2=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_mistral_7b.sh | awk '{print $4}') -echo " Job ID: $JOB2" - -echo "Submitting LLaMA-2-7B (generalization)..." -JOB3=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_llama2_7b.sh | awk '{print $4}') -echo " Job ID: $JOB3" - -echo "Submitting Qwen2-7B (generalization)..." -JOB4=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/prune_llm/run_qwen2_7b.sh | awk '{print $4}') -echo " Job ID: $JOB4" - -echo "" -echo "==============================================" -echo "All jobs submitted!" -echo "==============================================" -echo "" -echo "Job IDs: $JOB1, $JOB2, $JOB3, $JOB4" -echo "" -echo "Monitor with:" -echo " squeue -u \$USER" -echo "" -echo "View SLURM logs:" -echo " tail -f logs/paper_llama3_8b_${JOB1}.out" -echo " tail -f logs/paper_mistral_7b_${JOB2}.out" -echo " tail -f logs/paper_llama2_7b_${JOB3}.out" -echo " tail -f logs/paper_qwen2_7b_${JOB4}.out" -echo "" -echo "Expected runtime: ~6-8 hours per job" -echo "" -echo "Results will be in:" -echo " $OUTPUT_BASE/llama3_8b_paper_results_*_${JOB1}/" -echo " $OUTPUT_BASE/mistral_7b_paper_results_*_${JOB2}/" -echo " $OUTPUT_BASE/llama2_7b_paper_results_*_${JOB3}/" -echo " $OUTPUT_BASE/qwen2_7b_paper_results_*_${JOB4}/" diff --git a/slurm_jobs/prune_llm/run_llama2_7b.sh b/slurm_jobs/prune_llm/run_llama2_7b.sh deleted file mode 100755 index c9d36de8..00000000 --- a/slurm_jobs/prune_llm/run_llama2_7b.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama2_7b -#SBATCH --output=logs/paper_llama2_7b_%j.out -#SBATCH --error=logs/paper_llama2_7b_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=10:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -# ============================================================================ -# LLAMA-2-7B PAPER RESULTS (Generalization) -# ============================================================================ -# Cross-model generalization experiment -# Expected runtime: ~4-6 hours on H100 -# -# Output Directory Structure: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ -# llama2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ -# results/ - JSON results files -# logs/ - experiment.log -# figures/ - All visualizations -# checkpoints/ - Model checkpoints -# analysis/ - Post-analysis outputs -# ============================================================================ - -echo "============================================================================" -echo "SCAR Paper: LLaMA-2-7B (Generalization)" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create local logs directory for SLURM output files -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "" -echo "Running LLaMA-2-7B full paper analysis..." -echo "" - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama2_7b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" - -echo "" -echo "============================================================================" -echo "LLaMA-2-7B completed at $(date)" -echo "============================================================================" -echo "" -echo "Results saved to: $OUTPUT_BASE/" -echo "Look for directory: llama2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_llama3_8b.sh b/slurm_jobs/prune_llm/run_llama3_8b.sh deleted file mode 100755 index b7fa3c22..00000000 --- a/slurm_jobs/prune_llm/run_llama3_8b.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_llama3_8b -#SBATCH --output=logs/paper_llama3_8b_%j.out -#SBATCH --error=logs/paper_llama3_8b_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=12:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ============================================================================ -# LLAMA-3.1-8B PAPER RESULTS -# ============================================================================ -# Full SCAR analysis including: -# - Supernode distribution & robustness -# - Halo redundancy analysis -# - Cross-layer importance -# - Within-layer importance -# - All pruning methods + SOTA baselines (Wanda, SparseGPT) -# - Full benchmark evaluation -# -# Expected runtime: ~6-8 hours on H100 -# -# Output Directory Structure: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ -# llama3_8b_paper_results_{timestamp}_{SLURM_JOB_ID}/ -# results/ - JSON results files -# logs/ - experiment.log -# figures/ - All visualizations -# checkpoints/ - Model checkpoints -# analysis/ - Post-analysis outputs -# ============================================================================ - -echo "============================================================================" -echo "SCAR Paper: LLaMA-3.1-8B" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create local logs directory for SLURM output files -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "" -echo "Running LLaMA-3.1-8B full paper analysis..." -echo "" - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" - -echo "" -echo "============================================================================" -echo "LLaMA-3.1-8B completed at $(date)" -echo "============================================================================" -echo "" -echo "Results saved to: $OUTPUT_BASE/" -echo "Look for directory: llama3_8b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_mistral_7b.sh b/slurm_jobs/prune_llm/run_mistral_7b.sh deleted file mode 100755 index c4f04b7e..00000000 --- a/slurm_jobs/prune_llm/run_mistral_7b.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_mistral_7b -#SBATCH --output=logs/paper_mistral_7b_%j.out -#SBATCH --error=logs/paper_mistral_7b_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=10:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -# ============================================================================ -# MISTRAL-7B PAPER RESULTS (Generalization) -# ============================================================================ -# Cross-model generalization experiment -# Expected runtime: ~4-6 hours on H100 -# -# Output Directory Structure: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ -# mistral_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ -# results/ - JSON results files -# logs/ - experiment.log -# figures/ - All visualizations -# checkpoints/ - Model checkpoints -# analysis/ - Post-analysis outputs -# ============================================================================ - -echo "============================================================================" -echo "SCAR Paper: Mistral-7B (Generalization)" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create local logs directory for SLURM output files -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "" -echo "Running Mistral-7B full paper analysis..." -echo "" - -python scripts/run_experiment.py \ - --config configs/prune_llm/mistral_7b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" - -echo "" -echo "============================================================================" -echo "Mistral-7B completed at $(date)" -echo "============================================================================" -echo "" -echo "Results saved to: $OUTPUT_BASE/" -echo "Look for directory: mistral_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_qwen2_7b.sh b/slurm_jobs/prune_llm/run_qwen2_7b.sh deleted file mode 100755 index a81d62b9..00000000 --- a/slurm_jobs/prune_llm/run_qwen2_7b.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=paper_qwen2_7b -#SBATCH --output=logs/paper_qwen2_7b_%j.out -#SBATCH --error=logs/paper_qwen2_7b_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=10:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -# ============================================================================ -# QWEN2-7B PAPER RESULTS (Generalization) -# ============================================================================ -# Cross-model generalization experiment -# Qwen2 has different FFN architecture (28 layers, larger intermediate) -# Expected runtime: ~4-6 hours on H100 -# -# Output Directory Structure: -# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ -# qwen2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ -# results/ - JSON results files -# logs/ - experiment.log -# figures/ - All visualizations -# checkpoints/ - Model checkpoints -# analysis/ - Post-analysis outputs -# ============================================================================ - -echo "============================================================================" -echo "SCAR Paper: Qwen2-7B (Generalization)" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" -echo "Output Base: $OUTPUT_BASE" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create local logs directory for SLURM output files -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "" -echo "Running Qwen2-7B full paper analysis..." -echo "" - -python scripts/run_experiment.py \ - --config configs/prune_llm/qwen2_7b_full.yaml \ - --device cuda \ - --base-output-dir "$OUTPUT_BASE" - -echo "" -echo "============================================================================" -echo "Qwen2-7B completed at $(date)" -echo "============================================================================" -echo "" -echo "Results saved to: $OUTPUT_BASE/" -echo "Look for directory: qwen2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh b/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh deleted file mode 100644 index f3cde664..00000000 --- a/slurm_jobs/prune_vision/run_cluster_analysis_resnet18.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=cluster_analysis_resnet18 -#SBATCH --output=logs/cluster_analysis_resnet18_%j.out -#SBATCH --error=logs/cluster_analysis_resnet18_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --time=4:00:00 -#SBATCH --mem=64GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ============================================================================ -# CLUSTER-BASED ANALYSIS: ResNet-18 on CIFAR-10 -# ============================================================================ -# Full cluster-based analysis including: -# - Per-channel metrics (RQ, Redundancy, Synergy with continuous target) -# - K-means clustering into functional types -# - Cross-layer halo analysis -# - Cascade damage testing -# - Pruning experiments (without fine-tuning to see raw impact) -# - Organized visualization output -# -# Figure Organization: -# figures/01_distributions/ - Per-layer metric histograms -# figures/02_summary/ - Layer-wise violin plots, trends -# figures/03_clustering/ - Cluster scatter plots, evolution -# figures/04_cascade/ - Cascade damage test results -# figures/05_halo/ - Halo analysis plots -# figures/06_pruning/ - Pruning comparison charts -# -# Expected runtime: ~1-2 hours on single GPU -# ============================================================================ - -OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" - -echo "============================================================================" -echo "Cluster-Based Analysis: ResNet-18 on CIFAR-10" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" -echo "Output Base: $OUTPUT_BASE" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create local logs directory for SLURM output files -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK - -echo "" -echo "Running ResNet-18 cluster analysis..." -echo "Fine-tuning after pruning: DISABLED (seeing raw pruning impact)" -echo "" - -python scripts/run_experiment.py \ - --config configs/vision_prune/resnet18_cifar10_unified.yaml \ - --device cuda - -EXIT_CODE=$? - -echo "" -echo "============================================================================" -echo "ResNet-18 cluster analysis completed at $(date)" -echo "Exit code: $EXIT_CODE" -echo "============================================================================" -echo "" -echo "Results saved to: $OUTPUT_BASE/" -echo "Look for directory starting with: resnet18_cifar10_cluster_analysis_" - -exit $EXIT_CODE diff --git a/slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh b/slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh deleted file mode 100644 index f34b623a..00000000 --- a/slurm_jobs/prune_vision/run_cluster_analysis_resnet50.sh +++ /dev/null @@ -1,76 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=cluster_analysis_resnet50 -#SBATCH --output=logs/cluster_analysis_resnet50_%j.out -#SBATCH --error=logs/cluster_analysis_resnet50_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=8 -#SBATCH --time=12:00:00 -#SBATCH --mem=128GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ============================================================================ -# CLUSTER-BASED ANALYSIS: ResNet-50 on ImageNet-100 -# ============================================================================ -# Full cluster-based analysis including: -# - Per-channel metrics (RQ, Redundancy, Synergy with continuous target) -# - K-means clustering into functional types -# - Cross-layer halo analysis with activation weighting -# - Cascade damage testing -# - Pruning experiments with fine-tuning -# - Visualization generation -# -# Expected runtime: ~6-10 hours on single GPU (A100) -# ============================================================================ - -echo "============================================================================" -echo "Cluster-Based Analysis: ResNet-50 on ImageNet-100" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Create directories -mkdir -p logs -mkdir -p results/vision/resnet50_imagenet100 - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK - -# Check for ImageNet-100 data -if [ ! -d "data/imagenet100" ]; then - echo "WARNING: ImageNet-100 data not found at data/imagenet100" - echo "Please download or symlink the ImageNet-100 subset before running." - echo "" -fi - -echo "" -echo "Running ResNet-50 cluster analysis..." -echo "" - -python scripts/run_experiment.py \ - --config configs/cluster_analysis/resnet50_imagenet100.yaml \ - --device cuda - -EXIT_CODE=$? - -echo "" -echo "============================================================================" -echo "ResNet-50 cluster analysis completed at $(date)" -echo "Exit code: $EXIT_CODE" -echo "============================================================================" -echo "" -echo "Results saved to: results/vision/resnet50_imagenet100/" - -exit $EXIT_CODE diff --git a/slurm_jobs/run_comprehensive_pruning.sh b/slurm_jobs/run_comprehensive_pruning.sh deleted file mode 100644 index 6d8e84c8..00000000 --- a/slurm_jobs/run_comprehensive_pruning.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=comp_prune -#SBATCH --output=logs/comprehensive_pruning_%j.out -#SBATCH --error=logs/comprehensive_pruning_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=24:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev - -# ============================================================================ -# COMPREHENSIVE LLM PRUNING COMPARISON -# ============================================================================ -# Compares ALL custom pruning methods vs SOTA baselines -# Expected runtime: 6-12 hours on H100 -# ============================================================================ - -echo "============================================================================" -echo "COMPREHENSIVE LLM PRUNING COMPARISON" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "============================================================================" -echo "PRUNING METHODS TO COMPARE:" -echo "============================================================================" -echo "" -echo "ALIGNMENT-BASED (Our Methods):" -echo " - rayleigh_quotient (RQ - alignment measure)" -echo " - gaussian_mi_analytic (MI - mutual information)" -echo " - average_redundancy (Information-theoretic)" -echo "" -echo "SCAR-BASED (Gradient-Informed):" -echo " - scar_loss_proxy (Activation power + curvature)" -echo "" -echo "SUPERNODE-AWARE (Novel Contribution):" -echo " - supernode_protection_score (Protects unique halo neurons)" -echo " - supernode_connectivity_score (Low connectivity = safe to prune)" -echo "" -echo "GENERALIZED (No Outlier Assumption):" -echo " - generalized_importance (Works without supernode structure)" -echo "" -echo "ANALYSIS:" -echo " - Halo vs Non-halo redundancy comparison" -echo "" -echo "MAGNITUDE-BASED (Baseline):" -echo " - activation_l2_norm (Standard magnitude)" -echo "" -echo "SOTA BASELINES:" -echo " - wanda (Sun et al., 2023)" -echo " - sparsegpt (Frantar & Alistarh, 2023)" -echo "" -echo "============================================================================" -echo "SPARSITY LEVELS: 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90%" -echo "SELECTION MODES: low (prune lowest), high (prune highest)" -echo "============================================================================" -echo "" -echo "EVALUATION BENCHMARKS:" -echo " - Perplexity, Loss, Bits-per-Byte (WikiText-2)" -echo " - MMLU (57 subjects)" -echo " - HellaSwag (Commonsense)" -echo " - ARC-Easy/Challenge (Science)" -echo " - WinoGrande (Schemas)" -echo " - PIQA (Physical intuition)" -echo " - BoolQ (Boolean QA)" -echo " - GSM8k (Math)" -echo " - TruthfulQA" -echo " - MBPP/HumanEval (Code)" -echo "============================================================================" -echo "" - -python scripts/run_experiment.py \ - --config configs/examples/llama3_comprehensive_pruning.yaml \ - --device cuda - -echo "" -echo "============================================================================" -echo "Comprehensive pruning comparison completed at $(date)" -echo "============================================================================" - diff --git a/slurm_jobs/run_llama3_full_benchmark.sh b/slurm_jobs/run_llama3_full_benchmark.sh deleted file mode 100644 index 6ce7d85a..00000000 --- a/slurm_jobs/run_llama3_full_benchmark.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=llama3_bench -#SBATCH --output=logs/llama3_bench_%j.out -#SBATCH --error=logs/llama3_bench_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=12:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -echo "==========================================" -echo "LLaMA-3 Full Benchmark Suite" -echo "Including NVIDIA Minitron benchmarks + Wanda/SparseGPT baselines" -echo "==========================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Make logs directory if it doesn't exist -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "FULL BENCHMARK SUITE:" -echo "=====================" -echo "Language Model Metrics:" -echo " - Perplexity, Loss, Bits-per-byte" -echo "" -echo "NVIDIA Minitron Benchmarks (https://arxiv.org/abs/2407.14679):" -echo " - MMLU (Massive Multitask Language Understanding)" -echo " - HellaSwag (Commonsense reasoning)" -echo " - ARC-Challenge (Hard science questions)" -echo " - WinoGrande (Winograd schemas)" -echo " - PIQA (Physical intuition)" -echo " - TruthfulQA (Truthfulness)" -echo "" -echo "Additional Benchmarks:" -echo " - ARC-Easy (Science questions)" -echo " - BoolQ (Boolean questions)" -echo "" -echo "Pruning Strategies:" -echo " - Magnitude (L2), SCAR, RQ, MI, Redundancy" -echo " - Supernode protection/connectivity" -echo " - Wanda (Sun et al., 2023)" -echo " - SparseGPT (Frantar & Alistarh, 2023)" -echo "" - -python scripts/run_experiment.py \ - --config configs/examples/llama3_full_benchmark.yaml \ - --device cuda - -echo "" -echo "==========================================" -echo "Full benchmark completed at $(date)" -echo "==========================================" - diff --git a/slurm_jobs/run_llama3_scar_pruning.sh b/slurm_jobs/run_llama3_scar_pruning.sh deleted file mode 100755 index a95c3f78..00000000 --- a/slurm_jobs/run_llama3_scar_pruning.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=llama3_scar -#SBATCH --output=logs/llama3_scar_%j.out -#SBATCH --error=logs/llama3_scar_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=8:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -echo "==========================================" -echo "LLaMA-3 SCAR-Based Pruning with Supernode Protection" -echo "==========================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Make logs directory if it doesn't exist -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "Running SCAR-based pruning experiment..." -echo "Pruning metrics: L2 norm, SCAR loss proxy" -echo "Selection modes: low, high, random" -echo "Evaluation: perplexity, bits_per_byte, MMLU" -echo "" - -python scripts/run_experiment.py \ - --config configs/examples/llama3_scar_pruning.yaml \ - --device cuda - -echo "" -echo "==========================================" -echo "Experiment completed at $(date)" -echo "==========================================" diff --git a/slurm_jobs/run_minitron_comparison.sh b/slurm_jobs/run_minitron_comparison.sh deleted file mode 100755 index e3948410..00000000 --- a/slurm_jobs/run_minitron_comparison.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=minitron_cmp -#SBATCH --output=logs/minitron_cmp_%j.out -#SBATCH --error=logs/minitron_cmp_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=24:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -echo "==========================================" -echo "NVIDIA MINITRON-COMPATIBLE BENCHMARK" -echo "Reference: https://arxiv.org/abs/2408.11796" -echo "==========================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "==========================================" -echo "NVIDIA Minitron Reference (Llama 3.1 8B → 4B, 50% pruned):" -echo "" -echo " Benchmark │ Baseline │ Pruned │ Few-shot" -echo " ─────────────────┼──────────┼─────────┼──────────" -echo " Winogrande │ 77.3% │ 73.5% │ 5-shot" -echo " ARC-Challenge │ 57.9% │ 55.6% │ 25-shot" -echo " MMLU │ 65.3% │ 60.5% │ 5-shot" -echo " HellaSwag │ 81.8% │ 76.1% │ 10-shot" -echo " GSM8k │ 48.6% │ 41.2% │ 5-shot+CoT" -echo " TruthfulQA │ 45.0% │ 42.9% │ 0-shot" -echo " MBPP │ 42.3% │ 32.4% │ 0-shot" -echo " HumanEval │ 24.8% │ - │ 0-shot" -echo "==========================================" -echo "" -echo "PRUNING METHODS BEING COMPARED:" -echo "" -echo " ALIGNMENT-BASED (Our Novel Methods):" -echo " - rayleigh_quotient (RQ)" -echo " - gaussian_mi_analytic (MI)" -echo " - average_redundancy" -echo "" -echo " SCAR-BASED (Gradient-Informed):" -echo " - scar_loss_proxy" -echo "" -echo " SUPERNODE-AWARE (Novel Contribution):" -echo " - supernode_protection_score" -echo " - supernode_connectivity_score" -echo "" -echo " MAGNITUDE-BASED (Baseline):" -echo " - activation_l2_norm" -echo "" -echo " SOTA BASELINES (NVIDIA Minitron comparison):" -echo " - wanda (Sun et al., 2023)" -echo " - sparsegpt (Frantar & Alistarh, 2023)" -echo "" -echo "Sparsity Levels: 25%, 50%, 75%" -echo "Selection Modes: low, high" -echo "==========================================" -echo "" - -python scripts/run_experiment.py \ - --config configs/examples/llama3_minitron_comparison.yaml \ - --device cuda - -echo "" -echo "==========================================" -echo "Minitron comparison completed at $(date)" -echo "==========================================" -echo "" -echo "Results saved to: results/llama3_minitron_comparison_*/" -echo "Check plots/pruning/ for comparison plots" -echo "==========================================" diff --git a/slurm_jobs/run_multimodel_pruning.sh b/slurm_jobs/run_multimodel_pruning.sh deleted file mode 100644 index d993586d..00000000 --- a/slurm_jobs/run_multimodel_pruning.sh +++ /dev/null @@ -1,88 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=multimodel_prune -#SBATCH --output=logs/multimodel_pruning_%j_%a.out -#SBATCH --error=logs/multimodel_pruning_%j_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=16 -#SBATCH --time=24:00:00 -#SBATCH --mem=320GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev -#SBATCH --array=0-4 - -# ============================================================================ -# MULTI-MODEL PRUNING COMPARISON (Array Job) -# ============================================================================ -# Runs SCAR pruning comparison across multiple LLM architectures -# -# Models tested: -# 0: Mistral-7B (mistralai/Mistral-7B-v0.1) -# 1: Llama-2-7B (meta-llama/Llama-2-7b-hf) -# 2: Gemma-2B (google/gemma-2b) - smaller/faster -# 3: Phi-3 Mini (microsoft/Phi-3-mini-4k-instruct) - smaller/faster -# 4: Qwen2-7B (Qwen/Qwen2-7B) -# -# Expected runtime per model: 6-12 hours on H100 -# ============================================================================ - -# Define model configs as array -CONFIGS=( - "configs/examples/mistral7b_pruning.yaml" - "configs/examples/llama2_7b_pruning.yaml" - "configs/examples/gemma2b_pruning.yaml" - "configs/examples/phi3_mini_pruning.yaml" - "configs/examples/qwen2_7b_pruning.yaml" -) - -MODEL_NAMES=( - "Mistral-7B" - "Llama-2-7B" - "Gemma-2B" - "Phi-3-Mini" - "Qwen2-7B" -) - -# Get config for this array task -CONFIG=${CONFIGS[$SLURM_ARRAY_TASK_ID]} -MODEL_NAME=${MODEL_NAMES[$SLURM_ARRAY_TASK_ID]} - -echo "============================================================================" -echo "MULTI-MODEL PRUNING: ${MODEL_NAME}" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID (Array Task: $SLURM_ARRAY_TASK_ID)" -echo "Node: $(hostname)" -echo "Config: $CONFIG" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Environment setup -module purge -module load cuda/12.2.0-fasrc01 -eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis - -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -mkdir -p logs - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export TOKENIZERS_PARALLELISM=false -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) - -echo "============================================================================" -echo "Running experiment for ${MODEL_NAME}..." -echo "============================================================================" - -python scripts/run_experiment.py \ - --config "$CONFIG" \ - --device cuda - -echo "" -echo "============================================================================" -echo "${MODEL_NAME} pruning completed at $(date)" -echo "============================================================================" diff --git a/slurm_jobs/run_paper_experiments.sh b/slurm_jobs/run_paper_experiments.sh deleted file mode 100644 index fd106cc4..00000000 --- a/slurm_jobs/run_paper_experiments.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=scar_paper -#SBATCH --output=/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/logs/paper_%j.out -#SBATCH --error=/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/logs/paper_%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=8 -#SBATCH --mem=80G -#SBATCH --gres=gpu:1 -#SBATCH --time=30:00:00 -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev - -# ============================================================================ -# SCAR PAPER EXPERIMENTS -# ============================================================================ -# This script runs all experiments needed for the SCAR paper: -# 1. LLaMA-3.1-8B full analysis (main results) -# 2. Mistral-7B (generalization) -# 3. LLaMA-2-7B (generalization) -# 4. Qwen2-7B (generalization) -# -# Expected runtime: ~20-30 hours total (6-8h per model) -# ============================================================================ - -set -e - -echo "==============================================" -echo "SCAR Paper Experiments" -echo "==============================================" -echo "Start time: $(date)" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "==============================================" - -# Setup -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment -source ~/.bashrc -conda activate alignment - -# Create output directories -mkdir -p logs -mkdir -p results/paper - -# ============================================================================ -# Experiment 1: LLaMA-3.1-8B (Main Results) -# ============================================================================ -echo "" -echo "==============================================" -echo "Experiment 1: LLaMA-3.1-8B (Main Results)" -echo "==============================================" - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama3_8b_full.yaml \ - 2>&1 | tee logs/llama3_8b_paper.log - -echo "LLaMA-3.1-8B completed at $(date)" - -# ============================================================================ -# Experiment 2: Mistral-7B (Generalization) -# ============================================================================ -echo "" -echo "==============================================" -echo "Experiment 2: Mistral-7B (Generalization)" -echo "==============================================" - -python scripts/run_experiment.py \ - --config configs/prune_llm/mistral_7b_full.yaml \ - 2>&1 | tee logs/mistral_7b_paper.log - -echo "Mistral-7B completed at $(date)" - -# ============================================================================ -# Experiment 3: LLaMA-2-7B (Generalization) -# ============================================================================ -echo "" -echo "==============================================" -echo "Experiment 3: LLaMA-2-7B (Generalization)" -echo "==============================================" - -python scripts/run_experiment.py \ - --config configs/prune_llm/llama2_7b_full.yaml \ - 2>&1 | tee logs/llama2_7b_paper.log - -echo "LLaMA-2-7B completed at $(date)" - -# ============================================================================ -# Experiment 4: Qwen2-7B (Generalization) -# ============================================================================ -echo "" -echo "==============================================" -echo "Experiment 4: Qwen2-7B (Generalization)" -echo "==============================================" - -python scripts/run_experiment.py \ - --config configs/prune_llm/qwen2_7b_full.yaml \ - 2>&1 | tee logs/qwen2_7b_paper.log - -echo "Qwen2-7B completed at $(date)" - -# ============================================================================ -# Summary -# ============================================================================ -echo "" -echo "==============================================" -echo "All experiments completed!" -echo "==============================================" -echo "End time: $(date)" -echo "" -echo "Results saved to:" -echo " - results/paper/llama3_8b/" -echo " - results/paper/mistral_7b/" -echo " - results/paper/llama2_7b/" -echo " - results/paper/qwen2_7b/" -echo "" -echo "Figures ready for paper in:" -echo " - results/paper/*/figures/" diff --git a/slurm_jobs/run_supernode_robustness.sh b/slurm_jobs/run_supernode_robustness.sh deleted file mode 100644 index 6915ab07..00000000 --- a/slurm_jobs/run_supernode_robustness.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=sn_robust -#SBATCH --output=/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/logs/supernode_robustness_%j.out -#SBATCH --error=/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/logs/supernode_robustness_%j.err -#SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=16 -#SBATCH --gres=gpu:1 -#SBATCH --mem=128G -#SBATCH --time=04:00:00 - -# ============================================================================ -# SUPERNODE ROBUSTNESS ANALYSIS JOB -# ============================================================================ -# Analyzes the consistency of supernode identification across: -# - Different metrics (RQ, MI, SCAR, magnitude) -# - Different data batches (bootstrap sampling) -# -# Key outputs: -# - Jaccard similarity heatmaps -# - Spearman correlation heatmaps -# - Bootstrap stability distributions -# - Cross-metric consistency plots -# ============================================================================ - -set -e - -echo "============================================================================" -echo "SUPERNODE ROBUSTNESS ANALYSIS" -echo "============================================================================" -echo "Job ID: $SLURM_JOB_ID" -echo "Node: $(hostname)" -echo "Start time: $(date)" -echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" -echo "" - -# Setup environment -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment - -# Initialize conda -source ~/.bashrc -eval "$(conda shell.bash hook)" - -# Activate environment -conda activate networkAlignmentAnalysis - -# Verify Python environment -echo "Python: $(which python)" -echo "PyTorch: $(python -c 'import torch; print(torch.__version__)')" -echo "" - -# Check GPU availability -python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')" -echo "" - -echo "============================================================================" -echo "ANALYSIS FOCUS:" -echo "============================================================================" -echo "" -echo "1. CROSS-METRIC CONSISTENCY:" -echo " - Do different metrics identify the same neurons as supernodes?" -echo " - Metrics: SCAR activation power, SCAR loss proxy, SCAR taylor," -echo " Rayleigh quotient, Gaussian MI, Activation L2 norm" -echo "" -echo "2. BOOTSTRAP STABILITY:" -echo " - Are supernodes consistent across different input samples?" -echo " - 10 bootstrap resamples per layer" -echo "" -echo "3. LAYERS ANALYZED:" -echo " - Layer 5 (early)" -echo " - Layer 15 (middle)" -echo " - Layer 25 (late)" -echo "" -echo "============================================================================" -echo "" - -# Set HuggingFace cache (optional but recommended for cluster) -export HF_HOME=/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache -export TRANSFORMERS_CACHE=/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache - -# Run the experiment -python scripts/run_experiment.py \ - --config configs/examples/llama3_supernode_robustness.yaml - -echo "" -echo "============================================================================" -echo "Job completed at: $(date)" -echo "============================================================================" - diff --git a/src/alignment/analysis/clustering/cross_layer_halo.py b/src/alignment/analysis/clustering/cross_layer_halo.py index 9e5c1337..a088083c 100644 --- a/src/alignment/analysis/clustering/cross_layer_halo.py +++ b/src/alignment/analysis/clustering/cross_layer_halo.py @@ -137,19 +137,20 @@ def compute_cluster_to_cluster_flow( target_types: Mapping from cluster ID to type name Returns: - Nested dict: flow[source_type][target_type] = mean influence + Nested dict: flow[source_type][target_type] = normalized influence mass """ flow = {} for src_id, src_type in source_types.items(): flow[src_type] = {} src_mask = source_labels == src_id src_infl = influence[:, src_mask].sum(axis=1) # [out] + denom = float(src_infl.sum()) + 1e-10 for tgt_id, tgt_type in target_types.items(): tgt_mask = target_labels == tgt_id if tgt_mask.sum() > 0: - mean_infl = float(np.mean(src_infl[tgt_mask])) - flow[src_type][tgt_type] = mean_infl + # Fraction of total outgoing influence mass from src cluster + flow[src_type][tgt_type] = float(src_infl[tgt_mask].sum()) / denom else: flow[src_type][tgt_type] = 0.0 diff --git a/src/alignment/analysis/visualization/__init__.py b/src/alignment/analysis/visualization/__init__.py index 4ccb65d3..0fe5c59a 100644 --- a/src/alignment/analysis/visualization/__init__.py +++ b/src/alignment/analysis/visualization/__init__.py @@ -77,6 +77,14 @@ plot_halo_redundancy_heatmap, ) +# Paper-specific plots (SCAR draft) +from .paper_plots import ( + plot_loss_proxy_concentration, + plot_halo_structure, + plot_supernode_halo_summary, + plot_scar_schematic, +) + # Cluster visualization plots from .cluster_plots import ( plot_metric_scatter, @@ -128,6 +136,11 @@ "plot_halo_redundancy_by_depth", "plot_halo_redundancy_comprehensive", "plot_halo_redundancy_heatmap", + # Paper plots + "plot_loss_proxy_concentration", + "plot_halo_structure", + "plot_supernode_halo_summary", + "plot_scar_schematic", # Cluster plots "plot_metric_scatter", "plot_cluster_evolution", diff --git a/src/alignment/analysis/visualization/cluster_plots.py b/src/alignment/analysis/visualization/cluster_plots.py index f32e3859..f4635926 100644 --- a/src/alignment/analysis/visualization/cluster_plots.py +++ b/src/alignment/analysis/visualization/cluster_plots.py @@ -103,6 +103,126 @@ def plot_metric_scatter( return fig +def plot_metric_scatter_3d( + rq: np.ndarray, + redundancy: np.ndarray, + synergy: np.ndarray, + labels: np.ndarray, + type_mapping: Dict[int, str], + layer_name: str = "", + save_path: Optional[Path] = None, + figsize: Tuple[int, int] = (7, 6), + max_points: int = 20000, +) -> Optional["plt.Figure"]: + """ + Plot a 3D scatter in (log(RQ), Redundancy, Synergy) space. + + This is primarily intended for the vision paper's representative "cluster_3d_scatter.png". + """ + if not HAS_MPL: + return None + + try: + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 + except Exception: + return None + + log_rq = np.log(np.clip(np.asarray(rq).reshape(-1), 1e-10, None)) + red = np.asarray(redundancy).reshape(-1) + syn = np.asarray(synergy).reshape(-1) + lab = np.asarray(labels).reshape(-1).astype(int) + + n = int(log_rq.shape[0]) + if n == 0: + return None + + # Downsample for plot readability + if n > max_points: + rng = np.random.default_rng(0) + idx = rng.choice(np.arange(n), size=max_points, replace=False) + log_rq = log_rq[idx] + red = red[idx] + syn = syn[idx] + lab = lab[idx] + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection="3d") + + for cid, ctype in type_mapping.items(): + mask = lab == int(cid) + if mask.sum() == 0: + continue + color = CLUSTER_COLORS.get(ctype, "#999999") + ax.scatter( + log_rq[mask], + red[mask], + syn[mask], + c=color, + label=ctype, + alpha=0.55, + s=10, + depthshade=False, + ) + + ax.set_xlabel("log(RQ)") + ax.set_ylabel("Redundancy") + ax.set_zlabel("Synergy") + ax.set_title(f"Metric Space Clusters (3D): {layer_name}") + ax.legend(loc="best") + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=200, bbox_inches="tight") + logger.info(f"Saved 3D cluster scatter to {save_path}") + + return fig + + +def plot_pruning_by_cluster_type( + pruned: Dict[str, int], + total: Dict[str, int], + save_path: Optional[Path] = None, + title: str = "Pruned fraction by cluster type", + figsize: Tuple[int, int] = (7, 4), +) -> Optional["plt.Figure"]: + """Bar chart showing fraction pruned per cluster type.""" + if not HAS_MPL: + return None + + types = ["critical", "redundant", "synergistic", "background"] + frac = [] + for t in types: + denom = float(total.get(t, 0) or 0) + num = float(pruned.get(t, 0) or 0) + frac.append(num / denom if denom > 0 else 0.0) + + fig, ax = plt.subplots(figsize=figsize) + x = np.arange(len(types)) + colors = [CLUSTER_COLORS.get(t, "#999999") for t in types] + ax.bar(x, frac, color=colors, alpha=0.85) + ax.set_xticks(x) + ax.set_xticklabels([t.capitalize() for t in types]) + ax.set_ylim(0, 1.0) + ax.set_ylabel("Fraction pruned") + ax.set_title(title) + ax.grid(True, alpha=0.25, axis="y") + + for i, v in enumerate(frac): + ax.text(i, min(0.98, v + 0.03), f"{v:.2f}", ha="center", va="bottom", fontsize=10) + + plt.tight_layout() + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, dpi=200, bbox_inches="tight") + logger.info(f"Saved pruning-by-cluster plot to {save_path}") + + return fig + def plot_cluster_evolution( layer_results: List[Dict[str, Any]], save_path: Optional[Path] = None, diff --git a/src/alignment/analysis/visualization/paper_plots.py b/src/alignment/analysis/visualization/paper_plots.py new file mode 100644 index 00000000..5e665d56 --- /dev/null +++ b/src/alignment/analysis/visualization/paper_plots.py @@ -0,0 +1,312 @@ +""" +Paper-oriented plots for the SCAR LLM pruning draft. + +These are intentionally lightweight and deterministic, meant to produce: +- Loss-proxy concentration plots (supernode heavy-tail) +- Halo structure plots (Conn vs redundancy/protection) +- Summary plots for the mechanism evidence section +- A simple schematic diagram of the SCAR pipeline +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import matplotlib + +# Non-interactive backend for cluster jobs +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import FancyArrowPatch, FancyBboxPatch + +logger = logging.getLogger(__name__) + + +def _to_numpy(x: Any) -> np.ndarray: + try: + import torch + + if isinstance(x, torch.Tensor): + return x.detach().cpu().numpy() + except Exception: + pass + if isinstance(x, np.ndarray): + return x + return np.asarray(x) + + +def _save(fig: plt.Figure, save_path: Union[str, Path], dpi: int = 300) -> None: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=dpi, bbox_inches="tight") + logger.info(f"[Saved] {save_path}") + + +def plot_loss_proxy_concentration( + loss_proxy: Any, + rho: float = 0.01, + layer_label: str = "", + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Two-panel plot: + (Left) sorted LP values (heavy tail) + (Right) cumulative proxy mass vs fraction of channels kept + """ + lp = _to_numpy(loss_proxy).astype(np.float64).reshape(-1) + lp = lp[np.isfinite(lp)] + lp = np.maximum(lp, 0.0) + + fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + if lp.size == 0: + for ax in axes: + ax.axis("off") + return fig + + rho = float(rho) + rho = min(max(rho, 1e-6), 0.5) + + lp_sorted = np.sort(lp)[::-1] + n = lp_sorted.size + k = max(1, int(round(rho * n))) + threshold = lp_sorted[k - 1] + + total = float(lp_sorted.sum()) if float(lp_sorted.sum()) > 0 else 1.0 + cum_mass = np.cumsum(lp_sorted) / total + frac = (np.arange(n) + 1) / float(n) + top_mass = float(cum_mass[k - 1]) + + # Panel A: sorted values + ax = axes[0] + ax.plot(frac, lp_sorted, color="#2c3e50", linewidth=1.5) + ax.axvline(x=rho, color="#c0392b", linestyle="--", linewidth=2, label=f"Top {rho*100:.1f}%") + ax.set_yscale("log") + ax.set_xlabel("Fraction of channels (sorted by LP)") + ax.set_ylabel("Loss proxy (LP)") + title = "Loss-proxy heavy tail" + if layer_label: + title += f"\n{layer_label}" + ax.set_title(title) + ax.grid(True, alpha=0.25) + ax.legend(loc="upper right") + + # Panel B: cumulative mass + ax = axes[1] + ax.plot(frac, cum_mass, color="#2980b9", linewidth=2.0) + ax.axvline(x=rho, color="#c0392b", linestyle="--", linewidth=2) + ax.scatter([rho], [top_mass], color="#c0392b", zorder=5) + ax.set_xlabel("Fraction of channels kept (top by LP)") + ax.set_ylabel("Cumulative LP mass") + ax.set_ylim(0, 1.02) + ax.set_title(f"Top {rho*100:.1f}% mass = {top_mass*100:.1f}%") + ax.grid(True, alpha=0.25) + + plt.tight_layout() + + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + +def plot_halo_structure( + conn: Any, + redundancy_to_core: Any, + protect: Any, + super_mask: Any, + halo_mask: Any, + layer_label: str = "", + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, + max_points: int = 60000, +) -> plt.Figure: + """ + Two-panel plot: + (Left) Conn vs redundancy-to-core (halo channels) + (Right) Protect vs Conn (all channels; halo emphasized) + """ + conn_np = _to_numpy(conn).astype(np.float64).reshape(-1) + red_np = _to_numpy(redundancy_to_core).astype(np.float64).reshape(-1) + prot_np = _to_numpy(protect).astype(np.float64).reshape(-1) + super_np = _to_numpy(super_mask).astype(bool).reshape(-1) + halo_np = _to_numpy(halo_mask).astype(bool).reshape(-1) + + n = int(conn_np.size) + if n == 0: + fig, _ = plt.subplots(figsize=(10, 4)) + return fig + + # Downsample for plotting stability + idx_all = np.arange(n) + if n > max_points: + rng = np.random.default_rng(0) + idx_all = rng.choice(idx_all, size=max_points, replace=False) + + idx_halo = idx_all[halo_np[idx_all] & (~super_np[idx_all])] + idx_non = idx_all[(~halo_np[idx_all]) & (~super_np[idx_all])] + idx_sup = idx_all[super_np[idx_all]] + + fig, axes = plt.subplots(1, 2, figsize=(12, 4.2)) + + # Panel A: Conn vs redundancy-to-core (halo only, since redundancy is defined there) + ax = axes[0] + x = conn_np[idx_halo] + y = red_np[idx_halo] + finite = np.isfinite(x) & np.isfinite(y) + x = x[finite] + y = y[finite] + ax.scatter(x, y, s=10, alpha=0.35, color="#1f77b4", edgecolors="none") + ax.set_xlabel(r"Connectivity $\mathrm{Conn}$") + ax.set_ylabel(r"Redundancy to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") + title = "Halo redundancy structure" + if layer_label: + title += f"\n{layer_label}" + ax.set_title(title) + ax.grid(True, alpha=0.25) + if y.size > 0 and np.nanmin(y) > 0: + ax.set_yscale("log") + + # Panel B: Protect vs Conn (all channels) + ax = axes[1] + ax.scatter(conn_np[idx_non], prot_np[idx_non], s=6, alpha=0.15, color="#7f8c8d", label="Non-halo", edgecolors="none") + ax.scatter(conn_np[idx_halo], prot_np[idx_halo], s=10, alpha=0.35, color="#1f77b4", label="Halo", edgecolors="none") + if idx_sup.size > 0: + ax.scatter(conn_np[idx_sup], prot_np[idx_sup], s=14, alpha=0.7, color="#c0392b", label="Supernodes", edgecolors="none") + ax.set_xlabel(r"Connectivity $\mathrm{Conn}$") + ax.set_ylabel(r"Protection $\mathrm{Protect}$") + ax.set_title("Protection vs connectivity") + ax.set_ylim(-0.02, 1.02) + ax.grid(True, alpha=0.25) + ax.legend(loc="lower left", frameon=True) + + plt.tight_layout() + + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + +def plot_supernode_halo_summary( + layer_indices: Sequence[int], + top_mass_ratios: Sequence[float], + halo_aggregate: Dict[str, Any], + rho: float = 0.01, + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Two-panel plot: + (Left) top-rho LP mass ratio across layers + (Right) halo/non-halo redundancy summary bars (from halo_analysis.aggregate) + """ + layers = np.asarray(list(layer_indices), dtype=int) + ratios = np.asarray(list(top_mass_ratios), dtype=np.float64) + + fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + + ax = axes[0] + ax.plot(layers, ratios, "o-", color="#2c3e50", linewidth=2) + ax.set_xlabel("Layer index") + ax.set_ylabel(f"Top-{rho*100:.1f}% LP mass ratio") + ax.set_ylim(0, 1.02) + ax.set_title("Supernode concentration across layers") + ax.grid(True, alpha=0.25) + + ax = axes[1] + groups = [("Within-Halo", "halo_halo"), ("Within-Non-Halo", "non_halo"), ("Cross", "cross")] + means = [] + stds = [] + for _, key in groups: + rec = halo_aggregate.get(key) or {} + means.append(float(rec.get("mean", 0.0))) + stds.append(float(rec.get("std", 0.0))) + + x = np.arange(len(groups)) + ax.bar(x, means, yerr=stds, capsize=4, color=["#1f77b4", "#7f8c8d", "#2ecc71"], alpha=0.85) + ax.set_xticks(x) + ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") + ax.set_ylabel("Redundancy (Gaussian MI, nats)") + ax.set_title("Halo redundancy vs non-halo (avg.)") + ax.grid(True, alpha=0.25, axis="y") + + plt.tight_layout() + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + +def plot_scar_schematic( + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Generate a simple schematic of SCAR (supernodes + halos) as a flowchart. + This is model-agnostic and can be generated during artifact collection. + """ + fig = plt.figure(figsize=(12, 4.5)) + ax = fig.add_subplot(111) + ax.set_axis_off() + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + def box(x, y, w, h, text, fc="#ecf0f1", ec="#2c3e50"): + p = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.02,rounding_size=0.02", + linewidth=1.6, + edgecolor=ec, + facecolor=fc, + ) + ax.add_patch(p) + ax.text(x + w / 2, y + h / 2, text, ha="center", va="center", fontsize=10) + + def arrow(x1, y1, x2, y2, color="#2c3e50"): + a = FancyArrowPatch((x1, y1), (x2, y2), arrowstyle="->", linewidth=1.6, color=color, mutation_scale=12) + ax.add_patch(a) + + # Left: FFN depiction (conceptual) + box(0.02, 0.62, 0.18, 0.26, "FFN layer\n(MLP channels)", fc="#f8f9f9") + ax.text(0.11, 0.71, "u channels", ha="center", va="center", fontsize=9) + # Draw "channels": a few vertical ticks, color some as supernodes/halo + for i, x in enumerate(np.linspace(0.05, 0.19, 9)): + c = "#7f8c8d" + if i in (2, 3): + c = "#c0392b" # supernodes + if i in (5, 6): + c = "#1f77b4" # halo + ax.plot([x, x], [0.66, 0.84], color=c, linewidth=3) + ax.text(0.03, 0.58, "Supernodes (red): high LP\nHalo (blue): high Conn + redundant", fontsize=9, ha="left", va="top") + + # Middle: compute steps + box(0.28, 0.70, 0.20, 0.20, "Calibration\nforward+backward", fc="#fdf2e9", ec="#d35400") + box(0.52, 0.70, 0.22, 0.20, r"Loss proxy\n$\mathrm{LP}_i=\frac12\mathbb{E}[(u_i s_i)^2]$", fc="#fdf2e9", ec="#d35400") + box(0.78, 0.70, 0.20, 0.20, r"Supernodes\n(top-$\rho$ by LP)\nprotect core", fc="#fdebd0", ec="#c0392b") + + arrow(0.20, 0.80, 0.28, 0.80) + arrow(0.48, 0.80, 0.52, 0.80) + arrow(0.74, 0.80, 0.78, 0.80) + + # Bottom: halo + redundancy + pruning + box(0.28, 0.35, 0.22, 0.20, r"Connectivity\n$\mathrm{Conn}_j$ from $|v_j|$ overlap", fc="#e8f6ff", ec="#2980b9") + box(0.54, 0.35, 0.20, 0.20, r"Halo\n(top-$\eta$ non-core by Conn)", fc="#e8f6ff", ec="#2980b9") + box(0.78, 0.35, 0.20, 0.20, r"Redundancy\n$\mathrm{Red}^{\rightarrow\mathcal{M}}$ from $q=u\!\odot\!s$", fc="#eafaf1", ec="#27ae60") + box(0.52, 0.06, 0.46, 0.20, r"Score + prune\n(prune low-$\mathrm{LP}$ first,\nboost halo followers; respect caps)", fc="#f8f9f9", ec="#2c3e50") + + arrow(0.62, 0.70, 0.39, 0.55) + arrow(0.50, 0.45, 0.54, 0.45) + arrow(0.74, 0.45, 0.78, 0.45) + arrow(0.88, 0.35, 0.75, 0.26) + arrow(0.64, 0.35, 0.64, 0.26) + + ax.text(0.02, 0.97, "SCAR schematic (supernodes + halos for structured FFN channel pruning)", fontsize=12, fontweight="bold", ha="left", va="top") + + plt.tight_layout() + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index 4590a908..c4b9bc65 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -4,6 +4,7 @@ Supports both original format and unified format configs. """ +import ast import json import logging import os @@ -1214,7 +1215,8 @@ def load_config_with_overrides( elif low in {"none", "null"}: value = None else: - value = eval(value) + # Parse Python-literal values (lists, dicts, numbers, quoted strings) safely. + value = ast.literal_eval(value) except Exception: pass # Keep as string diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index d8fff89d..3266618d 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -2272,11 +2272,11 @@ def compute_baseline_pruning_scores( m = re.search(r"layers\.(\d+)\.mlp", k) if m: layer_indices.add(int(m.group(1))) - + if not layer_indices: logger.warning("No MLP layers found in importance_scores; cannot compute baseline channel scores.") return {} - + underlying_model = self._get_underlying_model() module_dict = dict(underlying_model.named_modules()) @@ -2393,17 +2393,17 @@ def _resolve_mlp_path(layer_idx: int) -> Optional[str]: if store_name not in self.importance_scores: self.importance_scores[store_name] = {} self.importance_scores[store_name]["sparsegpt"] = channel_scores - + if store_name not in results: results[store_name] = {} results[store_name]["sparsegpt"] = channel_scores - + logger.debug( f"SparseGPT channel scores for {mlp_path}: shape={tuple(channel_scores.shape)}, mean={channel_scores.mean().item():.4f}" ) except Exception as e: logger.warning(f"Failed to compute SparseGPT channel scores for {mlp_path}: {e}") - + logger.info(f"SparseGPT: computed channel scores for {len(layer_indices)} MLP layers") except Exception as e: logger.error(f"SparseGPT calibration failed: {e}") @@ -4044,7 +4044,7 @@ def compute_halo_redundancy_within_hidden_outputs( """ (Legacy/diagnostic) Compute redundancy among *hidden-dimension* output neurons that are strongly influenced by supernodes. - + Note: This is NOT the SCAR paper's "directed redundancy" (which is defined on loss-relevant per-channel contribution signals). This helper is kept for exploratory plots and is not used for pruning decisions. @@ -4601,19 +4601,19 @@ def compute_supernode_connectivity_pruning_score( logger.info("Computing SCAR halo connectivity + protection pruning scores...") logger.info(f" Supernode fraction (rho): {supernode_fraction*100:.1f}%") logger.info(f" Halo fraction (eta): {high_connectivity_fraction*100:.1f}%") - + eps = 1e-8 results: Dict[str, Dict[str, Any]] = {} supernode_cfg = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} positive_redundancy = bool(supernode_cfg.get("positive_redundancy", False)) if positive_redundancy: logger.info(" Redundancy: using positive-only correlation (anti-correlation does NOT count as redundancy)") - + # Underlying HF model for module lookup / hook registration hf_model = self.model if hasattr(hf_model, "model"): hf_model = hf_model.model - + module_dict = dict(hf_model.named_modules()) # Calibration texts @@ -4623,7 +4623,7 @@ def compute_supernode_connectivity_pruning_score( if not calibration_texts: logger.warning("No calibration texts available for SCAR protection/connectivity computation") return {} - + # Determine which layers to process (down_proj layers only) layer_names = [ln for ln in scar_scores.keys() if "mlp.down_proj" in ln] if not layer_names: @@ -4647,19 +4647,19 @@ def compute_supernode_connectivity_pruning_score( m = lp_cpu.numel() if m == 0: continue - + module = module_dict.get(layer_name) if module is None or not hasattr(module, "weight"): logger.warning(f"SCAR connectivity: could not resolve module/weight for {layer_name}") continue - + # Identify supernodes by LP num_supernodes = max(1, int(supernode_fraction * m)) _, super_idx = torch.topk(lp_cpu, k=num_supernodes, largest=True) super_idx = super_idx.long() super_mask = torch.zeros(m, dtype=torch.bool) super_mask[super_idx] = True - + # Compute Conn_i from down_proj weights (write-pattern overlap) W = module.weight.detach().float().cpu() # [hidden_dim, m] abs_W = W.abs() @@ -4761,7 +4761,7 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: q_halo = q_sel[:, n_super:] # [N, |H|] N = q_sel.shape[0] - + # Initialize streaming sums on first batch if st["sum_q_super"] is None: st["sum_q_super"] = torch.zeros(q_super.shape[1], device=q_super.device, dtype=torch.float32) @@ -4876,6 +4876,13 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: protect_full[halo_idx] = protect_halo protect_full[super_idx] = 1.0 + # Store redundancy-to-core in full channel space (defined only for halo channels) + redundancy_full = torch.full((m,), float("nan"), dtype=torch.float32) + try: + redundancy_full[halo_idx] = redundancy_to_core.float() + except Exception: + pass + # SCAR-Prot and SCAR-Conn importance scores (high=keep) prot_score = (lp * protect_full).float() conn_score = (lp * ((1.0 - conn) + conn * protect_full)).float() @@ -4897,10 +4904,11 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: layer_scores["supernode_connectivity_score"] = conn_score layer_scores["connectivity_score"] = conn layer_scores["protection_score"] = protect_full + layer_scores["redundancy_to_core"] = redundancy_full layer_scores["halo_mask"] = halo_mask layer_scores["supernode_mask"] = super_mask self.importance_scores[layer_name] = layer_scores - + results[layer_name] = { "num_supernodes": int(super_idx.numel()), "num_halo": int(halo_idx.numel()), @@ -4909,7 +4917,7 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: "protect_halo_mean": float(protect_halo.mean().item()) if protect_halo.numel() else 0.0, "redundancy_to_core_mean": float(redundancy_to_core.mean().item()) if redundancy_to_core.numel() else 0.0, } - + logger.info(f"Computed SCAR protection/connectivity scores for {len(results)} layers") return results @@ -4924,7 +4932,7 @@ def analyze_halo_vs_nonhalo_redundancy( ) -> Dict[str, Dict[str, Any]]: """ Paper-aligned halo redundancy analysis using the loss-relevant contribution signal. - + We compare redundancy between three groups (per layer), then aggregate across layers: 1) **Halo-Halo**: both channels in the halo (high Conn to supernode write pattern) 2) **Non-halo**: both channels outside halo and outside supernodes @@ -4943,7 +4951,7 @@ def analyze_halo_vs_nonhalo_redundancy( - Supernodes are identified by `scar_loss_proxy` when available (paper definition). - Halo membership is identified by Conn overlap with the aggregated supernode write pattern (same as `compute_supernode_connectivity_pruning_score`). - + Returns: Dict with: - per_layer: per-layer group stats @@ -5036,7 +5044,7 @@ def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: lp = layer_metrics.get("scar_activation_power") if lp is None: continue - + lp_cpu = lp.detach().float().cpu() m = int(lp_cpu.numel()) if m <= 0: @@ -5046,14 +5054,14 @@ def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: if module is None or not hasattr(module, "weight"): logger.warning(f"Halo redundancy: could not resolve module/weight for {layer_name}") continue - + # Identify supernodes by LP (paper definition) num_supernodes = max(1, int(supernode_fraction * m)) _, super_idx = torch.topk(lp_cpu, k=num_supernodes, largest=True) super_idx = super_idx.long() super_mask = torch.zeros(m, dtype=torch.bool) super_mask[super_idx] = True - + # Compute Conn_i from down_proj weights (write-pattern overlap) W = module.weight.detach().float().cpu() # [hidden_dim, m] abs_W = W.abs() @@ -5069,7 +5077,7 @@ def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: num_halo = max(1, int(halo_fraction * non_super_idx.numel())) _, halo_rel = torch.topk(conn[non_super_idx], k=num_halo, largest=True) halo_idx = non_super_idx[halo_rel].long() - + halo_mask = torch.zeros(m, dtype=torch.bool) halo_mask[halo_idx] = True non_halo_idx = ((~super_mask) & (~halo_mask)).nonzero(as_tuple=True)[0].long() @@ -5145,7 +5153,7 @@ def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: # Phase 2: Calibration passes (forward+backward) to accumulate q correlations # ------------------------------------------------------------------ hooks: List[Any] = [] - + def make_hooks(name: str): def fwd_hook(mod: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor): if not inputs or inputs[0] is None: @@ -5282,7 +5290,7 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: h.remove() except Exception: pass - + # ------------------------------------------------------------------ # Phase 3: Compute redundancy distributions and aggregate across layers # ------------------------------------------------------------------ @@ -5300,7 +5308,7 @@ def corr_to_red(corr: torch.Tensor) -> torch.Tensor: N = int(st.get("count", 0)) if N <= 1 or st["sum_qij_hh"] is None: continue - + sum_q_h = st["sum_q_halo"].detach().cpu() sum_q2_h = st["sum_q2_halo"].detach().cpu() sum_q_n = st["sum_q_nonhalo"].detach().cpu() @@ -5324,7 +5332,7 @@ def corr_to_red(corr: torch.Tensor) -> torch.Tensor: e_hh = st["sum_qij_hh"].detach().cpu() / float(N) e_nn = st["sum_qij_nn"].detach().cpu() / float(N) e_cn = st["sum_qij_cross"].detach().cpu() / float(N) - + # corr (halo-halo) cov = e_hh - (mean_h[hh_i] * mean_h[hh_j]) corr_hh = cov / (std_h[hh_i] * std_h[hh_j] + eps) @@ -5362,7 +5370,7 @@ def stats(x: torch.Tensor) -> Dict[str, float]: agg_vals["halo_halo"].extend(red_hh.tolist()) agg_vals["non_halo"].extend(red_nn.tolist()) agg_vals["cross"].extend(red_cn.tolist()) - + aggregate_stats: Dict[str, Dict[str, Any]] = {} for group, vals in agg_vals.items(): if not vals: @@ -5375,12 +5383,12 @@ def stats(x: torch.Tensor) -> Dict[str, float]: "median": float(np.median(arr)), "count": int(arr.size), } - + logger.info("\nHALO vs NON-HALO REDUNDANCY SUMMARY (q-signal):") logger.info(f" Halo-Halo: mean={aggregate_stats['halo_halo']['mean']:.4f}") logger.info(f" Non-halo: mean={aggregate_stats['non_halo']['mean']:.4f}") logger.info(f" Cross-group: mean={aggregate_stats['cross']['mean']:.4f}") - + return { "signal": "q", "positive_redundancy": positive_redundancy, @@ -7224,6 +7232,159 @@ def restore_weights(): except Exception as e: logger.error(f"Failed to generate pruning visualizations: {e}") + # ------------------------------------------------------------------ + # Paper-oriented mechanism figures (supernodes + halo structure) + # ------------------------------------------------------------------ + if getattr(self.config, "generate_plots", True): + try: + from alignment.analysis.visualization import ( + plot_halo_structure, + plot_loss_proxy_concentration, + plot_supernode_halo_summary, + ) + + plots_dir = Path(getattr(self.config, "plots_dir", Path(self.config.log_dir) / "plots")) + paper_dir = plots_dir / "paper" + paper_dir.mkdir(parents=True, exist_ok=True) + + # 1) Loss proxy concentration for a representative layer + rho = float((getattr(self.config, "supernode", {}) or {}).get("core_fraction", 0.01)) + down_layers = sorted([k for k in scar_scores.keys() if "mlp.down_proj" in k]) + if down_layers: + # Choose a stable "middle" layer as representative + mid_layer = down_layers[len(down_layers) // 2] + lp = scar_scores.get(mid_layer, {}).get("scar_loss_proxy") + if lp is not None: + plot_loss_proxy_concentration( + loss_proxy=lp, + rho=rho, + layer_label=mid_layer, + save_path=paper_dir / "fig_supernode_distribution.png", + dpi=getattr(self.config, "plot_dpi", 300), + ) + + # 2) Halo structure (global): aggregate across many layers for a cleaner story + if down_layers: + conn_all = [] + prot_all = [] + red_all = [] + halo_all = [] + super_all = [] + for ln in down_layers: + layer_scores = self.importance_scores.get(ln, {}) + conn = layer_scores.get("connectivity_score") + prot = layer_scores.get("protection_score") + red = layer_scores.get("redundancy_to_core") + halo_mask = layer_scores.get("halo_mask") + super_mask = layer_scores.get("supernode_mask") + if ( + conn is None + or prot is None + or red is None + or halo_mask is None + or super_mask is None + ): + continue + # Ensure consistent shapes + try: + if conn.numel() == 0 or conn.numel() != prot.numel() or conn.numel() != halo_mask.numel(): + continue + if red.numel() != conn.numel() or super_mask.numel() != conn.numel(): + continue + except Exception: + continue + + conn_all.append(conn.detach().cpu()) + prot_all.append(prot.detach().cpu()) + red_all.append(red.detach().cpu()) + halo_all.append(halo_mask.detach().cpu()) + super_all.append(super_mask.detach().cpu()) + + if conn_all: + import torch + + conn_cat = torch.cat(conn_all, dim=0) + prot_cat = torch.cat(prot_all, dim=0) + red_cat = torch.cat(red_all, dim=0) + halo_cat = torch.cat(halo_all, dim=0) + super_cat = torch.cat(super_all, dim=0) + + plot_halo_structure( + conn=conn_cat, + redundancy_to_core=red_cat, + protect=prot_cat, + super_mask=super_cat, + halo_mask=halo_cat, + layer_label="All layers (aggregated)", + save_path=paper_dir / "fig_halo_structure.png", + dpi=getattr(self.config, "plot_dpi", 300), + ) + + # 2b) Halo structure (example layer): keep a representative layer for debugging/supplementary + if down_layers: + mid_layer = down_layers[len(down_layers) // 2] + layer_scores = self.importance_scores.get(mid_layer, {}) + conn = layer_scores.get("connectivity_score") + prot = layer_scores.get("protection_score") + red = layer_scores.get("redundancy_to_core") + halo_mask = layer_scores.get("halo_mask") + super_mask = layer_scores.get("supernode_mask") + if conn is not None and prot is not None and red is not None and halo_mask is not None and super_mask is not None: + plot_halo_structure( + conn=conn, + redundancy_to_core=red, + protect=prot, + super_mask=super_mask, + halo_mask=halo_mask, + layer_label=mid_layer, + save_path=paper_dir / "fig_halo_structure_layer.png", + dpi=getattr(self.config, "plot_dpi", 300), + ) + + # 3) Supernode mass ratio across layers + halo redundancy summary + try: + halo_agg = (results.get("halo_analysis") or {}).get("aggregate") or {} + # Compute top-rho mass ratio per layer from scar_loss_proxy + layer_idxs: List[int] = [] + ratios: List[float] = [] + for ln in down_layers: + lp = scar_scores.get(ln, {}).get("scar_loss_proxy") + if lp is None: + continue + lp_cpu = lp.detach().float().cpu() + m = int(lp_cpu.numel()) + if m <= 0: + continue + k = max(1, int(round(rho * m))) + top = torch.topk(lp_cpu, k=k, largest=True).values + denom = float(lp_cpu.sum().item()) if float(lp_cpu.sum().item()) > 0 else 1.0 + ratio = float(top.sum().item()) / denom + try: + idx = int(ln.split("layers.")[-1].split(".")[0]) + except Exception: + idx = len(layer_idxs) + layer_idxs.append(idx) + ratios.append(ratio) + + if layer_idxs and halo_agg: + # Sort by layer index for plotting + order = np.argsort(np.asarray(layer_idxs)) + layer_idxs_sorted = [layer_idxs[i] for i in order] + ratios_sorted = [ratios[i] for i in order] + plot_supernode_halo_summary( + layer_indices=layer_idxs_sorted, + top_mass_ratios=ratios_sorted, + halo_aggregate=halo_agg, + rho=rho, + save_path=paper_dir / "fig_supernode_analysis.png", + dpi=getattr(self.config, "plot_dpi", 300), + ) + except Exception as _summary_err: + logger.debug(f"Paper summary plot skipped: {_summary_err}") + + except Exception as e: + logger.warning(f"Failed to generate paper mechanism figures: {e}") + return results From 5cf00922f912dbc15cc5fa8bfd66e913213b2a77 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 15:33:19 -0500 Subject: [PATCH 10/12] fix nan/inf --- slurm-54745862.out | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100644 slurm-54745862.out diff --git a/slurm-54745862.out b/slurm-54745862.out deleted file mode 100644 index b8070148..00000000 --- a/slurm-54745862.out +++ /dev/null @@ -1,37 +0,0 @@ -============================================== -Submitting SCAR Paper Experiments -============================================== - -Output directory: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM - -Submitting LLaMA-3.1-8B (main results)... - Job ID: 54745874 -Submitting Mistral-7B (generalization)... - Job ID: 54745875 -Submitting LLaMA-2-7B (generalization)... - Job ID: 54745878 -Submitting Qwen2-7B (generalization)... - Job ID: 54745879 - -============================================== -All jobs submitted! -============================================== - -Job IDs: 54745874, 54745875, 54745878, 54745879 - -Monitor with: - squeue -u $USER - -View SLURM logs: - tail -f logs/paper_llama3_8b_54745874.out - tail -f logs/paper_mistral_7b_54745875.out - tail -f logs/paper_llama2_7b_54745878.out - tail -f logs/paper_qwen2_7b_54745879.out - -Expected runtime: ~6-8 hours per job - -Results will be in: - /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/llama3_8b_paper_results_*_54745874/ - /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/mistral_7b_paper_results_*_54745875/ - /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/llama2_7b_paper_results_*_54745878/ - /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/qwen2_7b_paper_results_*_54745879/ From 99b98391267748d59ddf73c699540feffd9fd3b1 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 15:39:08 -0500 Subject: [PATCH 11/12] fix nan/inf --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 29d16233..ba8ce996 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,10 @@ Thumbs.db *.swo *~ +# SLURM default output files (created when submitting scripts without explicit --output) +slurm-*.out +slurm-*.err + # OS .DS_Store .DS_Store? From f0dd89e74466acbe79f9ee6c1b91eaee9a49ac05 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Fri, 9 Jan 2026 19:38:58 -0500 Subject: [PATCH 12/12] fix import --- src/alignment/experiments/llm_experiments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 3266618d..f9c3aa93 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -7301,8 +7301,6 @@ def restore_weights(): super_all.append(super_mask.detach().cpu()) if conn_all: - import torch - conn_cat = torch.cat(conn_all, dim=0) prot_cat = torch.cat(prot_all, dim=0) red_cat = torch.cat(red_all, dim=0)