Bert fail to train on ner task

I import the bert model and add the softmax layer to archive ner task.But when I train the model,I met the NullPointerException.

The following is my code and necessay file:

vocab.txt: vocab.txt - Google Drive
bert model:output-bert.zip - Google Drive
train-data: segMaster.corpus - Google Drive
pom:pom.xml - Google Drive

error log:

Exception in thread "main" java.lang.NullPointerException
	at org.nd4j.autodiff.samediff.internal.AbstractSession.getExecStepForVar(AbstractSession.java:680)
	at org.nd4j.autodiff.samediff.internal.AbstractSession.addDependenciesForOp(AbstractSession.java:645)
	at org.nd4j.autodiff.samediff.internal.AbstractSession.updateDescendantDeps(AbstractSession.java:581)
	at org.nd4j.autodiff.samediff.internal.AbstractSession.output(AbstractSession.java:468)
	at org.nd4j.autodiff.samediff.internal.TrainingSession.trainingIteration(TrainingSession.java:107)
	at org.nd4j.autodiff.samediff.SameDiff.fitHelper(SameDiff.java:1735)
	at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1591)
	at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1474)
	at com.wt.utils.TestMaster.train(TestMaster.java:126)
	at com.wt.utils.TestMaster.main(TestMaster.java:104)

code:

package com.wt.utils;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.weightinit.impl.UniformInitScheme;
import org.nd4j.weightinit.impl.XavierInitScheme;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;

public class TestMaster {
    static   Integer numOutput=13;
    static HashMap<Integer,String> idToLabel=new HashMap<Integer, String>(){{
        put(0,"O");
        put(1,"B-malware");
        put(2,"I-malware");
        put(3,"B-identity");
        put(4,"I-identity");
        put(5,"B-location");
        put(6,"I-location");
        put(7,"B-software");
        put(8,"I-software");
        put(9,"B-threatActor");
        put(10,"I-threatActor");
        put(11,"B-vulnerability_cve");
        put(12,"I-vulnerability_cve");
//        put(3,"s");
//        put(4,"b");
//        put(5,"e");
//        put(6,"m");
    }};
    static HashMap<String,Integer> labelMap=new HashMap<String, Integer>(){{
        put("O",0);
        put("B-malware",1);
        put("I-malware",2);
        put("B-identity",3);
        put("I-identity",4);
        put("B-location",5);
        put("I-location",6);
        put("B-software",7);
        put("I-software",8);
        put("B-threatActor",9);
        put("I-threatActor",10);
        put("B-vulnerability_cve",11);
        put("I-vulnerability_cve",12);
//        put("L",1);
//        put("I",2);
//        put("s",3);
//        put("b",4);
//        put("e",5);
//        put("m",6);
    }};

    static File wordPieceTokens = new File("D:\\deeplearning4j-examples\\mydownload\\output-bert\\vocab.txt");
    static BertWordPieceTokenizerFactory t;

