A good method of preventing overfitting using MutiGPUWrapper

@agibsonccc if mutilGPUWrapper = new ParallelWrapper.Builder(net) is out of " for (int i = 0; i < nEpochs; i++) {}", the UIServer works fine, but an exception was thrown: https://community.konduit.ai/t/how-to-customize-a-dataset-iterator-that-supports-multiple-gpus/3163
Another question: In a single GPU environment, using ParallelWrapper, the training speed has increased by nearly double, but it does not have the anti-overfitting ability in a multi-GPU environment. Can a similar ability to that in a multi-GPU environment be obtained by setting a multi-threaded training method?

/*
 * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license
 * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template
 */
package com.cq.aifocusstocks.train;

import java.nio.charset.Charset;
import java.nio.file.Path;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.StepSchedule;

/**
 *
 * @author cqiao
 */
public class CnnLstmPredictModelTestUIServer {

    protected int featuresCount = 24;
    protected int timeStep = 30;
    protected int nEpochs = 100;
    protected int startTrainResultReportEpoch = 3;  //从此次迭代开始输出报告
    protected int trainResultReportStep = 2;
    protected int batchSize = 64;
    private int samplesTotal = 100000;
    protected double l1 = 0;
    protected double l2 = 0.0001;
    protected float dropOut = 0.5f;

    protected ISchedule rnnLrSchedule;
    protected ISchedule outLrSchedule;
    protected Path modelFileNamesFilePath;
    protected final Charset CHARSET = Charset.forName("UTF-8");

    protected boolean mutilGPU = true;
    protected int prefetchBufferMutilGPU = 24;
    protected int workersMutilGPU = 4;
    protected int avgFrequencyMutilGPU = 2;
    protected float gradientNormalizationThreshold = 1;  //默认 
    protected float rnnGradientNormalizationThreshold = 0.5f;
    protected boolean hasPoolingLayer = false;
    protected ISchedule cnnLrSchedule;
    ;
    protected int[] cnnStrides = {1, 1, 1, 1};// Strides for each CNN layer
    protected int[] cnnNeurons = {32, 64}; //cnn各层的神经元数量
    protected int[] rnnNeurons = {64, 32};//rnn各层的神经元数量
    int[] cnnKernelSizes = {3, 3, 3, 3}; // Kernel sizes for each CNN layer  

