Thanks for help, I think so far it’s going quite well. But I still have the issue that with 25000 records and a batchSize of 128, 1 epoch still takes about 10 minutes on the CPU. Therefore I wanted to try the Cuda version again and have taken the changes from you. But now another error comes again:
https://gist.github.com/ForceUpdate1/5eb5083e7e329ef52857ff15b42e0aa3
My actual pom.xml:
<properties>
<dl4j.version>1.0.0-SNAPSHOT</dl4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-cuda-11.2</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.2</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${dl4j.version}</version>
<exclusions>
<exclusion>
<groupId>org.datavec</groupId>
<artifactId>datavec-arrow</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
Edit
In JProfiler you can see that the backpropagation takes a lot of time I came across the following article:
https://deeplearning4j.konduit.ai/models/recurrent#truncated-back-propagation-through-time
Unfortunately, it doesn’t quite work for me:
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [128, 69, 128] and labels with shape [128, 2]
[main] WARN org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [60, 69, 128] and labels with shape [60, 2]
[main] INFO de.foorcee.chatfilter.trainer.ChatfilterTrainer - Epoch 1 complete in 6598ms (0 min). Starting evaluation:
Exception in thread "main" java.lang.IllegalStateException: Illegal set of indices for array: need at least 2 point/interval/all/specified indices for rank 2 array ([128, 2]), got indices [all(), all(), Interval(b=0,e=20,s=1)]
at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:641)
at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:412)
at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4140)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.getSubsetsForTbptt(MultiLayerNetwork.java:2112)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.doEvaluationHelper(MultiLayerNetwork.java:3472)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.doEvaluation(MultiLayerNetwork.java:3400)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.evaluate(MultiLayerNetwork.java:3595)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.evaluate(MultiLayerNetwork.java:3505)