Hi there. I have a situation where I need to update a 4D array with the same constant value at different indices which are the argmax()
ones for the last dimension of original values. I did try different techniques directly with INDArrays
and at last found some (definitely not efficient, but working) solution (first reshaping, then putScalar()
within a loop - INDArray.put(indices…)
didn’t work for me because my indices are at least a matrix). The problem comes with SameDiff
, because there it’s really hard to reproduce the same logic (loop etc.). The only relevant solution would be using scatterNdUpdate()
or scatterNdAdd()
but I haven’t found any unit tests which would somehow explain how it works. Those operations demand a specific updates array and I have no idea what it should contain. Analyzing the C++ code of those OPs didn’t give me any answer unfortunately. Thanks in advance!
@partarstu do you mind posting some samples of your code? I can try to rewrite it or propose a solution based on that. In the mean time I’ll work on some scatter examples. I do think you have the right idea. You’ll find the c++ tests are a good place to start:
Just imagine NDArrayFactory as Nd4j.create and NDArray indices(.,…) or what have you as new NDArray instead.
If you can’t find anything in java please try searching the op name in c++. The op tests are fairly readable and may help you find otherwise missing context.
Apologies for the docs some work needs to be done there for sure.
I think some automatic extraction could be done to come up with better usage examples there…let me think on that.
@agibsonccc , thanks for the reference to the tests, I’ll take a look at them to figure it out.
I find them quite fast inside the GitHub repo, so shouldn’t be an issue.
Not a problem. Usually if the docs are not present, static code analysis is quite enough for me when it comes to Java. With C++ it’s a bit difficult, but still possible.
I think posting an example will be quite confusing because it’s still in a very raw state. Describing my specific use case would be easier to understand. The idea is to get the max values using the last dimension and zero out all others (think of it as sparse one-shot activations). The requirement is: there should be at most one non-zero value in the last dimension. The simplest solution would be to get the max values, subtract them from all original ones and replace all zeros with the required value and all negative ones with zeros. The problem is that the possibility of having more than one same max values is high and this will automatically break the requirement. So the only option is using argmax() to retrieve the indices which need to be non-zero and then somehow set the values for these indices within SameDiff in order to make the back-prop for my UDF)
@agibsonccc , I finally got the idea of scatterAdd OP. C++ tests didn’t shed any light upon my questions. But then I read the official pytorch docs and finally understood how this OP should work. I also understood it’s not the one I need to solve my problem. I needed any OP which would allow to make an update (or add) of the values residing at specific indices without any grouping/segmentation etc. So my workaround was reshaping the original arrays into vectors and calling SDVariable.put(…) in order to take the updates from one array and put them into the array I needed. It’s quite slow but I haven’t found any other workarounds.
@partarstu Internally I’ve started using create_view for that. Underneath that’s what put(…) does.
scatterUpdate should do that for you. Did you try that? There’s a number of scatter ops like add/subtract/…