Failed to import keras seq2seq model to DL4J with Merge Vertex error

Here I have a basic seq2seq model in Keras as follows:

def gen_models(dim_hidden=128):
    # training encoder
    encoder_input = Input(shape=(None, 1), name='encoder_input')
    encoder = LSTM(dim_hidden, return_state=True)
    encoder_output, encoder_h, encoder_c = encoder(encoder_input)
    encoded_state = [encoder_h, encoder_c]
    # training decoder
    decoder_input = Input(shape=(None, 1), name='decoder_input')
    decoder = LSTM(dim_hidden, return_sequences=True, return_state=True)
    decoder_output, _, _ = decoder(decoder_input, initial_state=encoded_state)
    output_layer_dropout = Dropout(0.2)
    output_layer_dense1 = Dense(64, activation='tanh')
    output_layer_dense2 = Dense(1, name='output')
    output = output_layer_dense2(output_layer_dense1(decoder_output))

    model = Model([encoder_input, decoder_input], output)
    # inference encoder
    infer_encoder = Model(encoder_input, encoded_state)
    # inference decoder
    infer_init_h = Input(shape=(1, dim_hidden), name='h_input')
    infer_init_c = Input(shape=(1, dim_hidden), name='c_input')
    infer_init_state = [infer_init_h, infer_init_c]
    infer_output, infer_h, infer_c = decoder(decoder_input, initial_state=infer_init_state)
    infer_state = [infer_h, infer_c]
    infer_outputs = output_layer_dense2(output_layer_dense1(infer_output))
    infer_decoder = Model([decoder_input] + infer_init_state, [infer_outputs] + infer_state)
    return model, infer_encoder, infer_decoder

The inference decoder requires 3 inputs(cell input and hidden states), where the cell input is a 3D tensor and the hidden states are two 2D matrices.

Here’s the thing

When I use DL4J to import the decoder to Java, I get Merge Vertex error. The exception information is as follows:
org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException: Invalid input: MergeVertex cannot merge activations of different types: first type = RNN, input type 2 = FF at org.deeplearning4j.nn.conf.graph.MergeVertex.getOutputType(MergeVertex.java:139) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.getLayerActivationTypes(ComputationGraphConfiguration.java:537) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.addPreProcessors(ComputationGraphConfiguration.java:450) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration$GraphBuilder.build(ComputationGraphConfiguration.java:1202) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraphConfiguration(KerasModel.java:394) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraph(KerasModel.java:415) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraph(KerasModel.java:404) at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:173) at com.didi.woqu.scheduling.util.Seq2SeqModel.loadModel(Seq2SeqModel.java:163)

Then I trace down the source code, finding that when DL4J builds the computation graph, it will add a ‘merge’ vertex if a layer vertex has multiple inputs and maxVertexInputs() = 1. Here decoder LSTM layer satisfies this condition.

So my guess is that DL4J thinks LSTM has only one input, ignoring the initial hidden states. But in seq2seq, the hidden states should be part of the inputs of the layer.

Is this a bug? Or how can I import the model successfully?