Skip to content

Add CaloClusterGNN training pipeline#7

Open
zwl0331 wants to merge 2 commits intoMu2e:mainfrom
zwl0331:add-calo-cluster-gnn
Open

Add CaloClusterGNN training pipeline#7
zwl0331 wants to merge 2 commits intoMu2e:mainfrom
zwl0331:add-calo-cluster-gnn

Conversation

@zwl0331
Copy link
Copy Markdown

@zwl0331 zwl0331 commented May 4, 2026

Summary

Adds a new CaloClusterGNN/ subdirectory containing the full training
pipeline for a Graph Neural Network calorimeter-clustering algorithm,
intended to run alongside the existing seed+BFS CaloClusterMaker in
Mu2e Offline. The deployed recipe is CCN+BFS10: CaloClusterNet
edge classifier followed by BFS-style traversal at ExpandCut = 10 MeV.

Layout follows the convention set by TrkQual/ and TrkPID/ — one
top-level subdirectory per algorithm, self-contained.

What's in here

CaloClusterGNN/
├── README.md            retrainer-facing docs                                                                                                              
├── setup_env.sh         wraps setupmu2e-art.sh + ana 2.6.1
├── src/                                                                                                                                                    
│   ├── data/            graph builder, calo-entrant truth labels,
│   │                    normalisation, packed dataset                                                                                                      
│   ├── geometry/        crystalId -> (x, y, disk) loader
│   ├── models/          SimpleEdgeNet, CaloClusterNet, layers, heads,                                                                                      
│   │                    deploy wrappers for ONNX export                                                                                                    
│   ├── training/        losses, metrics, trainer                                                                                                           
│   └── inference/       cluster_reco + postprocess (used by train-time                                                                                     
│                        evaluation scripts)                                                                                                                
├── scripts/             build/pack/train/tune/evaluate pipeline,                                                                                           
│                        ONNX export + parity validation,                                                                                                   
│                        failure audits, cluster-physics evaluation,                                                                                        
│                        Run1B no-field generalisation evaluation                                                                                           
├── configs/             5 YAML configs (one per training run)                                                                                              
├── tests/               110 unit tests (4 conditionally skipped on a                                                                                       
│                        fresh checkout for missing checkpoints)                                                                                            
├── splits/              frozen 35/7/8 v2 file split                                                                                                        
└── data/                small geometry CSVs (crystal_geometry.csv etc.)                                                                                    

Two model classes train from the same pipeline:

Model Params Frozen tau_edge Use
SimpleEdgeNet 215 K 0.26 Reference / A-B comparison
CaloClusterNet 676 K 0.20 Production model (CCN+BFS10)

Both share the input graph (one per calorimeter disk per event,
6 node features + 8 edge features) and z-score normalisation, so
swapping models in deployment is config-only.

Headline result

On the MDC2025 mixed-pileup test set (276,688 events, 481,543
disk-graphs), CCN+BFS10 beats BFS on every downstream-relevant
cluster-physics metric for E_reco >= 50 MeV clusters (those that
matter for track finding):

Metric BFS CCN+BFS10 Change
Mean abs(dE) / MeV 0.839 0.616 -27%
95th-pct abs(dE) / MeV 3.520 2.338 -34%
Mean centroid dr / mm 1.589 1.292 -19%
95th-pct dr / mm 3.606 2.294 -36%

In the 95-110 MeV signal region (47,279 clusters), mean abs(dE)
drops from 0.368 (BFS) to 0.210 (-43%) and mean dr drops from
0.559 mm to 0.460 mm (-18%).

Reproducibility

After source setup_env.sh:

# 1. Build per-disk graphs from EventNtuple ROOT files (~10 min CPU).
bash scripts/build_all_graphs.sh                                                                                                                            
python3 scripts/pack_graphs.py
                                                                                                                                                            
# 2. Train (CCN; production model).                    
python3 scripts/train_gnn.py \                                                                                                                              
    --config configs/calo_cluster_net.yaml \
    --device cuda --run-name calo_cluster_net_v2_stage1                                                                                                     
                                                                                                                                                            
# 3. Tune the edge threshold on val.                                                                                                                        
python3 scripts/tune_threshold.py \                                                                                                                         
    --config configs/calo_cluster_net.yaml \                                                                                                                
    --checkpoint outputs/runs/calo_cluster_net_v2_stage1/checkpoints/best_model.pt                                                                          
                                                                                                                                                            
# 4. Evaluate once on test.                            
OMP_NUM_THREADS=4 PYTHONUNBUFFERED=1 python3 -u scripts/evaluate_test.py                                                                                    
OMP_NUM_THREADS=4 PYTHONUNBUFFERED=1 python3 -u scripts/evaluate_cluster_physics.py                                                                         
                                                                                                                                                            
# 5. Export to ONNX for deployment.                                                                                                                         
python3 scripts/export_onnx.py --model ccn   # also --model sen                                                                                             
python3 scripts/export_norm_stats.py                                                                                                                        
python3 scripts/validate_onnx.py --model ccn                                                                                                                

Frozen hyperparameters and the exact recipe values live in configs/
and are documented in the README.

Coordinated PRs

  • Mu2e/EventNtuple#TBD — adds calomcsim.ancestorSimIds to
    SimInfo. The v2 training data uses calo-entrant ancestor truth,
    which requires this branch. Link this PR into the EventNtuple PR
    once it has a number.
  • Mu2e/Offline#TBD (pending) — the C++ inference modules
    (CaloHitGraphMaker, CaloClusterMakerGNN) under
    Offline/CaloCluster/. Loads the .onnx exported by this repo
    via art::ConfigFileLookupPolicy, asserts metadata_props
    agreement (model_version, node_features, edge_features)
    against FHiCL, and emits CaloClusterCollection under instance
    name "GNN" so existing BFS-reading analyses keep working.
    C++↔Python parity has been validated byte-exactly on the val
    split (100/100 disk-graphs, 8,502 hits) using a parity-dump
    analyzer + Python comparison harness.

