How to use sd.nn.batchNorm(…) in Deeplearning4j?

 SDVariable mean = sd.var("mean", new XavierFanInInitScheme('c', NerUtil.MAX_SENTENCE_LENGTH * 6), NerUtil.MAX_SENTENCE_LENGTH * 6);
      SDVariable variance = sd.var("variance", new DistributionInitScheme('c', new UniformDistribution(0, 1)), NerUtil.MAX_SENTENCE_LENGTH * 6);
      SDVariable gamma = sd.var("gamma", new XavierFanInInitScheme('c', NerUtil.MAX_SENTENCE_LENGTH * 6), DataType.FLOAT, NerUtil.MAX_SENTENCE_LENGTH * 6);
      SDVariable beta = sd.var("beta", new XavierFanInInitScheme('c', NerUtil.MAX_SENTENCE_LENGTH * 6), DataType.FLOAT, NerUtil.MAX_SENTENCE_LENGTH * 6);

      SDVariable batchNorm1 = sd.nn.batchNorm("batchNorm1", bertOutput, mean, variance, gamma, beta, true, true, 1e-8, 2);
      SDVariable tanh1 = sd.nn.tanh("tanh1", batchNorm1);
      SDVariable mmul1 = sd.mmul("mmul1", tanh1, wOut);

mean, variance, gamma, beta should be variables or constants or others? just not know how to use this method.

Usually, beta is 0, gammma is 1,
This is a working case of 1D Batch Normalization :

    val m1 = sd.`var`("m1", 1)
    val m2 = sd.`var`("m2", 1)
    val gamma = sd.`var`("m3", Nd4j.ones(1))
    val beta = sd.`var`("m4", Nd4j.zeros(1))
    sd.nn.batchNorm("batch", f.reshape(1,1,4), m1, m2, gamma, beta, 0.00001, 1)
    println(sd.output(null, "batch").get("batch"))

Because doDiff() is not implemented correctly. I can not get gradients.
sd.calculateGradients(null, “s1”, “s2”) makes an errors.

Are you using an old version? It is implemented in beta7:

Thank you for the reply. I am using beta7.
Here is the actual program in kotlin, get the same result in Java.

46 val vv = Nd4j.create(arrayOf(floatArrayOf(1.0F,2.0F, 3.0F, 4.0F)))
47 var sd = SameDiff.create()
48 val s2 = sd.var(“s1”, vv)
49 val m1 = sd.var(“m1”, 1)
50 val m2 = sd.var(“m2”, 1)
51 val m3 = sd.var(“m3”, Nd4j.ones(1))
52 val m4 = sd.var(“m4”, Nd4j.zeros(1))
53 sd.nn.batchNorm(“batch”, s2.reshape(1,1,4), m1, m2, m3, m4, 0.00001, 1)
54 sd.setLossVariables(“batch”)
55 println(sd.output(null, “batch”).get(“batch”))
56 sd.calculateGradients(null, “s1”)

And I get the following result. I am investigating what is going on.

[[[ 316.2278, 632.4556, 948.6833, 1264.9111]]]
Exception in thread “main” java.lang.IllegalStateException: Expected 3 to 5 input datatypes for class org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative, got [FLOAT, FLOAT, FLOAT, FLOAT, FLOAT, FLOAT]
at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:641)
at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:340)
at org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.calculateOutputDataTypes(BatchNorm.java:199)
at org.nd4j.autodiff.samediff.SameDiff.generateOutputVariableForOp(SameDiff.java:3807)
at org.nd4j.linalg.api.ops.DynamicCustomOp.outputVariables(DynamicCustomOp.java:230)
at org.nd4j.linalg.api.ops.DynamicCustomOp.outputVariables(DynamicCustomOp.java:213)
at org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.doDiff(BatchNorm.java:193)
at org.nd4j.autodiff.functions.DifferentialFunction.diff(DifferentialFunction.java:559)
at org.nd4j.autodiff.samediff.SameDiff$1.define(SameDiff.java:4427)
at org.nd4j.autodiff.samediff.SameDiff.defineFunction(SameDiff.java:3971)
at org.nd4j.autodiff.samediff.SameDiff.defineFunction(SameDiff.java:3956)
at org.nd4j.autodiff.samediff.SameDiff.createGradFunction(SameDiff.java:4167)
at org.nd4j.autodiff.samediff.SameDiff.createGradFunction(SameDiff.java:4074)
at org.nd4j.autodiff.samediff.SameDiff.calculateGradientsAndOutputs(SameDiff.java:4012)
at org.nd4j.autodiff.samediff.SameDiff.calculateGradients(SameDiff.java:3994)
at org.nd4j.autodiff.samediff.SameDiff.calculateGradients(SameDiff.java:3982)
at com.fujitsu.labs.neuralnetwork.Ex2_LinearRegression.main(Ex2_LinearRegression.kt:56)

That does indeed look like a bug to me, and I’ve opened an issue here to track the progress of resolving it:

Workaround Suggestion

As the underlying ops themselves do work properly, you can workaround that issue, but it requires a bit of insight into the definition of Ops for SameDiff.

In principle you’d create a subclass of BatchNorm op and BatchNormDerivative each. The BatchNorm subclass is just so it can reference your fixed BatchNormDerivative class.

The fixed BatchNormDerivative will require a bit more fixing than just the input type check.

If you feel up to it, you can check the other Derivative classes on how they should work. If not, you can wait for us to fix it. And then either copy that fix for your own project, or upgrade to SNAPSHOTS or wait for the next release.

Here’s one possible workaround:

class BatchNormFixed(sameDiff: SameDiff?, input: SDVariable?, mean: SDVariable?, variance: SDVariable?, gamma: SDVariable?, beta: SDVariable?, epsilon: Double, vararg axis: Int) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sameDiff, input, mean, variance, gamma, beta, epsilon, axis) {
    override fun doDiff(f1: MutableList<SDVariable>?): MutableList<SDVariable> {
        var list = args().toMutableList()
        list.add(f1!!.get(0))
        return BatchNormFixedBp(sameDiff, list.toTypedArray(), null, null, false, isApplyGamma, isApplyBeta, epsilon, axis.toIntArray()).outputs()
    }
}

class BatchNormFixedBp(sameDiff: SameDiff?, inputFunctions: Array<out SDVariable>?, inputArrays: Array<out INDArray>?, outputArrays: Array<out INDArray>?, inPlace: Boolean, applyGamma: Boolean, applyBeta: Boolean, epsilon: Double, axis: IntArray?) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative(sameDiff, inputFunctions, inputArrays, outputArrays, inPlace, applyGamma, applyBeta, epsilon, axis) {
    override fun calculateOutputDataTypes(dataTypes: MutableList<DataType>?): MutableList<DataType> {
        val types = mutableListOf(dataTypes!![0], dataTypes[1], dataTypes[2])
        if(isApplyBeta){
            types.add(dataTypes[3])
        }
        if(isApplyGamma){
            types.add(dataTypes[4])
        }
        return types;
    }

    override fun getNumOutputs(): Int {
        var out = 3;
        if(isApplyGamma){
            out++
        }
        if(isApplyBeta){
            out++
        }
        return out
    }
}

Thank you for opening the issue and giving me a workaround suggestion.
Based on it, I made my own BatchNormalDrivative class and BatchNormal class.
They seems working.