Skip to content
147 changes: 99 additions & 48 deletions source/slang/slang-ir-transform-params-to-constref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,79 +43,130 @@ struct TransformParamsToConstRefContext

void rewriteValueUsesToAddrUses(IRInst* newAddrInst)
{
HashSet<IRInst*> workListSet;
workListSet.add(newAddrInst);
List<IRInst*> workList;
workList.add(newAddrInst);
auto _addToWorkList = [&](IRInst* inst)
{
if (workListSet.add(inst))
workList.add(inst);
};
// The overall strategy here is as follows:
//
// - First, insert IRLoad() in front of all uses of newAddrInst. This
// is a value parameter turned into pointer to value. Add all IRLoad()
// instructions in the working set
// - Then, for every inserted IRLoad() instruction, search for
// IRFieldExtract(IRLoad(ptr), ...) and IRGetElement(IRLoad(ptr), ...) patterns,
// and transform these to IRLoad(IRFieldAddress(ptr, ...)) and
// IRLoad(IRGetElementPtr(ptr, ...)), and insert the new IRLoad()
// instructions in the working set
// - Remove also stores to write-once temporary variables that are
// immediately passed into a constref location in a call (see below)
// - If all uses of the inserted IRLoad() were translated, remove the
// IRLoad() to keep this pass clean

List<IRLoad*> workList;

traverseUses(
newAddrInst,
[&](IRUse* use)
{
auto user = use->getUser();
builder.setInsertBefore(user);
IRLoad* loadInst = as<IRLoad>(builder.emitLoad(newAddrInst));
use->set(loadInst);

workList.add(loadInst);
});

for (Index i = 0; i < workList.getCount(); i++)
{
auto inst = workList[i];
IRLoad* loadInst = workList[i];
bool allUsesTranslated = true;

traverseUses(
inst,
loadInst,
[&](IRUse* use)
{
auto user = use->getUser();
if (workListSet.contains(user))
return;
switch (user->getOp())
IRInst* userInst = use->getUser();
bool useTranslated = false;

switch (userInst->getOp())
{
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
if (!isUseBaseAddrOperand(use, user))
break;
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
fieldExtract->getField());
fieldExtract->replaceUsesWith(fieldAddr);
_addToWorkList(fieldAddr);
return;
if (isUseBaseAddrOperand(use, userInst))
{
// Transform IRFieldExtract(IRLoad(ptr), x)
// ==>
// IRLoad(IRFieldAddr(ptr), x)

auto fieldExtract = as<IRFieldExtract>(userInst);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
loadInst->getPtr(),
fieldExtract->getField());
auto loadFieldAddr = as<IRLoad>(builder.emitLoad(fieldAddr));
fieldExtract->replaceUsesWith(loadFieldAddr);
fieldExtract->removeAndDeallocate();

workList.add(loadFieldAddr);
useTranslated = true;
}
break;
}

case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
if (!isUseBaseAddrOperand(use, user))
break;
auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
getElement->getIndex());
getElement->replaceUsesWith(elemAddr);
_addToWorkList(elemAddr);
return;
if (isUseBaseAddrOperand(use, userInst))
{
// Transform IRGetElement(IRLoad(ptr), x)
// ==>
// IRLoad(IRGetElementPtr(ptr), x)

auto getElement = as<IRGetElement>(userInst);
builder.setInsertBefore(getElement);
auto getElementPtr = builder.emitElementAddress(
loadInst->getPtr(),
getElement->getIndex());
auto loadElementPtr = as<IRLoad>(builder.emitLoad(getElementPtr));
getElement->replaceUsesWith(loadElementPtr);
getElement->removeAndDeallocate();

workList.add(loadElementPtr);
useTranslated = true;
}
break;
}

case kIROp_Store:
{
// If the current value is being stored into a write-once temp var that
// is immediately passed into a constref location in a call, we can get
// rid of the temp var and replace it with `inst` directly.
// (such temp var can be introduced during `updateCallSites` when we
// were processing the callee.)
//
auto dest = as<IRStore>(user)->getPtr();
if (dest->findDecorationImpl(kIROp_TempCallArgImmutableVarDecoration))

// Transform IRStore(storeDest, load(ptr)) where storeDest has attribute
// TempCallArgImmutableVarDecoration
// IRInst(storeDest)
// ==>
// IRInst(ptr)

auto storeInst = as<IRStore>(userInst);
auto storeDest = storeInst->getPtr();

if (storeInst->getValUse() == use &&
storeDest->findDecorationImpl(
kIROp_TempCallArgImmutableVarDecoration))
{
user->removeAndDeallocate();
dest->replaceUsesWith(inst);
dest->removeAndDeallocate();
return;
storeDest->replaceUsesWith(loadInst->getPtr());
userInst->removeAndDeallocate();
storeDest->removeAndDeallocate();
useTranslated = true;
}
break;
}
}
// Insert a load before the user and replace the user with the load
builder.setInsertBefore(user);
auto loadInst = builder.emitLoad(inst);
use->set(loadInst);

allUsesTranslated = allUsesTranslated && useTranslated;
});

if (allUsesTranslated)
loadInst->removeAndDeallocate();
}
}

Expand Down
67 changes: 67 additions & 0 deletions tests/bugs/gh-9073-spirv-pointer-double-dereference.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//TEST:SIMPLE(filecheck=POSITIVE): -entry main -stage closesthit -target spirv

// This is a minimal repro for issue 9073

struct Payload
{
uint volumeStackIndex;
int volumeStack[2];

int GetVolume()
{
// This function exercises the following transformation in
// transformParamsToConstRef():
//
// IRGetElement(
// IRFieldExtract(...),
// IRFieldExtract(...))
//
// =>
//
// IRGetElementPtr(
// IRFieldAddress(...),
// IRLoad(IRFieldAddress(...)))
//
// That is, that IRLoad() is correctly injected to
// IRGetElementPtr() index when IRFieldExtract() is
// transformed to IRFieldAddress()

return volumeStack[volumeStackIndex];
}

int GetVolume2()
{
// This function exercises the following transformation in
// transformParamsToConstRef():
//
// IRGetElement(
// IRFieldExtract(...),
// IRGetElement(...))
//
// =>
//
// IRGetElementPtr(
// IRFieldAddress(...),
// IRLoad(IRGetElementPtr(...)))
//
// That is, that IRLoad() is correctly injected to
// IRGetElementPtr() index when IRGetElement() is
// transformed to IRGetElementPtr()

return volumeStack[volumeStack[volumeStackIndex]];
}
}

RWStructuredBuffer<int> outputBuffer;

[shader("closesthit")]
void main(inout Payload payload)
{
// POSITIVE-NOT: {{(error|warning).*}}:
// POSITIVE: result code = 0
// POSITIVE-NOT: {{(error|warning).*}}:

outputBuffer[0] = payload.GetVolume();

outputBuffer[1] = payload.GetVolume2();
}