Multi-Class Convolutional Neural Network for Segmentation

Sorry I have been going around in circles, and can’t find a suitable example

I have previously used a UNET to segment single-channel images with a binary classification output.

I am trying to adapt this to work with a 3 channel input, and a multi-class output.

  1. Although in the builder there is a variable for numClasses, this doesn’t seem to get incorporated into the model anywhere. I presume that was just left in the builder from a copy and paste job of another model.

  2. I figured given the output activation function would need to be changed from SIGMOID to SOFTMAX, and the loss function from XENT to MCXENT perhaps this is why the numClass variable is not used as it would not be as simple as just changing that value?

  3. So I copied the UNet class into my own class and altered the last two layers from

.addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(1)
                        .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                        .activation(Activation.IDENTITY).build(), "conv9-3")
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)
                        .activation(Activation.SIGMOID).build(), "conv10")

to

 .addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(numClasses)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.IDENTITY).build(), "conv9-3")
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .activation(Activation.SOFTMAX).build(), "conv10")

I am not sure if this is correct as I am failing at the next step, but included it just in case I have messed something up there.

Here is the model

=======================================================================================================
VertexName (VertexType) nIn,nOut TotalParams ParamsShape Vertex Inputs

input (InputVertex) -,- - - -
conv1-1 (ConvolutionLayer) 3,64 1,792 W:{64,3,3,3}, b:{1,64} [input]
conv1-2 (ConvolutionLayer) 64,64 36,928 W:{64,64,3,3}, b:{1,64} [conv1-1]
pool1 (SubsamplingLayer) -,- 0 - [conv1-2]
conv2-1 (ConvolutionLayer) 64,128 73,856 W:{128,64,3,3}, b:{1,128} [pool1]
conv2-2 (ConvolutionLayer) 128,128 147,584 W:{128,128,3,3}, b:{1,128} [conv2-1]
pool2 (SubsamplingLayer) -,- 0 - [conv2-2]
conv3-1 (ConvolutionLayer) 128,256 295,168 W:{256,128,3,3}, b:{1,256} [pool2]
conv3-2 (ConvolutionLayer) 256,256 590,080 W:{256,256,3,3}, b:{1,256} [conv3-1]
pool3 (SubsamplingLayer) -,- 0 - [conv3-2]
conv4-1 (ConvolutionLayer) 256,512 1,180,160 W:{512,256,3,3}, b:{1,512} [pool3]
conv4-2 (ConvolutionLayer) 512,512 2,359,808 W:{512,512,3,3}, b:{1,512} [conv4-1]
drop4 (DropoutLayer) -,- 0 - [conv4-2]
pool4 (SubsamplingLayer) -,- 0 - [drop4]
conv5-1 (ConvolutionLayer) 512,1024 4,719,616 W:{1024,512,3,3}, b:{1,1024} [pool4]
conv5-2 (ConvolutionLayer) 1024,1024 9,438,208 W:{1024,1024,3,3}, b:{1,1024} [conv5-1]
drop5 (DropoutLayer) -,- 0 - [conv5-2]
up6-1 (Upsampling2D) -,- 0 - [drop5]
up6-2 (ConvolutionLayer) 1024,512 2,097,664 W:{512,1024,2,2}, b:{1,512} [up6-1]
merge6 (MergeVertex) -,- - - [drop4, up6-2]
conv6-1 (ConvolutionLayer) 1024,512 4,719,104 W:{512,1024,3,3}, b:{1,512} [merge6]
conv6-2 (ConvolutionLayer) 512,512 2,359,808 W:{512,512,3,3}, b:{1,512} [conv6-1]
up7-1 (Upsampling2D) -,- 0 - [conv6-2]
up7-2 (ConvolutionLayer) 512,256 524,544 W:{256,512,2,2}, b:{1,256} [up7-1]
merge7 (MergeVertex) -,- - - [conv3-2, up7-2]
conv7-1 (ConvolutionLayer) 512,256 1,179,904 W:{256,512,3,3}, b:{1,256} [merge7]
conv7-2 (ConvolutionLayer) 256,256 590,080 W:{256,256,3,3}, b:{1,256} [conv7-1]
up8-1 (Upsampling2D) -,- 0 - [conv7-2]
up8-2 (ConvolutionLayer) 256,128 131,200 W:{128,256,2,2}, b:{1,128} [up8-1]
merge8 (MergeVertex) -,- - - [conv2-2, up8-2]
conv8-1 (ConvolutionLayer) 256,128 295,040 W:{128,256,3,3}, b:{1,128} [merge8]
conv8-2 (ConvolutionLayer) 128,128 147,584 W:{128,128,3,3}, b:{1,128} [conv8-1]
up9-1 (Upsampling2D) -,- 0 - [conv8-2]
up9-2 (ConvolutionLayer) 128,64 32,832 W:{64,128,2,2}, b:{1,64} [up9-1]
merge9 (MergeVertex) -,- - - [conv1-2, up9-2]
conv9-1 (ConvolutionLayer) 128,64 73,792 W:{64,128,3,3}, b:{1,64} [merge9]
conv9-2 (ConvolutionLayer) 64,64 36,928 W:{64,64,3,3}, b:{1,64} [conv9-1]
conv9-3 (ConvolutionLayer) 64,2 1,154 W:{2,64,3,3}, b:{1,2} [conv9-2]
conv10 (ConvolutionLayer) 2,5 15 W:{5,2,1,1}, b:{1,5} [conv9-3]
output (CnnLossLayer) -,- 0 - [conv10]

        Total Parameters:  31,032,849
    Trainable Parameters:  31,032,849
       Frozen Parameters:  0

