Hi. First of all, thanks for sharing the code!
Currently trying to train the model with pseudo captions, but it looks like the GPU memory is being used but nothing is being computed. For reference I am training with 2 A100s.
There is a warning in the log that reads:
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
which isn't too surprising since only the CPU version of JAX was installed in the script.
I am trying to figure out what could be wrong with the training. I am wondering if there is a particular reason why JAX with GPU support isn't used? Thanks.