My question is how can I train an integer (STOCK_FK) instead of the classification (.classification(finalSchema.getIndexOfColumn(“STOCK_FK”), 3))?..
Schema schema = new Schema.Builder()
.addColumnsInteger("ID", "HZW_FK")
.addColumnInteger("STOCK_FK")
.addColumnsInteger("DISPO_MONTH", "DISPO_YEAR")
.addColumnDouble("DISPO_AMOUNT")
.build();
DataAnalysis analysis = AnalyzeLocal.analyze(schema, recordReader);
HtmlAnalysis.createHtmlAnalysisFile(analysis, new File("C:/Disposition/analysis.html"));
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.removeColumns("ID", "DISPO_MONTH", "DISPO_YEAR")
.normalize("HZW_FK", Normalize.MinMax, analysis)
.normalize("STOCK_FK", Normalize.MinMax, analysis)
.normalize("DISPO_AMOUNT", Normalize.Log2MeanExcludingMin, analysis)
.build();
Schema finalSchema = transformProcess.getFinalSchema();
TransformProcessRecordReader trainRecordReader = new TransformProcessRecordReader(new CSVRecordReader(), transformProcess);
trainRecordReader.initialize(inputSplit);
int batchSize = 30;
RecordReaderDataSetIterator trainIterator = new RecordReaderDataSetIterator.Builder(trainRecordReader, batchSize)
.classification(finalSchema.getIndexOfColumn("STOCK_FK"), 3)
.build();
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.seed(0xC0FFEE)
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.updater(new Adam.Builder().learningRate(0.001).build())
.l2(0.0000316)
.list(
new DenseLayer.Builder().nOut(25).build(),
new DenseLayer.Builder().nOut(25).build(),
new DenseLayer.Builder().nOut(25).build(),
new DenseLayer.Builder().nOut(25).build(),
new DenseLayer.Builder().nOut(25).build(),
new OutputLayer.Builder(new LossMCXENT()).nOut(3).activation(Activation.SOFTMAX).build()
)
.setInputType(InputType.feedForward(finalSchema.numColumns() - 1))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
model.addListeners(new ScoreIterationListener(1));
model.addListeners(new StatsListener(statsStorage, 250));
model.fit(trainIterator, 59);
TransformProcessRecordReader testRecordReader = new TransformProcessRecordReader(new CSVRecordReader(), transformProcess);
testRecordReader.initialize( new FileSplit(new File("C:/Disposition/Test/")));
RecordReaderDataSetIterator testIterator = new RecordReaderDataSetIterator.Builder(testRecordReader, batchSize)
.classification(finalSchema.getIndexOfColumn("STOCK_FK"), 3)
.build();
Evaluation evaluate = model.evaluate(testIterator);
System.out.println(evaluate.stats());
System.out.println("MCC: "+evaluate.matthewsCorrelation(EvaluationAveraging.Macro));
File modelSave = new File("C:/Disposition/model.bin");
model.save(modelSave);
ModelSerializer.addObjectToFile(modelSave, "dataanalysis", analysis.toJson());
ModelSerializer.addObjectToFile(modelSave, "schema", finalSchema.toJson());
}
}