I used ND4J to build a model to make predictions on certain data, and I wanted to implement the same functionality with SameDiff, but it didn’t seem to work as expected. I don’t know if it’s my usage that’s wrong, and if it’s wrong, how can I modify it?The preprocessing and plotting parts of the data are simplified here
Below is the model built by ND4J
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
import java.util.List;
public class HourDataForecastWithND4J {
    static int batch_size;
    public static void main(String[] args) throws Exception {
        // ====================================prep====================================
        INDArray data = Nd4j.rand(17379, 59);
        int total_row = (int) data.shape()[0];
        int val_size = total_row - 21 * 24;
        int train_size = val_size - 60 * 24;
        int casualIndex = (int) data.shape()[1] - 3;
        int cntIndex = (int) data.shape()[1] - 1;
        // Normalization
        INDArray tmp = data.get(NDArrayIndex.all(), NDArrayIndex.interval(cntIndex, cntIndex + 1));
        double mean = tmp.meanNumber().doubleValue();
        double sigma = tmp.stdNumber().doubleValue();
        tmp.subi(mean).divi(sigma);
        INDArrayIndex trainIndex = NDArrayIndex.interval(0, train_size);
        INDArrayIndex valIndex = NDArrayIndex.interval(train_size, val_size);
        INDArrayIndex testIndex = NDArrayIndex.interval(val_size, total_row);
        INDArrayIndex featuresIndex = NDArrayIndex.interval(0, casualIndex);
        INDArrayIndex targetsIndex = NDArrayIndex.interval(casualIndex, cntIndex + 1);
        INDArray train_features = data.get(trainIndex, featuresIndex);
        INDArray train_targets = data.get(trainIndex, targetsIndex);
        INDArray val_features = data.get(valIndex, featuresIndex);
        INDArray val_targets = data.get(valIndex, targetsIndex);
        INDArray test_features = data.get(testIndex, featuresIndex);
        INDArray test_targets = data.get(testIndex, targetsIndex);
        // ====================================param====================================
        // The setting of hyperparameters
        batch_size = 4;
        int seed = 1234;
        int iterations = 6000;
        double learning_rate = 0.5;
        int hidden_nodes = 12;
        int input_nodes = (int) train_features.shape()[1];
        int output_nodes = 1;
        // Initialize the weights
        Random rng = Nd4j.getRandom();
        rng.setSeed(seed);
        INDArray weights_input_to_hidden = Nd4j.rand(0, Math.pow(input_nodes, -0.5), rng, input_nodes, hidden_nodes);
        INDArray weights_hidden_to_output = Nd4j.rand(0, Math.pow(hidden_nodes, -0.5), rng, hidden_nodes, output_nodes);
        // ====================================train====================================
        List<Double> xSeries = new ArrayList<>();
        List<Double> ySeries = new ArrayList<>();
        // Each time, 128 records are randomly extracted from the training dataset for training
        java.util.Random rand = new java.util.Random();
        int[] batchIndexes = new int[batch_size];
        for (int i = 0; i < iterations; i++) {
            for (int j = 0; j < batch_size; j++) {
                batchIndexes[j] = rand.nextInt(train_features.rows());
            }
            // Get training records and goals
            INDArray X = train_features.getRows(batchIndexes);
            INDArray y = train_targets.getRows(batchIndexes).getColumn(2);
            train(X, y, learning_rate, weights_input_to_hidden, weights_hidden_to_output);
            // Print out the training process
            double train_loss = MSE(run(train_features, weights_input_to_hidden, weights_hidden_to_output).transpose(), train_targets.getColumn(2));
            double val_loss = MSE(run(val_features, weights_input_to_hidden, weights_hidden_to_output).transpose(), val_targets.getColumn(2));
            System.out.print("\rProgress: " + String.format("%.1f", 100 * i / (float) iterations)
                                     + "% ... Training loss: " + String.format("%.5f", train_loss)
                                     + " ... Validation loss: " + String.format("%.5f", val_loss));
            System.out.flush();
            xSeries.add(train_loss);
            ySeries.add(val_loss);
        }
        plot(xSeries, ySeries);
        // ====================================test====================================
        // Originally, normalization was done by subtracting the mean and dividing by the variance, and now it is necessary to multiply the variance plus the mean
        INDArray predictions = run(test_features, weights_input_to_hidden, weights_hidden_to_output).transpose().mul(sigma).add(mean);
        INDArray actuals = test_targets.getColumn(2).mul(sigma).add(mean);
        plot(predictions, actuals);
    }
    public static void train(INDArray features, INDArray targets, double lr, INDArray weights_input_to_hidden, INDArray weights_hidden_to_output) {
        // From here, do the forward operation 128X56 to multiply the matrix by a 56X8
        INDArray hidden_inputs = Nd4j.matmul(features, weights_input_to_hidden);// Hide layer input
        INDArray hidden_outputs = activation_function(hidden_inputs);// Hide layer output
        // From here is the hidden layer to the output layer operation
        INDArray final_inputs = Nd4j.matmul(hidden_outputs, weights_hidden_to_output);// The final output layer input
        INDArray final_outputs = final_inputs;// Output of the final output layer (128, 1)
        // Start deploying backpropagation
        INDArray error = Nd4j.math.sub(final_outputs, targets.reshape(batch_size, -1));// Output layer error = actual value minus predicted value
        // Partial derivation of J to Z
        INDArray delta_output = error;
        // The output layer activation function f'(a)=1 has passed
        // Calculate the respective contribution of hidden layers to the error
        INDArray delta_hidden_outputs = Nd4j.matmul(delta_output, weights_hidden_to_output.transpose());
        // Backpropagation of error terms hidden_outputs=f(a)
        INDArray delta_hidden_inputs = hidden_outputs.mul(hidden_outputs.rsub(1.0).mul(delta_hidden_outputs));
        // Weight gradient update (input layer to hidden layer)
        INDArray delta_weights_i_h = features.transpose().mmul(delta_hidden_inputs);
        // Weight gradient update (hidden layer to output layer)
        INDArray delta_weights_h_o = hidden_outputs.transpose().mmul(delta_output);
        // Update the weights
        int n_records = (int) features.shape()[0];
        weights_hidden_to_output.subi(delta_weights_h_o.mul(lr).divi(n_records));// Update using gradient descent hidden-to-output weights
        weights_input_to_hidden.subi(delta_weights_i_h.mul(lr).divi(n_records));// Update using gradient descent input-to-hidden weights
    }
    public static INDArray run(INDArray features, INDArray weights_input_to_hidden, INDArray weights_hidden_to_output) {
        // Hides the inputs and outputs of layers
        INDArray hidden_inputs = Nd4j.matmul(features, weights_input_to_hidden);// Hide layer input
        INDArray hidden_outputs = activation_function(hidden_inputs);// Hide layer output
        // The input and output of the final output layer
        INDArray final_inputs = Nd4j.matmul(hidden_outputs, weights_hidden_to_output);// The input to the final output layer
        INDArray final_outputs = final_inputs;// the output of the final output layer
        return final_outputs;
    }
    public static INDArray activation_function(INDArray hidden_inputs) {
        return Nd4j.nn.sigmoid(hidden_inputs);
    }
    public static double MSE(INDArray y, INDArray Y) {
        INDArray diff = y.sub(Y);
        diff.muli(diff);
        return diff.meanNumber().doubleValue();
    }
    public static void plot(INDArray s1, INDArray s2) {
    }
    public static void plot(List<Double> s1, List<Double> s2) {
    }
}
Below is a model built using sameDiff
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.LecunUniformInitScheme;
import java.util.*;
public class HourDataForecastWithSameDiff {
    static int batch_size = 4;
    public static void main(String[] args) throws Exception {
        // ====================================prep====================================
        INDArray data = Nd4j.rand(17379, 59);
        int total_row = (int) data.shape()[0];
        int val_size = total_row - 21 * 24;
        int train_size = val_size - 60 * 24;
        int casualIndex = (int) data.shape()[1] - 3;
        int cntIndex = (int) data.shape()[1] - 1;
        // Normalization
        INDArray tmp = data.get(NDArrayIndex.all(), NDArrayIndex.interval(cntIndex, cntIndex + 1));
        double mean = tmp.meanNumber().doubleValue();
        double sigma = tmp.stdNumber().doubleValue();
        tmp.subi(mean).divi(sigma);
        INDArrayIndex trainIndex = NDArrayIndex.interval(0, train_size);
        INDArrayIndex valIndex = NDArrayIndex.interval(train_size, val_size);
        INDArrayIndex testIndex = NDArrayIndex.interval(val_size, total_row);
        INDArrayIndex featuresIndex = NDArrayIndex.interval(0, casualIndex);
        INDArrayIndex targetsIndex = NDArrayIndex.interval(casualIndex, cntIndex + 1);
        INDArray train_features = data.get(trainIndex, featuresIndex);
        INDArray train_targets = data.get(trainIndex, targetsIndex);
        INDArray val_features = data.get(valIndex, featuresIndex);
        INDArray val_targets = data.get(valIndex, targetsIndex);
        INDArray test_features = data.get(testIndex, featuresIndex);
        INDArray test_targets = data.get(testIndex, targetsIndex);
        // ====================================param====================================
        batch_size = 4;
        int seed = 1234;
        int iterations = 6000;
        double learning_rate = 0.5;
        int hidden_nodes = 12;
        int input_nodes = (int) train_features.shape()[1];
        int output_nodes = 1;
        // Create a neural network graph
        SameDiff sd = SameDiff.create();
        // The first step is to define the features and labels
        SDVariable train_x = sd.placeHolder("input", DataType.FLOAT, -1, input_nodes);
        SDVariable train_y = sd.placeHolder("label", DataType.FLOAT, -1, output_nodes);
        // Initialize the weights
        SDVariable weights_input_to_hidden = sd.var("weights_input_to_hidden", new LecunUniformInitScheme('c', input_nodes), DataType.FLOAT, input_nodes, hidden_nodes);
        SDVariable weights_hidden_to_output = sd.var("weights_hidden_to_output", new LecunUniformInitScheme('c', hidden_nodes), DataType.FLOAT, hidden_nodes, output_nodes);
        // The second step is to build the diagram
        SDVariable hidden_inputs = sd.mmul("hidden_inputs", train_x, weights_input_to_hidden);
        SDVariable hidden_outputs = sd.nn.sigmoid("hidden_outputs", hidden_inputs);
        SDVariable final_outputs = sd.mmul("final_outputs", hidden_outputs, weights_hidden_to_output);
        // The third step is to calculate the loss and define the training operation
        SDVariable loss = sd.math.squaredDifference(final_outputs, train_y).mean("loss");
        sd.setLossVariables(loss);
        TrainingConfig config = TrainingConfig.builder()
                                              .dataSetFeatureMapping("input")
                                              .dataSetLabelMapping("label")
                                              .minimize(true)
                                              .minimize("loss")
                                              .updater(new Sgd(learning_rate))
                                              .build();
        sd.setTrainingConfig(config);
        // ====================================train====================================
        List<Double> xSeries = new ArrayList<>();
        List<Double> ySeries = new ArrayList<>();
        Map<String, INDArray> placeholderData = new HashMap<>();
        Random rand = new Random();
        // Each time, 128 records are randomly extracted from the training dataset for training
        int[] batchIndexes = new int[batch_size];
        DataSet dataSet = new DataSet();
        for (int i = 0; i < iterations; i++) {
            // Get training records and goals
            for (int j = 0; j < batch_size; j++) {
                batchIndexes[j] = rand.nextInt(train_features.rows());
            }
            INDArray x = train_features.getRows(batchIndexes);
            INDArray y = train_targets.getRows(batchIndexes).getColumn(2);
            dataSet.setFeatures(x);
            dataSet.setLabels(y);
            sd.fit(dataSet);
            placeholderData.put("input", x);
            placeholderData.put("label", y);
            INDArray train_loss = sd.outputSingle(placeholderData, "loss");
            placeholderData.put("input", val_features);
            placeholderData.put("label", val_targets);
            INDArray val_loss = sd.outputSingle(placeholderData, "loss");
            System.out.print("\rtrain_loss: " + train_loss + "...val_loss: " + val_loss);
            xSeries.add(Double.parseDouble(train_loss.toString()));
            xSeries.add(Double.parseDouble(val_loss.toString()));
        }
        plot(xSeries, ySeries);
        // ====================================test====================================
        placeholderData.put("input", test_features);
        placeholderData.put("label", test_targets);
        // Originally, normalization was done by subtracting the mean and dividing by the variance, and now it is necessary to multiply the variance plus the mean
        INDArray predictions = sd.outputSingle(placeholderData, "final_outputs");
        INDArray predictN = predictions.mul(sigma).addi(mean);
        INDArray actualsN = test_targets.getColumn(2).mul(sigma).add(mean);
        plot(predictN, actualsN);
    }
    public static void plot(INDArray s1, INDArray s2) {
    }
    public static void plot(List<Double> s1, List<Double> s2) {
    }
}