here is my pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>SampleNd4j</groupId>
<artifactId>SampleNd4j</artifactId>
<version>0.0.1-SNAPSHOT</version>
<build>
<sourceDirectory>src</sourceDirectory>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<release>11</release>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<nd4j.cpu.backend>nd4j-native-platform</nd4j.cpu.backend>
<nd4j.cpu.backend2>nd4j-native</nd4j.cpu.backend2>
<nd4j.gpu.backend>nd4j-cuda-10.2-platform</nd4j.gpu.backend>
<dl4j.gpu.backend>deeplearning4j-cuda-10.2</dl4j.gpu.backend>
<cuda.redist.10.1.version>10.1-7.6-1.5.2</cuda.redist.10.1.version>
<cuda.redist.10.2.version>10.2-7.6-1.5.3</cuda.redist.10.2.version>
<dl4j.version>1.0.0-SNAPSHOT</dl4j.version>
<ffmpeg.version>3.2.1-1.3</ffmpeg.version>
<javacv.version>1.4.1</javacv.version>
<logback.version>1.1.7</logback.version>
<jackson.version>2.9.6</jackson.version>
</properties>
<repositories>
<repository>
<id>snapshots-repo</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
<releases>
<enabled>false</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
<updatePolicy>daily</updatePolicy> <!-- Optional, update daily -->
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>opencv-platform</artifactId>
<version>4.5.1-1.5.5</version>
</dependency>
<!-- Jackson dependencies -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- Log dependency -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.25</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<!-- deeplearning4j-core: contains main functionality and neural networks -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-cuda-11.2</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.2</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- ParallelWrapper & ParallelInference live here -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
</dependencies>
</project>
and here is my tes project
package com.sample.ui;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
public class TrainProgram {
public static final long seed = 1234;
public static final Random RAND_NUM_GEN = new Random(seed);
public static final String[] ALLOWED_FORMATS = BaseImageLoader.ALLOWED_FORMATS;
public static ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();
public static BalancedPathFilter PATH_FILTER = new BalancedPathFilter(RAND_NUM_GEN, ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);
protected static final Logger LOGGER = org.slf4j.LoggerFactory.getLogger(TrainProgram.class);
protected static final int TRAIN_SIZE = 85;
protected static final int BATCH_SIZE = 32;
protected static final int EPOCH = 30;
public static void main(String[] args) throws IOException {
String homePath = System.getProperty("user.home");
LOGGER.info(homePath);
Path datasetPath = Paths.get(homePath, "dataset");
Path trainPath = Paths.get(datasetPath.toString(), "sample-data");
File parentDir = trainPath.toFile();
FileSplit filesInDir = new FileSplit(parentDir, ALLOWED_FORMATS, RAND_NUM_GEN);
InputSplit[] filesInDirSplit = filesInDir.sample(PATH_FILTER, TRAIN_SIZE, 100 - TRAIN_SIZE);
DataSetIterator trainIter = makeIterator(filesInDirSplit[0]);
DataSetIterator testIter = makeIterator(filesInDirSplit[1]);
ZooModel objZooModel = VGG16.builder().workspaceMode(WorkspaceMode.ENABLED).build();
ComputationGraph preTrainedNet = (ComputationGraph) objZooModel.initPretrained(PretrainedType.IMAGENET);
LOGGER.info(preTrainedNet.summary());
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs(5e-5))
.seed(seed)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.build();
String FREEZE_UNTIL_LAYER = "fc2";
String OUTPUT_LAYER = "predictions";
int INPUT_LAYER_PARAM = 4096;
int numClasses = trainIter.getLabels().size();
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(preTrainedNet)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor(FREEZE_UNTIL_LAYER)
.removeVertexKeepConnections(OUTPUT_LAYER)
.setWorkspaceMode(WorkspaceMode.ENABLED)
.addLayer(OUTPUT_LAYER,
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(INPUT_LAYER_PARAM).nOut(numClasses)
//.weightInit(WeightInit.XAVIER)
.weightInit(new NormalDistribution(0, 0.2 * (2.0 / (INPUT_LAYER_PARAM + numClasses)))) //This weight init dist gave better results than Xavier
.activation(Activation.SOFTMAX).build(), FREEZE_UNTIL_LAYER)
.build();
vgg16Transfer.setListeners(new ScoreIterationListener(5));
LOGGER.info(vgg16Transfer.summary());
EarlyStoppingConfiguration < ComputationGraph > esConfig = new EarlyStoppingConfiguration.Builder < ComputationGraph > ()
.epochTerminationConditions(new MaxEpochsTerminationCondition(EPOCH),
new ScoreImprovementEpochTerminationCondition(5, 0))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(24, TimeUnit.HOURS)) //new MaxScoreIterationTerminationCondition(150.5)
.scoreCalculator(new DataSetLossCalculator(testIter, true))
.evaluateEveryNEpochs(1)
.modelSaver(new LocalFileGraphSaver(datasetPath.toString()))
.build();
IEarlyStoppingTrainer < ComputationGraph > trainer = new EarlyStoppingGraphTrainer(esConfig, vgg16Transfer, trainIter);
EarlyStoppingResult < ComputationGraph > result = trainer.fit();
//Print out the results:
LOGGER.info("Termination reason: " + result.getTerminationReason());
LOGGER.info("Termination details: " + result.getTerminationDetails());
LOGGER.info("Total epochs: " + result.getTotalEpochs());
LOGGER.info("Best epoch number: " + result.getBestModelEpoch());
LOGGER.info("Score at best epoch: " + result.getBestModelScore());
ComputationGraph bestModel = result.getBestModel();
evalOn(bestModel, testIter, 0);
Path fullPath = Paths.get(datasetPath.toString(), "model.zip");
ModelSerializer.writeModel(bestModel, fullPath.toFile(), false);
LOGGER.info("END");
}
public static DataSetIterator makeIterator(InputSplit split) throws IOException {
int channels = 3;
int width = 224;
int height = 224;
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
recordReader.initialize(split);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, 1, 1, true);
iter.setPreProcessor(new VGG16ImagePreProcessor());
return iter;
}
public static boolean evalOn(ComputationGraph graph, DataSetIterator testIterator, int iEpoch) throws IOException {
boolean result = true;
try {
LOGGER.info("Evaluate model at iteration " + iEpoch + " ....");
Evaluation eval = graph.evaluate(testIterator);
LOGGER.info(eval.stats());
testIterator.reset();
} catch (OutOfMemoryError e) {
System.gc();
LOGGER.info("Error: ", e);
result = false;
}
return result;
}
}
Each time I run this code I run out of mempry. I am using RTX 3060 with 12GB RAM and here are my memory settings: -Xms2G -Xmx4G -Dorg.bytedeco.javacpp.maxbytes=6G
Any guidance on how to fix this will be greatly appreciated