Hello,
I tried to build my first own transformer encoder block with the attention layers given. i read the most forum entries and some online articles and so far i got nearly to connect all basic layers. the last add normalize is missing but here i am stuck with an exception…
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(System.currentTimeMillis())
.graphBuilder();
// ------- INPUT LAYER -------------------
// encoder deocder inputs
builder.setInputTypes(
InputType.recurrent(vocabsize, MAX_SEQ), // token
InputType.recurrent(MAX_SEQ, MAX_SEQ)); // token pos
builder.addInputs("token","encodePos"); // ,"decodeIn","decodePos");
builder.addLayer("embedding", new EmbeddingSequenceLayer.Builder()
.nIn(vocabsize)
.nOut(embsize)
.build(), "token");
builder.addLayer("posembedding", new EmbeddingSequenceLayer.Builder()
.activation(Activation.IDENTITY)
.nIn(MAX_SEQ)
.nOut(embsize)
.build(), "encodePos");
builder.addLayer("encode_attention", new SelfAttentionLayer.Builder()
.nIn(embsize*2)
.nOut(embsize*2)
.nHeads(1)
.build(),"embedding","posembedding");
builder.addLayer("encode_subsample",new GlobalPoolingLayer.Builder(PoolingType.AVG)
.build(),"encode_attention","posembedding","embedding");
builder.addLayer("encode_ff", new DenseLayer.Builder()
.nOut(hiddenNodes)
.activation(Activation.RELU)
.build(),"encode_subsample");
builder.addLayer("output", new OutputLayer.Builder()
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)
.nOut(pos.getLabels().size())
.build(), "encode_ff","encode_subsample"); // original decode_sample
if i only use “encode_ff” for output there is no problem but as far as i merge the output input from “encode_ff” and “encode_subsample” i get an exception on backpropagation:
java.lang.IllegalStateException: Cannot perform operation "addi" - shapes are not equal and are not broadcastable.first.shape=[8, 400], second.shape=[8, 400, 1]
at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:638)
at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:373)
at org.nd4j.linalg.api.shape.Shape.assertBroadcastable(Shape.java:263)
at org.nd4j.linalg.api.ndarray.BaseNDArray.addi(BaseNDArray.java:3270)
at org.nd4j.linalg.api.ndarray.BaseNDArray.addi(BaseNDArray.java:3264)
at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2800)
at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1381)
at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1044)
i tried to insert a reshapevertex but without success. can anybody help what is my mistake here?
thanks in advance
thomas