LSTM Training stops working in Snapshot

Hello,

i tried some lstm model to solve some nlp tasks like pos tagging. i created an multidatasetiterator which gives on size batches. the sequences have variable length. Everything worked under beta7 and after an update to SNAPSHOT i get the error:

java.lang.IllegalStateException: Sequence lengths do not match for RnnOutputLayer input and labels:Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=[16, 800, 63] vs. label=[16, 56, 63]
	at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:638)
	at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:337)
	at org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer.backpropGradient(RnnOutputLayer.java:59)
	at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doBackward(LayerVertex.java:148)
	at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2772)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1381)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)

it mentioned the sequencelength do not match. but for me it looks fine with the shapeinfo given also by the exception text. can there be a problem with with the variable length. i wrapped my iterator to use batches inside the iteratormultidatasetiterator.

Would appreciate any comments. Thanks in advance.

Best regards
Thomas

My Model:

		ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
				.seed(System.currentTimeMillis())
				//.updater(new Adam.Builder().learningRate(0.01).beta1(0.9).beta2(0.95).build())
				.updater(new Sgd(new ExponentialSchedule(ScheduleType.EPOCH, 0.01, 0.95)))
				.weightInit(WeightInit.XAVIER)
				.l2(1e-4)
				.graphBuilder()
					.addInputs("word")
					.addLayer("bi_lstm", new Bidirectional(new LSTM.Builder()
						.nIn(inputs)
						.nOut(size)
						.activation(Activation.SOFTSIGN)
						.build()), "word")
					.addLayer("lstm_hide1", new LSTM.Builder()
						.nIn(size*2)
						.nOut(size*2)
						.activation(Activation.SOFTSIGN)
						.build(), "bi_lstm")
					.addLayer("rnn_out", new RnnOutputLayer.Builder()
							.nIn(size*2)
							.nOut(outputs)
							.lossFunction(LossFunction.MCXENT)
							.build() , "lstm_hide1")
					.setOutputs("rnn_out")
					.build();

@thomas mind giving me a reproducer I can just run with a feed forward so I can compare on beta7 and SNAPSHOTS? Intended input data in ndarray form with a .output call is fine.

That should help separate out the multi dataset iterator from troubleshooting the network.

No Problem:

TestTrainer to run the Training

import java.io.IOException;

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

public class TestTrainer {
	
	private ComputationGraph createGraphModel(int inputs,int outputs) {
		
		int size = 400;
		
		ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
				.seed(System.currentTimeMillis())
				.updater(new Sgd(new ExponentialSchedule(ScheduleType.EPOCH, 0.01, 0.95)))
				.weightInit(WeightInit.XAVIER)
				.l2(1e-4)
				.graphBuilder()
					.addInputs("word")
					.addLayer("bi_lstm", new Bidirectional(new LSTM.Builder()
						.nIn(inputs)
						.nOut(size)
						.activation(Activation.SOFTSIGN)
						.build()), "word")
					.addLayer("lstm_hide1", new LSTM.Builder()
						.nIn(size*2)
						.nOut(size*2)
						.activation(Activation.SOFTSIGN)
						.build(), "bi_lstm")
					.addLayer("rnn_out", new RnnOutputLayer.Builder()
							.nIn(size*2)
							.nOut(outputs)
							.lossFunction(LossFunction.MCXENT)
							.build() , "lstm_hide1")
					.setOutputs("rnn_out")
					.build();
		
		ComputationGraph graph = new ComputationGraph(conf);
		graph.init();
		
		return graph;
		
	}
	
	public void trainModel() throws IOException {
		ComputationGraph model = createGraphModel(200, 56);
		
		TestTokenIterator           iter = new TestTokenIterator();

		model.addListeners(new ScoreIterationListener(1));
		
		model.fit(iter);
	}
	
