Skip to content

Commit 518dbfc

Browse files
committed
Support scalar pair ABI
(T, U) pairs in entrypoints and regular functions are now supported.
1 parent 7358fae commit 518dbfc

File tree

23 files changed

+495
-45
lines changed

23 files changed

+495
-45
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ pub fn scalar_pair_element_backend_type<'tcx>(
655655
ty: TyAndLayout<'tcx>,
656656
index: usize,
657657
) -> Word {
658-
let [a, b] = match ty.layout.backend_repr() {
658+
let [a, b] = match ty.backend_repr {
659659
BackendRepr::ScalarPair(a, b) => [a, b],
660660
other => span_bug!(
661661
span,

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use rspirv::dr::Operand;
1111
use rspirv::spirv::{
1212
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
1313
};
14-
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
14+
use rustc_codegen_ssa::traits::{
15+
BaseTypeCodegenMethods, BuilderMethods, ConstCodegenMethods, LayoutTypeCodegenMethods,
16+
MiscCodegenMethods as _,
17+
};
1518
use rustc_data_structures::fx::FxHashMap;
1619
use rustc_errors::MultiSpan;
1720
use rustc_hir as hir;
@@ -87,21 +90,7 @@ impl<'tcx> CodegenCx<'tcx> {
8790
for (arg_abi, hir_param) in fn_abi.args.iter().zip(hir_params) {
8891
match arg_abi.mode {
8992
PassMode::Direct(_) => {}
90-
PassMode::Pair(..) => {
91-
// FIXME(eddyb) implement `ScalarPair` `Input`s, or change
92-
// the `FnAbi` readjustment to only use `PassMode::Pair` for
93-
// pointers to `!Sized` types, but not other `ScalarPair`s.
94-
if !matches!(arg_abi.layout.ty.kind(), ty::Ref(..)) {
95-
self.tcx.dcx().span_err(
96-
hir_param.ty_span,
97-
format!(
98-
"entry point parameter type not yet supported \
99-
(`{}` has `ScalarPair` ABI but is not a `&T`)",
100-
arg_abi.layout.ty
101-
),
102-
);
103-
}
104-
}
93+
PassMode::Pair(..) => {}
10594
// FIXME(eddyb) support these (by just ignoring them) - if there
10695
// is any validation concern, it should be done on the types.
10796
PassMode::Ignore => self.tcx.dcx().span_fatal(
@@ -442,6 +431,36 @@ impl<'tcx> CodegenCx<'tcx> {
442431
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs);
443432
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
444433

434+
// In compute shaders, user-provided data must come from buffers or push
435+
// constants, i.e. by-reference parameters.
436+
if execution_model == ExecutionModel::GLCompute
437+
&& matches!(entry_arg_abi.mode, PassMode::Direct(_) | PassMode::Pair(..))
438+
&& !matches!(entry_arg_abi.layout.ty.kind(), ty::Ref(..))
439+
&& attrs.builtin.is_none()
440+
{
441+
let param_name = if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
442+
ident.name.to_string()
443+
} else {
444+
"parameter".to_string()
445+
};
446+
self.tcx
447+
.dcx()
448+
.struct_span_err(
449+
hir_param.ty_span,
450+
format!(
451+
"compute entry parameter `{}` must be by-reference",
452+
param_name
453+
),
454+
)
455+
.with_help(format!(
456+
"consider changing the type to `&{}`",
457+
entry_arg_abi.layout.ty
458+
))
459+
.emit();
460+
// Keep this a hard error to stop compilation after emitting help.
461+
self.tcx.dcx().abort_if_errors();
462+
}
463+
445464
let (var_id, spec_const_id) = match storage_class {
446465
// Pre-allocate the module-scoped `OpVariable` *Result* ID.
447466
Ok(_) => (
@@ -491,14 +510,6 @@ impl<'tcx> CodegenCx<'tcx> {
491510
vs layout:\n{value_layout:#?}",
492511
entry_arg_abi.layout.ty
493512
);
494-
if is_pair && !is_unsized {
495-
// If PassMode is Pair, then we need to fill in the second part of the pair with a
496-
// value. We currently only do that with unsized types, so if a type is a pair for some
497-
// other reason (e.g. a tuple), we bail.
498-
self.tcx
499-
.dcx()
500-
.span_fatal(hir_param.ty_span, "pair type not supported yet")
501-
}
502513
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"?
503514
// FIXME(eddyb) should we talk about "descriptor indexing" or
504515
// actually use more reasonable terms like "resource arrays"?
@@ -621,8 +632,8 @@ impl<'tcx> CodegenCx<'tcx> {
621632
}
622633
}
623634

624-
let value_len = if is_pair {
625-
// We've already emitted an error, fill in a placeholder value
635+
let value_len = if is_pair && is_unsized {
636+
// For wide references (e.g., slices), the second component is a length.
626637
Some(bx.undef(self.type_isize()))
627638
} else {
628639
None
@@ -645,21 +656,54 @@ impl<'tcx> CodegenCx<'tcx> {
645656
_ => unreachable!(),
646657
}
647658
} else {
648-
assert_matches!(entry_arg_abi.mode, PassMode::Direct(_));
649-
650-
let value = match storage_class {
651-
Ok(_) => {
659+
match entry_arg_abi.mode {
660+
PassMode::Direct(_) => {
661+
let value = match storage_class {
662+
Ok(_) => {
663+
assert_eq!(storage_class, Ok(StorageClass::Input));
664+
bx.load(
665+
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
666+
value_ptr.unwrap(),
667+
entry_arg_abi.layout.align.abi,
668+
)
669+
}
670+
Err(SpecConstant { .. }) => {
671+
spec_const_id.unwrap().with_type(value_spirv_type)
672+
}
673+
};
674+
call_args.push(value);
675+
assert_eq!(value_len, None);
676+
}
677+
PassMode::Pair(..) => {
678+
// Load both elements of the scalar pair from the input variable.
652679
assert_eq!(storage_class, Ok(StorageClass::Input));
653-
bx.load(
654-
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
655-
value_ptr.unwrap(),
656-
entry_arg_abi.layout.align.abi,
657-
)
680+
let layout = entry_arg_abi.layout;
681+
let (a, b) = match layout.backend_repr {
682+
rustc_abi::BackendRepr::ScalarPair(a, b) => (a, b),
683+
other => span_bug!(
684+
hir_param.ty_span,
685+
"ScalarPair expected for entry param, found {other:?}"
686+
),
687+
};
688+
let b_offset = a
689+
.primitive()
690+
.size(self)
691+
.align_to(b.primitive().align(self).abi);
692+
693+
let elem0_ty = self.scalar_pair_element_backend_type(layout, 0, false);
694+
let elem1_ty = self.scalar_pair_element_backend_type(layout, 1, false);
695+
696+
let base_ptr = value_ptr.unwrap();
697+
let ptr1 = bx.inbounds_ptradd(base_ptr, self.const_usize(b_offset.bytes()));
698+
699+
let v0 = bx.load(elem0_ty, base_ptr, layout.align.abi);
700+
let v1 = bx.load(elem1_ty, ptr1, layout.align.restrict_for_offset(b_offset));
701+
call_args.push(v0);
702+
call_args.push(v1);
703+
assert_eq!(value_len, None);
658704
}
659-
Err(SpecConstant { .. }) => spec_const_id.unwrap().with_type(value_spirv_type),
660-
};
661-
call_args.push(value);
662-
assert_eq!(value_len, None);
705+
_ => unreachable!(),
706+
}
663707
}
664708

665709
// FIXME(eddyb) check whether the storage class is compatible with the
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// compile-fail
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
w: (u32, u32),
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
10+
) {
11+
out[0] = w.0 + w.1;
12+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
error: compute entry parameter `w` must be by-reference
2+
--> $DIR/compute_value_pair_fail.rs:8:8
3+
|
4+
8 | w: (u32, u32),
5+
| ^^^^^^^^^^
6+
|
7+
= help: consider changing the type to `&(u32, u32)`
8+
9+
error: aborting due to 1 previous error
10+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
9+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(f32, u32),
10+
) {
11+
out[0] = w.0.to_bits() ^ w.1;
12+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+Int64
3+
#![no_std]
4+
5+
use spirv_std::spirv;
6+
7+
#[spirv(compute(threads(1)))]
8+
pub fn main(
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
10+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(i32, i32),
11+
) {
12+
// Sum and reinterpret as u32 for output
13+
let s = (w.0 as i64 + w.1 as i64) as i32;
14+
out[0] = s as u32;
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
9+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, f32),
10+
) {
11+
let a = w.0;
12+
let b_bits = w.1.to_bits();
13+
out[0] = a ^ b_bits;
14+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[repr(transparent)]
7+
pub struct Wrap((u32, u32));
8+
9+
#[spirv(compute(threads(1)))]
10+
pub fn main(
11+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
12+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &Wrap,
13+
) {
14+
let a = (w.0).0;
15+
let b = (w.0).1;
16+
out[0] = a + b;
17+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[repr(transparent)]
7+
pub struct Inner((u32, u32));
8+
9+
#[repr(transparent)]
10+
pub struct Outer(
11+
core::mem::ManuallyDrop<Inner>,
12+
core::marker::PhantomData<()>,
13+
);
14+
15+
#[inline(never)]
16+
fn sum_outer(o: Outer) -> u32 {
17+
// SAFETY: repr(transparent) guarantees same layout as `Inner`.
18+
let i: Inner = unsafe { core::mem::ManuallyDrop::into_inner((o.0)) };
19+
(i.0).0 + (i.0).1
20+
}
21+
22+
#[spirv(compute(threads(1)))]
23+
pub fn main(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32]) {
24+
let i = Inner((5, 7));
25+
let o = Outer(core::mem::ManuallyDrop::new(i), core::marker::PhantomData);
26+
out[0] = sum_outer(o);
27+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(fragment)]
7+
pub fn main(
8+
#[spirv(flat)] pi: (u32, u32),
9+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32],
10+
) {
11+
out[0] = pi.0.wrapping_add(pi.1);
12+
}

0 commit comments

Comments
 (0)