Overriding the standard Op BP algorithm

I need to implement my own BP algorithm for the Op CompareAndSet. I did this in the same way as I did for Gather Op, but it didn’t work. So I’ve been struggling with this problem for like 3 hours and I had to give up. Unlike Gather Op, CompareAndSet is a base Op and it’s inside the SameDiff's ops field. I tried all the possible ways of overriding it, including replacing this Op with my custom one directly in SameDiff etc. but I’m stuck in FlatBuffersMapper#fromFlatNode because it creates this Op instance based on a hardcoded numerical value of OpType thus bringing the original version to life. That one is incorrectly cloned (I have no idea why cloning is needed) and some props are missing so the execution fails. Overriding the props setup method doesn’t help because the original version is hardcoded in LegacyOpMapper. Changing this type to a CUSTOM_OP in my custom class also doesn’t help because the execution fails on another part then (which makes sense because this Op is the part of the SameDiff Ops). Could someone please help me here?

P.S. It would be really great to have some flexibility to doing this stuff for all Ops, not only the custom ones. Like the way it works for Gather Op.

@partarstu all you need to do is override the doDiff(…) normally.

Most of the time we implement a _bp op in c++ to handle this if possible though.

That involves adding a _bp kernel in c++ and the associated java class to handle delegation to the underlying kernel.

In your case though, just overriding doDiff and in CompareAndSet should be enough.

That or you could also just add your own op.

Regarding the flexibility, adding your own op is very doable. The main problem is integrating it with the rest of samediff where we also have the codegen and the like. File an issue and we can discuss it over there. As usual, some expectations would be nice.

My main issue is I’m not sure what you’d expect when it comes to all the custom steps.

My initial hunch here would be to have a registry where we lookup custom ops by the user with annotations. I already implemented something like that for our model import. I don’t see why I couldn’t do the same thing here.

Did that from the very beginning, it doesn’t work. I forgot to mention that my custom Op extends CompareAndSet, because I don’t need a new Op, I just want to modify the BP, not the whole Op. I thought that adding the instance of my custom Op to the ImportClassMapping should do the job (it does for the case of my custom implementation of the Gather Op BP), but unlike Gather Op, CompareAndSet is ignored if it’s added to the ImportClassMapping. The core issue is cloning/creating a new instance from scratch instead of first looking up any class mapping. I think it has something to with the OpType - Gather Op is CUSTOM, CompareAndSet - TRANSFORM

Will do that today. Having an easy way to integrate the custom Ops would definitely be a great feature to the whole DL4J.

The idea is to bring something that could be specific to a project but not required to be present for the whole framework. In my case it’s only the BP part. Actually the original one is also implemented directly in Java, so I don’t think it’s an issue. But if any BP algorithm or the Op execution needs to be a little tuned/adjusted for specific case purposes, it would be great to have the ability to do this without a huge effort. Right now it’s not taken into account architecturally (e.g. private fields in all Ops instead of protected etc.). Regarding c++ implementation - tried that once, spent the whole day on the env setup and still failed to reach the normal setup.

It would be a great solution. Because the current workaround of modifying ImportClassMapping OP_NAME_MAP and replacing the Op instance with your own custom override having a custom BP is definitely not working for all Ops (otherwise I wouldn’t raise this topic). Would be nice to have a normal quick way to add custom Ops (at least with BP overrides) to adjust the graph to specific project needs.

Issue created : No way to override the standard Op BP algorithm · Issue #9898 · deeplearning4j/deeplearning4j · GitHub

@agibsonccc , until this issue is fixed - is there any alternative to using CompareAndSet ? Otherwise I’m kind of blocked, because I haven’t found in SameDiff any activation function which gives me the behavior I need (binary activation based on the threshold)

@partarstu I’ll have user defined functions done within the next day or 2. I mainly need to setup the right abstractions. It’ll just be simple and have the following:

  1. An op type so the FlatBuffersMapper (used for saving graphs) knows how to save/load UDFs
  2. A base class for having clear overrides
  3. An exec method where the user passes in relevant inputs and will be what gets used by the op executioner (rather than going down to c++)
  4. A hook in samediff with something like:
  1. An annotation or subclass scanner like we discussed for registration in relevant areas like the the ImportClassMappings

@partarstu here’s the baseline: Allow user to define UDFs in samediff/nd4j by agibsonccc · Pull Request #9901 · deeplearning4j/deeplearning4j · GitHub your feedback on usability would be great here. I need to add some more tests yet. Mainly covering training to make sure the proper doDiff is used.

@agibsonccc , I’ve taken a look at your PR and it seems fine at first sight. I’m however a little confused regarding all the overrides, specifically exec() method. Does it mean in my case using INDArray.replaceWhere() directly inside the exec() method and writing my own BP in doDiff() ?

