From 926ff51037d4bd31c0f3765e89f69a2dc5948345 Mon Sep 17 00:00:00 2001 From: Haining Date: Fri, 25 Apr 2025 18:13:05 +0200 Subject: [PATCH] Generalize threadGrid Signed-off-by: Haining Tong --- .../parsers/program/utils/ProgramBuilder.java | 2 +- .../program/visitors/VisitorLitmusC.java | 3 +- .../program/visitors/VisitorSpirv.java | 10 ++--- .../visitors/spirv/decorations/BuiltIn.java | 2 +- .../dartagnan/program/ScopeHierarchy.java | 3 +- .../dat3m/dartagnan/program/ThreadGrid.java | 40 ++++++++++--------- .../dat3m/dartagnan/program/event/Tag.java | 11 ++--- .../program/processing/ThreadCreation.java | 10 ++++- .../spirv/builders/ProgramBuilderTest.java | 2 +- .../VisitorExtensionClspvReflectionTest.java | 2 +- .../spirv/mocks/MockProgramBuilder.java | 2 +- .../dartagnan/spirv/header/ConfigTest.java | 2 +- 12 files changed, 51 insertions(+), 38 deletions(-) diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java index 092bcfae1d..b10ed311ca 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java @@ -301,7 +301,7 @@ public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) { ScopeHierarchy scopeHierarchy = switch (arch) { case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]); case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]); - case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]); + case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1], scopeIds[2]); default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch); }; diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java index 03cd6489ff..03d6c4545c 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java @@ -161,9 +161,10 @@ public Object visitThreadDeclarator(LitmusCParser.ThreadDeclaratorContext ctx) { scope = currentThread = ctx.threadId().id; threadIds.add(currentThread); if (isOpenCL && ctx.threadScope() != null) { + int sgID = 0; // Use subgroup ID 0 as default for OpenCL Litmus int wgID = ctx.threadScope().scopeID(0).id; int devID = ctx.threadScope().scopeID(1).id; - programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID); + programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID, sgID); } else { programBuilder.newThread(currentThread); } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java index 28c4a4612b..24d372893e 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java @@ -61,7 +61,7 @@ public Program visitOp(SpirvParser.OpContext ctx) { } private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) { - ThreadGrid grid = new ThreadGrid(1, 1, 1, 1); + ThreadGrid grid = new ThreadGrid(1, 1, 1, 1, 1); boolean hasConfig = false; for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) { SpirvParser.ConfigHeaderContext cfgCtx = header.configHeader(); @@ -71,10 +71,10 @@ private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) { } hasConfig = true; List literals = cfgCtx.literanHeaderUnsignedInteger(); - int sg = Integer.parseInt(literals.get(0).getText()); - int wg = Integer.parseInt(literals.get(1).getText()); - int qf = Integer.parseInt(literals.get(2).getText()); - grid = new ThreadGrid(sg, wg, qf, 1); + int threadCount = Integer.parseInt(literals.get(0).getText()); + int subgroupCount = Integer.parseInt(literals.get(1).getText()); + int workgroupCount = Integer.parseInt(literals.get(2).getText()); + grid = new ThreadGrid(threadCount, subgroupCount, workgroupCount, 1, 1); } } return new ProgramBuilder(grid); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/BuiltIn.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/BuiltIn.java index e08015862b..6604aa98b3 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/BuiltIn.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/decorations/BuiltIn.java @@ -79,7 +79,7 @@ private Expression getDecorationExpressions(String id, Type type) { case "LocalInvocationId" -> makeArray(id, type, tid % grid.wgSize(), 0, 0); case "LocalInvocationIndex" -> makeScalar(id, type, tid % grid.wgSize()); // scalar of LocalInvocationId case "GlobalInvocationId" -> makeArray(id, type, tid % grid.dvSize(), 0, 0); - case "DeviceIndex" -> makeScalar(id, type, 0); + case "DeviceIndex" -> makeScalar(id, type, grid.dvId(tid)); case "SubgroupId" -> makeScalar(id, type, grid.sgId(tid)); case "WorkgroupId" -> makeArray(id, type, grid.wgId(tid), 0, 0); case "SubgroupSize" -> makeScalar(id, type, grid.sgSize()); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java index c257565f77..1af6e55972 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java @@ -33,11 +33,12 @@ public static ScopeHierarchy ScopeHierarchyForPTX(int gpu, int cta) { return scopeHierarchy; } - public static ScopeHierarchy ScopeHierarchyForOpenCL(int dev, int wg) { + public static ScopeHierarchy ScopeHierarchyForOpenCL(int dev, int wg, int sg) { ScopeHierarchy scopeHierarchy = new ScopeHierarchy(); scopeHierarchy.scopeIds.put(Tag.OpenCL.ALL, 0); scopeHierarchy.scopeIds.put(Tag.OpenCL.DEVICE, dev); scopeHierarchy.scopeIds.put(Tag.OpenCL.WORK_GROUP, wg); + scopeHierarchy.scopeIds.put(Tag.OpenCL.SUB_GROUP, sg); return scopeHierarchy; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java index 8b0ecbf21e..ef71a83394 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java @@ -6,36 +6,42 @@ public class ThreadGrid { - private final int sg; - private final int wg; - private final int qf; - private final int dv; + private final int thCount; + private final int sgCount; + private final int wgCount; + private final int qfCount; + private final int dvCount; - public ThreadGrid(int sg, int wg, int qf, int dv) { - List elements = List.of(sg, wg, qf, dv); + public ThreadGrid(int thCount, int sgCount, int wgCount, int qfCount, int dvCount) { + List elements = List.of(thCount, sgCount, wgCount, qfCount, dvCount); if (elements.stream().anyMatch(i -> i <= 0)) { throw new ParsingException("Thread grid dimensions must be positive"); } - this.sg = sg; - this.wg = wg; - this.qf = qf; - this.dv = dv; + this.thCount = thCount; + this.sgCount = sgCount; + this.wgCount = wgCount; + this.qfCount = qfCount; + this.dvCount = dvCount; } public int sgSize() { - return sg; + return thCount; } public int wgSize() { - return sg * wg; + return thCount * sgCount; } public int qfSize() { - return sg * wg * qf; + return thCount * sgCount * wgCount; } public int dvSize() { - return sg * wg * qf * dv; + return thCount * sgCount * wgCount * qfCount; + } + + public int sysSize() { // Number of cross-device threads + return thCount * sgCount * wgCount * qfCount * dvCount; } public int thId(int tid) { @@ -55,10 +61,6 @@ public int qfId(int tid) { } public int dvId(int tid) { - return tid / dvSize(); - } - - public ScopeHierarchy getScoreHierarchy(int tid) { - return ScopeHierarchy.ScopeHierarchyForVulkan(qfId(tid), wgId(tid), sgId(tid)); + return (tid % sysSize()) / dvSize(); } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/Tag.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/Tag.java index 3f184fe817..2fcf9bfe2d 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/Tag.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/Tag.java @@ -375,6 +375,7 @@ public static String storeMO(String mo) { public static final class OpenCL { // Scopes public static final String WORK_ITEM = "WI"; + public static final String SUB_GROUP = "SG"; public static final String WORK_GROUP = "WG"; public static final String DEVICE = "DV"; public static final String ALL = "ALL"; @@ -388,7 +389,7 @@ public static final class OpenCL { public static final String DEFAULT_WEAK_SCOPE = WORK_ITEM; public static List getScopeTags() { - return List.of(WORK_GROUP, DEVICE, ALL); + return List.of(SUB_GROUP, WORK_GROUP, DEVICE, ALL); } public static List getSpaceTags() { @@ -529,13 +530,13 @@ public static String toOpenCLTag(String tag) { case SEQ_CST -> C11.MO_SC; // Scope - // TODO: OpenCL Kernel supports sub_group, but it's not mentioned in the model + // subgroup is supported in OpenCL Kernel, but it is not mentioned in the model case INVOCATION -> OpenCL.WORK_ITEM; - case SUBGROUP, - WORKGROUP -> OpenCL.WORK_GROUP; + case WORKGROUP -> OpenCL.WORK_GROUP; case DEVICE -> OpenCL.DEVICE; case CROSS_DEVICE -> OpenCL.ALL; - case QUEUE_FAMILY, + case SUBGROUP, + QUEUE_FAMILY, SHADER_CALL -> throw new UnsupportedOperationException( getErrorMsg(model, "scope", tag)); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java index 38e03d6a86..9fdc207965 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java @@ -415,7 +415,15 @@ private Thread createSPVThreadFromFunction(Function function, int tid, ThreadGri FunctionType type = function.getFunctionType(); List args = Lists.transform(function.getParameterRegisters(), Register::getName); ThreadStart start = EventFactory.newThreadStart(null); - ScopeHierarchy scope = grid.getScoreHierarchy(tid); + Arch arch = function.getProgram().getArch(); + ScopeHierarchy scope; + if (arch == Arch.VULKAN) { + scope = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(tid), grid.wgId(tid), grid.sgId(tid)); + } else if (arch == Arch.OPENCL) { + scope = ScopeHierarchy.ScopeHierarchyForOpenCL(grid.dvId(tid), grid.wgId(tid), grid.sgId(tid)); + } else { + throw new MalformedProgramException("Unsupported architecture for thread creation: " + arch); + } Thread thread = new Thread(name, type, args, tid, start, scope, Set.of()); thread.copyDummyCountFrom(function); Label returnLabel = EventFactory.newLabel("RETURN_OF_T" + thread.getId()); diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/ProgramBuilderTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/ProgramBuilderTest.java index dd5096466e..0a5e2b5ab8 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/ProgramBuilderTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/builders/ProgramBuilderTest.java @@ -19,7 +19,7 @@ public class ProgramBuilderTest { private static final TypeFactory types = TypeFactory.getInstance(); - private final ProgramBuilder builder = new ProgramBuilder(new ThreadGrid(1, 1, 1, 1)); + private final ProgramBuilder builder = new ProgramBuilder(new ThreadGrid(1, 1, 1, 1, 1)); private final ControlFlowBuilder cfBuilder = builder.getControlFlowBuilder(); @Test diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java index 1590a29bb0..4551b82a6a 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/extensions/VisitorExtensionClspvReflectionTest.java @@ -17,7 +17,7 @@ public class VisitorExtensionClspvReflectionTest { - private final MockProgramBuilder builder = new MockProgramBuilder(new ThreadGrid(2, 3, 4, 1)); + private final MockProgramBuilder builder = new MockProgramBuilder(new ThreadGrid(2, 3, 4, 1, 1)); @Before public void before() { diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java index 8bd24f6b01..4f07b2f79d 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilder.java @@ -26,7 +26,7 @@ public class MockProgramBuilder extends ProgramBuilder { private static final ExpressionFactory exprFactory = ExpressionFactory.getInstance(); public MockProgramBuilder() { - this(new ThreadGrid(1, 1, 1, 1)); + this(new ThreadGrid(1, 1, 1, 1, 1)); } public MockProgramBuilder(ThreadGrid grid) { diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java index a8970684e8..1d2f5e2203 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java @@ -35,7 +35,7 @@ private void doTestLegalConfig(String input, List scopes) { int wg_size = scopes.get(1) * sg_size; int qf_size = scopes.get(2) * wg_size; for (int i = 0; i < size; i++) { - ScopeHierarchy hierarchy = grid.getScoreHierarchy(i); + ScopeHierarchy hierarchy = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(i), grid.wgId(i), grid.sgId(i)); assertEquals(((i % qf_size) % wg_size) / sg_size, hierarchy.getScopeId(Tag.Vulkan.SUB_GROUP)); assertEquals((i % qf_size) / wg_size, hierarchy.getScopeId(Tag.Vulkan.WORK_GROUP)); assertEquals(i / qf_size, hierarchy.getScopeId(Tag.Vulkan.QUEUE_FAMILY));