SelfAttention Token Training Example


i found a lot of examples to classify text with the selfattention layer. Is there Already the possibility to train token wise. Like in BertIterator.Task.UNSUPERVISED. Would appreciate any code examples.

I tried to create a simple Learning Graph with Attention for Tokens like this:

    final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
    	.updater(new RmsProp(lr))

    final GraphBuilder graphBuilder = builder.graphBuilder()
        .addInputs("inputLine" ,"inputPos")
        .setInputTypes(InputType.recurrent(vocabSize) ,InputType.recurrent(maxLen))
                new EmbeddingSequenceLayer.Builder()
        .addLayer("posEmbedding", new EmbeddingSequenceLayer.Builder()
        		.build(), "inputPos")
        .addLayer("attention1", new SelfAttentionLayer.Builder()
        		.build(), "embedding","posEmbedding")
        .addLayer("output", new RnnOutputLayer.Builder()
        		.build(), "attention1")

    model = new ComputationGraph(;

I create a BertTokenIterator for Unsupervised learning and the gradients constantly explodes and i get only NaN Results.

Thanks in advance.


@thomas sorry for the late reply. Holidays and all. Let me get back to you with an example. You’ll probably want to use samediff though. You’ll want to use our BertIterator like here: deeplearning4j-examples/ at master · deeplearning4j/deeplearning4j-examples · GitHub

If you need something more specific, could you try to elaborate a bit? Thanks!

@agibsonccc thanks for your reply. The examples i found before, they are all for Sequence Classification not Token Classification. My problem ist that my model in 1000 tries and parameter setups every time crashes when i train with BertIterator and UNSUPERSIVED. I programmed my own iterator now.

For you a tipp, i got similar crashes from time to time with my iterator until i saw that i used in my iteratore something like this:

		// set input vector 
		input.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.point(0), NDArrayIndex.interval(0, inLen) },
                Nd4j.create(ArrayUtils.toPrimitive(masked.toArray(new Double[0]))));

I thought it would be ok if the List (masked) is longer then the “inLen” variable here, as long as it is not longer then the given vector. But as soon as i cut the List (masked) to the exact length == inLen the crashes where gone.

best regards