I am trying to replace a L2Vertex in an existing working network with a Cosine Similarity vertex using SameDiff (or any other way).
I keep getting errors like
Axis array [-539956880] contains values above array rank (rank=2)
org.nd4j.linalg.exception.ND4JIllegalStateException: Axis array [-539956880] contains values above array rank (rank=2)
with the code below.
.addVertex("CosineDiff", new SameDiffLambdaVertex() {
@Override
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
SDVariable input1 = inputs.getInput(0);
SDVariable input2 = inputs.getInput(1);
return sameDiff.math.cosineSimilarity(input1, input2, 1).reshape(input1.getShape()[0], 1);
}
}, "Content1", "Content2")
Am I trying to approach this the wrong way?