diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 046501d08c482..843c2ca7f81eb 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -2,13 +2,14 @@ use std::ffi::CString; use llvm::Linkage::*; use rustc_abi::Align; +use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_middle::ty::offload_meta::OffloadMetadata; use crate::builder::Builder; use crate::common::CodegenCx; use crate::llvm::AttributePlace::Function; -use crate::llvm::{self, Linkage, Type, Value}; +use crate::llvm::{self, Linkage, Type, Value, get_value_name}; use crate::{SimpleCx, attributes}; // LLVM kernel-independent globals required for offloading @@ -303,7 +304,6 @@ pub(crate) fn add_global<'ll>( pub(crate) fn gen_define_handling<'ll>( cx: &CodegenCx<'ll, '_>, metadata: &[OffloadMetadata], - types: &[&'ll Type], symbol: String, offload_globals: &OffloadGlobals<'ll>, ) -> OffloadKernelGlobals<'ll> { @@ -313,25 +313,18 @@ pub(crate) fn gen_define_handling<'ll>( let offload_entry_ty = offload_globals.offload_entry_ty; - // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or - // reference) types. - let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), - _ => None, - }); - // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary - let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) = - ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip(); + let (sizes, transfer): (Vec<_>, Vec<_>) = + metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip(); - let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); + let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes); // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). let memtransfer_types = - add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer); + add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer); // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value @@ -477,8 +470,27 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); for &v in args { - let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); - vals.push(v); + let name = String::from_utf8(get_value_name(v)).unwrap(); + let (base_val, gep_base) = match cx.type_kind(cx.val_ty(v)) { + TypeKind::Pointer => (v, v), + _ => { + let addr = + builder.direct_alloca(cx.val_ty(v), Align::EIGHT, &format!("{}.addr", name)); + let casted = + builder.direct_alloca(cx.type_i64(), Align::EIGHT, &format!("{}.casted", name)); + builder.store(v, addr, Align::EIGHT); + + let loaded = builder.load(cx.val_ty(v), addr, Align::EIGHT); + builder.store(loaded, casted, Align::EIGHT); + + let casted_val = builder.load(cx.type_i64(), casted, Align::EIGHT); + (casted_val, casted) + } + }; + + let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]); + + vals.push(base_val); geps.push(gep); } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 3bc890310cc87..d44d3afc0829a 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1310,7 +1310,7 @@ fn codegen_offload<'ll, 'tcx>( return; } }; - let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals); + let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals); gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals); } diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs index 04a7cd2c75f28..67c00765ed57b 100644 --- a/compiler/rustc_middle/src/ty/offload_meta.rs +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -78,16 +78,13 @@ impl MappingFlags { use rustc_ast::Mutability::*; match ty.kind() { - ty::Bool - | ty::Char - | ty::Int(_) - | ty::Uint(_) - | ty::Float(_) - | ty::Adt(_, _) - | ty::Tuple(_) - | ty::Array(_, _) - | ty::Alias(_, _) - | ty::Param(_) => MappingFlags::TO, + ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) => { + MappingFlags::LITERAL | MappingFlags::IMPLICIT + } + + ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Alias(_, _) | ty::Param(_) => { + MappingFlags::TO + } ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,