@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
45
45
element : Word ,
46
46
/// Note: vector count is literal.
47
47
count : u32 ,
48
+ size : Size ,
49
+ align : Align ,
48
50
} ,
49
51
Matrix {
50
52
element : Word ,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131
133
}
132
134
result
133
135
}
134
- Self :: Vector { element, count } => cx. emit_global ( ) . type_vector_id ( id, element, count) ,
136
+ Self :: Vector { element, count, .. } => {
137
+ cx. emit_global ( ) . type_vector_id ( id, element, count)
138
+ }
135
139
Self :: Matrix { element, count } => cx. emit_global ( ) . type_matrix_id ( id, element, count) ,
136
140
Self :: Array { element, count } => {
137
141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280
284
Self :: Bool => Size :: from_bytes ( 1 ) ,
281
285
Self :: Integer ( width, _) | Self :: Float ( width) => Size :: from_bits ( width) ,
282
286
Self :: Adt { size, .. } => size?,
283
- Self :: Vector { element, count } => {
284
- cx. lookup_type ( element) . sizeof ( cx) ? * count. next_power_of_two ( ) as u64
285
- }
287
+ Self :: Vector { size, .. } => size,
286
288
Self :: Matrix { element, count } => cx. lookup_type ( element) . sizeof ( cx) ? * count as u64 ,
287
289
Self :: Array { element, count } => {
288
290
cx. lookup_type ( element) . sizeof ( cx) ?
@@ -310,14 +312,7 @@ impl SpirvType<'_> {
310
312
311
313
Self :: Bool => Align :: from_bytes ( 1 ) . unwrap ( ) ,
312
314
Self :: Integer ( width, _) | Self :: Float ( width) => Align :: from_bits ( width as u64 ) . unwrap ( ) ,
313
- Self :: Adt { align, .. } => align,
314
- // Vectors have size==align
315
- Self :: Vector { .. } => Align :: from_bytes (
316
- self . sizeof ( cx)
317
- . expect ( "alignof: Vectors must be sized" )
318
- . bytes ( ) ,
319
- )
320
- . expect ( "alignof: Vectors must have power-of-2 size" ) ,
315
+ Self :: Adt { align, .. } | Self :: Vector { align, .. } => align,
321
316
Self :: Array { element, .. }
322
317
| Self :: RuntimeArray { element }
323
318
| Self :: Matrix { element, .. } => cx. lookup_type ( element) . alignof ( cx) ,
@@ -382,7 +377,17 @@ impl SpirvType<'_> {
382
377
SpirvType :: Bool => SpirvType :: Bool ,
383
378
SpirvType :: Integer ( width, signedness) => SpirvType :: Integer ( width, signedness) ,
384
379
SpirvType :: Float ( width) => SpirvType :: Float ( width) ,
385
- SpirvType :: Vector { element, count } => SpirvType :: Vector { element, count } ,
380
+ SpirvType :: Vector {
381
+ element,
382
+ count,
383
+ size,
384
+ align,
385
+ } => SpirvType :: Vector {
386
+ element,
387
+ count,
388
+ size,
389
+ align,
390
+ } ,
386
391
SpirvType :: Matrix { element, count } => SpirvType :: Matrix { element, count } ,
387
392
SpirvType :: Array { element, count } => SpirvType :: Array { element, count } ,
388
393
SpirvType :: RuntimeArray { element } => SpirvType :: RuntimeArray { element } ,
@@ -435,6 +440,15 @@ impl SpirvType<'_> {
435
440
} ,
436
441
}
437
442
}
443
+
444
+ pub fn simd_vector ( cx : & CodegenCx < ' _ > , span : Span , element : SpirvType < ' _ > , count : u32 ) -> Self {
445
+ Self :: Vector {
446
+ element : element. def ( span, cx) ,
447
+ count,
448
+ size : element. sizeof ( cx) . unwrap ( ) * count as u64 ,
449
+ align : element. alignof ( cx) ,
450
+ }
451
+ }
438
452
}
439
453
440
454
impl < ' a > SpirvType < ' a > {
@@ -501,11 +515,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501
515
. field ( "field_names" , & field_names)
502
516
. finish ( )
503
517
}
504
- SpirvType :: Vector { element, count } => f
518
+ SpirvType :: Vector {
519
+ element,
520
+ count,
521
+ size,
522
+ align,
523
+ } => f
505
524
. debug_struct ( "Vector" )
506
525
. field ( "id" , & self . id )
507
526
. field ( "element" , & self . cx . debug_type ( element) )
508
527
. field ( "count" , & count)
528
+ . field ( "size" , & size)
529
+ . field ( "align" , & align)
509
530
. finish ( ) ,
510
531
SpirvType :: Matrix { element, count } => f
511
532
. debug_struct ( "Matrix" )
@@ -668,7 +689,7 @@ impl SpirvTypePrinter<'_, '_> {
668
689
}
669
690
f. write_str ( " }" )
670
691
}
671
- SpirvType :: Vector { element, count } | SpirvType :: Matrix { element, count } => {
692
+ SpirvType :: Vector { element, count, .. } | SpirvType :: Matrix { element, count } => {
672
693
ty ( self . cx , stack, f, element) ?;
673
694
write ! ( f, "x{count}" )
674
695
}
0 commit comments