Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ pub fn scalar_pair_element_backend_type<'tcx>(
ty: TyAndLayout<'tcx>,
index: usize,
) -> Word {
let [a, b] = match ty.layout.backend_repr() {
let [a, b] = match ty.backend_repr {
BackendRepr::ScalarPair(a, b) => [a, b],
other => span_bug!(
span,
Expand Down
120 changes: 80 additions & 40 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use rspirv::dr::Operand;
use rspirv::spirv::{
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
use rustc_codegen_ssa::traits::{
BaseTypeCodegenMethods, BuilderMethods, ConstCodegenMethods, LayoutTypeCodegenMethods,
MiscCodegenMethods as _,
};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::MultiSpan;
use rustc_hir as hir;
Expand Down Expand Up @@ -86,22 +89,7 @@ impl<'tcx> CodegenCx<'tcx> {
};
for (arg_abi, hir_param) in fn_abi.args.iter().zip(hir_params) {
match arg_abi.mode {
PassMode::Direct(_) => {}
PassMode::Pair(..) => {
// FIXME(eddyb) implement `ScalarPair` `Input`s, or change
// the `FnAbi` readjustment to only use `PassMode::Pair` for
// pointers to `!Sized` types, but not other `ScalarPair`s.
if !matches!(arg_abi.layout.ty.kind(), ty::Ref(..)) {
self.tcx.dcx().span_err(
hir_param.ty_span,
format!(
"entry point parameter type not yet supported \
(`{}` has `ScalarPair` ABI but is not a `&T`)",
arg_abi.layout.ty
),
);
}
}
PassMode::Direct(_) | PassMode::Pair(..) => {}
// FIXME(eddyb) support these (by just ignoring them) - if there
// is any validation concern, it should be done on the types.
PassMode::Ignore => self.tcx.dcx().span_fatal(
Expand Down Expand Up @@ -442,6 +430,33 @@ impl<'tcx> CodegenCx<'tcx> {
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs);
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);

// In compute shaders, user-provided data must come from buffers or push
// constants, i.e. by-reference parameters.
if execution_model == ExecutionModel::GLCompute
&& matches!(entry_arg_abi.mode, PassMode::Direct(_) | PassMode::Pair(..))
&& !matches!(entry_arg_abi.layout.ty.kind(), ty::Ref(..))
&& attrs.builtin.is_none()
{
let param_name = if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
ident.name.to_string()
} else {
"parameter".to_string()
};
self.tcx
.dcx()
.struct_span_err(
hir_param.ty_span,
format!("compute entry parameter `{param_name}` must be by-reference",),
)
.with_help(format!(
"consider changing the type to `&{}`",
entry_arg_abi.layout.ty
))
.emit();
// Keep this a hard error to stop compilation after emitting help.
self.tcx.dcx().abort_if_errors();
}

let (var_id, spec_const_id) = match storage_class {
// Pre-allocate the module-scoped `OpVariable` *Result* ID.
Ok(_) => (
Expand Down Expand Up @@ -491,14 +506,6 @@ impl<'tcx> CodegenCx<'tcx> {
vs layout:\n{value_layout:#?}",
entry_arg_abi.layout.ty
);
if is_pair && !is_unsized {
// If PassMode is Pair, then we need to fill in the second part of the pair with a
// value. We currently only do that with unsized types, so if a type is a pair for some
// other reason (e.g. a tuple), we bail.
self.tcx
.dcx()
.span_fatal(hir_param.ty_span, "pair type not supported yet")
}
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"?
// FIXME(eddyb) should we talk about "descriptor indexing" or
// actually use more reasonable terms like "resource arrays"?
Expand Down Expand Up @@ -621,8 +628,8 @@ impl<'tcx> CodegenCx<'tcx> {
}
}

let value_len = if is_pair {
// We've already emitted an error, fill in a placeholder value
let value_len = if is_pair && is_unsized {
// For wide references (e.g., slices), the second component is a length.
Some(bx.undef(self.type_isize()))
} else {
None
Expand All @@ -645,21 +652,54 @@ impl<'tcx> CodegenCx<'tcx> {
_ => unreachable!(),
}
} else {
assert_matches!(entry_arg_abi.mode, PassMode::Direct(_));

let value = match storage_class {
Ok(_) => {
match entry_arg_abi.mode {
PassMode::Direct(_) => {
let value = match storage_class {
Ok(_) => {
assert_eq!(storage_class, Ok(StorageClass::Input));
bx.load(
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
value_ptr.unwrap(),
entry_arg_abi.layout.align.abi,
)
}
Err(SpecConstant { .. }) => {
spec_const_id.unwrap().with_type(value_spirv_type)
}
};
call_args.push(value);
assert_eq!(value_len, None);
}
PassMode::Pair(..) => {
// Load both elements of the scalar pair from the input variable.
assert_eq!(storage_class, Ok(StorageClass::Input));
bx.load(
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
value_ptr.unwrap(),
entry_arg_abi.layout.align.abi,
)
let layout = entry_arg_abi.layout;
let (a, b) = match layout.backend_repr {
rustc_abi::BackendRepr::ScalarPair(a, b) => (a, b),
other => span_bug!(
hir_param.ty_span,
"ScalarPair expected for entry param, found {other:?}"
),
};
let b_offset = a
.primitive()
.size(self)
.align_to(b.primitive().align(self).abi);

let elem0_ty = self.scalar_pair_element_backend_type(layout, 0, false);
let elem1_ty = self.scalar_pair_element_backend_type(layout, 1, false);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eddyb I am not sure this section is right, can you review?

let base_ptr = value_ptr.unwrap();
let ptr1 = bx.inbounds_ptradd(base_ptr, self.const_usize(b_offset.bytes()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use inbounds_ptradd here, when it just forwards to inbounds_gep?

fn inbounds_ptradd(&mut self, ptr: Self::Value, offset: Self::Value) -> Self::Value {
    self.inbounds_gep(self.cx().type_i8(), ptr, &[offset])
}

I'm also not so sure you want the resulting type to implicitly be an i8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is what I am not sure about, doing pointer arithmetic is bad and I was going to switch it to gep but wanted eddyb to chime in.


let v0 = bx.load(elem0_ty, base_ptr, layout.align.abi);
let v1 = bx.load(elem1_ty, ptr1, layout.align.restrict_for_offset(b_offset));
call_args.push(v0);
call_args.push(v1);
assert_eq!(value_len, None);
}
Err(SpecConstant { .. }) => spec_const_id.unwrap().with_type(value_spirv_type),
};
call_args.push(value);
assert_eq!(value_len, None);
_ => unreachable!(),
}
}

// FIXME(eddyb) check whether the storage class is compatible with the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// compile-fail
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
w: (u32, u32),
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
) {
out[0] = w.0 + w.1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
error: compute entry parameter `w` must be by-reference
--> $DIR/compute_value_pair_fail.rs:8:8
|
8 | w: (u32, u32),
| ^^^^^^^^^^
|
= help: consider changing the type to `&(u32, u32)`

error: aborting due to 1 previous error

12 changes: 12 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/f32_u32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(f32, u32),
) {
out[0] = w.0.to_bits() ^ w.1;
}
15 changes: 15 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/i32_i32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// build-pass
// compile-flags: -C target-feature=+Int64
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(i32, i32),
) {
// Sum and reinterpret as u32 for output
let s = (w.0 as i64 + w.1 as i64) as i32;
out[0] = s as u32;
}
14 changes: 14 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/mixed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, f32),
) {
let a = w.0;
let b_bits = w.1.to_bits();
out[0] = a ^ b_bits;
}
17 changes: 17 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/newtype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[repr(transparent)]
pub struct Wrap((u32, u32));

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &Wrap,
) {
let a = (w.0).0;
let b = (w.0).1;
out[0] = a + b;
}
27 changes: 27 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/newtype_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[repr(transparent)]
pub struct Inner((u32, u32));

#[repr(transparent)]
pub struct Outer(
core::mem::ManuallyDrop<Inner>,
core::marker::PhantomData<()>,
);

#[inline(never)]
fn sum_outer(o: Outer) -> u32 {
// SAFETY: repr(transparent) guarantees same layout as `Inner`.
let i: Inner = unsafe { core::mem::ManuallyDrop::into_inner((o.0)) };
(i.0).0 + (i.0).1
}

#[spirv(compute(threads(1)))]
pub fn main(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32]) {
let i = Inner((5, 7));
let o = Outer(core::mem::ManuallyDrop::new(i), core::marker::PhantomData);
out[0] = sum_outer(o);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[spirv(fragment)]
pub fn main(
#[spirv(flat)] pi: (u32, u32),
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32],
) {
out[0] = pi.0.wrapping_add(pi.1);
}
12 changes: 12 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/pair_input_vertex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[spirv(vertex)]
pub fn main(
p: (u32, u32),
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32],
) {
out[0] = p.0.wrapping_add(p.1);
}
12 changes: 12 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/tuple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-pass
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, u32),
) {
out[0] = w.0 + w.1;
}
15 changes: 15 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/u32_u64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// build-pass
// compile-flags: -C target-feature=+Int64
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, u64),
) {
let hi = (w.1 >> 32) as u32;
let lo = (w.1 & 0xFFFF_FFFF) as u32;
out[0] = w.0 ^ hi ^ lo;
}
16 changes: 16 additions & 0 deletions tests/compiletests/ui/lang/abi/scalar_pair/u64_u32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// build-pass
// compile-flags: -C target-feature=+Int64
#![no_std]

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u64, u32),
) {
// Fold 64-bit into 32-bit deterministically
let hi = (w.0 >> 32) as u32;
let lo = (w.0 & 0xFFFF_FFFF) as u32;
out[0] = hi ^ lo ^ w.1;
}
1 change: 0 additions & 1 deletion tests/compiletests/ui/lang/core/intrinsics/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
use core::hint::black_box;
use spirv_std::spirv;

// Minimal kernel that writes the disassembly function result to a buffer
#[spirv(compute(threads(1)))]
pub fn main(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32]) {
let r = disassemble();
Expand Down
8 changes: 4 additions & 4 deletions tests/compiletests/ui/lang/core/intrinsics/black_box.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ warning: black_box intrinsic does not prevent optimization in Rust GPU

%1 = OpFunction %2 DontInline %3
%4 = OpLabel
OpLine %5 32 17
OpLine %5 31 17
%6 = OpIAdd %7 %8 %9
OpLine %5 41 19
OpLine %5 40 19
%10 = OpIAdd %7 %11 %12
OpLine %5 47 8
OpLine %5 46 8
%13 = OpBitcast %7 %14
OpLine %15 1092 17
%16 = OpBitcast %7 %17
OpLine %5 46 4
OpLine %5 45 4
%18 = OpCompositeConstruct %2 %13 %16 %19 %20 %21 %22 %6 %23 %10 %24 %24 %24
OpNoLine
OpReturnValue %18
Expand Down
Loading