diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java index ccb8b09..51f7d51 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java @@ -213,6 +213,14 @@ public StringBuilder toDot(StringBuilder sb, boolean verbose) { sb.append(""); return sb; } + public StringBuilder listDomFrontiers(StringBuilder sb) { + sb.append("L").append(this.bid).append(":"); + for (BasicBlock bb: dominationFrontier) { + sb.append(" L").append(bb.bid); + } + sb.append("\n"); + return sb; + } @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 48526e6..501e7d5 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -426,7 +426,7 @@ private boolean compileArrayStoreExpr(AST.ArrayStoreExpr arrayStoreExpr) { } private boolean compileNewExpr(AST.NewExpr newExpr) { - codeNew(newExpr.type); + codeNew(newExpr.type,newExpr.len,newExpr.initValue); return false; } @@ -644,20 +644,44 @@ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) codeMove(value, indexed); } - private void codeNew(Type type) { + private void codeNew(Type type, AST.Expr len, AST.Expr initVal) { if (type instanceof Type.TypeArray typeArray) - codeNewArray(typeArray); + codeNewArray(typeArray, len, initVal); else if (type instanceof Type.TypeStruct typeStruct) codeNewStruct(typeStruct); else throw new CompilerException("Unexpected type: " + type); } - private void codeNewArray(Type.TypeArray typeArray) { + private void codeNewArray(Type.TypeArray typeArray, AST.Expr len, AST.Expr initVal) { var temp = createTemp(typeArray); + Operand lenOperand = null; + Operand initValOperand = null; + if (len != null) { + boolean indexed = compileExpr(len); + if (indexed) + codeIndexedLoad(); + if (initVal != null) { + indexed = compileExpr(initVal); + if (indexed) + codeIndexedLoad(); + initValOperand = pop(); + } + lenOperand = pop(); + } + Instruction insn; var target = (Operand.RegisterOperand) issa.write(temp); - var insn = new Instruction.NewArray(typeArray, target); + if (lenOperand != null) { + if (initValOperand != null) + insn = new Instruction.NewArray(typeArray, target, issa.read(lenOperand), issa.read(initValOperand)); + else + insn = new Instruction.NewArray(typeArray, target, issa.read(lenOperand)); + } + else + insn = new Instruction.NewArray(typeArray, target); issa.recordDef(target, insn); + if (lenOperand != null) issa.recordUse(lenOperand,insn); + if (initValOperand != null) issa.recordUse(initValOperand,insn); code(insn); } @@ -941,6 +965,32 @@ private Register addPhiOperands(Register variable, Instruction.Phi phi) { return tryRemovingPhi(phi); } + // The Phi's def is dead so we need to remove + // all occurrences of this def from the memoized defs + // per Basic Block + private void clearDefs(Instruction.Phi phi) { + // TODO rethink the data structure for currentDef + var def = phi.value(); + var defs = currentDef.get(def.nonSSAId()); + // Make a list of block/reg that we need to delete + var bbList = new ArrayList(); + var regList = new ArrayList(); + for (var entries : defs.entrySet()) { + var bb = entries.getKey(); + var reg = entries.getValue(); + if (reg.equals(def)) { + bbList.add(bb); + regList.add(reg); + } + } + // Now delete them + for (int i = 0; i < bbList.size(); i++) { + var bb = bbList.get(i); + var reg = regList.get(i); + defs.remove(bb, reg); + } + } + private Register tryRemovingPhi(Instruction.Phi phi) { Register same = null; // Check if phi has distinct inputs @@ -967,6 +1017,9 @@ private Register tryRemovingPhi(Instruction.Phi phi) { // remove all uses of phi to same and remove phi replacePhiValueAndUsers(phi, same); phi.block.deleteInstruction(phi); + // Since the phi is dead any references to its def + // must be removed; this is not mentioned in the paper + clearDefs(phi); // try to recursively remove all phi users, which might have become trivial for (var use: users) { if (use instanceof Instruction.Phi phiuser) @@ -979,25 +1032,28 @@ private Register tryRemovingPhi(Instruction.Phi phi) { * Reroute all uses of phi to new value */ private void replacePhiValueAndUsers(Instruction.Phi phi, Register newValue) { - var oldDefUseChain = ssaDefUses.get(phi.value()); + var oldValue = phi.value(); + var oldDefUseChain = ssaDefUses.get(oldValue); var newDefUseChain = ssaDefUses.get(newValue); if (newDefUseChain == null) { - // Can be null because this may be existing def - newDefUseChain = SSAEdges.addDef(ssaDefUses, newValue, phi); + throw new CompilerException("Expected error: undefined var " + newValue); } if (oldDefUseChain != null) { for (Instruction instruction: oldDefUseChain.useList) { + boolean replaced; if (instruction instanceof Instruction.Phi somePhi) { - somePhi.replaceInput(phi.value(), newValue); + replaced = somePhi.replaceInput(oldValue, newValue); } else { - instruction.replaceUse(phi.value(), newValue); + replaced = instruction.replaceUse(oldValue, newValue); + } + if (!replaced) { + throw new CompilerException("Discrepancy between var use list and var definition"); } } // Users of phi old value become users of the new value newDefUseChain.useList.addAll(oldDefUseChain.useList); oldDefUseChain.useList.clear(); - // FIXME remove old def from def-use chains } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java index ea3dc2f..2a6b120 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java @@ -1,10 +1,8 @@ package com.compilerprogramming.ezlang.compiler; -import java.util.ArrayList; import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.function.Consumer; /** * The dominator tree construction algorithm is based on figure 9.24, @@ -34,6 +32,7 @@ public DominatorTree(BasicBlock entry) { populateTree(); setDepth(); calculateDominanceFrontiers(); + //calculateDominanceFrontiersMethod2(); } private void calculateDominatorTree() { @@ -41,7 +40,7 @@ private void calculateDominatorTree() { annotateBlocksWithRPO(); sortBlocksByRPO(); - // Set IDom entry for root to itself + // Set IDom entry for root to itself (see note below) entry.idom = entry; boolean changed = true; while (changed) { @@ -68,6 +67,11 @@ private void calculateDominatorTree() { } } } + // There is a contradiction between what is described in the book + // and the algo - as the book says the IDOM for root is undefined + // But we set this to itself when calculating. So after calculations + // are done, set this to null. This is technically more correct IMO. + entry.idom = null; } private void resetDomInfo() { @@ -146,7 +150,7 @@ private BasicBlock findFirstPredecessorWithIdom(BasicBlock n) { private void populateTree() { for (BasicBlock block : blocks) { BasicBlock idom = block.idom; - if (idom == block) // root + if (idom == null) // root continue; // add edge from idom to n idom.dominatedChildren.add(block); @@ -166,12 +170,13 @@ private void setDepth() { */ private void setDepth_(BasicBlock block) { BasicBlock idom = block.idom; - if (idom != block) { + if (idom != null) { assert idom.domDepth > 0; block.domDepth = idom.domDepth + 1; - } else { - assert idom.domDepth == 1; - assert idom.domDepth == block.domDepth; + } + else { + // root (entry) block's idom is null + assert block.domDepth == 1; } for (BasicBlock child : block.dominatedChildren) setDepth_(child); @@ -189,6 +194,8 @@ private void calculateDominanceFrontiers() { // while runner != doms[b] // add b to runner’s dominance frontier set // runner = doms[runner] + for (BasicBlock b: blocks) + b.dominationFrontier.clear(); // empty set for (BasicBlock b : blocks) { if (b.predecessors.size() >= 2) { for (BasicBlock p : b.predecessors) { @@ -204,6 +211,37 @@ private void calculateDominanceFrontiers() { } } + // We have an alternative approach to calculating DOM Frontiers to + // allow us to validate above + private void calculateDominanceFrontiersMethod2() + { + for (BasicBlock b: blocks) + b.dominationFrontier.clear(); // empty set + computeDF(entry); + } + + // Implementation based on description in pg 440 of + // Modern Compiler implementation in C + // Appel + private void computeDF(BasicBlock n) { + var S = new HashSet(); + for (BasicBlock y: n.successors) { + if (y.idom != n) + S.add(y); + } + for (BasicBlock c: n.dominatedChildren) { + computeDF(c); + for (BasicBlock w: c.dominationFrontier) { + // Note that the printed book has an error below + // and errata gives the correct version + if (!n.dominates(w) || n==w) { + S.add(w); + } + } + } + n.dominationFrontier.addAll(S); + } + public String generateDotOutput() { StringBuilder sb = new StringBuilder(); sb.append("digraph DomTree {\n"); @@ -212,10 +250,18 @@ public String generateDotOutput() { } for (BasicBlock n : blocks) { BasicBlock idom = n.idom; - if (idom == n) continue; + if (idom == null) continue; sb.append(idom.uniqueName()).append("->").append(n.uniqueName()).append(";\n"); } sb.append("}\n"); return sb.toString(); } + + public String listDomFrontiers() { + StringBuilder sb = new StringBuilder(); + for (BasicBlock n : blocks) { + n.listDomFrontiers(sb); + } + return sb.toString(); + } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java index 83e7c24..143b123 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java @@ -41,9 +41,17 @@ public EnterSSA(CompiledFunction bytecodeFunction, EnumSet options) { System.out.println("Pre SSA Dominator Tree"); System.out.println(domTree.generateDotOutput()); } + if (options.contains(Options.DUMP_PRE_SSA_DOMFRONTIERS)) { + System.out.println("Pre SSA Dominance Frontiers"); + System.out.println(domTree.listDomFrontiers()); + } this.blocks = domTree.blocks; // the blocks are ordered reverse post order findNonLocalNames(); new Liveness(bytecodeFunction); // EWe require liveness info to construct pruned ssa + if (options.contains(Options.DUMP_PRE_SSA_LIVENESS)) { + System.out.println("Pre SSA Liveness"); + System.out.println(bytecodeFunction.toStr(new StringBuilder(), true)); + } insertPhis(); renameVars(); bytecodeFunction.isSSA = true; @@ -177,8 +185,9 @@ void search(BasicBlock block) { } // Pop stacks for defs for (Instruction i: block.instructions) { - if (i.definesVar()) { - var reg = i.def(); + // Phis don't answer to definesVar() or def() + if (i.definesVar() || i instanceof Instruction.Phi) { + var reg = i instanceof Instruction.Phi phi ? phi.value() : i.def(); stacks[reg.nonSSAId()].pop(); } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java index 0ab1203..0aa50ff 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java @@ -124,14 +124,28 @@ public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) { super(I_NEW_ARRAY, destOperand); this.type = type; } + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand, Operand len) { + super(I_NEW_ARRAY, destOperand, len); + this.type = type; + } + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand, Operand len, Operand initValue) { + super(I_NEW_ARRAY, destOperand, len, initValue); + this.type = type; + } + public Operand len() { return uses.length > 0 ? uses[0] : null; } + public Operand initValue() { return uses.length > 1 ? uses[1] : null; } public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(def) + sb.append(def) .append(" = ") .append("New(") - .append(type) - .append(")"); + .append(type); + if (len() != null) + sb.append(", len=").append(len()); + if (initValue() != null) + sb.append(", initValue=").append(initValue()); + return sb.append(")"); } } @@ -417,15 +431,18 @@ public void addInput(Register register) { newUses[newUses.length-1] = new Operand.RegisterOperand(register); this.uses = newUses; } - public void replaceInput(Register oldReg, Register newReg) { + public boolean replaceInput(Register oldReg, Register newReg) { + boolean replaced = false; for (int i = 0; i < numInputs(); i++) { if (isRegisterInput(i)) { Register in = inputAsRegister(i); if (in.equals(oldReg)) { replaceInput(i, newReg); + replaced = true; } } } + return replaced; } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java index e71988c..b53c8cd 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java @@ -10,6 +10,8 @@ public enum Options { REGALLOC, DUMP_INITIAL_IR, DUMP_PRE_SSA_DOMTREE, + DUMP_PRE_SSA_DOMFRONTIERS, + DUMP_PRE_SSA_LIVENESS, DUMP_SSA_IR, DUMP_SCCP_PREAPPLY, DUMP_SCCP_POSTAPPLY, diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java index 3376e61..42895b9 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java @@ -40,7 +40,7 @@ public class Register { * * Not unique. */ - private final String name; + protected final String name; /** * The type of the register */ @@ -50,7 +50,7 @@ public class Register { * of the executing function. Multiple registers may share the same * frame slot because of different non-overlapping life times. */ - private int frameSlot; + protected int frameSlot; public Register(int id, String name, Type type) { this(id,name,type,id); // Initially frame slot is set to the unique ID @@ -107,7 +107,15 @@ public SSARegister(Register original, int id, int version) { public int nonSSAId() { return originalRegNumber; } - } + @Override + public String toString() { + return "SSARegister{name=" + name + ", id=" + id + ", frameSlot=" + frameSlot + ", ssaVersion=" + ssaVersion + ", originalRegNumber=" + originalRegNumber + '}'; + } + } + @Override + public String toString() { + return "Register{name=" + name + ", id=" + id + ", frameSlot=" + frameSlot + "}"; + } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java index 4e81e91..41ebb8b 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java @@ -89,6 +89,8 @@ private static void recordUses(Map defUseChains, Register[] in public static void recordUse(Map defUseChains, Instruction instruction, Register register) { SSADef def = defUseChains.get(register); + if (def == null) + throw new CompilerException("No def found for " + register); def.useList.add(instruction); } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index c85fa24..654d2df 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -200,7 +200,19 @@ else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) } } case Instruction.NewArray newArrayInst -> { - execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type); + long size = 0; + Value initValue = null; + if (newArrayInst.len() instanceof Operand.ConstantOperand constantOperand) + size = constantOperand.value; + else if (newArrayInst.len() instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]; + size = (long) indexValue.value; + } + if (newArrayInst.initValue() instanceof Operand.ConstantOperand constantOperand) + initValue = new Value.IntegerValue(constantOperand.value); + else if (newArrayInst.initValue() instanceof Operand.RegisterOperand registerOperand) + initValue = execStack.stack[base + registerOperand.frameSlot()]; + execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type, size, initValue); } case Instruction.NewStruct newStructInst -> { execStack.stack[base + newStructInst.destOperand().frameSlot()] = new Value.StructValue(newStructInst.type); diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java index 7f914a6..6717daa 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -17,9 +17,12 @@ public NullValue() {} static public class ArrayValue extends Value { public final Type.TypeArray arrayType; public final ArrayList values; - public ArrayValue(Type.TypeArray arrayType) { + public ArrayValue(Type.TypeArray arrayType, long len, Value initValue) { this.arrayType = arrayType; values = new ArrayList<>(); + for (long i = 0; i < len; i++) { + values.add(initValue); + } } } static public class StructValue extends Value { diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index b6873ec..11b46eb 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -214,7 +214,7 @@ func foo()->[Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = New([Int]) + %t0 = New([Int], len=3) %t0[0] = 1 %t0[1] = 2 %t0[2] = 3 @@ -235,7 +235,7 @@ func foo(n: Int) -> [Int] { Assert.assertEquals(""" L0: arg n - %t1 = New([Int]) + %t1 = New([Int], len=1) %t1[0] = n ret %t1 goto L1 diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestDominators.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestDominators.java index 4a74ded..a58843b 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestDominators.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestDominators.java @@ -12,6 +12,8 @@ BasicBlock add(List nodes, BasicBlock node) { return node; } + // This CFG example is taken from page 473 of + // Engineering a Compiler 3rd ed BasicBlock makeGraph(List nodes) { BasicBlock r0 = add(nodes, new BasicBlock(1)); BasicBlock r1 = add(nodes, new BasicBlock(2, r0)); @@ -28,15 +30,34 @@ BasicBlock makeGraph(List nodes) { return r0; } + // This DOM Tree example is taken from page 473 of + // Engineering a Compiler 3rd ed @Test public void testDominatorTree() { List nodes = new ArrayList<>(); BasicBlock root = makeGraph(nodes); DominatorTree tree = new DominatorTree(root); System.out.println(tree.generateDotOutput()); - long[] expectedIdoms = {0,1,1,2,2,4,2,6,6,6}; + // expected {_,_,0,1,1,3,1,5,5,5} + // Note first entry is not used + // Note we set idom of root to itself so the second entry + // does not match example + long[] expectedIdoms = {-1,0,1,2,2,4,2,6,6,6}; for (BasicBlock n: nodes) { - Assert.assertEquals(expectedIdoms[(int)n.bid], n.idom.bid); + if (expectedIdoms[(int)n.bid] == 0) + Assert.assertNull(n.idom); + else + Assert.assertEquals(expectedIdoms[(int)n.bid], n.idom.bid); + } + // -1 means empty set + long[] expectedDF = {0,-1,2,4,2,-1,4,8,4,8}; + for (BasicBlock n: nodes) { + if (expectedDF[(int)n.bid] == -1) { + Assert.assertTrue(n.dominationFrontier.isEmpty()); + } + else { + Assert.assertEquals(1,n.dominationFrontier.size()); + } } } diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestIncrementalSSA.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestIncrementalSSA.java index e4f9b93..e3a5272 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestIncrementalSSA.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestIncrementalSSA.java @@ -183,188 +183,189 @@ func example14_66(p: Int, q: Int, r: Int, s: Int, t: Int) { """; String result = compileSrc(src); Assert.assertEquals(""" - func print(a: Int,b: Int,c: Int,d: Int) - Reg #0 a 0 - Reg #1 b 1 - Reg #2 c 2 - Reg #3 d 3 - L0: - arg a - arg b - arg c - arg d - goto L1 - L1: - func example14_66(p: Int,q: Int,r: Int,s: Int,t: Int) - Reg #0 p 0 - Reg #1 q 1 - Reg #2 r 2 - Reg #3 s 3 - Reg #4 t 4 - Reg #5 i 5 - Reg #6 j 6 - Reg #7 k 7 - Reg #8 l 8 - Reg #9 p_1 0 - Reg #10 i_1 5 - Reg #11 j_1 6 - Reg #12 q_1 1 - Reg #13 l_1 8 - Reg #14 l_2 8 - Reg #15 %t15 15 - Reg #16 k_1 7 - Reg #17 k_2 7 - Reg #18 k_3 7 - Reg #19 %t19 19 - Reg #20 k_4 7 - Reg #21 %t21 21 - Reg #22 i_2 5 - Reg #23 i_3 5 - Reg #24 %t24 24 - Reg #25 j_2 6 - Reg #26 j_3 6 - Reg #27 j_4 6 - Reg #28 %t28 28 - Reg #29 k_5 7 - Reg #30 %t30 30 - Reg #31 l_3 8 - Reg #32 l_4 8 - Reg #33 l_5 8 - Reg #34 r_1 2 - Reg #35 %t35 35 - Reg #36 l_6 8 - Reg #37 l_7 8 - Reg #38 %t38 38 - Reg #39 s_1 3 - Reg #40 s_2 3 - Reg #41 r_2 2 - Reg #42 r_3 2 - Reg #43 r_4 2 - Reg #44 r_5 2 - Reg #45 s_3 3 - Reg #46 s_4 3 - Reg #47 s_5 3 - Reg #48 l_8 8 - Reg #49 %t49 49 - Reg #50 i_4 5 - Reg #51 i_5 5 - Reg #52 i_6 5 - Reg #53 i_7 5 - Reg #54 %t54 54 - Reg #55 t_1 4 - Reg #56 t_2 4 - Reg #57 t_3 4 - Reg #58 t_4 4 - Reg #59 t_5 4 - Reg #60 t_6 4 - Reg #61 p_2 0 - Reg #62 p_3 0 - Reg #63 p_4 0 - Reg #64 p_5 0 - Reg #65 p_6 0 - Reg #66 q_2 1 - Reg #67 q_3 1 - Reg #68 q_4 1 - Reg #69 q_5 1 - Reg #70 q_6 1 - Reg #71 r_6 2 - Reg #72 s_6 3 - Reg #73 j_5 6 - Reg #74 j_6 6 - Reg #75 j_7 6 - Reg #76 k_6 7 - Reg #77 k_7 7 - Reg #78 k_8 7 - Reg #79 l_9 8 - L0: - arg p - arg q - arg r - arg s - arg t - i = 1 - j = 1 - k = 1 - l = 1 - goto L2 - L2: - t_5 = phi(t, t_1) - s_5 = phi(s, s_2) - r_4 = phi(r, r_1) - l_5 = phi(l, l_9) - j_4 = phi(j, j_5) - k_2 = phi(k, k_6) - q_1 = phi(q, q_2) - i_1 = phi(i, i_7) - p_1 = phi(p, p_2) - if 1 goto L3 else goto L4 - L3: - if p_1 goto L5 else goto L6 - L5: - j_1 = i_1 - if q_1 goto L8 else goto L9 - L8: - l_1 = 2 - goto L10 - L10: - l_4 = phi(l_1, l_2) - %t15 = k_2+1 - k_3 = %t15 - goto L7 - L7: - l_3 = phi(l_4, l_5) - k_5 = phi(k_3, k_4) - j_2 = phi(j_1, j_4) - %t21 = i_1 - %t24 = j_2 - %t28 = k_5 - %t30 = l_3 - call print params %t21, %t24, %t28, %t30 - goto L11 - L11: - l_6 = phi(l_3, l_8) - if 1 goto L12 else goto L13 - L12: - if r_4 goto L14 else goto L15 - L14: - %t35 = l_6+4 - l_7 = %t35 - goto L15 - L15: - l_8 = phi(l_6, l_7) - %t38 = !s_5 - if %t38 goto L16 else goto L17 - L16: - goto L13 - L13: - l_9 = phi(l_6, l_8) - k_6 = phi(k_5, k_7) - j_5 = phi(j_2, j_6) - q_2 = phi(q_1, q_3) - p_2 = phi(p_1, p_3) - t_1 = phi(t_5, t_2) - i_4 = phi(i_1, i_5) - %t49 = i_4+6 - i_7 = %t49 - %t54 = !t_1 - if %t54 goto L18 else goto L19 - L18: - goto L4 - L4: - goto L1 - L1: - L19: - goto L2 - L17: - goto L11 - L9: - l_2 = 3 - goto L10 - L6: - %t19 = k_2+2 - k_4 = %t19 - goto L7 - """, result); +func print(a: Int,b: Int,c: Int,d: Int) +Reg #0 a 0 +Reg #1 b 1 +Reg #2 c 2 +Reg #3 d 3 +L0: + arg a + arg b + arg c + arg d + goto L1 +L1: +func example14_66(p: Int,q: Int,r: Int,s: Int,t: Int) +Reg #0 p 0 +Reg #1 q 1 +Reg #2 r 2 +Reg #3 s 3 +Reg #4 t 4 +Reg #5 i 5 +Reg #6 j 6 +Reg #7 k 7 +Reg #8 l 8 +Reg #9 p_1 0 +Reg #10 i_1 5 +Reg #11 j_1 6 +Reg #12 q_1 1 +Reg #13 l_1 8 +Reg #14 l_2 8 +Reg #15 %t15 15 +Reg #16 k_1 7 +Reg #17 k_2 7 +Reg #18 k_3 7 +Reg #19 %t19 19 +Reg #20 k_4 7 +Reg #21 %t21 21 +Reg #22 i_2 5 +Reg #23 i_3 5 +Reg #24 %t24 24 +Reg #25 j_2 6 +Reg #26 j_3 6 +Reg #27 j_4 6 +Reg #28 %t28 28 +Reg #29 k_5 7 +Reg #30 %t30 30 +Reg #31 l_3 8 +Reg #32 l_4 8 +Reg #33 l_5 8 +Reg #34 r_1 2 +Reg #35 %t35 35 +Reg #36 l_6 8 +Reg #37 l_7 8 +Reg #38 %t38 38 +Reg #39 s_1 3 +Reg #40 s_2 3 +Reg #41 r_2 2 +Reg #42 r_3 2 +Reg #43 r_4 2 +Reg #44 r_5 2 +Reg #45 s_3 3 +Reg #46 s_4 3 +Reg #47 s_5 3 +Reg #48 l_8 8 +Reg #49 %t49 49 +Reg #50 i_4 5 +Reg #51 i_5 5 +Reg #52 i_6 5 +Reg #53 i_7 5 +Reg #54 i_8 5 +Reg #55 %t55 55 +Reg #56 t_1 4 +Reg #57 t_2 4 +Reg #58 t_3 4 +Reg #59 t_4 4 +Reg #60 t_5 4 +Reg #61 t_6 4 +Reg #62 t_7 4 +Reg #63 p_2 0 +Reg #64 p_3 0 +Reg #65 p_4 0 +Reg #66 p_5 0 +Reg #67 p_6 0 +Reg #68 p_7 0 +Reg #69 q_2 1 +Reg #70 q_3 1 +Reg #71 q_4 1 +Reg #72 q_5 1 +Reg #73 q_6 1 +Reg #74 q_7 1 +Reg #75 r_6 2 +Reg #76 r_7 2 +Reg #77 r_8 2 +Reg #78 r_9 2 +Reg #79 s_6 3 +Reg #80 s_7 3 +Reg #81 s_8 3 +Reg #82 s_9 3 +Reg #83 j_5 6 +Reg #84 j_6 6 +Reg #85 j_7 6 +Reg #86 j_8 6 +Reg #87 k_6 7 +Reg #88 k_7 7 +Reg #89 k_8 7 +Reg #90 k_9 7 +Reg #91 l_9 8 +L0: + arg p + arg q + arg r + arg s + arg t + i = 1 + j = 1 + k = 1 + l = 1 + goto L2 +L2: + l_5 = phi(l, l_9) + j_4 = phi(j, j_2) + k_2 = phi(k, k_5) + i_1 = phi(i, i_8) + if 1 goto L3 else goto L4 +L3: + if p goto L5 else goto L6 +L5: + j_1 = i_1 + if q goto L8 else goto L9 +L8: + l_1 = 2 + goto L10 +L10: + l_4 = phi(l_1, l_2) + %t15 = k_2+1 + k_3 = %t15 + goto L7 +L7: + l_3 = phi(l_4, l_5) + k_5 = phi(k_3, k_4) + j_2 = phi(j_1, j_4) + %t21 = i_1 + %t24 = j_2 + %t28 = k_5 + %t30 = l_3 + call print params %t21, %t24, %t28, %t30 + goto L11 +L11: + l_6 = phi(l_3, l_8) + if 1 goto L12 else goto L13 +L12: + if r goto L14 else goto L15 +L14: + %t35 = l_6+4 + l_7 = %t35 + goto L15 +L15: + l_8 = phi(l_6, l_7) + %t38 = !s + if %t38 goto L16 else goto L17 +L16: + goto L13 +L13: + l_9 = phi(l_6, l_8) + %t49 = i_1+6 + i_8 = %t49 + %t55 = !t + if %t55 goto L18 else goto L19 +L18: + goto L4 +L4: + goto L1 +L1: +L19: + goto L2 +L17: + goto L11 +L9: + l_2 = 3 + goto L10 +L6: + %t19 = k_2+2 + k_4 = %t19 + goto L7 +""", result); } @Test @@ -606,6 +607,114 @@ func main()->Int """, result); } + @Test + public void testSSA11() { + String src = """ + func main()->Int { + var a = 0 + var i = 0 + var j = 0 + while (i < 3) { + j = 0 + while (j < 3) { + if (i == j) + a = a + i + j + else if (i > j) + a = a - 1 + j = j + 1 + } + i = i + 1 + } + return a + } +"""; + String result = compileSrc(src); + Assert.assertEquals(""" +func main()->Int +Reg #0 a 0 +Reg #1 i 1 +Reg #2 j 2 +Reg #3 %t3 3 +Reg #4 i_1 1 +Reg #5 j_1 2 +Reg #6 %t6 6 +Reg #7 j_2 2 +Reg #8 %t8 8 +Reg #9 i_2 1 +Reg #10 %t10 10 +Reg #11 a_1 0 +Reg #12 %t12 12 +Reg #13 a_2 0 +Reg #14 %t14 14 +Reg #15 %t15 15 +Reg #16 a_3 0 +Reg #17 %t17 17 +Reg #18 j_3 2 +Reg #19 j_4 2 +Reg #20 j_5 2 +Reg #21 a_4 0 +Reg #22 a_5 0 +Reg #23 a_6 0 +Reg #24 i_3 1 +Reg #25 i_4 1 +Reg #26 %t26 26 +Reg #27 i_5 1 +Reg #28 i_6 1 +Reg #29 i_7 1 +Reg #30 i_8 1 +L0: + a = 0 + i = 0 + j = 0 + goto L2 +L2: + a_4 = phi(a, a_1) + i_1 = phi(i, i_8) + %t3 = i_1<3 + if %t3 goto L3 else goto L4 +L3: + j_1 = 0 + goto L5 +L5: + a_1 = phi(a_4, a_5) + j_2 = phi(j_1, j_5) + %t6 = j_2<3 + if %t6 goto L6 else goto L7 +L6: + %t8 = i_1==j_2 + if %t8 goto L8 else goto L9 +L8: + %t10 = a_1+i_1 + %t12 = %t10+j_2 + a_2 = %t12 + goto L10 +L10: + a_5 = phi(a_2, a_6) + %t17 = j_2+1 + j_5 = %t17 + goto L5 +L9: + %t14 = i_1>j_2 + if %t14 goto L11 else goto L12 +L11: + %t15 = a_1-1 + a_3 = %t15 + goto L12 +L12: + a_6 = phi(a_1, a_3) + goto L10 +L7: + %t26 = i_1+1 + i_8 = %t26 + goto L2 +L4: + ret a_4 + goto L1 +L1: +""", result); + } + + @Test public void testSSA17() { String src = """ @@ -670,4 +779,446 @@ func merge(begin: Int,middle: Int,end: Int) """, result); } + @Test + public void testSSA18() + { + String src = """ + func foo(len: Int, val: Int, x: Int, y: Int)->[Int] { + if (x > y) { + len=len+x + val=val+x + } + return new [Int]{len=len,value=val} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo(len: Int,val: Int,x: Int,y: Int)->[Int] +Reg #0 len 0 +Reg #1 val 1 +Reg #2 x 2 +Reg #3 y 3 +Reg #4 %t4 4 +Reg #5 %t5 5 +Reg #6 len_1 0 +Reg #7 %t7 7 +Reg #8 val_1 1 +Reg #9 %t9 9 +Reg #10 len_2 0 +Reg #11 val_2 1 +L0: + arg len + arg val + arg x + arg y + %t4 = x>y + if %t4 goto L2 else goto L3 +L2: + %t5 = len+x + len_1 = %t5 + %t7 = val+x + val_1 = %t7 + goto L3 +L3: + val_2 = phi(val, val_1) + len_2 = phi(len, len_1) + %t9 = New([Int], len=len_2, initValue=val_2) + ret %t9 + goto L1 +L1: +""", result); + } + + @Test + public void testSSA19() + { + String src = """ +func bug(N: Int) +{ + var p=2 + while( p < N ) { + if (p) { + p = p + 1 + } + } + while ( p < N ) { + p = p + 1 + } +} + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func bug(N: Int) +Reg #0 N 0 +Reg #1 p 1 +Reg #2 %t2 2 +Reg #3 p_1 1 +Reg #4 N_1 0 +Reg #5 %t5 5 +Reg #6 p_2 1 +Reg #7 N_2 0 +Reg #8 p_3 1 +Reg #9 %t9 9 +Reg #10 p_4 1 +Reg #11 N_3 0 +Reg #12 %t12 12 +Reg #13 p_5 1 +Reg #14 N_4 0 +Reg #15 N_5 0 +L0: + arg N + p = 2 + goto L2 +L2: + p_1 = phi(p, p_3) + %t2 = p_1[Int] +{ + // The main Sieve array + var ary = new [Int]{len=N,value=0} + // The primes less than N + var primes = new [Int]{len=N/2,value=0} + // Number of primes so far, searching at index p + var nprimes = 0 + var p=2 + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) { + p = p + 1 + } + // p is now a prime + primes[nprimes] = p + nprimes = nprimes+1 + // Mark out the rest non-primes + var i = p + p + while( i < N ) { + ary[i] = 1 + i = i + p + } + p = p + 1 + } + + // Now just collect the remaining primes, no more marking + while ( p < N ) { + if( !ary[p] ) { + primes[nprimes] = p + nprimes = nprimes + 1 + } + p = p + 1 + } + + // Copy/shrink the result array + var rez = new [Int]{len=nprimes,value=0} + var j = 0 + while( j < nprimes ) { + rez[j] = primes[j] + j = j + 1 + } + return rez +} +func eq(a: [Int], b: [Int], n: Int)->Int +{ + var result = 1 + var i = 0 + while (i < n) + { + if (a[i] != b[i]) + { + result = 0 + break + } + i = i + 1 + } + return result +} + +func main()->Int +{ + var rez = sieve(20) + var expected = new [Int]{2,3,5,7,11,13,17,19} + return eq(rez,expected,8) +} +"""; + String result = compileSrc(src); + Assert.assertEquals(""" +func sieve(N: Int)->[Int] +Reg #0 N 0 +Reg #1 ary 1 +Reg #2 primes 2 +Reg #3 nprimes 3 +Reg #4 p 4 +Reg #5 rez 5 +Reg #6 j 6 +Reg #7 i 7 +Reg #8 %t8 8 +Reg #9 %t9 9 +Reg #10 %t10 10 +Reg #11 %t11 11 +Reg #12 p_1 4 +Reg #13 %t13 13 +Reg #14 N_1 0 +Reg #15 %t15 15 +Reg #16 ary_1 1 +Reg #17 p_2 4 +Reg #18 %t18 18 +Reg #19 p_3 4 +Reg #20 ary_2 1 +Reg #21 primes_1 2 +Reg #22 primes_2 2 +Reg #23 nprimes_1 3 +Reg #24 nprimes_2 3 +Reg #25 %t25 25 +Reg #26 nprimes_3 3 +Reg #27 %t27 27 +Reg #28 %t28 28 +Reg #29 i_1 7 +Reg #30 N_2 0 +Reg #31 ary_3 1 +Reg #32 %t32 32 +Reg #33 p_4 4 +Reg #34 i_2 7 +Reg #35 N_3 0 +Reg #36 ary_4 1 +Reg #37 %t37 37 +Reg #38 p_5 4 +Reg #39 p_6 4 +Reg #40 N_4 0 +Reg #41 ary_5 1 +Reg #42 primes_3 2 +Reg #43 nprimes_4 3 +Reg #44 %t44 44 +Reg #45 p_7 4 +Reg #46 N_5 0 +Reg #47 %t47 47 +Reg #48 ary_6 1 +Reg #49 %t49 49 +Reg #50 primes_4 2 +Reg #51 nprimes_5 3 +Reg #52 %t52 52 +Reg #53 nprimes_6 3 +Reg #54 %t54 54 +Reg #55 p_8 4 +Reg #56 p_9 4 +Reg #57 N_6 0 +Reg #58 N_7 0 +Reg #59 N_8 0 +Reg #60 N_9 0 +Reg #61 ary_7 1 +Reg #62 ary_8 1 +Reg #63 ary_9 1 +Reg #64 ary_10 1 +Reg #65 primes_5 2 +Reg #66 primes_6 2 +Reg #67 primes_7 2 +Reg #68 primes_8 2 +Reg #69 nprimes_7 3 +Reg #70 %t70 70 +Reg #71 %t71 71 +Reg #72 j_1 6 +Reg #73 nprimes_8 3 +Reg #74 %t74 74 +Reg #75 primes_9 2 +Reg #76 rez_1 5 +Reg #77 %t77 77 +Reg #78 j_2 6 +Reg #79 primes_10 2 +Reg #80 primes_11 2 +Reg #81 rez_2 5 +L0: + arg N + %t8 = New([Int], len=N, initValue=0) + ary = %t8 + %t10 = N/2 + %t9 = New([Int], len=%t10, initValue=0) + primes = %t9 + nprimes = 0 + p = 2 + goto L2 +L2: + nprimes_2 = phi(nprimes, nprimes_3) + p_1 = phi(p, p_6) + %t11 = p_1*p_1 + %t13 = %t11Int +Reg #0 a 0 +Reg #1 b 1 +Reg #2 n 2 +Reg #3 result 3 +Reg #4 i 4 +Reg #5 %t5 5 +Reg #6 i_1 4 +Reg #7 n_1 2 +Reg #8 %t8 8 +Reg #9 a_1 0 +Reg #10 %t10 10 +Reg #11 b_1 1 +Reg #12 %t12 12 +Reg #13 result_1 3 +Reg #14 %t14 14 +Reg #15 i_2 4 +Reg #16 result_2 3 +Reg #17 result_3 3 +L0: + arg a + arg b + arg n + result = 1 + i = 0 + goto L2 +L2: + i_1 = phi(i, i_2) + %t5 = i_1Int +Reg #0 rez 0 +Reg #1 expected 1 +Reg #2 %t2 2 +Reg #3 %t3 3 +Reg #4 %t4 4 +Reg #5 %t5 5 +Reg #6 %t6 6 +Reg #7 %t7 7 +Reg #8 %t8 8 +L0: + %t2 = 20 + %t3 = call sieve params %t2 + rez = %t3 + %t4 = New([Int], len=8) + %t4[0] = 2 + %t4[1] = 3 + %t4[2] = 5 + %t4[3] = 7 + %t4[4] = 11 + %t4[5] = 13 + %t4[6] = 17 + %t4[7] = 19 + expected = %t4 + %t5 = rez + %t6 = expected + %t7 = 8 + %t8 = call eq params %t5, %t6, %t7 + ret %t8 + goto L1 +L1: +""", result); + } } \ No newline at end of file diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index c8b5d6a..02aeb01 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -24,10 +24,13 @@ String compileSrc(String src) { sb.append("Before SSA\n"); sb.append("==========\n"); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); + //functionBuilder.toDot(sb,false); new EnterSSA(functionBuilder, Options.NONE); + //new EnterSSA(functionBuilder, EnumSet.of(DUMP_PRE_SSA_DOMFRONTIERS)); sb.append("After SSA\n"); sb.append("=========\n"); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); + //functionBuilder.toDot(sb,false); new ExitSSA(functionBuilder, Options.NONE); sb.append("After exiting SSA\n"); sb.append("=================\n"); @@ -1662,7 +1665,7 @@ func foo()->Int { Before SSA ========== L0: - %t2 = New([Int]) + %t2 = New([Int], len=2) %t2[0] = 1 %t2[1] = 2 arr = %t2 @@ -1675,7 +1678,7 @@ func foo()->Int { After SSA ========= L0: - %t2_0 = New([Int]) + %t2_0 = New([Int], len=2) %t2_0[0] = 1 %t2_0[1] = 2 arr_0 = %t2_0 @@ -1688,7 +1691,7 @@ func foo()->Int { After exiting SSA ================= L0: - %t2_0 = New([Int]) + %t2_0 = New([Int], len=2) %t2_0[0] = 1 %t2_0[1] = 2 arr_0 = %t2_0 @@ -2276,7 +2279,7 @@ else if (i > j) %t5_0 = i_1==j_2 if %t5_0 goto L8 else goto L9 L8: - %t6_0 = a_4+i_1 + %t6_0 = a_2+i_1 %t7_0 = %t6_0+j_2 a_5 = %t7_0 goto L10 @@ -2327,7 +2330,7 @@ else if (i > j) %t5_0 = i_1==j_2 if %t5_0 goto L8 else goto L9 L8: - %t6_0 = a_4+i_1 + %t6_0 = a_2+i_1 %t7_0 = %t6_0+j_2 a_5 = %t7_0 a_6 = a_5 @@ -2781,7 +2784,7 @@ func foo()->Int Before SSA ========== L0: - %t1 = New([Foo?]) + %t1 = New([Foo?], len=2) %t2 = New(Foo) %t2.i = 1 %t1[0] = %t2 @@ -2805,7 +2808,7 @@ func foo()->Int After SSA ========= L0: - %t1_0 = New([Foo?]) + %t1_0 = New([Foo?], len=2) %t2_0 = New(Foo) %t2_0.i = 1 %t1_0[0] = %t2_0 @@ -2830,7 +2833,7 @@ func foo()->Int After exiting SSA ================= L0: - %t1_0 = New([Foo?]) + %t1_0 = New([Foo?], len=2) %t2_0 = New(Foo) %t2_0.i = 1 %t1_0[0] = %t2_0 @@ -2973,6 +2976,712 @@ func merge(begin: Int, middle: Int, end: Int) L3: goto L1 L1: +""", result); + } + + @Test + public void testSSA18() + { + String src = """ + func foo(len: Int, val: Int, x: Int, y: Int)->[Int] { + if (x > y) { + len=len+x + val=val+x + } + return new [Int]{len=len,value=val} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo +Before SSA +========== +L0: + arg len + arg val + arg x + arg y + %t4 = x>y + if %t4 goto L2 else goto L3 +L2: + %t5 = len+x + len = %t5 + %t6 = val+x + val = %t6 + goto L3 +L3: + %t7 = New([Int], len=len, initValue=val) + ret %t7 + goto L1 +L1: +After SSA +========= +L0: + arg len_0 + arg val_0 + arg x_0 + arg y_0 + %t4_0 = x_0>y_0 + if %t4_0 goto L2 else goto L3 +L2: + %t5_0 = len_0+x_0 + len_1 = %t5_0 + %t6_0 = val_0+x_0 + val_1 = %t6_0 + goto L3 +L3: + val_2 = phi(val_0, val_1) + len_2 = phi(len_0, len_1) + %t7_0 = New([Int], len=len_2, initValue=val_2) + ret %t7_0 + goto L1 +L1: +After exiting SSA +================= +L0: + arg len_0 + arg val_0 + arg x_0 + arg y_0 + %t4_0 = x_0>y_0 + val_2 = val_0 + len_2 = len_0 + if %t4_0 goto L2 else goto L3 +L2: + %t5_0 = len_0+x_0 + len_1 = %t5_0 + %t6_0 = val_0+x_0 + val_1 = %t6_0 + val_2 = val_1 + len_2 = len_1 + goto L3 +L3: + %t7_0 = New([Int], len=len_2, initValue=val_2) + ret %t7_0 + goto L1 +L1: +""", result); + } + + @Test + public void testSSA19() + { + String src = """ +func bug(N: Int) +{ + var p=2 + while( p < N ) { + if (p) { + p = p + 1 + } + } + while ( p < N ) { + p = p + 1 + } +} + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func bug +Before SSA +========== +L0: + arg N + p = 2 + goto L2 +L2: + %t2 = p[Int] +{ + // The main Sieve array + var ary = new [Int]{len=N,value=0} + // The primes less than N + var primes = new [Int]{len=N/2,value=0} + // Number of primes so far, searching at index p + var nprimes = 0 + var p=2 + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) { + p = p + 1 + } + // p is now a prime + primes[nprimes] = p + nprimes = nprimes+1 + // Mark out the rest non-primes + var i = p + p + while( i < N ) { + ary[i] = 1 + i = i + p + } + p = p + 1 + } + + // Now just collect the remaining primes, no more marking + while ( p < N ) { + if( !ary[p] ) { + primes[nprimes] = p + nprimes = nprimes + 1 + } + p = p + 1 + } + + // Copy/shrink the result array + var rez = new [Int]{len=nprimes,value=0} + var j = 0 + while( j < nprimes ) { + rez[j] = primes[j] + j = j + 1 + } + return rez +} +func eq(a: [Int], b: [Int], n: Int)->Int +{ + var result = 1 + var i = 0 + while (i < n) + { + if (a[i] != b[i]) + { + result = 0 + break + } + i = i + 1 + } + return result +} + +func main()->Int +{ + var rez = sieve(20) + var expected = new [Int]{2,3,5,7,11,13,17,19} + return eq(rez,expected,8) +} +"""; + String result = compileSrc(src); + Assert.assertEquals(""" +func sieve +Before SSA +========== +L0: + arg N + %t8 = New([Int], len=N, initValue=0) + ary = %t8 + %t10 = N/2 + %t9 = New([Int], len=%t10, initValue=0) + primes = %t9 + nprimes = 0 + p = 2 + goto L2 +L2: + %t11 = p*p + %t12 = %t11Int integerValue.value == 1); } + @Test + public void testFunction108() { + String src = """ + func make(len: Int, val: Int)->[Int] + { + return new [Int]{len=len, value=val} + } + func main()->Int + { + var arr = make(3,3); + var i = 0 + while (i < 3) { + if (arr[i] != 3) + return 1 + i = i + 1 + } + return 0 + } + """; + var value = compileAndRun(src, "main", Options.OPT); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 0); + } + + @Test + public void testFunction109() { + String src = """ +func sieve(N: Int)->[Int] +{ + // The main Sieve array + var ary = new [Int]{len=N,value=0} + // The primes less than N + var primes = new [Int]{len=N/2,value=0} + // Number of primes so far, searching at index p + var nprimes = 0 + var p=2 + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) { + p = p + 1 + } + // p is now a prime + primes[nprimes] = p + nprimes = nprimes+1 + // Mark out the rest non-primes + var i = p + p + while( i < N ) { + ary[i] = 1 + i = i + p + } + p = p + 1 + } + + // Now just collect the remaining primes, no more marking + while ( p < N ) { + if( !ary[p] ) { + primes[nprimes] = p + nprimes = nprimes + 1 + } + p = p + 1 + } + + // Copy/shrink the result array + var rez = new [Int]{len=nprimes,value=0} + var j = 0 + while( j < nprimes ) { + rez[j] = primes[j] + j = j + 1 + } + return rez +} +func eq(a: [Int], b: [Int], n: Int)->Int +{ + var result = 1 + var i = 0 + while (i < n) + { + if (a[i] != b[i]) + { + result = 0 + break + } + i = i + 1 + } + return result +} + +func main()->Int +{ + var rez = sieve(20) + var expected = new [Int]{2,3,5,7,11,13,17,19} + return eq(rez,expected,8) +} +"""; + var value = compileAndRun(src, "main"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } } diff --git a/parser/src/main/java/com/compilerprogramming/ezlang/parser/AST.java b/parser/src/main/java/com/compilerprogramming/ezlang/parser/AST.java index 46bcf78..162dd29 100644 --- a/parser/src/main/java/com/compilerprogramming/ezlang/parser/AST.java +++ b/parser/src/main/java/com/compilerprogramming/ezlang/parser/AST.java @@ -545,14 +545,17 @@ public void accept(ASTVisitor visitor) { */ public static class NewExpr extends Expr { public final TypeExpr typeExpr; - public final long len; // temp hack, as this needs to be an expression rather than constant + public final Expr len; + public final Expr initValue; public NewExpr(TypeExpr typeExpr) { this.typeExpr = typeExpr; - this.len = 0; + this.len = null; + this.initValue = null; } - public NewExpr(TypeExpr typeExpr, long len) { + public NewExpr(TypeExpr typeExpr, Expr len, Expr initValue) { this.typeExpr = typeExpr; this.len = len; + this.initValue = initValue; } @Override public StringBuilder toStr(StringBuilder sb) { @@ -566,6 +569,11 @@ public void accept(ASTVisitor visitor) { if (visitor == null) return; typeExpr.accept(visitor); + if (len != null) { + len.accept(visitor); + if (initValue != null) + initValue.accept(visitor); + } visitor.visit(this, false); } } @@ -579,12 +587,7 @@ public static class InitExpr extends Expr { public final List initExprList; public InitExpr(NewExpr newExpr, List initExprList) { this.initExprList = initExprList; - // For arrays we compute length based on number of elements - // This is not actually correct - see https://github.com/CompilerProgramming/ez-lang/issues/47 - if (initExprList.size() != newExpr.len) - this.newExpr = new NewExpr(newExpr.typeExpr,initExprList.size()); - else - this.newExpr = newExpr; + this.newExpr = newExpr; } @Override public StringBuilder toStr(StringBuilder sb) { diff --git a/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java b/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java index 821ce6c..423040b 100644 --- a/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java +++ b/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java @@ -333,6 +333,8 @@ private AST.Expr parseNew(Lexer lexer) { matchIdentifier(lexer, "new"); AST.TypeExpr resultType = parseTypeExpr(lexer); var newExpr = new AST.NewExpr(resultType); + AST.Expr lenExpr = null; + AST.Expr initValueExpr = null; List initExpr = new ArrayList<>(); int index = 0; if (testPunctuation(lexer, "{")) { @@ -343,10 +345,14 @@ private AST.Expr parseNew(Lexer lexer) { matchPunctuation(lexer, "="); AST.Expr value = parseBool(lexer); initExpr.add(new AST.InitFieldExpr(newExpr, fieldname, value)); + if (fieldname.equals("len")) + lenExpr = value; + else if (fieldname.equals("value")) + initValueExpr = value; } else { var indexLit = Integer.valueOf(index++); - var indexExpr = new AST.LiteralExpr(Token.newNum(indexLit,indexLit.toString(),0)); + var indexExpr = new AST.LiteralExpr(Token.newNum(indexLit,indexLit.toString(),currentToken.lineNumber)); initExpr.add(new AST.ArrayInitExpr(newExpr, indexExpr, parseBool(lexer))); } if (isToken(currentToken, ",")) @@ -355,6 +361,12 @@ private AST.Expr parseNew(Lexer lexer) { } } matchPunctuation(lexer, "}"); + if (initExpr.size() > 0 && lenExpr == null) { + var sizeLit = Integer.valueOf(initExpr.size()); + lenExpr = new AST.LiteralExpr(Token.newNum(sizeLit,sizeLit.toString(),currentToken.lineNumber)); + } + if (lenExpr != null) + return new AST.InitExpr(new AST.NewExpr(newExpr.typeExpr, lenExpr, initValueExpr), initExpr); return new AST.InitExpr(newExpr, initExpr); } diff --git a/parser/src/test/java/com/compilerprogramming/ezlang/parser/TestParser.java b/parser/src/test/java/com/compilerprogramming/ezlang/parser/TestParser.java index 1c60dc0..b4966c3 100644 --- a/parser/src/test/java/com/compilerprogramming/ezlang/parser/TestParser.java +++ b/parser/src/test/java/com/compilerprogramming/ezlang/parser/TestParser.java @@ -36,7 +36,7 @@ func bar() -> Test { func main() { var m = 42 var t: Tree - var array = new [Int] {size=10,1,2,3} + var array = new [Int] {len=10,1,2,3} array[1] = 42 t.left = null if (m < 1) diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 9433d5c..3071909 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -383,7 +383,7 @@ private boolean compileArrayStoreExpr(AST.ArrayStoreExpr arrayStoreExpr) { } private boolean compileNewExpr(AST.NewExpr newExpr) { - codeNew(newExpr.type); + codeNew(newExpr.type,newExpr.len,newExpr.initValue); return false; } @@ -603,13 +603,42 @@ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) code(new Instruction.Move(value, indexed)); } - private void codeNew(Type type) { - var temp = createTemp(type); + private void codeNew(Type type, AST.Expr len, AST.Expr initVal) { if (type instanceof Type.TypeArray typeArray) - code(new Instruction.NewArray(typeArray, temp)); - else if (type instanceof Type.TypeStruct typeStruct) + codeNewArray(typeArray, len, initVal); + else if (type instanceof Type.TypeStruct typeStruct) { + var temp = createTemp(type); code(new Instruction.NewStruct(typeStruct, temp)); + } else throw new CompilerException("Unexpected type: " + type); } + + private void codeNewArray(Type.TypeArray typeArray, AST.Expr len, AST.Expr initVal) { + var temp = createTemp(typeArray); + Operand lenOperand = null; + Operand initValOperand = null; + if (len != null) { + boolean indexed = compileExpr(len); + if (indexed) + codeIndexedLoad(); + if (initVal != null) { + indexed = compileExpr(initVal); + if (indexed) + codeIndexedLoad(); + initValOperand = pop(); + } + lenOperand = pop(); + } + Instruction insn; + if (lenOperand != null) { + if (initValOperand != null) + insn = new Instruction.NewArray(typeArray, temp, lenOperand, initValOperand); + else + insn = new Instruction.NewArray(typeArray, temp, lenOperand); + } + else + insn = new Instruction.NewArray(typeArray, temp); + code(insn); + } } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java index a80169a..73b5673 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java @@ -62,14 +62,28 @@ public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) { super(I_NEW_ARRAY, destOperand); this.type = type; } + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand, Operand len) { + super(I_NEW_ARRAY, destOperand, len); + this.type = type; + } + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand, Operand len, Operand initValue) { + super(I_NEW_ARRAY, destOperand, len, initValue); + this.type = type; + } + public Operand len() { return uses.length > 0 ? uses[0] : null; } + public Operand initValue() { return uses.length > 1 ? uses[1] : null; } public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(def) + sb.append(def) .append(" = ") .append("New(") - .append(type) - .append(")"); + .append(type); + if (len() != null) + sb.append(", len=").append(len()); + if (initValue() != null) + sb.append(", initValue=").append(initValue()); + return sb.append(")"); } } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index 919a82e..f441455 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -200,7 +200,19 @@ else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) } } case Instruction.NewArray newArrayInst -> { - execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type); + long size = 0; + Value initValue = null; + if (newArrayInst.len() instanceof Operand.ConstantOperand constantOperand) + size = constantOperand.value; + else if (newArrayInst.len() instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]; + size = (long) indexValue.value; + } + if (newArrayInst.initValue() instanceof Operand.ConstantOperand constantOperand) + initValue = new Value.IntegerValue(constantOperand.value); + else if (newArrayInst.initValue() instanceof Operand.RegisterOperand registerOperand) + initValue = execStack.stack[base + registerOperand.frameSlot()]; + execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type, size, initValue); } case Instruction.NewStruct newStructInst -> { execStack.stack[base + newStructInst.destOperand().frameSlot()] = new Value.StructValue(newStructInst.type); diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java index 7f914a6..6717daa 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -17,9 +17,12 @@ public NullValue() {} static public class ArrayValue extends Value { public final Type.TypeArray arrayType; public final ArrayList values; - public ArrayValue(Type.TypeArray arrayType) { + public ArrayValue(Type.TypeArray arrayType, long len, Value initValue) { this.arrayType = arrayType; values = new ArrayList<>(); + for (long i = 0; i < len; i++) { + values.add(initValue); + } } } static public class StructValue extends Value { diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index 142b8b6..1461658 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -203,7 +203,7 @@ func foo()->[Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = New([Int]) + %t0 = New([Int], len=3) %t0[0] = 1 %t0[1] = 2 %t0[2] = 3 @@ -223,7 +223,7 @@ func foo(n: Int) -> [Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t1 = New([Int]) + %t1 = New([Int], len=1) %t1[0] = n ret %t1 goto L1 @@ -577,6 +577,89 @@ public void testFunction29() { ret null goto L1 L1: +""", result); + } + + @Test + public void testFunction30() { + String src = """ + func foo()->[Int] { + return new [Int]{len=10,value=0} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t0 = New([Int], len=10, initValue=0) + ret %t0 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction31() { + String src = """ + func foo(len: Int, val: Int)->[Int] { + return new [Int]{len=len,value=val} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t2 = New([Int], len=len, initValue=val) + ret %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction32() { + String src = """ + func foo(len: Int, val: Int, x: Int)->[Int] { + return new [Int]{len=len+x,value=val+x} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t4 = len+x + %t5 = val+x + %t3 = New([Int], len=%t4, initValue=%t5) + ret %t3 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction33() { + String src = """ + func foo(len: Int, val: Int, x: Int, y: Int)->[Int] { + if (x > y) { + len=len+x + val=val+x + } + return new [Int]{len=len,value=val} + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t4 = x>y + if %t4 goto L2 else goto L3 +L2: + %t4 = len+x + len = %t4 + %t4 = val+x + val = %t4 + goto L3 +L3: + %t4 = New([Int], len=len, initValue=val) + ret %t4 + goto L1 +L1: """, result); } diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java index 49f8450..3667caa 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java @@ -326,6 +326,108 @@ func main()->Int merge_sort(a, b, 10) return eq(a,expect,10) } +"""; + var value = compileAndRun(src, "main"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction108() { + String src = """ + func make(len: Int, val: Int)->[Int] + { + return new [Int]{len=len, value=val} + } + func main()->Int + { + var arr = make(3,3); + var i = 0 + while (i < 3) { + if (arr[i] != 3) + return 1 + i = i + 1 + } + return 0 + } + """; + var value = compileAndRun(src, "main"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 0); + } + + @Test + public void testFunction109() { + String src = """ +func sieve(N: Int)->[Int] +{ + // The main Sieve array + var ary = new [Int]{len=N,value=0} + // The primes less than N + var primes = new [Int]{len=N/2,value=0} + // Number of primes so far, searching at index p + var nprimes = 0 + var p=2 + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) { + p = p + 1 + } + // p is now a prime + primes[nprimes] = p + nprimes = nprimes+1 + // Mark out the rest non-primes + var i = p + p + while( i < N ) { + ary[i] = 1 + i = i + p + } + p = p + 1 + } + + // Now just collect the remaining primes, no more marking + while ( p < N ) { + if( !ary[p] ) { + primes[nprimes] = p + nprimes = nprimes + 1 + } + p = p + 1 + } + + // Copy/shrink the result array + var rez = new [Int]{len=nprimes,value=0} + var j = 0 + while( j < nprimes ) { + rez[j] = primes[j] + j = j + 1 + } + return rez +} +func eq(a: [Int], b: [Int], n: Int)->Int +{ + var result = 1 + var i = 0 + while (i < n) + { + if (a[i] != b[i]) + { + result = 0 + break + } + i = i + 1 + } + return result +} + +func main()->Int +{ + var rez = sieve(20) + var expected = new [Int]{2,3,5,7,11,13,17,19} + return eq(rez,expected,8) +} """; var value = compileAndRun(src, "main"); Assert.assertNotNull(value); diff --git a/seaofnodes/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java b/seaofnodes/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java index 46805d5..3d7e546 100644 --- a/seaofnodes/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java +++ b/seaofnodes/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java @@ -554,7 +554,7 @@ private Node compileNewExpr(AST.NewExpr newExpr) { Type type = newExpr.type; if (type instanceof Type.TypeArray typeArray) { SONTypeMemPtr tarray = (SONTypeMemPtr) TYPES.get(typeArray.name()); - return newArray(tarray._obj,newExpr.len==0?ZERO:con(newExpr.len)); + return newArray(tarray._obj,newExpr.len==null?ZERO:compileExpr(newExpr.len)); } else if (type instanceof Type.TypeStruct typeStruct) { SONTypeMemPtr tptr = (SONTypeMemPtr) TYPES.get(typeStruct.name()); diff --git a/seaofnodes/src/test/cases/mergsort/sort.ez b/seaofnodes/src/test/cases/mergsort/sort.ez index 8ad506d..b5a5ae4 100644 --- a/seaofnodes/src/test/cases/mergsort/sort.ez +++ b/seaofnodes/src/test/cases/mergsort/sort.ez @@ -6,14 +6,15 @@ func merge_sort(a: [Int], b: [Int], n: Int) split_merge(a, 0, n, b) } -func split_merge(b: [Int], begin: Int, end: Int, a: [Int]) +func split_merge(b: [Int], begin: Int, end: Int, a: [Int])->Int { if (end - begin <= 1) - return; + return 0 var middle = (end + begin) / 2 split_merge(a, begin, middle, b) split_merge(a, middle, end, b) merge(b, begin, middle, end, a) + return 0 } func merge(b: [Int], begin: Int, middle: Int, end: Int, a: [Int]) diff --git a/seaofnodes/src/test/cases/mergsort/sort.smp b/seaofnodes/src/test/cases/mergsort/sort.smp index b82b3f4..d0aeaa4 100644 --- a/seaofnodes/src/test/cases/mergsort/sort.smp +++ b/seaofnodes/src/test/cases/mergsort/sort.smp @@ -17,23 +17,34 @@ val split_merge = { int[] b, int begin, int end, int[] a -> }; val merge = { int[] b, int begin, int middle, int end, int[] a -> - int i = begin, j = middle; - - for (int k = begin; k < end; k++) { + int i = begin; + int j = middle; + int k = begin; + while (k < end) { // && and || bool cond = false; if (i < middle) { if (j >= end) cond = true; else if (a[i] <= a[j]) cond = true; } - if (cond) b[k] = a[i++]; - else b[k] = a[j++]; + if (cond) { + b[k] = a[i]; + i = i + 1; + } + else { + b[k] = a[j]; + j = j + 1; + } + k = k + 1; } }; val copy_array = { int[] a, int begin, int end, int[] b -> - for (int k = begin; k < end; k++) + int k = begin; + while (k < end) { b[k] = a[k]; + k = k + 1; + } }; val eq = { int[] a, int[] b, int n -> diff --git a/seaofnodes/src/test/cases/sieve/sieve.ez b/seaofnodes/src/test/cases/sieve/sieve.ez new file mode 100644 index 0000000..8f4aa4a --- /dev/null +++ b/seaofnodes/src/test/cases/sieve/sieve.ez @@ -0,0 +1,67 @@ +func sieve(N: Int)->[Int] +{ + // The main Sieve array + var ary = new [Int]{len=N,value=0} + // The primes less than N + var primes = new [Int]{len=N/2,value=0} + // Number of primes so far, searching at index p + var nprimes = 0 + var p=2 + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) { + p = p + 1 + } + // p is now a prime + primes[nprimes] = p + nprimes = nprimes+1 + // Mark out the rest non-primes + var i = p + p + while( i < N ) { + ary[i] = 1 + i = i + p + } + p = p + 1 + } + + // Now just collect the remaining primes, no more marking + while ( p < N ) { + if( !ary[p] ) { + primes[nprimes] = p + nprimes = nprimes + 1 + } + p = p + 1 + } + + // Copy/shrink the result array + var rez = new [Int]{len=nprimes,value=0} + var j = 0 + while( j < nprimes ) { + rez[j] = primes[j] + j = j + 1 + } + return rez +} +func eq(a: [Int], b: [Int], n: Int)->Int +{ + var result = 1 + var i = 0 + while (i < n) + { + if (a[i] != b[i]) + { + result = 0 + break + } + i = i + 1 + } + return result +} + +func main()->Int +{ + var rez = sieve(20) + var expected = new [Int]{2,3,5,7,11,13,17,19} + return eq(rez,expected,8) +} \ No newline at end of file diff --git a/seaofnodes/src/test/cases/sieve/sieve.smp b/seaofnodes/src/test/cases/sieve/sieve.smp new file mode 100644 index 0000000..f0509a8 --- /dev/null +++ b/seaofnodes/src/test/cases/sieve/sieve.smp @@ -0,0 +1,31 @@ +// -*- mode: java; -*- +val sieve = { int N -> + // The main Sieve array + bool[] !ary = new bool[N]; + // The primes less than N + u32[] !primes = new u32[N>>1]; + // Number of primes so far, searching at index p + int nprimes = 0, p=2; + // Find primes while p^2 < N + while( p*p < N ) { + // skip marked non-primes + while( ary[p] ) p++; + // p is now a prime + primes[nprimes++] = p; + // Mark out the rest non-primes + for( int i = p + p; i < ary#; i+= p ) + ary[i] = true; + p++; + } + + // Now just collect the remaining primes, no more marking + for( ; p < N; p++ ) + if( !ary[p] ) + primes[nprimes++] = p; + + // Copy/shrink the result array + u32[] !rez = new u32[nprimes]; + for( int j=0; j < nprimes; j++ ) + rez[j] = primes[j]; + return rez; +}; \ No newline at end of file diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java index 0fa930d..743ac68 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java @@ -67,6 +67,8 @@ public ASTVisitor visit(AST.VarDecl varDecl, boolean enter) { @Override public ASTVisitor visit(AST.BinaryExpr binaryExpr, boolean enter) { if (!enter) { + if (binaryExpr.type != null) + return this; validType(binaryExpr.expr1.type, true); validType(binaryExpr.expr2.type, true); if (binaryExpr.expr1.type instanceof Type.TypeInteger && @@ -92,7 +94,7 @@ else if (((binaryExpr.expr1.type instanceof Type.TypeNull && @Override public ASTVisitor visit(AST.UnaryExpr unaryExpr, boolean enter) { - if (enter) { + if (enter || unaryExpr.type != null) { return this; } validType(unaryExpr.expr.type, false); @@ -109,6 +111,8 @@ public ASTVisitor visit(AST.UnaryExpr unaryExpr, boolean enter) { public ASTVisitor visit(AST.GetFieldExpr fieldExpr, boolean enter) { if (enter) return this; + if (fieldExpr.type != null) + return this; validType(fieldExpr.object.type, false); Type.TypeStruct structType = null; if (fieldExpr.object.type instanceof Type.TypeStruct ts) { @@ -131,6 +135,8 @@ else if (fieldExpr.object.type instanceof Type.TypeNullable ptr && public ASTVisitor visit(AST.SetFieldExpr fieldExpr, boolean enter) { if (enter) return this; + if (fieldExpr.type != null) + return this; validType(fieldExpr.object.type, true); Type.TypeStruct structType = null; if (fieldExpr.object.type instanceof Type.TypeStruct ts) { @@ -140,6 +146,16 @@ else if (fieldExpr.object.type instanceof Type.TypeNullable ptr && ptr.baseType instanceof Type.TypeStruct ts) { structType = ts; } + else if (fieldExpr.object.type instanceof Type.TypeArray typeArray) { + if (fieldExpr.fieldName.equals("len")) + checkAssignmentCompatible(typeDictionary.INT,fieldExpr.value.type); + else if (fieldExpr.fieldName.equals("value")) + checkAssignmentCompatible(typeArray.getElementType(),fieldExpr.value.type); + else + throw new CompilerException("Unexpected array initializer " + fieldExpr.fieldName); + fieldExpr.type = fieldExpr.value.type; + return this; + } else throw new CompilerException("Unexpected struct type " + fieldExpr.object.type); var fieldType = structType.getField(fieldExpr.fieldName); @@ -155,6 +171,8 @@ else if (fieldExpr.object.type instanceof Type.TypeNullable ptr && @Override public ASTVisitor visit(AST.CallExpr callExpr, boolean enter) { if (!enter) { + if (callExpr.type != null) + return this; validType(callExpr.callee.type, false); if (callExpr.callee.type instanceof Type.TypeFunction f) { callExpr.type = f.returnType; @@ -192,6 +210,8 @@ public ASTVisitor visit(AST.ReturnTypeExpr returnTypeExpr, boolean enter) { @Override public ASTVisitor visit(AST.LiteralExpr literalExpr, boolean enter) { if (enter) { + if (literalExpr.type != null) + return this; if (literalExpr.value.kind == Token.Kind.NUM) { literalExpr.type = typeDictionary.INT; } @@ -209,6 +229,8 @@ else if (literalExpr.value.kind == Token.Kind.IDENT @Override public ASTVisitor visit(AST.ArrayLoadExpr arrayIndexExpr, boolean enter) { if (!enter) { + if (arrayIndexExpr.type != null) + return this; validType(arrayIndexExpr.array.type, false); Type.TypeArray arrayType = null; if (arrayIndexExpr.array.type instanceof Type.TypeArray ta) { @@ -231,6 +253,8 @@ else if (arrayIndexExpr.array.type instanceof Type.TypeNullable ptr && @Override public ASTVisitor visit(AST.ArrayStoreExpr arrayIndexExpr, boolean enter) { if (!enter) { + if (arrayIndexExpr.type != null) + return this; validType(arrayIndexExpr.array.type, false); Type.TypeArray arrayType = null; if (arrayIndexExpr.array.type instanceof Type.TypeArray ta) { @@ -256,6 +280,8 @@ else if (arrayIndexExpr.array.type instanceof Type.TypeNullable ptr && public ASTVisitor visit(AST.NewExpr newExpr, boolean enter) { if (enter) return this; + if (newExpr.type != null) + return this; if (newExpr.typeExpr.type == null) throw new CompilerException("Unresolved type in new expression"); validType(newExpr.typeExpr.type, false); @@ -266,6 +292,14 @@ public ASTVisitor visit(AST.NewExpr newExpr, boolean enter) { } else if (newExpr.typeExpr.type instanceof Type.TypeArray arrayType) { newExpr.type = newExpr.typeExpr.type; + if (newExpr.len != null) { + if (!(newExpr.len.type instanceof Type.TypeInteger)) + throw new CompilerException("Array len must be integer type"); + if (newExpr.initValue != null) { + if (!arrayType.getElementType().isAssignable(newExpr.initValue.type)) + throw new CompilerException("Array init value must be assignable to array element type"); + } + } } else throw new CompilerException("Unsupported type in new expression"); @@ -278,6 +312,8 @@ public ASTVisitor visit(AST.InitExpr initExpr, boolean enter) { return this; if (initExpr.newExpr.type == null) throw new CompilerException("Unresolved type in new expression"); + if (initExpr.type != null) + return this; validType(initExpr.newExpr.type, false); if (initExpr.newExpr.type instanceof Type.TypeNullable) throw new CompilerException("new cannot be used to create a Nullable type"); @@ -290,6 +326,8 @@ public ASTVisitor visit(AST.InitExpr initExpr, boolean enter) { } } else if (initExpr.newExpr.type instanceof Type.TypeArray arrayType) { + if (initExpr.initExprList.size() > 0) + initExpr.initExprList.removeIf(e->e instanceof AST.InitFieldExpr); for (AST.Expr expr: initExpr.initExprList) { checkAssignmentCompatible(arrayType.getElementType(), expr.type); } @@ -304,6 +342,8 @@ else if (initExpr.newExpr.type instanceof Type.TypeArray arrayType) { public ASTVisitor visit(AST.NameExpr nameExpr, boolean enter) { if (!enter) return this; + if (nameExpr.type != null) + return this; var symbol = currentScope.lookup(nameExpr.name); if (symbol == null) { throw new CompilerException("Unknown symbol " + nameExpr.name); diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java index 05cbca6..1c8514c 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java @@ -53,7 +53,7 @@ public String toString() { public String name() { return name; } public boolean isAssignable(Type other) { - if (other instanceof TypeVoid || other instanceof TypeUnknown) + if (other == null || other instanceof TypeVoid || other instanceof TypeUnknown) return false; if (this == other || equals(other)) return true; if (this instanceof TypeNullable nullable) {