Tiny YOLO PredictedObjects NaN

Hi guys,

Im trying to run the code from this page: How to build a custom object detector using Yolo, It is an object detector for a rubix cube.

Below the code, i did a little changes:

    package com.dl4j.yolo.sample;   
 
    import java.io.File;
    import java.io.IOException;
    import java.io.Serializable;
    import java.net.URI;
    import java.util.List;
    import java.util.Random;
    
    import org.bytedeco.opencv.opencv_java;
    import org.datavec.api.io.filters.BalancedPathFilter;
    import org.datavec.api.io.labels.ParentPathLabelGenerator;
    import org.datavec.api.records.metadata.RecordMetaDataImageURI;
    import org.datavec.api.split.FileSplit;
    import org.datavec.api.split.InputSplit;
    import org.datavec.image.loader.NativeImageLoader;
    import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
    import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
    import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
    import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    import org.deeplearning4j.nn.conf.ConvolutionMode;
    import org.deeplearning4j.nn.conf.GradientNormalization;
    import org.deeplearning4j.nn.conf.WorkspaceMode;
    import org.deeplearning4j.nn.conf.inputs.InputType;
    import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
    import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
    import org.deeplearning4j.nn.graph.ComputationGraph;
    import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
    import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
    import org.deeplearning4j.nn.transferlearning.TransferLearning;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    import org.deeplearning4j.util.ModelSerializer;
    import org.deeplearning4j.zoo.model.TinyYOLO;
    import org.nd4j.linalg.activations.Activation;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
    import org.nd4j.linalg.factory.Nd4j;
    import org.nd4j.linalg.learning.config.RmsProp;
    import org.opencv.core.Mat;
    import org.opencv.core.Point;
    import org.opencv.core.Scalar;
    import org.opencv.imgcodecs.Imgcodecs;
    import org.opencv.imgproc.Imgproc;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    public class YOLOTrainer {
    	 private static final Logger log = LoggerFactory.getLogger(YOLOTrainer.class);
    
    	    private static final int INPUT_WIDTH = 416;
    	    private static final int INPUT_HEIGHT = 416;
    	    private static final int CHANNELS = 3;
    
    	    private static final int GRID_WIDTH = 13;
    	    private static final int GRID_HEIGHT = 13;
    	    private static final int CLASSES_NUMBER = 1;
    	    private static final int BOXES_NUMBER = 5;
    	    private static final double[][] PRIOR_BOXES = {{1.5, 1.5}, {2, 2}, {3, 3}, {3.5, 8}, {4, 9}};
    
    	    private static final int BATCH_SIZE = 4;
    	    private static final int EPOCHS = 50;
    	    private static final double LEARNIGN_RATE = 0.0001;
    	    private static final int SEED = 7854;
    
    	    /*parent Dataset folder "DATA_DIR" contains two subfolder "images" and "annotations" */
    	    private static final String DATA_DIR = "C:\\Java\\Dataset";
    
    	    /* Yolo loss function prameters for more info
    	    https://stats.stackexchange.com/questions/287486/yolo-loss-function-explanation*/
    	    private static final double LAMDBA_COORD = 1.0;
    	    private static final double LAMDBA_NO_OBJECT = 0.5;
    
    	    public static void main(String[] args) throws IOException, InterruptedException {
    
    	        Random rng = new Random(SEED);
    
    	        //Initialize the user interface backend, it is just as tensorboard.
    	        //it starts at http://localhost:9000
    	        //UIServer uiServer = UIServer.getInstance();
    
    	        //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
    	        //StatsStorage statsStorage = new InMemoryStatsStorage();
    
    	        //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
    	        //uiServer.attach(statsStorage);
    
    	        File imageDir = new File(DATA_DIR, "images");
    
    	        log.info("Load data...");
    	        
    	        ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();
    	        BalancedPathFilter PATH_FILTER = new BalancedPathFilter(rng, NativeImageLoader.ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);
    
    	        InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(PATH_FILTER, 85, 15);
    	        InputSplit trainData = data[0];
    	        InputSplit testData = data[1];
    
    	        ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS,
    	                GRID_HEIGHT, GRID_WIDTH, new VocLabelProvider(DATA_DIR));
    	        recordReaderTrain.initialize(trainData);
    
    	        ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS,
    	                GRID_HEIGHT, GRID_WIDTH, new VocLabelProvider(DATA_DIR));
    	        recordReaderTest.initialize(testData);
    
    	        RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, BATCH_SIZE, 1, 1, true);
    	        train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
    
    	        RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, BATCH_SIZE, 1, 1, true);
    	        test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
    
    	        /*
    	        ComputationGraph pretrained = (ComputationGraph) TinyYOLO.builder().build().initPretrained();
    
    	        INDArray priors = Nd4j.create(PRIOR_BOXES);
    	        FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
    	                .seed(SEED)
    	                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    	                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
    	                .gradientNormalizationThreshold(1.0)
    	                .updater(new RmsProp(LEARNIGN_RATE))
    	                .activation(Activation.IDENTITY).miniBatch(true)
    	                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
    	                .build();
    
    	        ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
    	                .fineTuneConfiguration(fineTuneConf)
    	                .setInputTypes(InputType.convolutional(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS))
    	                .removeVertexKeepConnections("conv2d_9")
    	                .removeVertexKeepConnections("outputs")
    	                .addLayer("convolution2d_9",
    	                        new ConvolutionLayer.Builder(1, 1)
    	                                .nIn(1024)
    	                                .nOut(BOXES_NUMBER * (5 + CLASSES_NUMBER))
    	                                .stride(1, 1)
    	                                .convolutionMode(ConvolutionMode.Same)
    	                                .weightInit(WeightInit.UNIFORM)
    	                                .hasBias(false)
    	                                .activation(Activation.IDENTITY)
    	                                .build(), "leaky_re_lu_8")
    	                .addLayer("outputs",
    	                        new Yolo2OutputLayer.Builder()
    	                                .lambdaNoObj(LAMDBA_NO_OBJECT)
    	                                .lambdaCoord(LAMDBA_COORD)
    	                                .boundingBoxPriors(priors)
    	                                .build(), "convolution2d_9")
    	                .setOutputs("outputs")
    	                .build();
    
    	        log.info("\n Model Summary \n" + model.summary());
    
    	        log.info("Train model...");
    	        model.setListeners(new ScoreIterationListener(1));//print score after each iteration on stout 
    	        //model.setListeners(new StatsListener(statsStorage));// visit http://localhost:9000 to track the training process
    	        for (int i = 0; i < EPOCHS; i++) {
    	            train.reset();
    	            while (train.hasNext()) {
    	                model.fit(train.next());
    	            }
    	            log.info("*** Completed epoch {} ***", i);
    	        }
    
    	        log.info("*** Saving Model ***");
    	        ModelSerializer.writeModel(model, "C:\\Java\\model.data", true);
    	        log.info("*** Training Done ***");
    	           	        
    	        
    	        URI[] loc = testData.locations();
    	        for (int i = 0; i < loc.length; i++) {
    				URI uri = loc[i];
    				Mat image = Imgcodecs.imread(uri.getPath().substring(1));
    				
    				List<DetectedObject> objs = detect(image, model);
    	        	boolean found = addRects(image, objs);
    	        	String name = String.format("NF_%s.jpg", i);
    	        	
    	        	if(found) {
    	        		name = String.format("F_%s.jpg", i);
    	        	}
    	        	
    	        	Imgcodecs.imwrite("C:\\Java\\test\\" + name, image);
    			}	       
    	    }
    	    
    	    public static List<DetectedObject> detect(Mat image, ComputationGraph model) throws IOException {
    	    	org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
    	    	 
    	    	 NativeImageLoader loader = new NativeImageLoader(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS);
    	         INDArray ds = loader.asMatrix(image);        
    	         ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
    	         scaler.transform(ds);
    	         
    	         INDArray results = model.outputSingle(ds);
    	         List<DetectedObject> objs = yout.getPredictedObjects(results, 0.4);	         
    	         
    	         return objs;
    	    }
    	    
    	    public static boolean addRects(Mat image, List<DetectedObject> objs) {
    	    	boolean result = false;
    	    	Scalar color = new Scalar(0, 0, 255);
    	    	for (int i = 0; i < objs.size(); i++) {
    				DetectedObject obj = objs.get(i);
    				
    				int imgW = image.width();
    				int imgH = image.height();
    				
    				double[] xy1 = obj.getTopLeftXY();
    				double[] xy2 = obj.getBottomRightXY();
    				
    				int x1 = (int) Math.round(imgW * xy1[0] / GRID_WIDTH);
    				int y1 = (int) Math.round(imgH * xy1[1] / GRID_HEIGHT);
    				int x2 = (int) Math.round(imgW * xy2[0] / GRID_WIDTH);
    				int y2 = (int) Math.round(imgH * xy2[1] / GRID_HEIGHT);
    				
    				if(x1 == 0 && y1 == 0 && x2 == 0 && y2 == 0) {
    					continue;
    				}
    				
    				result = true;
    				Imgproc.rectangle(image, new Point(x1, y1), new Point(x2, y2), color);			
    			}
    	    	
    	    	return result;
    	    }
    }

Dataset can be downloaded from here.

The problem is when i try to test the model, all the detected objects return NaN
Capture_NaN

Any hints on this topic would be very helpful.
Thanks.

@lquintero07 I have no idea but this is a cool project. I tried the yolo basic but I think Ill try this soon. Let me know of you get it to work.

Can you add this:

     Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder()
                .checkForINF(true)
                .checkElapsedTime(true)
                .checkLocality(true)
                .checkWorkspaces(true)
                .build());

NANs are generally an indicator of a bad dataset or tuning. I’d be curious when it NANs.

Hey hi,

Thanks for your response.

I deleted the previous model, so i trained it again with the lines you said at the start of main method.

This time i dont get NaN values but results dont seems good

F_3

F_5

When i was testing i see NaN from this line: INDArray results = model.outputSingle(ds);

Hi, I’m sorry to bump this old topics but I got these same error with my dataset after training it. It also occurs when I use the example for the TinyYoloHouseNumberDetection provided here : https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/objectdetection/TinyYoloHouseNumberDetection.java

I don’t even change anything from the example code and run it as is.
I eventually add this code to the example

and I got this exception :
Exception in thread "main" org.nd4j.linalg.exception.ND4JOpProfilerException: P.A.N.I.C.! Op.Z() contains 43264 Inf value(s) at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForInf(OpExecutionerUtil.java:94) at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForInf(OpExecutionerUtil.java:129) at org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.profilingConfigurableHookOut(DefaultOpExecutioner.java:558) at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1993) at org.nd4j.linalg.factory.Nd4j.exec(Nd4j.java:6575) at org.deeplearning4j.nn.layers.mkldnn.MKLDNNConvHelper.preOutput(MKLDNNConvHelper.java:166) at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.preOutput(ConvolutionLayer.java:401) at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.activate(ConvolutionLayer.java:489) at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doForward(LayerVertex.java:111) at org.deeplearning4j.nn.graph.ComputationGraph.outputOfLayersDetached(ComputationGraph.java:2380) at org.deeplearning4j.nn.graph.ComputationGraph.output(ComputationGraph.java:1793) at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1775) at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1761) at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1639) at TinyYoloHouseNumberDetection.main(TinyYoloHouseNumberDetection.java:215)

FYI I have able to use the TinyYolo example successfully before even with my own dataset. But now I can’t seems to get it right, it does the training process and everything seems to be fine but it still outputs the NaN after the training process.
It all happen after I install Cuda to speed up the training process but I’ve not even use it yet since I’m not done installing the cuDNN nor do I change my backed in the pom file.

Also the debug by ND4J for the AsyncMultiDataSetIterator saying Manually destroying AMDSI Workspace started to appear in the training process out of nowhere and my model start to be outputing NaN.

I don’t know if I’m stupid but I been debugging these things for quite some times now and not even get it working. If anyone know what causes this things please enlighten me, I really appreciate it. Thank You :pray:

Hello everyone, sorry to bump this old thread again.
I just want to say that I’ve got it working for me now, but since I cannot edit the reply anymore I’ll just post a new reply here :

For my case it’s the debug by ND4J for the AsyncMultiDataSetIterator saying Manually destroying AMDSI that causes the problem. I simply change the model.fit(train,nEpoch) with

for (int i = 0; i < nEpochs; i++) {
            while (train.hasNext()) {
                DataSet d = train.next();
                model.fit(d);
            }
            train.reset();
        }

and everything is running fine again, maybe the preloading stuff corrupt the data for me but idk for sure.
Hopefully this reply helps anyone with the same problem as mine :pray:t2:

Note : I’m using the newest version available (beta7)

@artinmare thanks for elaborating on this. Sorry just doing QA for the release. Anything where we have to dig too much I would have to do later. Could you elaborate a bit on what you think your issue was so others won’t repeat this? I"m wondering if there’s a bug in the AMDSI?

@agibsonccc
No problem, Happy I can help.

I think the issue might have related to the small amount of RAM that available for the training process ? Something along the buffering of the image to the memory for faster training ? since I’m trying to train the model using Laptop with just 8 GB of RAM and windows might free the memory or fill another data into the same address thus corrupted the data ?

For everyone with limited amount of RAM I suggesting to not using the preload method for now and use the Workaround if you got into the same problem.
For a reference, these guys here https://gist.github.com/saudet/fb8a4d9544dc3c411b302ccd6bbf87e7 and here Emaraic - How to build a custom object detector using Yolo also use the same code to train the data instead of the preload method. The Emaraic guy also state that he use old computer so it also might be the case ?

I don’t know much about how JVM handle the memory so I cannot give more proper information with the issue here. (I’ll try to read more about it)
But I don’t think it’s the AMDSI code itself since I don’t find anything wrong with the code, maybe the Workspace ? (I’ve yet to read into that part of the code). I’ll try to reread the code again and test it with different setup. If I found anything I’ll reply as soon as possible and give more elaborated answer. Thank You :pray:t2:

UPDATE :
As what I have guessed before, I’ve found that the Memory is what causes the problem. My guess is the JVM free up some memory and accidentally corrupted the training data that have been preloaded into the memory.
I’ve tested to limit the memory usage using this guide here https://deeplearning4j.konduit.ai/config/config-memory and everything is now working as intended.
I guess letting the JVM to automatically allocate the memory is not the best practice in the first place.

I’ve yet to try overfitting the network (only doing 200 Epoch) so I don’t know if it really is the answer but I hope it will not causes anymore NaN even if I overfitting it. The Inf exception is also not occurring so it’s a good news.

The training doing really great at 85% average confidence so I call it a success. I’ll try any other configuration to make sure everything is really fine. I’ll be back with more update

Could you elaborate w bit more? I’m doing the same thing, the same example, same data set, and I’m encountering the same issue. After adding your code (and a checkForNaN() call to it) I get this:

[main] INFO com.emaraic.rubikcubedetector.YoloTrainer - Train model...
Exception in thread "main" org.nd4j.linalg.exception.ND4JOpProfilerException: P.A.N.I.C.! Op.Z() contains 8400 NaN value(s)
	at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForNaN(OpExecutionerUtil.java:65)
	at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForNaN(OpExecutionerUtil.java:145)
	at org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.profilingConfigurableHookOut(DefaultOpExecutioner.java:554)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1971)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1540)
	at org.nd4j.linalg.factory.Nd4j.exec(Nd4j.java:6545)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.divi(BaseNDArray.java:3236)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.div(BaseNDArray.java:3053)
	at org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer.calculateIOULabelPredicted(Yolo2OutputLayer.java:480)
	at org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer.computeBackpropGradientAndScore(Yolo2OutputLayer.java:164)
	at org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer.backpropGradient(Yolo2OutputLayer.java:77)
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doBackward(LayerVertex.java:148)
	at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2784)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1393)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1353)
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1177)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1127)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1114)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:996)
	at com.emaraic.rubikcubedetector.YoloTrainer.main(YoloTrainer.java:202)

I don’t think I have any NaNs in my training data, to idea what the issue might be here.

@b005t3r that usually will be some sort of a tuning issue that triggers numerical underflow. Try tuning your learning rate and the like.

I tried changing the learning rate, updaters, but nothing works. I used the same params you used in one of your examples (house numbers), tried different ones from a different example, but nothing helps - it looks like after it goes through the first training batch, this issue happens.

Could you elaborate, what did you do to solve the memory issue you mentioned and what the issue actually was?

I tested the training params, nothing helps.

BUT then I had a closer look at the training data and noticed that features are correctly scaled between 0.0 and 1.0, but labels contain values which are much higher:

===========INPUT===================
[[[[    0.7569,    0.7569,    0.7686,  ...    0.3765,    0.2431,    0.2235], 
   [    0.7569,    0.7569,    0.7686,  ...    0.4235,    0.3647,    0.3647], 
   [    0.7569,    0.7569,    0.7686,  ...    0.4314,    0.4431,    0.4353], 
    ..., 
   [    0.2627,    0.2667,    0.2706,  ...    0.3059,    0.2980,    0.2941], 
   [    0.2667,    0.2667,    0.2667,  ...    0.3059,    0.3059,    0.3020], 
   [    0.2706,    0.2627,    0.2627,  ...    0.3020,    0.3020,    0.3020]], 

  [[    0.7725,    0.7725,    0.7765,  ...    0.4157,    0.2314,    0.1961], 
   [    0.7725,    0.7725,    0.7765,  ...    0.4431,    0.3490,    0.3333], 
   [    0.7725,    0.7725,    0.7765,  ...    0.4157,    0.4196,    0.4118], 
    ..., 
   [    0.2745,    0.2784,    0.2824,  ...    0.3216,    0.3176,    0.3137], 
   [    0.2824,    0.2824,    0.2784,  ...    0.3255,    0.3255,    0.3216], 
   [    0.2863,    0.2784,    0.2745,  ...    0.3216,    0.3216,    0.3216]], 

  [[    0.7255,    0.7255,    0.7294,  ...    0.3922,    0.2667,    0.2471], 
   [    0.7255,    0.7255,    0.7294,  ...    0.4353,    0.3686,    0.3725], 
   [    0.7255,    0.7255,    0.7294,  ...    0.4118,    0.4235,    0.4157], 
    ..., 
   [    0.2667,    0.2706,    0.2745,  ...    0.3255,    0.3216,    0.3176], 
   [    0.2627,    0.2627,    0.2706,  ...    0.3294,    0.3294,    0.3255], 
   [    0.2667,    0.2588,    0.2667,  ...    0.3255,    0.3255,    0.3255]]], 


 [[[    0.7608,    0.7647,    0.7647,  ...    0.4078,    0.2863,    0.2863], 
   [    0.7608,    0.7647,    0.7647,  ...    0.4314,    0.4431,    0.4118], 
   [    0.7608,    0.7647,    0.7647,  ...    0.3529,    0.4000,    0.4078], 
    ..., 
   [    0.2784,    0.2824,    0.2824,  ...    0.2745,    0.2745,    0.2784], 
   [    0.2824,    0.2784,    0.2745,  ...    0.2745,    0.2902,    0.2863], 
   [    0.2784,    0.2706,    0.2667,  ...    0.2784,    0.2941,    0.2824]], 

  [[    0.7725,    0.7765,    0.7765,  ...    0.4431,    0.2667,    0.2471], 
   [    0.7725,    0.7765,    0.7765,  ...    0.4431,    0.4275,    0.3843], 
   [    0.7725,    0.7765,    0.7765,  ...    0.3333,    0.3804,    0.3922], 
    ..., 
   [    0.2667,    0.2706,    0.2745,  ...    0.2824,    0.2824,    0.2863], 
   [    0.2706,    0.2667,    0.2667,  ...    0.2824,    0.2863,    0.2824], 
   [    0.2667,    0.2588,    0.2588,  ...    0.2863,    0.2902,    0.2784]], 

  [[    0.7137,    0.7176,    0.7176,  ...    0.3922,    0.2706,    0.2706], 
   [    0.7137,    0.7176,    0.7176,  ...    0.4078,    0.4235,    0.3961], 
   [    0.7137,    0.7176,    0.7176,  ...    0.3294,    0.3765,    0.3882], 
    ..., 
   [    0.2745,    0.2784,    0.2745,  ...    0.2863,    0.2863,    0.2902], 
   [    0.2784,    0.2745,    0.2667,  ...    0.2863,    0.2941,    0.2902], 
   [    0.2745,    0.2667,    0.2588,  ...    0.2902,    0.2980,    0.2863]]], 


 [[[    0.7725,    0.7765,    0.7804,  ...    0.3804,    0.3804,    0.3804], 
   [    0.7725,    0.7804,    0.7804,  ...    0.3843,    0.3882,    0.3843], 
   [    0.7765,    0.7804,    0.7804,  ...    0.3843,    0.3843,    0.3765], 
    ..., 
   [    0.3137,    0.3059,    0.3059,  ...    0.2824,    0.2745,    0.2667], 
   [    0.3098,    0.3098,    0.3020,  ...    0.2824,    0.2667,    0.2627], 
   [    0.3137,    0.3176,    0.3137,  ...    0.2667,    0.2627,    0.2706]], 

  [[    0.7882,    0.7922,    0.7961,  ...    0.3725,    0.3725,    0.3725], 
   [    0.7882,    0.7961,    0.7961,  ...    0.3725,    0.3765,    0.3725], 
   [    0.7922,    0.7961,    0.7961,  ...    0.3608,    0.3608,    0.3529], 
    ..., 
   [    0.3137,    0.3059,    0.3059,  ...    0.2941,    0.2941,    0.2863], 
   [    0.3020,    0.3020,    0.3020,  ...    0.2941,    0.2863,    0.2824], 
   [    0.3059,    0.3098,    0.3137,  ...    0.2784,    0.2824,    0.2902]], 

  [[    0.7451,    0.7490,    0.7529,  ...    0.3412,    0.3412,    0.3412], 
   [    0.7451,    0.7529,    0.7529,  ...    0.3412,    0.3451,    0.3412], 
   [    0.7490,    0.7529,    0.7529,  ...    0.3412,    0.3333,    0.3255], 
    ..., 
   [    0.3137,    0.3059,    0.3059,  ...    0.3098,    0.3059,    0.2980], 
   [    0.3020,    0.3020,    0.3020,  ...    0.3098,    0.2980,    0.2941], 
   [    0.3059,    0.3098,    0.3137,  ...    0.2941,    0.2941,    0.3020]]], 


 [[[    0.7176,    0.7176,    0.7098,  ...    0.4980,    0.4627,    0.3961], 
   [    0.7255,    0.7176,    0.7176,  ...    0.3176,    0.3137,    0.3608], 
   [    0.7255,    0.7216,    0.7176,  ...    0.4039,    0.4627,    0.4824], 
    ..., 
   [    0.2314,    0.2275,    0.2275,  ...    0.3765,    0.3843,    0.4000], 
   [    0.2275,    0.2275,    0.2275,  ...    0.4039,    0.4078,    0.4235], 
   [    0.2275,    0.2275,    0.2235,  ...    0.4196,    0.4275,    0.4392]], 

  [[    0.7294,    0.7294,    0.7294,  ...    0.5255,    0.4431,    0.3569], 
   [    0.7255,    0.7294,    0.7294,  ...    0.3451,    0.3020,    0.3333], 
   [    0.7216,    0.7216,    0.7294,  ...    0.4353,    0.4667,    0.4667], 
    ..., 
   [    0.2392,    0.2353,    0.2353,  ...    0.3922,    0.4000,    0.4157], 
   [    0.2353,    0.2353,    0.2314,  ...    0.4196,    0.4235,    0.4392], 
   [    0.2353,    0.2353,    0.2275,  ...    0.4353,    0.4431,    0.4549]], 

  [[    0.6706,    0.6706,    0.6706,  ...    0.4980,    0.4392,    0.3569], 
   [    0.6706,    0.6706,    0.6706,  ...    0.3176,    0.2863,    0.3216], 
   [    0.6667,    0.6667,    0.6667,  ...    0.3922,    0.4314,    0.4471], 
    ..., 
   [    0.2431,    0.2392,    0.2392,  ...    0.3725,    0.3804,    0.3961], 
   [    0.2392,    0.2392,    0.2471,  ...    0.4000,    0.4039,    0.4196], 
   [    0.2392,    0.2392,    0.2431,  ...    0.4157,    0.4235,    0.4353]]]]
=================OUTPUT==================
[[[[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]]], 


 [[[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]]], 


 [[[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,    9.5312,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,    0.6562,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,   12.6250,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,    4.7188,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,    1.0000,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]]], 


 [[[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]], 

  [[         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
    ..., 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0], 
   [         0,         0,         0,  ...         0,         0,         0]]]]

Is that correct? Could it be the cause of my issues with NaNs?

edit:
I have exactly the same issue with your house number detection example:

@b005t3r could you check if it’s a precision issue? Try setting the data type to double and seeing what that does.

I’d also wonder about how large the updates get. Are you doing any normalization or anything?

@agibsonccc changing the data type didn’t change anything.

I did some debugging and noticed that at one point during the training process there’s a “div” operation which takes two zero-filled arrays as arguments, which will obviously produce an array filled with NaNs. But later there’s a NaN check and all NaNs are replaced with zeros. Given that I figured out NaNs during training might not be an issue after all, so I disabled the NaN test.

Now, after finishing training, I get this, when I try to use the trained network:

Exception in thread "main" org.nd4j.linalg.exception.ND4JOpProfilerException: P.A.N.I.C.! Op.Z() contains 524 Inf value(s)
	at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForInf(OpExecutionerUtil.java:94)
	at org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil.checkForInf(OpExecutionerUtil.java:114)
	at org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.profilingConfigurableHookOut(DefaultOpExecutioner.java:525)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:888)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:151)
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:139)
	at org.nd4j.linalg.ops.transforms.Transforms.exec(Transforms.java:1140)
	at org.nd4j.linalg.ops.transforms.Transforms.exp(Transforms.java:972)
	at org.deeplearning4j.nn.layers.objdetect.YoloUtils.activate(YoloUtils.java:66)
	at org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer.activate(Yolo2OutputLayer.java:389)
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doForward(LayerVertex.java:111)
	at org.deeplearning4j.nn.graph.ComputationGraph.outputOfLayersDetached(ComputationGraph.java:2380)
	at org.deeplearning4j.nn.graph.ComputationGraph.output(ComputationGraph.java:1793)
	at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1775)
	at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1761)
	at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1639)
	at com.emaraic.utils.YoloModel.detectRubixCube(YoloModel.java:88)
	at com.emaraic.rubikcubedetector.YoloTrainer.main(YoloTrainer.java:242)

I’m attaching the source code for the training and testing steps, let me know if you notice something incorrect there:

package com.emaraic.rubikcubedetector;

import com.emaraic.utils.YoloModel;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Random;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;

import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Created on Jul 4, 2018 , 11:07:52 AM
 *
 * @author Taha Emara 
 * Email : taha@emaraic.com 
 * Website: http://www.emaraic.com
 */
