Fail to get batch output from sd.math.embeddingLookup

I’m trying to create a text classification model with SameDiff module. The model is embedding+LSTM. When testing the sd.math.embeddingLookup op, I found that single input can successfully and correctly fetch the result but batch input return an error. Please help to check the code.
1. Batch Input Code Below:

public class TestEmbedding {
	private SameDiff sd = null;
	
	@Before
	public void before() {
		sd = SameDiff.create();
	}
	
	@Test
	public void test() {
		SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
		SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
		sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
		//
		Map<String,INDArray> map = new HashMap<>();
		INDArray inputArr = Nd4j.createFromArray(new int[][] {{0, 3},{2, 3}});
		System.out.println(inputArr.shapeInfoToString());
		map.put("input", inputArr);
		//forward
		Map<String,INDArray> result = sd.output(map, "embeddingResult");
		System.out.println(result);
	}
	
	@After
	public void destroy() {
		
	}
}

2. Single Input Code Below:

public class TestEmbedding {
	private SameDiff sd = null;
	
	@Before
	public void before() {
		sd = SameDiff.create();
	}
	
	@Test
	public void test() {
		SDVariable input = sd.placeHolder("input", DataType.INT32, 2);
		SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
		sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
		//
		Map<String,INDArray> map = new HashMap<>();
		INDArray inputArr = Nd4j.createFromArray(new int[]{2, 3});
		System.out.println(inputArr.shapeInfoToString());
		map.put("input", inputArr);
		//forward
		Map<String,INDArray> result = sd.output(map, "embeddingResult");
		System.out.println(result);
	}
	
	@After
	public void destroy() {
		
	}
}

@AllenWGX Fix issue reported on forums for multiple inputs by agibsonccc · Pull Request #9779 · deeplearning4j/deeplearning4j · GitHub
Turns out that instead of a matrix you use mulitple inputs.
This worked for me:

 SameDiff sd = SameDiff.create();
        SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
        SDVariable input2 = sd.placeHolder("input2", INT32,2,2);
        SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
        SDVariable embeddingResult = sd.math.embeddingLookup("embeddingResult", lookUpDict, new SDVariable[]{input}, PartitionMode.MOD);
        //
        Map<String,INDArray> map = new HashMap<>();
        INDArray inputArr = Nd4j.createFromArray(0, 3);
        INDArray inputArr2 = Nd4j.createFromArray(2,3);
        System.out.println(inputArr.shapeInfoToString());
        map.put("input", inputArr);
        map.put("input2",inputArr2);
        //forward
        Map<String,INDArray> result = sd.output(map, "embeddingResult");
        System.out.println(result);

Next time you see something like this feel free to file an issue.

Edit: Note part of my PR involved changing the interface here a bit. The code gen was slightly off.

Folllow your commit, my understanding is that the interface of embeddingLookup is changed to multiple inputs, in another word, mini-batch input training/predicting can be supported by using multiple inputs ?

@AllenWGX yes. What you were trying to do was valid it just wasn’t interpreted correctly by the input. Now the input is as you saw in my changes:

  SameDiff sd = SameDiff.create();
        SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
        SDVariable input2 = sd.placeHolder("input2", INT32,2,2);
        SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
        SDVariable embeddingResult = sd.math.embeddingLookup("embeddingResult", lookUpDict, new SDVariable[]{input}, PartitionMode.MOD);
        //
        Map<String,INDArray> map = new HashMap<>();
        INDArray inputArr = Nd4j.createFromArray(0, 3);
        INDArray inputArr2 = Nd4j.createFromArray(2,3);
        System.out.println(inputArr.shapeInfoToString());
        map.put("input", inputArr);
        map.put("input2",inputArr2);
        //forward
        Map<String,INDArray> result = sd.output(map, "embeddingResult");
        System.out.println(result);

OK, I’ll try it using 1.0.0-SNAPSHOT.

@agibsonccc Just now, I use 1.0.0-SNAPSHOT to run the code below, but still got some error message. If the input array has only one variable(just as your code), it works. But I’m testing mini-batch input, so I use multiple input(input variable array) and got failure. Please check it.

public class TestEmbedding {
	private SameDiff sd = null;
	
	@Before
	public void before() {
		sd = SameDiff.create();
	}
	
	@Test
	public void test2() {
		SDVariable input1 = sd.placeHolder("input1", DataType.INT32, 1, 2);
		SDVariable input2 = sd.placeHolder("input2", DataType.INT32, 1, 2);
		SDVariable[] input = new SDVariable[] {input1, input2};	//if this array only has "input1" variable, that works.
		SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
		sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
		//mock data
		Map<String,INDArray> map = new HashMap<>();
		INDArray inputArr1 = Nd4j.createFromArray(new int[] {2, 3});
		INDArray inputArr2 = Nd4j.createFromArray(new int[] {0, 3});
		System.out.println(inputArr1.shapeInfoToString());
		map.put("input1", inputArr1);
		map.put("input2", inputArr2);
		//forward pass
		Map<String,INDArray> result = sd.output(map, "embeddingResult");
		System.out.println(result);
	}
	
	@After
	public void destroy() {
		
	}
}