-
Notifications
You must be signed in to change notification settings - Fork 24
NVFP4 cast/transpose without TMA #472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
matthiasdiener
wants to merge
55
commits into
dev
Choose a base branch
from
mdiener/fp4-cast-transpose
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+208
−34
Open
Changes from all commits
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
b8a4024
[ROCm] resolve the conflicts in common dir
wangye805 0519b4b
[ROCm] resolve the conflicts on jax side
wangye805 8f4b04d
[ROCm] resolve the conflicts on pytorch side
wangye805 e60ff21
[ROCm] resolve the conflicts in setup
wangye805 8bbb162
[ROCm] resolve the cpp gtest
wangye805 f573b40
[ROCm] resolve pytorch and jax tests
alextmagro eaaae94
pytest, example, wheels conflict resolution
alextmagro 8f94cf6
jax and pytorch bugfix
alextmagro bac7993
copyrights and fp8_autocast->autocast fix
alextmagro 8ae38e8
Enable test_distributed_dense.py
alextmagro 05a977a
address IFU comments
alextmagro 0385852
_FormatHelperFP8 and missing file add
alextmagro 46d382d
add use_async_d2h_group_size as a test parameter
alextmagro 15416f1
enable FP4 tests
matthiasdiener bac5096
rough initial version
matthiasdiener da24223
initial working version
matthiasdiener c03b7bb
Addressing comments and small fixes
alextmagro c453dba
various cleanups
matthiasdiener 4a843ba
manually update runner labels
matthiasdiener 316dffb
Comment cleanup
alextmagro 8a47bc5
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener 5c747bd
only enable on gfx950
matthiasdiener db56b8f
Update jax gemm.py
alextmagro b318bda
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener 62eea94
Revert "only enable on gfx950"
matthiasdiener 6d459ec
reenable in NVTEDType
matthiasdiener 6eb2707
Fix dev merge conflicts
alextmagro 8cec975
enable in bwd_helper
matthiasdiener c20e0e9
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener ccda439
alignment fixes
matthiasdiener 4b0fd34
fix merge error
matthiasdiener 84934c2
minor fixes
matthiasdiener e79134a
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener 586bd09
Run CI
leo-amd 4896edf
Merge branch 'dev' into mdiener/fp4-cast-transpose
matthiasdiener aa18e9a
more scales fixing
matthiasdiener c918a19
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener 5bd7388
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener 95d0c9f
address review comments
matthiasdiener 6cd6038
adjust error message slightly
matthiasdiener 55a8c84
simplify via hipify map
matthiasdiener 10d88bf
adjust more error messages
matthiasdiener b4caf6f
change disabling of header includes
matthiasdiener 511db61
address review comments
matthiasdiener 36cf73a
implement SR
matthiasdiener a85f68f
simplify slightly
matthiasdiener f4f5ec9
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener a607feb
address review comments
matthiasdiener ca2e444
bugfix arch SR support
matthiasdiener 5a5803c
use scale constants
matthiasdiener d36ccbd
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener fc5af65
simplify to use __hip_fp4x4_storage_t directly
matthiasdiener 94a4e5e
simplify storage for bit fiddling
matthiasdiener 82af544
allow null amax in fallback kernel
matthiasdiener 56fefaf
minor cleanup
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| /************************************************************************* | ||
| * This file was modified for portability to AMDGPU | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
|
|
@@ -30,14 +32,29 @@ enum ActivationType { | |
| SReLU | ||
| }; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| static constexpr float E2M1_LUT[16] = { | ||
| 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, | ||
| -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, | ||
| }; | ||
| #endif | ||
|
|
||
| double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| uint8_t raw = *reinterpret_cast<uint8_t*>(&fp4_pair); | ||
| // Decode manually | ||
| float lo = E2M1_LUT[raw & 0xF]; | ||
| float hi = E2M1_LUT[(raw >> 4) & 0xF]; | ||
| return {static_cast<double>(lo), static_cast<double>(hi)}; | ||
| #else | ||
| const __half2_raw raw_truncated_to_fp4e2m1_pair = | ||
| __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); | ||
|
|
||
| const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); | ||
| const double truncated_to_fp4e2m1_x = static_cast<double>(truncated_to_fp4e2m1_pair.x); | ||
| const double truncated_to_fp4e2m1_y = static_cast<double>(truncated_to_fp4e2m1_pair.y); | ||
| return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; | ||
| #endif | ||
| } | ||
|
|
||
| template <typename InputType> | ||
|
|
@@ -567,7 +584,18 @@ void performTest(float (*OP)(const float), | |
| // Set 2nd stage NVFP4 scaling factor | ||
| output.set_scale(amax); | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| bool use_2d_quantization = false; | ||
| #else | ||
| // Test both 1D and 2D quantization paths on AMDGPU, | ||
| // as well as stochastic rounding. | ||
| hipDeviceProp_t prop; | ||
| hipGetDeviceProperties(&prop, 0); | ||
| const bool is_gfx950 = std::string(prop.gcnArchName).find("gfx950") != std::string::npos; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prop.major == 9 && prop.minor == 5 |
||
| for (bool use_stochastic_rounding : (is_gfx950 ? std::vector<bool>{false, true} | ||
| : std::vector<bool>{false})) { | ||
| for (bool use_2d_quantization : {false, true}) { | ||
| #endif | ||
|
|
||
| compute_ref<InputType>(OP, | ||
| input.rowwise_cpu_dptr<InputType>(), | ||
|
|
@@ -589,7 +617,11 @@ void performTest(float (*OP)(const float), | |
| rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed | ||
| rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence | ||
| rng_state.from_cpu(); | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| quant_config.set_stochastic_rounding(use_stochastic_rounding); | ||
| #else | ||
| quant_config.set_stochastic_rounding(false); | ||
| #endif | ||
| quant_config.set_rng_state(rng_state.data()); | ||
|
|
||
| // Set 2D quantization based on compile-time flag | ||
|
|
@@ -631,15 +663,29 @@ void performTest(float (*OP)(const float), | |
| const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); | ||
|
|
||
| size_t scale_mismatches_num = 0; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| std::vector<size_t> mismatches_scales_indices; | ||
| #endif | ||
|
|
||
| compare_scaling_factors<fp8e4m3>("scales", output.rowwise_cpu_scale_inv_ptr<fp8e4m3>(), | ||
| ref_scales.get(), | ||
| unpadded_blocks_Y, unpadded_blocks_X, scales_stride, | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| mismatches_scales_indices, | ||
| #endif | ||
| scale_mismatches_num); | ||
|
|
||
| compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(), | ||
| ref_scales_t.get(), | ||
| unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| mismatches_scales_indices, | ||
| #endif | ||
| scale_mismatches_num); | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| } // for (bool use_2d_quantization : {false, true}) { | ||
| } // for (bool use_stochastic_rounding : (is_gfx950 ? std::vector<bool>{false, true} : std::vector<bool>{false})) { | ||
| #endif | ||
| } | ||
|
|
||
| std::vector<std::vector<size_t>> tensor_dims = { | ||
|
|
@@ -674,10 +720,12 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam | |
| transformer_engine::DType>> {}; | ||
|
|
||
| TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| // Skip tests for pre-Blackwell architectures | ||
| if (getDeviceComputeCapability() < blackwellComputeCapability) { | ||
| GTEST_SKIP(); | ||
| } | ||
| #endif | ||
|
|
||
| using namespace transformer_engine; | ||
| using namespace test; | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.