Model Persistence with Custom Lambda Layers

What is the proper way of saving a ComputationGraph with lambda (custom) layers? I’m currently using the save method from the ComputationGraph class. When loading the saved model using ComputationGraph.load() I’m getting the following error:

Exception in thread “main” java.lang.RuntimeException: Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is not a valid ComputationGraphConfiguration
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraphHelper(ModelSerializer.java:560)
at
mask_proxy.zip
org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:462)
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:647)
at org.deeplearning4j.nn.graph.ComputationGraph.load(ComputationGraph.java:4618)

Is there anything special that has to be done when saving a ComputationGraph with lambda layers?

@joel-a thanks for posting! You should just need to save the model and load it and ensure the exact class name is included with the application you’re trying to load it from. Try to follow the samples here and let me know if you have any problems:

@agibsonccc Thanks for your prompt response. I went through the example you pointed me to and I don’t see what I should be doing differently.

Instead of custom layers, I’m loading a Keras model with lambda layers and registering them with KerasLayer.registerLambdaLayer(name, new Instance Of SameDiffLambdaLayer); The model loads correctly and I’m able to run inference with it. When it comes time to save, I call save(file, false) on the model and I get no errors.

Now, if I try to load the DL4J saved model (.zip), not my original .h5, with ComputationGraph.load(name, false) I get the error I posted above. Why is the save() method not saving the JSON as a valid ComputationGraphConfiguration?

Thanks for your continued help. We are making great progress and hope to deploy with DL4J.

Are you loading it in the same application?

@treo I’m not completely sure I understand your question. I’m doing it all on the IDE (Netbeans). Converting the model first from .h5 to DL4J’s zip by calling the main of Converter class. Then, running a different main defined in a different class which is loading the zip file. Again, it’s all happening under the same development environment and IDE project. Thank you!

I’m asking because you have plenty of custom layers in the config, e.g.:

edu.mit.ll.seamnet.layers.STDAxis1

When those layers are not present on the class path when you are loading the model, it will complain that it can not deserialize the configuration.

So, are those layer classes available when you are loading the model?

@treo @agibsonccc
I get it now. Yes, the custom layer classes should be present on the class path when loading the model. Here’s a full example (inference is not shown but I assure you that I can run inference successfully on the exported .h5 model.):

String h5Path = Paths.get("testModel.h5").toAbsolutePath().toString();
        System.out.println("Loading h5 model from: " + h5Path);
        
        //=========== CONVERT MODEL FROM H5 TO DL4J FORMAT ===================//
        KerasLayer.registerLambdaLayer("log_lambda_layer", new LogLambda());
        KerasLayer.registerLambdaLayer("mean_lambda_layer", new MeanAxis1());
        KerasLayer.registerLambdaLayer("div_lambda_layer", new STDAxis1());
        KerasLayer.registerLambdaLayer("downsample0_lambda_layer", new DownSample0());
        KerasLayer.registerLambdaLayer("downsample1_lambda_layer", new DownSample1());
        KerasLayer.registerLambdaLayer("shift_forward_lambda_layer", new ShiftForward(10));
        KerasLayer.registerLambdaLayer("shift_backward_lambda_layer", new ShiftBackward(10));
        KerasLayer.registerLambdaLayer("slice_lambda_layer", new SliceLayer());
        KerasLayer.registerLambdaLayer("permute_dim_lambda_layer", new PermuteTensorDims());
        
        System.out.println("Loading model from h5 ...");
        ComputationGraph decoder = KerasModelImport.importKerasModelAndWeights(h5Path);
       
        //Running inference succesfully here with model in 'decoder'

        //Saving to DL4J format
        File modelLocation = new File(Paths.get("model.zip").toAbsolutePath().toString());
        System.out.println("Saving file to : " + modelLocation.toString());
        decoder.save(modelLocation, false);
        
        
        //================ LOADING SAVED MODEL ========================//
        System.out.println("Loading SAVED model ...");
        var restored = ComputationGraph.load(modelLocation,false);

This code results in the following error:

Exception in thread "main" java.lang.RuntimeException: Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is not a valid ComputationGraphConfiguration
	at org.deeplearning4j.util.ModelSerializer.restoreComputationGraphHelper(ModelSerializer.java:560)
	at org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:462)
	at org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:647)
	at org.deeplearning4j.nn.graph.ComputationGraph.load(ComputationGraph.java:4618)
	at edu.mit.ll.seamnet.utils.KerasToDL4JConverter.main(KerasToDL4JConverter.java:120)
Caused by: java.lang.RuntimeException: org.nd4j.shade.jackson.databind.exc.MismatchedInputException: Cannot construct instance of `edu.mit.ll.seamnet.layers.ShiftForward` (although at least one Creator exists): cannot deserialize from Object value (no delegate- or property-based Creator)
 at [Source: (String)"{
  "backpropType" : "Standard",
  "cacheMode" : "NONE",
  "dataType" : "FLOAT",
  "defaultConfiguration" : {
    "cacheMode" : "NONE",
    "dataType" : "FLOAT",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : null,
    "maxNumLineSearchIterations" : 5,
    "miniBatch" : true,
    "minimize" : true,
    "optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
    "seed" : 1625158628700,
    "stepFunction" : null,
    "variables" : [ "conv1d_14_W" ]
  },
  "epochCount" : 0,
  "inferenceW"[truncated 11346 chars]; line: 218, column: 11] (through reference chain: org.deeplearning4j.nn.conf.ComputationGraphConfiguration["vertices"]->java.util.LinkedHashMap["shift_forward_lambda_layer"]->org.deeplearning4j.nn.conf.graph.LayerVertex["layerConf"]->org.deeplearning4j.nn.conf.NeuralNetConfiguration["layer"])
	at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.fromJson(ComputationGraphConfiguration.java:196)
	at org.deeplearning4j.util.ModelSerializer.restoreComputationGraphHelper(ModelSerializer.java:547)
	... 4 more
Caused by: org.nd4j.shade.jackson.databind.exc.MismatchedInputException: Cannot construct instance of `edu.mit.ll.seamnet.layers.ShiftForward` (although at least one Creator exists): cannot deserialize from Object value (no delegate- or property-based Creator)
 at [Source: (String)"{
  "backpropType" : "Standard",
  "cacheMode" : "NONE",
  "dataType" : "FLOAT",
  "defaultConfiguration" : {
    "cacheMode" : "NONE",
    "dataType" : "FLOAT",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : null,
    "maxNumLineSearchIterations" : 5,
    "miniBatch" : true,
    "minimize" : true,
    "optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
    "seed" : 1625158628700,
    "stepFunction" : null,
    "variables" : [ "conv1d_14_W" ]
  },
  "epochCount" : 0,
  "inferenceW"[truncated 11346 chars]; line: 218, column: 11] (through reference chain: org.deeplearning4j.nn.conf.ComputationGraphConfiguration["vertices"]->java.util.LinkedHashMap["shift_forward_lambda_layer"]->org.deeplearning4j.nn.conf.graph.LayerVertex["layerConf"]->org.deeplearning4j.nn.conf.NeuralNetConfiguration["layer"])
	at org.nd4j.shade.jackson.databind.exc.MismatchedInputException.from(MismatchedInputException.java:63)
	at org.nd4j.shade.jackson.databind.DeserializationContext.reportInputMismatch(DeserializationContext.java:1588)
	at org.nd4j.shade.jackson.databind.DeserializationContext.handleMissingInstantiator(DeserializationContext.java:1213)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializerBase.deserializeFromObjectUsingNonDefault(BeanDeserializerBase.java:1415)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserializeFromObject(BeanDeserializer.java:362)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer._deserializeOther(BeanDeserializer.java:230)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:197)
	at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer._deserializeTypedForId(AsPropertyTypeDeserializer.java:137)
	at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer.deserializeTypedFromObject(AsPropertyTypeDeserializer.java:107)
	at org.nd4j.shade.jackson.databind.deser.AbstractDeserializer.deserializeWithType(AbstractDeserializer.java:263)
	at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:138)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:187)
	at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:129)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer._deserializeOther(BeanDeserializer.java:225)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:197)
	at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer._deserializeTypedForId(AsPropertyTypeDeserializer.java:137)
	at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer.deserializeTypedFromObject(AsPropertyTypeDeserializer.java:107)
	at org.nd4j.shade.jackson.databind.deser.AbstractDeserializer.deserializeWithType(AbstractDeserializer.java:263)
	at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer._readAndBindStringKeyMap(MapDeserializer.java:611)
	at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer.deserialize(MapDeserializer.java:437)
	at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer.deserialize(MapDeserializer.java:32)
	at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:129)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
	at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:187)
	at org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer.deserialize(ComputationGraphConfigurationDeserializer.java:61)
	at org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer.deserialize(ComputationGraphConfigurationDeserializer.java:51)
	at org.nd4j.shade.jackson.databind.deser.DefaultDeserializationContext.readRootValue(DefaultDeserializationContext.java:322)
	at org.nd4j.shade.jackson.databind.ObjectMapper._readMapAndClose(ObjectMapper.java:4593)
	at org.nd4j.shade.jackson.databind.ObjectMapper.readValue(ObjectMapper.java:3548)
	at org.nd4j.shade.jackson.databind.ObjectMapper.readValue(ObjectMapper.java:3516)
	at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.fromJson(ComputationGraphConfiguration.java:167)
	... 5 more
