Skip to content

Commit 75d9135

Browse files
committed
abi layout: give Vector a dynamic size and alignment
1 parent 08be456 commit 75d9135

File tree

7 files changed

+138
-59
lines changed

7 files changed

+138
-59
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
639639
SpirvType::Vector {
640640
element: elem_spirv,
641641
count: count as u32,
642+
size: self.size,
643+
align: self.align.abi,
642644
}
643645
.def(span, cx)
644646
}
@@ -1220,43 +1222,92 @@ fn trans_intrinsic_type<'tcx>(
12201222
}
12211223
}
12221224
IntrinsicType::Matrix => {
1223-
let span = def_id_for_spirv_type_adt(ty)
1224-
.map(|did| cx.tcx.def_span(did))
1225-
.expect("#[spirv(matrix)] must be added to a type which has DefId");
1226-
1227-
let field_types = (0..ty.fields.count())
1228-
.map(|i| ty.field(cx, i).spirv_type(span, cx))
1229-
.collect::<Vec<_>>();
1230-
if field_types.len() < 2 {
1231-
return Err(cx
1232-
.tcx
1233-
.dcx()
1234-
.span_err(span, "#[spirv(matrix)] type must have at least two fields"));
1235-
}
1236-
let elem_type = field_types[0];
1237-
if !field_types.iter().all(|&ty| ty == elem_type) {
1238-
return Err(cx.tcx.dcx().span_err(
1239-
span,
1240-
"#[spirv(matrix)] type fields must all be the same type",
1241-
));
1242-
}
1243-
match cx.lookup_type(elem_type) {
1225+
let (element, count) =
1226+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(matrix)]`")?;
1227+
match cx.lookup_type(element) {
12441228
SpirvType::Vector { .. } => (),
12451229
ty => {
12461230
return Err(cx
12471231
.tcx
12481232
.dcx()
1249-
.struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
1250-
.with_note(format!("field type is {}", ty.debug(elem_type, cx)))
1233+
.struct_span_err(span, "`#[spirv(matrix)]` type fields must all be vectors")
1234+
.with_note(format!("field type is {}", ty.debug(element, cx)))
12511235
.emit());
12521236
}
12531237
}
1254-
1255-
Ok(SpirvType::Matrix {
1256-
element: elem_type,
1257-
count: field_types.len() as u32,
1238+
Ok(SpirvType::Matrix { element, count }.def(span, cx))
1239+
}
1240+
IntrinsicType::Vector => {
1241+
let (element, count) =
1242+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(vector)]`")?;
1243+
match cx.lookup_type(element) {
1244+
SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1245+
ty => {
1246+
return Err(cx
1247+
.tcx
1248+
.dcx()
1249+
.struct_span_err(
1250+
span,
1251+
"`#[spirv(vector)]` type fields must all be floats or integers",
1252+
)
1253+
.with_note(format!("field type is {}", ty.debug(element, cx)))
1254+
.emit());
1255+
}
1256+
}
1257+
Ok(SpirvType::Vector {
1258+
element,
1259+
count,
1260+
size: ty.size,
1261+
align: ty.align.abi,
12581262
}
12591263
.def(span, cx))
12601264
}
12611265
}
12621266
}
1267+
1268+
/// A struct with multiple fields of the same kind
1269+
/// Used for `#[spirv(vector)]` and `#[spirv(matrix)]`
1270+
fn trans_glam_like_struct<'tcx>(
1271+
cx: &CodegenCx<'tcx>,
1272+
span: Span,
1273+
ty: TyAndLayout<'tcx>,
1274+
args: GenericArgsRef<'tcx>,
1275+
err_attr_name: &str,
1276+
) -> Result<(Word, u32), ErrorGuaranteed> {
1277+
let tcx = cx.tcx;
1278+
if let Some(adt) = ty.ty.ty_adt_def()
1279+
&& adt.is_struct()
1280+
{
1281+
let (count, element) = adt
1282+
.non_enum_variant()
1283+
.fields
1284+
.iter()
1285+
.map(|f| f.ty(tcx, args))
1286+
.dedup_with_count()
1287+
.exactly_one()
1288+
.map_err(|_e| {
1289+
tcx.dcx().span_err(
1290+
span,
1291+
format!("{err_attr_name} member types must all be the same"),
1292+
)
1293+
})?;
1294+
1295+
let element = cx.layout_of(element);
1296+
let element_word = element.spirv_type(span, cx);
1297+
let count = u32::try_from(count)
1298+
.ok()
1299+
.filter(|count| *count >= 2)
1300+
.ok_or_else(|| {
1301+
tcx.dcx().span_err(
1302+
span,
1303+
format!("{err_attr_name} must have at least 2 members"),
1304+
)
1305+
})?;
1306+
1307+
Ok((element_word, count))
1308+
} else {
1309+
Err(tcx
1310+
.dcx()
1311+
.span_err(span, "#[spirv(vector)] type must be a struct"))
1312+
}
1313+
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ pub enum IntrinsicType {
6666
RuntimeArray,
6767
TypedBuffer,
6868
Matrix,
69+
Vector,
6970
}
7071

7172
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,12 @@ fn memset_dynamic_scalar(
317317
byte_width: usize,
318318
is_float: bool,
319319
) -> Word {
320-
let composite_type = SpirvType::Vector {
321-
element: SpirvType::Integer(8, false).def(builder.span(), builder),
322-
count: byte_width as u32,
323-
}
320+
let composite_type = SpirvType::simd_vector(
321+
builder,
322+
builder.span(),
323+
SpirvType::Integer(8, false),
324+
byte_width as u32,
325+
)
324326
.def(builder.span(), builder);
325327
let composite = builder
326328
.emit()
@@ -417,7 +419,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
417419
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
418420
},
419421
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
420-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
422+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
421423
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
422424
self.constant_composite(
423425
ty.def(self.span(), self),
@@ -478,7 +480,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
478480
)
479481
.unwrap()
480482
}
481-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
483+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
482484
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
483485
self.emit()
484486
.composite_construct(
@@ -2966,11 +2968,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
29662968
}
29672969

