Transposable linear#31
Conversation
When block_scale_2d=True, the backward pass can obtain the transposed weight via a cheap nibble shuffle instead of full re-quantization. 16x16 tile scales are invariant to transposition, making this lossless. New files: - quantize/transpose.py: transpose_quantized_tensor utility - model/modules/transposable_linear.py: TransposableFourOverSixLinear - tests/test_transposable_linear.py: 14 tests covering correctness
Accelerate transpose_quantized_tensor with GPU kernels: - Triton @triton.jit kernel: fused unpack-transpose-repack using even/odd row decomposition into four (HALF_M x HALF_N) sub-matrix transposes, avoiding any intermediate full-size buffer. - CUDA kernel: register-only transpose — each thread loads from source position (k,j) and writes to dest position (j,k), with nibble repack. No shared memory needed. Registered via TORCH_LIBRARY_IMPL. - Backend dispatch: auto-selects Triton > CUDA > PyTorch fallback via TransposeBackend enum on transpose_quantized_tensor(). - Tests parametrized over all available backends with bitwise-identity verification. Standalone JIT-compiled CUDA test for environments where TORCH_LIBRARY registration doesn't fire.
|
Thank you for this huge contribution! I've been super busy lately so my apologies about the delayed review. My only major comment is that I'd prefer not to create a separate linear layer for this: especially if we have tests confirming that this kernel is identical to the existing |
|
Yes of course, my bad I meant to say let's make it the default when both transpose and block_scale_2d are set to True. |
…transpose When both block_scale_2d=True and transpose=True are set, pseudo-quantization now applies RMS diagonal preconditioning followed by a fixed 16x16 Hadamard rotation before quantizing, and undoes both transforms on the result. This is the default (precondition_2d=True in QuantizationConfig) per team request, matching the 'kitchen sink' strategy identified in benchmarks. The preprocessing reduces within-tile magnitude heterogeneity in W.T and improves reconstruction quality on structured model weights: - Weights with row magnitude outliers: ~27% MSE reduction - Transformer-like weights (row + col outliers): ~67% MSE reduction - Uniform/random weights: negligible overhead Only affects the pseudo_quantize=True path (used in STE training). The real-FP4 quantized_matmul path is unchanged. Implementation note: pseudo_quantize(transpose=True) returns W.T layout [in, out], so preconditioning operates on W.T explicitly and calls the backend with transpose=False to avoid double-transposing.
…ixLinear Per review feedback: instead of a separate class, integrate the fast nibble transpose directly into FourOverSixLinear's backward pass. When weight_scale_2d=True (the common training config), the dgrad path now: 1. Quantizes W once (row-major, no transpose flag) 2. Calls transpose_quantized_tensor() for the cheap nibble shuffle 3. Uses the transposed QuantizedTensor in quantized_matmul This replaces the previous re-quantization with transpose=True, which was redundant because 16x16 tile scales are invariant to transposition. The Triton, CUDA, and PyTorch transpose kernels are kept — they provide the actual performance win. Only the separate linear module class is removed. Removed: - TransposableFourOverSixLinear class and test - Exports from model/__init__.py and model/modules/__init__.py Modified: - FourOverSixLinearFunction.backward: added fast transpose path - linear.py: added transpose_quantized_tensor import
|
Done, pushed two new commits: Removed 652 tests pass, 0 new failures. |
The 2D preconditioning (RMS + Hadamard) improvement deserves its own PR with proper forward+backward coverage, not just the transpose path.

