Issues with LSTM-based prediction of a 2D trajectory

I try to train a LSTM network to predict a 2D trajectory of a golf ball. Previously, I was able to successfully train a simple neural net (dense layers only) to predict a fixed number of consecutive points of such trajectories, using the same training data that I used for LSTM network training. However, here the results are simply not there (the resulting lines continue in a semi-random direction), which hopefully is just an effect of my mistake somewhere in the training process.

Here’s my network’s config:

public static MultiLayerConfiguration getRNNConfiguration(long seed) {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .updater(new Nadam())
            .weightInit(WeightInit.XAVIER)
            .list()
            .layer(new LSTM.Builder().activation(Activation.TANH).nIn(2).nOut(100).build())
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(100).nOut(2).build())
            .build();

    return conf;

I tried different updaters and different layer combinations (multiple LSTMs, single LSTM, different sizes, etc.) and always got very similar effects (lines going somewhat up, not resembling where the ball would actually go).

Here’s how I created my training data (based on measured golf trajectories, in screen coordinate space, which I normalise to be in (-1.0;1.0) range before training, where (0.0;0.0) is the center of the screen; I used the same method for my other network training, with no LSTM layers, and the results were very precise):

public static List<DataSet> createRNNDataSet(List<LineData> lineData, boolean onlySmoothOutput) {
    List<DataSet> result = new ArrayList<>();

    for(LineData data : lineData) {
        List<Double> points = new ArrayList<>();
        List<Double> smoothPoints = new ArrayList<>();
        data.points.forEach(p -> {
            points.add((double) p.x);
            points.add((double) p.y);
        });

        data.smoothPoints.forEach(p -> {
            smoothPoints.add((double) p.x);
            smoothPoints.add((double) p.y);
        });

        if(! onlySmoothOutput) {
            List<Double> normalizedPoints = TrainingData.normalizePoints(points, data.width, data.height);
            final int pointCount    = normalizedPoints.size() - 2;
            double in[][][]         = new double[1][2][pointCount];
            double out[][][]        = new double[1][2][pointCount];

            for(int i = 0; i < normalizedPoints.size() - 2; i += 2) {
                in[0][0][i / 2]     = normalizedPoints.get(i);
                in[0][1][i / 2]     = normalizedPoints.get(i + 1);
                out[0][0][i / 2]    = normalizedPoints.get(i + 2);
                out[0][1][i / 2]    = normalizedPoints.get(i + 3);
            }

            DataSet resultEntry = new DataSet(Nd4j.create(in), Nd4j.create(out));
            result.add(resultEntry);
        }

        List<Double> normalizedSmoothPoints = TrainingData.normalizePoints(smoothPoints, data.width, data.height);
        final int pointCount    = normalizedSmoothPoints.size() - 2;
        double in[][][]         = new double[1][2][pointCount];
        double out[][][]        = new double[1][2][pointCount];

        for(int i = 0; i < normalizedSmoothPoints.size() - 2; i += 2) {
            in[0][0][i / 2]     = normalizedSmoothPoints.get(i);
            in[0][1][i / 2]     = normalizedSmoothPoints.get(i + 1);
            out[0][0][i / 2]    = normalizedSmoothPoints.get(i + 2);
            out[0][1][i / 2]    = normalizedSmoothPoints.get(i + 3);
        }

        DataSet resultEntry = new DataSet(Nd4j.create(in), Nd4j.create(out));

        result.add(resultEntry);
    }

    return result;
}

And finally the training method:

    public static void trainRNNModel(List<DataSet> dataSets, String namePrefix, File outputDir) throws IOException {
        dataSets = new ArrayList<>(dataSets);

        final int targetIterations  = 100000;
        final long seed             = 54321;
        final int batchSize         = 32;
        final int epochCount        = Math.max((int) Math.ceil((float) targetIterations / dataSets.size() * batchSize), 12);
        int iterationCount          = Math.round((float) dataSets.size() / batchSize * epochCount);
        
        System.out.println("Predicted iteration count: " + iterationCount);

        Collections.shuffle(dataSets, new Random(seed));
        DataSetIterator trainingDataIt = new ListDataSetIterator<>(dataSets, batchSize);

        MultiLayerConfiguration conf = getRNNConfiguration(seed);

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(
                new ScoreIterationListener(10),
                new BaseTrainingListener() {
                    @Override
                    public void iterationDone(Model model, int iteration, int epoch) {
                        if (iteration % 10 == 0) {
                            try {
                                ModelUtil.saveRNN((MultiLayerNetwork) model, outputDir, namePrefix);
                            } catch (IOException e) {
                                e.printStackTrace();
                            }
                        }
                    }
                }
        );

        model.fit(trainingDataIt, epochCount);
        ModelUtil.saveRNN(model, outputDir, namePrefix);
    }

All of that is super straightforward and that’s exactly why I don’t see where I did something wrong, so I would really appreciate if someone could spot the issue there and point me in the right direction.

@b005t3r I thought about this over the last day or so. Your pipeline looks correct with normalizing the data. Did you try using zero mean unit variance normalization instead? Beyond that, I would suggest adding more data points.

Hey, thx for the reply @agibsonccc

No, I haven’t tried that. However, with the current normalisation method I was able to train other regression models with a fixed number of inputs (dense layers only) and I have gotten good results (predicted segments were pretty precise, they often follow the actual trajectory very closely), but here I just get a random, semi-straight line going mostly upwards.

Is there like a rule of thumb on how many data sets/batches should I need? Or how many hidden layers/neurons should I use? Up to this point all of those params were just guesswork for me :slight_smile: