Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) {
ScopeHierarchy scopeHierarchy = switch (arch) {
case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]);
case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]);
case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]);
case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1], scopeIds[2]);
default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ public Object visitThreadDeclarator(LitmusCParser.ThreadDeclaratorContext ctx) {
scope = currentThread = ctx.threadId().id;
threadIds.add(currentThread);
if (isOpenCL && ctx.threadScope() != null) {
int sgID = 0; // Use subgroup ID 0 as default for OpenCL Litmus
int wgID = ctx.threadScope().scopeID(0).id;
int devID = ctx.threadScope().scopeID(1).id;
programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID);
programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID, sgID);
} else {
programBuilder.newThread(currentThread);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Program visitOp(SpirvParser.OpContext ctx) {
}

private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1);
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1, 1);
boolean hasConfig = false;
for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) {
SpirvParser.ConfigHeaderContext cfgCtx = header.configHeader();
Expand All @@ -71,10 +71,10 @@ private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
}
hasConfig = true;
List<SpirvParser.LiteranHeaderUnsignedIntegerContext> literals = cfgCtx.literanHeaderUnsignedInteger();
int sg = Integer.parseInt(literals.get(0).getText());
int wg = Integer.parseInt(literals.get(1).getText());
int qf = Integer.parseInt(literals.get(2).getText());
grid = new ThreadGrid(sg, wg, qf, 1);
int threadCount = Integer.parseInt(literals.get(0).getText());
int subgroupCount = Integer.parseInt(literals.get(1).getText());
int workgroupCount = Integer.parseInt(literals.get(2).getText());
grid = new ThreadGrid(threadCount, subgroupCount, workgroupCount, 1, 1);
}
}
return new ProgramBuilder(grid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private Expression getDecorationExpressions(String id, Type type) {
case "LocalInvocationId" -> makeArray(id, type, tid % grid.wgSize(), 0, 0);
case "LocalInvocationIndex" -> makeScalar(id, type, tid % grid.wgSize()); // scalar of LocalInvocationId
case "GlobalInvocationId" -> makeArray(id, type, tid % grid.dvSize(), 0, 0);
case "DeviceIndex" -> makeScalar(id, type, 0);
case "DeviceIndex" -> makeScalar(id, type, grid.dvId(tid));
case "SubgroupId" -> makeScalar(id, type, grid.sgId(tid));
case "WorkgroupId" -> makeArray(id, type, grid.wgId(tid), 0, 0);
case "SubgroupSize" -> makeScalar(id, type, grid.sgSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ public static ScopeHierarchy ScopeHierarchyForPTX(int gpu, int cta) {
return scopeHierarchy;
}

public static ScopeHierarchy ScopeHierarchyForOpenCL(int dev, int wg) {
public static ScopeHierarchy ScopeHierarchyForOpenCL(int dev, int wg, int sg) {
ScopeHierarchy scopeHierarchy = new ScopeHierarchy();
scopeHierarchy.scopeIds.put(Tag.OpenCL.ALL, 0);
scopeHierarchy.scopeIds.put(Tag.OpenCL.DEVICE, dev);
scopeHierarchy.scopeIds.put(Tag.OpenCL.WORK_GROUP, wg);
scopeHierarchy.scopeIds.put(Tag.OpenCL.SUB_GROUP, sg);
return scopeHierarchy;
}

Expand Down
40 changes: 21 additions & 19 deletions dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,42 @@

public class ThreadGrid {

private final int sg;
private final int wg;
private final int qf;
private final int dv;
private final int thCount;
private final int sgCount;
private final int wgCount;
private final int qfCount;
private final int dvCount;

public ThreadGrid(int sg, int wg, int qf, int dv) {
List<Integer> elements = List.of(sg, wg, qf, dv);
public ThreadGrid(int thCount, int sgCount, int wgCount, int qfCount, int dvCount) {
List<Integer> elements = List.of(thCount, sgCount, wgCount, qfCount, dvCount);
if (elements.stream().anyMatch(i -> i <= 0)) {
throw new ParsingException("Thread grid dimensions must be positive");
}
this.sg = sg;
this.wg = wg;
this.qf = qf;
this.dv = dv;
this.thCount = thCount;
this.sgCount = sgCount;
this.wgCount = wgCount;
this.qfCount = qfCount;
this.dvCount = dvCount;
}

public int sgSize() {
return sg;
return thCount;
}

public int wgSize() {
return sg * wg;
return thCount * sgCount;
}

public int qfSize() {
return sg * wg * qf;
return thCount * sgCount * wgCount;
}

public int dvSize() {
return sg * wg * qf * dv;
return thCount * sgCount * wgCount * qfCount;
}

public int sysSize() { // Number of cross-device threads
return thCount * sgCount * wgCount * qfCount * dvCount;
}

public int thId(int tid) {
Expand All @@ -55,10 +61,6 @@ public int qfId(int tid) {
}

public int dvId(int tid) {
return tid / dvSize();
}

public ScopeHierarchy getScoreHierarchy(int tid) {
return ScopeHierarchy.ScopeHierarchyForVulkan(qfId(tid), wgId(tid), sgId(tid));
return (tid % sysSize()) / dvSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ public static String storeMO(String mo) {
public static final class OpenCL {
// Scopes
public static final String WORK_ITEM = "WI";
public static final String SUB_GROUP = "SG";
public static final String WORK_GROUP = "WG";
public static final String DEVICE = "DV";
public static final String ALL = "ALL";
Expand All @@ -388,7 +389,7 @@ public static final class OpenCL {
public static final String DEFAULT_WEAK_SCOPE = WORK_ITEM;

public static List<String> getScopeTags() {
return List.of(WORK_GROUP, DEVICE, ALL);
return List.of(SUB_GROUP, WORK_GROUP, DEVICE, ALL);
}

public static List<String> getSpaceTags() {
Expand Down Expand Up @@ -529,13 +530,13 @@ public static String toOpenCLTag(String tag) {
case SEQ_CST -> C11.MO_SC;

// Scope
// TODO: OpenCL Kernel supports sub_group, but it's not mentioned in the model
// subgroup is supported in OpenCL Kernel, but it is not mentioned in the model
case INVOCATION -> OpenCL.WORK_ITEM;
case SUBGROUP,
WORKGROUP -> OpenCL.WORK_GROUP;
case WORKGROUP -> OpenCL.WORK_GROUP;
case DEVICE -> OpenCL.DEVICE;
case CROSS_DEVICE -> OpenCL.ALL;
case QUEUE_FAMILY,
case SUBGROUP,
QUEUE_FAMILY,
SHADER_CALL -> throw new UnsupportedOperationException(
getErrorMsg(model, "scope", tag));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,15 @@ private Thread createSPVThreadFromFunction(Function function, int tid, ThreadGri
FunctionType type = function.getFunctionType();
List<String> args = Lists.transform(function.getParameterRegisters(), Register::getName);
ThreadStart start = EventFactory.newThreadStart(null);
ScopeHierarchy scope = grid.getScoreHierarchy(tid);
Arch arch = function.getProgram().getArch();
ScopeHierarchy scope;
if (arch == Arch.VULKAN) {
scope = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(tid), grid.wgId(tid), grid.sgId(tid));
} else if (arch == Arch.OPENCL) {
scope = ScopeHierarchy.ScopeHierarchyForOpenCL(grid.dvId(tid), grid.wgId(tid), grid.sgId(tid));
} else {
throw new MalformedProgramException("Unsupported architecture for thread creation: " + arch);
}
Thread thread = new Thread(name, type, args, tid, start, scope, Set.of());
thread.copyDummyCountFrom(function);
Label returnLabel = EventFactory.newLabel("RETURN_OF_T" + thread.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class ProgramBuilderTest {

private static final TypeFactory types = TypeFactory.getInstance();

private final ProgramBuilder builder = new ProgramBuilder(new ThreadGrid(1, 1, 1, 1));
private final ProgramBuilder builder = new ProgramBuilder(new ThreadGrid(1, 1, 1, 1, 1));
private final ControlFlowBuilder cfBuilder = builder.getControlFlowBuilder();

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

public class VisitorExtensionClspvReflectionTest {

private final MockProgramBuilder builder = new MockProgramBuilder(new ThreadGrid(2, 3, 4, 1));
private final MockProgramBuilder builder = new MockProgramBuilder(new ThreadGrid(2, 3, 4, 1, 1));

@Before
public void before() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class MockProgramBuilder extends ProgramBuilder {
private static final ExpressionFactory exprFactory = ExpressionFactory.getInstance();

public MockProgramBuilder() {
this(new ThreadGrid(1, 1, 1, 1));
this(new ThreadGrid(1, 1, 1, 1, 1));
}

public MockProgramBuilder(ThreadGrid grid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private void doTestLegalConfig(String input, List<Integer> scopes) {
int wg_size = scopes.get(1) * sg_size;
int qf_size = scopes.get(2) * wg_size;
for (int i = 0; i < size; i++) {
ScopeHierarchy hierarchy = grid.getScoreHierarchy(i);
ScopeHierarchy hierarchy = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(i), grid.wgId(i), grid.sgId(i));
assertEquals(((i % qf_size) % wg_size) / sg_size, hierarchy.getScopeId(Tag.Vulkan.SUB_GROUP));
assertEquals((i % qf_size) / wg_size, hierarchy.getScopeId(Tag.Vulkan.WORK_GROUP));
assertEquals(i / qf_size, hierarchy.getScopeId(Tag.Vulkan.QUEUE_FAMILY));
Expand Down