Skip to content

[RFC] NVFP4 Rounding Modes #3264

@syed-ahmed

Description

@syed-ahmed

The current rounding mode for NVFP4 tensors in TorchAO is round-to-nearest. The purpose of this issue is to discuss support for other rounding modes.

What rounding modes are available?

  • Stochastic Rounding (RS)
  • Round Nearest (RN)
  • Round-zero (RZ)

Where do we need different rounding modes?

  • NVFP4 Training Recipe (NVFP4 MoE Training Status torchtitan#1962)
    • RS for gradients
    • RN for weights and activation
  • _AdamW in torchao.optim supports BF16 stochastic rounding:
    ```python
    # a clone of torch.optim.AdamW with extra features
    from torchao.optim import _AdamW
    model = ...
    model_bf16 = model.bfloat16()
    optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True)
  • INT8 Quantization has stochastic rounding mode in TorchAO:
    def quantize_int8_rowwise(
    tensor: Tensor, stochastic_rounding: bool = False, eps: float = 1e-12
    ):

Existing RN Kernels

  • Eager path:
    def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
    """Convert FP32 numbers to sub-byte floating point numbers with the given
    number of exponent and mantissa bits.
  • torch.compile:
    def convert_fp32_to_fp4_packed(x_pairs):
    """Convert FP32 pairs to packed FP4 format.
    • Uses cvt.rn.satfinite.e2m1x2.f32 inline asm

Possible RS Kernels implementation

  • Emulated implementation from @slayton58. I've quickly written this in triton syntax but it probably makes most sense to write in pytorch eager similar to RN (_f32_to_floatx_unpacked).
    @triton.jit
    def float_rs(x, seed, offset):
        """
        Apply stochastic rounding when casting from float32 to NVFP4.
        
        Args:
            x: Input tensor (float32)
            seed: Random seed for the random number generator
            offset: Offset for random number generation (should be unique per element)
        
        Returns:
            Stochastically rounded tensor
        """
        
        # Scale down by 2^(-125) to normalize range
        downscale_factor = tl.math.exp2(-125.0)
        x = x * downscale_factor
        
        # Create 32-bit pseudorandom value
        rnd = tl.randint(seed, offset)
        
        # Isolate lower 22 bits for randomness injection
        # Process: left-shift by 10, then right-shift by 10
        rnd_shifted = (rnd << 10) >> 10
        
        # Reinterpret float bits as unsigned integer
        xb = x.to(tl.uint32, bitcast=True)
        
        # Inject randomness into the discarded precision bits
        yb = xb + rnd_shifted
        
        # Clear the lower 22 bits to perform rounding
        yb = (yb >> 22) << 22
        
        # Reinterpret integer bits back as floating point
        y = yb.to(tl.float32, bitcast=True)
        
        # Restore original magnitude by scaling up
        upscale_factor = tl.math.exp2(125.0)
        y = y * upscale_factor
        
        return y
  • Use an inline asm triton kernel using cvt.rs.satfinite.e2m1x4.f32 for stochastic rounding similar to RN.

Integration

  • A possible integration point for NVFP4 Training Recipe use case is to specify the rounding mode in to_nvfp4 calls.
    class RoundingMode(Enum):
        RN = "round_nearest"
        RS = "round_stochastic"
        RZ = "round_zero"
    
    def to_nvfp4(
            data_hp: torch.Tensor,
            block_size: int = 16,
            per_tensor_scale: Optional[torch.Tensor] = None,
            act_per_tensor_scale: Optional[torch.Tensor] = None,
            is_swizzled_scales: bool = False,
            use_triton_kernel: bool = False,
            act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
            rounding_mode: RoundingMode = RoundingMode.RN,
        ):
        ...
        if use_triton_kernel:
            blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale, rounding_mode)
        else:
            blockwise_scales, data_lp = nvfp4_quantize(
                data_hp, block_size, per_tensor_scale, rounding_mode
            )
  • We should discuss if we need to support rounding mode more generically to support other use cases like _AdamW, and int8 training.

Test Plan

  • TODO

CC: @slayton58, @ngimel, @supriyar, @Priyadlfw, @ptrblck, @eqy

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions