Skip to content

Commit d8b1b46

Browse files
committed
WIP
1 parent 257377c commit d8b1b46

File tree

3 files changed

+137
-77
lines changed

3 files changed

+137
-77
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 132 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
55
use crate::codegen_cx::CodegenCx;
66
use crate::spirv_type::SpirvType;
7+
use crate::symbols::Symbols;
78
use itertools::Itertools;
89
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
910
use rustc_abi::{AbiAlign, ExternAbi as Abi};
@@ -16,10 +17,10 @@ use rustc_errors::ErrorGuaranteed;
1617
use rustc_hashes::Hash64;
1718
use rustc_index::Idx;
1819
use rustc_middle::query::Providers;
19-
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
20+
use rustc_middle::ty::layout::{FnAbiOf, LayoutError, LayoutOf, TyAndLayout};
2021
use rustc_middle::ty::{
21-
self, AdtDef, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty,
22-
TyCtxt, TyKind, UintTy,
22+
self, AdtDef, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, GenericArgs, IntTy,
23+
PolyFnSig, Ty, TyCtxt, TyKind, TypingEnv, UintTy,
2324
};
2425
use rustc_middle::ty::{GenericArgsRef, ScalarInt};
2526
use rustc_middle::{bug, span_bug};
@@ -169,85 +170,14 @@ pub(crate) fn provide(providers: &mut Providers) {
169170
fn layout_of<'tcx>(
170171
tcx: TyCtxt<'tcx>,
171172
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
172-
) -> Result<TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
173+
) -> Result<TyAndLayout<'tcx>, &'tcx LayoutError<'tcx>> {
173174
// HACK(eddyb) to special-case any types at all, they must be normalized,
174175
// but when normalization would be needed, `layout_of`'s default provider
175176
// recurses (supposedly for caching reasons), i.e. its calls `layout_of`
176177
// w/ the normalized type in input, which once again reaches this hook,
177178
// without ever needing any explicit normalization here.
178-
let ty = key.value;
179-
180-
// HACK(eddyb) bypassing upstream `#[repr(simd)]` changes (see also
181-
// the later comment above `check_well_formed`, for more details).
182-
let reimplement_old_style_repr_simd: Option<(&AdtDef<'tcx>, Ty<'tcx>, u64)> = match ty
183-
.kind()
184-
{
185-
ty::Adt(def, args) if def.repr().simd() && !def.repr().packed() && def.is_struct() => {
186-
Some(def.non_enum_variant()).and_then(|v| {
187-
let (count, e_ty) = v
188-
.fields
189-
.iter()
190-
.map(|f| f.ty(tcx, args))
191-
.dedup_with_count()
192-
.exactly_one()
193-
.ok()?;
194-
let e_len = u64::try_from(count).ok().filter(|&e_len| e_len > 1)?;
195-
Some((def, e_ty, e_len))
196-
})
197-
}
198-
_ => None,
199-
};
200-
201-
// HACK(eddyb) tweaked copy of the old upstream logic for `#[repr(simd)]`:
202-
// https://github.com/rust-lang/rust/blob/1.86.0/compiler/rustc_ty_utils/src/layout.rs#L464-L590
203-
if let Some((adt_def, e_ty, e_len)) = reimplement_old_style_repr_simd {
204-
let cx = rustc_middle::ty::layout::LayoutCx::new(
205-
tcx,
206-
key.typing_env.with_post_analysis_normalized(tcx),
207-
);
208-
209-
// Compute the ABI of the element type:
210-
let e_ly: TyAndLayout<'_> = cx.layout_of(e_ty)?;
211-
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
212-
// This error isn't caught in typeck, e.g., if
213-
// the element type of the vector is generic.
214-
tcx.dcx().span_fatal(
215-
tcx.def_span(adt_def.did()),
216-
format!(
217-
"SIMD type `{ty}` with a non-primitive-scalar \
218-
(integer/float/pointer) element type `{}`",
219-
e_ly.ty
220-
),
221-
);
222-
};
223-
224-
// Compute the size and alignment of the vector:
225-
let size = e_ly.size.checked_mul(e_len, &cx).unwrap();
226-
let align = adt_def.repr().align.unwrap_or(e_ly.align.abi);
227-
let size = size.align_to(align);
228-
229-
let layout = tcx.mk_layout(LayoutData {
230-
variants: Variants::Single {
231-
index: rustc_abi::FIRST_VARIANT,
232-
},
233-
fields: FieldsShape::Array {
234-
stride: e_ly.size,
235-
count: e_len,
236-
},
237-
backend_repr: BackendRepr::SimdVector {
238-
element: e_repr,
239-
count: e_len,
240-
},
241-
largest_niche: e_ly.largest_niche,
242-
uninhabited: false,
243-
size,
244-
align: AbiAlign::new(align),
245-
max_repr_align: None,
246-
unadjusted_abi_align: align,
247-
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
248-
});
249-
250-
return Ok(TyAndLayout { ty, layout });
179+
if let Some(layout) = layout_of_spirv_attr_special(tcx, key)? {
180+
return Ok(layout);
251181
}
252182

253183
let TyAndLayout { ty, mut layout } =
@@ -276,6 +206,128 @@ pub(crate) fn provide(providers: &mut Providers) {
276206
Ok(TyAndLayout { ty, layout })
277207
}
278208

