Finetune onnx model

Can I retrain some layers in onnx model by SameDiff? Is there an example?

@Booker yes we do. You would need to add an updater and loss function. Most onnx models will only have the feedforward. You can load it as is and just add these things when you’re done. Eg: loss + updater could be:

SDVariable loss = sd.loss.logLoss("loss", label, out);

                //Also set the training configuration:
                sd.setTrainingConfig(TrainingConfig.builder()
                        .updater(new Adam(0.01))
                        .weightDecay(1e-3, true)
                        .dataSetFeatureMapping("in")            //features[0] -> "in" placeholder
                        .dataSetLabelMapping("label")           //labels[0]   -> "label" placeholder
                        .build());

sd would be a samediff instance you imported using the OnnxFrameworkImporter.

Thanks. I should retrain entire model or one layer to get better result for my little data? how to frozen entire model and set specific layer to train?

@Booker there aren’t really “layers” in samediff. It’s just ops. There are just variables. In terms of specific variables, you can just do:

SameDIff sd = …;
sd = sd.freeze(true);

That will freeze all variables.

For leaving certain variables frozen, just set those to variables. You can see the code for that here:

By default everything should be constants since onnx imports as feedforward only.

Powerfull! The next release is waiting for a year, when will release it?

I had paid work that held that up for quite a while. Now that that’s done I’m wrapping up the cuda testing now. I’ve been cleaning up technical debt along the way. Don’t worry that’s the main item I’m working on atm.

Unable to resolve attribute for name auto_pad for node Conv for op type Conv
Unable to resolve attribute for name dilations for node MaxPool for op type MaxPool
Skipping input B on node /model.22/dfl/conv/Conv

These console logs exist when I import yolov8x.onnx, is it normal and can be ignore? Thanks.

@Booker can you file an issue and link to the model so I can reproduce this and verify this will work in the current upcoming release?Thanks!

done.