I tried to setup one first layer definition from my split classes for the different tasks, probably helps somebody (WARNING Not tested or sure if works). Only first steps and weight init not finshed and no debug … so use on own risk :
/**
* first try to implement a complete transfomer in one layer
*
* @author mrrobot
*
*/
public class TransformerLayer extends SameDiffLayer {
/**
* serialization id
*/
private static final long serialVersionUID = 5974498113062600619L;
// param names
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
private static final String GAIN_KEY_ADD1 = "Gadd1";
private static final String GAIN_KEY_ADD2 = "Gadd2";
private static final String WEIGHT_KEY_FFN = "Wffn";
// TODO add bias parameter
private int embWidth;
private int seqLen;
private int heads;
private int headSize;
/**
* builder constructor
*
* @param builder
*/
public TransformerLayer(Builder builder) {
embWidth = builder.embWidth;
seqLen = builder.seqLen;
heads = builder.heads;
headSize = heads / embWidth;
}
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
SDVariable result;
SDVariable attention;
SDVariable Wffn = paramTable.get(WEIGHT_KEY_FFN);
SDVariable gain1 = paramTable.get(GAIN_KEY_ADD1);
SDVariable gain2 = paramTable.get(GAIN_KEY_ADD2);
// first multi head attention dot product
if(heads > 1){
SDVariable Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
SDVariable Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
SDVariable Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
SDVariable Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
}else{
attention = sameDiff.nn.dotProductAttention(getLayerName(), layerInput, layerInput, layerInput, mask, true);
}
// add and norm
SDVariable add1 = attention.add(layerInput);
SDVariable norm1 = sameDiff.nn.layerNorm(add1, gain1, false, new int[] {1});
// ffn network part
long[] shape = layerInput.getShape();
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
int timeSteps = inputSlices.length;
SDVariable[] outputSlices = new SDVariable[timeSteps];
for (int i=0;i<timeSteps;i++) {
outputSlices[i] = inputSlices[i].mmul(Wffn);
outputSlices[i] = sameDiff.expandDims(outputSlices[i], 2);
}
SDVariable ffnout = sameDiff.concat(2, outputSlices);
// add and norm
SDVariable add2 = ffnout.add(norm1);
result = sameDiff.nn.layerNorm(add2,gain2,false,new int[] {1});
return result;
}
@Override
public void defineParameters(SDLayerParams params) {
params.clear();
// check for multi head attention parameter
if(heads > 1){
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, heads, headSize, embWidth);
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, heads, headSize, embWidth);
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, heads, headSize, embWidth);
params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, heads * headSize, embWidth);
}
// ffn parameters
params.addWeightParam(WEIGHT_KEY_FFN, embWidth, embWidth);
// layer normalization parameter
params.addWeightParam(GAIN_KEY_ADD1, seqLen);
params.addWeightParam(GAIN_KEY_ADD2, seqLen);
}
@Override
public void initializeParameters(Map<String, INDArray> params) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (Map.Entry<String, INDArray> e : params.entrySet()) {
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){
WeightInitUtil.initWeights(embWidth, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
} else if (!e.getKey().startsWith("Gadd")){
WeightInitUtil.initWeights(heads * headSize, embWidth, e.getValue().shape(), weightInit, null, 'c', e.getValue());
}
}
}
params.get(GAIN_KEY_ADD1).assign(1.0);
params.get(GAIN_KEY_ADD2).assign(1.0);
}
/**
* ensure NCW input format to be compatible with attention implementation
*/
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
}
/**
* configure and info about output type RNNFormat.NCW with seqLen and embWidth Size
*/
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for transformer layer (layer index = " + layerIndex
+ ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: "
+ inputType);
}
return InputType.recurrent(embWidth, seqLen);
}
/**
* util class to build a custom transformer layer
*/
public static class Builder {
public int embWidth;
public int heads;
public int seqLen;
public Builder embWidth(int width) {
embWidth = width;
return this;
}
public Builder nHeads(int heads) {
this.heads = heads;
return this;
}
public Builder seqLen(int len) {
seqLen = len;
return this;
}
public TransformerLayer builder() {
return new TransformerLayer(this);
}
}
}