gt_input = _get_test_bidi_cycle_graph(tf.constant(
# Node states have dimension 4.
# The first three dimensions one-hot encode the node_id i.
# The fourth dimension holds a distinct payload value i+1.
[[1., 0., 0., 1.],
[0., 1., 0., 2.],
[0., 0., 1., 3.]]))
conv = gat_v2.GATv2Convolution(
num_heads=1,
per_head_channels=4,
receiver_tag=const.TARGET,
attention_activation="relu") # Let's keep it simple.
inputs = tf.keras.layers.Input(type_spec=gt_input.spec)
outputs = conv(inputs, edge_set_name="edges")
model = tf.keras.Model(inputs, outputs)
export_dir = os.path.join("/home/sidney", "edge-input-model")
model.save(export_dir)
Why above code save the model as pb?
Can it be saved as h5 and imported as keras model to train?
This is a question for Google. Here you go.
# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")
@mdebeer @SidneyLann of note here is there are also some tools we have for staging models for import:
master/contrib/omnihub/src/omnihub/frameworks
Suite of tools for deploying and training deep learning models using the JVM. Highlights include model import for keras, tensorflow, and onnx/pytorch, a modular and tiny c++ library for running mat...
Note that keras’s .h5 file format only works with dl4j not samediff. Usually it’s easier to export a pb file rather than an h5 file format now.
pb file can’t be trained, right?
@SidneyLann I don’t see why not, one of your previous github issues was about me fixing an issue related to that on top of that we’ve been training/finetuning PB models for years now. That’s a bit broad don’t you think?
Unless you’re referring to a specific file? Then try to be more specific and say “my specific model here (link) has issues” not “all pb files are broken”. Those are 2 very different conversaations.
That’s a bert pb file, it’s output is my model’s input, not trainning on it. So I am not sure if tf.keras can be save as pb file and trained?