Skip to content

[Feat] Add support for RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES (Fix AMD support) #1465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HollowMan6
Copy link
Contributor

@HollowMan6 HollowMan6 commented May 9, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

Add support for RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES, also Fix AMD support

High-Level Design

Current approach for supporting AMD in verl is fundamentally not correct, and is just working out of the luck:

Calls such as torch.cuda.is_available() or torch.cuda.get_device_name() will initialize the CUDA/ROCm environment:
https://github.com/pytorch/pytorch/blob/c65ee728f069ea9544bdcac815eb0825f45d1633/torch/cuda/__init__.py#L342-L392

Setting CUDA/HIP/ROCR_VISIBLE_DEVICES after CUDA/ROCm is initialized will not take effect (Please check pytorch/pytorch#141678), which means that all current code that wrapped inside [SUPPORT AMD: torch] are mostly noops.

CUDA_VISIBLE_DEVICES also works for AMD, but it's because that a lot of AMD migrated software call those torch.cuda.* during importing, e.g.:

While ray/vllm manipulates those *_VISIBLE_DEVICES during runtime, which cause those torch.cuda.* to poison the current process if the CUDA/ROCm environment is initialized before the manipulation happens.

So, here, it would be a good solution to use only one environment variable for all (CUDA_VISIBLE_DEVICES) for consistency and hardware-agnostic, move all the other *_VISIBLE_DEVICES to the CUDA one. Note that we must pay attention if both HIP/CUDA and ROCR env vars are set as they have different meanings. Both env vars accept either a list of ints or a list of UUIDs. The ROCR env var is processed first which then reduces the number of GPUs that HIP can select from. (Refering to pytorch/pytorch#144026) To avoid the complexity of this, we simply gives out error if both are set (Also to keep consistency with ray's practice with 2.45.0).

For the poisoning issue, before those 2 PRs are merged, we will need to ask the users to set RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES or RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES, so that ray no longer manipulates these variables, and make verl workable when there is no *_VISIBLE_DEVICES.

Note that for latest ray (after their switch to HIP_VISIBLE_DEVICES), we also need this patch: ray-project/ray#52794

Test

Tested manually on both megatron and fsdp beckend with vllm.

Additional Info.

  • Issue Number: none
  • Training: both FSDP and Megatron
  • Inference: both vLLM and SGLang

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if neccessary.

@HollowMan6
Copy link
Contributor Author

cc @yushengsu-thu and @eric-haibin-lin for awareness, as I saw another ongoing PR #1453

@vermouth1992 vermouth1992 requested a review from yushengsu-thu May 9, 2025 12:58
Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Is there a unit test that can easily reproduce the issue?

@HollowMan6
Copy link
Contributor Author

HollowMan6 commented May 9, 2025

Thanks. Is there a unit test that can easily reproduce the issue?

For the AMD part, since those poisonings are AMD-specific, we would need AMD devices in CI/CD to easily reproduce the issue. I saw there are some ongoing discussions in #1453 about CI setup for AMD devices for verl. A quick demo of why the current approach is wrong can be found in pytorch/pytorch#141678

import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = ""
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
print(f"torch.cuda.device_count(): {torch.cuda.set_device('cuda:0')}")

For the RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES part, for NVIDIA ones, on top of existing test cases, we can of course have another full test with RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 set, but I'm not sure if this is necessary and will need your advice.

@HollowMan6 HollowMan6 force-pushed the amd branch 2 times, most recently from df26023 to 1a2791b Compare May 9, 2025 17:17
@yushengsu-thu yushengsu-thu self-assigned this May 9, 2025
@yushengsu-thu
Copy link
Collaborator

yushengsu-thu commented May 9, 2025

@HollowMan6, thank you, and I am also aware of this issue. My colleague - vicky, also sent this patch to Ray. If you will, could we arrange an online quick call (This is my mail: yushengsu.thu@gmail.com) to review this PR.
It'd be helpful for us if you could provide us with more suggestions for maintenance and modification.

is_ray_noset_visible_devices = ray_noset_visible_devices()

# Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES``
rocr_val = os.environ.get("ROCR_VISIBLE_DEVICES", None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ray >=2.45 will throw an error if ROCR_VISIBLE_DEVICES is used. In general, it can cause unexpected behavior in pytorch distributed.

Copy link
Contributor Author

@HollowMan6 HollowMan6 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for compatibility for ray < 2.45, and I don't think it will cause issues for Ray >=2.45, as ray will do that check if ROCR_VISIBLE_DEVICES is used from the beginning, and we just take its values to CUDA_VISIBLE_DEVICES and unset it for Ray < 2.45.

@vickytsang
Copy link

@HollowMan6 can you share any passing example or tests you have done for these changes?

@HollowMan6
Copy link
Contributor Author

@HollowMan6 can you share any passing example or tests you have done for these changes?

I just did the end-to-end training with my own use cases on AMD MI250x clusters, with ray 2.45 and a patch similar to yours ray-project/ray#52794, my ROCm version is 6.2.4, and my Pytorch version is torch-2.8.0.dev20250417+rocm6.2.4, all the rest are with the latest version. I tested using both FSDP and Megatron separately, with vLLM, which are similar to these examples:

And I tried both settings with or without RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES, and applied patches here, all of them work fine:

Without the above 2 patches, we had errors when we didn't set RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES.

I don't have more resources to test this PR more comprehensively, but if you see any errors, please let me know, and I can help check and figure it out.

@yushengsu-thu
Copy link
Collaborator

@HollowMan6 can you share any passing example or tests you have done for these changes?

I just did the end-to-end training with my own use cases on AMD MI250x clusters, with ray 2.45 and a patch similar to yours ray-project/ray#52794, my ROCm version is 6.2.4, and my Pytorch version is torch-2.8.0.dev20250417+rocm6.2.4, all the rest are with the latest version. I tested using both FSDP and Megatron separately, with vLLM, which are similar to these examples:

And I tried both settings with or without RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES, and applied patches here, all of them work fine:

Without the above 2 patches, we had errors when we didn't set RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES.

I don't have more resources to test this PR more comprehensively, but if you see any errors, please let me know, and I can help check and figure it out.

@HollowMan6 thanks,
Vicky and I will test this on our side.

@yushengsu-thu
Copy link
Collaborator

yushengsu-thu commented May 13, 2025

Update @HollowMan6
Regarding PR1, I've contacted AMD megatron-lm's owner and waited for their merging now link

@wenchenvincent
Copy link

Hi @HollowMan6 I would like to understand the motivation of setting HIP_VISIBLE_DEVICES after ROCm is initialized. Usually these variables are set before the runs start. Any particular reason to manipulate them during the run?

@HollowMan6
Copy link
Contributor Author

HollowMan6 commented May 14, 2025

@wenchenvincent It's mainly because the vLLM needs to manipulate CUDA_VISIBLE_DEVICES after the process starts (so that they have more flexibility), but the manipulation will only be successful before CUDA/ROCm is initialized.

Similar problems are already well-known in the vLLM community side, you can check vllm-project/vllm#6056 for reference.

…AMD support)

Current approach for supporting AMD in verl is fundamentally
not correct, and is just working out of the luck:

Calls such as `torch.cuda.is_available()` or `torch.cuda.get_device_name()`
will initialize the CUDA/ROCm environment:
https://github.com/pytorch/pytorch/blob/c65ee728f069ea9544bdcac815eb0825f45d1633/torch/cuda/__init__.py#L342-L392

Setting CUDA/HIP/ROCR_VISIBLE_DEVICES after CUDA/ROCm is initialized
will not take effect (Please check pytorch/pytorch#141678),
which means that all current code that wrapped inside `[SUPPORT AMD: torch]` are
mostly noops.

CUDA_VISIBLE_DEVICES also works for AMD, but it's because that a lot of AMD migrated
software call those `torch.cuda.*` during importing, e.g.:

- ROCm/TransformerEngine#183
- vllm-project/vllm#15246

While ray/vllm manipulates those *_VISIBLE_DEVICES during runtime, which cause those
`torch.cuda.*` to poison the current process if the CUDA/ROCm environment is
initialized before the manipulation happens.

So, here, it would be a good solution to use only one environment variable for
all (`CUDA_VISIBLE_DEVICES`) for consistency and hardware-agnostic, move all the other `*_VISIBLE_DEVICES` to
the CUDA one. Note that we must pay attention if both HIP/CUDA and ROCR env vars are set as they have
different meanings. Both env vars accept either a list of ints or a list of UUIDs. The ROCR env var is
processed first which then reduces the number of GPUs that HIP can select from.
(Refering to pytorch/pytorch#144026)
To avoid the complexity of this, we simply gives out error if both are set
(Also to keep consistency with ray's practice with 2.45.0).

For the poisoning issue, before those 2 PRs are merged, we will need to ask the users to set
`RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES` or `RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`,
so that ray no longer manipulates these variables, and make verl workable when there is no
`*_VISIBLE_DEVICES`.

Tested manually on both megatron and fsdp beckend with vllm.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants