Here I have a basic seq2seq model in Keras as follows:
def gen_models(dim_hidden=128):
# training encoder
encoder_input = Input(shape=(None, 1), name='encoder_input')
encoder = LSTM(dim_hidden, return_state=True)
encoder_output, encoder_h, encoder_c = encoder(encoder_input)
encoded_state = [encoder_h, encoder_c]
# training decoder
decoder_input = Input(shape=(None, 1), name='decoder_input')
decoder = LSTM(dim_hidden, return_sequences=True, return_state=True)
decoder_output, _, _ = decoder(decoder_input, initial_state=encoded_state)
output_layer_dropout = Dropout(0.2)
output_layer_dense1 = Dense(64, activation='tanh')
output_layer_dense2 = Dense(1, name='output')
output = output_layer_dense2(output_layer_dense1(decoder_output))
model = Model([encoder_input, decoder_input], output)
# inference encoder
infer_encoder = Model(encoder_input, encoded_state)
# inference decoder
infer_init_h = Input(shape=(1, dim_hidden), name='h_input')
infer_init_c = Input(shape=(1, dim_hidden), name='c_input')
infer_init_state = [infer_init_h, infer_init_c]
infer_output, infer_h, infer_c = decoder(decoder_input, initial_state=infer_init_state)
infer_state = [infer_h, infer_c]
infer_outputs = output_layer_dense2(output_layer_dense1(infer_output))
infer_decoder = Model([decoder_input] + infer_init_state, [infer_outputs] + infer_state)
return model, infer_encoder, infer_decoder
The inference decoder requires 3 inputs(cell input and hidden states), where the cell input is a 3D tensor and the hidden states are two 2D matrices.
Here’s the thing
When I use DL4J to import the decoder to Java, I get Merge Vertex error. The exception information is as follows:
org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException: Invalid input: MergeVertex cannot merge activations of different types: first type = RNN, input type 2 = FF at org.deeplearning4j.nn.conf.graph.MergeVertex.getOutputType(MergeVertex.java:139) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.getLayerActivationTypes(ComputationGraphConfiguration.java:537) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.addPreProcessors(ComputationGraphConfiguration.java:450) at org.deeplearning4j.nn.conf.ComputationGraphConfiguration$GraphBuilder.build(ComputationGraphConfiguration.java:1202) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraphConfiguration(KerasModel.java:394) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraph(KerasModel.java:415) at org.deeplearning4j.nn.modelimport.keras.KerasModel.getComputationGraph(KerasModel.java:404) at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:173) at com.didi.woqu.scheduling.util.Seq2SeqModel.loadModel(Seq2SeqModel.java:163)
Then I trace down the source code, finding that when DL4J builds the computation graph, it will add a ‘merge’ vertex if a layer vertex has multiple inputs and maxVertexInputs() = 1. Here decoder LSTM layer satisfies this condition.
So my guess is that DL4J thinks LSTM has only one input, ignoring the initial hidden states. But in seq2seq, the hidden states should be part of the inputs of the layer.
Is this a bug? Or how can I import the model successfully?