diff --git a/etc/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt b/etc/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt new file mode 100644 index 0000000000..f7af492f18 Binary files /dev/null and b/etc/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt differ diff --git a/reconstruction/ai/pom.xml b/reconstruction/ai/pom.xml new file mode 100644 index 0000000000..b8e10fcba8 --- /dev/null +++ b/reconstruction/ai/pom.xml @@ -0,0 +1,43 @@ + + + 4.0.0 + + org.jlab.clas12.detector + clas12detector-ai + 13.3.0-SNAPSHOT + jar + + + org.jlab.clas12 + reconstruction + 13.3.0-SNAPSHOT + + + + + + org.jlab.clas + clas-utils + 13.3.0-SNAPSHOT + + + + org.jlab.clas + clas-io + 13.3.0-SNAPSHOT + + + + org.jlab.clas + clas-reco + 13.3.0-SNAPSHOT + + + + ai.djl + api + + + + + diff --git a/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java b/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java new file mode 100644 index 0000000000..12560c7a95 --- /dev/null +++ b/reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java @@ -0,0 +1,221 @@ +package org.jlab.service.ai; + +import ai.djl.MalformedModelException; +import java.nio.file.Paths; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import org.jlab.clas.reco.ReconstructionEngine; +import org.jlab.io.base.DataBank; +import org.jlab.io.base.DataEvent; +import org.jlab.utils.system.ClasUtilsFile; + +public class DCDenoiseEngine extends ReconstructionEngine { + + final static String[] BANK_NAMES = {"DC::tot","DC::tdc"}; + final static String CONF_THRESHOLD = "threshold"; + final static int LAYERS = 36; + final static int WIRES = 112; + + float threshold = 0.025f; + Criteria criteria; + ZooModel model; + PredictorPool predictors; + + public static class PredictorPool { + final BlockingQueue pool; + public PredictorPool(int size, ZooModel model) { + pool = new LinkedBlockingQueue<>(size); + for (int i=0; i predictor = predictors.get(); + for (int sector=0; sector<6; sector++) { + float[][] input = DCDenoiseEngine.read(bank, sector+1); + float[][] output = predictor.predict(input); + //System.out.println("IN:");show(input); + //System.out.println("OUT:");show(output); + update(bank, threshold, output, sector); + } + predictors.put(predictor); + event.removeBank(BANK_NAMES[i]); + event.appendBank(bank); + } + catch (TranslateException | InterruptedException e) { + throw new RuntimeException(e); + } + break; + } + } + return true; + } + + boolean processFakeEvent() { + try { + Predictor predictor = model.newPredictor(); + float[][] input = getAlmostStraightSlightlyBendingTrack(); + float[][] output = predictor.predict(input); + //System.out.println("IN:");show(input); + //System.out.println("OUT:");show(output); + } + catch (TranslateException e) { + throw new RuntimeException(e); + } + return true; + } + + /** + * Reject sub-threshold hits by modifying the bank's order variable. + * WARNING: This is not a full implementation of OrderType enum and + * all its names, but for now a copy of the subset in C++ DC denoising, see: + * https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198 + */ + static void update(DataBank b, float threshold, float[][] data, int sector) { + //System.out.println("IN:");b.show(); + for (int row=0; row getTranslator() { + return new Translator() { + @Override + public NDList processInput(TranslatorContext ctx, float[][] input) throws Exception { + NDManager manager = ctx.getNDManager(); + int height = input.length; + int width = input[0].length; + float[] flat = new float[height * width]; + for (int i = 0; i < height; i++) { + System.arraycopy(input[i], 0, flat, i * width, width); + } + NDArray x = manager.create(flat, new Shape(height, width)); + // Add batch and channel dims -> [1,1,36,112] + x = x.expandDims(0).expandDims(0); + return new NDList(x); + } + @Override + public float[][] processOutput(TranslatorContext ctx, NDList list) throws Exception { + NDArray result = list.get(0); + // Remove batch and channel dims -> [36,112] + result = result.squeeze(); + // Convert to 1D float array + float[] flat = result.toFloatArray(); + // Reshape manually into 2D array + long[] shape = result.getShape().getShape(); + int height = (int) shape[0]; + int width = (int) shape[1]; + float[][] output2d = new float[height][width]; + for (int i = 0; i < height; i++) { + System.arraycopy(flat, i * width, output2d[i], 0, width); + } + return output2d; + } + @Override + public Batchifier getBatchifier() { + return null; // no batching + } + }; + } + +} diff --git a/reconstruction/pom.xml b/reconstruction/pom.xml index 2d77b0d614..f62500c9d9 100644 --- a/reconstruction/pom.xml +++ b/reconstruction/pom.xml @@ -14,6 +14,7 @@ + ai dc tof cvt