Load 3D image and 3D label for image classification

Hi all, I was trying to construct a workflow to train a UNET model for image segmentation. I have learned different dataSetIterator and ImageRecordReader from dl4j.examples or other online forums. But still confused about how to correctly load training data and label.

I have two folders, imagesPath contains training data, which are 100 3D images, another folder labelPath contains labels, which are also 100 3D labels. The training data and labels are paired with each other, such as image1.bin – label1.bin, image2.bin – label2.bin. Both image and label are 64x64x64 dimensions.

LOG.info("Loading training data & label and creating DataSetIterators");
int seed = 1234;   
Random randomNumber = new Random(seed);
    //1. Is it correct to load 3D label?
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    labelMaker.getLabelForPath(LabelPath);
    ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
    
    File imageDir = new File(imagesPath);
    FileSplit trainingData = new FileSplit(imageDir);
    try {
        recordReader.initialize(trainingData);
    } catch (IOException e1) {
	// TODO Auto-generated catch block
	e1.printStackTrace();
    }
    
    //2. Normalization          
    DataNormalization dataNormalization = new ImagePreProcessingScaler(0, 1);       
    DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1, true);
    dataNormalization.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(dataNormalization);

    //3. Build model
    ComputationGraph model = createTrainingModel();    // "createTrainingModel" creates a UNET model.
    model.init();
    model.setListeners(new ScoreIterationListener(10));
 
    //4. Start training
    LOG.info("Start to train the Mmachine learning model...");
    model.fit(dataSetIterator, epochs);
    dataSetIterator.reset();

    //5. Save model to path
    try {
    	boolean saveUpdater = true;
	    ModelSerializer.writeModel(model, modelPath, saveUpdater);
    } catch (IOException e) {
		// TODO Auto-generated catch block
		e.printStackTrace();
	}
1 Like

This means that you’ll get labels based on the path of the the file.

I’m really not sure that it will read a 64 channel image correctly. That isn’t a case that comes up very often, so there may be issues.

The rest looks like it may work though.