Add
TransposableFourOverSixLinearwith Triton/CUDA kernel supportSummary
This PR adds a new
TransposableFourOverSixLinearmodule that eliminates the need to re-quantize the weight matrix in the backward pass when computing dgrad (dX = dY @ W).The current
FourOverSixLinearcallsquantize(weight, transpose=True)on every backward iteration, which runs the full quantization pipeline (scale computation, fake-quantize, pack, blocked layout conversion) on the transposed weight. Whenblock_scale_2d=True, this re-quantization produces FP4 codes that are identical to simply rearranging the nibbles of the already-quantized weight — because 16x16 tile scales are invariant to transposition.TransposableFourOverSixLinearexploits this property: it quantizes the weight once in the forward pass, then obtains the transposedQuantizedTensorvia a cheap nibble shuffle + scale grid transpose in the backward pass. No FP4 codes are recomputed. The operation is mathematically lossless.The nibble transpose is accelerated with Triton and CUDA kernels, with automatic backend selection (Triton > CUDA > PyTorch fallback).
Motivation
We are GlamLabs — we train LoRAs for image generation models on Blackwell GPUs with NVFP4 quantization. For our workload we use the straightforward
block_scale_2d=Truepath without random Hadamard transforms or Quartet II — just plain 2D-block-scaled FP4. This is the simplest and most practical configuration for LoRA finetuning, where you want fast training without complex quantization overhead.In this setup, we noticed the backward pass re-quantizes W on every iteration just to get the transposed packed layout for dgrad. This is entirely redundant when 2D block scaling is used, because:
quantize(W, block_scale_2d=True)transposed ==quantize(W, block_scale_2d=True, transpose=True)The nibble transpose is O(n) with no floating-point arithmetic — just integer bit manipulation and a memory transpose. The full quantize path involves floating-point scale computation, fake-quantization with rounding, BF16-to-FP4 encoding, and blocked layout conversion.
This work grew out of NVFP4-transpose, a standalone library where we explored and benchmarked different strategies for runtime transposition of NVFP4 matrices (approximate nibble shuffle, exact dequant-requant, 2D tile scaling, and joint rounding). The 2D tile scaling approach proved to be the clear winner — lossless transpose with zero additional error — and integrates naturally into fouroversix's existing
block_scale_2dsupport.What changed
New files
src/fouroversix/quantize/transpose.py— Core transpose utility with backend dispatchtranspose_quantized_tensor(qt, backend=TransposeBackend.auto_select):QuantizedTensorwith swapped shapes and correctly laid-out scale factorsTransposeBackendenum:auto_select,triton,cuda,pytorchsrc/fouroversix/kernels/triton/transpose.py— Triton@triton.jitkernel(BLOCK_M//2 × BLOCK_N//2)sub-matrices viatl.trans, and repacks directlysrc/fouroversix/kernels/triton/ops_transpose.py— Triton host launchertranspose_packed_fp4(values, rows, cols)— allocates output, configures grid, launches kernelsrc/fouroversix/csrc/transpose_fp4.cu— CUDA kernel + torch binding(k, j)and writes to destination position(j, k)— the transpose is implicit in this coordinate swapTORCH_LIBRARY_IMPL(fouroversix, CUDA, m)src/fouroversix/model/modules/transposable_linear.py— New linear moduleTransposableFourOverSixLinearFunction(autograd Function):FourOverSixLinearFunction— quantize weight,quantized_matmul(X, W_q)transpose_quantized_tensor()for W^T — notranspose=Truein the quantize configdisable_dgrad_quantizationmodes for compatibilityTransposableFourOverSixLinear(nn.Linear):weight_scale_2d=Trueat construction (raisesValueErrorotherwise)get_quantized_parametersAPI asFourOverSixLinear— drop-in replacement when 2D scaling is already enabledkeep_master_weights=True(training) andFalse(inference/PTQ)tests/test_transposable_linear.py— 27 teststests/test_cuda_transpose_standalone.py— Standalone CUDA kernel test via JIT compilationModified files
src/fouroversix/csrc/bindings.cpp— Addedtranspose_packed_fp4op schemasrc/fouroversix/kernels/triton/__init__.py— Exportstranspose_packed_fp4src/fouroversix/quantize/__init__.py— Exportstranspose_quantized_tensor,TransposeBackendsrc/fouroversix/model/modules/__init__.py— ExportsTransposableFourOverSixLinearsrc/fouroversix/model/__init__.py— ExportsTransposableFourOverSixLinearHow it works
Current backward (dgrad)
New backward (dgrad)
Kernel design
Both kernels use the same algebraic insight: the full
BLOCK_M × BLOCK_Ncode transpose decomposes into four(BLOCK_M/2 × BLOCK_N/2)sub-matrix transposes by separating even/odd source rows and low/high nibbles.Triton: Uses
tl.trans()on four sub-matrices in registers, strided loads for even/odd rows, strided stores for even/odd destination rows.CUDA: Even simpler — each thread
(k, j)loads one packed byte from source even-row and one from odd-row, extracts four nibbles, repacks them, and writes to destination at(j, k). The transpose is implicit in the coordinate swap; no shared memory needed.Tests
All 27 tests pass:
test_transpose_matches_quantize_with_transpose_flag(16 parametrizations: 4 shapes × 2 scale rules × {pytorch, triton})quantize(W, transpose=True)— bitwise identicaltest_transpose_roundtrip_lossless(4 parametrizations: 2 shapes × {pytorch, triton})test_all_backends_bitwise_identical(3 parametrizations)test_transposable_linear_requires_2d_scalesValueErrorwhenweight_scale_2d=Falsetest_transposable_linear_forward_matches_originalFourOverSixLineartest_transposable_linear_backward_runstest_transposable_linear_backward_matches_originaltest_cuda_transpose_standalone.pyUsage
Controlling the backend
Scope and limitations
This targets the plain
block_scale_2d=Truetraining path — the most common configuration for LoRA finetuning on Blackwell. We intentionally don't touch:pseudo_quantizemode — same forward-path simplification as the existing moduleThe CUDA kernel requires building the C++ extension (
SKIP_CUDA_BUILD=1disables it). The Triton kernel works on any CUDA GPU without a build step.Related work