"Attention Is All You Need" model implementation using dl4j

Hi,

I searched and have been unable to find examples of a transformer/attention model using dl4j.

I have looked at examples such as

and have implemented a network like the one in the example mentioned above. However, it is not quite the same as a transformer/attention model as described in “Attention Is All You Need”

I also searched for “attention” on the deepelaning4j repository on github.

Has anyone seen/done such an implementation? If you have, could you share it with me?
Thanks,
Alex Donnini

@adonnini Unfortunately the easiest case is using model import. I’m working on finishing a pull request that improves the attention op as we have it in dl4j and plan on publishing some examples afterwards. With the interest in language models I also plan on adding some optimized kernels.

Otherwise for now we’ve imported some models like bert/gpt2 from tensorflow and those work. Most people implement attention manually and the ones from tensorflow have the ops manually defined like matrix multiply (which we have).

We also have the bert word piece tokenizer.

Your best bet if you want to experiment is to look in the tests. Otherwise I’ll let you know when I’ve published the improved documentation/examples.

Thanks for the prompt response Adam. I know you are extremely busy.

I will look through the tests again. I have not considered importing
models so far. I will try it out.

@adonnini you can definitely implement it manually. @partarstu did that I believe. Note that it’s mainly used in the samediff framework though not dl4j. I will be implementing support for dl4j for keras import mainly but regardless it will be there for you.

The biggest hurdle in implementing a transformer/attention model I am encountering is in configuration of the dual inputs.

I thought it might be helpful if I started by implementing the model described in

Unfortunately, I am not at all familiar with Scala. I tried to “translate/adapt” the code in Java.

Below, you will find the latest iteration of my attempt.

It is not working. Execution fails with

Exception in thread “main” java.lang.IllegalArgumentException: Invalid output array: network has 1 outputs, but array is of length 0
at org.deeplearning4j.nn.graph.ComputationGraph.setLabels(ComputationGraph.java:417)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1144)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1127)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1094)

In te code below the failure occurs here:

net2.fit(trainData2);

What am I doing wrong? I am guessing the issue may have to do with the way my input data is structured (?). I don’t understand which output array and outputs the error message refers to.

Any thoughts/ideas as to what is going on would be very helpful.

Thanks.

====================================================================

    int miniBatchSize = 48;


    // ----- Load the training data -----
    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));

    MultiDataSetIterator trainData2 = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
            .addSequenceReader("trainFeatures", trainFeatures)
            .addSequenceReader("trainLabels", trainLabels)
            .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
            .addInput("trainFeatures")
            .addInput("trainLabels")
            .build();

    // ----- Load the test data -----
    //Same process as for the training data.
    SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
    testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, lastTestCount));
    SequenceRecordReader testLabels = new CSVSequenceRecordReader();
    testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, lastTestCount));

    MultiDataSetIterator testData2 = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
            .addSequenceReader("testFeatures", testFeatures)
            .addSequenceReader("testLabels", testLabels)
            .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
            .addInput("testFeatures")
            .addInput("testLabels")
            .build();

    log.info(" Printing traindata dataset shape");
    MultiDataSet data = trainData2.next();
    System.out.println(java.util.Arrays.toString(data.getFeatures()));

    log.info(" Printing testdata dataset shape");
    MultiDataSet data2 = testData2.next();
    System.out.println(java.util.Arrays.toString(data2.getFeatures()));


    INDArray predicted2 = null;

    //NETWORK CONFIGURATION SET-UP AND NETWORK INIT - START - [=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=]

    MultiNormalizerStandardize normalizer2 = new MultiNormalizerStandardize();

    normalizer2.fitLabel(true);
    normalizer2.fit(trainData2);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data

    trainData2.reset();

    while(trainData2.hasNext()) {
        normalizer2.transform(trainData2.next());     //Apply normalization to the training data
    }

    while(testData2.hasNext()) {
        normalizer2.transform(testData2.next());         //Apply normalization to the test data. This is using statistics calculated from the *training* set
    }

    trainData2.setPreProcessor(normalizer2);
    testData2.setPreProcessor(normalizer2);

    ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
            .updater(new Nesterovs(0.001, 0.9))
            .seed(12345)
            .l2(0.001)
            .weightInit(WeightInit.XAVIER)
            .inferenceWorkspaceMode(WorkspaceMode.SINGLE)
            .trainingWorkspaceMode(WorkspaceMode.SINGLE)
            .graphBuilder()
            .addInputs("encoderInput","decoderInput")
            .setInputTypes(InputType.recurrent(2), InputType.recurrent(3))
            .addLayer("encoder", new LSTM.Builder().nIn(6).nOut(96).activation(Activation.TANH).build(), "encoderInput")
            .addLayer("encoder2", new LSTM.Builder().nIn(6).nOut(48).activation(Activation.TANH).build(), "encoder")
            .addVertex("laststep", new LastTimeStepVertex("encoderInput"), "encoder2")
            .addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "laststep")
            .addLayer("decoder", new LSTM.Builder().nIn(51).nOut(48).activation(Activation.TANH).build(), "decoderInput", "dup")
            .addLayer("decoder2", new LSTM.Builder().nIn(48).nOut(96).activation(Activation.TANH).build(), "decoder")
            .addLayer("output", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(96).nOut(2).build(), "decoder2")
            .setOutputs("output")
            .build();

    ComputationGraph net2 = new ComputationGraph(conf2);
    net2.setListeners(new ScoreIterationListener(1));
    net2.init();

