I am trying to teach a network how to play snake. I am using 1.0.0-beta6
I have used
-
QLearning
-
Async QLearning
-
Actor-Critic
None of these systems have had any luck. After millions of steps on Async QLearning my snake has not learned a single thing and still runs straight into walls.
Here is how my network is setup
I pass a 256 double array that represents my snake board
1=apple
0=free space
.25 = body
.3+direction=head (.1<=direction<=.4)
I then have my neural network choose 1,2,3 forward,left,right.
My Reward function is
if(!snake.inGame()) {
return -1.0; //dies
}
if(snake.gotApple()) {
return 1.0; //got apple
//return 5.0+.37*(snake.getLength());
}
return 0.0; //doesn't die or get apple this move
My Async network is setup like so
public static AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration TOY_ASYNC_QL =
new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(
123, //Random seed
400, //Max step By epoch
5000000, //Max step
16, //Number of threads
25, //t_max
10, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.98, //gamma
10.0, //td-error clipping
0.15f, //min epsilon
30000 //num step for eps greedy anneal
);
public static DQNFactoryStdDense.Configuration MALMO_NET = DQNFactoryStdDense.Configuration.builder().l2(0.001)
.updater(new Adam(0.0025)).numHiddenNodes(64).numLayer(3).build();
I can’t figure out how to get this thing to work is one of my settings wrong or something.
Addtionally and probably more frustratingly is the fact that whenever I try to use A3C it crashes after maybe like 10000 steps saying
Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[ ?, ?, ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)
Is there any reason why this is happening? I have tried tweaking the network and my reward function to no avail to stop this error.
Deep Learning seems like such a fun thing to explore but when I can’t even get the library to work it becomes very frustrating.
I will appreciate any help that ya’ll can offer.