Keras import model gave different prediction results

Unfortunately, you’ve run into a known issue (see DL4J: Add CNN1D & RNN 'channels last' support (Keras import) · Issue #8441 · eclipse/deeplearning4j · GitHub)

Depending on the complexity of your actual model, you will be able to work around it with something like the following. It replaces how the Flatten is handled after the model is imported:

 // Replace automatically created preprocessor with a more appropriate one for your case
model.getLayerWiseConfigurations().getInputPreProcessors()
     .put(1, new KerasFlattenPermutingingRnnPreprocessor(
             (KerasFlattenRnnPreprocessor) model.getLayerWiseConfigurations().getInputPreProcess(1)));

with KerasFlattenPermutingingRnnPreprocessor defined like this:

public class KerasFlattenPermutingingRnnPreprocessor extends KerasFlattenRnnPreprocessor {
    public KerasFlattenPermutingingRnnPreprocessor(KerasFlattenRnnPreprocessor wrapped) {
        super(wrapped.getDepth(), wrapped.getTsLength());
    }

    @Override
    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        return super.preProcess(input.permute(0, 2, 1), miniBatchSize, workspaceMgr);
    }
}
1 Like