Class ModelTrackFindingGNN
java.lang.Object
org.jlab.rec.alert.AI.ModelTrackFindingGNN
DJL wrapper around the GravNet TorchScript model exported from
track-finding/export_torchscript.py. Runs per-event edge scoring.
Exported forward signature (see SingleGraphEdgeScorer):
forward(x: float32[N, 10], edge_index: int64[2, E], edge_attr: float32[E, 9])
-> float32[E] (sigmoid edge scores in [0, 1])
DJL (Deep Java Library) is the inference engine used here. See its
documentation for the Criteria / ZooModel / Predictor
/ NDArray APIs used below:
- DJL docs (master)
- Loading a model
(drives the
Criteria/ZooModelblock) - PyTorch engine (TorchScript inference)
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionfloat[]predictEdgeScores(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr) Score every edge in the input graph.
-
Constructor Details
-
ModelTrackFindingGNN
public ModelTrackFindingGNN()
-
-
Method Details
-
predictEdgeScores
public float[] predictEdgeScores(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr) throws Exception Score every edge in the input graph.- Parameters:
nodeFeatures- shape [N, 10] — see GNNConstants.NODE_FEAT_DIMedgeIndex- shape [2, E] — int64 source / destination node idsedgeAttr- shape [E, 9] — see GNNConstants.EDGE_FEAT_DIM- Returns:
- float[E] of sigmoid edge scores in [0, 1]
- Throws:
Exception
-