@agibsonccc I implemented an sLSTM ( xLSTM) model with residual network using SameDiff’s. However, the loss always converges to the same value, and the graph page in uiServer does not display any content. Can you help me identify the issue in my code? The official documentation has very few complex examples of SameDiff. Could you simplify my case into an example for beginners to learn from? Thank you!
How do I upload the code file?
public class SLSTMCell {
private SameDiff sd;
private SDVariable w_i, w_f, w_o, w_z;
private SDVariable r_i, r_f, r_o, r_z;
private SDVariable b_i, b_f, b_o, b_z;
private SDVariable hiddenState, cellState, nState, mState;
private int inputDim, hiddenDim;
private long weightsCount, biasCount;
private Float zeroValue = Float.valueOf(0f);
public SLSTMCell(SameDiff sd, int layerId, int inputDim, int hiddenDim, int batchSize) {
this.sd = sd;
this.inputDim = inputDim;
this.hiddenDim = hiddenDim;
initializeParameters(layerId);
initializeStates(layerId, batchSize);
}
private void initializeParameters(int layerId) {
Nd4j.getRandom().setSeed(System.currentTimeMillis());
long[] wShape = new long[]{inputDim, hiddenDim};
double wSqrtScale = Math.sqrt(2.0 / (inputDim + hiddenDim));
INDArray xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);
w_i = sd.var(layerId + "w_i", xavierInit); // hiddenDim x inputDim
xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);
w_f = sd.var(layerId + "w_f", xavierInit);
xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);
w_o = sd.var(layerId + "w_o", xavierInit);
xavierInit = Nd4j.randn(wShape).muli(wSqrtScale);
w_z = sd.var(layerId + "w_z", xavierInit);
long[] rShape = new long[]{hiddenDim, hiddenDim};
double rSqrtScale = Math.sqrt(2.0 / (hiddenDim + hiddenDim));
xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
r_i = sd.var(layerId + "r_i", xavierInit); // hiddenDim x hiddenDim
xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
r_f = sd.var(layerId + "r_f", xavierInit);
xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
r_o = sd.var(layerId + "r_o", xavierInit);
xavierInit = Nd4j.randn(rShape).muli(rSqrtScale);
r_z = sd.var(layerId + "r_z", xavierInit);
weightsCount = inputDim * hiddenDim * 8;
b_i = sd.var(layerId + "b_i", Nd4j.randn(1, hiddenDim).mul(0.01));
b_f = sd.var(layerId + "b_f", Nd4j.randn(1, hiddenDim).mul(0.01));
b_o = sd.var(layerId + "b_o", Nd4j.randn(1, hiddenDim).mul(0.01));
b_z = sd.var(layerId + "b_z", Nd4j.randn(1, hiddenDim).mul(0.01));
biasCount = hiddenDim * 4;
}
private void initializeStates(int layerId, int batchSize) {
hiddenState = sd.constant(layerId + "hs", Nd4j.zeros(batchSize, hiddenDim));
cellState = sd.constant(layerId + "cs", Nd4j.zeros(batchSize, hiddenDim));
nState = sd.constant(layerId + "ns", Nd4j.zeros(batchSize, hiddenDim));
mState = sd.constant(layerId + "ms", Nd4j.zeros(batchSize, hiddenDim));
}
public SDVariable forward(SDVariable input) {
SDVariable i_tilda = sd.math.add(sd.mmul(input, w_i), sd.mmul(hiddenState, r_i)).add(b_i);
SDVariable f_tilda = sd.math.add(sd.mmul(input, w_f), sd.mmul(hiddenState, r_f)).add(b_f);
SDVariable o_tilda = sd.math.add(sd.mmul(input, w_o), sd.mmul(hiddenState, r_o)).add(b_o);
SDVariable z_tilda = sd.math.add(sd.mmul(input, w_z), sd.mmul(hiddenState, r_z)).add(b_z);
SDVariable i_t = sd.math.exp(i_tilda);
SDVariable f_t = sd.nn.sigmoid(f_tilda);
SDVariable newMState = sd.math.max(sd.math.log(f_t).add(mState), sd.math.log(i_t));
sd.assign(mState, newMState);
SDVariable i_prime = sd.math.exp(sd.math.log(i_t).sub(mState));
SDVariable f_prime = sd.math.exp(sd.math.log(f_t).add(mState).sub(mState));
sd.assign(cellState, f_prime.mul(cellState).add(i_prime.mul(sd.nn.tanh(z_tilda))));
sd.assign(nState, f_prime.mul(nState).add(i_prime));
SDVariable c_hat = cellState.div(nState);
sd.assign(hiddenState, sd.nn.sigmoid(o_tilda).mul(sd.nn.tanh(c_hat)));
return hiddenState;
}
public void resetStates(int batchSize) {
// Reset hidden states to zero
hiddenState.assign(zeroValue);
cellState.assign(zeroValue);
nState.assign(zeroValue);
mState.assign(zeroValue);
}
public long getWeightsCount() {
return weightsCount;
}
public void setWeightsCount(long weightsCount) {
this.weightsCount = weightsCount;
}
public long getBiasCount() {
return biasCount;
}
public void setBiasCount(long biasCount) {
this.biasCount = biasCount;
}
}
//-----------------------------------The core part of XLstmResidualNetwork class-------------------------
public XLstmResidualNetwork(int inputDim, int[] hiddenDims, boolean useMSLSTM) {
this.sd = SameDiff.create();
this.featuresCount = inputDim;
this.rnnNeurons = hiddenDims;
this.useMSLSTM = useMSLSTM;
}
public SDVariable buildNetwork(SDVariable input) {
Nd4j.getRandom().setSeed(System.currentTimeMillis());
SDVariable currentHiddenState = null;
int currentInputDim = featuresCount;
xLstmLayers = new ArrayList<>(rnnNeurons.length);
for (int i = 0; i < rnnNeurons.length; i++) {
int hiddenDim = rnnNeurons[i];
SLSTMCell cell = new SLSTMCell(sd, i, currentInputDim, hiddenDim, batchSize);
xLstmLayers.add(cell);
currentInputDim = hiddenDim;
}
SDVariable x = input;
long timeStep = input.getShape()[2]; // Assuming input shape is [batchSize, inputDim, timeStep]
for (int t = 0; t < timeStep; t++) {
SDVariable x_t = input.get(SDIndex.all(), SDIndex.all(), SDIndex.point(t));
for (int i = 0; i < rnnNeurons.length; i++) {
SLSTMCell cell = xLstmLayers.get(i);
currentHiddenState = cell.forward(x_t);
// Add residual connection
if (i > 0) {
//if rnnNeurons[i]!=rnnNeurons[i-1],add denseLayer
x_t = sd.math.add(currentHiddenState, x_t);
} else {
// The input for the next layer is the output of the current layer
x_t = currentHiddenState;
}
}
}
// Add dense layer
SDVariable output = dense(currentHiddenState, currentInputDim, 1);
return output;
}
private SDVariable dense(SDVariable input, int inputSize, int outputSize) {
// Xavier初始化
INDArray xavierInit = Nd4j.randn(new long[]{inputSize, outputSize}).muli(Math.sqrt(2.0 / (inputSize + outputSize)));
SDVariable w = sd.var("w", xavierInit);
SDVariable b = sd.var("b", new XavierInitScheme('f', 1, outputSize));
SDVariable output = sd.mmul(input, w).add("output", b);
return output;
}
public void train(DataSetIterator trainIter, DataSetIterator validateDataSetIterator, String modelSaveFileName) {
modelFileNamesFilePath = Paths.get(modelSaveFileName + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMddHHmmss")) + ".txt");
// Define placeholders for input and labels
SDVariable input = sd.placeHolder("input", DataType.FLOAT, -1, featuresCount, timeStep); //-1表示不确定
SDVariable labels = sd.placeHolder("labels", DataType.FLOAT, -1, 1);
// Build the network
SDVariable predictions = buildNetwork(input);
System.out.println("TotalNumberOfParameters = " + getTotalNumberOfParameters());
// Define loss function (mean squared error)
SDVariable loss = sd.loss.meanSquaredError("mse_loss", labels, predictions, null);
sd.setLossVariables(loss);
System.out.println(sd.summary());
// Define the optimizer
DataType dataType = DataType.FLOAT;
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new RmsProp(lrSchedule))
.l2(l2)
// .weightDecay(1e-5, false)
.dataSetFeatureMapping("input")
.dataSetLabelMapping("labels")
.build());
UIServer uiServer = null;
if (uiServerRun) {
uiServer = uiMonitor();
}
// Wrap model with MultiGpuTrainingWrapper for multi-GPU training
ParallelWrapper mutilGPUWrapper = null;
// if (mutilGPU) {
// mutilGPUWrapper = new ParallelWrapper.Builder(new MultiLayerNetwork(sd))
// .prefetchBuffer(prefetchBufferMutilGPU)
// .workers(workersMutilGPU)
// .averagingFrequency(avgFrequencyMutilGPU)
// .reportScoreAfterAveraging(true)
// .trainingMode(ParallelWrapper.TrainingMode.SHARED_GRADIENTS)
// .build();
// }
// Training loop
for (int epoch = 0; epoch < nEpochs; epoch++) {
DataSet batch = null;
while (trainIter.hasNext()) {
batch = trainIter.next();
resetHiddenStates(batchSize);
// if (mutilGPU) {
// mutilGPUWrapper.fit(batch);
// } else {
sd.fit(batch);
// }
}
// Calculate and print the loss
INDArray batchLoss = sd.output(batch, "mse_loss").get("mse_loss");
System.out.println("Epoch: " + epoch + ", Loss: " + batchLoss);
trainIter.reset();
}
try {
if (uiServer != null) {
uiServer.stop();
}
} catch (InterruptedException ex) {
Logger.getLogger(XLstmRegPredictor.class.getName()).log(Level.SEVERE, null, ex);
}
}