    static {
        try {
            t = new BertWordPieceTokenizerFactory(wordPieceTokens, true, true, StandardCharsets.UTF_8);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    static Map<String,Integer> vocab = t.getVocab();
    static List<MultiDataSet> iter;
    static   Map<Integer,String> idToChar=new HashMap<>();

    static {
        try {
            iter = getDataIter("D:\\deeplearning4j-examples\\mydownload\\dl4j-learn\\src\\main\\resources\\segMaster.corpus",vocab);
        } catch (IOException e) {
            e.printStackTrace();
        }

        for(Map.Entry<String,Integer> e:vocab.entrySet())
        {
            idToChar.put(e.getValue(),e.getKey());
        }
    }

    public TestMaster() throws IOException {

    }

    public static void main(String[] args) throws IOException {
        train();
    }
    public static void train() throws IOException {
        SameDiff sd = getBertModel();
        TrainingConfig c = TrainingConfig.builder()
            .updater(new Adam(0.01))
            .l2(1e-5)
            .dataSetFeatureMapping("tokenIdxs", "mask","sentenceIdx")
            .dataSetLabelMapping("label")
            .build();
        sd.setTrainingConfig(c);

        long start = System.currentTimeMillis();
        System.out.println("Start Training...");
        MultiDataSet multiDataset;
        for( int numEpoch = 0; numEpoch < 10; numEpoch++ ){
            for( int i = 0; i < iter.size(); ++i ){
                multiDataset=iter.get(i);
//                System.out.println("label : "+multiDataset.getLabels(0).shapeInfoToString());
//                System.out.println("token : "+multiDataset.getFeatures(0).shapeInfoToString());

                printBertInput(multiDataset);
                sd.fit(multiDataset);

//                System.out.println("validate------------------------------------1");
//                testSingleDataNew(0,iter,sd);
//                System.out.println("validate------------------------------------2");
            }
        }

        long end = System.currentTimeMillis();
        System.out.println("Total Time Cost: " + (end - start) + "ms");
        System.out.println("End Training...");

        File saveFileForInference = new File("mySaveModel.fb");
        sd.asFlatFile(saveFileForInference);
    }



    public static void printBertInput(MultiDataSet multiDataset){
        System.out.println("printBertInput----------------------------");
        System.out.println("tokenIdxs  : "+multiDataset.getFeatures(0).get(NDArrayIndex.interval(0,128)));
        System.out.println("mask  :  "+multiDataset.getFeatures(1).get(NDArrayIndex.interval(0,128)));
        System.out.println("sentenceIdx : "+multiDataset.getFeatures(2).get(NDArrayIndex.interval(0,128)));
        System.out.println("truthLabel  : "+multiDataset.getLabels(0).get(NDArrayIndex.interval(0,128)));
    }

    private static SameDiff getBertModel() throws IOException {
        File f=new File("D:\\deeplearning4j-examples\\mydownload\\output-bert\\output-bert.fb");
        SameDiff sd= SameDiff.fromFlatFile(f);
//        System.out.println(sd.summary());

        sd.renameVariable("IteratorGetNext", "tokenIdxs");
        sd.renameVariable("IteratorGetNext:1", "mask");
        sd.renameVariable("IteratorGetNext:4", "sentenceIdx");


        SDVariable labels = sd.placeHolder("label", DataType.INT, 1, 128);
        NameScope my_transfer = sd.withNameScope("mine");
        SDVariable input = sd.getVariable("bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1");

        SDVariable my_flatten_weights = sd.var("flatten_weights", new XavierInitScheme('c', 768, numOutput), DataType.FLOAT, 768, numOutput);
        SDVariable my_flatten_bias = sd.var("flatten_bias", new UniformInitScheme('c', numOutput),DataType.FLOAT, numOutput);
        SDVariable linear_output = input.mmul(my_flatten_weights).add("linear_output",my_flatten_bias);
        SDVariable softmax_output = sd.nn().softmax("softmax", linear_output);

        SDVariable loss = sd.loss().sparseSoftmaxCrossEntropy("Loss", softmax_output, labels);
        my_transfer.close();
        //
        sd.setLossVariables(loss);
        sd.addListeners(new ScoreListener(1));
        //
        System.out.println(sd.summary());
        return sd;
    }

    private static List<MultiDataSet> getDataIter(String fileName, Map<String,Integer> vocab) throws IOException{
        List<String> lines = FileUtils.readLines(new File(fileName), Charset.forName("utf-8"));
        List<MultiDataSet> datasets = new LinkedList<>();
        List<Integer> idxsLst = new ArrayList<>(128);
        List<Integer> maskLst = new ArrayList<>(128);
        List<Integer> labelLst = new ArrayList<>(128);

        for( String line : lines ){
//            System.out.println(line);
            String[] tokens = line.split("\t");
            idxsLst.add(101);
            labelLst.add(0);
            maskLst.add(1);
//            System.out.println("line split ---------------------");
            for( String token : tokens ){
                token=token.replace(" ","");
                if(token.equals("")){
                    continue;
                }
//                System.out.println("len="+token.length());
//                System.out.println("token="+token);
                String[] wordAndLabel = token.split("/");
                String word = wordAndLabel[0];
//                String label = wordAndLabel[1];
                String label = wordAndLabel[wordAndLabel.length-1];
                idxsLst.add(vocab.getOrDefault(word, 100));	//
                maskLst.add(1);
//                if(labelMap.get(label)==null){
//                    System.out.println("label null-----------------------"+label);
//                    System.out.println(word);
//                    System.out.println(label);
//                    System.out.println(wordAndLabel[2]);
//                }
                labelLst.add(labelMap.get(label));
            }
            idxsLst.add(102);
            idxsLst.addAll(Collections.nCopies(128 - idxsLst.size(), 0));
            labelLst.add(0);
            labelLst.addAll(Collections.nCopies(128 - labelLst.size(), 0));
            maskLst.add(1);
            maskLst.addAll(Collections.nCopies(128 - maskLst.size(), 0));
            //
            INDArray idxs = Nd4j.create(idxsLst);
            INDArray mask = Nd4j.create(maskLst);
            INDArray segmentIdxs = Nd4j.zeros(128);
//            System.out.println("labelLst : "+labelLst);
            INDArray labelArr = Nd4j.create(labelLst);

            MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{idxs, mask, segmentIdxs}, new INDArray[]{labelArr});
            datasets.add(mds);
            idxsLst.clear();
            maskLst.clear();
            labelLst.clear();
        }
        return datasets;
    }

}

This looks like it couldn’t properly import the model.

As you are importing a SameDiff serialized model, we need to know how you created that model in the first place.

If you imported the model from somewhere and then serialized it for further use, please try to re-import it with the current SNAPSHOT version.

Thank you for your reply.I got this model with the help of Adam Gibson.There is an error when I import bert model,so Adam Gibso helps me to finish it.This is the issue which fixed the " Missing tensorflow BERT ops filters" .And I got this model from his Google Drive.I would like to try import the model by myself,but I can`t compile from the newest master code.
Some error happened when I importFrozenTF by SameDiffI · Issue #9420 · eclipse/deeplearning4j · GitHub

Oh, I see, you even have an issue on github about that particular problem.

I know that @agibsonccc is currently working on a few more import issues, so I guess you’ll need to wait a bit longer, and then ask for a reimport of the model again in your github issue.

Thanks for your advice,I will contact him In a few days.

@benbenwt @treo JFY I replied on the issue there. Regarding the newest master, you should be able to try snapshots.