Hello, friends!
Im trying develop rnn network based on LSTM, which will predict stock market but i
m failed(
The problem is that my network on output shows input sequence.
You can tell me about normalization, but i dont know about maximum price in future and that
s why i use RELU activation function.
Please tell me what could be the problem?
Model:
private MultiLayerNetwork getModel(int inputNum, int outputNum) {
int lstmLayer1Size = 200;
int lstmLayer2Size = 200;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.seed(123)
.updater(new Adam(0.005))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new LSTM.Builder().nIn(inputNum).nOut(lstmLayer1Size).name("lstm1").gateActivationFunction(Activation.RELU).build())
.layer(1, new LSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size).name("lstm2").gateActivationFunction(Activation.RELU).build())
.layer(2, new RnnOutputLayer.Builder().nIn(lstmLayer2Size).nOut(outputNum).name("output").activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
Train iterator:
class CandleDataSetIterator implements DataSetIterator {
private final List<CandleDto> candles;
private final int lengthSequence;
private final int shift;
private final int batchSize;
private int position = 0;
public CandleDataSetIterator(List<CandleDto> candles, int lengthSequence, int shift, int batchSize) {
this.lengthSequence = lengthSequence;
this.shift = shift;
Collections.sort(candles);
this.candles = candles;
this.batchSize = batchSize;
}
@Override
public DataSet next(int num) {
INDArray features = Nd4j.create(num - 1, 1, lengthSequence);
INDArray labels = Nd4j.create(num-1, 1, lengthSequence);
for (int i = 0; i < num && hasNext(); i++) {
for (int j = 0; j < lengthSequence; j++) {
features.putScalar(new int[]{i, 0, j}, candles.get(position + j).getClose());
labels.putScalar(new int[]{i, 0, j}, candles.get(position + lengthSequence + j).getClose());
}
position += shift;
}
return new DataSet(features, labels);
}
@Override
public int inputColumns() {
return 1;
}
@Override
public int totalOutcomes() {
return 1;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
position = 0;
}
@Override
public int batch() {
return batchSize;
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not support the function");
}
@Override
public List<String> getLabels() {
return null;
}
@Override
public boolean hasNext() {
return (position + lengthSequence*2) <= candles.size();
}
@Override
public DataSet next() {
INDArray features = Nd4j.create(1, 1, lengthSequence);
INDArray labels = Nd4j.create(1, 1, lengthSequence);
if (hasNext()) {
for (int i = 0; i < lengthSequence; i++) {
features.putScalar(new int[]{0, 0, i}, candles.get(position + i).getClose());
labels.putScalar(new int[]{0, 0, i}, candles.get(position + lengthSequence + i).getClose());
}
}
position += shift;
return new DataSet(features, labels);
}
}
Trying predict (fault):
Predict method:
public void predict() {
try {
List<CandleDto> testCandles = evalCandles.subList(evalCandles.size() - 1 - LENGTH_SEQUENCE*3, evalCandles.size() - 1);
INDArray warmUp = Nd4j.create(1, 1, LENGTH_SEQUENCE);
INDArray in = Nd4j.create(1, 1, LENGTH_SEQUENCE);
for (int i = 0; i < LENGTH_SEQUENCE; i++) {
warmUp.putScalar(new int[]{0, 0, i}, testCandles.get(i).getClose());
in.putScalar(new int[]{0, 0, i}, testCandles.get(LENGTH_SEQUENCE + i).getClose());
}
MultiLayerNetwork model = MultiLayerNetwork.load(new File(MODEL_PATH_FILE), false);
model.rnnClearPreviousState();
model.rnnTimeStep(warmUp);
INDArray predicted = model.rnnTimeStep(in);
List<Double> p = getListFromINDArray(predicted);
Plot plot = Plot.create();
plot.plot()
.add(IntStream.range(0, testCandles.size()).boxed().collect(Collectors.toList()), testCandles.stream().mapToDouble(CandleDto::getClose).boxed().collect(Collectors.toList()))
.label("real");
plot.plot()
.add(IntStream.range(LENGTH_SEQUENCE*2, LENGTH_SEQUENCE*2 + p.size()).boxed().collect(Collectors.toList()), p)
.label("predicted");
plot.legend();
plot.show();
model.close();
} catch (IOException | PythonExecutionException e) {
throw new RuntimeException(e);
}
}