Hi all, I was trying to construct a workflow to train a UNET model for image segmentation. I have learned different dataSetIterator and ImageRecordReader from dl4j.examples or other online forums. But still confused about how to correctly load training data and label.
I have two folders, imagesPath
contains training data, which are 100 3D images, another folder labelPath
contains labels, which are also 100 3D labels. The training data and labels are paired with each other, such as image1.bin – label1.bin, image2.bin – label2.bin. Both image and label are 64x64x64 dimensions.
LOG.info("Loading training data & label and creating DataSetIterators");
int seed = 1234;
Random randomNumber = new Random(seed);
//1. Is it correct to load 3D label?
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
labelMaker.getLabelForPath(LabelPath);
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
File imageDir = new File(imagesPath);
FileSplit trainingData = new FileSplit(imageDir);
try {
recordReader.initialize(trainingData);
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
//2. Normalization
DataNormalization dataNormalization = new ImagePreProcessingScaler(0, 1);
DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1, true);
dataNormalization.fit(dataSetIterator);
dataSetIterator.setPreProcessor(dataNormalization);
//3. Build model
ComputationGraph model = createTrainingModel(); // "createTrainingModel" creates a UNET model.
model.init();
model.setListeners(new ScoreIterationListener(10));
//4. Start training
LOG.info("Start to train the Mmachine learning model...");
model.fit(dataSetIterator, epochs);
dataSetIterator.reset();
//5. Save model to path
try {
boolean saveUpdater = true;
ModelSerializer.writeModel(model, modelPath, saveUpdater);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}