When I fine-tune a BERT model in SameDiff on GPU, the following exception occurs:
Caused by: java.lang.RuntimeException: [DEVICE] allocation failed; Error code:  at org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner.exec(CudaExecutioner.java:2103) at org.nd4j.linalg.factory.Nd4j.exec(Nd4j.java:6575) at org.nd4j.autodiff.samediff.internal.InferenceSession.doExec(InferenceSession.java:495) at org.nd4j.autodiff.samediff.internal.InferenceSession.getOutputs(InferenceSession.java:222) at org.nd4j.autodiff.samediff.internal.TrainingSession.getOutputs(TrainingSession.java:149) at org.nd4j.autodiff.samediff.internal.TrainingSession.getOutputs(TrainingSession.java:31) at org.nd4j.autodiff.samediff.internal.AbstractSession.output(AbstractSession.java:391) at org.nd4j.autodiff.samediff.internal.TrainingSession.trainingIteration(TrainingSession.java:115) at org.nd4j.autodiff.samediff.SameDiff.fitHelper(SameDiff.java:1713) at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1569) at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1509) at org.nd4j.autodiff.samediff.config.FitConfig.exec(FitConfig.java:173) at org.nd4j.autodiff.samediff.SameDiff.fit(SameDiff.java:1524)
I use a batch size of 1 (!), the max sentence length is 128 and the model itself is a pre-trained BERT base model. I have an TITAN RTX with 24GB VRAM.
So alone from this configuration I would not expect that an GPU OOM exception can occur. But with “nvidia-smi” I can indeed see the memory grow very quickly an reach the max around 22-24GB when the error occurs.
This happens when I use roughly 50 sentences. However, when I only have 10 sentences in my training set, all works fine. Even though, also then the memory grows up to ~12GB.
I wonder why the GPU memory should be dependent on the training set size. In my understanding it makes sense that it depends on the batch size, but not the number of batches. Seems like, some memory is not freed… But to really analyze that I do not understand the ND4J memory management good enough, especially in combination with GPUs.
Do you have any suggestions what the reason could be or how I could analyze this?