Attention and Pooling Problem with Merge on Backpropagation

Hello,

I tried to build my first own transformer encoder block with the attention layers given. i read the most forum entries and some online articles and so far i got nearly to connect all basic layers. the last add normalize is missing but here i am stuck with an exception…

        ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.XAVIER)
                .updater(new Adam(0.001))
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .seed(System.currentTimeMillis())
                .graphBuilder();
        
        // ------- INPUT LAYER -------------------
        
        // encoder deocder inputs 
        builder.setInputTypes(
        		InputType.recurrent(vocabsize, MAX_SEQ), // token 
        		InputType.recurrent(MAX_SEQ, MAX_SEQ)); // token pos 
        
        builder.addInputs("token","encodePos"); // ,"decodeIn","decodePos");
        
		builder.addLayer("embedding", new EmbeddingSequenceLayer.Builder()
				.nIn(vocabsize)
				.nOut(embsize)
				.build(), "token");
		builder.addLayer("posembedding", new EmbeddingSequenceLayer.Builder()
				.activation(Activation.IDENTITY)
				.nIn(MAX_SEQ)
				.nOut(embsize) 
				.build(), "encodePos"); 
		builder.addLayer("encode_attention", new SelfAttentionLayer.Builder()
				.nIn(embsize*2)
				.nOut(embsize*2)
				.nHeads(1)
				.build(),"embedding","posembedding"); 
		builder.addLayer("encode_subsample",new GlobalPoolingLayer.Builder(PoolingType.AVG)
				.build(),"encode_attention","posembedding","embedding");
		builder.addLayer("encode_ff", new DenseLayer.Builder()
				.nOut(hiddenNodes)
				.activation(Activation.RELU)
				.build(),"encode_subsample");		
		builder.addLayer("output", new OutputLayer.Builder()
				.activation(Activation.SOFTMAX)
				.lossFunction(LossFunction.MCXENT)
				.nOut(pos.getLabels().size())
				.build(), "encode_ff","encode_subsample"); // original decode_sample 

if i only use “encode_ff” for output there is no problem but as far as i merge the output input from “encode_ff” and “encode_subsample” i get an exception on backpropagation:

java.lang.IllegalStateException: Cannot perform operation "addi" - shapes are not equal and are not broadcastable.first.shape=[8, 400], second.shape=[8, 400, 1]
	at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:638)
	at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:373)
	at org.nd4j.linalg.api.shape.Shape.assertBroadcastable(Shape.java:263)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.addi(BaseNDArray.java:3270)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.addi(BaseNDArray.java:3264)
	at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2800)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1381)
	at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
	at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
	at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
	at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
	at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)
	at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1044)

i tried to insert a reshapevertex but without success. can anybody help what is my mistake here?

thanks in advance
thomas

I think i found my mistake, in architecture and fixed it now with an custom normalization layer it seems to work. i had to put two input preprocessor before and after the ff dense layer to get it work with the elementwise add (no preproc allowed for vertices) i used a identity activation layer.

Hi. If you’re seriously thinking of implementing Transformer’s encoder/decoder with an ability of doing some tweaks/modifications or accessing internal stuff (like attention scores which could be later reused e.g. for filtering/selecting the best passages fitting QA or generative model) I’d suggest using SameDiff instead of an ordinary neural network model. I also started with a NN config but after getting some advice from experienced contributors of DL4J I decided to switch to SameDiff.

SameDiff allows you to build the graph as well, but you can define it completely on your own - each variable/element of the layer etc. Truth is that it will be not be as easy as the approach you’re using now, but it brings enough flexibility.

SameDiff is very powerfull but it has too many suspending issues. DL4J layers had been used for serveral years and then are more stable.

@SidneyLann have you tried snapshots recently? I’ve been fixing quite a few of the issues you’re mentioning relating to training. If you can highlight what here you’re specifically referring to it would help:

One of the bigger issues I’ve found was in dealing with frozen models. Many model imports require you to unfreeze variables. It varies depending on what you’re doing, but I know people using samediff just fine. It’s definitely going to need some polish on some areas for sure though.

I will test the suspending issues after the new import framework work.

thanks for your advice i tried my first step by using some dl4j implementations from other projects to get a first touch. the positional encoding and normalization layer is already self implemented but i think willl need to learn a bit more of the details of the transformer to implement a complete encoder decoder example.

@thomas actually for attention layers we do have samediff.nn().dotProductAttention - you can find it here: https://github.com/eclipse/deeplearning4j/blob/93345cd067dfb287e520c0043de5cb4b1d05a786/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java#L122

Access this from any samediff instance with samediff.nn()

I’m currently using SameDiff with M1 version and have no big issues. I don’t use models imported from another frameworks though.

@partarstu @SidneyLann I actually do know of some issues that can occur depending on how you setup the finetuning. Some steps are manual. With model import, some models need to be unfrozen in order to work properly. My current PR will add some better support for those kinds of use cases that should help.