I’m using the EmbeddingLayer and SequenceEmbeddingLayer but I noticed it was slow.
I tracked it down to the call to Nd4j.scatterUpdate
It felt unreasonably slow so I tried to do it in java.
from this:
INDArray weightGradients = this.gradientViews.get("W");
weightGradients.assign(0);
INDArray indices = Nd4j.createFromArray(this.indexes);
Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ASSIGN, weightGradients, indices, epsilon, WEIGHT_DIM);
to this:
INDArray weightGradients = this.gradientViews.get("W");
weightGradients.assign(0);
float[][] weightGradientUpdates = new float[(int) this.layerConf().getDictionarySize()][(int) nOut];
float[][] eps = new float[(int)nOut][(int)epsilon.size(0)];
for (int j = 0; j < nOut; j++) {
INDArray column = epsilon.getColumn(j);
eps[j] = column.data().asFloat();
}
for (int i = 0; i < indexes.length; i++) {
for (int j = 0; j < nOut; j++) {
weightGradientUpdates[indexes[i]][j]+=eps[j][i];
}
}
INDArray reshape = Nd4j.create(weightGradientUpdates).reshape(this.layerConf().getDictionarySize(), 1);
weightGradients.addi(reshape);
makes the code 100x faster
The size of the weightsGradient does matter in the CPU work-around and if it gets too large (>1M), the RAM grows too much and the call to Nd4j.create is slow. So, this CPU workaround doesn’t scale.
these are my backend dependencies and I have cuDnn set up:
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.6</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.6</artifactId>
<version>1.0.0-M2.1</version>
<classifier>windows-x86_64-cudnn</classifier>
</dependency>