Skip to content

Commit 219c555

Browse files
committed
Add autocast for i1 vectors
1 parent 666f92e commit 219c555

File tree

3 files changed

+151
-17
lines changed

3 files changed

+151
-17
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -375,26 +375,31 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
375375
return true;
376376
}
377377

378-
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
379-
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
380-
// as, well, packed structs, so they won't match with those either)
381-
if self.type_kind(llvm_ty) == TypeKind::Struct
382-
&& self.type_kind(rust_ty) == TypeKind::Struct
383-
{
384-
let rust_element_tys = self.struct_element_types(rust_ty);
385-
let llvm_element_tys = self.struct_element_types(llvm_ty);
378+
match self.type_kind(llvm_ty) {
379+
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
380+
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
381+
// as, well, packed structs, so they won't match with those either)
382+
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
383+
let rust_element_tys = self.struct_element_types(rust_ty);
384+
let llvm_element_tys = self.struct_element_types(llvm_ty);
385+
386+
if rust_element_tys.len() != llvm_element_tys.len() {
387+
return false;
388+
}
386389

387-
if rust_element_tys.len() != llvm_element_tys.len() {
388-
return false;
390+
iter::zip(rust_element_tys, llvm_element_tys).all(
391+
|(rust_element_ty, llvm_element_ty)| {
392+
self.equate_ty(rust_element_ty, llvm_element_ty)
393+
},
394+
)
389395
}
396+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
397+
let element_count = self.vector_length(llvm_ty) as u64;
398+
let int_width = element_count.next_power_of_two().max(8);
390399

391-
iter::zip(rust_element_tys, llvm_element_tys).all(
392-
|(rust_element_ty, llvm_element_ty)| {
393-
self.equate_ty(rust_element_ty, llvm_element_ty)
394-
},
395-
)
396-
} else {
397-
false
400+
rust_ty == self.type_ix(int_width)
401+
}
402+
_ => false,
398403
}
399404
}
400405
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,46 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16801680
}
16811681
}
16821682
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1683+
fn trunc_int_to_i1_vector(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
1684+
let vector_length = self.vector_length(dest_ty) as u64;
1685+
let int_width = vector_length.next_power_of_two().max(8);
1686+
1687+
let bitcasted = self.bitcast(val, self.type_vector(self.type_i1(), int_width));
1688+
if vector_length == int_width {
1689+
bitcasted
1690+
} else {
1691+
let shuffle_mask =
1692+
(0..vector_length).map(|i| self.const_i32(i as i32)).collect::<Vec<_>>();
1693+
self.shuffle_vector(bitcasted, bitcasted, self.const_vector(&shuffle_mask))
1694+
}
1695+
}
1696+
1697+
fn zext_i1_vector_to_int(
1698+
&mut self,
1699+
mut val: &'ll Value,
1700+
src_ty: &'ll Type,
1701+
dest_ty: &'ll Type,
1702+
) -> &'ll Value {
1703+
let vector_length = self.vector_length(src_ty) as u64;
1704+
let int_width = vector_length.next_power_of_two().max(8);
1705+
1706+
if vector_length != int_width {
1707+
let shuffle_indices = match vector_length {
1708+
0 => unreachable!("zero length vectors are not allowed"),
1709+
1 => vec![0, 1, 1, 1, 1, 1, 1, 1],
1710+
2 => vec![0, 1, 2, 3, 2, 3, 2, 3],
1711+
3 => vec![0, 1, 2, 3, 4, 5, 3, 4],
1712+
4.. => (0..int_width as i32).collect(),
1713+
};
1714+
let shuffle_mask =
1715+
shuffle_indices.into_iter().map(|i| self.const_i32(i)).collect::<Vec<_>>();
1716+
val =
1717+
self.shuffle_vector(val, self.const_null(src_ty), self.const_vector(&shuffle_mask));
1718+
}
1719+
1720+
self.bitcast(val, dest_ty)
1721+
}
1722+
16831723
fn autocast(
16841724
&mut self,
16851725
llfn: &'ll Value,
@@ -1708,6 +1748,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17081748
}
17091749
ret
17101750
}
1751+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
1752+
if is_argument {
1753+
self.trunc_int_to_i1_vector(val, dest_ty)
1754+
} else {
1755+
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
1756+
}
1757+
}
17111758
_ => unreachable!(),
17121759
}
17131760
}

