Keras import for tf2

gt_input = _get_test_bidi_cycle_graph(tf.constant(
    # Node states have dimension 4.
    # The first three dimensions one-hot encode the node_id i.
    # The fourth dimension holds a distinct payload value i+1.
    [[1., 0., 0., 1.],
     [0., 1., 0., 2.],
     [0., 0., 1., 3.]]))

conv = gat_v2.GATv2Convolution(
    num_heads=1,
    per_head_channels=4,
    receiver_tag=const.TARGET,
    attention_activation="relu")  # Let's keep it simple.

inputs = tf.keras.layers.Input(type_spec=gt_input.spec)
outputs = conv(inputs, edge_set_name="edges")
model = tf.keras.Model(inputs, outputs)
export_dir = os.path.join("/home/sidney", "edge-input-model")
model.save(export_dir)

Why above code save the model as pb?
Can it be saved as h5 and imported as keras model to train?

This is a question for Google. Here you go.

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

@mdebeer @SidneyLann of note here is there are also some tools we have for staging models for import:

Note that keras’s .h5 file format only works with dl4j not samediff. Usually it’s easier to export a pb file rather than an h5 file format now.

pb file can’t be trained, right?

@SidneyLann I don’t see why not, one of your previous github issues was about me fixing an issue related to that on top of that we’ve been training/finetuning PB models for years now. That’s a bit broad don’t you think?

Unless you’re referring to a specific file? Then try to be more specific and say “my specific model here (link) has issues” not “all pb files are broken”. Those are 2 very different conversaations.

That’s a bert pb file, it’s output is my model’s input, not trainning on it. So I am not sure if tf.keras can be save as pb file and trained?