Nd4j.scatterUpdates slower than simple CPU implementation

I’m using the EmbeddingLayer and SequenceEmbeddingLayer but I noticed it was slow.

I tracked it down to the call to Nd4j.scatterUpdate

It felt unreasonably slow so I tried to do it in java.

from this:

        INDArray weightGradients = this.gradientViews.get("W");
        weightGradients.assign(0);
        INDArray indices = Nd4j.createFromArray(this.indexes);
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ASSIGN, weightGradients, indices, epsilon, WEIGHT_DIM);

to this:

        INDArray weightGradients = this.gradientViews.get("W");
        weightGradients.assign(0);

        float[][] weightGradientUpdates = new float[(int) this.layerConf().getDictionarySize()][(int) nOut];
        float[][] eps = new float[(int)nOut][(int)epsilon.size(0)];
        for (int j = 0; j < nOut; j++) {
            INDArray column = epsilon.getColumn(j);
            eps[j] = column.data().asFloat();
        }

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < nOut; j++) {
                weightGradientUpdates[indexes[i]][j]+=eps[j][i];
            }
        }

        INDArray reshape = Nd4j.create(weightGradientUpdates).reshape(this.layerConf().getDictionarySize(), 1);
        weightGradients.addi(reshape);

makes the code 100x faster

The size of the weightsGradient does matter in the CPU work-around and if it gets too large (>1M), the RAM grows too much and the call to Nd4j.create is slow. So, this CPU workaround doesn’t scale.

these are my backend dependencies and I have cuDnn set up:

<dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-cuda-11.6</artifactId>
            <version>1.0.0-M2.1</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-cuda-11.6</artifactId>
            <version>1.0.0-M2.1</version>
            <classifier>windows-x86_64-cudnn</classifier>
        </dependency>

@ebeaufay can you setup a reproducer for me and file a github issue? Sign in to GitHub · GitHub

yes, I added an issue: Nd4j.ScatterUpdates has a large overhead · Issue #10029 · deeplearning4j/deeplearning4j · GitHub

@ebeaufay thanks I’ll take a look!