	public static void main(String[] args) {
		TestTrainer trainer = new TestTrainer();
		
		try {
			trainer.trainModel();
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	

}

and the test iterator class:

import java.io.FileNotFoundException;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

public class TestTokenIterator implements MultiDataSetIterator {

	private static final long serialVersionUID = -8725261332489394391L;

	private int OUTPUTS;
	
	private int INPUTS;

	private int count;

	private List<String> labels;	
	
	private MultiDataSetPreProcessor pre;
	
	private int maxTrain;
	
	private INDArray stop;	
	
	private Random rand;
	
	public TestTokenIterator() throws FileNotFoundException {
		this.count = 0;
		this.maxTrain = 1500;
		
		this.stop     = Nd4j.ones(1,200);
		
		labels = new LinkedList<>();
		for (int i=0;i<56;i++) {
			labels.add("label" + i);
		}
		
		this.rand = new Random(System.currentTimeMillis());
		
		this.OUTPUTS = labels.size();
		this.INPUTS  = 200;
		
	}
	
	@Override
	public boolean hasNext() {
		return count < maxTrain;
	}

	@Override
	public MultiDataSet next() {
		return next(1);
	}

	@Override
	public MultiDataSet next(int num) {
		int mySHIFT,TIMESTEPS;
		int mynum = 1;
		
		INDArray[] featuresList = new INDArray[mynum];
		INDArray[] labelsList   = new INDArray[mynum];
		INDArray[] featureMasks = new INDArray[mynum];
		INDArray[] labelsMasks  = new INDArray[mynum];
		
		for (int i=0;i<mynum;i++) {

			int sent_size = 2 + rand.nextInt(15);
			
			System.out.println("seq len: " + sent_size);
			mySHIFT   = sent_size;
			TIMESTEPS = sent_size;
			
			INDArray featureSteps = Nd4j.create(INPUTS, TIMESTEPS+mySHIFT+1);
			INDArray labelSteps   = Nd4j.create(OUTPUTS,TIMESTEPS+mySHIFT+1);
			INDArray featureMask  = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct included 
			INDArray labelMask    = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct include 

			// create example 
			for (int e=0;e<TIMESTEPS;e++) {
				int wv = rand.nextInt(200);
				
				INDArray wvec = Nd4j.zeros(1,200);
				INDArray lvec = Nd4j.zeros(labels.size());
				
				wvec.putScalar(new int[] {0, wv}, 1.0);
				lvec.putScalar(new int[] { (wv % labels.size() ) }, 1.0);

				featureSteps.putColumn(e,wvec);				
				labelSteps.putColumn(e + mySHIFT + 1,lvec);

				featureMask.putScalar(new int[] { e }, 1.0);
				labelMask.putScalar(new int[] { e + mySHIFT + 1 }, 1.0);
			}
			
			// add stop step
			featureSteps.putColumn(TIMESTEPS, stop);
			featureMask.putScalar(new int[] { TIMESTEPS }, 1.0);

			featuresList[i] = Nd4j.expandDims(featureSteps, 0);
			labelsList[i]   = Nd4j.expandDims(labelSteps, 0);
			featureMasks[i] = Nd4j.expandDims(featureMask, 0);
			labelsMasks[i]  = Nd4j.expandDims(labelMask, 0);
			
			count++;
		}
		
		return new org.nd4j.linalg.dataset.MultiDataSet(featuresList, labelsList,featureMasks,labelsMasks);
	}

	@Override
	public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
		this.pre = preProcessor;
	}

	@Override
	public MultiDataSetPreProcessor getPreProcessor() {
		return pre;
	}

	@Override
	public boolean resetSupported() {
		return true;
	}

	@Override
	public boolean asyncSupported() {
		return false;
	}

	@Override
	public void reset() {
		count = 0;
	}

}

Hope it helps.

Best regards

Thomas

@thomas this got it working:

Please use setInputTypes(…) rather than manually setting the in/outs yourself.

Thanks a lot for your help. I don’t know if i would ever looked at this parameter. Tested and worked here.
:+1: