Problem About Modifying the Bert Graph in SameDiff

Hi all, I’m trying to train a sequence classification/labeling model by importing pretrained Bert model to SameDiff. I download the pretrained Bert model and vocabulary txt. And then convert it to a frozen graph(.pb file) by using the .py script in this link(https://github.com/KonduitAI/dl4j-dev-tools/tree/master/import-tests/model_zoo/bert). I use tensorboard to check the graph and got the image
below and I think converting to .pb file is successful. Then, I import the .pb file to SameDiff without any operation replacements. I add some basic operations so that the pretrained model can be used to do transfer learning. Actually, when mini-batch is 1, it runs/trains successfully. I think the reason for this is that the three placeholders/inputs have been set to the shape of [1x128].


My question is if I want to set the mini-batch to 4/8/16/32 in SameDiff, how to modify the graph ? I think only modifying the shape of placeholders is not enough.
Maybe it can be realized by previously running some .py scripts offered by the solution offered in the linked above. But I wonder how to do it just in SamDiff ?

Take a look at this: https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java

The test also imports a pb file that was generated with the script you’ve linked, and it does some additional graph modification to make it reasonable to use that bert model for inference.

I tested that file before and my understanding is that the .pb model imported actually has been previously processed following the steps in this link(https://github.com/KonduitAI/dl4j-dev-tools/tree/master/import-tests/model_zoo/bert) especially STEP 3. The third step will run the “run_classifier.py” script and the graph will be modified including some subgraph structure and placeholder/variable/array’s shape. In another word, although “BERTGraphTest.java” does a few modification(remove dropout, reshape placeholder), most variable’s shape is fixed, I think. So, if I change the batch size, the imported graph will not work any more. I wonder how to modify the graph in SameDiff so that changable batch size can be supported(maybe some variable’s shape should be reset together) without running those .py script.