The Loss of the SameDiff converges to the same value

@agibsonccc I implemented an sLSTM ( xLSTM) model with residual network using SameDiff’s. However, the loss always converges to the same value, and the graph page in uiServer does not display any content. Can you help me identify the issue in my code? The official documentation has very few complex examples of SameDiff. Could you simplify my case into an example for beginners to learn from? Thank you!
How do I upload the code file?

public class SLSTMCell {
    private SameDiff sd;
    private SDVariable w_i, w_f, w_o, w_z;
    private SDVariable r_i, r_f, r_o, r_z;
    private SDVariable b_i, b_f, b_o, b_z;
    private SDVariable hiddenState, cellState, nState, mState;
    private int inputDim, hiddenDim;
    private long weightsCount, biasCount;
    private Float zeroValue = Float.valueOf(0f);

    public SLSTMCell(SameDiff sd, int layerId, int inputDim, int hiddenDim, int batchSize) {
        this.sd = sd;
        this.inputDim = inputDim;
        this.hiddenDim = hiddenDim;
        initializeParameters(layerId);
        initializeStates(layerId, batchSize);
    }

    private void initializeParameters(int layerId) {
        Nd4j.getRandom().setSeed(System.currentTimeMillis());
        long[] wShape = new long[]{inputDim, hiddenDim};        
        double wSqrtScale = Math.sqrt(2.0 / (inputDim + hiddenDim));        
        INDArray xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);        
        w_i = sd.var(layerId + "w_i", xavierInit); // hiddenDim x inputDim
        xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);        
        w_f = sd.var(layerId + "w_f", xavierInit); 
        xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);        
        w_o = sd.var(layerId + "w_o",  xavierInit); 
        xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);        
        w_z = sd.var(layerId + "w_z", xavierInit); 
        
        long[] rShape = new long[]{hiddenDim, hiddenDim};        
        double rSqrtScale = Math.sqrt(2.0 / (hiddenDim + hiddenDim));        
        xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);        
        r_i = sd.var(layerId + "r_i", xavierInit); // hiddenDim x hiddenDim
        xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
        r_f = sd.var(layerId + "r_f", xavierInit);
        xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
        r_o = sd.var(layerId + "r_o", xavierInit);
        xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
        r_z = sd.var(layerId + "r_z", xavierInit);

        weightsCount = inputDim * hiddenDim * 8;

        b_i = sd.var(layerId + "b_i", Nd4j.randn(1, hiddenDim).mul(0.01));
        b_f = sd.var(layerId + "b_f", Nd4j.randn(1, hiddenDim).mul(0.01));
        b_o = sd.var(layerId + "b_o", Nd4j.randn(1, hiddenDim).mul(0.01));
        b_z = sd.var(layerId + "b_z", Nd4j.randn(1, hiddenDim).mul(0.01));

        biasCount = hiddenDim * 4;
    }

    private void initializeStates(int layerId, int batchSize) {
        hiddenState = sd.constant(layerId + "hs", Nd4j.zeros(batchSize, hiddenDim));
        cellState = sd.constant(layerId + "cs", Nd4j.zeros(batchSize, hiddenDim));
        nState = sd.constant(layerId + "ns", Nd4j.zeros(batchSize, hiddenDim));
        mState = sd.constant(layerId + "ms", Nd4j.zeros(batchSize, hiddenDim));
    }

    public SDVariable forward(SDVariable input) {
        SDVariable i_tilda = sd.math.add(sd.mmul(input, w_i), sd.mmul(hiddenState, r_i)).add(b_i);
        SDVariable f_tilda = sd.math.add(sd.mmul(input, w_f), sd.mmul(hiddenState, r_f)).add(b_f);
        SDVariable o_tilda = sd.math.add(sd.mmul(input, w_o), sd.mmul(hiddenState, r_o)).add(b_o);
        SDVariable z_tilda = sd.math.add(sd.mmul(input, w_z), sd.mmul(hiddenState, r_z)).add(b_z);

        SDVariable i_t = sd.math.exp(i_tilda);
        SDVariable f_t = sd.nn.sigmoid(f_tilda);
        SDVariable newMState = sd.math.max(sd.math.log(f_t).add(mState), sd.math.log(i_t));

        sd.assign(mState, newMState);

        SDVariable i_prime = sd.math.exp(sd.math.log(i_t).sub(mState));
        SDVariable f_prime = sd.math.exp(sd.math.log(f_t).add(mState).sub(mState));

        sd.assign(cellState, f_prime.mul(cellState).add(i_prime.mul(sd.nn.tanh(z_tilda))));
        sd.assign(nState, f_prime.mul(nState).add(i_prime));

        SDVariable c_hat = cellState.div(nState);
        sd.assign(hiddenState, sd.nn.sigmoid(o_tilda).mul(sd.nn.tanh(c_hat)));

        return hiddenState;
    }

    public void resetStates(int batchSize) {
        // Reset hidden states to zero  
        hiddenState.assign(zeroValue);
        cellState.assign(zeroValue);
        nState.assign(zeroValue);
        mState.assign(zeroValue);
    }

    public long getWeightsCount() {
        return weightsCount;
    }

    public void setWeightsCount(long weightsCount) {
        this.weightsCount = weightsCount;
    }

    public long getBiasCount() {
        return biasCount;
    }

    public void setBiasCount(long biasCount) {
        this.biasCount = biasCount;
    }
}
//-----------------------------------The core part of XLstmResidualNetwork class-------------------------
    public XLstmResidualNetwork(int inputDim, int[] hiddenDims, boolean useMSLSTM) {
        this.sd = SameDiff.create();
        this.featuresCount = inputDim;
        this.rnnNeurons = hiddenDims;
        this.useMSLSTM = useMSLSTM;
    }

    public SDVariable buildNetwork(SDVariable input) {
        Nd4j.getRandom().setSeed(System.currentTimeMillis());
        SDVariable currentHiddenState = null;
        int currentInputDim = featuresCount;

        xLstmLayers = new ArrayList<>(rnnNeurons.length);
        for (int i = 0; i < rnnNeurons.length; i++) {
            int hiddenDim = rnnNeurons[i];
            SLSTMCell cell = new SLSTMCell(sd, i, currentInputDim, hiddenDim, batchSize);
            xLstmLayers.add(cell);
            currentInputDim = hiddenDim;
        }
        SDVariable x = input;
        long timeStep = input.getShape()[2]; // Assuming input shape is [batchSize, inputDim, timeStep]

        for (int t = 0; t < timeStep; t++) {
            SDVariable x_t = input.get(SDIndex.all(), SDIndex.all(), SDIndex.point(t));

            for (int i = 0; i < rnnNeurons.length; i++) {
                SLSTMCell cell = xLstmLayers.get(i);
                currentHiddenState = cell.forward(x_t);
                // Add residual connection
                if (i > 0) {
                    //if rnnNeurons[i]!=rnnNeurons[i-1],add denseLayer
                    x_t = sd.math.add(currentHiddenState, x_t);
                } else {
                    // The input for the next layer is the output of the current layer
                    x_t = currentHiddenState;
                }
            }
        }

        // Add dense layer
        SDVariable output = dense(currentHiddenState, currentInputDim, 1);

        return output;
    }

    private SDVariable dense(SDVariable input, int inputSize, int outputSize) {
        // Xavier初始化
        INDArray xavierInit = Nd4j.randn(new long[]{inputSize, outputSize}).muli(Math.sqrt(2.0 / (inputSize + outputSize)));
        SDVariable w = sd.var("w", xavierInit);
        SDVariable b = sd.var("b", new XavierInitScheme('f', 1, outputSize));
        SDVariable output = sd.mmul(input, w).add("output", b);
        return output;
    }

    public void train(DataSetIterator trainIter, DataSetIterator validateDataSetIterator, String modelSaveFileName) {
        modelFileNamesFilePath = Paths.get(modelSaveFileName + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMddHHmmss")) + ".txt");

        // Define placeholders for input and labels
        SDVariable input = sd.placeHolder("input", DataType.FLOAT, -1, featuresCount, timeStep);  //-1表示不确定
        SDVariable labels = sd.placeHolder("labels", DataType.FLOAT, -1, 1);

        // Build the network
        SDVariable predictions = buildNetwork(input);
        System.out.println("TotalNumberOfParameters = " + getTotalNumberOfParameters());

        // Define loss function (mean squared error)
        SDVariable loss = sd.loss.meanSquaredError("mse_loss", labels, predictions, null);
        sd.setLossVariables(loss);

        System.out.println(sd.summary());
        // Define the optimizer
        DataType dataType = DataType.FLOAT;
        sd.setTrainingConfig(TrainingConfig.builder()
                .updater(new RmsProp(lrSchedule))
                .l2(l2)
                //                .weightDecay(1e-5, false)
                .dataSetFeatureMapping("input")
                .dataSetLabelMapping("labels")
                .build());

        UIServer uiServer = null;
        if (uiServerRun) {
            uiServer = uiMonitor();
        }

        // Wrap model with MultiGpuTrainingWrapper for multi-GPU training
        ParallelWrapper mutilGPUWrapper = null;
//        if (mutilGPU) {
//            mutilGPUWrapper = new ParallelWrapper.Builder(new MultiLayerNetwork(sd))
//                    .prefetchBuffer(prefetchBufferMutilGPU)
//                    .workers(workersMutilGPU)
//                    .averagingFrequency(avgFrequencyMutilGPU)
//                    .reportScoreAfterAveraging(true)
//                    .trainingMode(ParallelWrapper.TrainingMode.SHARED_GRADIENTS)
//                    .build();
//        }

        // Training loop
        for (int epoch = 0; epoch < nEpochs; epoch++) {
            DataSet batch = null;
            while (trainIter.hasNext()) {
                batch = trainIter.next();
                resetHiddenStates(batchSize);
//                if (mutilGPU) {
//                    mutilGPUWrapper.fit(batch);
//                } else {
                sd.fit(batch);
//                }
            }

            // Calculate and print the loss
            INDArray batchLoss = sd.output(batch, "mse_loss").get("mse_loss");
            System.out.println("Epoch: " + epoch + ", Loss: " + batchLoss);
            trainIter.reset();            
        }
        try {
            if (uiServer != null) {
                uiServer.stop();
            }
        } catch (InterruptedException ex) {
            Logger.getLogger(XLstmRegPredictor.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

SameDiff is a very powerful and useful library, but unfortunately, there are not many learning resources available.

@cqiaoYc yes I’m aware of the samediff ui. That will improve substantially now that the refactoring is about done (see the PRs currently on the branch)

In terms of the UI, to be honest it needs some work.

It’s why I recommended just wrapping it in a multilayernetwork samediff layer instead since that UI is more mature.
Could you clarify why that’s not working for you?

I don’t know why it’s not working. If the log file recorded by the UIListener class is uploaded on the page, the following graph can be seen. It converges to the same value, which is likely caused by gradient explosion.