On the other hand, since I’ve not been using SNAPSHOT for a long time now (I always had some issues with it), right now I get exception for each constant/variable introduction into the graph. Never seen it in M 2.1. Is this a known issue? Without having the SNAPSHOT running I can’t test your changes.
The stack trace (I’m running in Windows 11 with Java 19):
Exception in thread "main" java.lang.UnsatisfiedLinkError: 'void org.nd4j.linalg.cpu.nativecpu.bindings.Nd4jCpu.setShapeBuffer(org.bytedeco.javacpp.LongPointer, int, org.bytedeco.javacpp.LongPointer, char, int, boolean)' at org.nd4j.linalg.cpu.nativecpu.bindings.Nd4jCpu.setShapeBuffer(Native Method) at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.createShapeInfo(NativeOpExecutioner.java:1800) at org.nd4j.linalg.api.shape.Shape.createShapeInformation(Shape.java:3262) at org.nd4j.linalg.api.ndarray.BaseShapeInfoProvider.createShapeInformation(BaseShapeInfoProvider.java:68) at org.nd4j.linalg.api.ndarray.BaseNDArray.<init>(BaseNDArray.java:169) at org.nd4j.linalg.api.ndarray.BaseNDArray.<init>(BaseNDArray.java:310) at org.nd4j.linalg.cpu.nativecpu.NDArray.<init>(NDArray.java:170) at org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory.createUninitialized(CpuNDArrayFactory.java:206) at org.nd4j.linalg.factory.Nd4j.createUninitialized(Nd4j.java:4388) at org.nd4j.linalg.factory.BaseNDArrayFactory.valueArrayOf(BaseNDArrayFactory.java:792) at org.nd4j.linalg.factory.Nd4j.valueArrayOf(Nd4j.java:4503)

@partarstu yes a UDF will override exec and exec(OpContext) (used when the op doesn’t have the input arrays associated with it)

The extra stuff is just for serialization. Ops often have properties that need to be propagated to the underlying arguments that actually get passed to c++. You can see these here:

That is stuff like the tArguments, iArguments, …

Those can either be on the op or op context. Often times ops that are implemented in c++ have an associated op class with it. For ease of use those have arguments as part of the constructor that then get passed down to the base classe’s lists.

You can see a bare bones example here:

1 Like

@partarstu regarding snapshots I saw someone else having issues like that as well. I’m not sure why yet (guessing github changed something again…) I’ll look in to that after I merge this PR.

Thanks a lot! Meanwhile I’ll try to search in the forum for similar issues - maybe someone has already found a solution.

@partarstu no I’m pretty sure this is recent. I did see someone else have issues with snapshots I’m just not sure what the source of this is. If you could, could you get set org.bytedeco.javacpp.logger.debug=true
and give me the logs for that?

@partarstu here’s a sample op I used to just test a custom UDF that implements add:

Bit more context here:

Example code in the tests, cherry picked here:

     int batchSize = 4;
        int modelDim = 8;

        SameDiff sd = SameDiff.create();

        SDVariable features = sd.placeHolder("features", FLOAT, batchSize, modelDim);
        SDVariable labels = sd.placeHolder("labels", FLOAT, batchSize, modelDim);
        SDVariable weights = sd.var("weights", new XavierInitScheme('c', modelDim, modelDim), FLOAT, modelDim, modelDim);
        SDVariable bias = sd.var("bias", new ZeroInitScheme('c'), FLOAT, modelDim);
        SDVariable predictions = sd.nn.linear("predictions", features, weights, bias);
        SDVariable[] sdVariables = sd.doUdf(new TestAddUdf(sd, new SDVariable[]{predictions, sd.constant(1.0)}));
        SDVariable loss = sd.loss.meanSquaredError("loss", labels, sdVariables[0], null);
        TrainingConfig config = new TrainingConfig.Builder()
                .updater(new Adam(0.1))

        DataSetIterator iterator = new RandomDataSetIterator(1, new long[]{batchSize, modelDim}, new long[]{batchSize, modelDim}, INTEGER_0_10, INTEGER_0_10);

        sd.fit(iterator, 10);

@agibsonccc , you’re right - the issue is with openblas. Maybe because of Windows 11, no idea. Logs:

@agibsonccc , it’s a perfect example, thanks a lot!

I think this UDF feature is one of the most important assets of a new release! It brings a huge amount of elasticity!

@partarstu could you look in to using GitHub - lucasg/Dependencies: A rewrite of the old legacy software "depends.exe" in C# for Windows devs to troubleshoot dll load dependencies issues. and figuring out where the missing dependencies are? Thanks! Just analyze the files at the absolute paths that it’s mentioning. If it is something specific to windows 11 I wonder if we need to bundle something that we’re currently not. Apologies thanks for looking in to this!

@agibsonccc sure, will do that and let you know my findings.

@agibsonccc , I’ve been analyzing the dependencies of openblas-0.3.21-1.5.8-windows-x86_64.jar and I found 2 missing ones:

I’m not sure it’s the core of the problem. But after I analyzed windows-x86_64-onednn-avx512 , I saw the following picture:

Seems like there is something missing. I tried changing the version of the native backend from windows-x86_64-onednn-avx512 to the generic windows-x86_64, but it didn’t help. I have no idea if it’s the issue of javacpp-1.5.8-windows-x86_64, or of nd4j package

@partarstu thanks for doing that! Let me do some digging. M2.1 works correct?

@agibsonccc , with M 2.1 I had no issues using the same config and same environment