Skip to content

Commit b73089f

Browse files
committed
Make {add,sub,mul,div}_sh functions const
1 parent 3b7dd43 commit b73089f

File tree

1 file changed

+104
-24
lines changed

1 file changed

+104
-24
lines changed

crates/core_arch/src/x86/avx512fp16.rs

Lines changed: 104 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,8 +1683,9 @@ pub fn _mm_maskz_add_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
16831683
#[target_feature(enable = "avx512fp16")]
16841684
#[cfg_attr(test, assert_instr(vaddsh))]
16851685
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
1686-
pub fn _mm_add_sh(a: __m128h, b: __m128h) -> __m128h {
1687-
_mm_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
1686+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
1687+
pub const fn _mm_add_sh(a: __m128h, b: __m128h) -> __m128h {
1688+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) + _mm_cvtsh_h(b)) }
16881689
}
16891690

16901691
/// Add the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -1696,8 +1697,18 @@ pub fn _mm_add_sh(a: __m128h, b: __m128h) -> __m128h {
16961697
#[target_feature(enable = "avx512fp16")]
16971698
#[cfg_attr(test, assert_instr(vaddsh))]
16981699
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
1699-
pub fn _mm_mask_add_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1700-
_mm_mask_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
1700+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
1701+
pub const fn _mm_mask_add_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1702+
unsafe {
1703+
let extractsrc: f16 = simd_extract!(src, 0);
1704+
let mut add: f16 = extractsrc;
1705+
if (k & 0b00000001) != 0 {
1706+
let extracta: f16 = simd_extract!(a, 0);
1707+
let extractb: f16 = simd_extract!(b, 0);
1708+
add = extracta + extractb;
1709+
}
1710+
simd_insert!(a, 0, add)
1711+
}
17011712
}
17021713

17031714
/// Add the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -1709,8 +1720,17 @@ pub fn _mm_mask_add_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
17091720
#[target_feature(enable = "avx512fp16")]
17101721
#[cfg_attr(test, assert_instr(vaddsh))]
17111722
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
1712-
pub fn _mm_maskz_add_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1713-
_mm_maskz_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
1723+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
1724+
pub const fn _mm_maskz_add_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1725+
unsafe {
1726+
let mut add: f16 = 0.;
1727+
if (k & 0b00000001) != 0 {
1728+
let extracta: f16 = simd_extract!(a, 0);
1729+
let extractb: f16 = simd_extract!(b, 0);
1730+
add = extracta + extractb;
1731+
}
1732+
simd_insert!(a, 0, add)
1733+
}
17141734
}
17151735

17161736
/// Subtract packed half-precision (16-bit) floating-point elements in b from a, and store the results in dst.
@@ -2004,8 +2024,9 @@ pub fn _mm_maskz_sub_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
20042024
#[target_feature(enable = "avx512fp16")]
20052025
#[cfg_attr(test, assert_instr(vsubsh))]
20062026
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2007-
pub fn _mm_sub_sh(a: __m128h, b: __m128h) -> __m128h {
2008-
_mm_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
2027+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2028+
pub const fn _mm_sub_sh(a: __m128h, b: __m128h) -> __m128h {
2029+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) - _mm_cvtsh_h(b)) }
20092030
}
20102031

20112032
/// Subtract the lower half-precision (16-bit) floating-point elements in b from a, store the result in the
@@ -2017,8 +2038,18 @@ pub fn _mm_sub_sh(a: __m128h, b: __m128h) -> __m128h {
20172038
#[target_feature(enable = "avx512fp16")]
20182039
#[cfg_attr(test, assert_instr(vsubsh))]
20192040
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2020-
pub fn _mm_mask_sub_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2021-
_mm_mask_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
2041+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2042+
pub const fn _mm_mask_sub_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2043+
unsafe {
2044+
let extractsrc: f16 = simd_extract!(src, 0);
2045+
let mut add: f16 = extractsrc;
2046+
if (k & 0b00000001) != 0 {
2047+
let extracta: f16 = simd_extract!(a, 0);
2048+
let extractb: f16 = simd_extract!(b, 0);
2049+
add = extracta - extractb;
2050+
}
2051+
simd_insert!(a, 0, add)
2052+
}
20222053
}
20232054

20242055
/// Subtract the lower half-precision (16-bit) floating-point elements in b from a, store the result in the
@@ -2030,8 +2061,17 @@ pub fn _mm_mask_sub_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
20302061
#[target_feature(enable = "avx512fp16")]
20312062
#[cfg_attr(test, assert_instr(vsubsh))]
20322063
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2033-
pub fn _mm_maskz_sub_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2034-
_mm_maskz_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
2064+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2065+
pub const fn _mm_maskz_sub_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2066+
unsafe {
2067+
let mut add: f16 = 0.;
2068+
if (k & 0b00000001) != 0 {
2069+
let extracta: f16 = simd_extract!(a, 0);
2070+
let extractb: f16 = simd_extract!(b, 0);
2071+
add = extracta - extractb;
2072+
}
2073+
simd_insert!(a, 0, add)
2074+
}
20352075
}
20362076

