I try to educate binary classifier. I have 178200 records and i checked that in this data i have a lot of 1 and 0 as a result.
But i get
here is my code:
public class Network {
private static final int CLASSES_COUNT = 2;
private static final int FEATURES_COUNT = 1000;
public void train() throws IOException, InterruptedException {
int seed = 123456;
int numInputs = 1000;
int numOutputs = 2;
int numHiddenNodes = 2 * numInputs + numOutputs;
double learningRate = 0.005;
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new ClassPathResource("data.csv").getFile()));
DataSet allData;
System.out.println("start to read data");
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 17820, FEATURES_COUNT, CLASSES_COUNT);
allData = iterator.next();
//System.out.println("read all data");
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//DataSet testData = iterator.next();
System.out.println("splitted data start to build configuration");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.biasInit(1)
.l2(1e-4)
.updater(new Nesterovs(learningRate, 0.9))
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SIGMOID)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
System.out.println("compiled configuration");
System.out.println("start to fit model");
/*DataSet allData2;
DataSetIterator iterator2 = new RecordReaderDataSetIterator(recordReader, 1700, FEATURES_COUNT, CLASSES_COUNT);
*/
//for(int i = 0; i < 13; i++) {
//iterator.reset();
while (iterator.hasNext()) {
//allData = iterator.next();
//allData.shuffle();
model.fit(trainingData.getFeatures(), trainingData.getLabels());
allData = iterator.next();
}
//}
//System.out.println("ended epoch number : " + i);
//}
System.out.println("fit ended start evaluating");
INDArray output = model.output(testData.getFeatures());
Evaluation eval = new Evaluation(CLASSES_COUNT);
//Evaluation eval = new EvaluationBinary();
eval.eval(testData.getLabels(), output);
//testData.getLabels();
//EvaluationBinary eval = new EvaluationBinary(testData.getLabels());
System.out.println(eval.stats());
//File locationToSave = new File("");
model.save(new File("mynet.zip"));
recordReader.close();
}
}