Skip to content

Commit 44f77ea

Browse files
committed
spv: minimal OpConstantFunctionPointerINTEL support.
1 parent eb22896 commit 44f77ea

File tree

8 files changed

+141
-18
lines changed

8 files changed

+141
-18
lines changed

src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,13 @@ pub struct ConstDef {
594594

595595
#[derive(Clone, PartialEq, Eq, Hash)]
596596
pub enum ConstKind {
597+
// FIXME(eddyb) maybe merge these? however, their connection is somewhat
598+
// tenuous (being one of the LLVM-isms SPIR-V inherited, among other things),
599+
// there's still the need to rename "global variable" post-`Var`-refactor,
600+
// and last but not least, `PtrToFunc` needs `SPV_INTEL_function_pointers`,
601+
// an OpenCL-only extension Intel came up with for their own SPIR-V tooling.
597602
PtrToGlobalVar(GlobalVar),
603+
PtrToFunc(Func),
598604

599605
// HACK(eddyb) this is a fallback case that should become increasingly rare
600606
// (especially wrt recursive consts), `Rc` means it can't bloat `ConstDef`.
@@ -683,7 +689,7 @@ pub struct FuncDecl {
683689
pub def: DeclDef<FuncDefBody>,
684690
}
685691

686-
#[derive(Copy, Clone)]
692+
#[derive(Copy, Clone, PartialEq, Eq)]
687693
pub struct FuncParam {
688694
pub attrs: AttrSet,
689695

src/print/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3348,6 +3348,9 @@ impl Print for ConstDef {
33483348
&ConstKind::PtrToGlobalVar(gv) => {
33493349
pretty::Fragment::new(["&".into(), gv.print(printer)])
33503350
}
3351+
&ConstKind::PtrToFunc(func) => {
3352+
pretty::Fragment::new(["&".into(), func.print(printer)])
3353+
}
33513354
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
33523355
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
33533356
pretty::Fragment::new([
@@ -4130,7 +4133,7 @@ impl Print for FuncAt<'_, DataInst> {
41304133
}
41314134
}
41324135
}
4133-
ConstKind::PtrToGlobalVar(_) => {}
4136+
ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) => {}
41344137
}
41354138
}
41364139
None