Command execution failed.
org.apache.commons.exec.ExecuteException: Process exited with an error: 1 (Exit value: 1)
    at org.apache.commons.exec.DefaultExecutor.executeInternal (DefaultExecutor.java:404)
    at org.apache.commons.exec.DefaultExecutor.execute (DefaultExecutor.java:166)
    at org.codehaus.mojo.exec.ExecMojo.executeCommandLine (ExecMojo.java:982)
    at org.codehaus.mojo.exec.ExecMojo.executeCommandLine (ExecMojo.java:929)
    at org.codehaus.mojo.exec.ExecMojo.execute (ExecMojo.java:457)
    at org.apache.maven.plugin.DefaultBuildPluginManager.executeMojo (DefaultBuildPluginManager.java:137)
    at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:210)
    at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:156)
    at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:148)
    at org.apache.maven.lifecycle.internal.LifecycleModuleBuilder.buildProject (LifecycleModuleBuilder.java:117)
    at org.apache.maven.lifecycle.internal.LifecycleModuleBuilder.buildProject (LifecycleModuleBuilder.java:81)
    at org.apache.maven.lifecycle.internal.builder.singlethreaded.SingleThreadedBuilder.build (SingleThreadedBuilder.java:56)
    at org.apache.maven.lifecycle.internal.LifecycleStarter.execute (LifecycleStarter.java:128)
    at org.apache.maven.DefaultMaven.doExecute (DefaultMaven.java:305)
    at org.apache.maven.DefaultMaven.doExecute (DefaultMaven.java:192)
    at org.apache.maven.DefaultMaven.execute (DefaultMaven.java:105)
    at org.apache.maven.cli.MavenCli.execute (MavenCli.java:957)
    at org.apache.maven.cli.MavenCli.doMain (MavenCli.java:289)
    at org.apache.maven.cli.MavenCli.main (MavenCli.java:193)
    at jdk.internal.reflect.NativeMethodAccessorImpl.invoke0 (Native Method)
    at jdk.internal.reflect.NativeMethodAccessorImpl.invoke (NativeMethodAccessorImpl.java:64)
    at jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke (DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke (Method.java:564)
    at org.codehaus.plexus.classworlds.launcher.Launcher.launchEnhanced (Launcher.java:282)
    at org.codehaus.plexus.classworlds.launcher.Launcher.launch (Launcher.java:225)
    at org.codehaus.plexus.classworlds.launcher.Launcher.mainWithExitCode (Launcher.java:406)
    at org.codehaus.plexus.classworlds.launcher.Launcher.main (Launcher.java:347)

So it still failing to reload because of the custom layers.

How did you define that layer in particular?

@treo The issue was that some of the custom classes needed default constructors. I implemented default constructors for all custom classes and it is now working. Thanks a lot for your help!