@treo @agibsonccc
I get it now. Yes, the custom layer classes should be present on the class path when loading the model. Here’s a full example (inference is not shown but I assure you that I can run inference successfully on the exported .h5 model.):
String h5Path = Paths.get("testModel.h5").toAbsolutePath().toString();
System.out.println("Loading h5 model from: " + h5Path);
//=========== CONVERT MODEL FROM H5 TO DL4J FORMAT ===================//
KerasLayer.registerLambdaLayer("log_lambda_layer", new LogLambda());
KerasLayer.registerLambdaLayer("mean_lambda_layer", new MeanAxis1());
KerasLayer.registerLambdaLayer("div_lambda_layer", new STDAxis1());
KerasLayer.registerLambdaLayer("downsample0_lambda_layer", new DownSample0());
KerasLayer.registerLambdaLayer("downsample1_lambda_layer", new DownSample1());
KerasLayer.registerLambdaLayer("shift_forward_lambda_layer", new ShiftForward(10));
KerasLayer.registerLambdaLayer("shift_backward_lambda_layer", new ShiftBackward(10));
KerasLayer.registerLambdaLayer("slice_lambda_layer", new SliceLayer());
KerasLayer.registerLambdaLayer("permute_dim_lambda_layer", new PermuteTensorDims());
System.out.println("Loading model from h5 ...");
ComputationGraph decoder = KerasModelImport.importKerasModelAndWeights(h5Path);
//Running inference succesfully here with model in 'decoder'
//Saving to DL4J format
File modelLocation = new File(Paths.get("model.zip").toAbsolutePath().toString());
System.out.println("Saving file to : " + modelLocation.toString());
decoder.save(modelLocation, false);
//================ LOADING SAVED MODEL ========================//
System.out.println("Loading SAVED model ...");
var restored = ComputationGraph.load(modelLocation,false);
This code results in the following error:
Exception in thread "main" java.lang.RuntimeException: Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is not a valid ComputationGraphConfiguration
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraphHelper(ModelSerializer.java:560)
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:462)
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraph(ModelSerializer.java:647)
at org.deeplearning4j.nn.graph.ComputationGraph.load(ComputationGraph.java:4618)
at edu.mit.ll.seamnet.utils.KerasToDL4JConverter.main(KerasToDL4JConverter.java:120)
Caused by: java.lang.RuntimeException: org.nd4j.shade.jackson.databind.exc.MismatchedInputException: Cannot construct instance of `edu.mit.ll.seamnet.layers.ShiftForward` (although at least one Creator exists): cannot deserialize from Object value (no delegate- or property-based Creator)
at [Source: (String)"{
"backpropType" : "Standard",
"cacheMode" : "NONE",
"dataType" : "FLOAT",
"defaultConfiguration" : {
"cacheMode" : "NONE",
"dataType" : "FLOAT",
"epochCount" : 0,
"iterationCount" : 0,
"layer" : null,
"maxNumLineSearchIterations" : 5,
"miniBatch" : true,
"minimize" : true,
"optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
"seed" : 1625158628700,
"stepFunction" : null,
"variables" : [ "conv1d_14_W" ]
},
"epochCount" : 0,
"inferenceW"[truncated 11346 chars]; line: 218, column: 11] (through reference chain: org.deeplearning4j.nn.conf.ComputationGraphConfiguration["vertices"]->java.util.LinkedHashMap["shift_forward_lambda_layer"]->org.deeplearning4j.nn.conf.graph.LayerVertex["layerConf"]->org.deeplearning4j.nn.conf.NeuralNetConfiguration["layer"])
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.fromJson(ComputationGraphConfiguration.java:196)
at org.deeplearning4j.util.ModelSerializer.restoreComputationGraphHelper(ModelSerializer.java:547)
... 4 more
Caused by: org.nd4j.shade.jackson.databind.exc.MismatchedInputException: Cannot construct instance of `edu.mit.ll.seamnet.layers.ShiftForward` (although at least one Creator exists): cannot deserialize from Object value (no delegate- or property-based Creator)
at [Source: (String)"{
"backpropType" : "Standard",
"cacheMode" : "NONE",
"dataType" : "FLOAT",
"defaultConfiguration" : {
"cacheMode" : "NONE",
"dataType" : "FLOAT",
"epochCount" : 0,
"iterationCount" : 0,
"layer" : null,
"maxNumLineSearchIterations" : 5,
"miniBatch" : true,
"minimize" : true,
"optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
"seed" : 1625158628700,
"stepFunction" : null,
"variables" : [ "conv1d_14_W" ]
},
"epochCount" : 0,
"inferenceW"[truncated 11346 chars]; line: 218, column: 11] (through reference chain: org.deeplearning4j.nn.conf.ComputationGraphConfiguration["vertices"]->java.util.LinkedHashMap["shift_forward_lambda_layer"]->org.deeplearning4j.nn.conf.graph.LayerVertex["layerConf"]->org.deeplearning4j.nn.conf.NeuralNetConfiguration["layer"])
at org.nd4j.shade.jackson.databind.exc.MismatchedInputException.from(MismatchedInputException.java:63)
at org.nd4j.shade.jackson.databind.DeserializationContext.reportInputMismatch(DeserializationContext.java:1588)
at org.nd4j.shade.jackson.databind.DeserializationContext.handleMissingInstantiator(DeserializationContext.java:1213)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializerBase.deserializeFromObjectUsingNonDefault(BeanDeserializerBase.java:1415)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserializeFromObject(BeanDeserializer.java:362)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer._deserializeOther(BeanDeserializer.java:230)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:197)
at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer._deserializeTypedForId(AsPropertyTypeDeserializer.java:137)
at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer.deserializeTypedFromObject(AsPropertyTypeDeserializer.java:107)
at org.nd4j.shade.jackson.databind.deser.AbstractDeserializer.deserializeWithType(AbstractDeserializer.java:263)
at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:138)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:187)
at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:129)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer._deserializeOther(BeanDeserializer.java:225)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:197)
at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer._deserializeTypedForId(AsPropertyTypeDeserializer.java:137)
at org.nd4j.shade.jackson.databind.jsontype.impl.AsPropertyTypeDeserializer.deserializeTypedFromObject(AsPropertyTypeDeserializer.java:107)
at org.nd4j.shade.jackson.databind.deser.AbstractDeserializer.deserializeWithType(AbstractDeserializer.java:263)
at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer._readAndBindStringKeyMap(MapDeserializer.java:611)
at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer.deserialize(MapDeserializer.java:437)
at org.nd4j.shade.jackson.databind.deser.std.MapDeserializer.deserialize(MapDeserializer.java:32)
at org.nd4j.shade.jackson.databind.deser.impl.MethodProperty.deserializeAndSet(MethodProperty.java:129)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.vanillaDeserialize(BeanDeserializer.java:324)
at org.nd4j.shade.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:187)
at org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer.deserialize(ComputationGraphConfigurationDeserializer.java:61)
at org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer.deserialize(ComputationGraphConfigurationDeserializer.java:51)
at org.nd4j.shade.jackson.databind.deser.DefaultDeserializationContext.readRootValue(DefaultDeserializationContext.java:322)
at org.nd4j.shade.jackson.databind.ObjectMapper._readMapAndClose(ObjectMapper.java:4593)
at org.nd4j.shade.jackson.databind.ObjectMapper.readValue(ObjectMapper.java:3548)
at org.nd4j.shade.jackson.databind.ObjectMapper.readValue(ObjectMapper.java:3516)
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.fromJson(ComputationGraphConfiguration.java:167)
... 5 more
Command execution failed.
org.apache.commons.exec.ExecuteException: Process exited with an error: 1 (Exit value: 1)
at org.apache.commons.exec.DefaultExecutor.executeInternal (DefaultExecutor.java:404)
at org.apache.commons.exec.DefaultExecutor.execute (DefaultExecutor.java:166)
at org.codehaus.mojo.exec.ExecMojo.executeCommandLine (ExecMojo.java:982)
at org.codehaus.mojo.exec.ExecMojo.executeCommandLine (ExecMojo.java:929)
at org.codehaus.mojo.exec.ExecMojo.execute (ExecMojo.java:457)
at org.apache.maven.plugin.DefaultBuildPluginManager.executeMojo (DefaultBuildPluginManager.java:137)
at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:210)
at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:156)
at org.apache.maven.lifecycle.internal.MojoExecutor.execute (MojoExecutor.java:148)
at org.apache.maven.lifecycle.internal.LifecycleModuleBuilder.buildProject (LifecycleModuleBuilder.java:117)
at org.apache.maven.lifecycle.internal.LifecycleModuleBuilder.buildProject (LifecycleModuleBuilder.java:81)
at org.apache.maven.lifecycle.internal.builder.singlethreaded.SingleThreadedBuilder.build (SingleThreadedBuilder.java:56)
at org.apache.maven.lifecycle.internal.LifecycleStarter.execute (LifecycleStarter.java:128)
at org.apache.maven.DefaultMaven.doExecute (DefaultMaven.java:305)
at org.apache.maven.DefaultMaven.doExecute (DefaultMaven.java:192)
at org.apache.maven.DefaultMaven.execute (DefaultMaven.java:105)
at org.apache.maven.cli.MavenCli.execute (MavenCli.java:957)
at org.apache.maven.cli.MavenCli.doMain (MavenCli.java:289)
at org.apache.maven.cli.MavenCli.main (MavenCli.java:193)
at jdk.internal.reflect.NativeMethodAccessorImpl.invoke0 (Native Method)
at jdk.internal.reflect.NativeMethodAccessorImpl.invoke (NativeMethodAccessorImpl.java:64)
at jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke (DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke (Method.java:564)
at org.codehaus.plexus.classworlds.launcher.Launcher.launchEnhanced (Launcher.java:282)
at org.codehaus.plexus.classworlds.launcher.Launcher.launch (Launcher.java:225)
at org.codehaus.plexus.classworlds.launcher.Launcher.mainWithExitCode (Launcher.java:406)
at org.codehaus.plexus.classworlds.launcher.Launcher.main (Launcher.java:347)