LSTM sequence prediction fail ((

Hello, friends!

Im trying develop rnn network based on LSTM, which will predict stock market but im failed(
The problem is that my network on output shows input sequence.

You can tell me about normalization, but i dont know about maximum price in future and thats why i use RELU activation function.
Please tell me what could be the problem?

Model:

    private MultiLayerNetwork getModel(int inputNum, int outputNum) {
        int lstmLayer1Size = 200;
        int lstmLayer2Size = 200;
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(123)
                .updater(new Adam(0.005))
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(0, new LSTM.Builder().nIn(inputNum).nOut(lstmLayer1Size).name("lstm1").gateActivationFunction(Activation.RELU).build())
                .layer(1, new LSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size).name("lstm2").gateActivationFunction(Activation.RELU).build())
                .layer(2, new RnnOutputLayer.Builder().nIn(lstmLayer2Size).nOut(outputNum).name("output").activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        return net;
    }

Train iterator:

class CandleDataSetIterator implements DataSetIterator {
    private final List<CandleDto> candles;

    private final int lengthSequence;

    private final int shift;

    private final int batchSize;

    private int position = 0;

    public CandleDataSetIterator(List<CandleDto> candles, int lengthSequence, int shift, int batchSize) {
        this.lengthSequence = lengthSequence;
        this.shift = shift;
        Collections.sort(candles);
        this.candles = candles;
        this.batchSize = batchSize;
    }

    @Override
    public DataSet next(int num) {
        INDArray features = Nd4j.create(num - 1, 1, lengthSequence);
        INDArray labels = Nd4j.create(num-1, 1, lengthSequence);

        for (int i = 0; i < num && hasNext(); i++) {
            for (int j = 0; j < lengthSequence; j++) {
                features.putScalar(new int[]{i, 0, j}, candles.get(position + j).getClose());
                labels.putScalar(new int[]{i, 0, j}, candles.get(position + lengthSequence + j).getClose());
            }
            position += shift;
        }

        return new DataSet(features, labels);
    }

    @Override
    public int inputColumns() {
        return 1;
    }

    @Override
    public int totalOutcomes() {
        return 1;
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    @Override
    public void reset() {
        position = 0;
    }

    @Override
    public int batch() {
        return batchSize;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {

    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        throw new UnsupportedOperationException("Not support the function");
    }

    @Override
    public List<String> getLabels() {
        return null;
    }

    @Override
    public boolean hasNext() {
        return (position + lengthSequence*2) <= candles.size();
    }

    @Override
    public DataSet next() {
        INDArray features = Nd4j.create(1, 1, lengthSequence);
        INDArray labels = Nd4j.create(1, 1, lengthSequence);

        if (hasNext()) {
            for (int i = 0; i < lengthSequence; i++) {
                features.putScalar(new int[]{0, 0, i}, candles.get(position + i).getClose());
                labels.putScalar(new int[]{0, 0, i}, candles.get(position + lengthSequence + i).getClose());
            }
        }
        position += shift;

        return new DataSet(features, labels);
    }
}

Trying predict (fault):
test_fail

Predict method:

public void predict() {
        try {
            List<CandleDto> testCandles = evalCandles.subList(evalCandles.size() - 1 - LENGTH_SEQUENCE*3, evalCandles.size() - 1);
            INDArray warmUp = Nd4j.create(1, 1, LENGTH_SEQUENCE);
            INDArray in = Nd4j.create(1, 1, LENGTH_SEQUENCE);
            for (int i = 0; i < LENGTH_SEQUENCE; i++) {
                warmUp.putScalar(new int[]{0, 0, i}, testCandles.get(i).getClose());
                in.putScalar(new int[]{0, 0, i}, testCandles.get(LENGTH_SEQUENCE + i).getClose());
            }

            MultiLayerNetwork model = MultiLayerNetwork.load(new File(MODEL_PATH_FILE), false);
            model.rnnClearPreviousState();
            model.rnnTimeStep(warmUp);
            INDArray predicted = model.rnnTimeStep(in);

            List<Double> p = getListFromINDArray(predicted);

            Plot plot = Plot.create();
            plot.plot()
                    .add(IntStream.range(0, testCandles.size()).boxed().collect(Collectors.toList()), testCandles.stream().mapToDouble(CandleDto::getClose).boxed().collect(Collectors.toList()))
                    .label("real");
            plot.plot()
                    .add(IntStream.range(LENGTH_SEQUENCE*2, LENGTH_SEQUENCE*2 + p.size()).boxed().collect(Collectors.toList()), p)
                    .label("predicted");
            plot.legend();
            plot.show();

            model.close();
        } catch (IOException | PythonExecutionException e) {
            throw new RuntimeException(e);
        }
    }

@Dimitry82 of note…this is a hard problem. Most of the time predicting price directly doesn’t work due to market volatility and market seasonality. I highly recommend doing some more reading before attempting this further.

Regardless let me wish you luck and answer your question. Usually one thing you’ll want to do is normalize your labels as well. I would also play with different forms of normalizing your inputs (0 to 1, zero mean unit variance)