29682970
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
2969-
let result_type = SpirvType::Vector {
2970-
element: elt.ty,
2971-
count: num_elts as u32,
2972-
}
2973-
.def(self.span(), self);
2971+
let result_type =
2972+
SpirvType::simd_vector(self, self.span(), self.lookup_type(elt.ty), num_elts as u32)
2973+
.def(self.span(), self);
29742974
if self.builder.lookup_const(elt).is_some() {
29752975
self.constant_composite(result_type, iter::repeat_n(elt.def(self), num_elts))
29762976
} else {

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
113113
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
114114
self.bitcast(val, result_type)
115115
}
116-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
116+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
117117
.load_vec_mat_arr(
118118
original_type,
119119
result_type,
@@ -312,7 +312,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
312312
let value_u32 = self.bitcast(value, u32_ty);
313313
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
314314
}
315-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
315+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
316316
.store_vec_mat_arr(
317317
original_type,
318318
value,

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,12 @@ impl ConstCodegenMethods for CodegenCx<'_> {
200200
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
201201
}
202202
fn const_vector(&self, elts: &[Self::Value]) -> Self::Value {
203-
let vector_ty = SpirvType::Vector {
204-
element: elts[0].ty,
205-
count: elts.len() as u32,
206-
}
203+
let vector_ty = SpirvType::simd_vector(
204+
self,
205+
DUMMY_SP,
206+
self.lookup_type(elts[0].ty),
207+
elts.len() as u32,
208+
)
207209
.def(DUMMY_SP, self);
208210
self.constant_composite(vector_ty, elts.iter().map(|elt| elt.def_cx(self)))
209211
}

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
4545
element: Word,
4646
/// Note: vector count is literal.
4747
count: u32,
48+
size: Size,
49+
align: Align,
4850
},
4951
Matrix {
5052
element: Word,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131133
}
132134
result
133135
}
134-
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
136+
Self::Vector { element, count, .. } => {
137+
cx.emit_global().type_vector_id(id, element, count)
138+
}
135139
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
136140
Self::Array { element, count } => {
137141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280284
Self::Bool => Size::from_bytes(1),
281285
Self::Integer(width, _) | Self::Float(width) => Size::from_bits(width),
282286
Self::Adt { size, .. } => size?,
283-
Self::Vector { element, count } => {
284-
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
285-
}
287+
Self::Vector { size, .. } => size,
286288
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
287289
Self::Array { element, count } => {
288290
cx.lookup_type(element).sizeof(cx)?
@@ -310,14 +312,7 @@ impl SpirvType<'_> {
310312

311313
Self::Bool => Align::from_bytes(1).unwrap(),
312314
Self::Integer(width, _) | Self::Float(width) => Align::from_bits(width as u64).unwrap(),
313-
Self::Adt { align, .. } => align,
314-
// Vectors have size==align
315-
Self::Vector { .. } => Align::from_bytes(
316-
self.sizeof(cx)
317-
.expect("alignof: Vectors must be sized")
318-
.bytes(),
319-
)
320-
.expect("alignof: Vectors must have power-of-2 size"),
315+
Self::Adt { align, .. } | Self::Vector { align, .. } => align,
321316
Self::Array { element, .. }
322317
| Self::RuntimeArray { element }
323318
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
@@ -382,7 +377,17 @@ impl SpirvType<'_> {
382377
SpirvType::Bool => SpirvType::Bool,
383378
SpirvType::Integer(width, signedness) => SpirvType::Integer(width, signedness),
384379
SpirvType::Float(width) => SpirvType::Float(width),
385-
SpirvType::Vector { element, count } => SpirvType::Vector { element, count },
380+
SpirvType::Vector {
381+
element,
382+
count,
383+
size,
384+
align,
385+
} => SpirvType::Vector {
386+
element,
387+
count,
388+
size,
389+
align,
390+
},
386391
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
387392
SpirvType::Array { element, count } => SpirvType::Array { element, count },
388393
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
@@ -435,6 +440,15 @@ impl SpirvType<'_> {
435440
},
436441
}
437442
}
443+
444+
pub fn simd_vector(cx: &CodegenCx<'_>, span: Span, element: SpirvType<'_>, count: u32) -> Self {
445+
Self::Vector {
446+
element: element.def(span, cx),
447+
count,
448+
size: element.sizeof(cx).unwrap() * count as u64,
449+
align: element.alignof(cx),
450+
}
451+
}
438452
}
439453

440454
impl<'a> SpirvType<'a> {
@@ -501,11 +515,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501515
.field("field_names", &field_names)
502516
.finish()
503517
}
504-
SpirvType::Vector { element, count } => f
518+
SpirvType::Vector {
519+
element,
520+
count,
521+
size,
522+
align,
523+
} => f
505524
.debug_struct("Vector")
506525
.field("id", &self.id)
507526
.field("element", &self.cx.debug_type(element))
508527
.field("count", &count)
528+
.field("size", &size)
529+
.field("align", &align)
509530
.finish(),
510531
SpirvType::Matrix { element, count } => f
511532
.debug_struct("Matrix")
@@ -668,7 +689,7 @@ impl SpirvTypePrinter<'_, '_> {
668689
}
669690
f.write_str(" }")
670691
}
671-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
692+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
672693
ty(self.cx, stack, f, element)?;
673694
write!(f, "x{count}")
674695
}

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ impl Symbols {
373373
"matrix",
374374
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
375375
),
376+
(
377+
"vector",
378+
SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379+
),
376380
("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
377381
(
378382
"buffer_store_intrinsic",

0 commit comments

Comments
 (0)