GPU memory usage for BERT in SameDiff is extremely high and grows with size of triaining set

Hi there

When I fine-tune a BERT model in SameDiff on GPU, the following exception occurs:

Caused by: java.lang.RuntimeException: [DEVICE] allocation failed; Error code: [2]                        
        at org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner.exec(        
        at org.nd4j.linalg.factory.Nd4j.exec(                                              
        at org.nd4j.autodiff.samediff.internal.InferenceSession.doExec(         
        at org.nd4j.autodiff.samediff.internal.InferenceSession.getOutputs(     
        at org.nd4j.autodiff.samediff.internal.TrainingSession.getOutputs(       
        at org.nd4j.autodiff.samediff.internal.TrainingSession.getOutputs(        
        at org.nd4j.autodiff.samediff.internal.AbstractSession.output(           
        at org.nd4j.autodiff.samediff.internal.TrainingSession.trainingIteration(
        at org.nd4j.autodiff.samediff.SameDiff.fitHelper(                              
        at org.nd4j.autodiff.samediff.config.FitConfig.exec(                           

I use a batch size of 1 (!), the max sentence length is 128 and the model itself is a pre-trained BERT base model. I have an TITAN RTX with 24GB VRAM.

So alone from this configuration I would not expect that an GPU OOM exception can occur. But with “nvidia-smi” I can indeed see the memory grow very quickly an reach the max around 22-24GB when the error occurs.

This happens when I use roughly 50 sentences. However, when I only have 10 sentences in my training set, all works fine. Even though, also then the memory grows up to ~12GB.

I wonder why the GPU memory should be dependent on the training set size. In my understanding it makes sense that it depends on the batch size, but not the number of batches. Seems like, some memory is not freed… But to really analyze that I do not understand the ND4J memory management good enough, especially in combination with GPUs.

Do you have any suggestions what the reason could be or how I could analyze this?

Thank you!

1 Like

Can you share any additional details? It shouldn’t be using that much memory, and in particular it shouldn’t be growing the memory.

Bert, and transformer based models in general, suffer from an O(n^2) memory requirement on the sequence length, but at 128 token sequence length, that shouldn’t be an issue yet.

Please see an extraction of the code here. It is not exactly the same code, since the original is cooperate code, but this should be mostly equivalent.

In the code you’ll see, that we use our own iterator for training and test data. But it actually does not much more, than generating LabelledDocuments with the sentence and the labes from our specific data format.

As the memory consumption on the GPU rises with the size of the training set, could it be a problem with the iterator/sample generation? As far as I know, I do not specify any specific memory managment configuration, so the samples of each batch should be put to GPU as long as the batch is computed and be freed right after that?

Here is some additional information about the system and CUDA:

org.nd4j.linalg.factory.Nd4jBackend.load( - Loaded [JCublasBackend] backend
org.nd4j.nativeblas.NativeOpsHolder.<init>( - Number of threads used for linear algebra: 32
org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.printEnvironmentInformation( - Backend used: [CUDA]; OS: [Linux]
org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.printEnvironmentInformation( - Cores: [28]; Memory: [64.0GB];
org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.printEnvironmentInformation( - Blas vendor: [CUBLAS]
org.nd4j.linalg.jcublas.JCublasBackend.logBackendInit( - ND4J CUDA build version: 10.2.89
org.nd4j.linalg.jcublas.JCublasBackend.logBackendInit( - CUDA device 0: [TITAN RTX]; cc: [7.5]; Total memory: [25395462144]
org.nd4j.linalg.jcublas.JCublasBackend.logBackendInit( - CUDA device 1: [TITAN RTX]; cc: [7.5]; Total memory: [25396838400]

Please tell me, if I can provide any other information to pinpoint the problem.

Sorry for pushing, but do you have any suggestions or things I could try different?

I’m having similar (propably the same) issues with a weaker GPU and less samples. I’d be interested to know what the solution is here.

@AlexBlack do you know something about this? It would be great if you could take a look. Thanks.