Update AITER subcommit and refactor internal AITER/CK FA API usage#446
Update AITER subcommit and refactor internal AITER/CK FA API usage#446
Conversation
Micky774
left a comment
There was a problem hiding this comment.
A few comments to hopefully help reviewers.
There was a problem hiding this comment.
This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through setup.py
There was a problem hiding this comment.
You can put this in the comment of this file.
Also don't forget to add the copyright
| int ck_to_aiter_mask_type(mask_enum mask_type, ck_tile::index_t left, ck_tile::index_t right){ | ||
| if( | ||
| mask_type == mask_enum::no_mask || | ||
| mask_type == mask_enum::window_generic | ||
| ) return 0; | ||
| if(left == -1 && right == 0){ | ||
| return mask_type == mask_enum::mask_top_left ? 1 : 2; | ||
| } | ||
| return 3; | ||
| } |
There was a problem hiding this comment.
This is based on their op_tests/cpp/mha/* which isn't the most stable or reliable source, but it's the only concrete location where such a mapping is used or documented.
| void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, bool ck_log_config){ | ||
| if (!ck_log_config) { | ||
| return; | ||
| } | ||
|
|
||
| auto log_value = [](const char* label, const auto& value) { | ||
| std::cout << label << ": " << value << "\n"; | ||
| }; |
There was a problem hiding this comment.
This merely standardizes the logging to make it a bit easier to parse through it at a glance while guaranteeing uniformity. This is similarly implemented across both files. Note the signature has been reduced to account for only what isn't stored in fmha_args.
There was a problem hiding this comment.
Same here. Put those helpful description into the comments
| @@ -499,10 +508,13 @@ hipError_t ck_attn_bwd( | |||
| bool uses_bwd_v3, | |||
| bool is_v3_atomic_fp32, | |||
| int how_v3_bf16_cvt, | |||
| bool is_group_mode, | |||
| const char* func_name, | |||
| bool ck_log_config, | |||
| hipStream_t stream){ | |||
There was a problem hiding this comment.
This abstracted implementation function has a signature that is a superset of the other higher-level API functions in the file, so that both can route directly to this function without affecting the API outside of this file.
| std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}}; | ||
| }(); | ||
| aiter::mha_bwd_args fmha_args{}; | ||
| fmha_args.mask_type = ck_to_aiter_mask_type(mask_type, left, right); |
There was a problem hiding this comment.
This is the AITER mask type, despite the same argument referring to the CK mask type in the FWD pass
There was a problem hiding this comment.
Note that there is a PR for removing this argument which is still pending.
| } | ||
| return hipSuccess; | ||
| } | ||
| hipError_t ck_attn_bwd( |
There was a problem hiding this comment.
Both ck_attn_{varlen_}bwd are wrappers around _ck_attn_bwd_impl + TE-side post-processing kernels. Those kernels can probably be refactored similarly but that's a bit outside the scope here and of dubious value.
setup.py
Outdated
| try: | ||
| subprocess.run( | ||
| sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", | ||
| shell=True, check=True | ||
| ) | ||
| except subprocess.CalledProcessError: | ||
| print("Error checking AITER mha_args usage.") | ||
| sys.exit(1) |
There was a problem hiding this comment.
Explicitly checks the AITER API usage
There was a problem hiding this comment.
Actually you can put your comment in PR to the comments in the source codes
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp
Outdated
Show resolved
Hide resolved
| dladdr((void*)set_aiter_asm_dir, &info); | ||
| auto install_lib_path = std::filesystem::path(info.dli_fname).parent_path() / "aiter"; | ||
| const char* log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); | ||
| auto editable_install_path = std::filesystem::path(info.dli_fname).parent_path().parent_path().parent_path() / "3rdparty" / "aiter" / "hsa"; |
There was a problem hiding this comment.
I recall we also copy hsa in editable installation
There was a problem hiding this comment.
I don't think that we do
There was a problem hiding this comment.
Maybe there's a difference in building from source vs downloading the pre-built AITER lib? My current editable install doesn't have the copied aiter directory under transformer_engine/lib.
There was a problem hiding this comment.
Okay, I think now I understand the root cause. Your aiter commit update need to bring back the hsa dir copy which was removed in Dragan's PR: https://github.com/ROCm/TransformerEngine/pull/402/changes#diff-e0641ac80f912d730377ecbf45edc9573b4ad4a07ee875e474a49f3d57605f1f
Otherwise, the non-editable installation will not have hsa dir in transformer_engine installation dir either
There was a problem hiding this comment.
With the hsa dir copy functionality brought back, we don't need to distinguish the install and editable install path, I believe?
|
|
||
| rm -rf "${AITER_DIR}/aiter/jit/build" | ||
| AITER_LOG_MORE=1 \ | ||
| AITER_LOG_MORE=0 \ |
There was a problem hiding this comment.
Is it possible to link pip install -v to AITER_LOG_MORE?
setup.py
Outdated
| try: | ||
| subprocess.run( | ||
| sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", | ||
| shell=True, check=True | ||
| ) | ||
| except subprocess.CalledProcessError: | ||
| print("Error checking AITER mha_args usage.") | ||
| sys.exit(1) |
There was a problem hiding this comment.
Actually you can put your comment in PR to the comments in the source codes
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
You can put this in the comment of this file.
Also don't forget to add the copyright
| void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, bool ck_log_config){ | ||
| if (!ck_log_config) { | ||
| return; | ||
| } | ||
|
|
||
| auto log_value = [](const char* label, const auto& value) { | ||
| std::cout << label << ": " << value << "\n"; | ||
| }; |
There was a problem hiding this comment.
Same here. Put those helpful description into the comments
| if (!is_group_mode) { | ||
| std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); | ||
| } | ||
| const ck_tile::index_t batch_stride_bias = [&]() -> ck_tile::index_t { |
There was a problem hiding this comment.
This overlaps with the lambda defined in fwd:
Try to consolidate them into util.cpp?
Also it's better to avoid using lambda with catching every reference. Previously I used it because I was copying from ck example.
| bias_enum bias_type = bias_enum::no_bias; | ||
| BiasShape bias_shape = BiasShape::k11SS; | ||
| if(is_group_mode){ | ||
| seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; |
There was a problem hiding this comment.
Make it consistent with bwd:
| const ck_tile::index_t stride_q = stride_s_q; | ||
| const ck_tile::index_t stride_k = stride_s_k; | ||
| const ck_tile::index_t stride_v = stride_s_v; | ||
| // bias of shape (bias_b, bias_h, s_q, s_kv) |
There was a problem hiding this comment.
Can you keep this kind of comments so that later we can understand why we set stride_bias/dbias to max_seqlen_k
| const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; | ||
| const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, B, H, S, D) | ||
| // dbias is of the same shape as bias | ||
| // but ck only take dbias with BHSS |
There was a problem hiding this comment.
Also this comment is crucial for dbias reduction
| nullptr, nullptr, | ||
| nullptr, nullptr, |
There was a problem hiding this comment.
Add some comments so code readers can know quickly what's those 4 nullptrs are for?
| dladdr((void*)set_aiter_asm_dir, &info); | ||
| auto install_lib_path = std::filesystem::path(info.dli_fname).parent_path() / "aiter"; | ||
| const char* log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); | ||
| auto editable_install_path = std::filesystem::path(info.dli_fname).parent_path().parent_path().parent_path() / "3rdparty" / "aiter" / "hsa"; |
ab79b40 to
78f1d69
Compare
|
level 3 run for tracking: https://github.com/ROCm/TransformerEngine/actions/runs/22781278349 |
|
Level 3 run for tracking: https://github.com/ROCm/TransformerEngine/actions/runs/23265047818 (updated for commit |
|
Okay so after patching the race condition w/ Xinya's mutex solution, another bug was revealed through the JAX MGPU tests which is actually a hip launch error due to device ordinal mismatch. This is essentially because the kernel map is keyed by the kernel name only, NOT by device ID, meaning that multiple device threads will map different kernels to the same key, and when loading we end up with the wrong kernels on the wrong devices. I've added a patch to address this issue, with a corresponding AITER PR here: ROCm/aiter#2401 |

Description
Updates AITER subcommit as well as refactors our internal usage of the API for greater clarity and explicitness. TODO: Update AITER commit upon merger of ROCm/aiter#2055
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: