Keras import model gave different prediction results

I am trying KerasModelImport.importKerasSequentialModelAndWeights() to load a Keras model generated by Python Keras, but cannot get the same result from Python and Java.

# Python code:
input_length = 3
model = Sequential()
model.add(Embedding(7, 4, input_length=input_length))
model.add(Dense(1, activation='relu'))
model.compile(optimizer=SGD(lr=0.01, momentum=0.9), loss='mean_squared_error')

model = load_model('.\\data\\simple_mlp_4.h5')

data = [[1, 2, 3], [4, 3, 6]]
data = np.asarray(data, dtype=np.int32)

output = model.predict(data)
// Java code:
final String modelFile = new File(dataLocalPath,"simple_mlp_4.h5").getAbsolutePath();
MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(modelFile, true);

INDArray input = Nd4j.create(new float[]{1,2,3,4,3,6},new int[]{2,3}).castTo(DataType.INT);
INDArray output = model.output(input);


Python code prints out [[0.04062647] [0.01502968]], while Java code prints out [0.0604, 0.0385]. They loaded the same model and weights, but produce the different results. Anyone can help?

Unfortunately, you’ve run into a known issue (see

Depending on the complexity of your actual model, you will be able to work around it with something like the following. It replaces how the Flatten is handled after the model is imported:

 // Replace automatically created preprocessor with a more appropriate one for your case
     .put(1, new KerasFlattenPermutingingRnnPreprocessor(
             (KerasFlattenRnnPreprocessor) model.getLayerWiseConfigurations().getInputPreProcess(1)));

with KerasFlattenPermutingingRnnPreprocessor defined like this:

public class KerasFlattenPermutingingRnnPreprocessor extends KerasFlattenRnnPreprocessor {
    public KerasFlattenPermutingingRnnPreprocessor(KerasFlattenRnnPreprocessor wrapped) {
        super(wrapped.getDepth(), wrapped.getTsLength());

    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        return super.preProcess(input.permute(0, 2, 1), miniBatchSize, workspaceMgr);
1 Like