I am trying to put a learning rate parameterspace into each updaters discretespace but its not apparent how to do it.
Second I get an error using an lstm layerspace if I use an activation function other than tanh. Is this expected behavior?
Heres a code snip relative to the question
ParameterSpace<Double> learnRate = new ContinuousParameterSpace(minLR, maxLR);
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace<>(new Activation[]{ Activation.TANH}); //// TODO: 7/22/2020 multi activations not working
DiscreteParameterSpace<WeightInit> weightSpace = new DiscreteParameterSpace<WeightInit>(new WeightInit[]{WeightInit.XAVIER});
//DiscreteParameterSpace<IUpdater> updaterSpace = new DiscreteParameterSpace<IUpdater>(new IUpdater[]{
// (IUpdater) new AdamSpace(learnRate), (IUpdater) new NesterovsSpace(learnRate), (IUpdater) new AdaGradSpace(learnRate)});
MultiLayerSpace mlSpace = new MultiLayerSpace.Builder()
.weightInit(weightSpace)
.updater(new SgdSpace(learnRate))// TODO: 7/21/2020 multiple updaters not working
.addLayer(new LSTMLayerSpace.Builder().activation(activationSpace).nIn(input).nOut(layerSize).build())
.addLayer(new LSTMLayerSpace.Builder().activation(activationSpace).nIn(layerSize).nOut(layerSize).build(), layerCount)
.addLayer(new RnnOutputLayerSpace.Builder()
.nIn(layerSize)
.nOut(output)
Here is the entire class. I haven’t tried the genetic testing examples but this is my somewhat naive implementation of an iterative solution. As you can see Im trying to throw every option at the wall to see what sticks.
public class Search {
public static void multipleSearch(int numberOfSearches, FormattedDataRequester.combination combination, String tick, int miniBatch, int searchLoops, double spaceReducer,
int[] layers, int[] neurons, double[] learnRate, int epochs, int maxTime) throws Exception
{
ArrayList<ArrayList<String>> sumResults = new ArrayList<>();
for (int i = 0; i < numberOfSearches; i++) {
ArrayList<String> n = Search.searchForBestNet(combination,tick,miniBatch,searchLoops,spaceReducer,layers,neurons,learnRate, epochs,maxTime);
sumResults.add(n);
}
String filePath = Constants.logPath + "search_result_" + LocalDate.now().toString() + ".csv";
StringBuilder builder = new StringBuilder();
for (ArrayList<String> arr :
sumResults) {
for (String s :
arr) {
builder.append(s).append("\n");
}
builder.append("\n");
}
FileManager.rewriteFullCSV(filePath, builder.toString());
}
public static ArrayList<String> searchForBestNet(FormattedDataRequester.combination combination, String tick, int miniBatch, int searchLoops, double spaceReducer,
int[] layers, int[] neurons, double[] learnRate, int epochs, int maxTime ) throws Exception
{
ArrayList<String> output = new ArrayList<>();
maxTime = (maxTime/searchLoops); // max hours alloted to each loop TODO: 7/22/2020 this is a crude stopping mechanism that should be adjusted or allow for a stop restart
int candidate = (int) ((layers[1] * neurons[1] * 10)/(spaceReducer/2)); //search 20% of the space, lr added as discrete 10 factor
//int maxTime = (candidate * epochs) * 5; // 5 minutes per candidate <-- this really depends on datasetiterator size
double totalTimeMax = searchLoops * ((candidate * 5) * Math.pow(epochs, searchLoops)); //how long might this run // TODO: 7/22/2020 wrong
System.out.println(totalTimeMax/60 + " hours");
for (int i = 0; i < searchLoops; i++) {
if (i > 0){
List<String> list = FileManager.retreiveFullCSV(Constants.searchModelPath + "\\" + String.valueOf(i-1) + "\\result.csv");
//[0]index [1]score [2]layers [3]neurons [4]learnrate
String[] results = list.get(0).split(",");
layers = minMaxScaler(Integer.valueOf(results[2]), layers, spaceReducer);
neurons = minMaxScaler(Integer.valueOf(results[3]), neurons, spaceReducer);
learnRate = minMaxScaler(Double.valueOf(results[4]), learnRate, spaceReducer);
epochs = (int) (epochs/spaceReducer);
candidate = (layers[1] * neurons[1])/5;
//maxTime = (candidate * epochs) * 5;
}
String s = buildSearch(combination, tick, miniBatch, Constants.searchModelPath + "\\" + String.valueOf(i),
maxTime, candidate, epochs,
layers[0], layers[1], neurons[0], neurons[1],
learnRate[0], learnRate[1]);
output.add(s);
System.out.println("COMPLETED : " + i);
}
System.out.println(" ----- COMPLETED SEARCH ------ ");
return output;
}
private static String buildSearch(FormattedDataRequester.combination combi, String tick, int miniBatch, String directory,
int maxTime, int candidateCount, int epochs,
int minLay, int maxLay, int minLayCount, int maxLayCount,
double minLR, double maxLR) throws Exception
{
DataSetIterator it = FormattedDataRequester.sequenceIterator(combi, tick, miniBatch);
int input = it.inputColumns();
int output = it.totalOutcomes();
//This paramater space is a wide net. We wont test every value within these spaces but
//will perform N tests and use those test to narrow our search
ParameterSpace<Integer> layerSize = new IntegerParameterSpace(minLayCount, maxLayCount);
ParameterSpace<Integer> layerCount = new IntegerParameterSpace(minLay, maxLay);
ParameterSpace<Double> learnRate = new ContinuousParameterSpace(minLR, maxLR);
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace<>(new Activation[]{ Activation.TANH}); //// TODO: 7/22/2020 multi activations not working
DiscreteParameterSpace<WeightInit> weightSpace = new DiscreteParameterSpace<WeightInit>(new WeightInit[]{WeightInit.XAVIER});
//DiscreteParameterSpace<IUpdater> updaterSpace = new DiscreteParameterSpace<IUpdater>(new IUpdater[]{
// (IUpdater) new AdamSpace(learnRate), (IUpdater) new NesterovsSpace(learnRate), (IUpdater) new AdaGradSpace(learnRate)});
MultiLayerSpace mlSpace = new MultiLayerSpace.Builder()
.weightInit(weightSpace)
.updater(new SgdSpace(learnRate))// TODO: 7/21/2020 multiple updaters not working
.addLayer(new LSTMLayerSpace.Builder().activation(activationSpace).nIn(input).nOut(layerSize).build())
.addLayer(new LSTMLayerSpace.Builder().activation(activationSpace).nIn(layerSize).nOut(layerSize).build(), layerCount)
.addLayer(new RnnOutputLayerSpace.Builder()
.nIn(layerSize)
.nOut(output)
.activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.numEpochs(epochs)
.build();
//course search is random
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mlSpace, null);
Class<? extends DataSource> dataSourceClass = SearchDataSource.class;
Properties dataSourceProperties = new Properties();
dataSourceProperties.setProperty("minibatchSize", String.valueOf(miniBatch));
dataSourceProperties.setProperty("combination", combi.toString());
dataSourceProperties.setProperty("tick", tick);
//saver per example
File f = new File(directory);
if (f.exists()) f.delete();
f.mkdir();
ResultSaver modelSaver = new FileModelSaver(directory);
ScoreFunction scoreFunction = new RegressionScoreFunction(RegressionEvaluation.Metric.MSE);
TerminationCondition[] terminationConditions = {
new MaxTimeCondition(maxTime, TimeUnit.MINUTES),
new MaxCandidatesCondition(candidateCount)};
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator)
.dataSource(dataSourceClass, dataSourceProperties)
.modelSaver(modelSaver)
.scoreFunction(scoreFunction)
.terminationConditions(terminationConditions)
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
//StatsStorage ss = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "arbiter.dl4j"));
//runner.addListeners(new ArbiterStatusListener(ss));
//UIServer.getInstance().attach(ss);
runner.execute();
//Print out some basic stats regarding the optimization procedure
String s = "Best score: " + runner.bestScore() + "\n" +
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
System.out.println(s);
StringBuilder sb = new StringBuilder();
sb.append(runner.bestScoreCandidateIndex()).append(",").append(runner.bestScore()).append(",");
//Get all results, and print out details of the best result:
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
MultiLayerNetwork bestModel = (MultiLayerNetwork) bestResult.getResultReference().getResultModel();
System.out.println("\n\nConfiguration of best model:\n");
//System.out.println(bestModel.getLayerWiseConfigurations().toJson());
System.out.println(bestModel.summary());
sb.append(bestModel.getLayers().length).append(",");
sb.append(((FeedForwardLayer) bestModel.getLayers()[0].conf().getLayer()).getNOut()).append(",");
sb.append( bestModel.getLearningRate(0));
FileManager.rewriteFullCSV(directory + "\\result.csv", sb.toString());
//Wait a while before exiting
//Thread.sleep(60000);
//UIServer.getInstance().stop();
return sb.toString();
}
private static int[] minMaxScaler(int center, int[] old, double reducer){
double space = ((old[1] - old[0]) * reducer)/2;
int margin = (int) Math.round(space);
int[] output = new int[]{center- margin, center + margin};
return output;
}
private static double[] minMaxScaler(double center, double[] old, double reducer){
double space = ((old[1] - old[0]) * reducer)/2;
double[] output = new double[]{center- space, center + space};
return output;
}
//Static class that hods the data
public static class SearchDataSource implements DataSource {
private int miniBatch;
private FormattedDataRequester.combination combo;
private String tick;
DataSetIterator trainIter;
DataSetIterator testIter;
public SearchDataSource()
{
}
private void setData(DataSetIterator iter)
{
DataSet dd = iter.next();
ArrayList<DataSet> split = FormattedDataRequester.splitDataSet(dd, .7);
DataSet training = split.get(0);
DataSet test = split.get(1);
DataNormalization normalization = new NormalizerMinMaxScaler();
normalization.fitLabel(true);
normalization.fit(training);
normalization.preProcess(training);
normalization.preProcess(test);
trainIter = new ExistingDataSetIterator(training);
testIter = new ExistingDataSetIterator(test);
}
@Override
public void configure(Properties properties) {
this.miniBatch = Integer.parseInt(properties.getProperty("minibatchSize", "1"));
this.combo = FormattedDataRequester.combination.valueOf(properties.getProperty("combination"));
this.tick = properties.getProperty("tick");
try{
DataSetIterator iter = FormattedDataRequester.sequenceIterator(combo, tick, miniBatch);
setData(iter);
}catch (Exception e){
e.printStackTrace();
}
}
@Override
public Object trainData() {
return trainIter;
}
@Override
public Object testData() {
return testIter;
}
@Override
public Class<?> getDataType() {
return DataSetIterator.class;
}
}
}