src/qptr/lower.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ impl<'a> LowerFromSpvPtrs<'a> {
162162
[spv::Imm::Short(_, sc)] => sc,
163163
_ => unreachable!(),
164164
};
165+
166+
// HACK(eddyb) keep function pointers separate, perhaps eventually
167+
// adding an `OpTypeUntypedPointerKHR CodeSectionINTEL` equivalent
168+
// to SPIR-T itself (after `SPV_KHR_untyped_pointers` support).
169+
if sc == self.wk.CodeSectionINTEL {
170+
return None;
171+
}
172+
165173
let pointee = match type_and_const_inputs[..] {
166174
[TypeOrConst::Type(elem_type)] => elem_type,
167175
_ => unreachable!(),

src/spv/lift.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> {
145145
}
146146
let ct_def = &self.cx[ct];
147147
match ct_def.kind {
148-
ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => {
148+
ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => {
149149
self.visit_const_def(ct_def);
150150
self.globals.insert(global);
151151
}
@@ -1051,7 +1051,9 @@ impl LazyInst<'_, '_> {
10511051
};
10521052
(gv_decl.attrs, import)
10531053
}
1054-
ConstKind::SpvInst { .. } => (ct_def.attrs, None),
1054+
ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => {
1055+
(ct_def.attrs, None)
1056+
}
10551057

10561058
// Not inserted into `globals` while visiting.
10571059
ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(),
@@ -1172,6 +1174,13 @@ impl LazyInst<'_, '_> {
11721174
}
11731175
}
11741176

1177+
&ConstKind::PtrToFunc(func) => spv::InstWithIds {
1178+
without_ids: wk.OpConstantFunctionPointerINTEL.into(),
1179+
result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]),
1180+
result_id,
1181+
ids: [ids.funcs[&func].func_id].into_iter().collect(),
1182+
},
1183+
11751184
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
11761185
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
11771186
spv::InstWithIds {

src/spv/lower.rs

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ enum IdDef {
2525

2626
Func(Func),
2727

28+
// HACK(eddyb) despite `FuncBody` deferring ID resolution to allow forward
29+
// references *between* functions, function pointer *constants* need a `Func`
30+
// long before any `OpFunction`s, so they're pre-defined as dummy imports.
31+
FuncForwardRef(Func),
32+
2833
SpvExtInstImport(InternedStr),
2934
SpvDebugString(InternedStr),
3035
}
@@ -37,7 +42,7 @@ impl IdDef {
3742
IdDef::Type(_) => "a type".into(),
3843
IdDef::Const(_) => "a constant".into(),
3944

40-
IdDef::Func(_) => "a function".into(),
45+
IdDef::Func(_) | IdDef::FuncForwardRef(_) => "a function".into(),
4146

4247
IdDef::SpvExtInstImport(name) => {
4348
format!("`OpExtInstImport {:?}`", &cx[name])
@@ -114,6 +119,37 @@ impl Module {
114119
// HACK(eddyb) used to quickly check whether an `OpVariable` is global.
115120
let storage_class_function_imm = spv::Imm::Short(wk.StorageClass, wk.Function);
116121

122+
// HACK(eddyb) used as the `FuncDecl` for an `IdDef::FuncForwardRef`.
123+
let dummy_decl_for_func_forward_ref = FuncDecl {
124+
attrs: {
125+
let mut attrs = AttrSet::default();
126+
attrs.push_diag(
127+
&cx,
128+
Diag::err(["function ID used as forward reference but never defined".into()]),
129+
);
130+
attrs
131+
},
132+
// FIXME(eddyb) this gets simpler w/ disaggregation.
133+
ret_type: cx.intern(TypeKind::SpvInst {
134+
spv_inst: wk.OpTypeVoid.into(),
135+
type_and_const_inputs: [].into_iter().collect(),
136+
}),
137+
params: [].into_iter().collect(),
138+
def: DeclDef::Imported(Import::LinkName(cx.intern(""))),
139+
};
140+
// HACK(eddyb) no `PartialEq` on `FuncDecl`.
141+
let assert_is_dummy_decl_for_func_forward_ref = |decl: &FuncDecl| {
142+
let [expected, found] = [&dummy_decl_for_func_forward_ref, decl].map(
143+
|FuncDecl { attrs, ret_type, params, def }| {
144+
let DeclDef::Imported(import) = def else {
145+
unreachable!();
146+
};
147+
(attrs, ret_type, params, import)
148+
},
149+
);
150+
assert!(expected == found);
151+
};
152+
117153
let mut module = {
118154
let [magic, version, generator_magic, id_bound, reserved_inst_schema] = parser.header;
119155

@@ -583,6 +619,38 @@ impl Module {
583619
});
584620
id_defs.insert(id, IdDef::Type(ty));
585621

622+
Seq::TypeConstOrGlobalVar
623+
} else if opcode == wk.OpConstantFunctionPointerINTEL {
624+
use std::collections::hash_map::Entry;
625+
626+
let id = inst.result_id.unwrap();
627+
628+
let func_id = inst.ids[0];
629+
let func = match id_defs.entry(func_id) {
630+
Entry::Occupied(entry) => match entry.get() {
631+
&IdDef::FuncForwardRef(func) => Ok(func),
632+
id_def => Err(id_def.descr(&cx)),
633+
},
634+
Entry::Vacant(entry) => {
635+
let func =
636+
module.funcs.define(&cx, dummy_decl_for_func_forward_ref.clone());
637+
entry.insert(IdDef::FuncForwardRef(func));
638+
Ok(func)
639+
}
640+
}
641+
.map_err(|descr| {
642+
invalid(&format!(
643+
"unsupported use of {descr} as the `OpConstantFunctionPointerINTEL` operand"
644+
))
645+
})?;
646+
647+
let ct = cx.intern(ConstDef {
648+
attrs: mem::take(&mut attrs),
649+
ty: result_type.unwrap(),
650+
kind: ConstKind::PtrToFunc(func),
651+
});
652+
id_defs.insert(id, IdDef::Const(ct));
653+
586654
Seq::TypeConstOrGlobalVar
587655
} else if inst_category == spec::InstructionCategory::Const || opcode == wk.OpUndef {
588656
let id = inst.result_id.unwrap();
@@ -755,19 +823,40 @@ impl Module {
755823
})
756824
}
757825
};
826+
let decl = FuncDecl {
827+
attrs: mem::take(&mut attrs),
828+
ret_type: func_ret_type,
829+
params: func_type_param_types
830+
.map(|ty| FuncParam { attrs: AttrSet::default(), ty })
831+
.collect(),
832+
def,
833+
};
758834

759-
let func = module.funcs.define(
760-
&cx,
761-
FuncDecl {
762-
attrs: mem::take(&mut attrs),
763-
ret_type: func_ret_type,
764-
params: func_type_param_types
765-
.map(|ty| FuncParam { attrs: AttrSet::default(), ty })
766-
.collect(),
767-
def,
768-
},
769-
);
770-
id_defs.insert(func_id, IdDef::Func(func));
835+
let func = {
836+
use std::collections::hash_map::Entry;
837+
838+
match id_defs.entry(func_id) {
839+
Entry::Occupied(mut entry) => match entry.get() {
840+
&IdDef::FuncForwardRef(func) => {
841+
let decl_slot = &mut module.funcs[func];
842+
assert_is_dummy_decl_for_func_forward_ref(decl_slot);
843+
*decl_slot = decl;
844+
845+
entry.insert(IdDef::Func(func));
846+
Ok(func)
847+
}
848+
id_def => Err(id_def.descr(&cx)),
849+
},
850+
Entry::Vacant(entry) => {
851+
let func = module.funcs.define(&cx, decl);
852+
entry.insert(IdDef::Func(func));
853+
Ok(func)
854+
}
855+
}
856+
.map_err(|descr| {
857+
invalid(&format!("invalid redefinition of {descr} as a new function"))
858+
})?
859+
};
771860

772861
current_func_body = Some(FuncBody { func_id, func, insts: vec![] });
773862

@@ -1171,7 +1260,7 @@ impl Module {
11711260
"unsupported use of {} outside `OpExtInst`",
11721261
id_def.descr(&cx),
11731262
))),
1174-
None => local_id_defs
1263+
None | Some(IdDef::FuncForwardRef(_)) => local_id_defs
11751264
.get(&id)
11761265
.copied()
11771266
.ok_or_else(|| invalid(&format!("undefined ID %{id}",))),

