Here’s one possible workaround:
class BatchNormFixed(sameDiff: SameDiff?, input: SDVariable?, mean: SDVariable?, variance: SDVariable?, gamma: SDVariable?, beta: SDVariable?, epsilon: Double, vararg axis: Int) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sameDiff, input, mean, variance, gamma, beta, epsilon, axis) {
override fun doDiff(f1: MutableList<SDVariable>?): MutableList<SDVariable> {
var list = args().toMutableList()
list.add(f1!!.get(0))
return BatchNormFixedBp(sameDiff, list.toTypedArray(), null, null, false, isApplyGamma, isApplyBeta, epsilon, axis.toIntArray()).outputs()
}
}
class BatchNormFixedBp(sameDiff: SameDiff?, inputFunctions: Array<out SDVariable>?, inputArrays: Array<out INDArray>?, outputArrays: Array<out INDArray>?, inPlace: Boolean, applyGamma: Boolean, applyBeta: Boolean, epsilon: Double, axis: IntArray?) : org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative(sameDiff, inputFunctions, inputArrays, outputArrays, inPlace, applyGamma, applyBeta, epsilon, axis) {
override fun calculateOutputDataTypes(dataTypes: MutableList<DataType>?): MutableList<DataType> {
val types = mutableListOf(dataTypes!![0], dataTypes[1], dataTypes[2])
if(isApplyBeta){
types.add(dataTypes[3])
}
if(isApplyGamma){
types.add(dataTypes[4])
}
return types;
}
override fun getNumOutputs(): Int {
var out = 3;
if(isApplyGamma){
out++
}
if(isApplyBeta){
out++
}
return out
}
}