tests/codegen-llvm/inject-autocast.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//@ compile-flags: -C opt-level=0
2+
//@ only-x86_64
3+
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
5+
#![crate_type = "lib"]
6+
7+
use std::simd::i64x2;
8+
9+
#[repr(simd)]
10+
pub struct Tile([i8; 1024]);
11+
12+
#[repr(C, packed)]
13+
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
14+
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
15+
16+
#[repr(simd)]
17+
pub struct f16x8([f16; 8]);
18+
19+
// CHECK-LABEL: @struct_with_i1_vector_autocast
20+
#[no_mangle]
21+
pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
22+
extern "unadjusted" {
23+
#[link_name = "llvm.x86.avx512.vp2intersect.q.128"]
24+
fn foo(a: i64x2, b: i64x2) -> (u8, u8);
25+
}
26+
27+
// CHECK: %2 = call { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64> %0, <2 x i64> %1)
28+
// CHECK-NEXT: %3 = extractvalue { <2 x i1>, <2 x i1> } %2, 0
29+
// CHECK-NEXT: %4 = shufflevector <2 x i1> %3, <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
30+
// CHECK-NEXT: %5 = bitcast <8 x i1> %4 to i8
31+
// CHECK-NEXT: %6 = insertvalue { i8, i8 } poison, i8 %5, 0
32+
// CHECK-NEXT: %7 = extractvalue { <2 x i1>, <2 x i1> } %2, 1
33+
// CHECK-NEXT: %8 = shufflevector <2 x i1> %7, <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
34+
// CHECK-NEXT: %9 = bitcast <8 x i1> %8 to i8
35+
// CHECK-NEXT: %10 = insertvalue { i8, i8 } %6, i8 %9, 1
36+
foo(a, b)
37+
}
38+
39+
// CHECK-LABEL: @struct_autocast
40+
#[no_mangle]
41+
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
42+
extern "unadjusted" {
43+
#[link_name = "llvm.x86.encodekey128"]
44+
fn foo(key_metadata: u32, key: i64x2) -> Bar;
45+
}
46+
47+
// CHECK: %1 = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 %key_metadata, <2 x i64> %0)
48+
// CHECK-NEXT: %2 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 0
49+
// CHECK-NEXT: %3 = insertvalue %Bar poison, i32 %2, 0
50+
// CHECK-NEXT: %4 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 1
51+
// CHECK-NEXT: %5 = insertvalue %Bar %3, <2 x i64> %4, 1
52+
// CHECK-NEXT: %6 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 2
53+
// CHECK-NEXT: %7 = insertvalue %Bar %5, <2 x i64> %6, 2
54+
// CHECK-NEXT: %8 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 3
55+
// CHECK-NEXT: %9 = insertvalue %Bar %7, <2 x i64> %8, 3
56+
// CHECK-NEXT: %10 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 4
57+
// CHECK-NEXT: %11 = insertvalue %Bar %9, <2 x i64> %10, 4
58+
// CHECK-NEXT: %12 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 5
59+
// CHECK-NEXT: %13 = insertvalue %Bar %11, <2 x i64> %12, 5
60+
// CHECK-NEXT: %14 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 6
61+
// CHECK-NEXT: %15 = insertvalue %Bar %13, <2 x i64> %14, 6
62+
foo(key_metadata, key)
63+
}
64+
65+
// CHECK-LABEL: @i1_vector_autocast
66+
#[no_mangle]
67+
pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
68+
extern "unadjusted" {
69+
#[link_name = "llvm.x86.avx512fp16.fpclass.ph.128"]
70+
fn foo(a: f16x8, b: i32) -> u8;
71+
}
72+
73+
// CHECK: %1 = call <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half> %0, i32 1)
74+
// CHECK-NEXT: %_0 = bitcast <8 x i1> %1 to i8
75+
foo(a, 1)
76+
}
77+
78+
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
79+
80+
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
81+
82+
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)