Hello, friends!
Im trying develop rnn network based on LSTM, which will predict stock market but im failed(
The problem is that my network on output shows input sequence.
You can tell me about normalization, but i dont know about maximum price in future and thats why i use RELU activation function.
Please tell me what could be the problem?
Model:
private MultiLayerNetwork getModel(int inputNum, int outputNum) {
int lstmLayer1Size = 200;
int lstmLayer2Size = 200;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.seed(123)
.updater(new Adam(0.005))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new LSTM.Builder().nIn(inputNum).nOut(lstmLayer1Size).name("lstm1").gateActivationFunction(Activation.RELU).build())
.layer(1, new LSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size).name("lstm2").gateActivationFunction(Activation.RELU).build())
.layer(2, new RnnOutputLayer.Builder().nIn(lstmLayer2Size).nOut(outputNum).name("output").activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
Train iterator:
class CandleDataSetIterator implements DataSetIterator {
private final List<CandleDto> candles;
private final int lengthSequence;
private final int shift;
private final int batchSize;
private int position = 0;
public CandleDataSetIterator(List<CandleDto> candles, int lengthSequence, int shift, int batchSize) {
this.lengthSequence = lengthSequence;
this.shift = shift;
Collections.sort(candles);
this.candles = candles;
this.batchSize = batchSize;
}
@Override
public DataSet next(int num) {
INDArray features = Nd4j.create(num - 1, 1, lengthSequence);
INDArray labels = Nd4j.create(num-1, 1, lengthSequence);
for (int i = 0; i < num && hasNext(); i++) {
for (int j = 0; j < lengthSequence; j++) {
features.putScalar(new int[]{i, 0, j}, candles.get(position + j).getClose());
labels.putScalar(new int[]{i, 0, j}, candles.get(position + lengthSequence + j).getClose());
}
position += shift;
}
return new DataSet(features, labels);
}
@Override
public int inputColumns() {
return 1;
}
@Override
public int totalOutcomes() {
return 1;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
position = 0;
}
@Override
public int batch() {
return batchSize;
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not support the function");
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public boolean hasNext() {
return (position + lengthSequence*2) <= candles.size();
}
@Override
public DataSet next() {
INDArray features = Nd4j.create(1, 1, lengthSequence);
INDArray labels = Nd4j.create(1, 1, lengthSequence);
if (hasNext()) {
for (int i = 0; i < lengthSequence; i++) {
features.putScalar(new int[]{0, 0, i}, candles.get(position + i).getClose());
labels.putScalar(new int[]{0, 0, i}, candles.get(position + lengthSequence + i).getClose());
}
}
position += shift;
return new DataSet(features, labels);
}
}
Trying predict (fault):

Predict method:
public void predict() {
try {
List<CandleDto> testCandles = evalCandles.subList(evalCandles.size() - 1 - LENGTH_SEQUENCE*3, evalCandles.size() - 1);
INDArray warmUp = Nd4j.create(1, 1, LENGTH_SEQUENCE);
INDArray in = Nd4j.create(1, 1, LENGTH_SEQUENCE);
for (int i = 0; i < LENGTH_SEQUENCE; i++) {
warmUp.putScalar(new int[]{0, 0, i}, testCandles.get(i).getClose());
in.putScalar(new int[]{0, 0, i}, testCandles.get(LENGTH_SEQUENCE + i).getClose());
}
MultiLayerNetwork model = MultiLayerNetwork.load(new File(MODEL_PATH_FILE), false);
model.rnnClearPreviousState();
model.rnnTimeStep(warmUp);
INDArray predicted = model.rnnTimeStep(in);
List<Double> p = getListFromINDArray(predicted);
Plot plot = Plot.create();
plot.plot()
.add(IntStream.range(0, testCandles.size()).boxed().collect(Collectors.toList()), testCandles.stream().mapToDouble(CandleDto::getClose).boxed().collect(Collectors.toList()))
.label("real");
plot.plot()
.add(IntStream.range(LENGTH_SEQUENCE*2, LENGTH_SEQUENCE*2 + p.size()).boxed().collect(Collectors.toList()), p)
.label("predicted");
plot.legend();
plot.show();
model.close();
} catch (IOException | PythonExecutionException e) {
throw new RuntimeException(e);
}
}