diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCCluster.java similarity index 72% rename from reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java rename to reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCCluster.java index f0df2e98da..26f5a016fa 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCCluster.java @@ -1,4 +1,4 @@ -package org.jlab.rec.ahdc.Cluster; +package org.jlab.rec.ahdc.AHDCCluster; import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.PreCluster.PreCluster; @@ -9,9 +9,9 @@ import org.jlab.geom.prim.Point3D; /** - * Cluster are compose by 2 PreCluster on layer with a different stereo angle + * AHDCCluster are compose by 2 PreCluster on layer with a different stereo angle */ -public class Cluster { +public class AHDCCluster { private int _trackId = -1; private double _Radius; @@ -50,7 +50,7 @@ private static double stereoTwistFromLine(Line3D line) { return wrapPi(phi1 - phi0); } - public Cluster(PreCluster precluster, PreCluster other_precluster) { + public AHDCCluster(PreCluster precluster, PreCluster other_precluster) { this._PreClusters_list = new ArrayList<>(); _PreClusters_list.add(precluster); _PreClusters_list.add(other_precluster); @@ -80,15 +80,44 @@ public Cluster(PreCluster precluster, PreCluster other_precluster) { this._V = this._Y / (this._X * this._X + this._Y * this._Y); } - public Cluster(double X, double Y, double Z) { + public AHDCCluster(double X, double Y, double Z) { this._X = X; this._Y = Y; this._Z = Z; } + /** Build an AHDCCluster from a single PreCluster (one layer of a superlayer). + * Used by the GNN path when a track covers a superlayer on only one + * stereo layer — no stereo pair is available, so Z is taken from the + * average wire-midpoint z of the PreCluster's hits rather than from a + * stereo-angle computation. DocaClusterRefiner falls back to a degenerate + * DocaCluster when {@code get_PreClusters_list().size() != 2}, so + * downstream is unaffected. */ + public AHDCCluster(PreCluster precluster) { + this._PreClusters_list = new ArrayList<>(); + _PreClusters_list.add(precluster); + this._Radius = precluster.get_Radius(); + this._Phi = precluster.get_Phi(); + this._X = precluster.get_X(); + this._Y = precluster.get_Y(); + this._Num_wire = (int) precluster.get_Num_wire(); + double r2 = this._X * this._X + this._Y * this._Y; + if (r2 > 0.0) { + this._U = this._X / r2; + this._V = this._Y / r2; + } + double zSum = 0.0; + int zCount = 0; + for (Hit h : precluster.get_hits_list()) { + Line3D line = h.getLine(); + if (line != null) { zSum += line.midpoint().z(); zCount++; } + } + this._Z = (zCount > 0) ? zSum / zCount : 0.0; + } + @Override public String toString() { - return "Cluster{" + "_X=" + _X + ", _Y=" + _Y + ", _Z=" + _Z + '}'; + return "AHDCCluster{" + "_X=" + _X + ", _Y=" + _Y + ", _Z=" + _Z + '}'; } public ArrayList get_PreClusters_list() { diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCClusterFinder.java similarity index 86% rename from reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java rename to reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCClusterFinder.java index 87a4446b1c..36f85cae84 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AHDCCluster/AHDCClusterFinder.java @@ -1,21 +1,21 @@ -package org.jlab.rec.ahdc.Cluster; +package org.jlab.rec.ahdc.AHDCCluster; import org.jlab.rec.ahdc.PreCluster.PreCluster; import java.util.ArrayList; import java.util.List; -/** ClusterFinder +/** AHDCClusterFinder * * \todo description of what it does and how it works * */ -public class ClusterFinder { +public class AHDCClusterFinder { - private final ArrayList _AHDCClusters = new ArrayList<>(); - private final ArrayList _list_with_maybe_same_cluster = new ArrayList<>(); + private final ArrayList _AHDCClusters = new ArrayList<>(); + private final ArrayList _list_with_maybe_same_cluster = new ArrayList<>(); - public ClusterFinder() {} + public AHDCClusterFinder() {} private void find_associate_cluster(PreCluster precluster, List AHDC_precluster_list, int window, int minimal_distance, int super_layer, int layer, int associate_super_layer) { //System.out.println(" precluster superlayer " + precluster.get_Super_layer() + " ref superlayer " + super_layer + " layer " + precluster.get_Layer() + " ref " + layer); @@ -52,7 +52,7 @@ private void find_associate_cluster(PreCluster precluster, List AHDC if (best_precluster != null) { precluster.set_Used(true); best_precluster.set_Used(true); - Cluster new_Cluster = new Cluster(precluster, best_precluster); + AHDCCluster new_Cluster = new AHDCCluster(precluster, best_precluster); _list_with_maybe_same_cluster.add(new_Cluster); } } @@ -81,18 +81,18 @@ public void findCluster(List AHDC_precluster_list) { find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 4, 2, 5); } - for (Cluster cluster : _list_with_maybe_same_cluster) { + for (AHDCCluster cluster : _list_with_maybe_same_cluster) { if (!containsCluster(_AHDCClusters, cluster.get_Phi(), cluster.get_Radius())) { _AHDCClusters.add(cluster); } } } - public boolean containsCluster(final List list, double phi, double radius) { + public boolean containsCluster(final List list, double phi, double radius) { return list.stream().anyMatch(o -> o.get_Radius() == (radius) && o.get_Phi() == phi); } - public ArrayList get_AHDCClusters() { + public ArrayList get_AHDCClusters() { return _AHDCClusters; } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java new file mode 100644 index 0000000000..502d3476c6 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNConstants.java @@ -0,0 +1,42 @@ +package org.jlab.rec.ahdc.AI; + +/** Normalization and graph-construction constants for the GNN track finder. + * Mirrors track-finding/gnn/config.py — keep in sync with the training config. + */ +final class GNNConstants { + private GNNConstants() {} + + static final int NODE_FEAT_DIM = 10; + static final int EDGE_FEAT_DIM = 9; + + // Model architecture parameters (control the minimum graph size at inference). + // GravNet progressive-k reaches 2*k, topk uses k+1 → N_nodes >= 2*k + 2. + // The exported model clamps topk(k+1) to N internally (see + // track-finding/export_torchscript.py::_knn_indices), so any graph with + // >=3 nodes runs without crashing. Smaller graphs can't form any edge + // with the MAX_LAYER_GAP rule anyway, so we skip them here. + static final int MIN_NODES = 3; + + // Graph construction + static final int MAX_LAYER_GAP = 2; + static final double MAX_EDGE_DISTANCE = 35.0; // mm + static final double MAX_EDGE_DIST_SQ = MAX_EDGE_DISTANCE * MAX_EDGE_DISTANCE; + + // Feature normalization + static final double MAX_R = 100.0; // mm + static final double DOCA_STD = 10.0; // mm + static final double Z_HALF_LENGTH = 200.0; // mm + static final double STEREO_ANGLE_MAX = 0.03; // rad + static final double STEREO_SCALE = 1.0 / STEREO_ANGLE_MAX; + + // ATOF abs_layer convention from Python's build_graph + static final int ATOF_BAR_ABS_LAYER = 10; // component == 10 + static final int ATOF_WEDGE_ABS_LAYER = 11; // all other components + + // Track extraction: connected components at a single score threshold, matching + // gnn/evaluate.py (extract_tracks(..., method="cc", threshold=0.1)). Drop tracks + // with fewer than MIN_TRACK_NODES total nodes — same filter evaluate.py applies + // after the method call. + static final double TRACK_SCORE_THRESHOLD = 0.1; + static final int MIN_TRACK_NODES = 3; +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java new file mode 100644 index 0000000000..c2da9df1e9 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNGraphBuilder.java @@ -0,0 +1,227 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.jlab.geom.prim.Line3D; +import org.jlab.geom.prim.Point3D; +import org.jlab.geom.prim.Vector3D; +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.Track.AtofHitStub; + +/** Builds the graph tensors expected by the exported GNN edge scorer. + */ +final class GNNGraphBuilder { + + /** Container for the tensors + node provenance that the caller needs. */ + static final class GraphInput { + final float[][] nodeFeatures; // shape [N, 10] + final long[][] edgeIndex; // shape [2, E] + final float[][] edgeAttr; // shape [E, 9] + /** nodeToSource[i] is the backing Hit for AHDC nodes, or null for ATOF nodes. */ + final Hit[] nodeToSource; + /** nodeToAtof[i] is the backing ATOF hit for ATOF nodes, or null for AHDC nodes. */ + final AtofHitStub[] nodeToAtof; + + GraphInput(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr, + Hit[] nodeToSource, AtofHitStub[] nodeToAtof) { + this.nodeFeatures = nodeFeatures; + this.edgeIndex = edgeIndex; + this.edgeAttr = edgeAttr; + this.nodeToSource = nodeToSource; + this.nodeToAtof = nodeToAtof; + } + } + + private GNNGraphBuilder() {} + + /** Build a graph from AHDC hits (required) plus the ATOF::hits bank (optional). */ + static GraphInput build(List ahdcHits, DataBank atofHitsBank) { + int nAhdc = ahdcHits == null ? 0 : ahdcHits.size(); + + // Node state buffers (grow as we append AHDC then ATOF nodes). + List nodeBuf = new ArrayList<>(); // per-node raw floats (see NodeField indexes) + List nodeLine = new ArrayList<>(); // wire line for AHDC; null for ATOF + List nodeHit = new ArrayList<>(); // backing Hit for AHDC; null for ATOF + List nodeAtof = new ArrayList<>(); // backing ATOF hit for ATOF; null for AHDC + + // --- AHDC nodes ------------------------------------------------------------- + for (int i = 0; i < nAhdc; i++) { + Hit h = ahdcHits.get(i); + Line3D line = h.getLine(); + if (line == null) continue; // missing geometry → skip (shouldn't happen after setWirePosition) + + Point3D mid = line.midpoint(); + Vector3D dir = line.toVector(); + double len = Math.max(dir.mag(), 1e-12); + double ux = dir.x() / len, uy = dir.y() / len, uz = dir.z() / len; + double stereo = Math.atan2(Math.sqrt(ux*ux + uy*uy), uz); + + int absLayer = (h.getSuperLayerId() - 1) * 2 + (h.getLayerId() - 1); + nodeBuf.add(new double[]{ + absLayer, // 0: abs_layer + h.getPhi(), // 1: phi + h.getRadius(), // 2: r + stereo, // 3: stereo_angle + mid.x(), // 4: x_mid + mid.y(), // 5: y_mid + mid.z(), // 6: z_mid + ux, // 7: ux + uy, // 8: uy + uz, // 9: uz + h.getX(), // 10: x (raw, for edge distance mask) + h.getY(), // 11: y (raw, for edge distance mask) + 0.0, // 12: det_type = 0 (AHDC) + }); + nodeLine.add(line); + nodeHit.add(h); + nodeAtof.add(null); + } + + // --- ATOF nodes ------------------------------------------------------------- + // Deduplicate by (sector, layer, component). + if (atofHitsBank != null) { + Set seen = new HashSet<>(); + int rows = atofHitsBank.rows(); + for (int r = 0; r < rows; r++) { + int sector = atofHitsBank.getInt("sector", r); + int layer = atofHitsBank.getInt("layer", r); + int component = atofHitsBank.getInt("component", r); + long key = (((long)sector * 1000L) + layer) * 1000L + component; + if (!seen.add(key)) continue; + + double x = atofHitsBank.getFloat("x", r); + double y = atofHitsBank.getFloat("y", r); + double radius = Math.hypot(x, y); + double phi = Math.atan2(y, x); + int absLayer = (component == 10) ? GNNConstants.ATOF_BAR_ABS_LAYER + : GNNConstants.ATOF_WEDGE_ABS_LAYER; + + nodeBuf.add(new double[]{ + absLayer, phi, radius, + 0.0, // stereo + x, y, 0.0, // mid + 0.0, 0.0, 1.0, // (ux, uy, uz) + x, y, // raw x, y (for edge mask) + 1.0, // det_type = 1 (ATOF) + }); + nodeLine.add(null); + nodeHit.add(null); + nodeAtof.add(new AtofHitStub(sector, layer, component, x, y)); + } + } + + int n = nodeBuf.size(); + if (n < 2) { + return new GraphInput(new float[0][GNNConstants.NODE_FEAT_DIM], + new long[][]{new long[0], new long[0]}, + new float[0][GNNConstants.EDGE_FEAT_DIM], + new Hit[0], new AtofHitStub[0]); + } + + // --- Node feature tensor [N, 10] -------------------------------------------- + float[][] nodeFeatures = new float[n][GNNConstants.NODE_FEAT_DIM]; + for (int i = 0; i < n; i++) { + double[] v = nodeBuf.get(i); + nodeFeatures[i][0] = (float)(v[0] / 11.0); + nodeFeatures[i][1] = (float)(v[1] / Math.PI); + nodeFeatures[i][2] = (float)(v[2] / GNNConstants.DOCA_STD); + nodeFeatures[i][3] = (float)(v[3] / GNNConstants.STEREO_ANGLE_MAX); + nodeFeatures[i][4] = (float)(v[4] / GNNConstants.MAX_R); + nodeFeatures[i][5] = (float)(v[5] / GNNConstants.MAX_R); + nodeFeatures[i][6] = (float)(v[6] / GNNConstants.Z_HALF_LENGTH); + nodeFeatures[i][7] = (float)(v[7] * GNNConstants.STEREO_SCALE); + nodeFeatures[i][8] = (float)(v[8] * GNNConstants.STEREO_SCALE); + nodeFeatures[i][9] = (float)(v[9]); + } + + // --- Edge construction (directed, layer_gap in [1, MAX_LAYER_GAP]) ----------- + int[] absLayer = new int[n]; + double[] xRaw = new double[n]; + double[] yRaw = new double[n]; + double[] rRaw = new double[n]; + double[] phiRaw = new double[n]; + double[] stereoRaw = new double[n]; + double[] detTypeRaw = new double[n]; + for (int i = 0; i < n; i++) { + double[] v = nodeBuf.get(i); + absLayer[i] = (int) v[0]; + phiRaw[i] = v[1]; + rRaw[i] = v[2]; + stereoRaw[i] = v[3]; + xRaw[i] = v[10]; + yRaw[i] = v[11]; + detTypeRaw[i] = v[12]; + } + + List edgePairs = new ArrayList<>(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i == j) continue; + int gap = absLayer[j] - absLayer[i]; + if (gap < 1 || gap > GNNConstants.MAX_LAYER_GAP) continue; + double dx = xRaw[i] - xRaw[j]; + double dy = yRaw[i] - yRaw[j]; + if (dx*dx + dy*dy > GNNConstants.MAX_EDGE_DIST_SQ) continue; + edgePairs.add(new long[]{i, j}); + } + } + + int e = edgePairs.size(); + long[][] edgeIndex = new long[2][e]; + float[][] edgeAttr = new float[e][GNNConstants.EDGE_FEAT_DIM]; + + for (int k = 0; k < e; k++) { + long[] p = edgePairs.get(k); + int s = (int) p[0]; + int d = (int) p[1]; + edgeIndex[0][k] = s; + edgeIndex[1][k] = d; + + // dphi wrapped into [-pi, pi] + double dphi = phiRaw[s] - phiRaw[d]; + dphi = ((dphi + Math.PI) % (2.0 * Math.PI) + 2.0 * Math.PI) % (2.0 * Math.PI) - Math.PI; + double dlayer = (double)(absLayer[d] - absLayer[s]) / GNNConstants.MAX_LAYER_GAP; + + double doca, z1, z2; + Line3D ls = nodeLine.get(s); + Line3D ld = nodeLine.get(d); + if (ls != null && ld != null) { + doca = ls.distance(ld).length(); + z1 = clampZ(ls.distance(ld.midpoint()).origin().z()); + z2 = clampZ(ld.distance(ls.midpoint()).origin().z()); + } else { + double ex = xRaw[s] - xRaw[d]; + double ey = yRaw[s] - yRaw[d]; + doca = Math.hypot(ex, ey); + z1 = 0.0; + z2 = 0.0; + } + + double edgeDetType = 0.5 * (detTypeRaw[s] + detTypeRaw[d]); + + edgeAttr[k][0] = (float)(dphi / Math.PI); + edgeAttr[k][1] = (float) dlayer; + edgeAttr[k][2] = (float)(doca / GNNConstants.MAX_R); + edgeAttr[k][3] = (float)(z1 / GNNConstants.Z_HALF_LENGTH); + edgeAttr[k][4] = (float)(z2 / GNNConstants.Z_HALF_LENGTH); + edgeAttr[k][5] = (float)(rRaw[s] / GNNConstants.DOCA_STD); + edgeAttr[k][6] = (float)(rRaw[d] / GNNConstants.DOCA_STD); + edgeAttr[k][7] = (float)((stereoRaw[s] - stereoRaw[d]) / (2.0 * GNNConstants.STEREO_ANGLE_MAX)); + edgeAttr[k][8] = (float) edgeDetType; + } + + Hit[] nodeToHit = nodeHit.toArray(new Hit[0]); + AtofHitStub[] nodeToAtof = nodeAtof.toArray(new AtofHitStub[0]); + return new GraphInput(nodeFeatures, edgeIndex, edgeAttr, nodeToHit, nodeToAtof); + } + + private static double clampZ(double z) { + if (z < -GNNConstants.Z_HALF_LENGTH) return -GNNConstants.Z_HALF_LENGTH; + if (z > GNNConstants.Z_HALF_LENGTH) return GNNConstants.Z_HALF_LENGTH; + return z; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java new file mode 100644 index 0000000000..9663eac128 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/GNNPrediction.java @@ -0,0 +1,130 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.AtofHitStub; +import org.jlab.rec.ahdc.Track.CandidateType; +import org.jlab.rec.ahdc.Track.TrackCandidate; + +/** Orchestrates GNN-based track finding: builds the graph, runs the exported + * edge scorer, extracts tracks via connected components on edge scores + * thresholded at 0.1, and converts each node-set back into a + * {@link TrackCandidate} carrying per-superlayer Clusters so the downstream + * helix fit / Kalman stages can consume it. Components that include ATOF + * graph nodes yield {@code AHDC_ATOF} candidates with those ATOF hits + * attached; the rest are {@code AHDC_ONLY}. + */ +public final class GNNPrediction { + + private static final Logger LOGGER = Logger.getLogger(GNNPrediction.class.getName()); + + public ArrayList prediction(List ahdcHits, + DataBank atofHitsBank, + ModelTrackFindingGNN model) { + ArrayList out = new ArrayList<>(); + if (ahdcHits == null || ahdcHits.isEmpty() || model == null) return out; + + GNNGraphBuilder.GraphInput g = GNNGraphBuilder.build(ahdcHits, atofHitsBank); + int nNodes = g.nodeToSource.length; + int nEdges = g.edgeIndex[0].length; + if (nNodes < GNNConstants.MIN_NODES || nEdges == 0) { + return out; // model cannot run on graphs this small + } + + float[] edgeScores; + try { + edgeScores = model.predictEdgeScores(g.nodeFeatures, g.edgeIndex, g.edgeAttr); + } catch (Exception ex) { + LOGGER.warning(() -> "GNN inference failed: " + ex); + return out; + } + + // Connected components at TRACK_SCORE_THRESHOLD, filtered to + // components of size >= MIN_TRACK_NODES. + List trackNodeSets = SeedExtendTrackExtractor.extract(edgeScores, g.edgeIndex, nNodes); + + for (int[] nodes : trackNodeSets) { + // Split the component's nodes: AHDC Hits become the candidate's hits, + // ATOF nodes (when present) are attached so the candidate is typed + // AHDC_ATOF. Only AHDC hits feed AHDC::track / AHDC::hits. + ArrayList trackHits = new ArrayList<>(nodes.length); + ArrayList trackAtof = new ArrayList<>(); + for (int n : nodes) { + Hit h = g.nodeToSource[n]; + if (h != null) { trackHits.add(h); continue; } + AtofHitStub a = g.nodeToAtof[n]; + if (a != null) trackAtof.add(a); + } + if (trackHits.isEmpty()) continue; + + ArrayList clusters = buildSuperlayerClusters(trackHits); + if (clusters.size() < 3) continue; // matches the downstream >=3 filter + + TrackCandidate candidate = new TrackCandidate(clusters); + if (!trackAtof.isEmpty()) { + candidate.setType(CandidateType.AHDC_ATOF); + for (AtofHitStub a : trackAtof) candidate.addAtofHit(a); + } + out.add(candidate); + } + + return out; + } + + /** One {@link AHDCCluster} per superlayer built from two {@link PreCluster}s (one + * per layer within the superlayer). Using real PreClusters — instead of the + * 3-arg {@code AHDCCluster(x,y,z)} constructor — keeps + * {@code Track.generateHitList()} and {@code DocaClusterRefiner}'s stereo + * pairing working for GNN-discovered tracks just like they do for MLP tracks. + */ + private static ArrayList buildSuperlayerClusters(List hits) { + // Feed the track's hits through the same preclustering the MLP path uses. + // findPreclusters mutates its input (it calls setUse(true) on consumed + // hits), so pass a copy and ensure each hit starts unmarked. + ArrayList hitsForPre = new ArrayList<>(hits.size()); + for (Hit h : hits) { h.setUse(false); hitsForPre.add(h); } + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hitsForPre); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + // Index by (superlayer, layer). If the GNN assigns two PreClusters of the + // same superlayer+layer to one track (rare — it would mean two disjoint + // wire runs on the same layer), keep the largest and drop the rest. + Map bySuperlayer = new HashMap<>(); + for (PreCluster pc : preclusters) { + int sl = pc.get_Super_layer(); + int layerIdx = pc.get_Layer() - 1; // layer is 1-based, slots are [0,1] + if (layerIdx < 0 || layerIdx > 1) continue; + PreCluster[] slot = bySuperlayer.computeIfAbsent(sl, k -> new PreCluster[2]); + PreCluster prev = slot[layerIdx]; + if (prev == null || pc.get_Num_wire() > prev.get_Num_wire()) slot[layerIdx] = pc; + } + + ArrayList clusters = new ArrayList<>(); + // Iterate superlayers in ascending order to keep downstream output stable. + // If both stereo layers have a PreCluster, pair them (full stereo cluster). + // If only one has hits, use the single-layer AHDCCluster(PreCluster) ctor — + // DocaClusterRefiner handles PreClusters_list.size() != 2 with a + // degenerate DocaCluster fallback, so the helix fit still runs. + for (int sl = 1; sl <= 5; sl++) { + PreCluster[] slot = bySuperlayer.get(sl); + if (slot == null) continue; + if (slot[0] != null && slot[1] != null) { + clusters.add(new AHDCCluster(slot[0], slot[1])); + } else { + PreCluster single = (slot[0] != null) ? slot[0] : slot[1]; + if (single != null) clusters.add(new AHDCCluster(single)); + } + } + return clusters; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java new file mode 100644 index 0000000000..57e8a39b33 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFindingGNN.java @@ -0,0 +1,96 @@ +package org.jlab.rec.ahdc.AI; + +import java.io.IOException; +import java.nio.file.Paths; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.NoopTranslator; +import org.jlab.utils.CLASResources; + +/** DJL wrapper around the GravNet TorchScript model exported from + * track-finding/export_torchscript.py. Runs per-event edge scoring. + * + * Exported forward signature (see SingleGraphEdgeScorer): + * forward(x: float32[N, 10], edge_index: int64[2, E], edge_attr: float32[E, 9]) + * -> float32[E] (sigmoid edge scores in [0, 1]) + */ +public class ModelTrackFindingGNN { + + private final ZooModel model; + /** Reused across every call. DJL's Predictor is NOT thread-safe, but the + * ALERT reconstruction engine is single-threaded, so one instance is fine + * and avoids the allocation/libtorch-graph-prep cost per event that + * dominated predictEdgeScores on small graphs. + */ + private final Predictor predictor; + + public ModelTrackFindingGNN() { + // Let libtorch pick sensible defaults: GravNet's cdist + topk + gather + // chain benefits from the graph optimizer and intra-op parallelism the + // MLP copy-paste was pinning off. Keep num_interop_threads=1 — there is + // only one event in flight at a time. + System.setProperty("ai.djl.pytorch.num_interop_threads", "1"); + + String path = CLASResources.getResourcePath("etc/data/nnet/rg-l/model_AHDC_GNN/"); + Criteria criteria = Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get(path)) + .optEngine("PyTorch") + .optTranslator(new NoopTranslator()) + .optProgress(new ProgressBar()) + .build(); + + try { + model = criteria.loadModel(); + } catch (IOException | ModelNotFoundException | MalformedModelException ex) { + throw new RuntimeException(ex); + } + predictor = model.newPredictor(new NoopTranslator()); + } + + /** Score every edge in the input graph. + * + * @param nodeFeatures shape [N, 10] — see GNNConstants.NODE_FEAT_DIM + * @param edgeIndex shape [2, E] — int64 source / destination node ids + * @param edgeAttr shape [E, 9] — see GNNConstants.EDGE_FEAT_DIM + * @return float[E] of sigmoid edge scores in [0, 1] + */ + public float[] predictEdgeScores(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr) throws Exception { + if (nodeFeatures == null || nodeFeatures.length == 0) return new float[0]; + int n = nodeFeatures.length; + int e = edgeIndex[0].length; + if (e == 0) return new float[0]; + + try (NDManager manager = NDManager.newBaseManager()) { + // Flatten x into a contiguous float[] so DJL builds a [N, node_dim] tensor. + int nodeDim = nodeFeatures[0].length; + float[] xFlat = new float[n * nodeDim]; + for (int i = 0; i < n; i++) System.arraycopy(nodeFeatures[i], 0, xFlat, i * nodeDim, nodeDim); + NDArray x = manager.create(xFlat, new Shape(n, nodeDim)); + + // edge_index is int64[2, E]; flatten row-major. + long[] edgeIndexFlat = new long[2 * e]; + System.arraycopy(edgeIndex[0], 0, edgeIndexFlat, 0, e); + System.arraycopy(edgeIndex[1], 0, edgeIndexFlat, e, e); + NDArray edgeIndexNd = manager.create(edgeIndexFlat, new Shape(2, e)); + + int edgeDim = edgeAttr[0].length; + float[] edgeAttrFlat = new float[e * edgeDim]; + for (int i = 0; i < e; i++) System.arraycopy(edgeAttr[i], 0, edgeAttrFlat, i * edgeDim, edgeDim); + NDArray edgeAttrNd = manager.create(edgeAttrFlat, new Shape(e, edgeDim)); + + NDList output = predictor.predict(new NDList(x, edgeIndexNd, edgeAttrNd)); + NDArray scoresNd = output.get(0); + return scoresNd.toFloatArray(); + } + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java new file mode 100644 index 0000000000..abe3a9d4f1 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/SeedExtendTrackExtractor.java @@ -0,0 +1,63 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Track extraction from per-edge scores via union-find connected components + * at a single threshold. Ports the {@code method="cc"} branch of + * {@code track-finding/gnn/inference.py::extract_tracks}, which is the + * extractor that gnn/evaluate.py uses. + */ +final class SeedExtendTrackExtractor { + + private SeedExtendTrackExtractor() {} + + /** @return list of node-index arrays, one per connected component of size + * ≥ {@link GNNConstants#MIN_TRACK_NODES}. */ + static List extract(float[] scores, long[][] edgeIndex, int nNodes) { + if (nNodes <= 0 || scores == null || edgeIndex == null || edgeIndex[0].length != scores.length) { + return new ArrayList<>(); + } + long[] src = edgeIndex[0]; + long[] dst = edgeIndex[1]; + + int[] parent = new int[nNodes]; + for (int i = 0; i < nNodes; i++) parent[i] = i; + for (int e = 0; e < scores.length; e++) { + if (scores[e] >= GNNConstants.TRACK_SCORE_THRESHOLD) { + union(parent, (int) src[e], (int) dst[e]); + } + } + + Map> groups = new HashMap<>(); + for (int i = 0; i < nNodes; i++) { + int r = find(parent, i); + groups.computeIfAbsent(r, k -> new ArrayList<>()).add(i); + } + + List out = new ArrayList<>(); + for (List members : groups.values()) { + if (members.size() < GNNConstants.MIN_TRACK_NODES) continue; + int[] arr = new int[members.size()]; + for (int i = 0; i < arr.length; i++) arr[i] = members.get(i); + out.add(arr); + } + return out; + } + + private static int find(int[] parent, int x) { + while (parent[x] != x) { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + return x; + } + + private static void union(int[] parent, int a, int b) { + int ra = find(parent, a); + int rb = find(parent, b); + if (ra != rb) parent[ra] = rb; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java index 0937243b87..16673134ef 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java @@ -1,6 +1,6 @@ package org.jlab.rec.ahdc.AI; -import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; import org.jlab.rec.ahdc.PreCluster.PreCluster; import java.util.ArrayList; @@ -10,7 +10,7 @@ public class TrackPrediction { private float prediction; private final ArrayList interclusters; private final ArrayList preclusters = new ArrayList<>(); - private ArrayList clusters = new ArrayList<>(); + private ArrayList clusters = new ArrayList<>(); public TrackPrediction(float prediction, ArrayList interclusters_) { this.prediction = prediction; @@ -25,28 +25,28 @@ public TrackPrediction(float prediction, ArrayList interclusters_) if (p.get_Super_layer() == 1) { for (PreCluster other : this.preclusters) { if (other.get_Super_layer() == 2 && other.get_Layer() == 1) - clusters.add(new Cluster(p, other)); + clusters.add(new AHDCCluster(p, other)); } } if (p.get_Super_layer() == 2 && p.get_Layer() == 2) { for (PreCluster other : this.preclusters) { if (other.get_Super_layer() == 3 && other.get_Layer() == 1) - clusters.add(new Cluster(p, other)); + clusters.add(new AHDCCluster(p, other)); } } if (p.get_Super_layer() == 3 && p.get_Layer() == 2) { for (PreCluster other : this.preclusters) { if (other.get_Super_layer() == 4 && other.get_Layer() == 1) - clusters.add(new Cluster(p, other)); + clusters.add(new AHDCCluster(p, other)); } } if (p.get_Super_layer() == 4 && p.get_Layer() == 2) { for (PreCluster other : this.preclusters) { if (other.get_Super_layer() == 5 && other.get_Layer() == 1) - clusters.add(new Cluster(p, other)); + clusters.add(new AHDCCluster(p, other)); } } @@ -67,7 +67,7 @@ public ArrayList getPreclusters() { return preclusters; } - public ArrayList getClusters() { + public ArrayList getClusters() { return clusters; } } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java index 0884ab547f..949202879a 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java @@ -4,7 +4,7 @@ import org.jlab.io.base.DataEvent; import org.jlab.rec.ahdc.AI.InterCluster; import org.jlab.rec.ahdc.AI.TrackPrediction; -import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; import org.jlab.rec.ahdc.DocaCluster.DocaCluster; import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.PreCluster.PreCluster; @@ -50,7 +50,7 @@ public DataBank fillPreClustersBank(DataEvent event, ArrayList preCl return bank; } - public DataBank fillClustersBank(DataEvent event, ArrayList clusters) { + public DataBank fillClustersBank(DataEvent event, ArrayList clusters) { if (clusters == null || clusters.size() == 0) return null; DataBank bank = event.createBank("AHDC::clusters", clusters.size()); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Distance/Distance.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Distance/Distance.java index 0d9d10ce87..5139926efc 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Distance/Distance.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Distance/Distance.java @@ -1,7 +1,7 @@ package org.jlab.rec.ahdc.Distance; -import org.jlab.rec.ahdc.Cluster.Cluster; -import org.jlab.rec.ahdc.Track.Track; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.Track.TrackCandidate; import java.util.ArrayList; import java.util.Arrays; @@ -14,13 +14,13 @@ */ public class Distance { - private ArrayList _AHDCTracks; + private ArrayList _AHDCTrackCandidates; public Distance(){ - _AHDCTracks = new ArrayList<>(); + _AHDCTrackCandidates = new ArrayList<>(); } - public void find_track(List AHDC_Cluster){ + public void find_track(List AHDC_Cluster){ find_track_4_clusters(AHDC_Cluster); find_track_3_clusters(AHDC_Cluster); } @@ -41,14 +41,14 @@ public static List> computeCombinations2(List> lists) { return combinations; } - private void find_track_4_clusters(List AHDC_Cluster){ - List clusters_to_remove = new ArrayList<>(); - List layer1 = new ArrayList<>(); // List of all cluster with a radius equal 35 - List layer2 = new ArrayList<>(); // List of all cluster with a radius equal 45 - List layer3 = new ArrayList<>(); // List of all cluster with a radius equal 55 - List layer4 = new ArrayList<>(); // List of all cluster with a radius equal 65 + private void find_track_4_clusters(List AHDC_Cluster){ + List clusters_to_remove = new ArrayList<>(); + List layer1 = new ArrayList<>(); // List of all cluster with a radius equal 35 + List layer2 = new ArrayList<>(); // List of all cluster with a radius equal 45 + List layer3 = new ArrayList<>(); // List of all cluster with a radius equal 55 + List layer4 = new ArrayList<>(); // List of all cluster with a radius equal 65 - for(Cluster cluster : AHDC_Cluster){ + for(AHDCCluster cluster : AHDC_Cluster){ if(cluster.get_Radius() == 35){ layer1.add(cluster); } @@ -62,29 +62,29 @@ else if(cluster.get_Radius() == 65){ layer4.add(cluster); } } - List> merged_list = new ArrayList<>(); + List> merged_list = new ArrayList<>(); merged_list.add(layer1); merged_list.add(layer2); merged_list.add(layer3); merged_list.add(layer4); - List> all_combinations = computeCombinations2(merged_list); + List> all_combinations = computeCombinations2(merged_list); - List all_track = new ArrayList<>(); - for(List combination : all_combinations){ - all_track.add(new Track(combination)); + List all_track = new ArrayList<>(); + for(List combination : all_combinations){ + all_track.add(new TrackCandidate(combination)); } - List tracks_possible = new ArrayList<>(); - for(Track track : all_track){ + List tracks_possible = new ArrayList<>(); + for(TrackCandidate track : all_track){ if(track.get_Distance() < 45){ tracks_possible.add(track); } } double window = 3.8; - for(Track track : tracks_possible){ - List tracks_with_close_starting_point = new ArrayList<>(); - for(Track other_track : tracks_possible){ + for(TrackCandidate track : tracks_possible){ + List tracks_with_close_starting_point = new ArrayList<>(); + for(TrackCandidate other_track : tracks_possible){ if(other_track.get_Clusters().get(0).get_X() > track.get_Clusters().get(0).get_X() - window && other_track.get_Clusters().get(0).get_X() < track.get_Clusters().get(0).get_X() + window && other_track.get_Clusters().get(0).get_Y() > track.get_Clusters().get(0).get_Y() - window @@ -97,12 +97,12 @@ else if(cluster.get_Radius() == 65){ if(tracks_with_close_starting_point.size() > 0){ double chisq_min = Double.MAX_VALUE; - Track best_track = null; - for(Track other_track : tracks_with_close_starting_point){ + TrackCandidate best_track = null; + for(TrackCandidate other_track : tracks_with_close_starting_point){ ArrayList x_ = new ArrayList<>(); ArrayList y_ = new ArrayList<>(); ArrayList w_ = new ArrayList<>(); // weight for circlefit - for(Cluster cluster : other_track.get_Clusters()){ + for(AHDCCluster cluster : other_track.get_Clusters()){ x_.add(cluster.get_X()); y_.add(cluster.get_Y()); w_.add(1.); @@ -119,32 +119,32 @@ else if(cluster.get_Radius() == 65){ } if (best_track != null ){ clusters_to_remove.addAll(best_track.get_Clusters()); - _AHDCTracks.add(best_track); + _AHDCTrackCandidates.add(best_track); } } } - List clusters_to_remove_without_double = new ArrayList<>(); - for(Cluster cluster : clusters_to_remove){ + List clusters_to_remove_without_double = new ArrayList<>(); + for(AHDCCluster cluster : clusters_to_remove){ if(!containsCluster(clusters_to_remove_without_double, cluster.get_Phi(), cluster.get_Radius())){ clusters_to_remove_without_double.add(cluster); } } - for(Cluster cluster : clusters_to_remove_without_double){ + for(AHDCCluster cluster : clusters_to_remove_without_double){ AHDC_Cluster.remove(cluster); } } - public boolean containsCluster(final List list, double phi, double radius){ + public boolean containsCluster(final List list, double phi, double radius){ return list.stream().anyMatch(o -> o.get_Radius() == (radius) && o.get_Phi() == phi); } - private ArrayList> combination(List arr, ArrayList data, int start, + private ArrayList> combination(List arr, ArrayList data, int start, int end, int index, int r) { - ArrayList> all = new ArrayList<>(); + ArrayList> all = new ArrayList<>(); if (index == r) { all.add(data); } @@ -157,25 +157,25 @@ private ArrayList> combination(List arr, ArrayList AHDC_Cluster){ - ArrayList> all_combinations = combination(AHDC_Cluster, new ArrayList(),0, AHDC_Cluster.size()-1, 0, 3); + private void find_track_3_clusters(List AHDC_Cluster){ + ArrayList> all_combinations = combination(AHDC_Cluster, new ArrayList(),0, AHDC_Cluster.size()-1, 0, 3); - List all_track = new ArrayList<>(); - for(List combination : all_combinations){ - all_track.add(new Track(combination)); + List all_track = new ArrayList<>(); + for(List combination : all_combinations){ + all_track.add(new TrackCandidate(combination)); } - List tracks_possible = new ArrayList<>(); - for(Track track : all_track){ + List tracks_possible = new ArrayList<>(); + for(TrackCandidate track : all_track){ if(track.get_Distance() < 45){ tracks_possible.add(track); } } double window = 3.8; - for(Track track : tracks_possible) { - List tracks_with_close_starting_point = new ArrayList<>(); - for (Track other_track : tracks_possible) { + for(TrackCandidate track : tracks_possible) { + List tracks_with_close_starting_point = new ArrayList<>(); + for (TrackCandidate other_track : tracks_possible) { if (other_track.get_Clusters().get(0).get_X() > track.get_Clusters().get(0).get_X() - window && other_track.get_Clusters().get(0).get_X() < track.get_Clusters().get(0).get_X() + window && other_track.get_Clusters().get(0).get_Y() > track.get_Clusters().get(0).get_Y() - window @@ -188,12 +188,12 @@ private void find_track_3_clusters(List AHDC_Cluster){ if(tracks_with_close_starting_point.size() > 0){ double chisq_min = Double.MAX_VALUE; - Track best_track = null; - for(Track other_track : tracks_with_close_starting_point){ + TrackCandidate best_track = null; + for(TrackCandidate other_track : tracks_with_close_starting_point){ ArrayList x_ = new ArrayList<>(); ArrayList y_ = new ArrayList<>(); ArrayList w_ = new ArrayList<>(); // weight for circlefit - for(Cluster cluster : other_track.get_Clusters()){ + for(AHDCCluster cluster : other_track.get_Clusters()){ x_.add(cluster.get_X()); y_.add(cluster.get_Y()); w_.add(1.); @@ -209,17 +209,17 @@ private void find_track_3_clusters(List AHDC_Cluster){ } } if (best_track != null ){ - _AHDCTracks.add(best_track); + _AHDCTrackCandidates.add(best_track); } } } } - public ArrayList get_AHDCTracks() { - return _AHDCTracks; + public ArrayList get_AHDCTrackCandidates() { + return _AHDCTrackCandidates; } - public void set_AHDCTracks(ArrayList _AHDCTracks) { - this._AHDCTracks = _AHDCTracks; + public void set_AHDCTrackCandidates(ArrayList _AHDCTrackCandidates) { + this._AHDCTrackCandidates = _AHDCTrackCandidates; } } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/DocaCluster/DocaClusterRefiner.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/DocaCluster/DocaClusterRefiner.java index 1491f0a96e..897f598028 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/DocaCluster/DocaClusterRefiner.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/DocaCluster/DocaClusterRefiner.java @@ -2,7 +2,7 @@ import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.PreCluster.PreCluster; -import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; import org.jlab.geom.prim.Line3D; import org.jlab.geom.prim.Point3D; @@ -14,7 +14,7 @@ /** * Build refined cluster space points using DOCA circles. * - * For each original Cluster (2 PreClusters with stereo angle), + * For each original AHDCCluster (2 PreClusters with stereo angle), * if the hit multiplicity is (1,1), (1,2)/(2,1) or (2,2), we * construct new space points from circle-circle tangents. * Otherwise we fall back to the original (x,y,z) with weight 1. @@ -284,18 +284,18 @@ private static Point3D midpointOfCommonPerpendicular(Line3D L1, Line3D L2) { // ===================================================================== /** - * Build a list of DocaCluster objects from the original list of Cluster. - * Each Cluster may generate multiple DocaCluster points. + * Build a list of DocaCluster objects from the original list of AHDCCluster. + * Each AHDCCluster may generate multiple DocaCluster points. */ - public static ArrayList buildRefinedClusters(List clusters) { + public static ArrayList buildRefinedClusters(List clusters) { ArrayList out = new ArrayList<>(); if (clusters == null) return out; for (int idx = 0; idx < clusters.size(); idx++) { - Cluster cl = clusters.get(idx); + AHDCCluster cl = clusters.get(idx); ArrayList pcs = cl.get_PreClusters_list(); if (pcs == null || pcs.size() != 2) { @@ -367,7 +367,7 @@ public static ArrayList buildRefinedClusters(List clusters // (1,1) case // ===================================================================== - private static List refine11(Cluster oldCluster, + private static List refine11(AHDCCluster oldCluster, PreCluster pc1, Hit h1, PreCluster pc2, Hit h2, int clusterIndex) { @@ -460,11 +460,11 @@ private static List refine11(Cluster oldCluster, /** * One precluster has 1 hit (singlePc, singleHits), the other has 2 hits (doublePc, doubleHits). - * singleIsFirst indicates whether in the original Cluster list ordering we had: + * singleIsFirst indicates whether in the original AHDCCluster list ordering we had: * - true : (pc1, pc2) = (singlePc, doublePc) * - false : (pc1, pc2) = (doublePc, singlePc) */ - private static List refine12(Cluster oldCluster, + private static List refine12(AHDCCluster oldCluster, PreCluster singlePc, ArrayList singleHits, PreCluster doublePc, ArrayList doubleHits, boolean singleIsFirst, @@ -637,7 +637,7 @@ private static List refine12(Cluster oldCluster, // (2,2) case // ===================================================================== - private static List refine22(Cluster oldCluster, + private static List refine22(AHDCCluster oldCluster, PreCluster pc1, ArrayList hits1, PreCluster pc2, ArrayList hits2, int clusterIndex) { @@ -867,7 +867,7 @@ private static double mod(double a, double b) { return a - b * Math.floor(a / b); } - /** Compute Z using the same relation as Cluster constructor but with new φ values. */ + /** Compute Z using the same relation as AHDCCluster constructor but with new φ values. */ private static double computeZ(PreCluster pre1, PreCluster pre2, double phi1, double phi2) { diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/HoughTransform/HoughTransform.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/HoughTransform/HoughTransform.java index ed697aecff..8f7402f6d5 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/HoughTransform/HoughTransform.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/HoughTransform/HoughTransform.java @@ -1,21 +1,21 @@ package org.jlab.rec.ahdc.HoughTransform; -import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; import Jama.Matrix; -import org.jlab.rec.ahdc.Track.Track; +import org.jlab.rec.ahdc.Track.TrackCandidate; import java.util.ArrayList; import java.util.List; public class HoughTransform { - private ArrayList _AHDCTracks; + private ArrayList _AHDCTrackCandidates; public HoughTransform(){ - _AHDCTracks = new ArrayList<>(); + _AHDCTrackCandidates = new ArrayList<>(); } - public void find_tracks(List AHDC_Clusters){ + public void find_tracks(List AHDC_Clusters){ int matrix_size = 300; boolean delete = false; ArrayList delete_i_max = new ArrayList<>(); @@ -23,7 +23,7 @@ public void find_tracks(List AHDC_Clusters){ while(true){ Matrix B = new Matrix(matrix_size + 1, matrix_size + 1, 0); - for (Cluster cluster : AHDC_Clusters) { + for (AHDCCluster cluster : AHDC_Clusters) { double new_u = (cluster.get_U() + 0.06) / 0.12; double new_v = (cluster.get_V() + 0.06) / 0.12; B.set((int) (new_u * matrix_size), (int) (new_v * matrix_size), 1); @@ -85,8 +85,8 @@ public void find_tracks(List AHDC_Clusters){ double a = Math.cos(theta)/(2*rho); double b = Math.sin(theta)/(2*rho); - ArrayList possible_cluster_of_track = new ArrayList<>(); - for(Cluster cluster : AHDC_Clusters){ + ArrayList possible_cluster_of_track = new ArrayList<>(); + for(AHDCCluster cluster: AHDC_Clusters){ double distance = Math.abs(Math.sqrt(Math.pow((cluster.get_X() - a),2) + Math.pow((cluster.get_Y() - b),2)) - r); if(distance < 4){ possible_cluster_of_track.add(cluster); @@ -97,10 +97,10 @@ public void find_tracks(List AHDC_Clusters){ double x_0 = possible_cluster_of_track.get(0).get_X(); double y_0 = possible_cluster_of_track.get(0).get_Y(); - ArrayList cluster_track = new ArrayList<>(); - ArrayList cluster_to_remove = new ArrayList<>(); + ArrayList cluster_track = new ArrayList<>(); + ArrayList cluster_to_remove = new ArrayList<>(); - for(Cluster other_cluster : possible_cluster_of_track){ + for(AHDCCluster other_cluster: possible_cluster_of_track){ double distance = Math.sqrt( (other_cluster.get_X() - x_0)*(other_cluster.get_X() - x_0) + (other_cluster.get_Y() - y_0)*(other_cluster.get_Y() - y_0) ); if(distance < 50){ @@ -119,21 +119,21 @@ public void find_tracks(List AHDC_Clusters){ } } - ArrayList cluster_to_remove_without_double = new ArrayList<>(); - for(Cluster cluster : cluster_to_remove){ + ArrayList cluster_to_remove_without_double = new ArrayList<>(); + for(AHDCCluster cluster: cluster_to_remove){ if(!containsCluster(cluster_to_remove_without_double, cluster.get_Phi(), cluster.get_Radius())){ cluster_to_remove_without_double.add(cluster); } } - for(Cluster cluster : cluster_to_remove_without_double){ + for(AHDCCluster cluster: cluster_to_remove_without_double){ cluster_track.remove(cluster); } if(cluster_track.size() > 2){ - Track track = new Track(cluster_track); - _AHDCTracks.add(track); - for(Cluster cluster : cluster_track){AHDC_Clusters.remove(cluster);} + TrackCandidate track = new TrackCandidate(cluster_track); + _AHDCTrackCandidates.add(track); + for(AHDCCluster cluster: cluster_track){AHDC_Clusters.remove(cluster);} } else{ delete = true; @@ -149,15 +149,15 @@ public void find_tracks(List AHDC_Clusters){ } } - public boolean containsCluster(final List list, double phi, double radius){ + public boolean containsCluster(final List list, double phi, double radius){ return list.stream().anyMatch(o -> o.get_Radius() == (radius) && o.get_Phi() == phi); } - public ArrayList get_AHDCTracks() { - return _AHDCTracks; + public ArrayList get_AHDCTrackCandidates() { + return _AHDCTrackCandidates; } - public void set_AHDCTracks(ArrayList _AHDCTracks) { - this._AHDCTracks = _AHDCTracks; + public void set_AHDCTrackCandidates(ArrayList _AHDCTrackCandidates) { + this._AHDCTrackCandidates = _AHDCTrackCandidates; } } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java index ec500c3ad9..72a260ba52 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/ModeTrackFinding.java @@ -1,7 +1,8 @@ package org.jlab.rec.ahdc; public enum ModeTrackFinding { - AI_Track_Finding, + MLP_Track_Finding, CV_Distance, CV_Hough, + GNN_Track_Finding, } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/AtofHitStub.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/AtofHitStub.java new file mode 100644 index 0000000000..9a06ce67c4 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/AtofHitStub.java @@ -0,0 +1,35 @@ +package org.jlab.rec.ahdc.Track; + +/** A minimal stand-in for an ATOF hit, attached to a {@link TrackCandidate} by + * the GNN graph. + * + *

Carries only the fields the GNN graph builder already extracts from the + * {@code ATOF::hits} bank. It is deliberately not + * {@code org.jlab.rec.atof.hit.ATOFHit}: constructing a real {@code ATOFHit} + * requires the ATOF {@code Detector} geometry and a calibration table, neither + * of which is available where the graph is built. Keeping this a plain value + * class also keeps the {@code Track} package free of any dependency on the + * {@code atof} package.

+ */ +public final class AtofHitStub { + + private final int sector; + private final int layer; + private final int component; + private final double x; + private final double y; + + public AtofHitStub(int sector, int layer, int component, double x, double y) { + this.sector = sector; + this.layer = layer; + this.component = component; + this.x = x; + this.y = y; + } + + public int getSector() { return sector; } + public int getLayer() { return layer; } + public int getComponent() { return component; } + public double getX() { return x; } + public double getY() { return y; } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/CandidateType.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/CandidateType.java new file mode 100644 index 0000000000..198e9d93a6 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/CandidateType.java @@ -0,0 +1,19 @@ +package org.jlab.rec.ahdc.Track; + +/** Specialization of a {@link TrackCandidate}. The type a track finder assigns + * to a candidate dictates how the candidate is fitted downstream. + * + *
    + *
  • {@code AHDC_ONLY} — AHDC hits only (MLP / Distance / Hough, and GNN + * tracks that include no ATOF graph nodes).
  • + *
  • {@code AHDC_ATOF} — AHDC hits plus attached ATOF hits (GNN tracks whose + * connected component includes ATOF graph nodes).
  • + *
  • {@code AHDC_VERTEX} — reserved for future use (AHDC hits + a vertex + * constraint).
  • + *
+ */ +public enum CandidateType { + AHDC_ONLY, + AHDC_ATOF, + AHDC_VERTEX +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java index 7efc0cc37e..ab07d9c63e 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java @@ -1,28 +1,24 @@ package org.jlab.rec.ahdc.Track; -import org.apache.commons.math3.linear.RealVector; import org.jlab.rec.ahdc.AI.InterCluster; -import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; import org.jlab.rec.ahdc.HelixFit.HelixFitObject; import org.jlab.rec.ahdc.Hit.Hit; -import org.jlab.rec.ahdc.PreCluster.PreCluster; -import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; -import org.jlab.rec.ahdc.AI.PreClustering; import java.util.ArrayList; import java.util.List; +/** A fitted track: the result of track fitting. + * + *

A {@code Track} is produced by fitting a {@link TrackCandidate} (helix + * fit, then Kalman filter). It composes the candidate it was fitted from — so + * it still exposes the candidate's hits / clusters to the bank writers — and + * adds the fitted vertex, momentum and fit-quality quantities.

+ */ public class Track { - private double _Distance; - private List _Clusters = new ArrayList<>(); - private List _InterClusters = new ArrayList<>(); - private boolean _Used = false; - private final ArrayList hits = new ArrayList<>(); - - private int trackId = -1; ///< id of the track - private int n_hits = 0; ///< number of hits - private int sum_adc = 0; ///< sum of adc (adc) + private final TrackCandidate candidate; + private double sum_residuals = 0; ///< sum of residuals (mm) private double chi2 = 0; ///< sum of residuals^2 (mm^2) // AHDC::track @@ -36,35 +32,14 @@ public class Track { private double p_drift = 0; ///< momentum in the drift region (MeV) private double path = 0; ///< length of the track (mm) - // AHDC::aiprediction - private int predicted_ATOF_sector = -1; - private int predicted_ATOF_layer = -1; - private int predicted_ATOF_wedge = -1; + public Track(TrackCandidate candidate) { + this.candidate = candidate; + } - public Track(List clusters) { - this._Clusters = clusters; - this._Distance = 0; - for (int i = 0; i < clusters.size() - 1; i++) { - this._Distance += Math.sqrt((clusters.get(i).get_X() - clusters.get(i + 1).get_X()) * (clusters.get(i).get_X() - clusters.get(i + 1).get_X()) + (clusters.get(i).get_Y() - clusters.get(i + 1).get_Y()) * (clusters.get(i).get_Y() - clusters.get(i + 1).get_Y())); - } - generateHitList(); - generateInterClusterList(); - } - - public Track(ArrayList hitslist) { - hits.addAll(hitslist); - this.x0 = 0.0; - this.y0 = 0.0; - this.z0 = 0.0; - double p = 150.0;//MeV/c - //take first hit. - Hit hit = hitslist.get(0); - double phi = Math.atan2(hit.getY(), hit.getX()); - //hitslist. - this.px0 = p*Math.sin(phi); - this.py0 = p*Math.cos(phi); - this.pz0 = 0.0; - } + /** The candidate this track was fitted from. */ + public TrackCandidate getCandidate() { + return candidate; + } public void setPositionAndMomentum(HelixFitObject helixFitObject) { this.x0 = helixFitObject.get_X0(); @@ -86,52 +61,9 @@ public void setPositionAndMomentumVec(double[] x) { this.pz0 = x[5]; } - private void generateHitList() { - for (Cluster cluster : _Clusters) { - for (PreCluster preCluster : cluster.get_PreClusters_list()) { - hits.addAll(preCluster.get_hits_list()); - } - } - } - - private void generateInterClusterList() { - // Use hits to generate preclusters - PreClusterFinder preclusterfinder = new PreClusterFinder(); - preclusterfinder.findPreclusters(hits); - ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); - - // Use preclusters to generate interclusters - PreClustering preClustering = new PreClustering(); - this._InterClusters = preClustering.mergePreclusters(AHDC_PreClusters); - } - - public ArrayList getHits() { - return hits; - } - @Override public String toString() { - return "Track{" + "_Clusters=" + _Clusters + '}'; - } - - public double get_Distance() { - return _Distance; - } - - public List get_Clusters() { - return _Clusters; - } - - public List getInterclusters() { - return _InterClusters; - } - - public boolean is_Used() { - return _Used; - } - - public void set_Used(boolean _Used) { - this._Used = _Used; + return "Track{" + "candidate=" + candidate + '}'; } public double get_X0() { @@ -158,42 +90,8 @@ public double get_pz() { return pz0; } - public void set_trackId(int _trackId) { - trackId = _trackId; - // set trackId for clusters - for(Cluster cluster : this._Clusters) { - cluster.set_trackId(_trackId); - } - // set trackId for interclusters - for(InterCluster interCluster : this._InterClusters) { - interCluster.setTrackId(_trackId); - } - // set trackId for hits - for (Hit hit : this.hits) { - hit.setTrackId(_trackId); - } - } - public void set_n_hits(int _n_hits) { n_hits = _n_hits;} - public void set_sum_adc(int _sum_adc) { sum_adc = _sum_adc;} public void set_chi2(double _chi2) { chi2 = _chi2;} public void set_sum_residuals(double _sum_residuals) { sum_residuals = _sum_residuals;} - public int get_trackId() {return trackId;} - public int get_n_hits() { - if (hits == null) { - return 0; - } - return hits.size(); - } - public int get_sum_adc() { - if (hits == null || hits.isEmpty()) { - return 0; - } - int sum = 0; - for (Hit h : hits) { - sum += (int) Math.round(h.getADC()); - } - return sum; - } public double get_chi2() {return chi2;} public double get_sum_residuals() {return sum_residuals;} // AHDC::track @@ -205,22 +103,41 @@ public double get_dEdx() { dEdx = 0; }else { int sum = 0; - for (Hit h : hits) { + for (Hit h : candidate.getHits()) { sum += (int) Math.round(h.getADC()); } - dEdx = sum/path; + dEdx = sum/path; } return dEdx; } public double get_p_drift() {return p_drift;} public double get_path() {return path;} - // AHDC::aiprediction - public void set_predicted_ATOF_sector(int s) {predicted_ATOF_sector = s;} - public void set_predicted_ATOF_layer(int l) {predicted_ATOF_layer = l;} - public void set_predicted_ATOF_wedge(int w) {predicted_ATOF_wedge = w;} - public int get_predicted_ATOF_sector() {return predicted_ATOF_sector;} - public int get_predicted_ATOF_layer() {return predicted_ATOF_layer;} - public int get_predicted_ATOF_wedge() {return predicted_ATOF_wedge;} - -} \ No newline at end of file + // --- delegated to the underlying candidate ------------------------------- + // The candidate owns the hits / clusters / specialization; Track exposes the + // same accessors so callers can treat a fitted Track as a full handle. + public int get_trackId() { return candidate.get_trackId(); } + public void set_trackId(int id) { candidate.set_trackId(id); } + public int get_n_hits() { return candidate.get_n_hits(); } + public void set_n_hits(int n) { candidate.set_n_hits(n); } + public int get_sum_adc() { return candidate.get_sum_adc(); } + public void set_sum_adc(int s) { candidate.set_sum_adc(s); } + public ArrayList getHits() { return candidate.getHits(); } + public List get_Clusters() { return candidate.get_Clusters(); } + public List getInterclusters() { return candidate.getInterclusters(); } + public double get_Distance() { return candidate.get_Distance(); } + public boolean is_Used() { return candidate.is_Used(); } + public void set_Used(boolean u) { candidate.set_Used(u); } + public CandidateType getType() { return candidate.getType(); } + public void setType(CandidateType t) { candidate.setType(t); } + public List getAtofHits() { return candidate.getAtofHits(); } + public void addAtofHit(AtofHitStub h) { candidate.addAtofHit(h); } + + // AHDC::aiprediction — delegated to the candidate + public void set_predicted_ATOF_sector(int s) { candidate.set_predicted_ATOF_sector(s); } + public void set_predicted_ATOF_layer(int l) { candidate.set_predicted_ATOF_layer(l); } + public void set_predicted_ATOF_wedge(int w) { candidate.set_predicted_ATOF_wedge(w); } + public int get_predicted_ATOF_sector() { return candidate.get_predicted_ATOF_sector(); } + public int get_predicted_ATOF_layer() { return candidate.get_predicted_ATOF_layer(); } + public int get_predicted_ATOF_wedge() { return candidate.get_predicted_ATOF_wedge(); } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/TrackCandidate.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/TrackCandidate.java new file mode 100644 index 0000000000..7b796bd1b4 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/TrackCandidate.java @@ -0,0 +1,167 @@ +package org.jlab.rec.ahdc.Track; + +import org.jlab.rec.ahdc.AI.InterCluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.AI.PreClustering; + +import java.util.ArrayList; +import java.util.List; + +/** A track candidate: the output of a track finder. + * + *

A candidate is a set of hits grouped into {@link AHDCCluster}s (and the + * derived {@link InterCluster}s). It is not a fitted track — fitting + * consumes a {@code TrackCandidate} and produces a {@link Track}.

+ * + *

A candidate carries a {@link CandidateType} describing its specialization + * (AHDC-only, AHDC+ATOF, ...). The type is what dictates how the candidate is + * fitted downstream.

+ */ +public class TrackCandidate { + + private double _Distance; + private List _Clusters = new ArrayList<>(); + private List _InterClusters = new ArrayList<>(); + private boolean _Used = false; + private final ArrayList hits = new ArrayList<>(); + + private int trackId = -1; ///< id of the track + private int n_hits = 0; ///< number of hits + private int sum_adc = 0; ///< sum of adc (adc) + + /** Candidate specialization — defaults to AHDC-only; finders that build + * AHDC+ATOF candidates (GNN) override it via {@link #setType}. */ + private CandidateType type = CandidateType.AHDC_ONLY; + /** ATOF hits attached to this candidate (non-empty only for AHDC_ATOF). */ + private final List atofHits = new ArrayList<>(); + + // AHDC::aiprediction + private int predicted_ATOF_sector = -1; + private int predicted_ATOF_layer = -1; + private int predicted_ATOF_wedge = -1; + + public TrackCandidate(List clusters) { + this._Clusters = clusters; + this._Distance = 0; + for (int i = 0; i < clusters.size() - 1; i++) { + this._Distance += Math.sqrt((clusters.get(i).get_X() - clusters.get(i + 1).get_X()) * (clusters.get(i).get_X() - clusters.get(i + 1).get_X()) + (clusters.get(i).get_Y() - clusters.get(i + 1).get_Y()) * (clusters.get(i).get_Y() - clusters.get(i + 1).get_Y())); + } + generateHitList(); + generateInterClusterList(); + } + + public TrackCandidate(ArrayList hitslist) { + hits.addAll(hitslist); + } + + private void generateHitList() { + for (AHDCCluster cluster : _Clusters) { + for (PreCluster preCluster : cluster.get_PreClusters_list()) { + hits.addAll(preCluster.get_hits_list()); + } + } + } + + private void generateInterClusterList() { + // Use hits to generate preclusters + PreClusterFinder preclusterfinder = new PreClusterFinder(); + preclusterfinder.findPreclusters(hits); + ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); + + // Use preclusters to generate interclusters + PreClustering preClustering = new PreClustering(); + this._InterClusters = preClustering.mergePreclusters(AHDC_PreClusters); + } + + public ArrayList getHits() { + return hits; + } + + @Override + public String toString() { + return "TrackCandidate{" + "_Clusters=" + _Clusters + '}'; + } + + public double get_Distance() { + return _Distance; + } + + public List get_Clusters() { + return _Clusters; + } + + public List getInterclusters() { + return _InterClusters; + } + + public boolean is_Used() { + return _Used; + } + + public void set_Used(boolean _Used) { + this._Used = _Used; + } + + public CandidateType getType() { + return type; + } + + public void setType(CandidateType type) { + this.type = type; + } + + public List getAtofHits() { + return atofHits; + } + + public void addAtofHit(AtofHitStub hit) { + atofHits.add(hit); + } + + public void set_trackId(int _trackId) { + trackId = _trackId; + // set trackId for clusters + for(AHDCCluster cluster : this._Clusters) { + cluster.set_trackId(_trackId); + } + // set trackId for interclusters + for(InterCluster interCluster : this._InterClusters) { + interCluster.setTrackId(_trackId); + } + // set trackId for hits + for (Hit hit : this.hits) { + hit.setTrackId(_trackId); + } + } + public void set_n_hits(int _n_hits) { n_hits = _n_hits;} + public void set_sum_adc(int _sum_adc) { sum_adc = _sum_adc;} + public int get_trackId() {return trackId;} + public int get_n_hits() { + if (hits == null) { + return 0; + } + return hits.size(); + } + public int get_sum_adc() { + if (hits == null || hits.isEmpty()) { + return 0; + } + int sum = 0; + for (Hit h : hits) { + sum += (int) Math.round(h.getADC()); + } + return sum; + } + + // AHDC::aiprediction + public void set_predicted_ATOF_sector(int s) {predicted_ATOF_sector = s;} + public void set_predicted_ATOF_layer(int l) {predicted_ATOF_layer = l;} + public void set_predicted_ATOF_wedge(int w) {predicted_ATOF_wedge = w;} + public int get_predicted_ATOF_sector() {return predicted_ATOF_sector;} + public int get_predicted_ATOF_layer() {return predicted_ATOF_layer;} + public int get_predicted_ATOF_wedge() {return predicted_ATOF_wedge;} + +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java new file mode 100644 index 0000000000..087de22023 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/AITrackFinder.java @@ -0,0 +1,93 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.AI.AIPrediction; +import org.jlab.rec.ahdc.AI.InterCluster; +import org.jlab.rec.ahdc.AI.ModelTrackFinding; +import org.jlab.rec.ahdc.AI.PreClustering; +import org.jlab.rec.ahdc.AI.TrackCandidatesGenerator; +import org.jlab.rec.ahdc.AI.TrackPrediction; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.TrackCandidate; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; + +public class AITrackFinder implements TrackFinder { + + private static final Logger LOGGER = Logger.getLogger(AITrackFinder.class.getName()); + + private static final double TRACK_FINDING_AI_THRESHOLD = 0.2; + private static final int MAX_HITS_FOR_AI = 300; + + private final ModelTrackFinding model; + private final TrackFinder fallback; + + public AITrackFinder() { + this.model = new ModelTrackFinding(); + this.fallback = new DistanceTrackFinder(); + } + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + // Safety: too many hits → fall back to conventional track finding for this event + if (hits.size() > MAX_HITS_FOR_AI) { + LOGGER.info("Too many AHDC_Hits in AHDC::adc, rely on conventional track finding for this event"); + return fallback.findTracks(hits); + } + + // Preclustering + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + // 1) Create inter-clusters from pre-clusters + PreClustering preClustering = new PreClustering(); + ArrayList inter_clusters = preClustering.mergePreclusters(preclusters); + + // 2) Create track candidates from inter-clusters + ArrayList> tracks_candidates = new ArrayList<>(); + TrackCandidatesGenerator trackCandidatesGenerator = new TrackCandidatesGenerator(); + boolean success = trackCandidatesGenerator.getAllPossibleTrack(inter_clusters, tracks_candidates); + + if (!success) { + LOGGER.severe("Too many track candidates find by the AI, exiting..."); + return TrackFinderResult.invalid(); + } + + // 3) Use AI model to evaluate track candidates + ArrayList predictions; + try { + AIPrediction aiPrediction = new AIPrediction(); + predictions = aiPrediction.prediction(tracks_candidates, model); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // 4) Select good tracks via greedy non-overlap: sort predictions by score + // descending, accept the highest-scoring prediction, mark its PreClusters + // as claimed, and skip any later prediction that reuses a claimed PreCluster. + // The AI candidate generator routinely emits overlapping predictions (each + // PreCluster can feed several combinations), and because set_trackId mutates + // the shared Hit references in place, a naive "accept all above threshold" + // pass would let later tracks silently steal earlier tracks' hits and leave + // them orphaned in AHDC::hits. Greedy selection enforces one-hit-one-track. + predictions.sort((a, b) -> Float.compare(b.getPrediction(), a.getPrediction())); + Set claimedPreclusters = new HashSet<>(); + ArrayList tracks = new ArrayList<>(); + for (TrackPrediction t : predictions) { + if (t.getPrediction() <= TRACK_FINDING_AI_THRESHOLD) continue; + boolean overlaps = false; + for (PreCluster pc : t.getPreclusters()) { + if (claimedPreclusters.contains(pc)) { overlaps = true; break; } + } + if (overlaps) continue; + claimedPreclusters.addAll(t.getPreclusters()); + tracks.add(new TrackCandidate(t.getClusters())); + } + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java new file mode 100644 index 0000000000..ac64612dcc --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/DistanceTrackFinder.java @@ -0,0 +1,31 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCClusterFinder; +import org.jlab.rec.ahdc.Distance.Distance; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.TrackCandidate; + +import java.util.ArrayList; + +public class DistanceTrackFinder implements TrackFinder { + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + AHDCClusterFinder cf = new AHDCClusterFinder(); + cf.findCluster(preclusters); + ArrayList clusters = cf.get_AHDCClusters(); + + Distance distance = new Distance(); + distance.find_track(clusters); + ArrayList tracks = distance.get_AHDCTrackCandidates(); + + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java new file mode 100644 index 0000000000..92432c0f1f --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/GNNTrackFinder.java @@ -0,0 +1,52 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.AI.GNNPrediction; +import org.jlab.rec.ahdc.AI.ModelTrackFindingGNN; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.Track.TrackCandidate; + +import java.util.ArrayList; +import java.util.logging.Logger; + +/** GravNet-based track finder. Builds a per-event hit graph from the AHDC + * hits and (when present) the ATOF::hits bank, runs the exported edge + * scorer, extracts tracks via connected components on edges with score + * ≥ 0.1, and packages each surviving track as a {@link TrackCandidate} + * backed by per-superlayer {@link org.jlab.rec.ahdc.AHDCCluster.AHDCCluster}s. + */ +public class GNNTrackFinder implements TrackFinder { + + private static final Logger LOGGER = Logger.getLogger(GNNTrackFinder.class.getName()); + + /** Above this hit count the graph builder + GNN inference is too slow to + * be useful, so the event is skipped (no GNN tracks produced). */ + private static final int MAX_HITS_FOR_GNN = 500; + + private final ModelTrackFindingGNN model; + private final GNNPrediction predictor; + + public GNNTrackFinder() { + this.model = new ModelTrackFindingGNN(); + this.predictor = new GNNPrediction(); + } + + /** Without an ATOF bank the GNN still runs on AHDC-only graphs. */ + @Override + public TrackFinderResult findTracks(ArrayList hits) { + return findTracks(hits, null); + } + + @Override + public TrackFinderResult findTracks(ArrayList ahdcHits, DataBank atofHitsBank) { + if (ahdcHits == null || ahdcHits.size() > MAX_HITS_FOR_GNN) { + if (ahdcHits != null) { + LOGGER.info("Too many AHDC_Hits in AHDC::hits, skipping GNN track finding for this event"); + } + return TrackFinderResult.ok(new ArrayList<>()); + } + + ArrayList tracks = predictor.prediction(ahdcHits, atofHitsBank, model); + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java new file mode 100644 index 0000000000..13ca914c05 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/HoughTrackFinder.java @@ -0,0 +1,31 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCClusterFinder; +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.HoughTransform.HoughTransform; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; +import org.jlab.rec.ahdc.Track.TrackCandidate; + +import java.util.ArrayList; + +public class HoughTrackFinder implements TrackFinder { + + @Override + public TrackFinderResult findTracks(ArrayList hits) { + PreClusterFinder pcf = new PreClusterFinder(); + pcf.findPreclusters(hits); + ArrayList preclusters = pcf.get_AHDCPreClusters(); + + AHDCClusterFinder cf = new AHDCClusterFinder(); + cf.findCluster(preclusters); + ArrayList clusters = cf.get_AHDCClusters(); + + HoughTransform hough = new HoughTransform(); + hough.find_tracks(clusters); + ArrayList tracks = hough.get_AHDCTrackCandidates(); + + return TrackFinderResult.ok(tracks); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java new file mode 100644 index 0000000000..e0d40f803d --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinder.java @@ -0,0 +1,20 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.io.base.DataBank; +import org.jlab.rec.ahdc.Hit.Hit; + +import java.util.ArrayList; + +public interface TrackFinder { + + /** AHDC-only track finding. Implementations that don't need ATOF context + * (MLP / Distance / Hough) only need to override this method. */ + TrackFinderResult findTracks(ArrayList hits); + + /** Track finding with ATOF context (e.g. GNN, which builds a joint + * AHDC + ATOF hit graph). The default delegates to the AHDC-only + * version, ignoring the ATOF bank. */ + default TrackFinderResult findTracks(ArrayList ahdcHits, DataBank atofHitsBank) { + return findTracks(ahdcHits); + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java new file mode 100644 index 0000000000..0489d5cef0 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/TrackFinding/TrackFinderResult.java @@ -0,0 +1,33 @@ +package org.jlab.rec.ahdc.TrackFinding; + +import org.jlab.rec.ahdc.Track.TrackCandidate; + +import java.util.Collections; +import java.util.List; + +public class TrackFinderResult { + + private final List tracks; + private final boolean valid; + + public TrackFinderResult(List tracks, boolean valid) { + this.tracks = tracks; + this.valid = valid; + } + + public static TrackFinderResult ok(List tracks) { + return new TrackFinderResult(tracks, true); + } + + public static TrackFinderResult invalid() { + return new TrackFinderResult(Collections.emptyList(), false); + } + + public List getTracks() { + return tracks; + } + + public boolean isValid() { + return valid; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java index 66b965825c..f5beeb1591 100644 --- a/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/service/ahdc/AHDCEngine.java @@ -5,23 +5,13 @@ import org.jlab.io.base.DataEvent; import org.jlab.io.hipo.HipoDataSource; import org.jlab.io.hipo.HipoDataSync; -import org.jlab.rec.ahdc.AI.*; import org.jlab.rec.ahdc.Banks.RecoBankWriter; -import org.jlab.rec.ahdc.Cluster.Cluster; -import org.jlab.rec.ahdc.Cluster.ClusterFinder; -import org.jlab.rec.ahdc.DocaCluster.DocaClusterRefiner; -import org.jlab.rec.ahdc.DocaCluster.DocaCluster; -import org.jlab.rec.ahdc.Distance.Distance; -import org.jlab.rec.ahdc.HelixFit.HelixFitJava; import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.Hit.HitReader; -import org.jlab.rec.ahdc.HoughTransform.HoughTransform; -import org.jlab.rec.ahdc.PreCluster.PreCluster; -import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; -import org.jlab.rec.ahdc.Track.Track; -import org.jlab.rec.ahdc.ModeTrackFinding; import java.io.File; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Logger; import org.jlab.detector.calib.utils.DatabaseConstantProvider; @@ -32,22 +22,14 @@ /** AHDCEngine reconstruction service. * - * AHDC Reconstruction using only AHDC information. - * - * Reconstruction utilizing other detectors (i.e. ATOF) are - * implemented in ALERTEngine. + * Reads AHDC::adc, applies calibration, and writes calibrated AHDC::hits. * + * Track finding (preclustering, AI/CV track finders, DOCA refinement, + * helix fit) lives in ALERTEngine. */ public class AHDCEngine extends ReconstructionEngine { static final Logger LOGGER = Logger.getLogger(AHDCEngine.class.getName()); - private boolean simulation = false; - - private ModelTrackFinding modelTrackFinding; - private ModeTrackFinding modeTrackFinding = ModeTrackFinding.AI_Track_Finding; - static final double TRACK_FINDING_AI_THRESHOLD = 0.2; - static final int MAX_HITS_FOR_AI = 300; - private AlertDCDetector factory = null; private ModeAHDC ahdcExtractor = new ModeAHDC(); @@ -62,20 +44,11 @@ public class AHDCEngine extends ReconstructionEngine { public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); } - public boolean init(ModeTrackFinding m) { - modeTrackFinding = m; - return init(); - } - @Override public boolean init() { factory = (new AlertDCFactory()).createDetectorCLAS(new DatabaseConstantProvider()); - String modeConfig = this.getEngineConfigString("Mode"); - if (modeConfig != null) modeTrackFinding = ModeTrackFinding.valueOf(modeConfig); - if (modeTrackFinding == ModeTrackFinding.AI_Track_Finding) modelTrackFinding = new ModelTrackFinding(); - Map tableMap = new HashMap<>(); tableMap.put("/calibration/alert/ahdc/time_offsets", 3); tableMap.put("/calibration/alert/ahdc/time_to_distance_wire", 3); @@ -84,8 +57,8 @@ public boolean init() { tableMap.put("/calibration/alert/ahdc/time_over_threshold", 3); requireConstants(tableMap); - this.getConstantsManager().setVariation("default"); - this.registerOutputBank("AHDC::hits","AHDC::preclusters","AHDC::clusters","AHDC::track","AHDC::mc","AHDC::ai:prediction","AHDC::interclusters","AHDC::docaclusters"); + this.getConstantsManager().setVariation("default"); + this.registerOutputBank("AHDC::hits"); return true; } @@ -93,8 +66,6 @@ public boolean init() { @Override public boolean processDataEvent(DataEvent event) { - if(event.hasBank("MC::Particle")) simulation = true; - ahdcExtractor.update(30, null, event, "AHDC::wf", "AHDC::adc"); if (event.hasBank("RUN::config")) { @@ -115,155 +86,15 @@ public boolean processDataEvent(DataEvent event) { } if (event.hasBank("AHDC::adc")) { - // I) Read raw hits + boolean simulation = event.hasBank("MC::Particle"); HitReader hitReader = new HitReader(event, factory, simulation, ahdcRawHitCutsTable, ahdcTimeOffsetsTable, ahdcTimeToDistanceWireTable, ahdcTimeOverThresholdTable, ahdcAdcGainsTable); ArrayList AHDC_Hits = hitReader.get_AHDCHits(); - // II) Create PreClusters - PreClusterFinder preclusterfinder = new PreClusterFinder(); - preclusterfinder.findPreclusters(AHDC_Hits); - ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); - - - // III) Track Finding: Input = PreClusters, Output = Tracks - // During track finding we build Clusters and InterClusters. Each of these objects must be assigned a Track ID so we can: - // - identify which track they belong to, - // - write them properly into the output banks later, - // - and reuse them downstream in the ALERT Engine. - // - // If using AI-based track finding, tracks are identified using inter-clusters. - // Otherwise, the conventional methods (Hough Transform or distance) use clusters. - - // Safety check: if too many hits, rely on conventional track finding - ModeTrackFinding effectiveMode = modeTrackFinding; - if (AHDC_Hits.size() > MAX_HITS_FOR_AI) { - LOGGER.info("Too many AHDC_Hits in AHDC::adc, rely on conventional track finding for this event"); - effectiveMode = ModeTrackFinding.CV_Distance; - } - - ArrayList AHDC_Tracks = new ArrayList<>(); - - if (effectiveMode == ModeTrackFinding.AI_Track_Finding) { - // 1) Create inter-clusters from pre-clusters - PreClustering preClustering = new PreClustering(); - ArrayList inter_clusters = preClustering.mergePreclusters(AHDC_PreClusters); - - // 2) Create track candidates from inter-clusters - ArrayList> tracks_candidates = new ArrayList<>(); - TrackCandidatesGenerator trackCandidatesGenerator = new TrackCandidatesGenerator(); - boolean success = trackCandidatesGenerator.getAllPossibleTrack(inter_clusters, tracks_candidates); - - if (!success) { - LOGGER.severe("Too many track candidates find by the AI, exiting..."); - return false; - } - - // 3) Use AI model to evaluate track candidates - ArrayList predictions = new ArrayList<>(); - try { - AIPrediction aiPrediction = new AIPrediction(); - predictions = aiPrediction.prediction(tracks_candidates, modelTrackFinding); - } catch (Exception e) { - throw new RuntimeException(e); - } - - // 4) Select good tracks via greedy non-overlap: sort predictions by score - // descending, accept the highest-scoring prediction, mark its PreClusters - // as claimed, and skip any later prediction that reuses a claimed PreCluster. - // The AI candidate generator routinely emits overlapping predictions (each - // PreCluster can feed several combinations), and because set_trackId mutates - // the shared Hit references in place, a naive "accept all above threshold" - // pass would let later tracks silently steal earlier tracks' hits and leave - // them orphaned in AHDC::hits. Greedy selection enforces one-hit-one-track. - predictions.sort((a, b) -> Float.compare(b.getPrediction(), a.getPrediction())); - Set claimedPreclusters = new HashSet<>(); - for (TrackPrediction t : predictions) { - if (t.getPrediction() <= TRACK_FINDING_AI_THRESHOLD) continue; - boolean overlaps = false; - for (PreCluster pc : t.getPreclusters()) { - if (claimedPreclusters.contains(pc)) { overlaps = true; break; } - } - if (overlaps) continue; - claimedPreclusters.addAll(t.getPreclusters()); - AHDC_Tracks.add(new Track(t.getClusters())); - } - } - else { - // Conventional Track Finding: Hough Transform or Distance: use cluster informations to find tracks - // 1) Create clusters from pre-clusters - ClusterFinder clusterfinder = new ClusterFinder(); - clusterfinder.findCluster(AHDC_PreClusters); - ArrayList AHDC_Clusters = clusterfinder.get_AHDCClusters(); - - // 2) Find tracks using the selected conventional method - if (effectiveMode == ModeTrackFinding.CV_Distance) { - Distance distance = new Distance(); - distance.find_track(AHDC_Clusters); - AHDC_Tracks = distance.get_AHDCTracks(); - } - else if (effectiveMode == ModeTrackFinding.CV_Hough) { - HoughTransform houghtransform = new HoughTransform(); - houghtransform.find_tracks(AHDC_Clusters); - AHDC_Tracks = houghtransform.get_AHDCTracks(); - } - } - - - //Temporary track method ONLY for MC with no background; - //AHDC_Tracks.add(new Track(AHDC_Hits)); - - // V) Global fit - int trackid = 0; - ArrayList all_docaClusters = new ArrayList<>(); - AHDC_Tracks.removeIf(track -> track.get_Clusters().size() < 3); - for (Track track : AHDC_Tracks) { - trackid++; - track.set_trackId(trackid); - List originalClusters = track.get_Clusters(); - ArrayList docaClusters = DocaClusterRefiner.buildRefinedClusters(originalClusters); - all_docaClusters.addAll(docaClusters); - if (docaClusters == null || docaClusters.size() < 3 || originalClusters == null || originalClusters.size() < 3) { - // not enough points, skip helix fit - continue; - } - HelixFitJava h = new HelixFitJava(); - track.setPositionAndMomentum(h.helix_fit_with_doca_selection(docaClusters, 1)); - } - - // VII) Write bank RecoBankWriter writer = new RecoBankWriter(); - - DataBank recoHitsBank = writer.fillAHDCHitsBank(event, AHDC_Hits); - DataBank recoPreClusterBank = writer.fillPreClustersBank(event, AHDC_PreClusters); - ArrayList AHDC_Clusters = new ArrayList<>(); - for (Track track : AHDC_Tracks) { - AHDC_Clusters.addAll(track.get_Clusters()); - } - DataBank recoClusterBank = writer.fillClustersBank(event, AHDC_Clusters); - DataBank recoTracksBank = writer.fillAHDCTrackBank(event, AHDC_Tracks); - DataBank clustersDocaBank = writer.fillAHDCDocaClustersBank(event, all_docaClusters); - - ArrayList all_interclusters = new ArrayList<>(); - for (Track track : AHDC_Tracks) { - all_interclusters.addAll(track.getInterclusters()); - } - DataBank recoInterClusterBank = writer.fillInterClusterBank(event, all_interclusters); - - //event.removeBanks("AHDC::hits","AHDC::preclusters","AHDC::clusters","AHDC::track","AHDC::kftrack","AHDC::mc","AHDC::ai:prediction"); - event.appendBank(recoHitsBank); - event.appendBank(recoPreClusterBank); - event.appendBank(recoClusterBank); - event.appendBank(recoTracksBank); - event.appendBank(recoInterClusterBank); - event.appendBank(clustersDocaBank); - - if (simulation) { - DataBank recoMCBank = writer.fillAHDCMCTrackBank(event); - event.appendBank(recoMCBank); - } - + DataBank recoHitsBank = writer.fillAHDCHitsBank(event, AHDC_Hits); + if (recoHitsBank != null) event.appendBank(recoHitsBank); } return true; } @@ -274,7 +105,6 @@ public static void main(String[] args) { int nEvent = 0; int maxEvent = 10; - int myEvent = 3; String inputFile = "output1.hipo"; String outputFile = "output.hipo"; @@ -286,24 +116,17 @@ public static void main(String[] args) { HipoDataSource reader = new HipoDataSource(); - // en.init(); - en.init(ModeTrackFinding.AI_Track_Finding); + en.init(); reader.open(inputFile); - // SchemaFactory factory = reader.getReader().getSchemaFactory(); HipoDataSync writer = new HipoDataSync(); writer.open(outputFile); while (reader.hasEvent() && nEvent < maxEvent) { nEvent++; - // if (nEvent % 100 == 0) System.out.println("nEvent = " + nEvent); DataEvent event = reader.getNextEvent(); System.out.println("Event: " + nEvent); - // if (nEvent != myEvent) continue; - // System.out.println("*********** NEXT EVENT ************"); - // event.show(); - en.processDataEvent(event); writer.writeEvent(event); @@ -312,4 +135,4 @@ public static void main(String[] args) { System.out.println("finished " + (System.nanoTime() - starttime) * Math.pow(10, -9)); } -} \ No newline at end of file +} diff --git a/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java b/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java index a899e20c8a..c77fbff99a 100644 --- a/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/service/alert/ALERTEngine.java @@ -24,13 +24,29 @@ import org.jlab.rec.alert.banks.RecoBankWriter; import org.jlab.rec.alert.projections.TrackProjector; import org.jlab.rec.atof.hit.ATOFHit; +import org.jlab.rec.ahdc.AI.InterCluster; +import org.jlab.rec.ahdc.AHDCCluster.AHDCCluster; +import org.jlab.rec.ahdc.DocaCluster.DocaCluster; +import org.jlab.rec.ahdc.DocaCluster.DocaClusterRefiner; +import org.jlab.rec.ahdc.HelixFit.HelixFitJava; import org.jlab.rec.ahdc.KalmanFilter.KalmanFilter; +import org.jlab.rec.ahdc.ModeTrackFinding; +import org.jlab.rec.ahdc.PreCluster.PreCluster; +import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.TrackFinding.AITrackFinder; +import org.jlab.rec.ahdc.TrackFinding.DistanceTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.GNNTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.HoughTrackFinder; +import org.jlab.rec.ahdc.TrackFinding.TrackFinder; +import org.jlab.rec.ahdc.TrackFinding.TrackFinderResult; import org.jlab.geom.detector.alert.AHDC.AlertDCDetector; import org.jlab.geom.detector.alert.AHDC.AlertDCFactory; import org.jlab.rec.ahdc.Track.Track; +import org.jlab.rec.ahdc.Track.TrackCandidate; import org.jlab.clas.pdg.PDGDatabase; import org.jlab.clas.pdg.PDGParticle; +import java.util.List; import java.util.logging.Logger; @@ -75,6 +91,11 @@ public class ALERTEngine extends ReconstructionEngine { private ModelTrackMatching modelTrackMatching; private ModelPrePID modelPrePID; + // AHDC track-finding strategy (driven by ALERT.Mode YAML key) + private TrackFinder trackFinder; + private final org.jlab.rec.ahdc.Banks.RecoBankWriter ahdcWriter + = new org.jlab.rec.ahdc.Banks.RecoBankWriter(); + // AHDC calibration table (refreshed on run change) private IndexedTable ahdcAdcGainsTable; @@ -106,19 +127,33 @@ public boolean init() { modelPrePID = new ModelPrePID(); AlertTOFFactory factory = new AlertTOFFactory(); + + // One CCDB session for both ATOF and AHDC geometry. DatabaseConstantProvider cp = new DatabaseConstantProvider(11, "default"); ATOF = factory.createDetectorCLAS(cp); - AHDC = (new AlertDCFactory()).createDetectorCLAS(new DatabaseConstantProvider()); + AHDC = (new AlertDCFactory()).createDetectorCLAS(cp); Map tableMap = new HashMap<>(); tableMap.put("/calibration/alert/ahdc/gains", 3); requireConstants(tableMap); this.getConstantsManager().setVariation("default"); - if(this.getEngineConfigString("Mode")!=null) { - //if (Objects.equals(this.getEngineConfigString("Mode"), Mode.AI_Track_Finding.name())) - // mode = Mode.AI_Track_Finding; + ModeTrackFinding mode = ModeTrackFinding.MLP_Track_Finding; + String modeConfig = this.getEngineConfigString("Mode"); + if (modeConfig != null) mode = ModeTrackFinding.valueOf(modeConfig); + switch (mode) { + case MLP_Track_Finding: trackFinder = new AITrackFinder(); break; + case CV_Distance: trackFinder = new DistanceTrackFinder(); break; + case CV_Hough: trackFinder = new HoughTrackFinder(); break; + case GNN_Track_Finding: trackFinder = new GNNTrackFinder(); break; } + + this.registerOutputBank( + "AHDC::preclusters", "AHDC::clusters", "AHDC::track", + "AHDC::interclusters", "AHDC::docaclusters", "AHDC::ai:prediction", + "AHDC::mc", "AHDC::kftrack", + "ALERT::projections", "ALERT::ai:projections", "ALERT::prePID"); + return true; } @@ -134,9 +169,7 @@ public boolean init() { @Override public boolean processDataEvent(DataEvent event) { - if (!event.hasBank("AHDC::adc")) - return false; - if (!event.hasBank("ATOF::tdc")) + if (!event.hasBank("AHDC::adc")) return false; if (!event.hasBank("RUN::config")) { @@ -154,7 +187,129 @@ public boolean processDataEvent(DataEvent event) { run.set(newRun); ahdcAdcGainsTable = this.getConstantsManager().getConstants(newRun, "/calibration/alert/ahdc/gains"); } - + + // =========================================================================== + // AHDC track-finding pipeline (preclustering, track finder, DOCA, helix fit) + // Originally lived in AHDCEngine; runs here so AHDCEngine is hits-only. + // Reads AHDC::hits produced by AHDCEngine, mutates Hit.trackId during finding, + // then rewrites AHDC::hits and writes the cluster/track/intercluster banks. + // =========================================================================== + boolean simulation = event.hasBank("MC::Particle"); + + if (event.hasBank("AHDC::hits")) { + + // I) Reconstruct Hit list from AHDC::hits bank + DataBank ahdcHitBank = event.getBank("AHDC::hits"); + ArrayList AHDC_Hits = new ArrayList<>(); + for (int row = 0; row < ahdcHitBank.rows(); row++) { + int id = ahdcHitBank.getShort("id", row); + int superlayer = ahdcHitBank.getByte("superlayer", row); + int layer = ahdcHitBank.getByte("layer", row); + int wire = ahdcHitBank.getInt("wire", row); + int adc = ahdcHitBank.getInt("adc", row); + double doca = ahdcHitBank.getDouble("doca", row); + double time = ahdcHitBank.getDouble("time", row); + double tot = ahdcHitBank.getDouble("timeOverThreshold", row); + Hit hit = new Hit(id, superlayer, layer, wire, doca, adc, time); + hit.setWirePosition(AHDC); + hit.setADC(adc); + hit.setToT(tot); + AHDC_Hits.add(hit); + } + + // II) Track Finding via the strategy selected in init() (ALERT.Mode YAML key). + // The implementation owns its own preclustering, cluster building, and any + // mode-specific safety fallbacks (e.g. AITrackFinder delegates to Distance + // when the hit count exceeds its MAX_HITS_FOR_AI threshold). The ATOF bank + // is passed for finders that build joint AHDC+ATOF graphs (GNN); the + // AHDC-only finders inherit the default and ignore it. + DataBank atofHitsBankForGNN = event.hasBank("ATOF::hits") ? event.getBank("ATOF::hits") : null; + TrackFinderResult trackResult = trackFinder.findTracks(AHDC_Hits, atofHitsBankForGNN); + if (!trackResult.isValid()) { + return false; + } + ArrayList AHDC_Candidates = new ArrayList<>(trackResult.getTracks()); + + // Preclusters are also written to AHDC::preclusters as a diagnostic bank; + // PreClusterFinder is idempotent on Hit.use, so re-running it here is safe. + PreClusterFinder preclusterfinder = new PreClusterFinder(); + preclusterfinder.findPreclusters(AHDC_Hits); + ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); + + // IV) Global fit: DOCA refinement + helix fit. + // Each surviving TrackCandidate is fitted into a Track; the fit is + // dispatched on the candidate's CandidateType (see the switch below). + int trackid = 0; + ArrayList all_docaClusters = new ArrayList<>(); + AHDC_Candidates.removeIf(cand -> cand.get_Clusters().size() < 3); + ArrayList AHDC_Tracks = new ArrayList<>(); + for (TrackCandidate cand : AHDC_Candidates) { + trackid++; + cand.set_trackId(trackid); + // Every surviving candidate yields an AHDC::track row, even if the + // helix fit below is skipped (its Track keeps zero parameters). + Track track = new Track(cand); + AHDC_Tracks.add(track); + List originalClusters = cand.get_Clusters(); + ArrayList docaClusters = DocaClusterRefiner.buildRefinedClusters(originalClusters); + all_docaClusters.addAll(docaClusters); + if (docaClusters == null || docaClusters.size() < 3 || originalClusters == null || originalClusters.size() < 3) { + // not enough points, skip helix fit + continue; + } + HelixFitJava h = new HelixFitJava(); + switch (cand.getType()) { + case AHDC_ATOF: + // EXTENSION HOOK: a future commit may incorporate + // cand.getAtofHits() as an additional fit constraint here. + // For now AHDC_ATOF fits exactly like AHDC_ONLY — the ATOF + // hits are carried on the candidate, not yet fitted. + case AHDC_ONLY: + case AHDC_VERTEX: + default: + track.setPositionAndMomentum(h.helix_fit_with_doca_selection(docaClusters, 1)); + break; + } + } + + // V) Replace AHDC::hits (now with trackId) and write track-finding output banks + DataBank recoHitsBank = ahdcWriter.fillAHDCHitsBank(event, AHDC_Hits); + DataBank recoPreClusterBank = ahdcWriter.fillPreClustersBank(event, AHDC_PreClusters); + ArrayList AHDC_Clusters = new ArrayList<>(); + for (Track track : AHDC_Tracks) { + AHDC_Clusters.addAll(track.get_Clusters()); + } + DataBank recoClusterBank = ahdcWriter.fillClustersBank(event, AHDC_Clusters); + DataBank recoTracksBank = ahdcWriter.fillAHDCTrackBank(event, AHDC_Tracks); + DataBank clustersDocaBank = ahdcWriter.fillAHDCDocaClustersBank(event, all_docaClusters); + + ArrayList all_interclusters = new ArrayList<>(); + for (Track track : AHDC_Tracks) { + all_interclusters.addAll(track.getInterclusters()); + } + DataBank recoInterClusterBank = ahdcWriter.fillInterClusterBank(event, all_interclusters); + + event.removeBank("AHDC::hits"); + event.appendBank(recoHitsBank); + event.appendBank(recoPreClusterBank); + event.appendBank(recoClusterBank); + event.appendBank(recoTracksBank); + event.appendBank(recoInterClusterBank); + event.appendBank(clustersDocaBank); + + if (simulation) { + DataBank recoMCBank = ahdcWriter.fillAHDCMCTrackBank(event); + event.appendBank(recoMCBank); + } + } + // =========================================================================== + + // ATOF-dependent processing follows. Bail out for events without ATOF::tdc + // so the AHDC track-finding output above stands on its own (matches the + // pre-refactor flow where AHDCEngine ran independently of ATOF presence). + if (!event.hasBank("ATOF::tdc")) + return false; + //Do we need to read the event vx,vy,vz? //If not, this part can be moved in the initialization of the engine. double eventVx=0,eventVy=0,eventVz=0; //They should be in CM @@ -196,50 +351,60 @@ public boolean processDataEvent(DataEvent event) { } if (interClusters.size() != 5) continue; + float[] pred; try { + pred = modelTrackMatching.prediction(interClusters); + } catch (TranslateException ex) { + LOGGER.warning(() -> "Exception in ALERTEngine track matching: " + ex); + continue; + } + int sector_pred = (int) pred[0]; + int layer_pred = (int) pred[1]; + int wedge_pred = (int) pred[2]; + + // The matching model's three argmax heads can land outside the ATOF + // ranges (sectors 0-14, layers 0-3, wedges 0-9) when the input + // interclusters fall outside its training distribution; the ATOFHit + // geometry lookup chain returns null on a miss and would NPE. + if (sector_pred < 0 || sector_pred >= 15 + || layer_pred < 0 || layer_pred >= 4 + || wedge_pred < 0 || wedge_pred >= 10) { + continue; + } - float[] pred = modelTrackMatching.prediction(interClusters); - int sector_pred = (int) pred[0]; - int layer_pred = (int) pred[1]; - int wedge_pred = (int) pred[2]; - - ATOFHit hit_pred = new ATOFHit(sector_pred, layer_pred, wedge_pred, 0, 0, 0, 0f, ATOF, null); - double pred_x = hit_pred.getX(); - double pred_y = hit_pred.getY(); - double pred_z = hit_pred.getZ(); + ATOFHit hit_pred = new ATOFHit(sector_pred, layer_pred, wedge_pred, 0, 0, 0, 0f, ATOF, null); + double pred_x = hit_pred.getX(); + double pred_y = hit_pred.getY(); + double pred_z = hit_pred.getZ(); - double threshold = 20.0; - double minDistanceSquared = threshold * threshold; + double threshold = 20.0; + double minDistanceSquared = threshold * threshold; - ATOFHit matchAtofHit = null; // Could be used later - int matchHitId = -1; + ATOFHit matchAtofHit = null; // Could be used later + int matchHitId = -1; - for (int k = 0; k < bank_ATOFHits.rows(); k++) { - int component = bank_ATOFHits.getInt("component", k); - if (component == 10) continue; + for (int k = 0; k < bank_ATOFHits.rows(); k++) { + int component = bank_ATOFHits.getInt("component", k); + if (component == 10) continue; - int sector = bank_ATOFHits.getInt("sector", k); - int layer = bank_ATOFHits.getInt("layer", k); + int sector = bank_ATOFHits.getInt("sector", k); + int layer = bank_ATOFHits.getInt("layer", k); - ATOFHit hit = new ATOFHit(sector, layer, component, 0, 0, 0, 0f, ATOF, null); + ATOFHit hit = new ATOFHit(sector, layer, component, 0, 0, 0, 0f, ATOF, null); - double dx = pred_x - hit.getX(); - double dy = pred_y - hit.getY(); - double dz = pred_z - hit.getZ(); + double dx = pred_x - hit.getX(); + double dy = pred_y - hit.getY(); + double dz = pred_z - hit.getZ(); - double distanceSquared = dx * dx + dy * dy + dz * dz; + double distanceSquared = dx * dx + dy * dy + dz * dz; - if (distanceSquared < minDistanceSquared) { - minDistanceSquared = distanceSquared; - matchAtofHit = hit; - matchHitId = bank_ATOFHits.getInt("id", k); - } + if (distanceSquared < minDistanceSquared) { + minDistanceSquared = distanceSquared; + matchAtofHit = hit; + matchHitId = bank_ATOFHits.getInt("id", k); } - matched_ATOF_hit_id.add(new Pair<>(track_id, matchHitId)); - - } catch (Exception ex) { - System.out.println("Exception in ALERTEngine processDataEvent: " + ex); // TODO: proper logging } + matched_ATOF_hit_id.add(new Pair<>(track_id, matchHitId)); } rbc.appendTrackMatchingAIBank(event, matched_ATOF_hit_id); @@ -359,15 +524,14 @@ public boolean processDataEvent(DataEvent event) { AHDC_hits.add(hit); } } - // Initialise the position and the momentum using the information of the AHDC::track - // position : mm - // momentum : MeV - // Invariant: AHDC_hits is non-empty. AHDCEngine's AI_Track_Finding path uses greedy + // Rebuild a hits-only TrackCandidate and wrap it in a Track seeded + // with the banked helix-fit position/momentum (mm / MeV). + // Invariant: AHDC_hits is non-empty. AHDCEngine's MLP_Track_Finding path uses greedy // non-overlap selection so each PreCluster (and thus each Hit) belongs to at most one // surviving track, so the set_trackId stamping is unambiguous and every AHDC::track - // row has matching AHDC::hits rows. If this invariant ever flips, the get(0) inside - // Track(ArrayList) fails loudly here, which is the right signal. - Track newTrack = new Track(AHDC_hits); + // row has matching AHDC::hits rows. + TrackCandidate newCandidate = new TrackCandidate(AHDC_hits); + Track newTrack = new Track(newCandidate); double[] vec = { trackBank.getFloat("x", row), trackBank.getFloat("y", row), diff --git a/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java b/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java index 0c372b919c..640ab7fea4 100644 --- a/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java +++ b/reconstruction/alert/src/test/java/org/jlab/service/alert/AHDCTest.java @@ -7,7 +7,6 @@ import org.jlab.detector.base.DetectorType; import org.jlab.analysis.physics.TestEvent; import org.jlab.service.ahdc.AHDCEngine; -import org.jlab.rec.ahdc.ModeTrackFinding; /** * @@ -23,12 +22,11 @@ public void run() { DataEvent event = TestEvent.get(DetectorType.AHDC); AHDCEngine engine = new AHDCEngine(); - engine.init(ModeTrackFinding.AI_Track_Finding); + engine.init(); engine.processDataEvent(event); event.show(); event.getBank("AHDC::hits").show(); - event.getBank("AHDC::clusters").show(); assertEquals(event.hasBank("FAKE::Bank"), false); assertEquals(event.hasBank("AHDC::wf"), true);