Hello,
i didn’t understand how to get output for a LSTM network (probably more specifically any network)
i have a set of CSV file with that format, every file has 200 rows. there are 10 “attributes” and label can assume only 0/1 values, so i guess there are 2 labels
0/1 (label),double,double,double…
0/1 (label),double,double,double…
and so on…
so i splitted the file arbitrary from 0-50 are used for train and 51-60 for test
datasetiterator are created in that way
SequenceRecordReader reader = new CSVSequenceRecordReader(0, ","); reader.initialize(new NumberedFileInputSplit("data/csv/%d.csv", 0, 50)); DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(reader, 100, 2, 0, false);
SequenceRecordReader reader = new CSVSequenceRecordReader(0, ","); reader.initialize(new NumberedFileInputSplit("data/csv/%d.csv", 51, 60)); DataSetIterator testData= new SequenceRecordReaderDataSetIterator(reader, 100, 2, 0, false);
same apply for test.
net is created with
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123) //Random number generator seed for improved repeatability. Optional.
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.005))
.list()
.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(numFeatures).nOut(HID_LAY).build())
.layer(1, new LSTM.Builder().activation(Activation.TANH).nIn(HID_LAY).nOut(HID_LAY).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(HID_LAY).nOut(2).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
hid_lay it’s 200.
now, i train for some times with (eventually i also use net.evaluate(test) to see the progress, but it doesn’t matter at this point.
net.fit(trainData);
trainData.reset();
and then i am trying to extract data. my idea was to do something like
INDArray output = net.output(testData);
but i don’t understand at this point, what information i can extract from output.
how to get the “next” prediction for testData?
how to get the prediction for testData at position N? (suppose again that testData is loaded from a set of csv (9x) that have 200 rows each, how can i get the prediction for index YYYY)
thanks for helping hope my english is clear enough