Skip to content

Commit 257377c

Browse files
committed
Allow parsing of #[spirv(...)] attribute with Iterators
1 parent c8b87f0 commit 257377c

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -546,11 +546,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
546546
span = cx.tcx.def_span(adt.did());
547547
}
548548

549-
let attrs = AggregatedSpirvAttributes::parse(
550-
cx.tcx,
551-
&cx.sym,
552-
cx.tcx.get_attrs_unchecked(adt.did()),
553-
);
549+
let attrs =
550+
AggregatedSpirvAttributes::parse(cx.tcx, &cx.sym, cx.tcx.get_all_attrs(adt.did()));
554551

555552
if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
556553
&& let Ok(spirv_type) =

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//!
33
//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
44
5-
use crate::symbols::Symbols;
5+
use crate::symbols::{Symbols, parse_attrs_for_checking};
66
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
77
use rustc_hir as hir;
88
use rustc_hir::def_id::LocalModDefId;
@@ -146,12 +146,16 @@ impl AggregatedSpirvAttributes {
146146
///
147147
/// Any errors for malformed/duplicate attributes will have been reported
148148
/// prior to codegen, by the `attr` check pass.
149-
pub fn parse<'tcx>(tcx: TyCtxt<'tcx>, sym: &Symbols, attrs: &'tcx [Attribute]) -> Self {
149+
pub fn parse<'tcx>(
150+
tcx: TyCtxt<'tcx>,
151+
sym: &Symbols,
152+
attrs: impl Iterator<Item = &'tcx Attribute>,
153+
) -> Self {
150154
let mut aggregated_attrs = Self::default();
151155

152156
// NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
153157
// to see an attribute error, it will cause an ICE instead.
154-
for parse_attr_result in crate::symbols::parse_attrs_for_checking(sym, attrs) {
158+
for parse_attr_result in parse_attrs_for_checking(sym, attrs) {
155159
let (span, parsed_attr) = match parse_attr_result {
156160
Ok(span_and_parsed_attr) => span_and_parsed_attr,
157161
Err((span, msg)) => {
@@ -276,10 +280,8 @@ impl CheckSpirvAttrVisitor<'_> {
276280
fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
277281
let mut aggregated_attrs = AggregatedSpirvAttributes::default();
278282

279-
let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
280-
281283
let attrs = self.tcx.hir_attrs(hir_id);
282-
for parse_attr_result in parse_attrs(attrs) {
284+
for parse_attr_result in parse_attrs_for_checking(&self.sym, attrs.into_iter()) {
283285
let (span, parsed_attr) = match parse_attr_result {
284286
Ok(span_and_parsed_attr) => span_and_parsed_attr,
285287
Err((span, msg)) => {
@@ -324,9 +326,12 @@ impl CheckSpirvAttrVisitor<'_> {
324326
| SpirvAttribute::SpecConstant(_) => match target {
325327
Target::Param => {
326328
let parent_hir_id = self.tcx.parent_hir_id(hir_id);
327-
let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
328-
.filter_map(|r| r.ok())
329-
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
329+
let parent_is_entry_point = parse_attrs_for_checking(
330+
&self.sym,
331+
self.tcx.hir_attrs(parent_hir_id).iter(),
332+
)
333+
.filter_map(|r| r.ok())
334+
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
330335
if !parent_is_entry_point {
331336
self.tcx.dcx().span_err(
332337
span,

crates/rustc_codegen_spirv/src/codegen_cx/declare.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,8 @@ impl<'tcx> CodegenCx<'tcx> {
133133
self.set_linkage(fn_id, symbol_name.to_owned(), linkage);
134134
}
135135

136-
let attrs = AggregatedSpirvAttributes::parse(
137-
self.tcx,
138-
&self.sym,
139-
self.tcx.get_attrs_unchecked(def_id),
140-
);
136+
let attrs =
137+
AggregatedSpirvAttributes::parse(self.tcx, &self.sym, self.tcx.get_all_attrs(def_id));
141138
if let Some(entry) = attrs.entry.map(|attr| attr.value) {
142139
// HACK(eddyb) early insert to let `shader_entry_stub` call this
143140
// very function via `get_fn_addr`.

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ impl<'tcx> CodegenCx<'tcx> {
436436
let attrs = AggregatedSpirvAttributes::parse(
437437
self.tcx,
438438
&self.sym,
439-
self.tcx.hir_attrs(hir_param.hir_id),
439+
self.tcx.hir_attrs(hir_param.hir_id).iter(),
440440
);
441441

442442
let EntryParamDeducedFromRustRefOrValue {

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,11 @@ impl Symbols {
441441
type ParseAttrError = (Span, String);
442442

443443
// FIXME(eddyb) maybe move this to `attr`?
444-
pub(crate) fn parse_attrs_for_checking<'a>(
445-
sym: &'a Symbols,
446-
attrs: &'a [Attribute],
447-
) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
448-
attrs.iter().flat_map(move |attr| {
444+
pub(crate) fn parse_attrs_for_checking<'a, 'b>(
445+
sym: &'b Symbols,
446+
attrs: impl Iterator<Item = &'a Attribute> + 'b,
447+
) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'b {
448+
attrs.flat_map(move |attr| {
449449
let (whole_attr_error, args) = match attr {
450450
Attribute::Unparsed(item) => {
451451
// #[...]

0 commit comments

Comments
 (0)