Skip to content

Is JAX falling back to CPU the correct behavior during training? #3

@FifthEpoch

Description

@FifthEpoch

Hi. First of all, thanks for sharing the code!

Currently trying to train the model with pseudo captions, but it looks like the GPU memory is being used but nothing is being computed. For reference I am training with 2 A100s.

There is a warning in the log that reads:

WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

which isn't too surprising since only the CPU version of JAX was installed in the script.

I am trying to figure out what could be wrong with the training. I am wondering if there is a particular reason why JAX with GPU support isn't used? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions