Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 336 additions & 3 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
use crate::codegen_cx::CodegenCx;
use crate::symbols::Symbols;
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
use rustc_hir as hir;
use rustc_hir::def_id::LocalModDefId;
use rustc_hir::intravisit::{self, Visitor};
use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
use rustc_middle::hir::nested_filter;
use rustc_middle::query::Providers;
use rustc_middle::ty::TyCtxt;
use rustc_span::{Span, Symbol};
use rustc_span::{Ident, Span, Symbol};
use smallvec::SmallVec;
use std::rc::Rc;

// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
Expand Down Expand Up @@ -152,7 +154,7 @@ impl AggregatedSpirvAttributes {

// NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
// to see an attribute error, it will cause an ICE instead.
for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
Expand Down Expand Up @@ -278,7 +280,7 @@ impl CheckSpirvAttrVisitor<'_> {
fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
let mut aggregated_attrs = AggregatedSpirvAttributes::default();

let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);

let attrs = self.tcx.hir_attrs(hir_id);
for parse_attr_result in parse_attrs(attrs) {
Expand Down Expand Up @@ -512,3 +514,334 @@ pub(crate) fn provide(providers: &mut Providers) {
..*providers
};
}

// FIXME(eddyb) find something nicer for the error type.
type ParseAttrError = (Span, String);

#[allow(clippy::get_first)]
fn parse_attrs_for_checking<'a>(
sym: &'a Symbols,
attrs: &'a [Attribute],
) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
attrs
.iter()
.map(move |attr| {
// parse the #[rust_gpu::spirv(...)] attr and return the inner list
match attr {
Attribute::Unparsed(item) => {
// #[...]
let s = &item.path.segments;
if let Some(rust_gpu) = s.get(0) && rust_gpu.name == sym.rust_gpu {
// #[rust_gpu ...]
match s.get(1) {
Some(command) if command.name == sym.spirv_attr_with_version => {
// #[rust_gpu::spirv ...]
if let Some(args) = attr.meta_item_list() {
// #[rust_gpu::spirv(...)]
Ok(parse_spirv_attr(sym, args.iter()))
} else {
// #[rust_gpu::spirv]
Err((
attr.span(),
"#[spirv(..)] attribute must have at least one argument"
.to_string(),
))
}
}
_ => {
// #[rust_gpu::...] but not a know version
let spirv = sym.spirv_attr_with_version.as_str();
Err((
attr.span(),
format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
))
}
}
} else {
// #[...] but not #[rust_gpu ...]
Ok(Default::default())
}
}
Attribute::Parsed(_) => Ok(Default::default()),
}
})
.flat_map(|result| {
result
.unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
.into_iter()
})
}

fn parse_spirv_attr<'a>(
sym: &Symbols,
iter: impl Iterator<Item = &'a MetaItemInner>,
) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
iter.map(|arg| {
let span = arg.span();
let parsed_attr =
if arg.has_name(sym.descriptor_set) {
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.binding) {
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.input_attachment_index) {
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.spec_constant) {
SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
} else {
let name = match arg.ident() {
Some(i) => i,
None => {
return Err((
span,
"#[spirv(..)] attribute argument must be single identifier".to_string(),
));
}
};
sym.attributes.get(&name.name).map_or_else(
|| Err((name.span, "unknown argument to spirv attribute".to_string())),
|a| {
Ok(match a {
SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
),
_ => a.clone(),
})
},
)?
};
Ok((span, parsed_attr))
})
.collect()
}

fn parse_spec_constant_attr(
sym: &Symbols,
arg: &MetaItemInner,
) -> Result<SpecConstant, ParseAttrError> {
let mut id = None;
let mut default = None;

if let Some(attrs) = arg.meta_item_list() {
for attr in attrs {
if attr.has_name(sym.id) {
if id.is_none() {
id = Some(parse_attr_int_value(attr)?);
} else {
return Err((attr.span(), "`id` may only be specified once".into()));
}
} else if attr.has_name(sym.default) {
if default.is_none() {
default = Some(parse_attr_int_value(attr)?);
} else {
return Err((attr.span(), "`default` may only be specified once".into()));
}
} else {
return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
}
}
}
Ok(SpecConstant {
id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
default,
})
}

fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
let arg = match arg.meta_item() {
Some(arg) => arg,
None => return Err((arg.span(), "attribute must have value".to_string())),
};
match arg.name_value_literal() {
Some(&MetaItemLit {
kind: LitKind::Int(x, ..),
..
}) if x <= u32::MAX as u128 => Ok(x.get() as u32),
_ => Err((arg.span, "attribute value must be integer".to_string())),
}
}

fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
let arg = match arg.meta_item() {
Some(arg) => arg,
None => return Err((arg.span(), "attribute must have value".to_string())),
};
match arg.meta_item_list() {
Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
let mut local_size = [1; 3];
for (idx, lit) in tuple.iter().enumerate() {
match lit {
MetaItemInner::Lit(MetaItemLit {
kind: LitKind::Int(x, ..),
..
}) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
_ => return Err((lit.span(), "must be a u32 literal".to_string())),
}
}
Ok(local_size)
}
Some([]) => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
)),
Some(tuple) if tuple.len() > 3 => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
)),
_ => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
)),
}
}

// for a given entry, gather up the additional attributes
// in this case ExecutionMode's, some have extra arguments
// others are specified with x, y, or z components
// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
fn parse_entry_attrs(
sym: &Symbols,
arg: &MetaItemInner,
name: &Ident,
execution_model: ExecutionModel,
) -> Result<Entry, ParseAttrError> {
use ExecutionMode::*;
use ExecutionModel::*;
let mut entry = Entry::from(execution_model);
let mut origin_mode: Option<ExecutionMode> = None;
let mut local_size: Option<[u32; 3]> = None;
let mut local_size_hint: Option<[u32; 3]> = None;
// Reserved
//let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
if let Some(attrs) = arg.meta_item_list() {
for attr in attrs {
if let Some(attr_name) = attr.ident() {
if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
{
use crate::symbols::ExecutionModeExtraDim::*;
let val = match extra_dim {
None | Tuple => Option::None,
_ => Some(parse_attr_int_value(attr)?),
};
match execution_mode {
OriginUpperLeft | OriginLowerLeft => {
origin_mode.replace(*execution_mode);
}
LocalSize => {
if local_size.is_none() {
local_size.replace(parse_local_size_attr(attr)?);
} else {
return Err((
attr_name.span,
String::from(
"`#[spirv(compute(threads))]` may only be specified once",
),
));
}
}
LocalSizeHint => {
let val = val.unwrap();
if local_size_hint.is_none() {
local_size_hint.replace([1, 1, 1]);
}
let local_size_hint = local_size_hint.as_mut().unwrap();
match extra_dim {
X => {
local_size_hint[0] = val;
}
Y => {
local_size_hint[1] = val;
}
Z => {
local_size_hint[2] = val;
}
_ => unreachable!(),
}
}
// Reserved
/*MaxWorkgroupSizeINTEL => {
let val = val.unwrap();
if max_workgroup_size_intel.is_none() {
max_workgroup_size_intel.replace([1, 1, 1]);
}
let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
.unwrap();
match extra_dim {
X => {
max_workgroup_size_intel[0] = val;
},
Y => {
max_workgroup_size_intel[1] = val;
},
Z => {
max_workgroup_size_intel[2] = val;
},
_ => unreachable!(),
}
},*/
_ => {
if let Some(val) = val {
entry
.execution_modes
.push((*execution_mode, ExecutionModeExtra::new([val])));
} else {
entry
.execution_modes
.push((*execution_mode, ExecutionModeExtra::new([])));
}
}
}
} else if attr_name.name == sym.entry_point_name {
match attr.value_str() {
Some(sym) => {
entry.name = Some(sym);
}
None => {
return Err((
attr_name.span,
format!(
"#[spirv({name}(..))] unknown attribute argument {attr_name}"
),
));
}
}
} else {
return Err((
attr_name.span,
format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
));
}
} else {
return Err((
arg.span(),
format!("#[spirv({name}(..))] attribute argument must be single identifier"),
));
}
}
}
match entry.execution_model {
Fragment => {
let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
entry
.execution_modes
.push((origin_mode, ExecutionModeExtra::new([])));
}
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
if let Some(local_size) = local_size {
entry
.execution_modes
.push((LocalSize, ExecutionModeExtra::new(local_size)));
} else {
return Err((
arg.span(),
String::from(
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
),
));
}
}
//TODO: Cover more defaults
_ => {}
}
Ok(entry)
}
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ mod custom_decorations;
mod custom_insts;
mod link;
mod linker;
#[path = "../../spirv_attr_version.rs"]
mod spirv_attr_version;
mod spirv_type;
mod spirv_type_constraints;
mod symbols;
Expand Down
Loading