Hello all, I’m trying to import a PyTorch generated GRU onnx model into dl4j and got errors.
I will show all my processes and my speculation about the error.
(1) Generate GRU onnx model with PyTorch
Follow by PyTorch document GRU and onnx, I wrote a simple script.
import torch
rnn = torch.nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
torch.onnx.export(rnn, (input, h0), "Single_GRU.onnx")
(2) Import to dl4j, got import error
OnnxFrameworkImporter onnxFrameworkImporter = new OnnxFrameworkImporter();
SameDiff graph = onnxFrameworkImporter.runImport(f.getAbsolutePath(), Collections.emptyMap(), true);
Exception in thread "main" java.lang.IndexOutOfBoundsException: Index 2 out of bounds for length 2
at java.base/jdk.internal.util.Preconditions.outOfBounds(Preconditions.java:64)
(3) Modify onnx file, got another error
I found IndexOutOfBoundsException occurs while accessing GRU onnx node outputs, so I wrote a python script add more outputs to GRU node, which makes an off-standard onnx file.
import onnx
model = onnx.load(r"Single_GRU.onnx")
graph = model.graph
node = graph.node
for i in range(len(node)):
if node[i].op_type == "GRU":
for j in range(2,4):
node[i].output.insert(j, node[i].name + "_output_placeholder_" + str(j))
print(node[i])
print("=====")
onnx.save(model, r"Single_GRU_addoutput.onnx")
Then I got another error.
Error at [D:/a/deeplearning4j/deeplearning4j/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp:97:0]:
gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got 3, 3
Exception in thread "main" java.lang.RuntimeException: Op gruCell with name /GRU failed to execute. Here is the error from c++: Op validation failed
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(NativeOpExecutioner.java:1486)