Sure, here’s the class I’m running:
public class App
{
UIServer uiServer;
File trainData, testData;
public App() throws IOException
{
//configure training data and normalizer
trainData = new File("D:\\cam29\\Downloads\\100-bird-species\\180\\train");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random());
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(224,224,3, labelMaker);
rr.initialize(trainSplit);
DataSetIterator trainIterator = new RecordReaderDataSetIterator(rr,64,1,180);
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIterator);
trainIterator.setPreProcessor(imageScaler);
//configure testing data
testData = new File("D:\\cam29\\Downloads\\100-bird-species\\180\\test");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, new Random());
ImageRecordReader rrTest = new ImageRecordReader(224,224,3, labelMaker);
rrTest.initialize(testSplit);
DataSetIterator testIterator = new RecordReaderDataSetIterator(rrTest, 64, 1, 180);
testIterator.setPreProcessor(imageScaler);
ConvolutionLayer layer0 = new ConvolutionLayer.Builder(5,5)
.nIn(3)
.nOut(16)
.stride(1,1)
.padding(2,2)
.weightInit(WeightInit.XAVIER)
.name("First convolution layer")
.activation(Activation.RELU)
.build();
SubsamplingLayer layer1 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.name("First subsampling layer")
.build();
ConvolutionLayer layer2 = new ConvolutionLayer.Builder(5,5)
.nOut(20)
.stride(1,1)
.padding(2,2)
.weightInit(WeightInit.XAVIER)
.name("Second convolution layer")
.activation(Activation.RELU)
.build();
SubsamplingLayer layer3 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.name("Second subsampling layer")
.build();
ConvolutionLayer layer4 = new ConvolutionLayer.Builder(5,5)
.nOut(20)
.stride(1,1)
.padding(2,2)
.weightInit(WeightInit.XAVIER)
.name("Third convolution layer")
.activation(Activation.RELU)
.build();
SubsamplingLayer layer5 = new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.name("Third subsampling layer")
.build();
OutputLayer layer6 = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.name("Output")
.nOut(180)
.build();
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.seed(System.currentTimeMillis())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
//.regularization(true)
//.l1(0.0001)
//.l2(0.0002)
.l2(0.0004)
//.dropOut(0.8)
.updater(new Nesterovs(0.001, 0.9))
.list()
.layer(0, layer0)
.layer(1, layer1)
.layer(2, layer2)
.layer(3, layer3)
.layer(4, layer4)
.layer(5, layer5)
.layer(6, layer6)
//.pretrain(false)
//.backprop(true)
.setInputType(InputType.convolutional(224,224,3))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(configuration);
network.init();
network.setListeners(new ScoreIterationListener(10));
attachUI(network);
double start_time = System.currentTimeMillis();
network.fit(trainIterator, 50);
//network.evaluateROCMultiClass(testIterator);
Evaluation evaluation = network.evaluate(testIterator);
double end_time = System.currentTimeMillis();
if (evaluation.accuracy() > 0.15)
{
System.out.println("This run took " + (end_time - start_time)/1000/60 + " minutes.");
System.out.println(evaluation.stats());
}
else
{
System.out.println("Accuracy low at " + evaluation.accuracy());
System.out.println("Repeating...");
new App();
}
uiServer.stop();
}
public static void main(String[] args) {
try {
App a = new App();
} catch (IOException e) {
e.printStackTrace();
}
}
public void attachUI(MultiLayerNetwork mln)
{
uiServer = UIServer.getInstance();
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));
uiServer.detach(statsStorage);
int listenerFrequency = 20;
uiServer.attach(statsStorage);
mln.setListeners(new StatsListener(statsStorage, listenerFrequency));
}
}