Convolutional Autoencoder Example

My goal was to build a CNN autoencoder as I have never tried to build a CNN yet. I must be getting used to the api because it only took a single day. I used this dataset http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

It works basically as I expected but surprisingly well for my first attempt. I’m posting in case you have any corrections I could implement or If your looking for an example as I couldn’t find one myself of a CNNAE.

edit: I think I quoted it wrong but this is a single class file I pasted even though it looks all broken up. You should be able to do a copy and paste. Instantiate and call run with any image dataset.

CNNAE results

package ML;

import nu.pattern.OpenCV;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.fetchers.DataSetType;
import org.deeplearning4j.datasets.iterator.impl.Cifar10DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.opencv.core.Mat;
import org.opencv.imgcodecs.Imgcodecs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.swing.;
import java.awt.
;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

public class NewExp {

private static JFrame frame;
private static JLabel label;

JLabel originalLaabel;
JLabel resultLabel;
Java2DNativeImageLoader imageLoader;
Java2DNativeImageLoader imageLoaderTwo;

private static final Logger log = LoggerFactory.getLogger(NewExp.class);
int height = 32;
int width = 32;

// String source = “D:\Downloads\img_align_celeba\output\”;
String source = “D:\Downloads\img_align_celeba\img_align_celeba\”;

int channels = 3;

public NewExp() {

    OpenCV.loadLocally();
    String[] file = new File(source).list();

    Mat src = Imgcodecs.imread(source + file[0]);
    System.out.println("FINAL \n Width: " + src.width() + "\n Height: " + src.height());


    frame = new JFrame("Results");
    frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);


    imageLoader = new Java2DNativeImageLoader();
    imageLoaderTwo = new Java2DNativeImageLoader();
    originalLaabel = new JLabel();
    originalLaabel.setBounds(0,0, width * 2, height * 2);

    resultLabel = new JLabel();
    resultLabel.setBounds(width * 2,0, width * 2, height * 2);

    frame.add(originalLaabel);
    frame.add(resultLabel);
    frame.setSize(width * 3, height * 4 );
    frame.setLayout(null);
    frame.setVisible(true);


}

public void run()
{

    DataSetIterator iter = getDataSetIter();

    MultiLayerNetwork net = getNet();




    for (int i = 0; i < 100; i++) {

        int count = 0;
        while (iter.hasNext()){

            DataSet d = iter.next();

            net.fit(d.getFeatures(), d.getFeatures());

            if ((count % 100) == 0) {

                BufferedImage bf = imageLoader.asBufferedImage(d.getFeatures());

                List<INDArray> a = net.feedForwardToLayer(10, d.getFeatures());

                BufferedImage bfO = imageLoaderTwo.asBufferedImage(a.get(a.size() - 1));

                originalLaabel.setIcon(new ImageIcon(bf));
                resultLabel.setIcon(new ImageIcon(bfO));

                System.out.println("");

            }
            count++;
        }

    }




}


public DataSetIterator getDataSetIter()
{
    try{
        ImageRecordReader recordReader = new ImageRecordReader(height,width, 3);
        recordReader.initialize(new FileSplit(new File(source)));

        DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(recordReader, 1);

        return dataSetIterator;

    }catch (Exception e){
        e.printStackTrace();
    }

    return null;
}







private MultiLayerNetwork getNet()
{
    int rngSeed = 123;
    int dimensions = 16;

    //Neural net configuration
    Nd4j.getRandom().setSeed(rngSeed);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(rngSeed)
            .updater(new RmsProp(1e-3))
            .weightInit(WeightInit.XAVIER)
            .l2(1e-4)
            .list()

            //encode
            .layer(0, new ConvolutionLayer.Builder().kernelSize(3,3).stride(1,1).activation(Activation.RELU).nIn(channels).nOut(32).build())
            //(32-3)/1 + 1 = 32x30x30
            //(64-3)/1 + 1 = 32x62x62
            .layer( new BatchNormalization())
            .layer(1, new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).poolingType(SubsamplingLayer.PoolingType.MAX).build())
            //(30-2)/2 + 1 = 32x15x15
            //(62-3)/2 + 1 = 32x30x30
            .layer(2, new ConvolutionLayer.Builder().kernelSize(2,2).stride(1,1).activation(Activation.RELU).nIn(32).nOut(16).build())
            //(15-2)/1 + 1 = 16x14x14
            //(35-2)/1 + 1 = 16x29x29
            .layer(new BatchNormalization())
            .layer(3, new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).poolingType(SubsamplingLayer.PoolingType.MAX).build())
            //(14-2)/2 + 1 = 16x7x7
            //(29-2)/2 + 1 = 16x15x15

            //double stuff oreo center
            .layer(4, new DenseLayer.Builder().nIn(784).nOut(dimensions).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build())
            .layer(5, new DenseLayer.Builder().nIn(dimensions).nOut(784).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build())

            //decode
            .layer( 6, new Upsampling2D.Builder().size(2) .build())
            .layer( new BatchNormalization())
            .layer(7, new Deconvolution2D.Builder().kernelSize(2,2) .stride(1,1).nIn(16).nOut(32).activation(Activation.RELU).build())
            .layer(8, new Upsampling2D.Builder().size(2).build())
            .layer( new BatchNormalization())
            .layer(9, new Deconvolution2D.Builder().kernelSize(3,3).stride(1,1).activation(Activation.RELU).nIn(32).nOut(channels).build())

            .layer(10, new CnnLossLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).build())
            //.layer(10, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(3072).nOut(3072).activation(Activation.IDENTITY).build())

            // say hail marys here
            .inputPreProcessor(4, new CnnToFeedForwardPreProcessor(7,7,16))
            .inputPreProcessor(6, new FeedForwardToCnnPreProcessor(7,7,16))
            //.inputPreProcessor(10, new CnnToFeedForwardPreProcessor(32,32,3))
            .build();


    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(100));

    return net;

}

}