-
Notifications
You must be signed in to change notification settings - Fork 72
Support scalar pair ABI #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,10 @@ use rspirv::dr::Operand; | |
use rspirv::spirv::{ | ||
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, | ||
}; | ||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _}; | ||
use rustc_codegen_ssa::traits::{ | ||
BaseTypeCodegenMethods, BuilderMethods, ConstCodegenMethods, LayoutTypeCodegenMethods, | ||
MiscCodegenMethods as _, | ||
}; | ||
use rustc_data_structures::fx::FxHashMap; | ||
use rustc_errors::MultiSpan; | ||
use rustc_hir as hir; | ||
|
@@ -86,22 +89,7 @@ impl<'tcx> CodegenCx<'tcx> { | |
}; | ||
for (arg_abi, hir_param) in fn_abi.args.iter().zip(hir_params) { | ||
match arg_abi.mode { | ||
PassMode::Direct(_) => {} | ||
PassMode::Pair(..) => { | ||
// FIXME(eddyb) implement `ScalarPair` `Input`s, or change | ||
// the `FnAbi` readjustment to only use `PassMode::Pair` for | ||
// pointers to `!Sized` types, but not other `ScalarPair`s. | ||
if !matches!(arg_abi.layout.ty.kind(), ty::Ref(..)) { | ||
self.tcx.dcx().span_err( | ||
hir_param.ty_span, | ||
format!( | ||
"entry point parameter type not yet supported \ | ||
(`{}` has `ScalarPair` ABI but is not a `&T`)", | ||
arg_abi.layout.ty | ||
), | ||
); | ||
} | ||
} | ||
PassMode::Direct(_) | PassMode::Pair(..) => {} | ||
// FIXME(eddyb) support these (by just ignoring them) - if there | ||
// is any validation concern, it should be done on the types. | ||
PassMode::Ignore => self.tcx.dcx().span_fatal( | ||
|
@@ -442,6 +430,33 @@ impl<'tcx> CodegenCx<'tcx> { | |
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs); | ||
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self); | ||
|
||
// In compute shaders, user-provided data must come from buffers or push | ||
// constants, i.e. by-reference parameters. | ||
if execution_model == ExecutionModel::GLCompute | ||
&& matches!(entry_arg_abi.mode, PassMode::Direct(_) | PassMode::Pair(..)) | ||
&& !matches!(entry_arg_abi.layout.ty.kind(), ty::Ref(..)) | ||
&& attrs.builtin.is_none() | ||
{ | ||
let param_name = if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { | ||
ident.name.to_string() | ||
} else { | ||
"parameter".to_string() | ||
}; | ||
self.tcx | ||
.dcx() | ||
.struct_span_err( | ||
hir_param.ty_span, | ||
format!("compute entry parameter `{param_name}` must be by-reference",), | ||
) | ||
.with_help(format!( | ||
"consider changing the type to `&{}`", | ||
entry_arg_abi.layout.ty | ||
)) | ||
.emit(); | ||
// Keep this a hard error to stop compilation after emitting help. | ||
self.tcx.dcx().abort_if_errors(); | ||
} | ||
|
||
let (var_id, spec_const_id) = match storage_class { | ||
// Pre-allocate the module-scoped `OpVariable` *Result* ID. | ||
Ok(_) => ( | ||
|
@@ -491,14 +506,6 @@ impl<'tcx> CodegenCx<'tcx> { | |
vs layout:\n{value_layout:#?}", | ||
entry_arg_abi.layout.ty | ||
); | ||
if is_pair && !is_unsized { | ||
// If PassMode is Pair, then we need to fill in the second part of the pair with a | ||
// value. We currently only do that with unsized types, so if a type is a pair for some | ||
// other reason (e.g. a tuple), we bail. | ||
self.tcx | ||
.dcx() | ||
.span_fatal(hir_param.ty_span, "pair type not supported yet") | ||
} | ||
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"? | ||
// FIXME(eddyb) should we talk about "descriptor indexing" or | ||
// actually use more reasonable terms like "resource arrays"? | ||
|
@@ -621,8 +628,8 @@ impl<'tcx> CodegenCx<'tcx> { | |
} | ||
} | ||
|
||
let value_len = if is_pair { | ||
// We've already emitted an error, fill in a placeholder value | ||
let value_len = if is_pair && is_unsized { | ||
// For wide references (e.g., slices), the second component is a length. | ||
Some(bx.undef(self.type_isize())) | ||
} else { | ||
None | ||
|
@@ -645,21 +652,54 @@ impl<'tcx> CodegenCx<'tcx> { | |
_ => unreachable!(), | ||
} | ||
} else { | ||
assert_matches!(entry_arg_abi.mode, PassMode::Direct(_)); | ||
|
||
let value = match storage_class { | ||
Ok(_) => { | ||
match entry_arg_abi.mode { | ||
PassMode::Direct(_) => { | ||
let value = match storage_class { | ||
Ok(_) => { | ||
assert_eq!(storage_class, Ok(StorageClass::Input)); | ||
bx.load( | ||
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx), | ||
value_ptr.unwrap(), | ||
entry_arg_abi.layout.align.abi, | ||
) | ||
} | ||
Err(SpecConstant { .. }) => { | ||
spec_const_id.unwrap().with_type(value_spirv_type) | ||
} | ||
}; | ||
call_args.push(value); | ||
assert_eq!(value_len, None); | ||
} | ||
PassMode::Pair(..) => { | ||
// Load both elements of the scalar pair from the input variable. | ||
assert_eq!(storage_class, Ok(StorageClass::Input)); | ||
bx.load( | ||
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx), | ||
value_ptr.unwrap(), | ||
entry_arg_abi.layout.align.abi, | ||
) | ||
let layout = entry_arg_abi.layout; | ||
let (a, b) = match layout.backend_repr { | ||
rustc_abi::BackendRepr::ScalarPair(a, b) => (a, b), | ||
other => span_bug!( | ||
hir_param.ty_span, | ||
"ScalarPair expected for entry param, found {other:?}" | ||
), | ||
}; | ||
let b_offset = a | ||
.primitive() | ||
.size(self) | ||
.align_to(b.primitive().align(self).abi); | ||
|
||
let elem0_ty = self.scalar_pair_element_backend_type(layout, 0, false); | ||
let elem1_ty = self.scalar_pair_element_backend_type(layout, 1, false); | ||
|
||
let base_ptr = value_ptr.unwrap(); | ||
let ptr1 = bx.inbounds_ptradd(base_ptr, self.const_usize(b_offset.bytes())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use fn inbounds_ptradd(&mut self, ptr: Self::Value, offset: Self::Value) -> Self::Value {
self.inbounds_gep(self.cx().type_i8(), ptr, &[offset])
} I'm also not so sure you want the resulting type to implicitly be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is what I am not sure about, doing pointer arithmetic is bad and I was going to switch it to gep but wanted eddyb to chime in. |
||
|
||
let v0 = bx.load(elem0_ty, base_ptr, layout.align.abi); | ||
let v1 = bx.load(elem1_ty, ptr1, layout.align.restrict_for_offset(b_offset)); | ||
call_args.push(v0); | ||
call_args.push(v1); | ||
assert_eq!(value_len, None); | ||
} | ||
Err(SpecConstant { .. }) => spec_const_id.unwrap().with_type(value_spirv_type), | ||
}; | ||
call_args.push(value); | ||
assert_eq!(value_len, None); | ||
_ => unreachable!(), | ||
} | ||
} | ||
|
||
// FIXME(eddyb) check whether the storage class is compatible with the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// compile-fail | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
w: (u32, u32), | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
) { | ||
out[0] = w.0 + w.1; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
error: compute entry parameter `w` must be by-reference | ||
--> $DIR/compute_value_pair_fail.rs:8:8 | ||
| | ||
8 | w: (u32, u32), | ||
| ^^^^^^^^^^ | ||
| | ||
= help: consider changing the type to `&(u32, u32)` | ||
|
||
error: aborting due to 1 previous error | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(f32, u32), | ||
) { | ||
out[0] = w.0.to_bits() ^ w.1; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// build-pass | ||
// compile-flags: -C target-feature=+Int64 | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(i32, i32), | ||
) { | ||
// Sum and reinterpret as u32 for output | ||
let s = (w.0 as i64 + w.1 as i64) as i32; | ||
out[0] = s as u32; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, f32), | ||
) { | ||
let a = w.0; | ||
let b_bits = w.1.to_bits(); | ||
out[0] = a ^ b_bits; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[repr(transparent)] | ||
pub struct Wrap((u32, u32)); | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &Wrap, | ||
) { | ||
let a = (w.0).0; | ||
let b = (w.0).1; | ||
out[0] = a + b; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[repr(transparent)] | ||
pub struct Inner((u32, u32)); | ||
|
||
#[repr(transparent)] | ||
pub struct Outer( | ||
core::mem::ManuallyDrop<Inner>, | ||
core::marker::PhantomData<()>, | ||
); | ||
|
||
#[inline(never)] | ||
fn sum_outer(o: Outer) -> u32 { | ||
// SAFETY: repr(transparent) guarantees same layout as `Inner`. | ||
let i: Inner = unsafe { core::mem::ManuallyDrop::into_inner((o.0)) }; | ||
(i.0).0 + (i.0).1 | ||
} | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32]) { | ||
let i = Inner((5, 7)); | ||
let o = Outer(core::mem::ManuallyDrop::new(i), core::marker::PhantomData); | ||
out[0] = sum_outer(o); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(fragment)] | ||
pub fn main( | ||
#[spirv(flat)] pi: (u32, u32), | ||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32], | ||
) { | ||
out[0] = pi.0.wrapping_add(pi.1); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(vertex)] | ||
pub fn main( | ||
p: (u32, u32), | ||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32], | ||
) { | ||
out[0] = p.0.wrapping_add(p.1); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// build-pass | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, u32), | ||
) { | ||
out[0] = w.0 + w.1; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// build-pass | ||
// compile-flags: -C target-feature=+Int64 | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, u64), | ||
) { | ||
let hi = (w.1 >> 32) as u32; | ||
let lo = (w.1 & 0xFFFF_FFFF) as u32; | ||
out[0] = w.0 ^ hi ^ lo; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// build-pass | ||
// compile-flags: -C target-feature=+Int64 | ||
#![no_std] | ||
|
||
use spirv_std::spirv; | ||
|
||
#[spirv(compute(threads(1)))] | ||
pub fn main( | ||
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32], | ||
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u64, u32), | ||
) { | ||
// Fold 64-bit into 32-bit deterministically | ||
let hi = (w.0 >> 32) as u32; | ||
let lo = (w.0 & 0xFFFF_FFFF) as u32; | ||
out[0] = hi ^ lo ^ w.1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eddyb I am not sure this section is right, can you review?