Nd4j.scatterUpdates slower than simple CPU implementation

I’m using the EmbeddingLayer and SequenceEmbeddingLayer but I noticed it was slow.

I tracked it down to the call to Nd4j.scatterUpdate

It felt unreasonably slow so I tried to do it in java.

from this:

        INDArray weightGradients = this.gradientViews.get("W");
        INDArray indices = Nd4j.createFromArray(this.indexes);
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ASSIGN, weightGradients, indices, epsilon, WEIGHT_DIM);

to this:

        INDArray weightGradients = this.gradientViews.get("W");

        float[][] weightGradientUpdates = new float[(int) this.layerConf().getDictionarySize()][(int) nOut];
        float[][] eps = new float[(int)nOut][(int)epsilon.size(0)];
        for (int j = 0; j < nOut; j++) {
            INDArray column = epsilon.getColumn(j);
            eps[j] = column.data().asFloat();

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < nOut; j++) {

        INDArray reshape = Nd4j.create(weightGradientUpdates).reshape(this.layerConf().getDictionarySize(), 1);

makes the code 100x faster

The size of the weightsGradient does matter in the CPU work-around and if it gets too large (>1M), the RAM grows too much and the call to Nd4j.create is slow. So, this CPU workaround doesn’t scale.

these are my backend dependencies and I have cuDnn set up:


@ebeaufay can you setup a reproducer for me and file a github issue? Sign in to GitHub · GitHub

yes, I added an issue: Nd4j.ScatterUpdates has a large overhead · Issue #10029 · deeplearning4j/deeplearning4j · GitHub

@ebeaufay thanks I’ll take a look!