    public MultiLayerConfiguration getNetConf() {
        double startLR = 0.001f;
        double endLR = 0.00001f;
        long iterationsTotal = samplesTotal / batchSize * nEpochs;
        long step = 100;
        double decayRate = computeDecayRate(startLR, endLR, iterationsTotal, step);
        cnnLrSchedule = new StepSchedule(ScheduleType.ITERATION, startLR, decayRate, endLR);
        rnnLrSchedule = cnnLrSchedule;
        outLrSchedule = cnnLrSchedule;

        DataType dataType = DataType.FLOAT;
        NeuralNetConfiguration.Builder nncBuilder = new NeuralNetConfiguration.Builder()
                .seed(System.currentTimeMillis())
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                //                .updater(new RmsProp(rnnLrSchedule))//(rnnLrSchedule))
                .gradientNormalization(GradientNormalization.ClipL2PerLayer)
                .gradientNormalizationThreshold(gradientNormalizationThreshold)
                .dataType(dataType);

        nncBuilder.l1(l1);
        nncBuilder.l2(l2);

        NeuralNetConfiguration.ListBuilder listBuilder = nncBuilder.list();
        int nIn = featuresCount;//
        int layerIndex = 0;

        listBuilder.setInputType(InputType.recurrent(nIn));

        // Add CNN layers
        if (cnnNeurons != null) {
            final int cnnLayerCount = cnnNeurons.length;
            final Adam adam = new Adam(cnnLrSchedule);
            for (int i = 0; i < cnnLayerCount; i++) {

                listBuilder.layer(layerIndex, new Convolution1D.Builder()
                        .dropOut(dropOut)
                        .kernelSize(cnnKernelSizes[i])
                        .stride(cnnStrides[i])
                        .convolutionMode(ConvolutionMode.Same)
                        //                    .padding(cnnPadding)   
                        .updater(adam)
                        .nIn(nIn)
                        .nOut(cnnNeurons[i])
                        .activation(Activation.TANH)
                        .build());

                nIn = cnnNeurons[i];
                ++layerIndex;
                if (hasPoolingLayer) {
                    listBuilder.layer(layerIndex, new Subsampling1DLayer.Builder()
                            .kernelSize(cnnKernelSizes[i])
                            .stride(cnnStrides[i])
                            .convolutionMode(ConvolutionMode.Same)
                            .poolingType(SubsamplingLayer.PoolingType.MAX)
                            .build());

                    ++layerIndex;
                }

//            listBuilder.layer(layerIndex, new BatchNormalization.Builder().nOut(nIn).build());//an exception is thrown
//            ++layerIndex;
            }
        }

        // Add RNN layers
        final RmsProp rmsProp = new RmsProp(rnnLrSchedule);
        for (int i = 0; i < this.rnnNeurons.length; ++i) {
            listBuilder.layer(layerIndex, new LSTM.Builder()
                    .dropOut(dropOut)
                    .activation(Activation.TANH)
                    .updater(rmsProp)
                    .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                    .gradientNormalizationThreshold(rnnGradientNormalizationThreshold)
                    .nIn(nIn)
                    .nOut(rnnNeurons[i])
                    .build());

            nIn = rnnNeurons[i];
            ++layerIndex;

        }
//            listBuilder.layer(layerIndex, new BatchNormalization.Builder().nOut(nIn).build());//an exception is thrown
//            ++layerIndex;

        listBuilder.layer(layerIndex,
                new RnnOutputLayer.Builder(new LossMSE()).updater(new RmsProp(outLrSchedule))//
                        .activation(Activation.IDENTITY).nIn(nIn).nOut(1).dataFormat(RNNFormat.NCW).build());
//        listBuilder.setInputType(InputType.recurrent(featuresCount)); 
        MultiLayerConfiguration conf = listBuilder.build();
        return conf;
    }

    private double computeDecayRate(double startLr, double endLr, long iterationsTotal, long step) {
        return Math.pow(endLr / startLr, (double) step / iterationsTotal);
    }

    public void trainModel() {
        System.out.println("start train: " + LocalDateTime.now());
        TimeSeriesListDataSetIterator trainIterator=generateIterator();
        MultiLayerNetwork net = new MultiLayerNetwork(getNetConf());
        net.init();// 

        UIServer uiServer = uiMonitor(net);
//        modelFileNamesFilePath = Paths.get(modelSaveFileName + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMddHHmmss")) + ".txt");
        ParallelWrapper mutilGPUWrapper = null;
//        if (mutilGPU) {
//            mutilGPUWrapper = new ParallelWrapper.Builder(net)
//                    .prefetchBuffer(prefetchBufferMutilGPU)
//                    .workers(workersMutilGPU)
//                    .averagingFrequency(avgFrequencyMutilGPU)
//                    .reportScoreAfterAveraging(true)
//                    .build();
//        }        

        for (int i = 0; i < nEpochs; i++) {
            trainIterator.reset();
            if (mutilGPU) {
                //if this statement is placed outside the loop,
                //an exception will be thrown after being executed multiple times. 
                //The number of times the loop can be executed is uncertain.
                //My iterator is custom-defined and I don't know if it is caused by it.
                mutilGPUWrapper = new ParallelWrapper.Builder(net)
                        .prefetchBuffer(prefetchBufferMutilGPU)
                        .workers(workersMutilGPU)
                        .averagingFrequency(avgFrequencyMutilGPU)
                        .reportScoreAfterAveraging(true)
                        .build();
                mutilGPUWrapper.fit(trainIterator);

            } else {
                net.fit(trainIterator);
            }
            System.out.println("==No." + i + " nEpochs, " //
                    + LocalTime.now() + ", model.score=" + net.score());
//            if ((i == startTrainResultReportEpoch || (i > startTrainResultReportEpoch && (i - startTrainResultReportEpoch) % trainResultReportStep == 0)) && i != nEpochs - 1) {
//                if (!mutilGPU) {
//                    RegressionEvaluation eval = new RegressionEvaluation();
//                    test(eval, model, validateDataSetIterator);
//                }
//
//                String modelId = getNeuronsStr() + "-" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMddHHmmss"));
//                this.saveModel(model, modelSaveFileName, modelId);
//            }
        }

        try {
            if (uiServer != null) {
                uiServer.stop();
            }
        } catch (InterruptedException ex) {
            Logger.getLogger(CnnLstmPredictModelTestUIServer.class.getName()).log(Level.SEVERE, null, ex);
        }

    }

