Skip to content

Sage attn integration wip#1

Open
weiyanlin117 wants to merge 21 commits intounstablefrom
sage_attn_integration
Open

Sage attn integration wip#1
weiyanlin117 wants to merge 21 commits intounstablefrom
sage_attn_integration

Conversation

@weiyanlin117
Copy link
Owner

sanity test passed

ccv_nnc_tensor_free(bt);
}

TEST_CASE("direct sage vs flash attention NHD performance test")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a perf test

make -C test/int/nnc cublas.tests && test/int/nnc/cublas.tests " performance test"

=== Direct SageAttention NHD Performance Test ===
Testing performance with memory reuse across 1000 runs

[Trial 0] Config: B=32, R=160, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (1342177280 total_Flops):
GPU Execution Time: 0.0175 ms (average of 10000 runs with memory reuse)

[Trial 1] Config: B=12, R=256, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1610612736 total_Flops):
GPU Execution Time: 0.0191 ms (average of 10000 runs with memory reuse)

[Trial 2] Config: B=16, R=128, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1073741824 total_Flops):
GPU Execution Time: 0.0139 ms (average of 10000 runs with memory reuse)

[Trial 3] Config: B=1, R=77, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (20185088 total_Flops):
GPU Execution Time: 0.0073 ms (average of 10000 runs with memory reuse)

[Trial 4] Config: B=2, R=77, C=128, Hq=8, Hk=2, D=64, causal=0
Full tensor comparison (40370176 total_Flops):
GPU Execution Time: 0.0073 ms (average of 10000 runs with memory reuse)

[Trial 5] Config: B=1, R=5, C=5, Hq=32, Hk=8, D=128, causal=0
Full tensor comparison (409600 total_Flops):
GPU Execution Time: 0.0071 ms (average of 10000 runs with memory reuse)

=== Direct SageAttention NHD Performance Summary ===
Individual trial times (ms):
Trial 0: 0.0175 ms
Trial 1: 0.0191 ms
Trial 2: 0.0139 ms
Trial 3: 0.0073 ms
Trial 4: 0.0073 ms
Trial 5: 0.0071 ms
Fastest trial: 5 (0.0071 ms)
Slowest trial: 1 (0.0191 ms)

[Trial 0] Config: B=32, R=160, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (1342177280 total_flops):
GPU Execution Time: 0.0945 ms (average of 10000 runs with memory reuse)

[Trial 1] Config: B=12, R=256, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1610612736 total_flops):
GPU Execution Time: 0.0932 ms (average of 10000 runs with memory reuse)

[Trial 2] Config: B=16, R=128, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1073741824 total_flops):
GPU Execution Time: 0.0870 ms (average of 10000 runs with memory reuse)

[Trial 3] Config: B=1, R=77, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (20185088 total_flops):
GPU Execution Time: 0.0116 ms (average of 10000 runs with memory reuse)

[Trial 4] Config: B=2, R=77, C=128, Hq=8, Hk=2, D=64, causal=0
Full tensor comparison (40370176 total_flops):
GPU Execution Time: 0.0117 ms (average of 10000 runs with memory reuse)

[Trial 5] Config: B=1, R=5, C=5, Hq=32, Hk=8, D=128, causal=0
Full tensor comparison (409600 total_flops):
GPU Execution Time: 0.0094 ms (average of 10000 runs with memory reuse)

=== FlashAttention Performance Summary ===
Individual trial times (ms):
Trial 0: 0.0945 ms
Trial 1: 0.0932 ms
Trial 2: 0.0870 ms
Trial 3: 0.0116 ms
Trial 4: 0.0117 ms
Trial 5: 0.0094 ms
Fastest trial: 5 (0.0094 ms)
Slowest trial: 0 (0.0945 ms)

2, // qk_quant_gran: 2=per_warp
scale, // sm_scale
0, // return_lse: false
1, // pv_accum_dtype: 2=FP32
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using FP32 (2)

will be 10% slower

[1/1] [RUN] direct sage vs flash attention NHD performance test ...
=== Direct SageAttention NHD Performance Test ===
Testing performance with memory reuse across 1000 runs

[Trial 0] Config: B=32, R=160, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (1342177280 total_Flops):
GPU Execution Time: 0.0199 ms (average of 10000 runs with memory reuse)

[Trial 1] Config: B=12, R=256, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1610612736 total_Flops):
GPU Execution Time: 0.0197 ms (average of 10000 runs with memory reuse)

[Trial 2] Config: B=16, R=128, C=128, Hq=8, Hk=8, D=128, causal=0
Full tensor comparison (1073741824 total_Flops):
GPU Execution Time: 0.0157 ms (average of 10000 runs with memory reuse)

[Trial 3] Config: B=1, R=77, C=128, Hq=8, Hk=8, D=64, causal=0
Full tensor comparison (20185088 total_Flops):
GPU Execution Time: 0.0079 ms (average of 10000 runs with memory reuse)

[Trial 4] Config: B=2, R=77, C=128, Hq=8, Hk=2, D=64, causal=0
Full tensor comparison (40370176 total_Flops):
GPU Execution Time: 0.0079 ms (average of 10000 runs with memory reuse)

[Trial 5] Config: B=1, R=5, C=5, Hq=32, Hk=8, D=128, causal=0
Full tensor comparison (409600 total_Flops):
GPU Execution Time: 0.0079 ms (average of 10000 runs with memory reuse)

=== Direct SageAttention NHD Performance Summary ===
Individual trial times (ms):
Trial 0: 0.0199 ms
Trial 1: 0.0197 ms
Trial 2: 0.0157 ms
Trial 3: 0.0079 ms
Trial 4: 0.0079 ms
Trial 5: 0.0079 ms
Fastest trial: 5 (0.0079 ms)
Slowest trial: 0 (0.0199 ms)

FP16_FP32 mixed:
=== Direct SageAttention NHD Performance Summary ===
Individual trial times (ms):
Trial 0: 0.0175 ms
Trial 1: 0.0193 ms
Trial 2: 0.0140 ms
Trial 3: 0.0072 ms
Trial 4: 0.0072 ms
Trial 5: 0.0070 ms
Fastest trial: 5 (0.0070 ms)
Slowest trial: 1 (0.0193 ms)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
    - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.

@weiyanlin117
Copy link
Owner Author

sage attn CMD without memory cache
(only calculate before and after CMD triggering)

[Trial 0] Config: B=32, R=160, C=128, Hq=8, Hk=8, D=64, causal=0
GPU Execution Time: 0.4437 ms (average of 1000 runs)

[Trial 1] Config: B=12, R=256, C=128, Hq=8, Hk=8, D=128, causal=0
GPU Execution Time: 0.2966 ms (average of 1000 runs)

sage attn CMD with memory cache

[Trial 0] Config: B=32, R=160, C=128, Hq=8, Hk=8, D=64, causal=0
GPU Execution Time: 0.2108 ms (average of 1000 runs)

[Trial 1] Config: B=12, R=256, C=128, Hq=8, Hk=8, D=128, causal=0
GPU Execution Time: 0.1605 ms (average of 1000 runs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant