Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use std::ffi::CString;

use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::offload_meta::OffloadMetadata;

use crate::builder::Builder;
use crate::common::CodegenCx;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, Linkage, Type, Value};
use crate::llvm::{self, Linkage, Type, Value, get_value_name};
use crate::{SimpleCx, attributes};

// LLVM kernel-independent globals required for offloading
Expand Down Expand Up @@ -303,7 +304,6 @@ pub(crate) fn add_global<'ll>(
pub(crate) fn gen_define_handling<'ll>(
cx: &CodegenCx<'ll, '_>,
metadata: &[OffloadMetadata],
types: &[&'ll Type],
symbol: String,
offload_globals: &OffloadGlobals<'ll>,
) -> OffloadKernelGlobals<'ll> {
Expand All @@ -313,25 +313,18 @@ pub(crate) fn gen_define_handling<'ll>(

let offload_entry_ty = offload_globals.offload_entry_ty;

// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
// reference) types.
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
_ => None,
});

// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
let (sizes, transfer): (Vec<_>, Vec<_>) =
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();

let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
let memtransfer_types =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);

// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value
Expand Down Expand Up @@ -477,8 +470,27 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
let mut geps = vec![];
let i32_0 = cx.get_const_i32(0);
for &v in args {
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
vals.push(v);
let name = String::from_utf8(get_value_name(v)).unwrap();
let (base_val, gep_base) = match cx.type_kind(cx.val_ty(v)) {
TypeKind::Pointer => (v, v),
_ => {
let addr =
builder.direct_alloca(cx.val_ty(v), Align::EIGHT, &format!("{}.addr", name));
let casted =
builder.direct_alloca(cx.type_i64(), Align::EIGHT, &format!("{}.casted", name));
builder.store(v, addr, Align::EIGHT);

let loaded = builder.load(cx.val_ty(v), addr, Align::EIGHT);
builder.store(loaded, casted, Align::EIGHT);

let casted_val = builder.load(cx.type_i64(), casted, Align::EIGHT);
(casted_val, casted)
}
};

let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);

vals.push(base_val);
geps.push(gep);
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ fn codegen_offload<'ll, 'tcx>(
return;
}
};
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals);
}

Expand Down
17 changes: 7 additions & 10 deletions compiler/rustc_middle/src/ty/offload_meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,13 @@ impl MappingFlags {
use rustc_ast::Mutability::*;

match ty.kind() {
ty::Bool
| ty::Char
| ty::Int(_)
| ty::Uint(_)
| ty::Float(_)
| ty::Adt(_, _)
| ty::Tuple(_)
| ty::Array(_, _)
| ty::Alias(_, _)
| ty::Param(_) => MappingFlags::TO,
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) => {
MappingFlags::LITERAL | MappingFlags::IMPLICIT
}

ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Alias(_, _) | ty::Param(_) => {
MappingFlags::TO
}

ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,

Expand Down
Loading