Multi-GPU: Exception when storing Model using SameDiff

Hello :slight_smile:

We fine-tune a BERT model based on a TensorFlow frozen graph using SameDiff. This happens on a system with two GPUs.

After the training for the fine-tuning is successfully completed, we try to persist the resulting model using, false);

This results in an error of an unsupported data type with the following stack trace:

java.lang.UnsupportedOperationException: Unsupported data type used: UTF8

Cause for this is that, within org.nd4j.linalg.util.DeviceLocalNDArray.get(, sourceId and deviceId are sometimes unequal, with sourceId (to my understanding the device where the variable contents reside in memory) = 0 and deviceId (where the current thread runs) = 1. This should, by itself, not be a problem because the intent of this method seems exactly to be to get the data to the device local to the currently running thread.

However, there is at least one data item being an UTF-8 scalar, which was initially loaded as part of the TensorFlow model (protocol buffer *.pb file generated with the from here on the basis of the “BERT-Base, Multilingual Cased” model downloadable on Google’s github) and which is now not transferrable to the other device in this way, because UTF-8 scalars are not supported.

As such, the unsupported data type as such is just an implementation decision on what to support. However, in the whole context, it looks to me like a bug in the design concerning the handling of multiple GPU devices being present.

I can supply the protobuf file if it helps.

Also, any ideas for a workaround on this? We tried limiting to one GPU by setting ND4J_CUDA_FORCE_SINGLE_GPU=true. But this does not help.

Regards, Crispy

@Crispy The issue looks simpler than what you’re describing. that appears to be trying to create a scalar of type string. Could you file an issue? I don’t see why we shouldn’t allow that. Beyond that, anything you can give me to help me reproduce it (preferably code + model if needed over DMs) would be nice.

Thanks, @agibsonccc. I’ll create an issue.

Created here