Hi,
Hi, I saw that ND4J is supposed to support multiple data types. I am able to create INDArrays with INT or LONG datatype, but I can’t multiply them using mmul (I got an operand unexpected datatype INT instead of HALF error). Following the function calls leads to the gemm methods
@Override
public void gemm(INDArray A, INDArray B, INDArray C, boolean transposeA, boolean transposeB, double alpha,
double beta) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(true, A, B, C);
GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C);
dgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(),
alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), beta, C,
params.getLdc());
} else if (A.data().dataType() == DataType.FLOAT) {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, params.getA(), params.getB(), C);
sgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(),
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
C, params.getLdc());
} else {
DefaultOpExecutioner.validateDataType(DataType.HALF, params.getA(), params.getB(), C);
hgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(),
(float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta,
C, params.getLdc());
}
OpExecutionerUtil.checkForAny(C);
}
this method get the HALF datatype (default) as the expected datatype for the operand leading the “validateDataType” method to trigger and exception for INT or LONG datatype INDArray
So, how are we suppose to use INT or LONG datatype INDArrays? Any help would be welcome.