Tests

$ source setup_env.sh
$ python3 -m unittest discover -s tests -p "test_*.py" -v
...                                                                                                                                                         
Ran 110 tests in 0.16s
OK (skipped=4)                                                                                                                                              

The 4 skipped tests are conditional — they exercise loading a real
trained checkpoint or the exported .onnx and self-skip with a clear
message when those files aren't in the local checkout (the case for a
fresh clone).

Acknowledgement

Implementation, refactoring, and documentation drafting in this
subdirectory were assisted by Anthropic's Claude (Claude Code). All
scientific decisions, hyperparameter choices, validation results, and
the v1→v2 truth-definition campaign are my own work.


Notes:

  • The Mu2e/EventNtuple#TBD and Mu2e/Offline#TBD placeholders should be edited in once you open those PRs (or after the EventNtuple PR you just pushed gets a
    number — easy edit on the GitHub PR page).
  • Title is 35 chars (well under 72).
  • If you want me to run gh pr create --title "..." --body "..." for you, say the word — I'll need the auth state to be ready (gh auth login). Otherwise
    paste it into the web UI; it'll render the tables and code blocks correctly.

zwl0331 added 2 commits May 4, 2026 15:13
Adds a CaloClusterGNN/ subdirectory containing the full training
pipeline for the GNN calorimeter-clustering algorithm intended as a
parallel to the existing seed+BFS CaloClusterMaker in Mu2e Offline.
The deployed recipe is "CCN+BFS10": CaloClusterNet edge classifier +
BFS-style traversal with ExpandCut = 10 MeV.

Layout (modelled on TrkQual/, but Python-package shaped):

  CaloClusterGNN/
    README.md                   how to retrain, frozen hyperparams,
                                deployment cross-link
    setup_env.sh                wraps setupmu2e-art.sh + ana 2.6.1
    src/
      data/                     graph builder, calo-entrant truth labels,
                                normalisation, packed dataset
      geometry/                 crystalId -> (x, y, disk) loader
      models/                   SimpleEdgeNet, CaloClusterNet, layers, heads
      training/                 losses, metrics, trainer
      inference/                cluster reconstruction (cluster_reco.py),
                                postprocess (kept here so train-time eval
                                scripts work end to end)
    scripts/                    build/pack/train/tune/evaluate pipeline,
                                failure audits, cluster-physics eval, ancestry
                                validation, run1B no-field eval, plotting
    configs/                    five YAML configs (one per training run)
    tests/                      88 unit tests covering all of src/ above
    splits/                     frozen 35/7/8 v2 split file lists
    data/                       crystal_geometry.csv + crystal_neighbors.csv +
                                crystal_map_raw.csv (small lookup tables)

What does NOT live here:
* The deployment-side ONNX export / parity scripts (export_onnx.py,
  export_norm_stats.py, validate_onnx.py, dump_parity_payloads.py,
  compare_parity_dump.py) and the deploy wrappers
  (calo_cluster_net_deploy.py, simple_edge_net_deploy.py) -- those
  belong with the Mu2e/Offline integration PR, not the training repo.
* The `.onnx` artifacts themselves (shipped via Mu2e data area, not
  versioned in MLTrain -- same convention TrkQual follows).
* Large run outputs and processed graphs (regenerable from
  EventNtuple ROOT files via scripts/build_all_graphs.sh).

The v2 training data requires the `calomcsim.ancestorSimIds` branch
added in Mu2e/EventNtuple (PR pending). README cross-links there
once the EventNtuple PR has a number.

Test suite: 88/88 passing in this layout via
`python3 -m unittest discover -s tests -p "test_*.py" -v` after
`source setup_env.sh`.
Both trained models in CaloClusterGNN/ now have a complete training-
to-ONNX path inside MLTrain (consistent with the TrkQual pattern of
shipping conversion scripts alongside training).

New / restored:
* src/models/calo_cluster_net_deploy.py   tensor-API wrapper around
  CaloClusterNet (no PyG Data, no node-saliency head); used by ONNX
  export so torch.onnx.export can trace it.
* src/models/simple_edge_net_deploy.py    same shape for SimpleEdgeNet.
  No node head to bypass, so it's a thin pass-through.
* scripts/export_onnx.py                  --model {ccn,sen} flag with
  per-model presets (checkpoint, output path, model_version). Stamps
  metadata_props {model_version, node_features, edge_features} into
  the .onnx after export.
* scripts/export_norm_stats.py            writes the train-split z-score
  stats next to the .onnx as a flat JSON sidecar so the C++ side
  doesn't need a LibTorch dep to read 28 floats.
* scripts/validate_onnx.py                --model flag with per-model
  preset for tau_edge and tolerance. Asserts:
    - max abs-diff edge_logits within tol on the full val split
    - zero per-edge threshold flips at tau_edge (proxy for cluster-
      reco byte-equivalence with the deployed C++ pipeline)
* tests/test_calo_cluster_net_deploy.py   (9 tests)
* tests/test_export_onnx.py               (5 tests)
* tests/test_export_norm_stats.py         (8 tests)

README extended with an "Exporting a Trained Model to ONNX" section
that documents the full chain for both models, the
metadata_props deployment contract, and the per-model frozen
tau_edge/tol values used by validate_onnx.py.

Test count goes from 88 to 110 (4 conditionally skipped on a fresh
checkout when no trained checkpoint is present locally; this is by
design and the skip messages name the missing file).

Also acknowledges Claude assistance in README.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant