I am running the following in a JUnit test.
INDArray conditions = Nd4j.ones(1, 5); INDArray testWhere = Nd4j.where(conditions, null, null);
But get the following error:
java.lang.RuntimeException: Data types validation failed at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(NativeOpExecutioner.java:1700) at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.calculateOutputShape(NativeOpExecutioner.java:1613) at org.nd4j.linalg.factory.Nd4j.where(Nd4j.java:5559)
I also see
Op [where_np] failed check for input , DataType: [FLOAT] Failed to calculate output shapes for op ... (goes on to list the input args and their shapes)
in the logs.
This is using 1.0.0-beta7 of nd4j-api and nd4j-native-platform. I am using Java 15 for my project.
I saw a similar post where there was a shape mismatch issue with concat. I was expecting this to go through and simply return the indices of my conditions array. Happy to add in further details to help investigate this.