Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion nvshmem4py/nvshmem/core/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
# Collectives
valid_collectives = ["reduce", "reducescatter", "alltoall", "fcollect", "broadcast"]

# Mapping from user-facing dtype aliases to NVSHMEM binding dtype names
# This allows users to use common shorthand like "fp16" and "bf16"
dtype_aliases = {
"fp16": "half",
"bf16": "bfloat16",
}

# Mapping from Cupy/Torch dtypes to NVSHMEM dtype names
external_to_nvshmem_dtypes = {
# --------------------
Expand Down Expand Up @@ -178,9 +185,12 @@ def collective_on_buffer(coll: str, team: Teams, dest: Buffer, src: Buffer, dtyp
else:
size_elem = max(1, size // (n_pes() * dtype_nbytes(dtype)))


func_name = ""
if dtype:
func_name += f"{dtype}_"
# Normalize dtype alias (e.g., fp16 -> half, bf16 -> bfloat16)
normalized_dtype = dtype_aliases.get(dtype, dtype)
func_name += f"{normalized_dtype}_"
if op:
func_name += f"{op}_"
func_name += f"{coll}_on_stream"
Expand Down