Use TRAX/JAX model in DL4J

I have a Trax Transformer model that I want to use in DL4J.
It looks like I need to redo the model in Tensorflow2 then follow an approach like this BERT example. Is that the best approach? Anyone done something similar?

@craig88 if you can find an implementation to import from keras, tensorflow or pytorch(onnx import) I would start there. We’ll work on improving the model zoo size, but it will take some time.

Here is a recipe to export Trax to Keras with tensorflow-numpy backend Using Trax with TensorFlow NumPy and Keras — Trax documentation

@craig88 looks great! I want to look in to also using the new model import framework to directly import from jax as well. From the looks of it this doesn’t use the keras h5 file format which dl4j uses. This would actually be the saved model format. We can mostly handle that but still need to work on that a bit to make it seamless.

With what’s already done (it’s still mostly the same PB format) it shouldn’t be too bad to load this.