Fail to get batch output from sd.math.embeddingLookup

I’m trying to create a text classification model with SameDiff module. The model is embedding+LSTM. When testing the sd.math.embeddingLookup op, I found that single input can successfully and correctly fetch the result but batch input return an error. Please help to check the code.
1. Batch Input Code Below:

public class TestEmbedding {
	private SameDiff sd = null;
	
	@Before
	public void before() {
		sd = SameDiff.create();
	}
	
	@Test
	public void test() {
		SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
		SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
		sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
		//
		Map<String,INDArray> map = new HashMap<>();
		INDArray inputArr = Nd4j.createFromArray(new int[][] {{0, 3},{2, 3}});
		System.out.println(inputArr.shapeInfoToString());
		map.put("input", inputArr);
		//forward
		Map<String,INDArray> result = sd.output(map, "embeddingResult");
		System.out.println(result);
	}
	
	@After
	public void destroy() {
		
	}
}

2. Single Input Code Below:

public class TestEmbedding {
	private SameDiff sd = null;
	
	@Before
	public void before() {
		sd = SameDiff.create();
	}
	
	@Test
	public void test() {
		SDVariable input = sd.placeHolder("input", DataType.INT32, 2);
		SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
		sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
		//
		Map<String,INDArray> map = new HashMap<>();
		INDArray inputArr = Nd4j.createFromArray(new int[]{2, 3});
		System.out.println(inputArr.shapeInfoToString());
		map.put("input", inputArr);
		//forward
		Map<String,INDArray> result = sd.output(map, "embeddingResult");
		System.out.println(result);
	}
	
	@After
	public void destroy() {
		
	}
}

@AllenWGX Fix issue reported on forums for multiple inputs by agibsonccc · Pull Request #9779 · deeplearning4j/deeplearning4j · GitHub
Turns out that instead of a matrix you use mulitple inputs.
This worked for me:

 SameDiff sd = SameDiff.create();
        SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
        SDVariable input2 = sd.placeHolder("input2", INT32,2,2);
        SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
        SDVariable embeddingResult = sd.math.embeddingLookup("embeddingResult", lookUpDict, new SDVariable[]{input}, PartitionMode.MOD);
        //
        Map<String,INDArray> map = new HashMap<>();
        INDArray inputArr = Nd4j.createFromArray(0, 3);
        INDArray inputArr2 = Nd4j.createFromArray(2,3);
        System.out.println(inputArr.shapeInfoToString());
        map.put("input", inputArr);
        map.put("input2",inputArr2);
        //forward
        Map<String,INDArray> result = sd.output(map, "embeddingResult");
        System.out.println(result);

Next time you see something like this feel free to file an issue.

Edit: Note part of my PR involved changing the interface here a bit. The code gen was slightly off.

Folllow your commit, my understanding is that the interface of embeddingLookup is changed to multiple inputs, in another word, mini-batch input training/predicting can be supported by using multiple inputs ?

@AllenWGX yes. What you were trying to do was valid it just wasn’t interpreted correctly by the input. Now the input is as you saw in my changes:

  SameDiff sd = SameDiff.create();
        SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
        SDVariable input2 = sd.placeHolder("input2", INT32,2,2);
        SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
        SDVariable embeddingResult = sd.math.embeddingLookup("embeddingResult", lookUpDict, new SDVariable[]{input}, PartitionMode.MOD);
        //
        Map<String,INDArray> map = new HashMap<>();
        INDArray inputArr = Nd4j.createFromArray(0, 3);
        INDArray inputArr2 = Nd4j.createFromArray(2,3);
        System.out.println(inputArr.shapeInfoToString());
        map.put("input", inputArr);
        map.put("input2",inputArr2);
        //forward
        Map<String,INDArray> result = sd.output(map, "embeddingResult");
        System.out.println(result);