Inference in M2.1 is not idempotent and gives wrong results comparing to M1.1

Hi. I was just testing my SameDiff generational model which was trained in M2.1. I haven’t noticed any specific issues during training but while testing it I saw very strange results and first decided that architecture is wrong. But at the very end while debugging I found out that the same model provides different inference results for the same input data while calling SameDiff.getVariable(MODEL_OUTPUT_VAR_NAME).eval() for the first and second time. The output is quite different and what is most important - the results remain the same after the 2nd time (3rd, 4th etc.). So the first inference is quite different from the next ones (using the same inputs). Moreover, the same model, having fit() being called at least once, provides during the inference yet other results, which quite differ from the ones during direct (without any fit()) inference, both after the first and subsequent runs.

This problem doesn’t happen while running the inference and training under M1.1

Is this an already known issue ?

@partarstu mind giving a reproducer and filing an issue? The only thing I could think of that might be affecting it might be the new cache changes somehow. Could you try calling setEnableCache(false) on the samediff instance?

Beyond that, it’s kind of hard to tell. Nothing significant was done to the graph execution.

Either way, I’ll make sure the tests cover that. We didn’t have any issues in the tests and it didn’t really come up in the internal testing.

@agibsonccc , it was a straight jackpot! Just one line of code (sd.setEnableCache(false)) got it all back to the normal state. I can now see the same behavior as with M1.1. which means I won’t have to revert back to it and re-train my model from scratch (model saved in M2.1 is not backwards-compatible with M1.1). Thanks a lot !!!

Do you think it’s reasonable to turn off the caching during training as well? I did notice during training of another model under M2.1 that with every new iteration the duration of training becomes shorter and shorter which made me think of some sort of caching being involved. I’ve noticed this because I use preemptible VM in GCP thus JVM is restarted every 24h and it’s quite visible how the duration keeps getting shorter and shorter with every new training step.

@partarstu the grad samediff instance will inherit the same behavior so that should be fine.

Thanks for confirming that! So apparently the cache was interfering with that. I’ll look in to that and cut a release ASAP.

@agibsonccc Thanks! Unfortunately my SameDiff graph has 197 variables and 141 operations which makes it definitely not the best candidate for a reproducer. I’m sure you’ll find the root cause without it. If however it won’t be that easy, I could try to fiddle around with my graph to find a shorter version

@partarstu could you do a graph.summary()? Posting that as a gist is fine. I mainly want to see the ops you’re using.

@agibsonccc Sure, here it is: Generative model SD summary · GitHub

@partarstu thanks. Looking at your graph, stack,strided slice, or gather could be problems…usually the bugs come from anything relating to a view. I tried ensuring that anything that involves a view isn’t cached. Let me double check the behavior htere.

@agibsonccc Got it, thanks again!

@partarstu do you know if this is only during training this happens or Inference as well? We use a “sessions” concept when setting everything up. That will track the state of feed forward and training. We have a training subclass of our feed forward as well.

I was wondering if there was something missing in the training.

@agibsonccc , I haven’t conspicuously noticed that issue during training. Only during inference. Interesting is that I even tried to run inference of the model after a single iteration but still I haven’t noticed anything suspicious related the inference results. Only if I run the inference directly, without any warmup/fitting.

I did notice however, that after loading the saved model after JVM restart and running training, the actual accuracy at the beginning of training was quite lower than the one before saving. But it went back to then normal one in like 30-50 iterations so I didn’t pay much attention to it. This behavior is not present with caching turned off.

@partarstu could you DM me your network? The summary gives me a hint but I’d like to see the behavior for myself. I don’t mind doing the work to narrow down the problem. I appreciate just being aware of this. It’s a bit difficult to track down where the cache might be going wrong without tracing each op.

@agibsonccc I guess you don’t need only the network itself, but also some data to feed into it, right? It’s not a problem for me to send the SameDiff model, but preparing the data for it in order to run the inference/training would be not so easy. Or do you intend to use some dummy data ?

@partarstu dummy data. Mainly need to see the inconsistent outputs and track the results.

@partarstu done. Thanks!

@agibsonccc You’re welcome!

So I tried to do a reproducer given your output and at a first quick glance it seems like it works:

import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.util.HashMap;
import java.util.Map;

public class TestCacheIssue {

