Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal.#432
Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal.#432
Conversation
af08498 to
02fb882
Compare
|
I was assisted by Claude Opus to write this PR. I have read every line added in this PR, and reviewed the lines. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus. |
c7cbb8b to
4a42abe
Compare
| auto elem_num = tensor_type->get_num_elements(); | ||
| spirv::SType elem_type = | ||
| ir_->get_primitive_type(tensor_type->get_element_type()); | ||
| DataType elem_dt = tensor_type->get_element_type(); |
There was a problem hiding this comment.
elem_dt and elem_type are very confusing. Could we either give more intuitive names, or at least add a comment on what is the difference between them?
There was a problem hiding this comment.
It should be better now.
| // float atomics). | ||
| if (alloca->is_shared && is_real(elem_dt)) { | ||
| elem_type = | ||
| ir_->get_primitive_type(ir_->get_quadrants_uint_type(elem_dt)); |
There was a problem hiding this comment.
it's not clear to me from the name what get_quadrants_uint_type does. specifically around nubmer of bits. Could we add a comment to clarify what is happening in this line, specifically around nubmer of bits?
There was a problem hiding this comment.
I should be better now.
| spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); | ||
| auto dt = stmt->element_type().ptr_removed(); | ||
| // Flatten nested tensor types to scalar (e.g., vec3 to f32) | ||
| if (auto nested = dt->cast<TensorType>()) { |
There was a problem hiding this comment.
this seems very similar to what happens above. Could this be factorized into a helper function?
| ir_->get_primitive_type(dt), origin_val.stype.storage_class); | ||
| auto elem_type = ir_->get_primitive_type(dt); | ||
| if (shared_float_retyped_.count(stmt->origin)) { | ||
| elem_type = ir_->get_primitive_type(ir_->get_quadrants_uint_type(dt)); |
There was a problem hiding this comment.
ditto for questoin about helper function.
| std::unordered_map<int, GetRootStmt *> | ||
| root_stmts_; // maps root id to get root stmt | ||
| std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_; | ||
| // Shared float arrays retyped to uint (Metal lacks threadgroup float atomics) |
There was a problem hiding this comment.
can we give a bit more detail about what we are storing here, and why.
| spirv::SType ptr_type = ir_->get_pointer_type( | ||
| ir_->get_primitive_type(dt), origin_val.stype.storage_class); | ||
| auto elem_type = ir_->get_primitive_type(dt); | ||
| if (shared_float_retyped_.count(stmt->origin)) { |
There was a problem hiding this comment.
can we add a comment about what this if statement is checking for intuitively
| spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val); | ||
| ptr_val = ir_->add(origin_val, offset_bytes); | ||
| ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; | ||
| } else if (origin_val.stype.flag == TypeKind::kPtr) { |
There was a problem hiding this comment.
can we add a comment about what this new else if block is checking for
| stmt->op_type == AtomicOpType::add) { | ||
| addr_ptr = at_buffer(stmt->dest, dt); | ||
| } else { | ||
| addr_ptr = dest_is_ptr |
There was a problem hiding this comment.
this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.
| } | ||
|
|
||
| // Shared float arrays are retyped to uint, so native float atomics | ||
| // (which require a float pointer) cannot be used on them. |
There was a problem hiding this comment.
I'm not sure I follow. I thought the purpose of changing the backing type to uint was to enable the spirv atomics? Could you give a little more clarification (in the comments) about this point please.
Also, how do we know if we ar edealing with a shared array here?
There was a problem hiding this comment.
Added some comment on the PR itself to clarify this.
tests/python/test_shared_array.py
Outdated
|
|
||
| @test_utils.test(arch=[qd.cuda]) | ||
| @pytest.mark.parametrize("op", ["add", "sub", "min", "max"]) | ||
| @test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal, qd.amdgpu]) |
There was a problem hiding this comment.
do we have a simpler way of doing this like conceptually:
- not cpu?, or
- gpu?
tests/python/test_shared_array.py
Outdated
| def test_shared_array_float_atomics(op): | ||
| N = 256 | ||
| block_dim = 32 | ||
| total = block_dim * (block_dim - 1) / 2.0 |
There was a problem hiding this comment.
total what? total_threads? Why are we dividing by 2.0? Oh, perhaps we are doing some kind of arithmetic progression or similar, and this is the expected_sum of that progression? Could we update the name to make the meaning more intuitive please. By the way, this calculation could be done using ints. Could we make this something that needs actual floats? Like, e.g. multiply each term in the progression by 0.333, which is pretty incompatible with binary representation.
tests/python/test_shared_array.py
Outdated
| sharr = qd.simt.block.SharedArray((block_dim,), qd.f32) | ||
| sharr[tid] = qd.f32(tid) | ||
| qd.simt.block.sync() | ||
| atomic_fn(sharr[0], qd.f32(tid)) |
There was a problem hiding this comment.
lets multiply qd.f32(tid) by some fractional float, like 0.3333, or maybe something arbitary like 0.1523f
tests/python/test_shared_array.py
Outdated
|
|
||
|
|
||
| @test_utils.test(arch=[qd.cuda], debug=True) | ||
| @test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal]) |
There was a problem hiding this comment.
why is this excluding amdgpu?
again, can we use something simpler?
- 'not cpu', or
- 'gpu'
?
|
Opus review: Thoughts What's good: Things I'd flag:
Overall, this is a well-structured branch. The Metal/Vulkan shared memory work is the kind of backend plumbing that's easy to get wrong, but the approach here is principled and |
|
from the AI review, pelase could we address at least: |
|
(so AI and myself concur about the ambiguity over what get_quadrants_uint_type does) |
4a42abe to
4b0ea62
Compare
| // type | ||
| DataType get_quadrants_uint_type(const DataType &dt) const; | ||
| // Return the SPIR-V uint type with the same bit-width as dt (e.g. f32->u32). | ||
| SType get_bitcast_uint_stype(const DataType &dt) const; |
There was a problem hiding this comment.
stype does at least give some clue, but I would prefer the more explciit _spirv_dtype I feel. (or _spirv_dt is ok for me too, or potentially _spv_dt, if you want it really short.)
There was a problem hiding this comment.
stype is already at many places before this PR. I think it is better to be consistent with the existing naming conventions.
| // Convert a value from float dt to shared-memory uint backing. | ||
| Value float_to_shared_uint(Value val, const DataType &dt); | ||
| // Get the pointer type that points to value_type | ||
| SType get_storage_pointer_type(const SType &value_type); |
There was a problem hiding this comment.
How does the SType here relaet to the stype name above?
There was a problem hiding this comment.
SType is the SPIR-V type struct (defined in spirv_ir_builder.h:48). The _stype suffix in get_bitcast_uint_stype indicates it returns an SType, as opposed to _dtype which returns a Quadrants DataType. get_storage_pointer_type is pre-existing code, not part of this PR.
tests/python/test_shared_array.py
Outdated
|
|
||
|
|
||
| @test_utils.test(arch=[qd.cuda, qd.vulkan, qd.amdgpu]) | ||
| @test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal, qd.amdgpu]) |
There was a problem hiding this comment.
my question about can we simply exclude cpu, or say to run only on gpu seems not to have been addressed? (or I didnt see the response perhaps)
tests/python/test_shared_array.py
Outdated
|
|
||
| @test_utils.test(arch=[qd.cuda]) | ||
| @pytest.mark.parametrize("op", ["add", "sub", "min", "max"]) | ||
| @pytest.mark.parametrize("dtype", [qd.f16, qd.f32]) |
There was a problem hiding this comment.
f64? (seems more common than f16 tbh)
There was a problem hiding this comment.
( i mean, we can test both)
There was a problem hiding this comment.
oh i guess f64 doesnt work on metal right?
There was a problem hiding this comment.
It does not, but it is easy to skip if unsupported in a clean way.
tests/python/test_shared_array.py
Outdated
| rtol = 1e-3 if dtype == qd.f16 else 1e-6 | ||
| arr = qd.ndarray(qd.f32, (N)) | ||
| make_kernel(atomic_op)(arr) | ||
| qd.sync() |
There was a problem hiding this comment.
why do we need a sync? Is this the existing metal bug you've mentioned recenlty?
There was a problem hiding this comment.
We do not actually. I was being extra cautious for no reason.
tests/python/test_shared_array.py
Outdated
| make_kernel(atomic_op)(arr) | ||
| qd.sync() | ||
| assert arr[0] == test_utils.approx(expected[op], rel=rtol) | ||
| assert arr[32] == test_utils.approx(expected[op], rel=rtol) |
There was a problem hiding this comment.
can we check also 31 and 255?
|
For the record, the original SPIRV-Cross fix makes SPIRV-Cross emit atomic_uint instead of atomic_float for threadgroup pointers when the CAS result type is integer but the pointee is float. This would let us keep f32 shared arrays as float-typed and avoid the uint retyping + bitcast overhead on every load/store. However, it only partially helps:
So the SPIRV-Cross fix could simplify the f32 case, but the uint retyping approach is still needed for f16 and is more robust overall. Given that the pre-scan now limits the retyping to only arrays with atomics, the overhead is minimal. |
| // Propagated from shared_atomic_allocs_ to derived MatrixPtrStmt nodes | ||
| // during codegen, so that load/store/atomic visitors know to bitcast. | ||
| // Example: if `sharr` (AllocaStmt) is in shared_atomic_allocs_, then | ||
| // `sharr[0]` (MatrixPtrStmt) is added here during visit(MatrixPtrStmt). |
There was a problem hiding this comment.
Nice explanation thanks. It would reduce my cognitive load if you could give me a case study showing some code (perhaps LLVM IR, or python, no strong preference), and how that maps to these sets, and to the resulting spir-v instructions. Doesn't have to be as comments in the code; could be in the PR description, or in some slides or similar.
There was a problem hiding this comment.
Python kernel:
sharr = qd.simt.block.SharedArray((32,), qd.f32)
sharr[tid] = qd.f32(tid)
qd.simt.block.sync()
qd.atomic_add(sharr[0], qd.f32(tid))
qd.simt.block.sync()
out[i] = sharr[0]IR statements and set membership:
%sharr = AllocaStmt(shared, array<f32, 32>) # in shared_atomic_allocs_ (pre-scan found atomic_add targets it)
# in shared_float_retyped_ (visit(AllocaStmt) retypes to u32)
%ptr0 = MatrixPtrStmt(%sharr, tid) # in shared_float_retyped_ (propagated from %sharr)
%ptr1 = MatrixPtrStmt(%sharr, 0) # in shared_float_retyped_ (propagated from %sharr)
LocalStoreStmt(%ptr0, f32(tid)) # sees %ptr0 in shared_float_retyped_ -> float_to_shared_uint
AtomicOpStmt(add, %ptr1, f32(tid)) # dest_is_ptr=true -> CAS with u32 atomics
%val = LocalLoadStmt(%ptr1) # sees %ptr1 in shared_float_retyped_ -> shared_uint_to_float
Generated SPIR-V (simplified):
; Allocation: u32 array instead of f32 (retyped)
%sharr = OpVariable Workgroup array<u32, 32>
; Store: f32 -> bitcast to u32 -> store
%u_tid = OpBitcast u32 %f_tid
OpStore %sharr[tid] %u_tid
; Atomic add (CAS loop):
%old = OpAtomicLoad u32 %sharr[0]
%old_f = OpBitcast f32 %old ; u32 -> f32
%new_f = OpFAdd f32 %old_f %f_tid ; float add
%new = OpBitcast u32 %new_f ; f32 -> u32
%loaded = OpAtomicCompareExchange u32 %sharr[0] %new %old
; (loop until %loaded == %old)
; Load: load u32 -> bitcast to f32
%raw = OpLoad u32 %sharr[0]
%val = OpBitcast f32 %raw
For f16, the only difference is the array is still u32-backed (not u16, since Metal/Vulkan lack 16-bit atomics), with OpUConvert inserted between the bitcast and the atomic:
; Store f16: bitcast f16->u16, widen u16->u32, store
; Load u32: narrow u32->u16, bitcast u16->f16
; CAS: OpAtomicLoad u32, narrow->bitcast->FAdd->bitcast->widen, OpAtomicCompareExchange u32
| ir_->register_value(const_stmt->raw_name(), val); | ||
| } | ||
|
|
||
| const AllocaStmt *TaskCodegen::trace_to_alloca(const Stmt *s) { |
There was a problem hiding this comment.
could we add a comment to this function saying what it does, and providing an example.
| return nullptr; | ||
| } | ||
|
|
||
| void TaskCodegen::scan_shared_atomic_allocs(Block *block) { |
There was a problem hiding this comment.
So this scans a single block, but writes the results to a class level collection? I wonder if it would be more intuitive/re-usable/testable if we pased that collection in as a function parameter somehow?
There was a problem hiding this comment.
Are there tests for this function? Could we add some?
How do we know that this function is complete? It seems fairly complex.
There was a problem hiding this comment.
So this scans a single block, but writes the results to a class level collection? I wonder if it would be more intuitive/re-usable/testable if we pased that collection in as a function parameter somehow?
Done.
Are there tests for this function? Could we add some?
A hole would cause a shader compilation error, not a silent bug. If the pre-scan misses an atomic target, the array stays float-typed, but the CAS emulation expects a uint pointer. The resulting type mismatch (OpAtomicLoad(u32, ptr_to_f32)) is invalid SPIR-V and gets rejected at compile time. So I don't think more testing is necessary at this point.
2890be3 to
a806667
Compare
Brief Summary
Accompanying PR: Genesis-Embodied-AI/SPIRV-Cross#1