Want to use gather - how?

I’m trying to translate some pytorch into dl4j, and I need to be able to use gather on some INDArray’s. I’ve tried INDArray.gather, Nd4j.gather(), and Nd4j.baseOps.gather(), but none of them are methods I can access. Do I need to import anything special to be able to use gather() in DL4J?

Are you using beta7? If yes, what exactly isn’t working? How are you trying to use it?

Yeah sorry that wasn’t a very specific question.

I am indeed using beta7. Here’s the code:

        // batchsize * 1
        INDArray rewards_t = Nd4j.create(rewards);

        // batchsize * 1
        INDArray dones_t = Nd4j.create(dones);

        // batchsize * action_size
        INDArray qVals = this.evaluateBatch(states);
        // want the java equivalent of:
        // qvals = qvals.gather(1, actions_t.unsqueeze(1)).squeeze(1)

So i want to be able to gather this qval INDArray but I’m not sure how in DL4J. My next problem after sorting out gather would be squeeze and unsqueeze so if anyone knows how to do that as well I’d be over the moon :slight_smile:

It looks like there was an oversight in adding the Factory getter to the Nd4j class.

As a workaround, you can instantiate the factory directly with NDBase base = new NDBase(); and then you’ll have access to gather as well as squeeze. The equivalent of unsequeeze is expandDims, which is also available on that class.