The data format for BatchNormalization is 2D CNN(NCHW), it is inconsistent with Convolution1D and LSTM layers(NCW).
@cqiaoYc you can specify the layout. Could you post code or clarify what issues you might be having?
the code:
// Add CNN layers
final Adam adam = new Adam(cnnLrSchedule);
for (int i = 0; i < cnnLayerCount; i++) {
int cnnPadding=cnnPaddings[i];
listBuilder.layer(layerIndex, new Convolution1D.Builder()
.kernelSize(cnnKernelSizes[i])
.stride(cnnStrides[i])
.padding(cnnPadding)
.updater(adam)
.nIn(nIn)
.nOut(cnnNeurons[i])
.activation(Activation.TANH)
.build());
nIn = cnnNeurons[i];
++layerIndex;
// listBuilder.layer(layerIndex, new BatchNormalization.Builder().nOut(nIn).build());
// ++layerIndex;
}
// Add RNN layers
final RmsProp rmsProp = new RmsProp(lrSchedule);
for (int i = 0; i < this.rnnNeurons.length; ++i) {
listBuilder.layer(layerIndex, new LSTM.Builder()
.dropOut(dropOut)
.activation(Activation.TANH)
.updater(rmsProp)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(rnnGradientNormalizationThreshold)
.nIn(nIn)
.nOut(rnnNeurons[i])
.build());
nIn = rnnNeurons[i];
++layerIndex;
// listBuilder.layer(layerIndex, new BatchNormalization.Builder().nOut(nIn).build());
// ++layerIndex;
}
If a BatchNormalization layer is added, a data shape mismatch exception will be thrown. The dataFormat of BatchNormalization is CNN2DFormat, but it is RNNFormat for Convolution1D and LSTM.
The Subsampling1DLayer has a similar question:
listBuilder.layer(layerIndex, new Subsampling1DLayer.Builder()
.kernelSize(cnnKernelSizes[i])
.stride(cnnStrides[i])
.padding(cnnPadding)
.poolingType(SubsamplingLayer.PoolingType.AVG)
.build());
++layerIndex;
the Exception:
Exception in thread "main" org.nd4j.linalg.exception.ND4JIllegalStateException: New shape length doesn't match original length: [983040] vs [2949120]. Original shape: [256, 64, 60, 3] New Shape: [256, 64, 60]
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3804)
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3749)
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3872)
at org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer.activate(Subsampling1DLayer.java:107)
at org.deeplearning4j.nn.layers.AbstractLayer.activate(AbstractLayer.java:262)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.ffToLayerActivationsInWs(MultiLayerNetwork.java:1147)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2798)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2756)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fitHelper(MultiLayerNetwork.java:1767)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1688)
at com.cq.aifocusstocks.train.RnnPredictModel.train(RnnPredictModel.java:176)
at com.cq.aifocusstocks.train.CnnLstmRegPredictor.trainModel(CnnLstmRegPredictor.java:225)
at com.cq.aifocusstocks.train.TrainCnnLstmModel.main(TrainCnnLstmModel.java:15)
The issue with the Subsampling1DLayer has been resolved.
Replace .padding(cnnPadding) with .convolutionMode(ConvolutionMode.Same) , and add “listBuilder.setInputType(InputType.recurrent(featuresCount)”
but the issue with BatchNormalization
The issue with BatchNormalization still persists, and the exception thrown is:
Exception in thread "main" java.lang.IllegalArgumentException: input.size(1) does not match expected input size of 64 - got input array with shape [1, 256, 64, 60]
at org.deeplearning4j.nn.layers.normalization.BatchNormalization.preOutput(BatchNormalization.java:405)
at org.deeplearning4j.nn.layers.normalization.BatchNormalization.activate(BatchNormalization.java:384)
at org.deeplearning4j.nn.layers.AbstractLayer.activate(AbstractLayer.java:262)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.ffToLayerActivationsInWs(MultiLayerNetwork.java:1147)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2798)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2756)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fitHelper(MultiLayerNetwork.java:1767)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1688)
at com.cq.aifocusstocks.train.RnnPredictModel.train(RnnPredictModel.java:176)
at com.cq.aifocusstocks.train.CnnLstmRegPredictor.trainModel(CnnLstmRegPredictor.java:225)
at com.cq.aifocusstocks.train.TrainCnnLstmModel.main(TrainCnnLstmModel.java:15)
256 is batchSize, 64 is featuresCount , 60 is timeStep.