Custom Loss Function and Gradient

Hello, I recently got started with dl4j (am a beginner), and I thought it would be cool if I could teach an AI to play snake using a2c (maybe it could work better than dqn?), and got the feed forward and everything set up.
However, the actual loss function and training is what’s throwing me off. To start, I take in 3 inputs for calculating the loss based on output from a state and new state:

  • The value of the old and new states (oldvalue and newvalue)
  • The probabilities of the network taking all possible actions in the old state (probs)
    I calculate the loss based on that, which works fine, but trying to calculate the gradients from that is where the errors come in.
    I looked into SameDiffLoss, but I figured that I couldn’t feed all my inputs in with that, and external errors (I am kind of going for something like that), and the linear regression example from samediff examples on github, but doing something like that doesn’t seem to work.
    I might have done something stupid, so here is my code for the training part:
    *I didn’t add the backpropagation because just getting the gradMap is throwing an error
	void learn(INDArray oldstate, INDArray newstate, double[] act, double reward, boolean done) {
		SameDiff sd = SameDiff.create();
		INDArray[] output = model.output(oldstate);
		//create vars
		SDVariable action = sd.var(Nd4j.create(act));
		SDVariable probs = sd.placeHolder("probs", DataType.FLOAT, 1, 3);
		SDVariable oldvalue = sd.placeHolder("oldvalue", DataType.FLOAT, 1, 1);
		SDVariable newvalue = sd.placeHolder("newvalue", DataType.FLOAT, 1, 1);
		action = sd.squeeze(action, 0);
		probs = sd.squeeze(probs, 0);
		oldvalue = sd.squeeze(oldvalue, 0);
		newvalue = sd.squeeze(newvalue, 0);
		//calculate total loss
		SDVariable logprobs = sd.math().log(probs);
		SDVariable logprob = logprobs.get(sd.argmax(action, 0));
		SDVariable mask = sd.constant(1).sub(sd.constant(done ? 1 : 0));
		SDVariable delta = newvalue.mul(gamma).mul(mask).sub(oldvalue).add(reward);
		SDVariable actorloss = logprob.mul(delta).neg();
		SDVariable criticloss = sd.math().pow(delta, 2);
		SDVariable totalloss = actorloss.add(criticloss);
		totalloss.rename("loss");
		sd.addLossVariable(totalloss);
		//calculate gradients
		Map<String,INDArray> placeholderData = new HashMap<>();
		INDArray probsArr = output[0];
		INDArray oldvalueArr = output[1];
		INDArray newvalueArr = model.output(newstate)[1];
        placeholderData.put("probs", probsArr);
        placeholderData.put("oldvalue", oldvalueArr);
        placeholderData.put("newvalue", newvalueArr);
        INDArray lossData = sd.output(placeholderData, "loss").get("loss");
        Map<String,INDArray> gradMap = sd.calculateGradients(placeholderData, "probs");
	}

@SNICKRS we actually do have external errors. Could you clarify if you looked at that? It’d look something like this:

 INDArray externalGrad = Nd4j.linspace(1, 12, 12).reshape(3, 4);

        SameDiff sd = SameDiff.create();
        SDVariable var = sd.var("var", externalGrad);
        SDVariable out = var.mul("out", 0.5);

        Map<String, INDArray> gradMap = new HashMap<>();
        gradMap.put("out", externalGrad);
        ExternalErrorsFunction fn = SameDiffUtils.externalErrors(sd, null, out);

        Map<String, INDArray> m = new HashMap<>();
        m.put("out-grad", externalGrad);
        Map<String, INDArray> grads = sd.calculateGradients(m, sd.getVariables().keySet());

        INDArray gradVar = grads.get(var.name());

That’s what we used with our old reinforcement learning library and it worked pretty well.