209+
fn layout_of_spirv_attr_special<'tcx>(
210+
tcx: TyCtxt<'tcx>,
211+
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
212+
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
213+
let ty::PseudoCanonicalInput {
214+
typing_env,
215+
value: ty,
216+
} = key;
217+
218+
match ty.kind() {
219+
ty::Adt(def, args) => {
220+
let def: &AdtDef<'tcx> = def;
221+
let args: &'tcx GenericArgs<'tcx> = args;
222+
let attrs = AggregatedSpirvAttributes::parse(
223+
tcx,
224+
&Symbols::get(),
225+
tcx.get_all_attrs(def.did()),
226+
);
227+
228+
// add spirv-attr special layouts here
229+
if let Some(layout) =
230+
layout_of_spirv_vector(tcx, typing_env, ty, def, args, &attrs)?
231+
{
232+
return Ok(Some(layout));
233+
}
234+
}
235+
_ => {}
236+
}
237+
Ok(None)
238+
}
239+
240+
fn layout_of_spirv_vector<'tcx>(
241+
tcx: TyCtxt<'tcx>,
242+
typing_env: TypingEnv<'tcx>,
243+
ty: Ty<'tcx>,
244+
def: &AdtDef<'tcx>,
245+
args: &'tcx GenericArgs<'tcx>,
246+
attrs: &AggregatedSpirvAttributes,
247+
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
248+
let layout_err = |msg| {
249+
&*tcx.arena.alloc(LayoutError::ReferencesError(
250+
tcx.dcx().span_err(tcx.def_span(def.did()), msg),
251+
))
252+
};
253+
254+
let has_spirv_vector_attr = attrs
255+
.intrinsic_type
256+
.as_ref()
257+
.map_or(false, |attr| matches!(attr.value, IntrinsicType::Vector));
258+
let has_repr_simd = def.repr().simd() && !def.repr().packed();
259+
if !has_spirv_vector_attr && !has_repr_simd {
260+
return Ok(None);
261+
}
262+
263+
if !def.is_struct() {
264+
return Err(layout_err(format!(
265+
"spirv vector type `{ty}` must be a struct"
266+
)));
267+
}
268+
let (count, e_ty) = def
269+
.non_enum_variant()
270+
.fields
271+
.iter()
272+
.map(|f| f.ty(tcx, args))
273+
.dedup_with_count()
274+
.exactly_one()
275+
.map_err(|_| {
276+
layout_err(format!(
277+
"spirv vector type `{ty}` must have a single element type"
278+
))
279+
})?;
280+
let e_len = u64::try_from(count)
281+
.ok()
282+
.filter(|&e_len| e_len >= 2)
283+
.ok_or_else(|| {
284+
layout_err(format!(
285+
"spirv vector type `{ty}` to have at least 2 elements"
286+
))
287+
})?;
288+
289+
let lcx = ty::layout::LayoutCx::new(tcx, typing_env.with_post_analysis_normalized(tcx));
290+
291+
// Compute the ABI of the element type:
292+
let e_ly: TyAndLayout<'_> = lcx.layout_of(e_ty)?;
293+
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
294+
// This error isn't caught in typeck, e.g., if
295+
// the element type of the vector is generic.
296+
return Err(layout_err(format!(
297+
"spirv vector type `{ty}` must have a non-primitive-scalar (integer/float/pointer) element type, got `{}`",
298+
e_ly.ty
299+
)));
300+
};
301+
302+
// Compute the size and alignment of the vector:
303+
let size = e_ly.size.checked_mul(e_len, &lcx).unwrap();
304+
let align = def.repr().align.unwrap_or(e_ly.align.abi);
305+
let size = size.align_to(align);
306+
307+
let layout = tcx.mk_layout(LayoutData {
308+
variants: Variants::Single {
309+
index: rustc_abi::FIRST_VARIANT,
310+
},
311+
fields: FieldsShape::Array {
312+
stride: e_ly.size,
313+
count: e_len,
314+
},
315+
backend_repr: BackendRepr::SimdVector {
316+
element: e_repr,
317+
count: e_len,
318+
},
319+
largest_niche: e_ly.largest_niche,
320+
uninhabited: false,
321+
size,
322+
align: AbiAlign::new(align),
323+
max_repr_align: None,
324+
unadjusted_abi_align: align,
325+
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
326+
});
327+
328+
Ok(Some(TyAndLayout { ty, layout }))
329+
}
330+
279331
// HACK(eddyb) work around https://github.com/rust-lang/rust/pull/129403
280332
// banning "struct-style" `#[repr(simd)]` (in favor of "array-newtype-style"),
281333
// by simply bypassing "type definition WF checks" for affected types, which:
@@ -1265,5 +1317,8 @@ fn trans_intrinsic_type<'tcx>(
12651317
}
12661318
.def(span, cx))
12671319
}
1320+
IntrinsicType::Vector => {
1321+
todo!()
1322+
}
12681323
}
12691324
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub enum IntrinsicType {
6565
RuntimeArray,
6666
TypedBuffer,
6767
Matrix,
68+
Vector,
6869
}
6970

7071
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ impl Symbols {
373373
"matrix",
374374
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
375375
),
376+
(
377+
"vector",
378+
SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379+
),
376380
("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
377381
(
378382
"buffer_store_intrinsic",

0 commit comments

Comments
 (0)