I try to train a LSTM network to predict a 2D trajectory of a golf ball. Previously, I was able to successfully train a simple neural net (dense layers only) to predict a fixed number of consecutive points of such trajectories, using the same training data that I used for LSTM network training. However, here the results are simply not there (the resulting lines continue in a semi-random direction), which hopefully is just an effect of my mistake somewhere in the training process.
Here’s my network’s config:
public static MultiLayerConfiguration getRNNConfiguration(long seed) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(new Nadam())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().activation(Activation.TANH).nIn(2).nOut(100).build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(100).nOut(2).build())
.build();
return conf;
I tried different updaters and different layer combinations (multiple LSTMs, single LSTM, different sizes, etc.) and always got very similar effects (lines going somewhat up, not resembling where the ball would actually go).
Here’s how I created my training data (based on measured golf trajectories, in screen coordinate space, which I normalise to be in (-1.0;1.0) range before training, where (0.0;0.0) is the center of the screen; I used the same method for my other network training, with no LSTM layers, and the results were very precise):
public static List<DataSet> createRNNDataSet(List<LineData> lineData, boolean onlySmoothOutput) {
List<DataSet> result = new ArrayList<>();
for(LineData data : lineData) {
List<Double> points = new ArrayList<>();
List<Double> smoothPoints = new ArrayList<>();
data.points.forEach(p -> {
points.add((double) p.x);
points.add((double) p.y);
});
data.smoothPoints.forEach(p -> {
smoothPoints.add((double) p.x);
smoothPoints.add((double) p.y);
});
if(! onlySmoothOutput) {
List<Double> normalizedPoints = TrainingData.normalizePoints(points, data.width, data.height);
final int pointCount = normalizedPoints.size() - 2;
double in[][][] = new double[1][2][pointCount];
double out[][][] = new double[1][2][pointCount];
for(int i = 0; i < normalizedPoints.size() - 2; i += 2) {
in[0][0][i / 2] = normalizedPoints.get(i);
in[0][1][i / 2] = normalizedPoints.get(i + 1);
out[0][0][i / 2] = normalizedPoints.get(i + 2);
out[0][1][i / 2] = normalizedPoints.get(i + 3);
}
DataSet resultEntry = new DataSet(Nd4j.create(in), Nd4j.create(out));
result.add(resultEntry);
}
List<Double> normalizedSmoothPoints = TrainingData.normalizePoints(smoothPoints, data.width, data.height);
final int pointCount = normalizedSmoothPoints.size() - 2;
double in[][][] = new double[1][2][pointCount];
double out[][][] = new double[1][2][pointCount];
for(int i = 0; i < normalizedSmoothPoints.size() - 2; i += 2) {
in[0][0][i / 2] = normalizedSmoothPoints.get(i);
in[0][1][i / 2] = normalizedSmoothPoints.get(i + 1);
out[0][0][i / 2] = normalizedSmoothPoints.get(i + 2);
out[0][1][i / 2] = normalizedSmoothPoints.get(i + 3);
}
DataSet resultEntry = new DataSet(Nd4j.create(in), Nd4j.create(out));
result.add(resultEntry);
}
return result;
}
And finally the training method:
public static void trainRNNModel(List<DataSet> dataSets, String namePrefix, File outputDir) throws IOException {
dataSets = new ArrayList<>(dataSets);
final int targetIterations = 100000;
final long seed = 54321;
final int batchSize = 32;
final int epochCount = Math.max((int) Math.ceil((float) targetIterations / dataSets.size() * batchSize), 12);
int iterationCount = Math.round((float) dataSets.size() / batchSize * epochCount);
System.out.println("Predicted iteration count: " + iterationCount);
Collections.shuffle(dataSets, new Random(seed));
DataSetIterator trainingDataIt = new ListDataSetIterator<>(dataSets, batchSize);
MultiLayerConfiguration conf = getRNNConfiguration(seed);
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(
new ScoreIterationListener(10),
new BaseTrainingListener() {
@Override
public void iterationDone(Model model, int iteration, int epoch) {
if (iteration % 10 == 0) {
try {
ModelUtil.saveRNN((MultiLayerNetwork) model, outputDir, namePrefix);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
);
model.fit(trainingDataIt, epochCount);
ModelUtil.saveRNN(model, outputDir, namePrefix);
}
All of that is super straightforward and that’s exactly why I don’t see where I did something wrong, so I would really appreciate if someone could spot the issue there and point me in the right direction.