|
| 1 | +package org.jlab.rec.ai.dcHBTrackState; |
| 2 | + |
| 3 | +import ai.djl.MalformedModelException; |
| 4 | +import ai.djl.ModelException; |
| 5 | +import ai.djl.inference.Predictor; |
| 6 | +import ai.djl.ndarray.*; |
| 7 | +import ai.djl.ndarray.types.Shape; |
| 8 | +import ai.djl.repository.zoo.*; |
| 9 | +import ai.djl.training.util.ProgressBar; |
| 10 | +import ai.djl.translate.*; |
| 11 | + |
| 12 | +import java.io.IOException; |
| 13 | +import java.nio.file.Paths; |
| 14 | +import java.util.concurrent.*; |
| 15 | +import java.util.logging.Level; |
| 16 | +import java.util.logging.Logger; |
| 17 | + |
| 18 | +import org.jlab.clas.reco.ReconstructionEngine; |
| 19 | +import org.jlab.io.base.*; |
| 20 | +import org.jlab.utils.system.ClasUtilsFile; |
| 21 | +import org.jlab.service.ai.PredictorPool; |
| 22 | + |
| 23 | + |
| 24 | +public class HBTrackStateEstimator{ |
| 25 | + // ---------------- Configuration ---------------- |
| 26 | + private String modelFile; |
| 27 | + |
| 28 | + ZooModel<float[][], float[]> model; |
| 29 | + PredictorPool predictors; |
| 30 | + |
| 31 | + // ---------------- Statistics for normalization of inputs and outputs of training samples ---------------- |
| 32 | + //// Note: Statistics of hits and track states depends on training samples, so need to be renewed when training samples change!!! |
| 33 | + // Statistics of hits: doca, xm, xr, yr, z |
| 34 | + private float[] HIT_MEAN; |
| 35 | + private float[] HIT_STD; |
| 36 | + |
| 37 | + // Statistics of track state: x, y, tx, ty, Q at z = 229 cm in the tilted sector frame |
| 38 | + private float[] STATE_MEAN; |
| 39 | + private float[] STATE_STD; |
| 40 | + |
| 41 | + public HBTrackStateEstimator(String modelFile){ |
| 42 | + this.modelFile = modelFile; |
| 43 | + |
| 44 | + if(modelFile.contains("inbending")){ |
| 45 | + HIT_MEAN = new float[]{0.52949071f, -45.771999f, -45.744694f, 57.336819f, 373.046356f}; |
| 46 | + HIT_STD = new float[]{0.40272677f, 47.928203f, 48.379021f, 32.645191f, 111.54994f}; |
| 47 | + STATE_MEAN = new float[]{-33.564308f, 0.010787425f, -0.15567796f, 0.0017755219f, 0.317530721f}; |
| 48 | + STATE_STD = new float[]{28.667490f, 17.761129f, 0.11940812f, 0.074460238f, 0.74185127f}; |
| 49 | + } |
| 50 | + else if(modelFile.contains("outbending")){ |
| 51 | + HIT_MEAN = new float[]{0.53385729f, -59.236504f, -59.200584f, 50.136387f, 372.057922f}; |
| 52 | + HIT_STD = new float[]{0.40085429f, 51.385536f, 51.840462f, 31.498201f, 111.50029f}; |
| 53 | + STATE_MEAN = new float[]{-39.446106f, 0.17583229f, -0.18047817f, 0.0014163271f, -0.082320645f}; |
| 54 | + STATE_STD = new float[]{33.733425f, 17.226780f, 0.14071095f, 0.072449364f, 0.72273886f}; |
| 55 | + } |
| 56 | + else{ |
| 57 | + Logger.getLogger(getClass().getName()).log(Level.SEVERE, "Name of model file does not include inbending or outbending"); |
| 58 | + } |
| 59 | + |
| 60 | + System.setProperty("ai.djl.pytorch.num_interop_threads", "1"); |
| 61 | + System.setProperty("ai.djl.pytorch.num_threads", "1"); |
| 62 | + System.setProperty("ai.djl.pytorch.graph_optimizer", "false"); |
| 63 | + try { |
| 64 | + String modelPath = ClasUtilsFile.getResourceDir("CLAS12DIR", "etc/data/nnet/hbTSE/" + modelFile); |
| 65 | + |
| 66 | + Criteria<float[][], float[]> criteria = Criteria.builder() |
| 67 | + .setTypes(float[][].class, float[].class) |
| 68 | + .optModelPath(Paths.get(modelPath)) |
| 69 | + .optEngine("PyTorch") |
| 70 | + .optTranslator(getTranslator()) |
| 71 | + .optProgress(new ProgressBar()) |
| 72 | + .build(); |
| 73 | + |
| 74 | + model = criteria.loadModel(); |
| 75 | + |
| 76 | + int threads = 64; |
| 77 | + predictors = new PredictorPool(threads, model); |
| 78 | + |
| 79 | + |
| 80 | + } catch (IOException | ModelException e) { |
| 81 | + Logger.getLogger(getClass().getName()).log(Level.SEVERE, null, e); |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + // ---------------- Translator ---------------- |
| 86 | + private Translator<float[][], float[]> getTranslator() { |
| 87 | + return new Translator<float[][], float[]>() { |
| 88 | + |
| 89 | + @Override |
| 90 | + public NDList processInput(TranslatorContext ctx, float[][] hits) { |
| 91 | + NDManager manager = ctx.getNDManager(); |
| 92 | + int n = hits.length; |
| 93 | + |
| 94 | + float[][] norm = new float[n][5]; |
| 95 | + for (int i = 0; i < n; i++) |
| 96 | + for (int j = 0; j < 5; j++) |
| 97 | + norm[i][j] = (hits[i][j] - HIT_MEAN[j]) / HIT_STD[j]; |
| 98 | + |
| 99 | + NDArray x = manager.create(norm); |
| 100 | + x = x.reshape(1, n, 5); |
| 101 | + return new NDList(x); |
| 102 | + } |
| 103 | + |
| 104 | + @Override |
| 105 | + public float[] processOutput(TranslatorContext ctx, NDList list) { |
| 106 | + NDArray out = list.get(0); // [1,5] |
| 107 | + float[] y = out.toFloatArray(); |
| 108 | + |
| 109 | + for (int i = 0; i < 5; i++) |
| 110 | + y[i] = y[i] * STATE_STD[i] + STATE_MEAN[i]; |
| 111 | + |
| 112 | + return y; |
| 113 | + } |
| 114 | + |
| 115 | + @Override |
| 116 | + public Batchifier getBatchifier() { |
| 117 | + return null; |
| 118 | + } |
| 119 | + }; |
| 120 | + } |
| 121 | + |
| 122 | + |
| 123 | + public float[] predict(float[][] hits) { |
| 124 | + if (hits == null) return null; |
| 125 | + |
| 126 | + if (hits.length == 0) { |
| 127 | + throw new IllegalArgumentException("HBInitialStateEstimator: empty hits"); |
| 128 | + } |
| 129 | + |
| 130 | + for (int i = 0; i < hits.length; i++) { |
| 131 | + if (hits[i].length != 5) { |
| 132 | + throw new IllegalArgumentException( |
| 133 | + "Expect 5 features per hit, got " + hits[i].length |
| 134 | + ); |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + try { |
| 139 | + Predictor<float[][], float[]> predictor = predictors.take(); |
| 140 | + try { |
| 141 | + return predictor.predict(hits); |
| 142 | + } finally { |
| 143 | + predictors.put(predictor); |
| 144 | + } |
| 145 | + } catch (TranslateException | InterruptedException e) { |
| 146 | + throw new RuntimeException(e); |
| 147 | + } |
| 148 | + } |
| 149 | +} |
| 150 | + |
0 commit comments