I’m trying to create a text classification model with SameDiff module. The model is embedding+LSTM. When testing the sd.math.embeddingLookup op, I found that single input can successfully and correctly fetch the result but batch input return an error. Please help to check the code.
1. Batch Input Code Below:
public class TestEmbedding {
private SameDiff sd = null;
@Before
public void before() {
sd = SameDiff.create();
}
@Test
public void test() {
SDVariable input = sd.placeHolder("input", DataType.INT32, -1, 2);
SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
//
Map<String,INDArray> map = new HashMap<>();
INDArray inputArr = Nd4j.createFromArray(new int[][] {{0, 3},{2, 3}});
System.out.println(inputArr.shapeInfoToString());
map.put("input", inputArr);
//forward
Map<String,INDArray> result = sd.output(map, "embeddingResult");
System.out.println(result);
}
@After
public void destroy() {
}
}
2. Single Input Code Below:
public class TestEmbedding {
private SameDiff sd = null;
@Before
public void before() {
sd = SameDiff.create();
}
@Test
public void test() {
SDVariable input = sd.placeHolder("input", DataType.INT32, 2);
SDVariable lookUpDict = sd.var("lookUpDict",new XavierInitScheme('c', 4, 8), DataType.FLOAT, 4, 8);
sd.math.embeddingLookup("embeddingResult", lookUpDict, input, PartitionMode.MOD);
//
Map<String,INDArray> map = new HashMap<>();
INDArray inputArr = Nd4j.createFromArray(new int[]{2, 3});
System.out.println(inputArr.shapeInfoToString());
map.put("input", inputArr);
//forward
Map<String,INDArray> result = sd.output(map, "embeddingResult");
System.out.println(result);
}
@After
public void destroy() {
}
}