Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b64806c
wip(tf32): add TF32 TensorCore GEMM kernel (correctness bug)
m96-chan Dec 13, 2025
914b58f
fix(tf32): WMMA store_matrix_sync alignment bug for N % 8 != 0
m96-chan Dec 13, 2025
baa31de
docs(CLAUDE.md): add kernel development workflow
m96-chan Dec 13, 2025
6e06319
wip(tf32): 2-stage cp.async pipeline with transposed B (correctness bug)
m96-chan Dec 13, 2025
2eb35cc
wip(tf32): BK=32 kernel with manual epilogue fix
m96-chan Dec 13, 2025
b10fa7e
wip(tf32): BK=16 with Bs[BN][BK] col-major layout
m96-chan Dec 13, 2025
8ee0e82
wip(tf32): user rewrite v3 - 51KB smem, 0 spills
m96-chan Dec 13, 2025
0be67dc
wip(tf32): G3 kernel - 40KB smem, BK=16
m96-chan Dec 13, 2025
10967f8
wip(tf32): cp.async for both A and B, row_major fragments (6 TFLOPS)
m96-chan Dec 13, 2025
0b1345e
wip(tf32): revert to simplified kernel structure
m96-chan Dec 13, 2025
a2a70dd
wip(tf32): add launch function (correctness broken)
m96-chan Dec 13, 2025
007f732
wip(tf32): restore 32 TFLOPS kernel (correctness bug)
m96-chan Dec 13, 2025
b256ecb
wip(tf32): 44 TFLOPS kernel (correctness still broken)
m96-chan Dec 13, 2025
2fe874a
wip(tf32): WMMA row_major×row_major verified working
m96-chan Dec 13, 2025
20b78b1
docs(tf32): add WMMA 16x16x8 fragment mapping from dump_fragments
m96-chan Dec 13, 2025
7d01fb9
fix(tf32): correct C fragment mapping for PTX mma.sync m16n8k8
m96-chan Dec 13, 2025
1d69de4
feat(tf32): correct cp.async pipeline achieving 27 TFLOPS
m96-chan Dec 13, 2025
8c79f98
docs(readme): add v0.2.3 TF32 benchmark comparison table
m96-chan Dec 13, 2025
db89f94
docs(readme): fix benchmark numbers with actual cuBLAS data
m96-chan Dec 13, 2025
da41bf7
perf(tf32): optimize kernel with A fragment hoisting (+1.35 TFLOPS)
m96-chan Dec 14, 2025
d00b446
docs: add TF32 kernel optimization summary
m96-chan Dec 14, 2025
ea3700f
docs: remove optimization summary md (moving to Issue #41)
m96-chan Dec 14, 2025
a1c8f3c
fix(lint): remove extraneous f-string prefixes
m96-chan Dec 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 271 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,15 @@ mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32

### 6. Benchmark Expectations (Target)

| GPU | FP32 naive-opt | FP32 MMA | Notes |
|-----|---------------|----------|-------|
| RTX 3090 | 2.1–2.3 TFLOPS | 9+ TFLOPS | TF32 or FP16 |
| GPU | FP32 naive-opt | TF32 TensorCore | Notes |
|-----|---------------|-----------------|-------|
| RTX 3090 Ti | 18 TFLOPS | 27+ TFLOPS | Achieved with cp.async pipeline |
| A100 | 5.5+ TFLOPS | 156 TFLOPS | tensor cores |

**Achieved Results (v0.2.3)**:
- TF32 on RTX 3090 Ti: **27.38 TFLOPS** (8192×8192×8192)
- Correctness: ~3-5% relative error (expected for TF32 precision)

If performance regresses from naive baseline, re-profile.

### 7. CMake Compilation Flags
Expand All @@ -402,6 +406,89 @@ If performance regresses from naive baseline, re-profile.

For portability: allow runtime switch to sm_89, sm_90.

### 8. PTX mma.sync Fragment Mapping (VERIFIED)

**CRITICAL**: PTX inline assembly `mma.sync` has DIFFERENT fragment layouts than WMMA API.
The following mappings were verified empirically using `dump_c_fragment.cu`.

#### PTX `mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32`

Each thread in a warp (lane 0-31) holds:
- **A fragment**: 4 registers (16×8 matrix, row-major)
- **B fragment**: 2 registers (8×8 matrix, col-major)
- **C fragment**: 4 registers (16×8 matrix)

```
A fragment (16×8):
a[0] = A[lane/4][lane%4] // rows 0-7, cols 0-3
a[1] = A[lane/4 + 8][lane%4] // rows 8-15, cols 0-3
a[2] = A[lane/4][lane%4 + 4] // rows 0-7, cols 4-7
a[3] = A[lane/4 + 8][lane%4 + 4] // rows 8-15, cols 4-7

B fragment (8×8):
b[0] = B[lane%4][lane/4] // rows 0-3, cols 0-7
b[1] = B[lane%4 + 4][lane/4] // rows 4-7, cols 0-7

C fragment (16×8) - KEY DIFFERENCE FROM WMMA:
c[0] = C[lane/4][(lane%4)*2] // rows 0-7, cols 0,2,4,6
c[1] = C[lane/4][(lane%4)*2 + 1] // rows 0-7, cols 1,3,5,7
c[2] = C[lane/4 + 8][(lane%4)*2] // rows 8-15, cols 0,2,4,6
c[3] = C[lane/4 + 8][(lane%4)*2 + 1] // rows 8-15, cols 1,3,5,7
```

#### Common Mistakes

1. **C fragment column stride**: PTX uses `(lane%4)*2` (stride 2), NOT `lane%4` (stride 1)
2. **C fragment pairs**: c[0],c[1] are adjacent columns; c[2],c[3] are +8 rows

#### WMMA API vs PTX Inline ASM

| Aspect | WMMA API | PTX mma.sync |
|--------|----------|--------------|
| Fragment types | `wmma::fragment<>` | Raw registers |
| Layout | Opaque (compiler-managed) | Must match PTX spec exactly |
| Flexibility | Limited shapes | Full control |
| Performance | Good | Potentially better |

**Recommendation**: Use PTX for maximum performance, but VERIFY fragment mappings with test code.

### 9. cp.async Double-Buffering Pipeline (CRITICAL)

**Common Bug**: Prefetching into the wrong stage.

#### WRONG (causes correctness bug):
```cpp
// Prefetch kt+2 into stage (kt+2)&1 — WRONG!
// On kt=0, this prefetches into stage 0 while READING from stage 0
for (int kt = 0; kt < num_k_tiles; ++kt) {
int curr = kt & 1;
if (kt + 2 < num_k_tiles) {
load_async((kt+2) & 1, kt + 2); // BUG: overwrites current!
}
process(curr);
}
```

#### CORRECT (simple double-buffering):
```cpp
// Prefetch kt+1 into the OTHER stage
load_async(0, 0);
cp_async_wait_0();

for (int kt = 0; kt < num_k_tiles; ++kt) {
int curr = kt & 1;
int next = curr ^ 1; // OTHER stage

if (kt + 1 < num_k_tiles) {
load_async(next, kt + 1); // Prefetch into OTHER buffer
}
process(curr); // Read from current buffer
cp_async_wait_0();
}
```

**Key Insight**: Always prefetch into the stage you're NOT currently reading from.

---

## Build System
Expand Down Expand Up @@ -453,3 +540,184 @@ For portability: allow runtime switch to sm_89, sm_90.
### Python Components (Orchestration Only)
8. Python API wrappers for Rust scheduler (thin wrappers only)
9. Python API wrappers for Rust memory pool (thin wrappers only)

---

## Kernel Development Workflow (MANDATORY)

カーネル開発時は以下のワークフローを**必ず**守ること:

### 1. 開発サイクル

```
Edit → Build → Validate → Benchmark → Commit
```

**どんな結果でもValidationとBenchmarkが完了したら必ずコミットする。**

### 2. コミットルール

- Validation/Benchmarkが終わったら**結果に関わらず**コミット
- コミットメッセージにベンチマーク結果を必ず記載

### 3. コミットメッセージ形式

```
wip(tf32): <変更内容の要約>

Benchmark results (RTX 3090 Ti):
- 2048x2048: XX.XX TFLOPS
- 4096x4096: XX.XX TFLOPS
- 8192x8192: XX.XX TFLOPS

Correctness: <PASS/FAIL>
```

### 4. 理由

- 高速だったバージョンに戻せなくなることを防ぐ
- パフォーマンスの変化を追跡可能にする
- 試行錯誤の履歴を保存する

---

## Commit Enforcement Rules (ABSOLUTE)

YOU MUST perform a git commit immediately under ANY of the following conditions:

### 1. Benchmark Improvement

If benchmark results improve in ANY matrix size:
- 2048, 4096, or 8192 shows higher TFLOPS than all previous runs
- Improvement = ANY positive increase (even +0.01 TFLOPS)

### 2. Correctness Achievement

If correctness becomes PASS for all tested sizes:
- relative error < 1e-3 for all matrices

### 3. After EVERY Benchmark Execution

- EVEN IF results are worse
- EVEN IF no improvement is observed
- You MUST create a commit with message: `bench: results logged (no improvement)`

### 4. Commit Before Proceeding

- You MUST NOT proceed to next kernel edit UNTIL the commit is complete

### 5. Never Overwrite Without Commit

- You MUST NEVER overwrite a working kernel without committing it first

### 6. Revert on Regression

If performance or correctness DEGRADES:
- You MUST revert to the previous commit BEFORE continuing

**These rules are absolute. No exceptions.**

---

## TF32 TensorCore GEMM Development Notes

### WMMA vs PTX mma.sync

**重要な発見 (2024-12):**

1. **WMMA API** (`nvcuda::wmma`) は動作確認済み
- `row_major` A + `row_major` B の組み合わせで正常動作
- `row_major` A + `col_major` B は**メモリレイアウトの解釈が異なり失敗**

2. **PTX mma.sync** の正しいマッピングはまだ特定中
- m16n8k8 のフラグメントレイアウトが複雑
- WMMA の `debug_dump_fragments` で実際のマッピングを確認可能

### 動作確認済みカーネル

```cpp
// WMMA row_major × row_major (PASS)
fragment<matrix_a, 16, 16, 8, precision::tf32, row_major> a_frag;
fragment<matrix_b, 16, 16, 8, precision::tf32, row_major> b_frag;
fragment<accumulator, 16, 16, 8, float> c_frag;

load_matrix_sync(a_frag, A + k, K); // ldA = K
load_matrix_sync(b_frag, B + k * N, N); // ldB = N (row-major storage)
mma_sync(c_frag, a_frag, b_frag, c_frag);
store_matrix_sync(C, c_frag, N, mem_row_major);
```

### テスト結果 (WMMA row_row)

| M | N | K | max_err | rel_err | Status |
|---|---|---|---------|---------|--------|
| 16 | 16 | 8 | 0.0055 | 0.05% | PASS |
| 16 | 16 | 16 | 0.0089 | 0.07% | PASS |
| 16 | 16 | 32 | 0.0094 | 0.06% | PASS |
| 16 | 16 | 64 | 0.0205 | 0.10% | PASS |
| 16 | 16 | 128 | 0.0247 | 0.08% | PASS |
| 16 | 16 | 256 | 0.0373 | 0.08% | PASS |

### WMMA 16×16×8 フラグメントマッピング (実測値)

`dump_fragments.cu` による実測結果:

#### A fragment (16×8 matrix_a, row_major)
```cpp
// Thread t (0-31):
int a_row = t / 4; // 0-7
int a_col = t % 4; // 0-3

a[0] = A[a_row][a_col] // rows 0-7, cols 0-3
a[1] = A[a_row + 8][a_col] // rows 8-15, cols 0-3
a[2] = A[a_row][a_col + 4] // rows 0-7, cols 4-7
a[3] = A[a_row + 8][a_col + 4] // rows 8-15, cols 4-7
```

#### B fragment (8×16 matrix_b, row_major)
```cpp
// Thread t (0-31):
int b_row = t % 4; // 0-3
int b_col = t / 4; // 0-7

b[0] = B[b_row][b_col] // rows 0-3, cols 0-7
b[1] = B[b_row + 4][b_col] // rows 4-7, cols 0-7
b[2] = B[b_row][b_col + 8] // rows 0-3, cols 8-15
b[3] = B[b_row + 4][b_col + 8] // rows 4-7, cols 8-15
```

#### サイズの違い
| API | A | B | C |
|-----|---|---|---|
| WMMA 16×16×8 | 16×8 | 8×16 | 16×16 |
| PTX m16n8k8 | 16×8 | 8×8 | 16×8 |

PTX m16n8k8 は WMMA の **B/C の左半分** (cols 0-7) のみを使用。

#### C fragment マッピング (実測: dump_c_fragment.cu)
```cpp
int c_row = t / 4; // 0-7
int c_col = (t % 4) * 2; // 0, 2, 4, 6
c[0] = C[c_row][c_col] // rows 0-7, cols even
c[1] = C[c_row][c_col + 1] // rows 0-7, cols odd
c[2] = C[c_row + 8][c_col] // rows 8-15, cols even
c[3] = C[c_row + 8][c_col + 1]// rows 8-15, cols odd
```

### 正確性テスト (C fragment 修正後) - 全 PASS
- 256³〜4096³: rel_err ≈ 8e-4 (0.08%)
- 決定性100回: PASS

### 次のステップ

1. ✅ WMMAの正しいフラグメントマッピングを `dump_fragments` で確認
2. ✅ C fragment マッピングを `dump_c_fragment` で確認・修正
3. ✅ 全正確性テスト PASS
4. パフォーマンス最適化 (現状 11-18 TFLOPS → 目標 22-35 TFLOPS)

### ファイル構成

- `native/ops/matmul_f32_tf32.cuh` - TF32カーネル
- `native/ops/basic.cu` - ディスパッチロジック (line 848-854)
- `dump_fragments.cu` - フラグメントマッピング確認用
- 環境変数 `PYGPUKIT_ALLOW_TF32=1` で有効化
45 changes: 31 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,32 @@ PyGPUkit aims to be the "micro-runtime for GPU computing": small, fast, and idea

---

## v0.2.2 Features (NEW)
## v0.2.3 Features (NEW)

### Ampere-Optimized SGEMM
### TF32 TensorCore GEMM
| Feature | Description |
|---------|-------------|
| **cp.async Pipeline** | 4-stage software pipeline with async memory transfers |
| **Vectorized Loads** | float4 (16-byte) loads for A and B matrices |
| **Shared Memory Tiling** | BM=128, BN=128, BK=16 with 8x8 thread tiles |
| **PTX mma.sync** | Direct TensorCore access via inline PTX assembly |
| **cp.async Pipeline** | Double-buffered async memory transfers |
| **TF32 Precision** | 19-bit mantissa (vs FP32's 23-bit), ~0.1% per-op error |
| **SM 80+ Required** | Ampere architecture (RTX 30XX+) required |

### Performance (RTX 3090 Ti)
| Matrix Size | TFLOPS | Efficiency | vs NumPy |
|-------------|--------|------------|----------|
| 2048x2048 | 7.6 | 19% | 10x |
| 4096x4096 | 13.2 | 33% | 16x |
| 8192x8192 | **18.2** | 46% | **22x** |
### Benchmark Comparison (RTX 3090 Ti, 8192×8192×8192)

| Library | FP32 | TF32 | Notes |
|---------|------|------|-------|
| **NumPy** (OpenBLAS) | ~0.8 TFLOPS | — | CPU baseline |
| **cuBLAS** | ~21 TFLOPS | ~59 TFLOPS | [NVIDIA benchmark](https://forums.developer.nvidia.com/t/a40-and-3090-gemm-performance-test-data/249424) |
| **PyGPUkit** | 18 TFLOPS (86%) | 27 TFLOPS (46%) | Custom kernels |

> FP32 is near cuBLAS level. TF32 optimization ongoing.

### PyGPUkit Performance by Size
| Matrix Size | FP32 | TF32 |
|-------------|------|------|
| 2048×2048 | 7.6 TFLOPS | 10.2 TFLOPS |
| 4096×4096 | 13.2 TFLOPS | 19.5 TFLOPS |
| 8192×8192 | 18.2 TFLOPS | **27.5 TFLOPS** |

### Core Infrastructure (Rust)
| Feature | Description |
Expand Down Expand Up @@ -338,18 +348,25 @@ PyGPUkit/
- [x] 18.2 TFLOPS on RTX 3090 Ti (46% efficiency)
- [x] SM 80+ (Ampere) architecture requirement

### **v0.2.3 — Reliability Phase**
### **v0.2.3 — TF32 TensorCore Phase (Released)**
- [x] TF32 TensorCore GEMM with PTX mma.sync
- [x] cp.async double-buffered pipeline
- [x] 27.5 TFLOPS on RTX 3090 Ti
- [x] PTX fragment mapping documentation

### **v0.2.4 — Benchmark & Reliability Phase**
- [ ] Actual PyTorch/NumPy comparison benchmarks
- [ ] Kernel cache LRU completion
- [ ] Driver-only mode stabilization
- [ ] Windows/Linux full support
- [ ] Large GPU memory test (16GB continuous alloc/free)

### **v0.2.4 — Distributed Phase**
### **v0.2.5 — Distributed Phase**
- [ ] Multi-GPU Detection
- [ ] NCCL / peer-to-peer preliminary support
- [ ] Scheduler multi-device support

### **v0.2.5 — Pre-v0.3 Finalization**
### **v0.2.6 — Pre-v0.3 Finalization**
- [ ] Full API review
- [ ] Backward compatibility policy
- [ ] JIT build options, safety measures, env vars cleanup
Expand Down
Loading