I’m trying to translate some pytorch into dl4j, and I need to be able to use gather on some INDArray’s. I’ve tried INDArray.gather, Nd4j.gather(), and Nd4j.baseOps.gather(), but none of them are methods I can access. Do I need to import anything special to be able to use gather() in DL4J?
Are you using beta7? If yes, what exactly isn’t working? How are you trying to use it?
Yeah sorry that wasn’t a very specific question.
I am indeed using beta7. Here’s the code:
// batchsize * 1
INDArray rewards_t = Nd4j.create(rewards);
// batchsize * 1
INDArray dones_t = Nd4j.create(dones);
// batchsize * action_size
INDArray qVals = this.evaluateBatch(states);
// want the java equivalent of:
// qvals = qvals.gather(1, actions_t.unsqueeze(1)).squeeze(1)
So i want to be able to gather this qval INDArray but I’m not sure how in DL4J. My next problem after sorting out gather would be squeeze and unsqueeze so if anyone knows how to do that as well I’d be over the moon
It looks like there was an oversight in adding the Factory getter to the Nd4j class.
As a workaround, you can instantiate the factory directly with NDBase base = new NDBase();
and then you’ll have access to gather
as well as squeeze
. The equivalent of unsequeeze
is expandDims
, which is also available on that class.