diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp index 95ce31327f6..b427504e3f6 100644 --- a/source/slang/slang-ir-transform-params-to-constref.cpp +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -43,55 +43,95 @@ struct TransformParamsToConstRefContext void rewriteValueUsesToAddrUses(IRInst* newAddrInst) { - HashSet workListSet; - workListSet.add(newAddrInst); - List workList; - workList.add(newAddrInst); - auto _addToWorkList = [&](IRInst* inst) - { - if (workListSet.add(inst)) - workList.add(inst); - }; + // The overall strategy here is as follows: + // + // - First, insert IRLoad() in front of all uses of newAddrInst. This + // is a value parameter turned into pointer to value. Add all IRLoad() + // instructions in the working set + // - Then, for every inserted IRLoad() instruction, search for + // IRFieldExtract(IRLoad(ptr), ...) and IRGetElement(IRLoad(ptr), ...) patterns, + // and transform these to IRLoad(IRFieldAddress(ptr, ...)) and + // IRLoad(IRGetElementPtr(ptr, ...)), and insert the new IRLoad() + // instructions in the working set + // - Remove also stores to write-once temporary variables that are + // immediately passed into a constref location in a call (see below) + // - If all uses of the inserted IRLoad() were translated, remove the + // IRLoad() to keep this pass clean + + List workList; + + traverseUses( + newAddrInst, + [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + IRLoad* loadInst = as(builder.emitLoad(newAddrInst)); + use->set(loadInst); + + workList.add(loadInst); + }); + for (Index i = 0; i < workList.getCount(); i++) { - auto inst = workList[i]; + IRLoad* loadInst = workList[i]; + bool allUsesTranslated = true; + traverseUses( - inst, + loadInst, [&](IRUse* use) { - auto user = use->getUser(); - if (workListSet.contains(user)) - return; - switch (user->getOp()) + IRInst* userInst = use->getUser(); + bool useTranslated = false; + + switch (userInst->getOp()) { case kIROp_FieldExtract: { - // Transform the IRFieldExtract into a IRFieldAddress - if (!isUseBaseAddrOperand(use, user)) - break; - auto fieldExtract = as(user); - builder.setInsertBefore(fieldExtract); - auto fieldAddr = builder.emitFieldAddress( - fieldExtract->getBase(), - fieldExtract->getField()); - fieldExtract->replaceUsesWith(fieldAddr); - _addToWorkList(fieldAddr); - return; + if (isUseBaseAddrOperand(use, userInst)) + { + // Transform IRFieldExtract(IRLoad(ptr), x) + // ==> + // IRLoad(IRFieldAddr(ptr), x) + + auto fieldExtract = as(userInst); + builder.setInsertBefore(fieldExtract); + auto fieldAddr = builder.emitFieldAddress( + loadInst->getPtr(), + fieldExtract->getField()); + auto loadFieldAddr = as(builder.emitLoad(fieldAddr)); + fieldExtract->replaceUsesWith(loadFieldAddr); + fieldExtract->removeAndDeallocate(); + + workList.add(loadFieldAddr); + useTranslated = true; + } + break; } + case kIROp_GetElement: { - // Transform the IRGetElement into a IRGetElementPtr - if (!isUseBaseAddrOperand(use, user)) - break; - auto getElement = as(user); - builder.setInsertBefore(getElement); - auto elemAddr = builder.emitElementAddress( - getElement->getBase(), - getElement->getIndex()); - getElement->replaceUsesWith(elemAddr); - _addToWorkList(elemAddr); - return; + if (isUseBaseAddrOperand(use, userInst)) + { + // Transform IRGetElement(IRLoad(ptr), x) + // ==> + // IRLoad(IRGetElementPtr(ptr), x) + + auto getElement = as(userInst); + builder.setInsertBefore(getElement); + auto getElementPtr = builder.emitElementAddress( + loadInst->getPtr(), + getElement->getIndex()); + auto loadElementPtr = as(builder.emitLoad(getElementPtr)); + getElement->replaceUsesWith(loadElementPtr); + getElement->removeAndDeallocate(); + + workList.add(loadElementPtr); + useTranslated = true; + } + break; } + case kIROp_Store: { // If the current value is being stored into a write-once temp var that @@ -99,23 +139,34 @@ struct TransformParamsToConstRefContext // rid of the temp var and replace it with `inst` directly. // (such temp var can be introduced during `updateCallSites` when we // were processing the callee.) - // - auto dest = as(user)->getPtr(); - if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration)) + + // Transform IRStore(storeDest, load(ptr)) where storeDest has attribute + // TempCallArgImmutableVarDecoration + // IRInst(storeDest) + // ==> + // IRInst(ptr) + + auto storeInst = as(userInst); + auto storeDest = storeInst->getPtr(); + + if (storeInst->getValUse() == use && + storeDest->findDecorationImpl( + kIROp_TempCallArgImmutableVarDecoration)) { - user->removeAndDeallocate(); - dest->replaceUsesWith(inst); - dest->removeAndDeallocate(); - return; + storeDest->replaceUsesWith(loadInst->getPtr()); + userInst->removeAndDeallocate(); + storeDest->removeAndDeallocate(); + useTranslated = true; } break; } } - // Insert a load before the user and replace the user with the load - builder.setInsertBefore(user); - auto loadInst = builder.emitLoad(inst); - use->set(loadInst); + + allUsesTranslated = allUsesTranslated && useTranslated; }); + + if (allUsesTranslated) + loadInst->removeAndDeallocate(); } } diff --git a/tests/bugs/gh-9073-spirv-pointer-double-dereference.slang b/tests/bugs/gh-9073-spirv-pointer-double-dereference.slang new file mode 100644 index 00000000000..6d03c043d52 --- /dev/null +++ b/tests/bugs/gh-9073-spirv-pointer-double-dereference.slang @@ -0,0 +1,67 @@ +//TEST:SIMPLE(filecheck=POSITIVE): -entry main -stage closesthit -target spirv + +// This is a minimal repro for issue 9073 + +struct Payload +{ + uint volumeStackIndex; + int volumeStack[2]; + + int GetVolume() + { + // This function exercises the following transformation in + // transformParamsToConstRef(): + // + // IRGetElement( + // IRFieldExtract(...), + // IRFieldExtract(...)) + // + // => + // + // IRGetElementPtr( + // IRFieldAddress(...), + // IRLoad(IRFieldAddress(...))) + // + // That is, that IRLoad() is correctly injected to + // IRGetElementPtr() index when IRFieldExtract() is + // transformed to IRFieldAddress() + + return volumeStack[volumeStackIndex]; + } + + int GetVolume2() + { + // This function exercises the following transformation in + // transformParamsToConstRef(): + // + // IRGetElement( + // IRFieldExtract(...), + // IRGetElement(...)) + // + // => + // + // IRGetElementPtr( + // IRFieldAddress(...), + // IRLoad(IRGetElementPtr(...))) + // + // That is, that IRLoad() is correctly injected to + // IRGetElementPtr() index when IRGetElement() is + // transformed to IRGetElementPtr() + + return volumeStack[volumeStack[volumeStackIndex]]; + } +} + +RWStructuredBuffer outputBuffer; + +[shader("closesthit")] +void main(inout Payload payload) +{ +// POSITIVE-NOT: {{(error|warning).*}}: +// POSITIVE: result code = 0 +// POSITIVE-NOT: {{(error|warning).*}}: + + outputBuffer[0] = payload.GetVolume(); + + outputBuffer[1] = payload.GetVolume2(); +}