How to use sd.nn.batchNorm(…) in Deeplearning4j?

Here’s one possible workaround:

class BatchNormFixed(sameDiff: SameDiff?, input: SDVariable?, mean: SDVariable?, variance: SDVariable?, gamma: SDVariable?, beta: SDVariable?, epsilon: Double, vararg axis: Int) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sameDiff, input, mean, variance, gamma, beta, epsilon, axis) {
    override fun doDiff(f1: MutableList<SDVariable>?): MutableList<SDVariable> {
        var list = args().toMutableList()
        list.add(f1!!.get(0))
        return BatchNormFixedBp(sameDiff, list.toTypedArray(), null, null, false, isApplyGamma, isApplyBeta, epsilon, axis.toIntArray()).outputs()
    }
}

class BatchNormFixedBp(sameDiff: SameDiff?, inputFunctions: Array<out SDVariable>?, inputArrays: Array<out INDArray>?, outputArrays: Array<out INDArray>?, inPlace: Boolean, applyGamma: Boolean, applyBeta: Boolean, epsilon: Double, axis: IntArray?) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative(sameDiff, inputFunctions, inputArrays, outputArrays, inPlace, applyGamma, applyBeta, epsilon, axis) {
    override fun calculateOutputDataTypes(dataTypes: MutableList<DataType>?): MutableList<DataType> {
        val types = mutableListOf(dataTypes!![0], dataTypes[1], dataTypes[2])
        if(isApplyBeta){
            types.add(dataTypes[3])
        }
        if(isApplyGamma){
            types.add(dataTypes[4])
        }
        return types;
    }

    override fun getNumOutputs(): Int {
        var out = 3;
        if(isApplyGamma){
            out++
        }
        if(isApplyBeta){
            out++
        }
        return out
    }
}