public class YoloTrainer {

    private static final Logger log = LoggerFactory.getLogger(YoloTrainer.class);

    private static final int INPUT_WIDTH = 416;
    private static final int INPUT_HEIGHT = 416;
    private static final int CHANNELS = 3;

    private static final int GRID_WIDTH = 13;
    private static final int GRID_HEIGHT = 13;
    private static final int CLASSES_NUMBER = 1;
    private static final int BOXES_NUMBER = 5;
    private static final double[][] PRIOR_BOXES = {{1.5, 1.5}, {2, 2}, {3, 3}, {3.5, 8}, {4, 9}};

    private static final int BATCH_SIZE = 4;
    private static final int EPOCHS = 3; // 50;
    private static final double LEARNING_RATE = 0.001; // 0.0001;
    private static final int SEED = 12345;

    /*parent Dataset folder "DATA_DIR" contains two subfolder "images" and "annotations" */
    private static final String DATA_DIR = "Dataset";

    /* Yolo loss function prameters for more info
    https://stats.stackexchange.com/questions/287486/yolo-loss-function-explanation*/
    private static final double LAMDBA_COORD = 1.0f; // 1.0;
    private static final double LAMDBA_NO_OBJECT = 0.5;

    public static void main(String[] args) throws IOException, InterruptedException {
//        System.setProperty("org.bytedeco.javacpp.noPointerGC", "true");
//        System.setProperty("org.bytedeco.javacpp.maxBytes", "12G");
//        System.setProperty("org.bytedeco.javacpp.maxPhysicalBytes", "20G");

        try(PointerScope scope = new PointerScope()) {
            Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);

            Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder()
                    .checkForINF(true)
                    //.checkForNAN(true)
                    .checkElapsedTime(true)
                    .checkLocality(true)
                    .checkWorkspaces(true)
                    .build());