    private TimeSeriesListDataSetIterator generateIterator() {
        List<DataSet> dataSetList = new ArrayList<>();
        int dataSetCount = samplesTotal / batchSize;
        System.out.println("the iterator count of each Epoch: "+dataSetCount);
        for (int i = 0; i < dataSetCount; ++i) {
            INDArray features3D = Nd4j.randn(new int[]{batchSize, featuresCount, timeStep}).muli(2).subi(1);
            INDArray labels3D = Nd4j.randn(new int[]{batchSize, 1,timeStep}).muli(2).subi(1);
             dataSetList.add(new DataSet(features3D,labels3D));
        }
        return new TimeSeriesListDataSetIterator(dataSetList, true);
    }

    public UIServer uiMonitor(MultiLayerNetwork model) {
        //网络学习过程监控
        //初始化用户界面后端
        System.setProperty("org.deeplearning4j.ui.port", "9001");
        UIServer uiServer = UIServer.getInstance();
        StringBuilder sb = new StringBuilder("http://localhost:").append(UIServer.getInstance().getPort()).append("/");
        System.out.println("UIServer url:" + sb.toString());

        //设置网络信息(随时间变化的梯度、分值等)的存储位置。这里将其存储于内存。
        StatsStorage statsStorage = new InMemoryStatsStorage();         //或者: new FileStatsStorage(File),用于后续的保存和载入

        //将StatsStorage实例连接至用户界面,让StatsStorage的内容能够被可视化
        uiServer.attach(statsStorage);

        //然后添加StatsListener来在网络定型时收集这些信息//   
        model.setListeners(new StatsListener(statsStorage));
        return uiServer;
    }

    public void setCnnStrides(int[] cnnStrides) {
        this.cnnStrides = cnnStrides;
    }

    public int[] getCnnNeurons() {
        return cnnNeurons;
    }

    public void setCnnNeurons(int[] cnnNeurons) {
        this.cnnNeurons = cnnNeurons;
    }

    public int[] getCnnKernelSizes() {
        return cnnKernelSizes;
    }

    public void setCnnKernelSizes(int[] cnnKernelSizes) {
        this.cnnKernelSizes = cnnKernelSizes;
    }

    public ISchedule getCnnLrSchedule() {
        return cnnLrSchedule;
    }

    public void setCnnLrSchedule(ISchedule cnnLrSchedule) {
        this.cnnLrSchedule = cnnLrSchedule;
    }

    public float getGradientNormalizationThreshold() {
        return gradientNormalizationThreshold;
    }

    public void setGradientNormalizationThreshold(float gradientNormalizationThreshold) {
        this.gradientNormalizationThreshold = gradientNormalizationThreshold;
    }

    public float getRnnGradientNormalizationThreshold() {
        return rnnGradientNormalizationThreshold;
    }

    public void setRnnGradientNormalizationThreshold(float rnnGradientNormalizationThreshold) {
        this.rnnGradientNormalizationThreshold = rnnGradientNormalizationThreshold;
    }

    public boolean isHasPoolingLayer() {
        return hasPoolingLayer;
    }

    public void setHasPoolingLayer(boolean hasPoolingLayer) {
        this.hasPoolingLayer = hasPoolingLayer;
    }
    
