diff --git a/README.md b/README.md index 46ac9e8..b93906a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Status -[![.github/workflows/ci.yml](https://github.com/clause/471c/actions/workflows/ci.yml/badge.svg)](https://github.com/clause/471c/actions/workflows/ci.yml) -[![Coverage](https://codecov.io/gh/clause/471c/branch/main/graph/badge.svg)](https://codecov.io/gh/clause/471c) +[![.github/workflows/ci.yml](https://github.com/JTFulkerson/471c/actions/workflows/ci.yml/badge.svg)](https://github.com/JTFulkerson/471c/actions/workflows/ci.yml) +[![Coverage](https://codecov.io/gh/JTFulkerson/471c/branch/main/graph/badge.svg)](https://codecov.io/gh/JTFulkerson/471c) # Contributing diff --git a/packages/L0/README.md b/packages/L0/README.md index e69de29..da6e2bc 100644 --- a/packages/L0/README.md +++ b/packages/L0/README.md @@ -0,0 +1,3 @@ +L0 + +L0 is now a procedure based language as compared to our previous expression and statement languages. Here, a program has a sequence of named procedures each containing its own parameters and a statement body. A statement can have copy, immediate, primative, branch, allocate, load, store, address, call, and hault. The main differences are that L0 organizes code into multiple named procedures whereas L1 had a single extry program with a statement body. L1 also uses abstract and apply for its functions where L0 replaces them with address and call removing a layer of abstraction. L0 also makes procedures ex plicity and predefined rather than in others where they were constructred functions at runtime. \ No newline at end of file diff --git a/packages/L1/README.md b/packages/L1/README.md index e69de29..6eaeb81 100644 --- a/packages/L1/README.md +++ b/packages/L1/README.md @@ -0,0 +1,3 @@ +L1 + +Now in L1 it is a statement based language with very explicit flow control. There exists a program which has parameters and a body, and the statements are similar to L2 and L3 with a few differences. We have copy, immediate, primative, branch, abstract, apply, allocate, load, store, and hault. The differences with L1 as compared to L2 and L3 is that L2 used nested terms as expressions while now L1 replaces all of those with sequential statements. Doing this makes flow control explicit as each statement has a then continuation, except for apply and hault. L1 also adds a copy ability to move values between identifiers, and hault to end the program. L1 removes the idea of let, reference, and begin. \ No newline at end of file diff --git a/packages/L2/README.md b/packages/L2/README.md index e69de29..5fba267 100644 --- a/packages/L2/README.md +++ b/packages/L2/README.md @@ -0,0 +1,3 @@ +L2 + +L2 has most of the same characteristics as L3 except letrec is missing. Functions, arithmetic, branching, and heap management all remain the same but not having letrec means that recursive bindings are not avalible. \ No newline at end of file diff --git a/packages/L2/src/L2/constant_folding.py b/packages/L2/src/L2/constant_folding.py new file mode 100644 index 0000000..c360fd8 --- /dev/null +++ b/packages/L2/src/L2/constant_folding.py @@ -0,0 +1,180 @@ +from functools import partial + +from .syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, + Immediate, + Let, + Load, + Primitive, + Reference, + Store, + Term, +) +from .util import ( + Context, + extend_context_with_bindings, + recur_terms, +) + + +def _normalize_commutative_immediate_left( + operator: str, + left: Term, + right: Term, +) -> Primitive: + return Primitive( + operator=operator, # type: ignore[arg-type] + left=right, + right=left, + ) + + +def constant_folding_term( + term: Term, + context: Context, +) -> Term: + recur = partial(constant_folding_term, context=context) # noqa: F841 + + match term: + case Let(bindings=bindings, body=body): + new_bindings, new_context = extend_context_with_bindings(bindings, context, recur) + return Let( + bindings=new_bindings, + body=constant_folding_term(body, new_context), + ) + + case Reference(name=name): + if name in context: + return context[name] + return term + + case Abstract(parameters=parameters, body=body): + return Abstract(parameters=parameters, body=recur(body)) + + case Apply(target=target, arguments=arguments): + return Apply( + target=recur(target), + arguments=recur_terms(arguments, recur), + ) + + case Immediate(value=_value): + return term + + case Primitive(operator=operator, left=left, right=right): + match operator: + case "+": + match recur(left), recur(right): + case Immediate(value=i1), Immediate(value=i2): + return Immediate(value=i1 + i2) + + case Immediate(value=0), right: + return right + + case [ + Primitive(operator="+", left=Immediate(value=i1), right=left), + Primitive(operator="+", left=Immediate(value=i2), right=right), + ]: + return Primitive( + operator="+", + left=Immediate(value=i1 + i2), + right=Primitive( + operator="+", + left=left, + right=right, + ), + ) + + case left, Immediate() as right: + return _normalize_commutative_immediate_left("+", left, right) + + # Coverage reports a synthetic exit arc on this fallback match arm. + # The arm is intentionally reachable and returns the non-folded primitive. + case left, right: # pragma: no branch + return Primitive( + operator="+", + left=left, + right=right, + ) + + case "-": + match recur(left), recur(right): + case Immediate(value=i1), Immediate(value=i2): + return Immediate(value=i1 - i2) + + # Coverage reports a synthetic exit arc on this fallback match arm. + # The arm is intentionally reachable and returns the non-folded primitive. + case left, right: # pragma: no branch + return Primitive(operator="-", left=left, right=right) + + # Coverage may report an extra arc on this literal case label under pattern matching. + # Runtime terms validated by the syntax model still follow normal folding logic below. + case "*": # pragma: no branch + match recur(left), recur(right): + case Immediate(value=i1), Immediate(value=i2): + return Immediate(value=i1 * i2) + + case Immediate(value=0), _: + return Immediate(value=0) + + case _, Immediate(value=0): + return Immediate(value=0) + + case Immediate(value=1), right: + return right + + case left, Immediate(value=1): + return left + + case left, Immediate() as right: + return _normalize_commutative_immediate_left("*", left, right) + + # Coverage reports a synthetic exit arc on this fallback match arm. + # The arm is intentionally reachable and returns the non-folded primitive. + case left, right: # pragma: no branch + return Primitive(operator="*", left=left, right=right) + + case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): + folded_left = recur(left) + folded_right = recur(right) + folded_consequent = recur(consequent) + folded_otherwise = recur(otherwise) + match operator: + case "<": + match folded_left, folded_right: + case Immediate(value=i1), Immediate(value=i2): + return folded_consequent if i1 < i2 else folded_otherwise + case _: + pass + # Coverage may report an extra arc on this literal case label under pattern matching. + # Runtime terms validated by the syntax model use only "<" and "==". + case "==": # pragma: no branch + match folded_left, folded_right: + case Immediate(value=i1), Immediate(value=i2): + return folded_consequent if i1 == i2 else folded_otherwise + case _: + pass + return Branch( + operator=operator, + left=folded_left, + right=folded_right, + consequent=folded_consequent, + otherwise=folded_otherwise, + ) + + case Allocate(count=count): + return Allocate(count=count) + + case Load(base=base, index=index): + return Load(base=recur(base), index=index) + + case Store(base=base, index=index, value=value): + return Store(base=recur(base), index=index, value=recur(value)) + + # Coverage may report an extra structural arc for this match arm. + # Semantically this always returns the reconstructed Begin node. + case Begin(effects=effects, value=value): # pragma: no branch + return Begin(effects=recur_terms(effects, recur), value=recur(value)) diff --git a/packages/L2/src/L2/constant_propogation.py b/packages/L2/src/L2/constant_propogation.py new file mode 100644 index 0000000..b5ea6d5 --- /dev/null +++ b/packages/L2/src/L2/constant_propogation.py @@ -0,0 +1,88 @@ +from functools import partial + +from .syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, + Immediate, + Let, + Load, + Primitive, + Reference, + Store, + Term, +) +from .util import ( + Context, + extend_context_with_bindings, + recur_terms, +) + + +def constant_propogation_term( + term: Term, + context: Context, +) -> Term: + recur = partial(constant_propogation_term, context=context) + + match term: + case Let(bindings=bindings, body=body): + new_bindings, new_context = extend_context_with_bindings(bindings, context, recur) + return Let( + bindings=new_bindings, + body=constant_propogation_term(body, new_context), + ) + + case Reference(name=name): + if name in context: + return context[name] + return term + + case Abstract(parameters=parameters, body=body): + abstract_context = {name: value for name, value in context.items() if name not in parameters} + return Abstract( + parameters=parameters, + body=constant_propogation_term(body, abstract_context), + ) + + case Apply(target=target, arguments=arguments): + return Apply( + target=recur(target), + arguments=recur_terms(arguments, recur), + ) + + case Immediate(value=_value): + return term + + case Primitive(operator=operator, left=left, right=right): + return Primitive( + operator=operator, + left=recur(left), + right=recur(right), + ) + + case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): + return Branch( + operator=operator, + left=recur(left), + right=recur(right), + consequent=recur(consequent), + otherwise=recur(otherwise), + ) + + case Allocate(count=count): + return Allocate(count=count) + + case Load(base=base, index=index): + return Load(base=recur(base), index=index) + + case Store(base=base, index=index, value=value): + return Store(base=recur(base), index=index, value=recur(value)) + + case Begin(effects=effects, value=value): # pragma: no branch + return Begin( + effects=recur_terms(effects, recur), + value=recur(value), + ) diff --git a/packages/L2/src/L2/cps_convert.py b/packages/L2/src/L2/cps_convert.py index ee729ad..b9f1b2a 100644 --- a/packages/L2/src/L2/cps_convert.py +++ b/packages/L2/src/L2/cps_convert.py @@ -8,7 +8,7 @@ def cps_convert_term( term: L2.Term, - k: Callable[[L1.Identifier], L1.Statement], + m: Callable[[L1.Identifier], L1.Statement], fresh: Callable[[str], str], ) -> L1.Statement: _term = partial(cps_convert_term, fresh=fresh) @@ -16,37 +16,110 @@ def cps_convert_term( match term: case L2.Let(bindings=bindings, body=body): - pass + result = _term(body, m) + + for name, value in reversed(bindings): + result = _term(value, lambda value: L1.Copy(destination=name, source=value, then=result)) + + return result case L2.Reference(name=name): - pass + return m(name) case L2.Abstract(parameters=parameters, body=body): - pass + tmp = fresh("t") + k = fresh("k") + return L1.Abstract( + destination=tmp, + parameters=[*parameters, k], + body=_term(body, lambda body: L1.Apply(target=k, arguments=[body])), + then=m(tmp), + ) case L2.Apply(target=target, arguments=arguments): - pass + k = fresh("k") + tmp = fresh("t") + return L1.Abstract( + destination=k, + parameters=[tmp], + body=m(tmp), + then=_term( + target, + lambda target: _terms( + arguments, + lambda arguments: L1.Apply(target=target, arguments=[*arguments, k]), + ), + ), + ) case L2.Immediate(value=value): - pass + tmp = fresh("t") + return L1.Immediate(destination=tmp, value=value, then=m(tmp)) case L2.Primitive(operator=operator, left=left, right=right): - pass + tmp = fresh("t") + return _term( + left, + lambda left: _term( + right, + lambda right: L1.Primitive(destination=tmp, operator=operator, left=left, right=right, then=m(tmp)), + ), + ) case L2.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): - pass + j = fresh("j") + tmp = fresh("t") + return L1.Abstract( + destination=j, + parameters=[tmp], + body=m(tmp), + then=_term( + left, + lambda left: _term( + right, + lambda right: L1.Branch( + operator=operator, + left=left, + right=right, + then=_term(consequent, lambda consequent: L1.Apply(target=j, arguments=[consequent])), + otherwise=_term(otherwise, lambda otherwise: L1.Apply(target=j, arguments=[otherwise])), + ), + ), + ), + ) case L2.Allocate(count=count): - pass + tmp = fresh("t") + return L1.Allocate(destination=tmp, count=count, then=m(tmp)) case L2.Load(base=base, index=index): - pass + tmp = fresh("t") + return _term( + base, + lambda base: L1.Load(destination=tmp, base=base, index=index, then=m(tmp)), + ) + # Should double check this case L2.Store(base=base, index=index, value=value): - pass + tmp = fresh("t") + return _term( + base, + lambda base: _term( + value, + lambda value: L1.Store( + base=base, index=index, value=value, then=L1.Immediate(destination=tmp, value=0, then=m(tmp)) + ), + ), + ) case L2.Begin(effects=effects, value=value): # pragma: no branch - pass + return _terms( + effects, + lambda effects: _term( + value, + lambda value: m(value), + ), + ) def cps_convert_terms( diff --git a/packages/L2/src/L2/dead_code_elimination.py b/packages/L2/src/L2/dead_code_elimination.py new file mode 100644 index 0000000..4bf2d16 --- /dev/null +++ b/packages/L2/src/L2/dead_code_elimination.py @@ -0,0 +1,180 @@ +from functools import partial + +from .syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, + Identifier, + Immediate, + Let, + Load, + Primitive, + Reference, + Store, + Term, +) +from .util import ( + Context, + recur_terms, +) + + +def is_pure(term: Term) -> bool: + match term: + case Immediate(): + return True + + case Reference(): + return True + + case Primitive(left=left, right=right): + return is_pure(left) and is_pure(right) + + case Abstract(body=body): + return is_pure(body) + + case Let(bindings=bindings, body=body): + return all(is_pure(value) for _, value in bindings) and is_pure(body) + + case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): + return is_pure(left) and is_pure(right) and is_pure(consequent) and is_pure(otherwise) + + case Load(base=base): + return is_pure(base) + + case Begin(effects=effects, value=value): + return all(is_pure(effect) for effect in effects) and is_pure(value) + + case Apply(): + return False + + case Allocate(): + return False + + case Store(): + return False + + +def free_vars(term: Term) -> set[Identifier]: + match term: + case Immediate(): + return set() + + case Reference(name=name): + return {name} + + case Primitive(left=left, right=right): + return free_vars(left) | free_vars(right) + + case Apply(target=target, arguments=arguments): + result = free_vars(target) + for argument in arguments: + result |= free_vars(argument) + return result + + case Abstract(parameters=parameters, body=body): + return free_vars(body) - set(parameters) + + case Branch(left=left, right=right, consequent=consequent, otherwise=otherwise): + return free_vars(left) | free_vars(right) | free_vars(consequent) | free_vars(otherwise) + + case Load(base=base): + return free_vars(base) + + case Store(base=base, value=value): + return free_vars(base) | free_vars(value) + + case Begin(effects=effects, value=value): + result = free_vars(value) + for effect in effects: + result |= free_vars(effect) + return result + + case Allocate(): + return set() + + case Let(bindings=bindings, body=body): + names = [name for name, _ in bindings] + result = free_vars(body) - set(names) + for _, value in bindings: + result |= free_vars(value) + return result + + +def dead_code_elimination_term( + term: Term, + context: Context, +) -> Term: + recur = partial(dead_code_elimination_term, context=context) + + match term: + case Let(bindings=bindings, body=body): + new_values = [(name, recur(value)) for name, value in bindings] + new_body = recur(body) + + live = free_vars(new_body) + kept_reversed: list[tuple[Identifier, Term]] = [] + + for name, value in reversed(new_values): + if name in live or not is_pure(value): + kept_reversed.append((name, value)) + live.discard(name) + live |= free_vars(value) + + kept = list(reversed(kept_reversed)) + if len(kept) == 0: + return new_body + + return Let(bindings=kept, body=new_body) + + case Reference(name=_name): + return term + + case Abstract(parameters=parameters, body=body): + return Abstract(parameters=parameters, body=recur(body)) + + case Apply(target=target, arguments=arguments): + return Apply( + target=recur(target), + arguments=recur_terms(arguments, recur), + ) + + case Immediate(value=_value): + return term + + case Primitive(operator=operator, left=left, right=right): + return Primitive( + operator=operator, + left=recur(left), + right=recur(right), + ) + + case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): + return Branch( + operator=operator, + left=recur(left), + right=recur(right), + consequent=recur(consequent), + otherwise=recur(otherwise), + ) + + case Allocate(count=count): + return Allocate(count=count) + + case Load(base=base, index=index): + return Load(base=recur(base), index=index) + + case Store(base=base, index=index, value=value): + return Store(base=recur(base), index=index, value=recur(value)) + + case Begin(effects=effects, value=value): # pragma: no branch + new_effects = [recur(effect) for effect in effects] + kept_effects = [effect for effect in new_effects if not is_pure(effect)] + new_value = recur(value) + + if len(kept_effects) == 0: + return new_value + + return Begin(effects=kept_effects, value=new_value) diff --git a/packages/L2/src/L2/optimize.py b/packages/L2/src/L2/optimize.py index 77ef7c6..6972975 100644 --- a/packages/L2/src/L2/optimize.py +++ b/packages/L2/src/L2/optimize.py @@ -1,7 +1,32 @@ +from .constant_folding import constant_folding_term +from .constant_propogation import constant_propogation_term +from .dead_code_elimination import dead_code_elimination_term from .syntax import Program +def optimize_program_step( + program: Program, +) -> tuple[Program, bool]: + propagated = Program( + parameters=program.parameters, + body=constant_propogation_term(program.body, {}), + ) + folded = Program( + parameters=propagated.parameters, + body=constant_folding_term(propagated.body, {}), + ) + eliminated = Program( + parameters=folded.parameters, + body=dead_code_elimination_term(folded.body, {}), + ) + return eliminated, eliminated != program + + def optimize_program( program: Program, ) -> Program: - return program + current = program + while True: + current, changed = optimize_program_step(current) + if not changed: + return current diff --git a/packages/L2/src/L2/util.py b/packages/L2/src/L2/util.py new file mode 100644 index 0000000..2dbfff9 --- /dev/null +++ b/packages/L2/src/L2/util.py @@ -0,0 +1,32 @@ +from collections.abc import Callable, Sequence + +from .syntax import ( + Identifier, + Immediate, + Reference, + Term, +) + +type Context = dict[Identifier, Term] + + +def recur_terms( + terms: Sequence[Term], + recur: Callable[[Term], Term], +) -> list[Term]: + return [recur(term) for term in terms] + + +def extend_context_with_bindings( + bindings: Sequence[tuple[Identifier, Term]], + context: Context, + recur: Callable[[Term], Term], +) -> tuple[list[tuple[Identifier, Term]], Context]: + new_bindings: list[tuple[Identifier, Term]] = [] + new_context = dict(context) + for name, value in bindings: + result = recur(value) + new_bindings.append((name, result)) + if isinstance(result, Immediate | Reference): + new_context[name] = result + return new_bindings, new_context diff --git a/packages/L2/test/L2/test_optimize.py b/packages/L2/test/L2/test_optimize.py index 716e9cc..a018941 100644 --- a/packages/L2/test/L2/test_optimize.py +++ b/packages/L2/test/L2/test_optimize.py @@ -1,9 +1,27 @@ -from L2.optimize import optimize_program +from typing import cast + +from L2.constant_folding import constant_folding_term +from L2.constant_propogation import constant_propogation_term +from L2.dead_code_elimination import dead_code_elimination_term, free_vars, is_pure +from L2.optimize import optimize_program, optimize_program_step from L2.syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, Immediate, + Let, + Load, Primitive, Program, + Reference, + Store, + Term, ) +from pydantic import ValidationError + +# Used copilot to help with writing these tests, gave it my expected input and output and it generated tests that I then modified to be more comprehensive and cover more edge cases. def test_optimize_program(): @@ -24,3 +42,552 @@ def test_optimize_program(): actual = optimize_program(program) assert actual == expected + + +def test_constant_propogation_term_all_cases(): + term = Let( + bindings=[ + ("x", Immediate(value=1)), + ("y", Reference(name="x")), + ], + body=Begin( + effects=[ + Apply(target=Reference(name="f"), arguments=[Reference(name="y")]), + Store(base=Reference(name="arr"), index=0, value=Reference(name="x")), + ], + value=Branch( + operator="<", + left=Primitive(operator="+", left=Reference(name="x"), right=Immediate(value=2)), + right=Immediate(value=5), + consequent=Load(base=Allocate(count=1), index=0), + otherwise=Abstract(parameters=["x"], body=Reference(name="x")), + ), + ), + ) + + expected = Let( + bindings=[ + ("x", Immediate(value=1)), + ("y", Reference(name="x")), + ], + body=Begin( + effects=[ + Apply(target=Reference(name="f"), arguments=[Reference(name="x")]), + Store(base=Reference(name="arr"), index=0, value=Immediate(value=1)), + ], + value=Branch( + operator="<", + left=Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2)), + right=Immediate(value=5), + consequent=Load(base=Allocate(count=1), index=0), + otherwise=Abstract(parameters=["x"], body=Reference(name="x")), + ), + ), + ) + + actual = constant_propogation_term(term, context={"f": Reference(name="f")}) + + assert actual == expected + + +def test_constant_folding_plus_cases(): + term = Primitive( + operator="+", + left=Primitive(operator="+", left=Immediate(value=2), right=Reference(name="a")), + right=Primitive(operator="+", left=Immediate(value=3), right=Reference(name="b")), + ) + + expected = Primitive( + operator="+", + left=Immediate(value=5), + right=Primitive( + operator="+", + left=Reference(name="a"), + right=Reference(name="b"), + ), + ) + + actual = constant_folding_term(term, context={}) + + assert actual == expected + + actual_immediate = constant_folding_term( + Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2)), + context={}, + ) + assert actual_immediate == Immediate(value=3) + + actual_zero = constant_folding_term( + Primitive(operator="+", left=Immediate(value=0), right=Reference(name="x")), + context={}, + ) + assert actual_zero == Reference(name="x") + + actual_normalize = constant_folding_term( + Primitive(operator="+", left=Reference(name="x"), right=Immediate(value=9)), + context={}, + ) + assert actual_normalize == Primitive( + operator="+", + left=Immediate(value=9), + right=Reference(name="x"), + ) + + +def test_constant_folding_minus_and_multiply_cases(): + actual_subtract = constant_folding_term( + Primitive(operator="-", left=Immediate(value=9), right=Immediate(value=4)), + context={}, + ) + assert actual_subtract == Immediate(value=5) + + actual_subtract_not_folded = constant_folding_term( + Primitive(operator="-", left=Reference(name="x"), right=Immediate(value=4)), + context={}, + ) + assert actual_subtract_not_folded == Primitive( + operator="-", + left=Reference(name="x"), + right=Immediate(value=4), + ) + + actual_multiply = constant_folding_term( + Primitive(operator="*", left=Immediate(value=3), right=Immediate(value=4)), + context={}, + ) + assert actual_multiply == Immediate(value=12) + + actual_mul_left_zero = constant_folding_term( + Primitive(operator="*", left=Immediate(value=0), right=Reference(name="x")), + context={}, + ) + assert actual_mul_left_zero == Immediate(value=0) + + actual_mul_right_zero = constant_folding_term( + Primitive(operator="*", left=Reference(name="x"), right=Immediate(value=0)), + context={}, + ) + assert actual_mul_right_zero == Immediate(value=0) + + actual_mul_left_one = constant_folding_term( + Primitive(operator="*", left=Immediate(value=1), right=Reference(name="x")), + context={}, + ) + assert actual_mul_left_one == Reference(name="x") + + actual_mul_right_one = constant_folding_term( + Primitive(operator="*", left=Reference(name="x"), right=Immediate(value=1)), + context={}, + ) + assert actual_mul_right_one == Reference(name="x") + + actual_mul_normalize = constant_folding_term( + Primitive(operator="*", left=Reference(name="x"), right=Immediate(value=7)), + context={}, + ) + assert actual_mul_normalize == Primitive( + operator="*", + left=Immediate(value=7), + right=Reference(name="x"), + ) + + actual_mul_not_folded = constant_folding_term( + Primitive(operator="*", left=Reference(name="x"), right=Reference(name="y")), + context={}, + ) + assert actual_mul_not_folded == Primitive( + operator="*", + left=Reference(name="x"), + right=Reference(name="y"), + ) + + +def test_constant_folding_non_primitive_cases_and_branches(): + actual_reference_hit = constant_folding_term(Reference(name="x"), context={"x": Immediate(value=4)}) + assert actual_reference_hit == Immediate(value=4) + + actual_reference_miss = constant_folding_term(Reference(name="y"), context={"x": Immediate(value=4)}) + assert actual_reference_miss == Reference(name="y") + + actual_abstract = constant_folding_term( + Abstract(parameters=["x"], body=Primitive(operator="+", left=Reference(name="x"), right=Immediate(value=1))), + context={}, + ) + assert actual_abstract == Abstract( + parameters=["x"], + body=Primitive(operator="+", left=Immediate(value=1), right=Reference(name="x")), + ) + + actual_apply = constant_folding_term( + Apply( + target=Reference(name="f"), + arguments=[Primitive(operator="+", left=Immediate(value=2), right=Immediate(value=3))], + ), + context={}, + ) + assert actual_apply == Apply(target=Reference(name="f"), arguments=[Immediate(value=5)]) + + actual_immediate = constant_folding_term(Immediate(value=11), context={}) + assert actual_immediate == Immediate(value=11) + + actual_lt_true = constant_folding_term( + Branch( + operator="<", + left=Immediate(value=1), + right=Immediate(value=2), + consequent=Immediate(value=7), + otherwise=Immediate(value=8), + ), + context={}, + ) + assert actual_lt_true == Immediate(value=7) + + actual_lt_false = constant_folding_term( + Branch( + operator="<", + left=Immediate(value=4), + right=Immediate(value=2), + consequent=Immediate(value=7), + otherwise=Immediate(value=8), + ), + context={}, + ) + assert actual_lt_false == Immediate(value=8) + + actual_lt_fallback = constant_folding_term( + Branch( + operator="<", + left=Reference(name="x"), + right=Immediate(value=2), + consequent=Immediate(value=7), + otherwise=Immediate(value=8), + ), + context={}, + ) + assert actual_lt_fallback == Branch( + operator="<", + left=Reference(name="x"), + right=Immediate(value=2), + consequent=Immediate(value=7), + otherwise=Immediate(value=8), + ) + + actual_eq_true = constant_folding_term( + Branch( + operator="==", + left=Immediate(value=3), + right=Immediate(value=3), + consequent=Immediate(value=9), + otherwise=Immediate(value=10), + ), + context={}, + ) + assert actual_eq_true == Immediate(value=9) + + actual_eq_false = constant_folding_term( + Branch( + operator="==", + left=Immediate(value=3), + right=Immediate(value=4), + consequent=Immediate(value=9), + otherwise=Immediate(value=10), + ), + context={}, + ) + assert actual_eq_false == Immediate(value=10) + + actual_eq_fallback = constant_folding_term( + Branch( + operator="==", + left=Reference(name="x"), + right=Immediate(value=4), + consequent=Immediate(value=9), + otherwise=Immediate(value=10), + ), + context={}, + ) + assert actual_eq_fallback == Branch( + operator="==", + left=Reference(name="x"), + right=Immediate(value=4), + consequent=Immediate(value=9), + otherwise=Immediate(value=10), + ) + + actual_allocate = constant_folding_term(Allocate(count=3), context={}) + assert actual_allocate == Allocate(count=3) + + actual_load = constant_folding_term( + Load(base=Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2)), index=0), context={} + ) + assert actual_load == Load(base=Immediate(value=3), index=0) + + actual_store = constant_folding_term( + Store( + base=Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2)), + index=0, + value=Primitive(operator="+", left=Immediate(value=3), right=Immediate(value=4)), + ), + context={}, + ) + assert actual_store == Store(base=Immediate(value=3), index=0, value=Immediate(value=7)) + + actual_begin = constant_folding_term( + Begin( + effects=[Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2))], + value=Primitive(operator="+", left=Immediate(value=3), right=Immediate(value=4)), + ), + context={}, + ) + assert actual_begin == Begin(effects=[Immediate(value=3)], value=Immediate(value=7)) + + +def test_dead_codeis_pure_andfree_vars_cases(): + assert is_pure(Immediate(value=1)) is True + assert is_pure(Reference(name="x")) is True + assert is_pure(Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2))) is True + assert is_pure(Abstract(parameters=["x"], body=Reference(name="x"))) is True + assert is_pure(Let(bindings=[("x", Immediate(value=1))], body=Reference(name="x"))) is True + assert ( + is_pure( + Branch( + operator="<", + left=Immediate(value=1), + right=Immediate(value=2), + consequent=Immediate(value=3), + otherwise=Immediate(value=4), + ) + ) + is True + ) + assert is_pure(Load(base=Reference(name="x"), index=0)) is True + assert is_pure(Begin(effects=[Immediate(value=1)], value=Reference(name="x"))) is True + assert is_pure(Apply(target=Reference(name="f"), arguments=[])) is False + assert is_pure(Allocate(count=1)) is False + assert is_pure(Store(base=Reference(name="x"), index=0, value=Immediate(value=1))) is False + + assert free_vars(Immediate(value=1)) == set() + assert free_vars(Reference(name="x")) == {"x"} + assert free_vars(Primitive(operator="+", left=Reference(name="x"), right=Reference(name="y"))) == {"x", "y"} + assert free_vars(Apply(target=Reference(name="f"), arguments=[Reference(name="x")])) == {"f", "x"} + assert free_vars( + Abstract(parameters=["x"], body=Primitive(operator="+", left=Reference(name="x"), right=Reference(name="y"))) + ) == {"y"} + assert free_vars( + Branch( + operator="<", + left=Reference(name="a"), + right=Reference(name="b"), + consequent=Reference(name="c"), + otherwise=Reference(name="d"), + ) + ) == {"a", "b", "c", "d"} + assert free_vars(Load(base=Reference(name="arr"), index=0)) == {"arr"} + assert free_vars(Store(base=Reference(name="arr"), index=0, value=Reference(name="v"))) == {"arr", "v"} + assert free_vars(Begin(effects=[Reference(name="u")], value=Reference(name="v"))) == {"u", "v"} + assert free_vars(Allocate(count=1)) == set() + assert free_vars( + Let( + bindings=[ + ("x", Reference(name="a")), + ("y", Primitive(operator="+", left=Reference(name="x"), right=Reference(name="b"))), + ], + body=Primitive(operator="+", left=Reference(name="y"), right=Reference(name="c")), + ) + ) == {"a", "b", "c", "x"} + + +def test_dead_code_elimination_term_all_cases(): + term_drop_let = Let( + bindings=[("x", Immediate(value=1))], + body=Immediate(value=7), + ) + expected_drop_let = Immediate(value=7) + actual_drop_let = dead_code_elimination_term(term_drop_let, context={}) + assert actual_drop_let == expected_drop_let + + term_keep_let = Let( + bindings=[("x", Store(base=Reference(name="arr"), index=0, value=Immediate(value=1)))], + body=Immediate(value=7), + ) + expected_keep_let = Let( + bindings=[("x", Store(base=Reference(name="arr"), index=0, value=Immediate(value=1)))], + body=Immediate(value=7), + ) + actual_keep_let = dead_code_elimination_term(term_keep_let, context={}) + assert actual_keep_let == expected_keep_let + + actual_reference = dead_code_elimination_term(Reference(name="x"), context={}) + assert actual_reference == Reference(name="x") + + actual_abstract = dead_code_elimination_term( + Abstract(parameters=["x"], body=Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2))), + context={}, + ) + assert actual_abstract == Abstract( + parameters=["x"], body=Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2)) + ) + + actual_apply = dead_code_elimination_term( + Apply( + target=Reference(name="f"), + arguments=[Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2))], + ), + context={}, + ) + assert actual_apply == Apply( + target=Reference(name="f"), + arguments=[Primitive(operator="+", left=Immediate(value=1), right=Immediate(value=2))], + ) + + actual_immediate = dead_code_elimination_term(Immediate(value=5), context={}) + assert actual_immediate == Immediate(value=5) + + actual_primitive = dead_code_elimination_term( + Primitive(operator="+", left=Reference(name="x"), right=Immediate(value=2)), + context={}, + ) + assert actual_primitive == Primitive(operator="+", left=Reference(name="x"), right=Immediate(value=2)) + + actual_branch = dead_code_elimination_term( + Branch( + operator="==", + left=Reference(name="x"), + right=Immediate(value=2), + consequent=Immediate(value=1), + otherwise=Immediate(value=0), + ), + context={}, + ) + assert actual_branch == Branch( + operator="==", + left=Reference(name="x"), + right=Immediate(value=2), + consequent=Immediate(value=1), + otherwise=Immediate(value=0), + ) + + actual_allocate = dead_code_elimination_term(Allocate(count=2), context={}) + assert actual_allocate == Allocate(count=2) + + actual_load = dead_code_elimination_term(Load(base=Reference(name="arr"), index=0), context={}) + assert actual_load == Load(base=Reference(name="arr"), index=0) + + actual_store = dead_code_elimination_term( + Store(base=Reference(name="arr"), index=0, value=Immediate(value=7)), + context={}, + ) + assert actual_store == Store(base=Reference(name="arr"), index=0, value=Immediate(value=7)) + + term_begin_drop = Begin(effects=[Immediate(value=1)], value=Reference(name="x")) + expected_begin_drop = Reference(name="x") + actual_begin_drop = dead_code_elimination_term(term_begin_drop, context={}) + assert actual_begin_drop == expected_begin_drop + + term_begin_keep = Begin( + effects=[Immediate(value=1), Store(base=Reference(name="arr"), index=0, value=Immediate(value=9))], + value=Reference(name="x"), + ) + expected_begin_keep = Begin( + effects=[Store(base=Reference(name="arr"), index=0, value=Immediate(value=9))], + value=Reference(name="x"), + ) + actual_begin_keep = dead_code_elimination_term(term_begin_keep, context={}) + assert actual_begin_keep == expected_begin_keep + + +def test_optimize_program_step_and_optimize_program(): + program_change = Program( + parameters=[], + body=Let( + bindings=[ + ("x", Immediate(value=1)), + ("y", Reference(name="x")), + ("z", Primitive(operator="+", left=Reference(name="y"), right=Immediate(value=2))), + ], + body=Reference(name="z"), + ), + ) + + expected_step_program = Program( + parameters=[], + body=Let( + bindings=[ + ("x", Immediate(value=1)), + ("y", Reference(name="x")), + ("z", Primitive(operator="+", left=Immediate(value=2), right=Reference(name="y"))), + ], + body=Reference(name="z"), + ), + ) + + actual_step_program, changed = optimize_program_step(program_change) + + assert actual_step_program == expected_step_program + assert changed is True + + program_no_change = Program( + parameters=["x"], + body=Reference(name="x"), + ) + + actual_same_program, changed_same = optimize_program_step(program_no_change) + + assert actual_same_program == program_no_change + assert changed_same is False + + actual_optimize = optimize_program(program_change) + + assert actual_optimize == expected_step_program + + actual_optimize_no_change = optimize_program(program_no_change) + + assert actual_optimize_no_change == program_no_change + + +def test_dead_code_helper_fallthrough_cases(): + invalid_term = cast(Term, object()) + + actualis_pure = is_pure(invalid_term) + assert actualis_pure is None + + actualfree_vars = free_vars(invalid_term) + assert actualfree_vars is None + + +def test_constant_folding_edge_fallthrough_cases(): + plus_invalid = Primitive.model_construct( + operator="+", + left=Reference(name="x"), + right=object(), + ) + with_validation_error_plus = False + try: + constant_folding_term(plus_invalid, context={}) + except ValidationError: + with_validation_error_plus = True + assert with_validation_error_plus is True + + minus_invalid = Primitive.model_construct( + operator="-", + left=Reference(name="x"), + right=object(), + ) + with_validation_error_minus = False + try: + constant_folding_term(minus_invalid, context={}) + except ValidationError: + with_validation_error_minus = True + assert with_validation_error_minus is True + + multiply_invalid = Primitive.model_construct( + operator="*", + left=Reference(name="x"), + right=object(), + ) + with_validation_error_multiply = False + try: + constant_folding_term(multiply_invalid, context={}) + except ValidationError: + with_validation_error_multiply = True + assert with_validation_error_multiply is True diff --git a/packages/L3/README.md b/packages/L3/README.md index e69de29..c22df31 100644 --- a/packages/L3/README.md +++ b/packages/L3/README.md @@ -0,0 +1,11 @@ +L3 + +The L3 Compiler is an expression based language with named variables, functions, arithmetic, conditionals, and explicit heap operations. There is an Identifier that names variables and functions as well as a Program that contains parameters and a single term body. A term is a unioned type of all avalible constructs. + +Bindings and variables let and letrec bind different identifiers to terms; and then reference reads a bound name. Unlike let, letrec supports self-reference for recursive definitions. + +For functions, there is an abstract that defines the function with a list of parameters and a body. Apply is used to invoke a function with given arguments. + +For values and control we have immediate, primative, and branch. Immediate is an integer literal and primitive applies the +, -, or * operator to a left and right term. Branch is then able to be used to evaluate < or == and selects a consequent or otherwise. + +The last thing that L3 supports is memory and sequencing. Allocate reserves a block of memory, while load reads from the memory and store writes to memory. There is also begin which sequences terms before producing a final value. \ No newline at end of file diff --git a/packages/L3/src/L3/L3.lark b/packages/L3/src/L3/L3.lark index a7f21d7..f2e0dc0 100644 --- a/packages/L3/src/L3/L3.lark +++ b/packages/L3/src/L3/L3.lark @@ -1,3 +1,6 @@ +%import common.WS +%ignore WS + program : "(" PROGRAM "(" parameters ")" term ")" parameters : IDENTIFIER* @@ -20,11 +23,40 @@ let : "(" LET "(" bindings ")" term ")" letrec : "(" LETREC "(" bindings ")" term ")" bindings : binding* -binding : IDENTIFIER term +binding : "(" IDENTIFIER term ")" + +reference : IDENTIFIER + +abstract : "(" LAMBDA "(" parameters ")" term ")" + +apply : "(" term term* ")" + +immediate : NUMBER + +primitive : "(" OPERATOR term term ")" + +branch : "(" IF "(" COMPARATOR term term ")" term term ")" + +allocate : "(" ALLOCATE immediate ")" + +load : "(" LOAD term immediate ")" + +store : "(" STORE term immediate term ")" + +begin : "(" BEGIN term+ ")" PROGRAM.2 : "l3" LET.2 : "let" LETREC.2 : "letrec" LAMBDA.2 : "\\" | "lambda" | "λ" +IF.2 : "if" +ALLOCATE.2 : "allocate" +LOAD.2 : "load" +STORE.2 : "store" +BEGIN.2 : "begin" + +OPERATOR : "+" | "-" | "*" +COMPARATOR : "<" | "==" -IDENTIFIER : /[a-zA-Z_][a-zA-Z0-9_]*/ \ No newline at end of file +IDENTIFIER : /[a-zA-Z_][a-zA-Z0-9_]*/ +NUMBER : /-?\d+/ \ No newline at end of file diff --git a/packages/L3/src/L3/check.py b/packages/L3/src/L3/check.py index 78af5b8..e923785 100644 --- a/packages/L3/src/L3/check.py +++ b/packages/L3/src/L3/check.py @@ -1,6 +1,6 @@ -from collections import Counter from collections.abc import Mapping from functools import partial +from typing import Counter from .syntax import ( Abstract, @@ -27,14 +27,14 @@ def check_term( term: Term, context: Context, ) -> None: - recur = partial(check_term, context=context) + recur = partial(check_term, context=context) # noqa: F841 match term: case Let(bindings=bindings, body=body): counts = Counter(name for name, _ in bindings) duplicates = {name: count for name, count in counts.items() if count > 1} if duplicates: - raise ValueError(f"duplicate binders: {duplicates}") + raise ValueError(f"Duplicate bindings: {duplicates}") for _, value in bindings: recur(value) @@ -46,37 +46,36 @@ def check_term( counts = Counter(name for name, _ in bindings) duplicates = {name: count for name, count in counts.items() if count > 1} if duplicates: - raise ValueError(f"duplicate binders: {duplicates}") + raise ValueError(f"Duplicate bindings: {duplicates}") local = dict.fromkeys([name for name, _ in bindings]) for name, value in bindings: recur(value, context={**context, **local}) - check_term(body, {**context, **local}) + check_term(body, context={**context, **local}) - case Reference(name=name): + case Reference(name=name): # Leaf if name not in context: - raise ValueError(f"unknown variable: {name}") + raise ValueError(f"Unbound variable: {name}") - case Abstract(parameters=parameters, body=body): + case Abstract(parameters=parameters, body=body): # Done counts = Counter(parameters) duplicates = {name for name, count in counts.items() if count > 1} if duplicates: - raise ValueError(f"duplicate parameters: {duplicates}") - + raise ValueError(f"Duplicate parameters: {duplicates}") local = dict.fromkeys(parameters, None) - recur(body, context={**context, **local}) + check_term(body, context=local) - case Apply(target=target, arguments=arguments): + case Apply(target=target, arguments=arguments): # Done recur(target) for argument in arguments: recur(argument) - case Immediate(value=_value): + case Immediate(value=_value): # Leaf pass - case Primitive(operator=_operator, left=left, right=right): + case Primitive(operator=_operator, left=left, right=right): # Should be done recur(left) recur(right) @@ -87,13 +86,13 @@ def check_term( recur(otherwise) case Allocate(count=_count): - pass + pass # No need to check count, as it is a non-negative integer by construction case Load(base=base, index=_index): - recur(base) + recur(base) # No need to check index, as it is a non-negative integer by construction case Store(base=base, index=_index, value=value): - recur(base) + recur(base) # No need to check index, as it is a non-negative integer by construction recur(value) case Begin(effects=effects, value=value): # pragma: no branch @@ -102,15 +101,12 @@ def check_term( recur(value) -def check_program( - program: Program, -) -> None: +def check_program(program: Program) -> None: match program: case Program(parameters=parameters, body=body): # pragma: no branch counts = Counter(parameters) duplicates = {name for name, count in counts.items() if count > 1} if duplicates: - raise ValueError(f"duplicate parameters: {duplicates}") - + raise ValueError(f"Duplicate parameters: {duplicates}") local = dict.fromkeys(parameters, None) check_term(body, context=local) diff --git a/packages/L3/src/L3/eliminate_letrec.py b/packages/L3/src/L3/eliminate_letrec.py index 63ea854..cc6b044 100644 --- a/packages/L3/src/L3/eliminate_letrec.py +++ b/packages/L3/src/L3/eliminate_letrec.py @@ -1,6 +1,5 @@ # noqa: F841 from collections.abc import Mapping -from functools import partial from L2 import syntax as L2 @@ -13,39 +12,108 @@ def eliminate_letrec_term( term: L3.Term, context: Context, ) -> L2.Term: - recur = partial(eliminate_letrec_term, context=context) - match term: case L3.Let(bindings=bindings, body=body): - pass + return L2.Let( + bindings=[(name, eliminate_letrec_term(value, context)) for name, value in bindings], + body=eliminate_letrec_term(body, context), + ) case L3.LetRec(bindings=bindings, body=body): - pass + # Mark all binding names as recursive in the context + binding_names = [name for name, _ in bindings] + new_context: Context = {**context, **dict.fromkeys(binding_names)} # type: ignore + + # Check which bindings need heap allocation based on their values + # Simple values (Immediate, Allocate) can be stored directly + # Complex values (everything else) need Allocate + Store + simple_binding_indices: set[int] = set() + for i, (_, value) in enumerate(bindings): + match value: + case L3.Immediate() | L3.Allocate(): + simple_binding_indices.add(i) + case _: + pass + + # Separate simple and complex bindings + simple_bindings: list[tuple[str, L2.Term]] = [] + complex_bindings: list[tuple[str, L3.Term]] = [] + complex_binding_names: list[str] = [] + + for i, (name, value) in enumerate(bindings): + if i in simple_binding_indices: + transformed_value = eliminate_letrec_term(value, new_context) + simple_bindings.append((name, transformed_value)) + else: + complex_bindings.append((name, value)) + complex_binding_names.append(name) + + # Create stores for complex bindings + stores: list[L2.Term] = [] + for name, value in complex_bindings: + transformed_value = eliminate_letrec_term(value, new_context) + stores.append( + L2.Store( + base=L2.Reference(name=name), + index=0, + value=transformed_value, + ) + ) + + # Transform the body + transformed_body = eliminate_letrec_term(body, new_context) + + # Build the result + all_bindings = simple_bindings + [(name, L2.Allocate(count=1)) for name in complex_binding_names] + + if stores: + return L2.Let( + bindings=all_bindings, + body=L2.Begin( + effects=stores, + value=transformed_body, + ), + ) + else: + return L2.Let( + bindings=all_bindings, + body=transformed_body, + ) case L3.Reference(name=name): # if name is a recursive variable -> (Load (Reference name))) # else (Reference name) - pass + if name in context: + return L2.Load(base=L2.Reference(name=name), index=0) + else: + return L2.Reference(name=name) case L3.Abstract(parameters=parameters, body=body): - pass + return L2.Abstract(parameters=parameters, body=eliminate_letrec_term(body, context)) case L3.Apply(target=target, arguments=arguments): - pass + return L2.Apply( + target=eliminate_letrec_term(target, context), + arguments=[eliminate_letrec_term(argument, context) for argument in arguments], + ) case L3.Immediate(value=value): return L2.Immediate(value=value) - case L3.Primitive(operator=_operator, left=left, right=right): - pass + case L3.Primitive(operator=operator, left=left, right=right): + return L2.Primitive( + operator=operator, + left=eliminate_letrec_term(left, context), + right=eliminate_letrec_term(right, context), + ) case L3.Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): return L2.Branch( operator=operator, - left=recur(left), - right=recur(right), - consequent=recur(consequent), - otherwise=recur(otherwise), + left=eliminate_letrec_term(left, context), + right=eliminate_letrec_term(right, context), + consequent=eliminate_letrec_term(consequent, context), + otherwise=eliminate_letrec_term(otherwise, context), ) case L3.Allocate(count=count): @@ -53,15 +121,22 @@ def eliminate_letrec_term( case L3.Load(base=base, index=index): return L2.Load( - base=recur(base), + base=eliminate_letrec_term(base, context), index=index, ) - case L3.Store(base=base, index=_index, value=value): - pass + case L3.Store(base=base, index=index, value=value): + return L2.Store( + base=eliminate_letrec_term(base, context), + index=index, + value=eliminate_letrec_term(value, context), + ) case L3.Begin(effects=effects, value=value): # pragma: no branch - pass + return L2.Begin( + effects=[eliminate_letrec_term(effect, context) for effect in effects], + value=eliminate_letrec_term(value, context), + ) def eliminate_letrec_program( diff --git a/packages/L3/src/L3/main.py b/packages/L3/src/L3/main.py index aaa9c6f..d623913 100644 --- a/packages/L3/src/L3/main.py +++ b/packages/L3/src/L3/main.py @@ -1,9 +1,10 @@ from pathlib import Path import click -from L1.to_python import to_ast_program -from L2.cps_convert import cps_convert_program + +# from L2.cps_convert import cps_convert_program from L2.optimize import optimize_program +from L2.to_python import to_ast_program from .check import check_program from .eliminate_letrec import eliminate_letrec_program @@ -51,15 +52,15 @@ def main( if check: check_program(l3) - fresh, l3 = uniqify_program(l3) + fresh, l3 = uniqify_program(l3) # type: ignore l2 = eliminate_letrec_program(l3) if optimize: l2 = optimize_program(l2) - l1 = cps_convert_program(l2, fresh) + # l1 = cps_convert_program(l2, fresh) - module = to_ast_program(l1) + module = to_ast_program(l2) (output or input.with_suffix(".py")).write_text(module) diff --git a/packages/L3/src/L3/parse.py b/packages/L3/src/L3/parse.py index 8fbd08f..5af096c 100644 --- a/packages/L3/src/L3/parse.py +++ b/packages/L3/src/L3/parse.py @@ -5,9 +5,20 @@ from lark.visitors import v_args # pyright: ignore[reportUnknownVariableType] from .syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, Identifier, + Immediate, Let, + LetRec, + Load, + Primitive, Program, + Reference, + Store, Term, ) @@ -27,9 +38,9 @@ def program( def parameters( self, - parameters: Sequence[Identifier], + parameters: Sequence[Token], ) -> Sequence[Identifier]: - return parameters + return [str(p) for p in parameters] @v_args(inline=True) def term( @@ -57,7 +68,7 @@ def letrec( bindings: Sequence[tuple[Identifier, Term]], body: Term, ) -> Term: - return Let( + return LetRec( bindings=bindings, body=body, ) @@ -71,10 +82,127 @@ def bindings( @v_args(inline=True) def binding( self, - name: Identifier, + name: Token, value: Term, ) -> tuple[Identifier, Term]: - return name, value + return str(name), value + + @v_args(inline=True) + def reference( + self, + name: Token, + ) -> Term: + return Reference(name=str(name)) + + @v_args(inline=True) + def abstract( + self, + _lambda: Token, + parameters: Sequence[Identifier], + body: Term, + ) -> Term: + return Abstract( + parameters=parameters, + body=body, + ) + + @v_args(inline=True) + def apply( + self, + target: Term, + *arguments: Term, + ) -> Term: + return Apply( + target=target, + arguments=list(arguments), + ) + + @v_args(inline=True) + def immediate( + self, + value: Token, + ) -> Term: + return Immediate(value=int(value)) + + @v_args(inline=True) + def primitive( + self, + operator: Token, + left: Term, + right: Term, + ) -> Term: + return Primitive( + operator=str(operator), # type: ignore + left=left, + right=right, + ) + + @v_args(inline=True) + def branch( + self, + _if: Token, + operator: Token, + left: Term, + right: Term, + consequent: Term, + otherwise: Term, + ) -> Term: + return Branch( + operator=str(operator), # type: ignore + left=left, + right=right, + consequent=consequent, + otherwise=otherwise, + ) + + @v_args(inline=True) + def allocate( + self, + _allocate: Token, + count: Immediate, + ) -> Term: + return Allocate( + count=count.value, + ) + + @v_args(inline=True) + def load( + self, + _load: Token, + base: Term, + index: Immediate, + ) -> Term: + return Load( + base=base, + index=index.value, + ) + + @v_args(inline=True) + def store( + self, + _store: Token, + base: Term, + index: Immediate, + value: Term, + ) -> Term: + return Store( + base=base, + index=index.value, + value=value, + ) + + def begin( + self, + args: Sequence[Token | Term], + ) -> Term: + # Filter out Token objects (like BEGIN token), keep only Terms + terms = [arg for arg in args if not isinstance(arg, Token)] + if len(terms) == 0: + raise ValueError("begin requires at least one term") + return Begin( + effects=terms[:-1], + value=terms[-1], + ) def parse_term(source: str) -> Term: diff --git a/packages/L3/src/L3/uniqify.py b/packages/L3/src/L3/uniqify.py index 2e242d6..ba3180a 100644 --- a/packages/L3/src/L3/uniqify.py +++ b/packages/L3/src/L3/uniqify.py @@ -9,6 +9,7 @@ Apply, Begin, Branch, + Identifier, Immediate, Let, LetRec, @@ -20,9 +21,10 @@ Term, ) -type Context = Mapping[str, str] +type Context = Mapping[Identifier, Identifier] +# This pass is responsible for renaming all identifiers in a program to be unique. This is necessary for the later stages of the compiler, which rely on the fact that all identifiers are unique. def uniqify_term( term: Term, context: Context, @@ -32,40 +34,95 @@ def uniqify_term( match term: case Let(bindings=bindings, body=body): - pass + new_bindings: list[tuple[Identifier, Term]] = [] + new_context = context + for name, value in bindings: + new_name = fresh(name) + new_bindings.append((new_name, _term(value))) + new_context = {**new_context, name: new_name} + + return Let( + bindings=new_bindings, + body=uniqify_term(body, new_context, fresh), + ) case LetRec(bindings=bindings, body=body): - pass + new_bindings: list[tuple[Identifier, Term]] = [] + new_context = context + for name, value in bindings: + new_name = fresh(name) + new_bindings.append((new_name, value)) + new_context = {**new_context, name: new_name} + + _new_term = partial(uniqify_term, context=new_context, fresh=fresh) + + return LetRec( + bindings=[(new_name, _new_term(value)) for new_name, value in new_bindings], + body=_new_term(body), + ) case Reference(name=name): - pass + if name in context: + return Reference(name=context[name]) + return term case Abstract(parameters=parameters, body=body): - pass + new_parameters = [fresh(parameter) for parameter in parameters] + new_context = { + **context, + **{parameter: new_parameter for parameter, new_parameter in zip(parameters, new_parameters)}, + } + return Abstract( + parameters=new_parameters, + body=uniqify_term(body, new_context, fresh), + ) case Apply(target=target, arguments=arguments): - pass + return Apply( + target=_term(target), + arguments=[_term(argument) for argument in arguments], + ) case Immediate(): - pass + return term case Primitive(operator=operator, left=left, right=right): - pass + return Primitive( + operator=operator, + left=_term(left), + right=_term(right), + ) case Branch(operator=operator, left=left, right=right, consequent=consequent, otherwise=otherwise): - pass + return Branch( + operator=operator, + left=_term(left), + right=_term(right), + consequent=_term(consequent), + otherwise=_term(otherwise), + ) case Allocate(): - pass + return term case Load(base=base, index=index): - pass + return Load( + base=_term(base), + index=index, + ) case Store(base=base, index=index, value=value): - pass + return Store( + base=_term(base), + index=index, + value=_term(value), + ) case Begin(effects=effects, value=value): # pragma: no branch - pass + return Begin( + effects=[_term(effect) for effect in effects], + value=_term(value), + ) def uniqify_program( diff --git a/packages/L3/test/L3/test_eliminate_letrec.py b/packages/L3/test/L3/test_eliminate_letrec.py index d9d42d2..7cf8d36 100644 --- a/packages/L3/test/L3/test_eliminate_letrec.py +++ b/packages/L3/test/L3/test_eliminate_letrec.py @@ -233,3 +233,203 @@ def test_eliminate_letrec_program(): ) assert actual == expected + + +def test_eliminate_letrec_reference_recursive(): + term = L3.LetRec( + bindings=[("x", L3.Immediate(value=0))], + body=L3.Reference(name="x"), + ) + expected = L2.Let( + bindings=[("x", L2.Immediate(value=0))], + body=L2.Load(base=L2.Reference(name="x"), index=0), + ) + actual = eliminate_letrec_term(term, context={}) + assert actual == expected + + +def test_eliminate_letrec_reference_non_recursive(): + term = L3.Reference(name="x") + + context: Context = {} + + expected = L2.Reference(name="x") + + actual = eliminate_letrec_term(term, context) + + assert actual == expected + + +def test_eliminate_letrec_body_uses_recursive_binding(): + term = L3.LetRec( + bindings=[("x", L3.Immediate(value=1))], + body=L3.Reference(name="x"), + ) + + expected = L2.Let( + bindings=[("x", L2.Immediate(value=1))], + body=L2.Load(base=L2.Reference(name="x"), index=0), + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_abstract_apply(): + term = L3.Apply( + target=L3.Abstract(parameters=["x"], body=L3.Reference(name="x")), + arguments=[L3.Immediate(value=1)], + ) + + expected = L2.Apply( + target=L2.Abstract(parameters=["x"], body=L2.Reference(name="x")), + arguments=[L2.Immediate(value=1)], + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_primitive_branch(): + term = L3.Branch( + operator="<", + left=L3.Primitive( + operator="+", + left=L3.Immediate(value=1), + right=L3.Immediate(value=2), + ), + right=L3.Immediate(value=4), + consequent=L3.Immediate(value=5), + otherwise=L3.Immediate(value=6), + ) + + expected = L2.Branch( + operator="<", + left=L2.Primitive( + operator="+", + left=L2.Immediate(value=1), + right=L2.Immediate(value=2), + ), + right=L2.Immediate(value=4), + consequent=L2.Immediate(value=5), + otherwise=L2.Immediate(value=6), + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_nested_primitives(): + term = L3.Primitive( + operator="+", + left=L3.Primitive( + operator="+", + left=L3.Immediate(value=1), + right=L3.Immediate(value=2), + ), + right=L3.Immediate(value=3), + ) + expected = L2.Primitive( + operator="+", + left=L2.Primitive( + operator="+", + left=L2.Immediate(value=1), + right=L2.Immediate(value=2), + ), + right=L2.Immediate(value=3), + ) + actual = eliminate_letrec_term(term, context={}) + assert actual == expected + + +def test_eliminate_letrec_branch_equals(): + term = L3.Branch( + operator="==", + left=L3.Immediate(value=1), + right=L3.Immediate(value=1), + consequent=L3.Immediate(value=7), + otherwise=L3.Immediate(value=8), + ) + + expected = L2.Branch( + operator="==", + left=L2.Immediate(value=1), + right=L2.Immediate(value=1), + consequent=L2.Immediate(value=7), + otherwise=L2.Immediate(value=8), + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_memory_and_begin(): + term = L3.Begin( + effects=[ + L3.Store( + base=L3.Reference(name="b"), + index=0, + value=L3.Immediate(value=2), + ), + ], + value=L3.Load( + base=L3.Allocate(count=1), + index=0, + ), + ) + + expected = L2.Begin( + effects=[ + L2.Store( + base=L2.Reference(name="b"), + index=0, + value=L2.Immediate(value=2), + ) + ], + value=L2.Load( + base=L2.Allocate(count=1), + index=0, + ), + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_allocate_and_load(): + term = L3.Load( + base=L3.Allocate(count=1), + index=0, + ) + + expected = L2.Load( + base=L2.Allocate(count=1), + index=0, + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected + + +def test_eliminate_letrec_store(): + term = L3.Store( + base=L3.Allocate(count=1), + index=0, + value=L3.Immediate(value=42), + ) + + expected = L2.Store( + base=L2.Allocate(count=1), + index=0, + value=L2.Immediate(value=42), + ) + + actual = eliminate_letrec_term(term, context={}) + + assert actual == expected diff --git a/packages/L3/test/L3/test_parse.py b/packages/L3/test/L3/test_parse.py index f48391a..e4aa85a 100644 --- a/packages/L3/test/L3/test_parse.py +++ b/packages/L3/test/L3/test_parse.py @@ -1,4 +1,5 @@ -from L3.parse import parse_program, parse_term +import pytest +from L3.parse import AstTransformer, parse_program, parse_term from L3.syntax import ( Abstract, Allocate, @@ -14,6 +15,7 @@ Reference, Store, ) +from lark import Token # Let @@ -295,3 +297,10 @@ def test_parse_program_identity(): actual = parse_program(source) assert actual == expected + + +def test_begin_requires_at_least_one_term(): + transformer = AstTransformer() + + with pytest.raises(ValueError, match="begin requires at least one term"): + transformer.begin([Token("BEGIN", "begin")]) diff --git a/packages/L3/test/L3/test_uniqify.py b/packages/L3/test/L3/test_uniqify.py index e2243e9..3fdaf62 100644 --- a/packages/L3/test/L3/test_uniqify.py +++ b/packages/L3/test/L3/test_uniqify.py @@ -1,5 +1,19 @@ -from L3.syntax import Apply, Immediate, Let, Reference -from L3.uniqify import Context, uniqify_term +from L3.syntax import ( + Abstract, + Allocate, + Apply, + Begin, + Branch, + Immediate, + Let, + LetRec, + Load, + Primitive, + Program, + Reference, + Store, +) +from L3.uniqify import Context, uniqify_program, uniqify_term from util.sequential_name_generator import SequentialNameGenerator @@ -15,6 +29,18 @@ def test_uniqify_term_reference(): assert actual == expected +def test_uniqify_term_reference_not_in_context(): + term = Reference(name="x") + + context: Context = {} + fresh = SequentialNameGenerator() + actual = uniqify_term(term, context, fresh=fresh) + + expected = Reference(name="x") + + assert actual == expected + + def test_uniqify_immediate(): term = Immediate(value=42) @@ -59,3 +85,119 @@ def test_uniqify_term_let(): ) assert actual == expected + + +def test_uniqify_term_letrec_and_abstract_and_apply(): + term = LetRec( + bindings=[ + ( + "f", + Abstract( + parameters=["x"], + body=Apply( + target=Reference(name="f"), + arguments=[Reference(name="x")], + ), + ), + ), + ("y", Reference(name="f")), + ], + body=Reference(name="y"), + ) + + context: Context = {"f": "outer_f"} + fresh = SequentialNameGenerator() + actual = uniqify_term(term, context, fresh) + + expected = LetRec( + bindings=[ + ( + "f0", + Abstract( + parameters=["x0"], + body=Apply( + target=Reference(name="f0"), + arguments=[Reference(name="x0")], + ), + ), + ), + ("y0", Reference(name="f0")), + ], + body=Reference(name="y0"), + ) + + assert actual == expected + + +def test_uniqify_term_memory_and_control_forms(): + term = Begin( + effects=[ + Store( + base=Reference(name="ptr"), + index=0, + value=Primitive( + operator="+", + left=Reference(name="a"), + right=Reference(name="b"), + ), + ) + ], + value=Branch( + operator="<", + left=Reference(name="a"), + right=Immediate(value=0), + consequent=Allocate(count=1), + otherwise=Load(base=Reference(name="ptr"), index=1), + ), + ) + + context: Context = {"a": "a1", "b": "b1", "ptr": "p1"} + fresh = SequentialNameGenerator() + actual = uniqify_term(term, context, fresh) + + expected = Begin( + effects=[ + Store( + base=Reference(name="p1"), + index=0, + value=Primitive( + operator="+", + left=Reference(name="a1"), + right=Reference(name="b1"), + ), + ) + ], + value=Branch( + operator="<", + left=Reference(name="a1"), + right=Immediate(value=0), + consequent=Allocate(count=1), + otherwise=Load(base=Reference(name="p1"), index=1), + ), + ) + + assert actual == expected + + +def test_uniqify_program(): + program = Program( + parameters=["x", "y"], + body=Primitive( + operator="+", + left=Reference(name="x"), + right=Reference(name="y"), + ), + ) + + _, actual = uniqify_program(program) + + expected = Program( + parameters=["x0", "y0"], + body=Primitive( + operator="+", + left=Reference(name="x0"), + right=Reference(name="y0"), + ), + ) + + assert actual == expected