=======================================================================================================

I am struggling to understand how to load in the labels with a multi-class dataset. I have my input data which are RGB images, and I have an 8-bit grayscale images for the labels. The pixel value in the label image refers to the class the pixel belongs to 0,1,2,3 or 4 (5 classes). I don’t know how to load this in, but from reading I am semi-sure these need to be one-hot encoded, so in my PathLabelGenerator I open the the label image and create a new NDArray using the channels index to identify what output class it should belong to

private Writable oneHotEncode(INDArray in) {
		
		System.out.println("Input Shape: " + in.shapeInfoToString());
		float[][][][] oneHotEncoded = new float[1][labelNum][height][width];
		
		INDArray outputShape = Nd4j.create(oneHotEncoded);
		System.out.println("Output Shape " + outputShape.shapeInfoToString());
		
		for(int w = 0; w < width; w++) {
			for(int h = 0; h < height; h++) {
				int value = (int) in.getFloat(new int[] {0, 0, h, w});
				oneHotEncoded[0][value][h][w] = 1f;
			}
		}
		
		return new NDArrayWritable(Nd4j.create(oneHotEncoded));
	}

I am guessing this is not the correct way as I am getting the following error

Exception in thread "Thread-7" java.lang.IllegalStateException: Input and label arrays do not have same shape: [4, 5, 512, 512] vs. [4, 1, 512, 512]
	at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:638)
	at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:337)
	at org.deeplearning4j.nn.layers.convolution.CnnLossLayer.backpropGradient(CnnLossLayer.java:67)
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doBackward(LayerVertex.java:148)
	at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2772)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1381)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1018)
	at org.cascade.ai.DeepLearningTrainerWorkspace.startTraining(DeepLearningTrainerWorkspace.java:247)
	at org.cascade.ai.DeepLearningTrainerWorkspace.lambda$3(DeepLearningTrainerWorkspace.java:157)
	at java.base/java.lang.Thread.run(Thread.java:833)

If someone can point out what I am doing wrong, or point me to an example that uses a convolution neural network with a CnnLossLayer output that uses multi-class then I can start trying to figure out how to put back in all the hair I have just pulled out.