    public static void main(String[] args){
        CnnLstmPredictModelTestUIServer testUI=new CnnLstmPredictModelTestUIServer();
        testUI.trainModel();
    }

}

/*
 * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license
 * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template
 */
package com.cq.aifocusstocks.train;

import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

/**
 *
 * @author cqiao
 */
public class TimeSeriesListDataSetIterator implements DataSetIterator {

    private  List<DataSet> dataSetList;

    protected int inputColumns;
    protected int outputColumns;
    private  int totalSamples;
    private  int totalBatch;
    private int batchCursor;
    protected int batchSize;
    protected int sampleStep;
    private INDArray batchsizeLabelMask;
    private boolean needLabelMask = true;

//    public TimeSeriesListDataSetIterator(List<DataSet> dataSetList) {
//        TimeSeriesListDataSetIterator(dataSetList,true);
//    }
   
    /**
     *
     * @param dataSetList 每个DataSet有相同数量的sample,3D
     * @param needLabelMask
     */
    public TimeSeriesListDataSetIterator(List<DataSet> dataSetList,boolean needLabelMask) {
        this.needLabelMask=needLabelMask;
        this.dataSetList = dataSetList;
        totalBatch = dataSetList.size();

        long[] featuresShape = dataSetList.get(0).getFeatures().shape();
        long[] labelsShape = dataSetList.get(0).getLabels().shape();
        this.batchSize = (int) labelsShape[0];
        this.inputColumns = (int) featuresShape[1];
        this.outputColumns = (int) labelsShape[1];
        this.sampleStep = (int) featuresShape[2];
        totalSamples = totalBatch * batchSize;
        if (needLabelMask) {
            if (!dataSetList.get(0).hasMaskArrays()) {
                batchsizeLabelMask = generateLabelsMask(batchSize);
                for (DataSet dataSet : this.dataSetList) {
                    dataSet.setLabelsMaskArray(batchsizeLabelMask);
                }
            }
        }
    }

    @Override
    public synchronized DataSet next(int num) {
        if (batchCursor == totalBatch) {
            batchCursor = 0;
        }

        DataSet dataSet = this.dataSetList.get(batchCursor);

        ++batchCursor;
        return dataSet;
    }

    private INDArray generateLabelsMask(int batchSize) {
        INDArray mask = Nd4j.create(new int[]{batchSize, sampleStep}, 'f');
        for (int j = 0; j < batchSize; ++j) {
            mask.putScalar(j, sampleStep - 1, 1);
        }

        return mask;
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return true;
    }

    @Override
    public synchronized void reset() {
        batchCursor = 0;
    }

    @Override
    public synchronized boolean hasNext() {
        return this.batchCursor < this.totalBatch;
    }

    @Override
    public DataSet next() {
        return next(0);
    }

    public int getTotalExamples() {
        return totalSamples;
    }

    @Override
    public int inputColumns() {
        return this.inputColumns;
    }

    @Override
    public int totalOutcomes() {
        return this.outputColumns;
    }

    @Override
    public int batch() {
        return batchSize;
    }

    public boolean isNeedLabelMask() {
        return needLabelMask;
    }

    public void setNeedLabelMask(boolean needLabelMask) {
        this.needLabelMask = needLabelMask;
    }
    
    

    @Override
    public void setPreProcessor(DataSetPreProcessor dspp) {
        throw new UnsupportedOperationException("Not supported yet."); // Generated from nbfs://nbhost/SystemFileSystem/Templates/Classes/Code/GeneratedMethodBody
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        throw new UnsupportedOperationException("Not supported yet."); // Generated from nbfs://nbhost/SystemFileSystem/Templates/Classes/Code/GeneratedMethodBody
    }

    @Override
    public List<String> getLabels() {
        throw new UnsupportedOperationException("Not supported yet."); // Generated from nbfs://nbhost/SystemFileSystem/Templates/Classes/Code/GeneratedMethodBody
    }
}