I implements an example using LSTM and ComputationGraph,but an exception throwed. my deeplearning4j version is 1.0.0-M2.1.
can anyone help me? thank you very much.
Exception in thread “main” java.lang.IllegalStateException: Sequence lengths do not match for RnnOutputLayer input and labels:Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=[32, 15, 20] vs. label=[32, 5, 20]
at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:639)
at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:337)
at org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer.backpropGradient(RnnOutputLayer.java:59)
at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doBackward(LayerVertex.java:148)
at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2784)
at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1393)
at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1353)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1177)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1127)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1094)
at com.ai.lstm.MultiInputLSTMComputationGraphExample.main(MultiInputLSTMComputationGraphExample.java:63)
the code is showed as follow:
package com.ai.lstm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class MultiInputLSTMComputationGraphExample {
public static void main(String args) {
// 定义计算图配置
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs(“input1”, “input2”) // 定义两个输入
.addLayer(“lstm1”, new LSTM.Builder()
.nIn(10) // 输入1的维度
.nOut(20) // LSTM输出维度
.activation(Activation.TANH)
.build(), “input1”)
.addLayer(“lstm2”, new LSTM.Builder()
.nIn(5) // 输入2的维度
.nOut(10) // LSTM输出维度
.activation(Activation.TANH)
.build(), “input2”)
.addLayer(“merge”, new LSTM.Builder()
.nIn(30) // 合并后的输入维度 (20 + 10)
.nOut(15) // 合并后的LSTM输出维度
.activation(Activation.TANH)
// .setInputTypes(InputType.recurrent(30,20,RNNFormat.NCW))
.build(), “lstm1”, “lstm2”)
.addLayer(“output”, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(15) // 输入维度
.nOut(5) // 输出类别数
.build(), “merge”)
.setOutputs(“output”) // 定义输出
// .setInputTypes(InputType.recurrent(30,20,RNNFormat.NCW))
.build();
// 创建计算图模型
ComputationGraph model = new ComputationGraph(config);
model.init();
// 创建虚拟数据集
MultiDataSetIterator iterator = createDummyData();
// 训练模型
for (int i = 0; i < 10; i++) {
model.fit(iterator);
}
}
private static MultiDataSetIterator createDummyData() {
return new MultiDataSetIterator() {
@Override
public MultiDataSet next(int num) {
// 创建虚拟输入数据
// 输入1的形状: [batchSize, inputSize, timeSeriesLength]
INDArray input1 = Nd4j.rand(new int{num, 10, 20});
// 输入2的形状: [batchSize, inputSize, timeSeriesLength]
INDArray input2 = Nd4j.rand(new int{num, 5, 20});
// 创建虚拟标签数据
INDArray labels = Nd4j.zeros(num, 5, 20); // 输出形状: [batchSize, nOut, timeSeriesLength]
for (int i = 0; i < num; i++) {
for (int j = 0; j < 20; j++) {
labels.putScalar(new int{i, j % 5, j}, 1.0); // 随机分配标签
}
}
return new org.nd4j.linalg.dataset.MultiDataSet(
new INDArray{input1, input2},
new INDArray{labels});
}
@Override
public void reset() {
// 重置迭代器
}
@Override
public boolean hasNext() {
return true;
}
@Override
public MultiDataSet next() {
return next(32); // 默认批量大小为32
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
// TODO Auto-generated method stub
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
// TODO Auto-generated method stub
return null;
}
@Override
public boolean resetSupported() {
// TODO Auto-generated method stub
return false;
}
@Override
public boolean asyncSupported() {
// TODO Auto-generated method stub
return false;
}
};
}
}