Hello, I recently got started with dl4j (am a beginner), and I thought it would be cool if I could teach an AI to play snake using a2c (maybe it could work better than dqn?), and got the feed forward and everything set up.
However, the actual loss function and training is what’s throwing me off. To start, I take in 3 inputs for calculating the loss based on output from a state and new state:
- The value of the old and new states (oldvalue and newvalue)
- The probabilities of the network taking all possible actions in the old state (probs)
I calculate the loss based on that, which works fine, but trying to calculate the gradients from that is where the errors come in.
I looked into SameDiffLoss, but I figured that I couldn’t feed all my inputs in with that, and external errors (I am kind of going for something like that), and the linear regression example from samediff examples on github, but doing something like that doesn’t seem to work.
I might have done something stupid, so here is my code for the training part:
*I didn’t add the backpropagation because just getting the gradMap is throwing an error
void learn(INDArray oldstate, INDArray newstate, double[] act, double reward, boolean done) {
SameDiff sd = SameDiff.create();
INDArray[] output = model.output(oldstate);
//create vars
SDVariable action = sd.var(Nd4j.create(act));
SDVariable probs = sd.placeHolder("probs", DataType.FLOAT, 1, 3);
SDVariable oldvalue = sd.placeHolder("oldvalue", DataType.FLOAT, 1, 1);
SDVariable newvalue = sd.placeHolder("newvalue", DataType.FLOAT, 1, 1);
action = sd.squeeze(action, 0);
probs = sd.squeeze(probs, 0);
oldvalue = sd.squeeze(oldvalue, 0);
newvalue = sd.squeeze(newvalue, 0);
//calculate total loss
SDVariable logprobs = sd.math().log(probs);
SDVariable logprob = logprobs.get(sd.argmax(action, 0));
SDVariable mask = sd.constant(1).sub(sd.constant(done ? 1 : 0));
SDVariable delta = newvalue.mul(gamma).mul(mask).sub(oldvalue).add(reward);
SDVariable actorloss = logprob.mul(delta).neg();
SDVariable criticloss = sd.math().pow(delta, 2);
SDVariable totalloss = actorloss.add(criticloss);
totalloss.rename("loss");
sd.addLossVariable(totalloss);
//calculate gradients
Map<String,INDArray> placeholderData = new HashMap<>();
INDArray probsArr = output[0];
INDArray oldvalueArr = output[1];
INDArray newvalueArr = model.output(newstate)[1];
placeholderData.put("probs", probsArr);
placeholderData.put("oldvalue", oldvalueArr);
placeholderData.put("newvalue", newvalueArr);
INDArray lossData = sd.output(placeholderData, "loss").get("loss");
Map<String,INDArray> gradMap = sd.calculateGradients(placeholderData, "probs");
}