Creating a SameDiff op purely in Java

Hi. After analyzing the SameDiff workflow I concluded that all the operations are implemented in C++ and couldn’t find any option using a pure Java one. This intention is inspired by released in JDK 16 Vector and Foreign-Memory Access incubator APIs. I’d like to do some experiments to see if applying those APIs could enhance transformer-based huge models built in SameDiff (taking into account the fact that SameDiff works mainly with CPU with no GPU support). Would be great to have any example/tips of how one could run an op in SameDiff purely in Java avoiding a standard call to C++ implementation.

Thanks in advance!

Could anyone help ? :disappointed:

SameDiff has had GPU support for quite a while now. If an op doesn’t run on the GPU, then it is considered to be a bug.

There is your answer, if you want to implement an “Op” for SameDiff, the easiest way to do it is to implement it in C++.

So far for the “easy” answer, but in principle there is nothing that prevents you from creating an SDVariable subclass that implements some special behavior, but it will be brittle, as this is not a supported use case.

Thanks a lot for a reply!

Sorry, my bad. I must have read somewhere regarding SD, that it doesn’t support GPU (probably JavaDoc). I’ll have to give it a try and see if all the OPs in my graph work.

As far as I understand, at some point any SDVariable needs to use some OP which should be an implementation of DifferentialFunction which in its turn relies on C++ implementation. Or do you mean, that this new SDVariable subclass should contain its own “DifferentialFunction” logic which relies purely on Java ?

By the way, a little off the topic. I can’t utilize my CPU fully with the current SD model. Does SameDiff have anything similar to ParallelWrapper from DL4J ?

@partarstu no, it’s a separate api and everything. ParallelWrapper is basically a thread safe way of training neural nets via parameter averaging. It runs one copy of the model on each gpu and aggregates the results saving one model at the end.

@agibsonccc , does it mean that this API works only for DL4J ? Are there any thread-safe ways right now to provide the parallelization on Java-level for SameDiff ?

@partarstu try using a samediff layer: https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java

Are you trying to do multi gpu or just training?

@agibsonccc , tried the SameDiff layer but I needed more flexibility, e.g. the access to some intermediate variables, custom transformation logic etc. Pure SameDiff works really good for those purposes (especially because I modify the graph quite often and use the previous version’s weights in order to not lose the accumulated training results). But it’s just that the CPU utilization is not maximal so I was looking for some ways to improve that using concurrency.

Just training, but my model is a transformer-based one (obviously takes a lot of time to train). Using a GPU (single or multiple) is an alternative for accelerating the training but for now I tried running the model only on CPU (including in the docker) as my machine doesn’t have Nvidia card

This topic was exactly created by me because I wanted to try the newest JDK hoping some new incubator initiatives would allow to get a C+±like performance purely in Java providing me at the same time with an ability to take care on my own about concurrency. Based on the answers of @treo it’s currently almost not possible to do that in SameDiff.

If you have hyperthreading / SMT enabled, you don’t actually want it to be utilized 100%, as that means that you are losing time on needless context switching.

At some point it will call SDVariable.getArr(), that will return an INDArray and there will be some setup in the background for the gradient which at some point too will be turned into an INDArray.

Those are the points you would need to hook into, but we can’t really support you in building that, as this wasn’t ever intended to be used that way.

So that leaves you with needing to work with INDArray.

Honestly, you will not get great performance from this. As far as I know, the new APIs do not support GPUs, and you still need to implement things in the best possible way to get the full improvement. For example see: GitHub - lessthanoptimal/VectorPerformance

That shows just under a 2x improvement from java. It is an improvement, but with all the work you need to do, I do not think you will gain a lot compared to our C++ implementation.

Currently it’s in average 55-60%, I’d gladly have at least 80%.

That’s actually the problem - I wouldn’t like to invest much time into the workflow which is not supported. I’d rather try other alternatives to see if the performance there is better.

The example in that code uses add and multiplication vector operations directly. I actually tried testing FloatVector with using fma() directly and it was indeed fast (almost 10 times faster than vanilla Math.fma()) on huge arrays for single dot product. However the performance for huge arrays during matrix multiplication was no match to the ND4J’s one (probably because I didn’t use loop unrolling).

Here you’re definitely right. Taking into account you previous answers regarding direkt hooks into SameDiff and not supported use cases I’d rather look into a custom OP’s C++ implementation (although I’m not quite strong in C++ coding) in order to enhance the performance

That is exactly the utilization you want to see when hyperthreading is enabled. That means that it is using every single one of your cores to the max.With Hyperthreading every core can contain the state for two threads, and when one of the threads needs to wait for IO, the other thread can run. So it is something that exists purely to improve the utilization of your cores in situations where IO happens.

With math heavy code, there is little to no IO, and that means that using more threads than you have actual cores, will result in context switching and will at best do nothing, and at worst reduce your performance and slow down your code.

Set the environment variable OMP_NUM_THREADS to the number of threads your system supports, and if your calculations are big enough, it will use more resources.

In the example I used for it it went from 55% utilization to 70% utilization, and after I increased the model size it even went to 100% utilization.

At the small size, I lost about 10% performance from doing that, and at the large size it was about equal, but the CPU still used 20W more power.

Maybe for your example it will be different :slight_smile: If all you want to do is get 100% utilization in your task manager, then you just need to set that environment variable.

Tried that a couple of times with different values. I have 16 logical cores, 8 physical ones. The default value of OMP_NUM_THREADS is the number of physical cores. After I increased it to 16, I got not only ~ 20% CPU usage increase but the performance one as well (didn’t calculate the exact value but I saw approximately the same increase based on the epoch duration). But when I set this variable to the value of more than 16, it changes nothing (as you already stated that). My model’s size seems to be not small (during training it consumes ~120GB at the peak) but increasing the batch size twice still doesn’t seem to change the CPU load.

Maybe if I implement my custom logic as a single OP in C++ it will bring some surplus in performance. Currently I have many operations in my graph which could be aggregated into a single pipeline thus saving time on the calls from/to native implementation

It appears that you are looking for a black cat in a dark room that doesn’t contain a black cat.

Calls to native are very fast. The biggest improvement you can get from going down to a custom op, is that you may be able to maybe fuse some of your ops, but then you will also need to take care of dealing with calculating the gradient in that same way.

If you really think there is a bottleneck in the way things are done currently, you should attach a good profiler to your application and try to figure out where it spends most of its time.

Got it. Thanks a lot for explanations! I appreciate that!