I import the bert model and add the softmax layer to archive ner task.But when I train the model,I met the NullPointerException.
The following is my code and necessay file:
vocab.txt: vocab.txt - Google Drive
bert model:output-bert.zip - Google Drive
train-data: segMaster.corpus - Google Drive
pom:pom.xml - Google Drive
error log:
Exception in thread "main" java.lang.NullPointerException
at org.nd4j.autodiff.samediff.internal.AbstractSession.getExecStepForVar(AbstractSession.java:680)
at org.nd4j.autodiff.samediff.internal.AbstractSession.addDependenciesForOp(AbstractSession.java:645)
at org.nd4j.autodiff.samediff.internal.AbstractSession.updateDescendantDeps(AbstractSession.java:581)
at org.nd4j.autodiff.samediff.internal.AbstractSession.output(AbstractSession.java:468)
at org.nd4j.autodiff.samediff.internal.TrainingSession.trainingIteration(TrainingSession.java:107)
at org.nd4j.autodiff.samediff.SameDiff.fitHelper(SameDiff.java:1735)
at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1591)
at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1474)
at com.wt.utils.TestMaster.train(TestMaster.java:126)
at com.wt.utils.TestMaster.main(TestMaster.java:104)
code:
package com.wt.utils;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.weightinit.impl.UniformInitScheme;
import org.nd4j.weightinit.impl.XavierInitScheme;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class TestMaster {
static Integer numOutput=13;
static HashMap<Integer,String> idToLabel=new HashMap<Integer, String>(){{
put(0,"O");
put(1,"B-malware");
put(2,"I-malware");
put(3,"B-identity");
put(4,"I-identity");
put(5,"B-location");
put(6,"I-location");
put(7,"B-software");
put(8,"I-software");
put(9,"B-threatActor");
put(10,"I-threatActor");
put(11,"B-vulnerability_cve");
put(12,"I-vulnerability_cve");
// put(3,"s");
// put(4,"b");
// put(5,"e");
// put(6,"m");
}};
static HashMap<String,Integer> labelMap=new HashMap<String, Integer>(){{
put("O",0);
put("B-malware",1);
put("I-malware",2);
put("B-identity",3);
put("I-identity",4);
put("B-location",5);
put("I-location",6);
put("B-software",7);
put("I-software",8);
put("B-threatActor",9);
put("I-threatActor",10);
put("B-vulnerability_cve",11);
put("I-vulnerability_cve",12);
// put("L",1);
// put("I",2);
// put("s",3);
// put("b",4);
// put("e",5);
// put("m",6);
}};
static File wordPieceTokens = new File("D:\\deeplearning4j-examples\\mydownload\\output-bert\\vocab.txt");
static BertWordPieceTokenizerFactory t;
static {
try {
t = new BertWordPieceTokenizerFactory(wordPieceTokens, true, true, StandardCharsets.UTF_8);
} catch (IOException e) {
e.printStackTrace();
}
}
static Map<String,Integer> vocab = t.getVocab();
static List<MultiDataSet> iter;
static Map<Integer,String> idToChar=new HashMap<>();
static {
try {
iter = getDataIter("D:\\deeplearning4j-examples\\mydownload\\dl4j-learn\\src\\main\\resources\\segMaster.corpus",vocab);
} catch (IOException e) {
e.printStackTrace();
}
for(Map.Entry<String,Integer> e:vocab.entrySet())
{
idToChar.put(e.getValue(),e.getKey());
}
}
public TestMaster() throws IOException {
}
public static void main(String[] args) throws IOException {
train();
}
public static void train() throws IOException {
SameDiff sd = getBertModel();
TrainingConfig c = TrainingConfig.builder()
.updater(new Adam(0.01))
.l2(1e-5)
.dataSetFeatureMapping("tokenIdxs", "mask","sentenceIdx")
.dataSetLabelMapping("label")
.build();
sd.setTrainingConfig(c);
long start = System.currentTimeMillis();
System.out.println("Start Training...");
MultiDataSet multiDataset;
for( int numEpoch = 0; numEpoch < 10; numEpoch++ ){
for( int i = 0; i < iter.size(); ++i ){
multiDataset=iter.get(i);
// System.out.println("label : "+multiDataset.getLabels(0).shapeInfoToString());
// System.out.println("token : "+multiDataset.getFeatures(0).shapeInfoToString());
printBertInput(multiDataset);
sd.fit(multiDataset);
// System.out.println("validate------------------------------------1");
// testSingleDataNew(0,iter,sd);
// System.out.println("validate------------------------------------2");
}
}
long end = System.currentTimeMillis();
System.out.println("Total Time Cost: " + (end - start) + "ms");
System.out.println("End Training...");
File saveFileForInference = new File("mySaveModel.fb");
sd.asFlatFile(saveFileForInference);
}
public static void printBertInput(MultiDataSet multiDataset){
System.out.println("printBertInput----------------------------");
System.out.println("tokenIdxs : "+multiDataset.getFeatures(0).get(NDArrayIndex.interval(0,128)));
System.out.println("mask : "+multiDataset.getFeatures(1).get(NDArrayIndex.interval(0,128)));
System.out.println("sentenceIdx : "+multiDataset.getFeatures(2).get(NDArrayIndex.interval(0,128)));
System.out.println("truthLabel : "+multiDataset.getLabels(0).get(NDArrayIndex.interval(0,128)));
}
private static SameDiff getBertModel() throws IOException {
File f=new File("D:\\deeplearning4j-examples\\mydownload\\output-bert\\output-bert.fb");
SameDiff sd= SameDiff.fromFlatFile(f);
// System.out.println(sd.summary());
sd.renameVariable("IteratorGetNext", "tokenIdxs");
sd.renameVariable("IteratorGetNext:1", "mask");
sd.renameVariable("IteratorGetNext:4", "sentenceIdx");
SDVariable labels = sd.placeHolder("label", DataType.INT, 1, 128);
NameScope my_transfer = sd.withNameScope("mine");
SDVariable input = sd.getVariable("bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1");
SDVariable my_flatten_weights = sd.var("flatten_weights", new XavierInitScheme('c', 768, numOutput), DataType.FLOAT, 768, numOutput);
SDVariable my_flatten_bias = sd.var("flatten_bias", new UniformInitScheme('c', numOutput),DataType.FLOAT, numOutput);
SDVariable linear_output = input.mmul(my_flatten_weights).add("linear_output",my_flatten_bias);
SDVariable softmax_output = sd.nn().softmax("softmax", linear_output);
SDVariable loss = sd.loss().sparseSoftmaxCrossEntropy("Loss", softmax_output, labels);
my_transfer.close();
//
sd.setLossVariables(loss);
sd.addListeners(new ScoreListener(1));
//
System.out.println(sd.summary());
return sd;
}
private static List<MultiDataSet> getDataIter(String fileName, Map<String,Integer> vocab) throws IOException{
List<String> lines = FileUtils.readLines(new File(fileName), Charset.forName("utf-8"));
List<MultiDataSet> datasets = new LinkedList<>();
List<Integer> idxsLst = new ArrayList<>(128);
List<Integer> maskLst = new ArrayList<>(128);
List<Integer> labelLst = new ArrayList<>(128);
for( String line : lines ){
// System.out.println(line);
String[] tokens = line.split("\t");
idxsLst.add(101);
labelLst.add(0);
maskLst.add(1);
// System.out.println("line split ---------------------");
for( String token : tokens ){
token=token.replace(" ","");
if(token.equals("")){
continue;
}
// System.out.println("len="+token.length());
// System.out.println("token="+token);
String[] wordAndLabel = token.split("/");
String word = wordAndLabel[0];
// String label = wordAndLabel[1];
String label = wordAndLabel[wordAndLabel.length-1];
idxsLst.add(vocab.getOrDefault(word, 100)); //
maskLst.add(1);
// if(labelMap.get(label)==null){
// System.out.println("label null-----------------------"+label);
// System.out.println(word);
// System.out.println(label);
// System.out.println(wordAndLabel[2]);
// }
labelLst.add(labelMap.get(label));
}
idxsLst.add(102);
idxsLst.addAll(Collections.nCopies(128 - idxsLst.size(), 0));
labelLst.add(0);
labelLst.addAll(Collections.nCopies(128 - labelLst.size(), 0));
maskLst.add(1);
maskLst.addAll(Collections.nCopies(128 - maskLst.size(), 0));
//
INDArray idxs = Nd4j.create(idxsLst);
INDArray mask = Nd4j.create(maskLst);
INDArray segmentIdxs = Nd4j.zeros(128);
// System.out.println("labelLst : "+labelLst);
INDArray labelArr = Nd4j.create(labelLst);
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{idxs, mask, segmentIdxs}, new INDArray[]{labelArr});
datasets.add(mds);
idxsLst.clear();
maskLst.clear();
labelLst.clear();
}
return datasets;
}
}