    public static void main(String...args) {
        SameDiff sameDiff = SameDiff.load(new File("test-cache-issue.fb"),true);
        Map<String, INDArray> placeHolders = new HashMap<>();
        placeHolders.put("FLAT_TARGET_BATCH_POSITIONS", Nd4j.ones(DataType.INT,100));
        placeHolders.put("decoderInputTokenVocabIndices",Nd4j.zeros(DataType.INT,100,256));
        placeHolders.put("FLAT_TARGET_TOKEN_VOCAB_INDICES",Nd4j.zeros(DataType.INT,768).reshape(768));

        placeHolders.put("selfAttentionCausalMasks",Nd4j.ones(DataType.FLOAT,100, 1, 256, 256));
        System.out.println(sameDiff.summary());

        for(int i = 0; i < 10; i++) {
            System.out.println(sameDiff.output(placeHolders,"predictionTokenEmbeddings"));
        }
    }

Could you detail what else could be causing this? You said you saw it more during inference. Did you have to train the model first before updating it?

@agibsonccc , I haven’t noticed this issue directly during training, only, as I mentioned before, that after restarting the JVM and loading the model the accuracy/loss of the model were quite lower/higher during the first 20-50 iterations than before the JVM was shut down. I don’t see this behavior while having sd.setEnableCache(false).

Regarding reproducer itself - FLAT_TARGET_BATCH_POSITIONS and FLAT_TARGET_TOKEN_VOCAB_INDICES should have the same shape, because they represent the same labels.

Why I noticed at all that inference doesn’t work is thanks to running periodically the inference during training (like every 20 or 30 iterations). So during training I always run fit() with only 1 epoch in the loop and do the saving/inference periodically. After some duration of training I decided however to test the model with custom test data, e.g. manually and not during training. That’s where I saw that the inference results don’t make much sense. Only after feeding into the network the same input sequences which I use for regular inference during training I saw that the results are quite different. Honestly saying I have no idea why, because running “cold” inference from the beginning and running inference after at least one iteration of fit() doesn’t seem to be so different. The same logic is used in both cases. The only difference is running fit() before inference.

I could however track some additional info using the listener is you need this in order to identify the root cause of this issue.

@partarstu thanks. Could you clarify this bit? So you’re saying that reusing arrays more than once causes this issue? In my reproducer I just did inference multiple times in a row with the same arrays to see if it would be different…

One thing that could be going on is the arrays being returned from the cache are modified elsewhere (since it’s a view) and then reused. What’s odd is the array cache itself checks for views and purposely avoids caching those.

We would need to identify a scenario that causes the cache to not detect views but arrays that should not be reused are reused somewhere.

This can potentially happen with anything involving:

  1. create_view: this was used in the gather op to allow for faster sparse updates
  2. Training: Training has a child samediff instance where arrays are passed in from the parent. This can affect training.
  3. Strided slice: this can return slices of arrays that might be affected by the above.

I would need something strange to look in to like an array that should be a view but isn’t or I’m missing a case the cache shouldn’t be handling. I don’t quite know what that is from your description here yet. Quite a bit of testing went in to the cache to make sure this behavior wasn’t there though.
It’s likely a combination of things in your network that are causing this.

Could you give me an overview of your loop you tend to do like:

  1. Train model
  2. Inference x times

Then I can try to reproduce something similar with your model. Just doing feed forward multiple times didn’t appear to be an issue. It’s probably a more complicated case to handle.

Seems like that, because results after the first “cold” inference (with no fit() on this SameDiff instance before) differ from the ones during the next ones. Also while doing “cold” inference, the results which are returned do not match the ones which are returned after running at least once the fit() on this SameDiff instance. It doesn’t matter if it was the first “cold” inference round, or second, or third - they were always different from the ones which were provided by the same model after the first fit().

The issue with “cold” inference happens all the time, I can see that after the first inference the output is different than from the next ones.

In order to check it while training you need to load the model, run the fit() once with batchSize=1, do the inference with this SD instance and output the results (output #1). After that you need to load the same model again, but this time run the “cold” inference - immediately run eval(). Output the results after the first inference (output #2) and after the second one (output #3).
Expected: all 3 outputs are the same (taking into account deviations made by the back-prop after the fit(), but with batchSize=1 those should be minor). Outputs #2 and #3 are identical.
Actual: all 3 outputs differ quite significantly (back-prop can’t be the reason to such huge deviations)