@agibsonccc Thank you for your response.
So my program is intended to be used by users who don’t write code. Here is a snippet.
All they should have to do is select the model from the load Model dropdown. That dropdown is populated from a models folder in the root of the program directory, and then this feeds into an inference node which should spit out the predictions.
These models are built and trained in the program
This all currently works fine, except for I have hard coded in the code the model input width, hight and channels.
public static ImgPlus[] inferImg(ComputationGraph model, ImgPlus image) throws IOException {
log.info("Call to infer image" + image.getName());
ImagePlus imp = ImageJFunctions.wrap(image, image.getName());
System.out.println("wrapping Image");
int imageWidth = (int) image.dimension(0);
int imageHeight = (int) image.dimension(1);
int modelWidth = 512;
int modelHeight = 512;
long modelChannels = 1;
int classes = 3;
if(modelWidth >= imageWidth && modelHeight >= imageHeight) {
System.out.println("Standard Inference");
NativeImageLoader loader = new NativeImageLoader(modelHeight, modelWidth, modelChannels);
ImageType imgType = new ImageType();
BufferedImage bufferedBGR = imgType.getBGRBufferedImage(imp.getProcessor().getBufferedImage());
INDArray imageNative = loader.asMatrix(bufferedBGR);
System.out.println("Loading Image");
imageNative = imageNative.reshape(1, modelChannels, modelHeight, modelWidth);
imageNative = imageNative.divi(255f);
System.out.println("About to call Model Output.... ");
try {
INDArray[] output = model.output(imageNative);
}catch (Exception ex) {
ex.printStackTrace();
}
INDArray[] output = model.output(imageNative);
System.out.println("Getting Model Output.... ");
ImagePlus imp2 = IJ.createImage(image.getName(), "32-bit", modelWidth, modelHeight, classes,1,1);
for (INDArray out : output) {
out = out.reshape(classes,modelHeight, modelWidth);
for(int c = 0; c < classes; c++) {
ImageProcessor ip = imp2.getStack().getProcessor(c+1);
for (int i = 0; i < modelWidth; i++) {
for (int j = 0; j < modelHeight; j++) {
float f = out.getFloat(new int[] {c, i, j });
ip.putPixelValue(j, i, f*255);
}
}
}
Img probImg = ImageJFunctions.wrap(imp2);
AxisType[] axisTypes = new AxisType[] {Axes.X, Axes.Y, Axes.CHANNEL};
ImgPlus probabilityImgPlus = new ImgPlus(probImg, image.getName(), axisTypes);
System.out.println("About to return images");
return new ImgPlus[] {probabilityImgPlus};
}
}else{
<Code here just handles breaking the input image up into chunks to process and stich back together.
}
}
At the moment I am just typing these modelWidth, modelHeight, modelChannels in when in debug mode. I was hoping there was a way of gettting this from the ComputationGraph object, so i didn’t need to pass it in during the method call. The users of this program might not know stuff about the model, just that they have to pick a model for “Organoids” or “Phase Segmentation”. The builder part of the program I am also hoping will allow non-coding users to build up new models by dragging in layers, so the range of models this should be able to infer from would be quite large, and not known at compile time.