src/spv/spec.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def_well_known! {
137137
OpConstantTrue,
138138
OpConstant,
139139
OpUndef,
140+
OpConstantFunctionPointerINTEL,
140141

141142
OpVariable,
142143

@@ -201,6 +202,8 @@ def_well_known! {
201202
HitAttributeKHR,
202203
RayPayloadKHR,
203204
CallableDataKHR,
205+
206+
CodeSectionINTEL,
204207
],
205208
decoration: u32 = [
206209
LinkageAttributes,

src/transform.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,10 @@ impl InnerTransform for ConstDef {
473473
gv -> transformer.transform_global_var_use(*gv),
474474
} => ConstKind::PtrToGlobalVar(gv)),
475475

476+
ConstKind::PtrToFunc(func) => transform!({
477+
func -> transformer.transform_func_use(*func),
478+
} => ConstKind::PtrToFunc(func)),
479+
476480
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
477481
let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
478482
Transformed::map_iter(

src/visit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ impl InnerVisit for ConstDef {
344344
visitor.visit_type_use(*ty);
345345
match kind {
346346
&ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv),
347+
&ConstKind::PtrToFunc(func) => visitor.visit_func_use(func),
347348
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
348349
let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
349350
for &ct in const_inputs {

0 commit comments

Comments
 (0)