Importing GRU onnx model failed

Hello all, I’m trying to import a PyTorch generated GRU onnx model into dl4j and got errors.
I will show all my processes and my speculation about the error.

(1) Generate GRU onnx model with PyTorch
Follow by PyTorch document GRU and onnx, I wrote a simple script.

import torch
rnn = torch.nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
torch.onnx.export(rnn, (input, h0), "Single_GRU.onnx")

(2) Import to dl4j, got import error

OnnxFrameworkImporter onnxFrameworkImporter = new OnnxFrameworkImporter();
SameDiff graph = onnxFrameworkImporter.runImport(f.getAbsolutePath(), Collections.emptyMap(), true);
Exception in thread "main" java.lang.IndexOutOfBoundsException: Index 2 out of bounds for length 2
	at java.base/jdk.internal.util.Preconditions.outOfBounds(

(3) Modify onnx file, got another error
I found IndexOutOfBoundsException occurs while accessing GRU onnx node outputs, so I wrote a python script add more outputs to GRU node, which makes an off-standard onnx file.

import onnx
model = onnx.load(r"Single_GRU.onnx")
graph = model.graph
node = graph.node

for i in range(len(node)):
    if node[i].op_type == "GRU":
        for j in range(2,4):
            node[i].output.insert(j, node[i].name + "_output_placeholder_" + str(j))
        print("====="), r"Single_GRU_addoutput.onnx")

Then I got another error.

Error at [D:/a/deeplearning4j/deeplearning4j/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp:97:0]:
gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got 3, 3
Exception in thread "main" java.lang.RuntimeException: Op gruCell with name /GRU failed to execute. Here is the error from c++: Op validation failed
	at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(

@ccen what pytorch version are you using? Try not to generate random 1 off files. If it doesn’t work out of the box then we should fix it. I’ve sometimes found the opset version matters, see our example here:

If you still have issues after updating your call please do file an issue. Whatever it is should be a quick fix. Thanks!

@ccen also of note…GRU was one of the first ones we implemented: deeplearning4j/OnnxOpDeclarations.kt at e5218991026880a4b1ae09dff0750f149469f80c · deeplearning4j/deeplearning4j · GitHub

This should have been a fairly straightforward import mapping 1 to 1. The version can really matter here.

As a new user, I can only put 2 links in a post. So my message is divided into segments. :wink:

(4) My speculation about the error
The error message comes from gruCell.cpp#L97.

In dl4j 1.0.0-beta7 document, I found dl4j has gru and gruCell classes, gruCell does a single time step operation, so it requires inputs rank 2.

My speculation is dl4j uses gruCell operator on onnx gru node, rather than gru operator. Which causes error in (2) and (3). One suspect is in nd4j, onnx name of gruCell is “GRU”.

My PyTorch is 1.13.1, dl4j version is 1.0.0-SNAPSHOT, actually nd4j-cpu-backend-common-1.0.0-20230209.002746-392.jar

Yes opset version often matters, but not in this case. I’ve use different opset version, and the script above is just for brief intro. I’ve file a issue at github.

I don’t think gruCell should map to onnx GRU node, “gruCell does a single time step operation”.
Onnx GRU operator define as here.

@ccen we have a gru layer as well…let me revisit this.