diff --git a/nvshmem4py/nvshmem/core/collective.py b/nvshmem4py/nvshmem/core/collective.py index f678561..ad2630f 100644 --- a/nvshmem4py/nvshmem/core/collective.py +++ b/nvshmem4py/nvshmem/core/collective.py @@ -43,6 +43,13 @@ # Collectives valid_collectives = ["reduce", "reducescatter", "alltoall", "fcollect", "broadcast"] +# Mapping from user-facing dtype aliases to NVSHMEM binding dtype names +# This allows users to use common shorthand like "fp16" and "bf16" +dtype_aliases = { + "fp16": "half", + "bf16": "bfloat16", +} + # Mapping from Cupy/Torch dtypes to NVSHMEM dtype names external_to_nvshmem_dtypes = { # -------------------- @@ -178,9 +185,12 @@ def collective_on_buffer(coll: str, team: Teams, dest: Buffer, src: Buffer, dtyp else: size_elem = max(1, size // (n_pes() * dtype_nbytes(dtype))) + func_name = "" if dtype: - func_name += f"{dtype}_" + # Normalize dtype alias (e.g., fp16 -> half, bf16 -> bfloat16) + normalized_dtype = dtype_aliases.get(dtype, dtype) + func_name += f"{normalized_dtype}_" if op: func_name += f"{op}_" func_name += f"{coll}_on_stream"