Skip to content

Conversation

jialei777
Copy link
Collaborator

@jialei777 jialei777 commented Aug 12, 2025

Worked locally with 4 chips fsdp=4: export PJRT_DEVICE=TPU; export TORCHPRIME_TPU_TYPE=v6e-4 && python torchprime/torch_xla_models/train.py model=flex-qwen-1b

MFU: 0.21

On a v5p-128 cluster with command tp run --name jialei-0812-qwen-fsdp32tensor2 torchprime/torch_xla_models/train.py model=flex-qwen-1b task.global_batch_size=64 ici_mesh.fsdp=32 ici_mesh.tensor=2

  • fsdp64: hang????
  • fsdp 32 tp2: finished MFU 0.22
  • fsdp 16 tp4: finished: MFU 0.19
  • fsdp 8 tp8: finished, MFU 0.11

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant