Unet predictions poor performance

Hello all,
I was trying UNet architecture on DL4J. I get white predictions (with just one thin line of black pixels)
Here is how I am loading the data, normalizing it and assigning groundtruth

        File testData = new File("G:/ASPIJ/deep_learning/val_img");
        LabelGenerator labelMakerTrain = new LabelGenerator("G:/ASPIJ/deep_learning/train_label");
        LabelGenerator labelMakerTest = new LabelGenerator("G:/ASPIJ/deep_learning/val_label");

        FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng);
        FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng);

        System.out.println("Checking values : h : "+height+" w: "+width+" c: "+channels);
        ImageRecordReader rrTrain = new ImageRecordReader(height, width, channels, labelMakerTrain);
        System.out.println("Train shape filesplit : "+train.length());
        rrTrain.initialize(train, null);

        ImageRecordReader rrTest = new ImageRecordReader(height, width, channels, labelMakerTest);
        rrTest.initialize(test, null);

        int labelIndex = 1;

        DataSetIterator dataTrainIter = new RecordReaderDataSetIterator(rrTrain, batchSize, labelIndex, labelIndex, true);
        DataSetIterator dataTestIter = new RecordReaderDataSetIterator(rrTest, 1, labelIndex, labelIndex, true);
        NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(0, 1);
        //DataNormalization scaler = new ImagePreProcessingScaler();
         scaler.fitLabel(true);
        scaler.fit(dataTrainIter);
        dataTrainIter.setPreProcessor(scaler);
        scaler.fit(dataTestIter);
        dataTestIter.setPreProcessor(scaler);

I get very high values score iteration values (INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 0 is 149587.20166618665INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1 is 744774.2926322176 13:40:47.122 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 2 is 667952.7301255261). I doubt if my model is learning properly or no.

My eval stats and confusion matrix

 # of classes:    2
 Accuracy:        0.8385
 Precision:       0.9978
 Recall:          0.8400
 F1 Score:        0.9122
Precision, recall & F1: reported for positive class (class 1 - "1") only


=========================Confusion Matrix=========================
      0      1
---------------
     34    477 | 0 = 0
  41851 219782 | 1 = 1

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times

Here is the test part:


            DataSet t = dataTestIter.next();
            scaler.revert(t);
            INDArray[] predicted = cp.output(t.getFeatures());
            INDArray input = t.getFeatures();
            INDArray pred = predicted[0].reshape(new int[]{512, 512});
            Evaluation eval = new Evaluation();
            System.out.println(t.getLabels());
            System.out.println("Prediction array");
            System.out.println(pred);
            eval.eval(pred.dup().reshape(512 * 512, 1), t.getLabels().dup().reshape(512 * 512, 1));
            System.out.println(eval.stats());
            DataBuffer dataBuffer = pred.data();
            System.out.println(dataBuffer);
            double[] classificationResult = dataBuffer.asDouble();
            System.out.println(classificationResult);
            ImageProcessor classifiedSliceProcessor = new FloatProcessor(512, 512, classificationResult);

            //segmented image instance
            ImagePlus classifiedImage = new ImagePlus("pred" + j, classifiedSliceProcessor);
            IJ.save(classifiedImage, "G:/ASPIJ/deep_learning/predict/pred-" + j + ".png");


            j++;
        }

Whole Unet script at: UNet1.java ยท GitHub
Note: I am training on a very small dataset (4-5 images because of limited computation resources, but just read it in a blog which I am following that the results were poor but obtainable with this number of dataset)
Many of the links I try to follow for documentation turn out to be page not found eg: https://deeplearning4j.konduit.ai/datavec/overview#reading-records-iterating-over-data (It makes it a little difficult to find solution with many such links not working from prev discussions, or is it not working just for me?)