The biggest hurdle in implementing a transformer/attention model I am encountering is in configuration of the dual inputs.
I thought it might be helpful if I started by implementing the model described in
Unfortunately, I am not at all familiar with Scala. I tried to “translate/adapt” the code in Java.
Below, you will find the latest iteration of my attempt.
It is not working. Execution fails with
Exception in thread “main” java.lang.IllegalArgumentException: Invalid output array: network has 1 outputs, but array is of length 0
at org.deeplearning4j.nn.graph.ComputationGraph.setLabels(ComputationGraph.java:417)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1144)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1127)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1094)
In te code below the failure occurs here:
net2.fit(trainData2);
What am I doing wrong? I am guessing the issue may have to do with the way my input data is structured (?). I don’t understand which output array and outputs the error message refers to.
Any thoughts/ideas as to what is going on would be very helpful.
Thanks.
====================================================================
int miniBatchSize = 48;
// ----- Load the training data -----
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));
MultiDataSetIterator trainData2 = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("trainFeatures", trainFeatures)
.addSequenceReader("trainLabels", trainLabels)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
.addInput("trainFeatures")
.addInput("trainLabels")
.build();
// ----- Load the test data -----
//Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, lastTestCount));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, lastTestCount));
MultiDataSetIterator testData2 = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("testFeatures", testFeatures)
.addSequenceReader("testLabels", testLabels)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
.addInput("testFeatures")
.addInput("testLabels")
.build();
log.info(" Printing traindata dataset shape");
MultiDataSet data = trainData2.next();
System.out.println(java.util.Arrays.toString(data.getFeatures()));
log.info(" Printing testdata dataset shape");
MultiDataSet data2 = testData2.next();
System.out.println(java.util.Arrays.toString(data2.getFeatures()));
INDArray predicted2 = null;
//NETWORK CONFIGURATION SET-UP AND NETWORK INIT - START - [=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=]
MultiNormalizerStandardize normalizer2 = new MultiNormalizerStandardize();
normalizer2.fitLabel(true);
normalizer2.fit(trainData2); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
trainData2.reset();
while(trainData2.hasNext()) {
normalizer2.transform(trainData2.next()); //Apply normalization to the training data
}
while(testData2.hasNext()) {
normalizer2.transform(testData2.next()); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
trainData2.setPreProcessor(normalizer2);
testData2.setPreProcessor(normalizer2);
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(0.001, 0.9))
.seed(12345)
.l2(0.001)
.weightInit(WeightInit.XAVIER)
.inferenceWorkspaceMode(WorkspaceMode.SINGLE)
.trainingWorkspaceMode(WorkspaceMode.SINGLE)
.graphBuilder()
.addInputs("encoderInput","decoderInput")
.setInputTypes(InputType.recurrent(2), InputType.recurrent(3))
.addLayer("encoder", new LSTM.Builder().nIn(6).nOut(96).activation(Activation.TANH).build(), "encoderInput")
.addLayer("encoder2", new LSTM.Builder().nIn(6).nOut(48).activation(Activation.TANH).build(), "encoder")
.addVertex("laststep", new LastTimeStepVertex("encoderInput"), "encoder2")
.addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "laststep")
.addLayer("decoder", new LSTM.Builder().nIn(51).nOut(48).activation(Activation.TANH).build(), "decoderInput", "dup")
.addLayer("decoder2", new LSTM.Builder().nIn(48).nOut(96).activation(Activation.TANH).build(), "decoder")
.addLayer("output", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(96).nOut(2).build(), "decoder2")
.setOutputs("output")
.build();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.setListeners(new ScoreIterationListener(1));
net2.init();
// //NETWORK CONFIGURATION SET-UP AND NETWORK INIT - END - [=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=]
//ADD VISUALIZATION CODE HERE - START - <><><><><><><><><><><><><><><><><><><><><><><><>
//Initialize the user interface backend
UIServer uiServer = UIServer.getInstance();
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
uiServer.attach(statsStorage);
//Then add the StatsListener to collect this information from the network, as it trains
int listenerFrequency = 1;
net2.setListeners(new StatsListener(statsStorage, listenerFrequency));
//ADD VISUALIZATION CODE HERE - END - <*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*>
// ----- Train the network, evaluating the test set performance at each epoch -----
int nEpochs = 50;
log.info(" - net2.toStringFull() - "+net2.summary());
for (int i = 0; i < nEpochs; i++) {
net2.fit(trainData2);
trainData2.reset();
log.info("Epoch " + i + " complete. Time series evaluation:");
//Run regression evaluation on our single column input
RegressionEvaluation evaluation = new RegressionEvaluation(2);
testData2.reset();
}
String pathToSavedNetwork = "src/main/assets/location_next_neural_network_v6_07.zip";
File savedNetwork = new File(pathToSavedNetwork);