            Random rng = new Random(SEED);

/*
            //Initialize the user interface backend, it is just as tensorboard.
            //it starts at http://localhost:9000
            UIServer uiServer = UIServer.getInstance();

            //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
            StatsStorage statsStorage = new InMemoryStatsStorage();

            //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
            uiServer.attach(statsStorage);
*/

            File imageDir = new File(DATA_DIR, "images");

            log.info("Load data...");
            RandomPathFilter pathFilter = new RandomPathFilter(rng) {
                @Override
                protected boolean accept(String name) {
                    name = name.replace("/images/", "/annotations/").replace(".jpg", ".xml");
                    //System.out.println("Name " + name);
                    try {
                        return new File(new URI(name)).exists();
                    } catch (URISyntaxException ex) {
                        throw new RuntimeException(ex);
                    }
                }
            };

            InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.9, 0.1);
            InputSplit trainData = data[0];
            InputSplit testData = data[1];

            ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS,
                    GRID_HEIGHT, GRID_WIDTH, new VocLabelProvider(DATA_DIR));
            recordReaderTrain.initialize(trainData);

            ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS,
                    GRID_HEIGHT, GRID_WIDTH, new VocLabelProvider(DATA_DIR));
            recordReaderTest.initialize(testData);

            ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1, 8);
            //scaler.fitLabel(true);

            RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, BATCH_SIZE, 1, 1, true);
            train.setPreProcessor(scaler);

            RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
            test.setPreProcessor(scaler);

        ComputationGraph pretrained = (ComputationGraph) TinyYOLO.builder().build().initPretrained();
        INDArray priors = Nd4j.create(PRIOR_BOXES);

        FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
                .seed(SEED)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(1.0)
                .updater(new Adam.Builder().learningRate(LEARNING_RATE).build())
                //.updater(new Nesterovs.Builder().learningRate(learningRate).momentum(lrMomentum).build())
                .activation(Activation.IDENTITY)
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .build();

        ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
                .fineTuneConfiguration(fineTuneConf)
                .removeVertexKeepConnections("conv2d_9")
                .removeVertexKeepConnections("outputs")
                .addLayer("convolution2d_9",
                        new ConvolutionLayer.Builder(1,1)
                                .nIn(1024)
                                .nOut(BOXES_NUMBER * (5 + CLASSES_NUMBER))
                                .stride(1,1)
                                .convolutionMode(ConvolutionMode.Same)
                                .weightInit(WeightInit.UNIFORM)
                                .hasBias(false)
                                .activation(Activation.IDENTITY)
                                .build(),
                        "leaky_re_lu_8")
                .addLayer("outputs",
                        new Yolo2OutputLayer.Builder()
                                .lambdaNoObj(LAMDBA_NO_OBJECT)
                                .lambdaCoord(LAMDBA_COORD)
                                .boundingBoxPriors(priors)
                                .build(),
                        "convolution2d_9")
                .setOutputs("outputs")
                .build();

        System.out.println(model.summary(InputType.convolutional(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS)));

            log.info("Train model...");
            model.setListeners(new ScoreIterationListener(1));//print score after each iteration on stout
            //model.setListeners(new StatsListener(statsStorage));// visit http://localhost:9000 to track the training process
            for (int i = 0; i < EPOCHS; i++) {
                train.reset();
                while (train.hasNext()) {
                    DataSet ds = train.next();

/*
                    if (ds.toString().contains("NaN")) {
                        System.err.println("NaN present");

                        System.out.println(ds);
                    }
*/

                    model.fit(ds);
                }
                log.info("*** Completed epoch {} ***", i);
            }
            //model.fit(train, EPOCHS);

            log.info("*** Saving Model ***");
            ModelSerializer.writeModel(model, "model.data", true);
            log.info("*** Training Done ***");


            //visualize results on the test set, Just hit any key in your keyboard to iterate the test set.
            log.info("*** Visualizing model on test data ***");
            YoloModel detector = new YoloModel();
            CanvasFrame mainframe = new CanvasFrame("Rubix Cube");
            mainframe.setDefaultCloseOperation(javax.swing.JFrame.EXIT_ON_CLOSE);
            mainframe.setCanvasSize(600, 600);
            mainframe.setLocationRelativeTo(null);
            mainframe.setVisible(true);

            OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
            test.setCollectMetaData(true);
            while (test.hasNext() && mainframe.isVisible()) {
                org.nd4j.linalg.dataset.DataSet ds = test.next();
                RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
                Mat image = imread(metadata.getURI().getPath());
                //System.out.println("Path: " +metadata.getURI().getPath());
                //detector.detectRubixCube(ds.getFeatures(), image, 0.4);
                detector.detectRubixCube(image, 0.1);
                mainframe.setTitle(new File(metadata.getURI()).getName());
                mainframe.showImage(converter.convert(image));
                mainframe.waitKey();
            }
            mainframe.dispose();
            System.exit(0);
        }
    }
}