Pass float in java and get tf.float64 in python

Python4j:
I pass float in java and get tf.float64 in tf2 side, how to get it as tf.float32?

@Booker I would just cast it using tf.cast. If you are passing in a tensor, ensure it’s the correct datatype.
Nd4j’s datatype by default is float 32.

If you don’t specify it, numpy will have a datatype of float64 by default.

I would double check the inpus/outputs and what their defaults are.

Do you have more code to post?

I pass List<List>, and call train(data: Tuple) in python, when I create tensor in py, use np.array({data, dtype=tf.float32.as_numpy_dtype) can get float32 now, thanks.

1 Like