// //NETWORK CONFIGURATION SET-UP AND NETWORK INIT - END - [=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=][=]
//ADD VISUALIZATION CODE HERE - START - <><><><><><><><><><><><><><><><><><><><><><><><>
//Initialize the user interface backend
UIServer uiServer = UIServer.getInstance();

    //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
    StatsStorage statsStorage = new InMemoryStatsStorage();         //Alternative: new FileStatsStorage(File), for saving and loading later

    //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
    uiServer.attach(statsStorage);

    //Then add the StatsListener to collect this information from the network, as it trains
    int listenerFrequency = 1;
    net2.setListeners(new StatsListener(statsStorage, listenerFrequency));
    //ADD VISUALIZATION CODE HERE - END - <*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*><*>

    // ----- Train the network, evaluating the test set performance at each epoch -----

    int nEpochs = 50;

    log.info(" - net2.toStringFull() - "+net2.summary());

    for (int i = 0; i < nEpochs; i++) {
        net2.fit(trainData2);
        trainData2.reset();
        log.info("Epoch " + i + " complete. Time series evaluation:");

        //Run regression evaluation on our single column input
        RegressionEvaluation evaluation = new RegressionEvaluation(2);

        testData2.reset();
    }

    String pathToSavedNetwork = "src/main/assets/location_next_neural_network_v6_07.zip";
    File savedNetwork = new File(pathToSavedNetwork);

Here is the summary of the network I created in the code I included in the message above:

===================================================================================================================
VertexName (VertexType) nIn,nOut TotalParams ParamsShape Vertex Inputs

encoderInput (InputVertex) -,- - - -
decoderInput (InputVertex) -,- - - -
encoder (LSTM) 2,96 38,016 W:{2,384}, RW:{96,384}, b:{384} [encoderInput]
encoder2 (LSTM) 96,48 27,840 W:{96,192}, RW:{48,192}, b:{192} [encoder]
laststep (LastTimeStepVertex) -,- - - [encoder2]
dup (DuplicateToTimeSeriesVertex) -,- - - [laststep]
decoder-merge (MergeVertex) -,- - - [decoderInput, dup]
decoder (LSTM) 51,48 19,200 W:{51,192}, RW:{48,192}, b:{192} [decoder-merge]
decoder2 (LSTM) 48,96 55,680 W:{48,384}, RW:{96,384}, b:{384} [decoder]
output (RnnOutputLayer) 96,2 194 W:{96,2}, b:{2} [decoder2]

        Total Parameters:  140,930
    Trainable Parameters:  140,930
       Frozen Parameters:  0

===================================================================================================================

@adonnini you don’t need this architecture. Not even transformers are multi input. Multi input literally means a network that looks like this:

input1 >
merge/concat → next node
input2 >

Please stick to simpler problems. As I mentioned, the main work and examples for transformer are not even in the old dl4j API. It’s in the newer samediff framework. If you want to spend time learning something, spend time there. The same concepts like updaters and datasetiterators still apply.

I started the implementation of the network based on SameDiff. I use

as my starting point.

I have a few questions, I hope you don’t mind (too much):

  1. With regards to dataset iterators can I still use
    SequenceRecordReaderDataSetIterator to create them?

  2. When creating input and label variables, I noticed that DataType is
    set to FLOAT, and shape to -1. Are these the default values? How do I
    determine what they should be set to for my network?

  3. Similarly, the variable layerSize0 used when defining the hidden
    layer is set to 128. Is this the “default”? How do I determine what it
    should be set to for my network?

  4. When defining layers, how do I specify the type of layer I want to
    create? (e.g. encoder, decoder, selfAttention, LSTM, GlobalPooling etc.,
    etc.), or is this not done when using SameDiff? I have a feeling I am
    missing something important here

  5. Similarly, I do not see anywhere in the code where a normalizer is
    specifiied/used. Do I simply normalize the data as I have been doing
    (using NormalizerStandardize) before training the network?

I may have other questions later. I’ll try and find answers on my own
before asking you.

Thanks.

  1. yes call fit with the iterator
  2. yes data type is the default but all variables will require you to specify a data type anyways. -1 is used with unknown shapes
  3. Layers in general do not have defaults and you shouldn’t expect them. The same is true for the dl4j API.
  4. Create a samediff instance then just declare the ops you want to use eg:
SameDiff sd = Samediff.create();
SDVariable var1 = sd.var(...);
SDVariable var2 = sd.var(...);
SDVarable result = sd.math().add(var1,var2);

There are different namespaces you can use to separate out all the various ops.

Thanks very much!

So, I can still use NormalizerStandardize?

A)
I understand the general structure of the SameDiff network. It is is
very clear. What I still do not understand is how I would specify an
LSTM layer, for example. Are you saying that I would need to specify
SDVariable variabes that would define an LSTM?

In the quick start example, when it comes to the layer definition, the
code is

     SDVariable w0 = sd.var("w0", new XavierInitScheme('c', nIn, 

layerSize0), DataType.FLOAT, nIn, layerSize0);
SDVariable b0 = sd.zero(“b0”, 1, layerSize0);
SDVariable activations0 = sd.nn().tanh(in.mmul(w0).add(b0));

which just specifies activation, weightInit, and shape, nothing specific
to the type of layer. How would I define and LSTM layer using SameDiff?
In dl4j there is a specific LSTM builder. Am I supposed to look at the
LSTM Builder code and extract values for the SDVariable variables?

B)
What does layerSize0 refer to?

Thanks

@adonnini

A) look under sd.nn.lstmLayer(…) layers you want are more than likely ops.

B) layerSize0 is just the first output size. It’s equivalent to nOut. We’re specifying a weight matrix as a variable.

@adonnini , you could PM me if you want some details regarding attention implementation using SameDiif. I got it working but unfortunately I have no time now to prepare the project for open-sourcing it.

Hi Adam,

I have spent the past few days trying to understand why

doesn’t work when I use my dataset as input.

As a reminder, I define my input dataset as follows:

    trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));
    trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, lastTrainCount));

    trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

with numLabelClasses set to -1

The only potential explanation I have come up with is that my input dataset combines features and labels while TrainingConfig expects separate feature and label data structures. This is just a guess on my part. I confess that I am lost.

Could you help me?

By the way, here is the error Ex2_LinearRegression.java fails with when I use my dataset as input:

ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim 28 != yDim 6
Exception in thread “main” java.lang.RuntimeException: Op matmul with name matmul failed to execute. Here is the error from c++:
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(NativeOpExecutioner.java:1672)
at org.nd4j.linalg.api.ops.DynamicCustomOp.calculateOutputShape(DynamicCustomOp.java:696)
at org.nd4j.autodiff.samediff.internal.InferenceSession.getAndParameterizeOp(InferenceSession.java:1363)
at org.nd4j.autodiff.samediff.internal.InferenceSession.getAndParameterizeOp(InferenceSession.java:68)
at org.nd4j.autodiff.samediff.internal.AbstractSession.output(AbstractSession.java:531)
at org.nd4j.autodiff.samediff.SameDiff.directExecHelper(SameDiff.java:2927)
at org.nd4j.autodiff.samediff.SameDiff.directExecHelper(SameDiff.java:2890)
at org.nd4j.autodiff.samediff.SameDiff.outputHelper(SameDiff.java:2683)
at org.nd4j.autodiff.samediff.SameDiff.output(SameDiff.java:2518)
at org.nd4j.autodiff.samediff.config.OutputConfig.exec(OutputConfig.java:132)
at org.nd4j.autodiff.samediff.SameDiff.output(SameDiff.java:2473)
at org.deeplearning4j.examples.quickstart.modeling.recurrent.LocationNextNeuralNetworkV6.sameDiff2(LocationNextNeuralNetworkV6.java:933)
at org.deeplearning4j.examples.quickstart.modeling.recurrent.LocationNextNeuralNetworkV6.main(LocationNextNeuralNetworkV6.java:193)

@adonnini mind posting your samediff.summary()?

Here it is:

— Summary —
Variables: 9 (2 with arrays)
Functions: 5
SameDiff Function Defs: 0
Loss function variables:

— Variables —

  • Name - - Array Shape - - Variable Type - - Data Type- - Output Of Function - - Inputs To Functions -
    add - ARRAY FLOAT add(add) [subtract]
    bias VARIABLE FLOAT [add]
    input [32, 6, -1] PLACEHOLDER FLOAT [matmul]
    label [32, 2, 1] PLACEHOLDER FLOAT [subtract]
    matmul - ARRAY FLOAT matmul(matmul) [add]
    mse - ARRAY FLOAT reduce_mean(reduce_mean)
    square - ARRAY FLOAT square(square) [reduce_mean]
    subtract - ARRAY FLOAT subtract(subtract) [square]
    weights [32, 6, 1] VARIABLE FLOAT [matmul]

Also, here is the shape of a dataset in the trainData iterator:

Printing traindata dataset shape - 1
[32, 6, 57]

I don’t know if this is any help. I changed my code to loop through the trainData iterator and process one dataset at a time ensuring that the SDVariable weights and the trainData dataset being processed have exactly the same shape (in bold below).

The code fails with the same exact error. Here is the execution output

Printing traindata dataset shape
[32, 6, 28]
dim2 - 28
======================================================= -
weights.getShapeDescriptor().toString() - 1 - [3,32, 6, 28,168, 28, 1,8192,1,c]
Printing sd information
SameDiff(nVars=9,nOps=5)
— Summary —
Variables: 9 (2 with arrays)
Functions: 5
SameDiff Function Defs: 0
Loss function variables:

— Variables —

  • Name - - Array Shape - - Variable Type - - Data Type- - Output Of Function - - Inputs To Functions -
    add - ARRAY FLOAT add(add) [subtract]
    bias VARIABLE FLOAT [add]
    input [32, 6, -1] PLACEHOLDER FLOAT [matmul]
    label [32, 2, 1] PLACEHOLDER FLOAT [subtract]
    matmul - ARRAY FLOAT matmul(matmul) [add]
    mse - ARRAY FLOAT reduce_mean(reduce_mean)
    square - ARRAY FLOAT square(square) [reduce_mean]
    subtract - ARRAY FLOAT subtract(subtract) [square]
    weights [32, 6, 28] VARIABLE FLOAT [matmul]

— Functions —
- Function Name - - Op - - Inputs - - Outputs -
0 matmul Mmul [input, weights] [matmul]
1 add AddOp [matmul, bias] [add]
2 subtract SubOp [label, add] [subtract]
3 square Square [subtract] [square]
4 reduce_mean Mean [square] [mse]

ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim 28 != yDim 6
Exception in thread “main” java.lang.RuntimeException: Op matmul with name matmul failed to execute. Here is the error from c++:
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(NativeOpExecutioner.java:1672)
at org.nd4j.linalg.api.ops.DynamicCustomOp.calculateOutputShape(DynamicCustomOp.java:696)
at org.nd4j.autodiff.samediff.internal.InferenceSession.getAndParameterizeOp(InferenceSession.java:1363)
at org.nd4j.autodiff.samediff.internal.InferenceSession.getAndParameterizeOp(InferenceSession.java:68)
at org.nd4j.autodiff.samediff.internal.AbstractSession.output(AbstractSession.java:531)
at org.nd4j.autodiff.samediff.SameDiff.directExecHelper(SameDiff.java:2927)
at org.nd4j.autodiff.samediff.SameDiff.batchOutputHelper(SameDiff.java:2870)
at org.nd4j.autodiff.samediff.SameDiff.output(SameDiff.java:2835)
at org.nd4j.autodiff.samediff.SameDiff.output(SameDiff.java:2808)
at org.nd4j.autodiff.samediff.config.BatchOutputConfig.output(BatchOutputConfig.java:183)
at org.nd4j.autodiff.samediff.SameDiff.output(SameDiff.java:2764)
at org.deeplearning4j.examples.quickstart.modeling.recurrent.LocationNextNeuralNetworkV6.sameDiff2(LocationNextNeuralNetworkV6.java:974)
at org.deeplearning4j.examples.quickstart.modeling.recurrent.LocationNextNeuralNetworkV6.main(LocationNextNeuralNetworkV6.java:193)

@adonnini your dimensions are still clearly off. Could you turn on the executioner debug and verbose with:
Nd4j.getExecutioner().enableVerboseMode(true);
Nd4j.getExecutioner().enableDebugMode(true);

Matrix multiplies require very specific inputs. Columns must equal rows. Otherwise it throws that error. You’re either defining the graph wrong or your input data is wrong.

Thanks. I added executioner debug and verbose. I am not sure how to interpret the output. I o not see any entries relevant to the definition/execution of the network. Should I send you a partial or the entire debug/verbose output?

As you say, matrix multiplies require that columns equal rows. I thought that the purpose of SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END
in the line below was to some extent to do that

    trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

My feature files all have different number of rows which are not necessarily equal to the number of columns (features).

More specifically, in my case for matrix multiply which matrices need to have the same number of rows and columns?

As far as I know, graph definition is correct as the code I am using comes straight from

I am sorry for these questions which show my ignorance and take up your time. I have looked at the code for many (all) of the examples and frankly I do not understand how shapes are set (e.g. deeplearning4j-examples/MNISTCNN.java at master · deeplearning4j/deeplearning4j-examples · GitHub)

Thanks

@adonnini no just show me the output. It’ll show me the shapes and where your model fails.