I tried to import a keras model which consists of a Masking layer , LSTM(64) and then two dense layers.
It takes input the one hot encoded vectors which are then pre-padded to fixed length.
In Java I simply used this to import the model given the h5 file saved from keras:
MultiLayerNetwork model= KerasModelImport.importKerasSequentialModelAndWeights(model_weight_struc_path);
After importing the model, i checked the dense layers, their model weights and bias matched up with the python version. However the LSTM layer weights don’t match up.
If i just run the input and get the lstm layer output using feedforward().get() method, it doesn’t match the result given by python model from the same lstm output.
I assume there’re some difference because in keras the input dimension is (1, timestep, features) but in DL4J this is (1,features,timestep).
So i created a toy model:
model=Sequential()
model.add(Masking(mask_value=0.0,input_shape=(3,5)))
model.add(LSTM(6))
model.add(Dense(1,activation='sigmoid'))
Then in python, if i do model.layers[1].get_weights()[1]
it gives me a (6,24) matrix as below:
But in java, calling System.out.println(model.getParam("1_RW"))
gives me this (6,24) matrix:
[[ 0.0088, -0.2417, -0.2305, 0.2706, 0.1590, 0.3007, -0.0555, 0.0270, 0.3155, -0.0984, -0.1430, -0.0205, -0.3000, 0.1479, -0.0466, 0.2184, 0.1548, 0.0750, -0.0230, -0.3926, -0.4117, 0.0292, 0.0271, -0.2201],
[ 0.1818, 0.1398, -0.0671, -0.0146, -0.2206, -0.3618, -0.0043, 0.5145, 0.0989, 0.1930, -0.3135, 0.2240, -0.0753, -0.1024, 0.0655, -0.0538, 0.2332, -0.0445, 0.0382, -0.0611, -0.2177, -0.1036, -0.3807, 0.1219],
[ -0.2991, 0.1280, 0.5051, 0.5081, 0.2659, -0.0082, -0.2817, 0.2300, -0.0058, -0.1414, -0.0094, 0.0928, -0.1039, 0.0323, 0.0075, 0.0445, 0.0116, -0.0360, -0.0148, 0.0854, 0.1633, 0.2460, -0.1953, 0.0297],
[ 0.3264, 0.2795, -0.0577, -0.3209, 0.3719, 0.3687, -0.0624, 0.3096, -0.0187, 0.0080, 0.1646, 0.1298, 0.2549, 0.2799, -0.0596, -0.0176, 0.1987, -0.0776, 0.0796, -0.0164, 0.0126, 0.3000, 0.0046, 0.0355],
[ -0.0691, 0.2839, 0.0199, -0.0236, 0.0436, -0.0590, -0.1854, -0.3150, 0.2103, -0.0140, 0.2179, 0.1249, 0.0680, 0.3275, -0.2206, 0.1803, -0.1980, -0.0579, 0.0450, -0.0770, -0.2042, -0.3926, -0.3038, 0.3681],
[ 0.0067, -0.0010, -0.0336, 0.1233, -0.1479, 0.2217, -0.1695, -0.1380, -0.2195, 0.1326, -0.0877, 0.4220, -0.0741, 0.1814, 0.2704, 0.1903, 0.3147, 0.1108, -0.1511, 0.4744, -0.1200, -0.1763, 0.2374, 0.0676]]
the ordering is weird, i can see the first element of the first row in Nd4j (which is 0.0088) appears in the first row of python matrix but from the middle of it. Then it continues off until some point and jumps to the front of the row.
Can anyone take a look at this?
Edit: Formatted Code and output text for better readability.