Prediction Makes No Sense

Hey, I’m new to the whole Machine Learning scene, and I’m playing around with DL4J. I’m trying to get it to predict the classification of an Iris based on the info given (a very common first program with machine learning, I hear).

During training (testAndTrain), it comes out with roughly 98% accuracy using the Iris Flower Dataset that is also pretty commonly used for this problem.

The issue is, although it seems to be working well, when I go to actually get it to predict something, it almost always predicts that the input is a “1” Iris (versicolor) even when the input is a “0” Iris (setosa) or “2” Iris (virginica)

For example, here’s some code:

private static int FEATURE_COUNT = 4;
    private static int CLASS_COUNT = 3;
    private static int BATCH_SIZE = 150;

    public static final String MODEL_FILE_PATH = "model.txt";

    public static void main(String[] args) throws IOException, InterruptedException {
        int numLinesToSkip = 0;
        char delimiter = ',';
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

        DataSetIterator iterator =
                new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, FEATURE_COUNT, CLASS_COUNT);
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);

        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData);
        normalizer.transform(trainingData);
        normalizer.transform(testData);

        final int numInputs = 4;
        int outputNum = 3;
        long seed = 6;

        System.out.println("Build model...");

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(new Sgd(0.1))
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
                        .build())
                .layer(new DenseLayer.Builder().nIn(3).nOut(3)
                        .build())
                .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(3).nOut(outputNum).build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(new ScoreIterationListener(100));

        for(int i=0; i<1000; i++ ) {
            model.fit(trainingData);
        }

        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);
        System.out.println(eval.stats());

        File inputFile = new ClassPathResource("oneIris.txt").getFile();
        RecordReader inputReader = new CSVRecordReader(numLinesToSkip, delimiter);

        inputReader.initialize(new FileSplit(inputFile));

        DataSetIterator iterator2 =
                new RecordReaderDataSetIterator(inputReader, 1);

        INDArray output2 = model.output(iterator2);
        System.out.println("Prediction: ");
        System.out.println(output2);
    }

The output I get is some relatively accurate training data, followed by

Prediction: 
[[    0.0061,    0.9917,    0.0022]]

Which is not expected, as oneIris.txt contains only one line (which is shown below), so I’d expect a high value for the 0, not the 1.

5.1,3.5,1.4,0.2

Which is exactly identical to the “0” flower on the first line of iris.txt.

Am I doing something wrong?

@AustinDart your data isn’t being normalized.
You are passing in an iterator without setting a normalizer. EIther use a different signature (eg: model.output(INDArray) with pre normalized data or call setNormalizer on the iterator2 so it normalizes the data before passing the result to the underlying predict function.