Class ModelTrackFindingGNN

java.lang.Object
org.jlab.rec.alert.AI.ModelTrackFindingGNN

public class ModelTrackFindingGNN extends Object
DJL wrapper around the GravNet TorchScript model exported from track-finding/export_torchscript.py. Runs per-event edge scoring. Exported forward signature (see SingleGraphEdgeScorer): forward(x: float32[N, 10], edge_index: int64[2, E], edge_attr: float32[E, 9]) -> float32[E] (sigmoid edge scores in [0, 1])

DJL (Deep Java Library) is the inference engine used here. See its documentation for the Criteria / ZooModel / Predictor / NDArray APIs used below:

  • Constructor Details

    • ModelTrackFindingGNN

      public ModelTrackFindingGNN()
  • Method Details

    • predictEdgeScores

      public float[] predictEdgeScores(float[][] nodeFeatures, long[][] edgeIndex, float[][] edgeAttr) throws Exception
      Score every edge in the input graph.
      Parameters:
      nodeFeatures - shape [N, 10] — see GNNConstants.NODE_FEAT_DIM
      edgeIndex - shape [2, E] — int64 source / destination node ids
      edgeAttr - shape [E, 9] — see GNNConstants.EDGE_FEAT_DIM
      Returns:
      float[E] of sigmoid edge scores in [0, 1]
      Throws:
      Exception