20372077
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, and store the results in dst.
@@ -2325,8 +2365,9 @@ pub fn _mm_maskz_mul_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
23252365
#[target_feature(enable = "avx512fp16")]
23262366
#[cfg_attr(test, assert_instr(vmulsh))]
23272367
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2328-
pub fn _mm_mul_sh(a: __m128h, b: __m128h) -> __m128h {
2329-
_mm_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
2368+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2369+
pub const fn _mm_mul_sh(a: __m128h, b: __m128h) -> __m128h {
2370+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) * _mm_cvtsh_h(b)) }
23302371
}
23312372

23322373
/// Multiply the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -2338,8 +2379,18 @@ pub fn _mm_mul_sh(a: __m128h, b: __m128h) -> __m128h {
23382379
#[target_feature(enable = "avx512fp16")]
23392380
#[cfg_attr(test, assert_instr(vmulsh))]
23402381
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2341-
pub fn _mm_mask_mul_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2342-
_mm_mask_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
2382+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2383+
pub const fn _mm_mask_mul_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2384+
unsafe {
2385+
let extractsrc: f16 = simd_extract!(src, 0);
2386+
let mut add: f16 = extractsrc;
2387+
if (k & 0b00000001) != 0 {
2388+
let extracta: f16 = simd_extract!(a, 0);
2389+
let extractb: f16 = simd_extract!(b, 0);
2390+
add = extracta * extractb;
2391+
}
2392+
simd_insert!(a, 0, add)
2393+
}
23432394
}
23442395

23452396
/// Multiply the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -2351,8 +2402,17 @@ pub fn _mm_mask_mul_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
23512402
#[target_feature(enable = "avx512fp16")]
23522403
#[cfg_attr(test, assert_instr(vmulsh))]
23532404
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2354-
pub fn _mm_maskz_mul_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2355-
_mm_maskz_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
2405+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2406+
pub const fn _mm_maskz_mul_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2407+
unsafe {
2408+
let mut add: f16 = 0.;
2409+
if (k & 0b00000001) != 0 {
2410+
let extracta: f16 = simd_extract!(a, 0);
2411+
let extractb: f16 = simd_extract!(b, 0);
2412+
add = extracta * extractb;
2413+
}
2414+
simd_insert!(a, 0, add)
2415+
}
23562416
}
23572417

23582418
/// Divide packed half-precision (16-bit) floating-point elements in a by b, and store the results in dst.
@@ -2646,8 +2706,9 @@ pub fn _mm_maskz_div_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
26462706
#[target_feature(enable = "avx512fp16")]
26472707
#[cfg_attr(test, assert_instr(vdivsh))]
26482708
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2649-
pub fn _mm_div_sh(a: __m128h, b: __m128h) -> __m128h {
2650-
_mm_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
2709+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2710+
pub const fn _mm_div_sh(a: __m128h, b: __m128h) -> __m128h {
2711+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) / _mm_cvtsh_h(b)) }
26512712
}
26522713

26532714
/// Divide the lower half-precision (16-bit) floating-point elements in a by b, store the result in the
@@ -2659,8 +2720,18 @@ pub fn _mm_div_sh(a: __m128h, b: __m128h) -> __m128h {
26592720
#[target_feature(enable = "avx512fp16")]
26602721
#[cfg_attr(test, assert_instr(vdivsh))]
26612722
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2662-
pub fn _mm_mask_div_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2663-
_mm_mask_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
2723+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2724+
pub const fn _mm_mask_div_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2725+
unsafe {
2726+
let extractsrc: f16 = simd_extract!(src, 0);
2727+
let mut add: f16 = extractsrc;
2728+
if (k & 0b00000001) != 0 {
2729+
let extracta: f16 = simd_extract!(a, 0);
2730+
let extractb: f16 = simd_extract!(b, 0);
2731+
add = extracta / extractb;
2732+
}
2733+
simd_insert!(a, 0, add)
2734+
}
26642735
}
26652736

26662737
/// Divide the lower half-precision (16-bit) floating-point elements in a by b, store the result in the
@@ -2672,8 +2743,17 @@ pub fn _mm_mask_div_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
26722743
#[target_feature(enable = "avx512fp16")]
26732744
#[cfg_attr(test, assert_instr(vdivsh))]
26742745
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
2675-
pub fn _mm_maskz_div_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2676-
_mm_maskz_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
2746+
#[rustc_const_unstable(feature = "stdarch_const_intrinsics", issue = "none")]
2747+
pub const fn _mm_maskz_div_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2748+
unsafe {
2749+
let mut add: f16 = 0.;
2750+
if (k & 0b00000001) != 0 {
2751+
let extracta: f16 = simd_extract!(a, 0);
2752+
let extractb: f16 = simd_extract!(b, 0);
2753+
add = extracta / extractb;
2754+
}
2755+
simd_insert!(a, 0, add)
2756+
}
26772757
}
26782758

26792759
/// Multiply packed complex numbers in a and b, and store the results in dst. Each complex number is

0 commit comments

Comments
 (0)