diff --git a/regression-tests/tests/1opt-par1.out b/regression-tests/tests/1opt-par1.out new file mode 100644 index 00000000..7ed6ff82 --- /dev/null +++ b/regression-tests/tests/1opt-par1.out @@ -0,0 +1 @@ +5 diff --git a/regression-tests/tests/1opt-par1.ssl.failing b/regression-tests/tests/1opt-par1.ssl.failing new file mode 100644 index 00000000..a0e00409 --- /dev/null +++ b/regression-tests/tests/1opt-par1.ssl.failing @@ -0,0 +1,20 @@ +// once non-application expressions are allowed inside a par, +// this test should pass. + +// check that optimize par pass does NOT rewrite par expressions +// that have at least one non-instantaneous argument + +type Pair2 a b + Pair2 a b + +add a b = a + b + +main cin cout = + let q = par 3+4 // instantaneous + add 3 4 // functions are not necessarily instantaneous + + // ^should NOT rewrite q as a tuple + after 1, cout <- 5 + 48 // Should print 5 + wait cout + after 1, cout <- 10 + wait cout diff --git a/regression-tests/tests/1opt-par2.out b/regression-tests/tests/1opt-par2.out new file mode 100644 index 00000000..7ed6ff82 --- /dev/null +++ b/regression-tests/tests/1opt-par2.out @@ -0,0 +1 @@ +5 diff --git a/regression-tests/tests/1opt-par2.ssl b/regression-tests/tests/1opt-par2.ssl new file mode 100644 index 00000000..7471880c --- /dev/null +++ b/regression-tests/tests/1opt-par2.ssl @@ -0,0 +1,17 @@ +// check that optimize par pass rewrites par expressions +// that have all instantaneous arguemnts as tuples + +type Pair2 a b + Pair2 a b + +add a b = a + b + +main cin cout = + let q = par 2+3 // instantaneous + 3+4 // instantaneous + + // ^should rewrite q as a tuple + after 1, cout <- 5 + 48 // Should print 5 + wait cout + after 1, cout <- 10 + wait cout diff --git a/regression-tests/tests/1opt-par3.out b/regression-tests/tests/1opt-par3.out new file mode 100644 index 00000000..7ed6ff82 --- /dev/null +++ b/regression-tests/tests/1opt-par3.out @@ -0,0 +1 @@ +5 diff --git a/regression-tests/tests/1opt-par3.ssl b/regression-tests/tests/1opt-par3.ssl new file mode 100644 index 00000000..f9565e56 --- /dev/null +++ b/regression-tests/tests/1opt-par3.ssl @@ -0,0 +1,16 @@ +// check that optimize par pass does NOT rewrite par expressions +// that have at least one non-instantaneous argument + +type Pair2 a b + Pair2 a b + +add a b = a + b + +main cin cout = + let q = par add 2 3 // functions are not necessarily instantaneous + add 3 4 // functions are not necessarily instantaneous + // ^should NOT rewrite q as a tuple + after 1, cout <- 5 + 48 // Should print 5 + wait cout + after 1, cout <- 10 + wait cout diff --git a/regression-tests/tests/1opt-par4.out b/regression-tests/tests/1opt-par4.out new file mode 100644 index 00000000..f70f10e4 --- /dev/null +++ b/regression-tests/tests/1opt-par4.out @@ -0,0 +1 @@ +A diff --git a/regression-tests/tests/1opt-par4.ssl.failing b/regression-tests/tests/1opt-par4.ssl.failing new file mode 100644 index 00000000..3ec1fe99 --- /dev/null +++ b/regression-tests/tests/1opt-par4.ssl.failing @@ -0,0 +1,23 @@ +// once codegen supports return value of tuples, +// this test should pass + +// check that optimize par pass does NOT rewrite par expressions +// that have at least one non-instantaneous argument + +type Pair2 a b + Pair2 a b + +add a b = a + b + +main cin cout = + let x = 5 + y = 60 + let r = par add x y // functions are not necessarily instantaneous + add y x // functions are not necessarily instantaneous + // ^should NOT rewrite r as a tuple + match r + (0,0) = () + (a1,a2) = after 1, cout <- a2 + wait cout + after 1, cout <- 10 + wait cout \ No newline at end of file diff --git a/regression-tests/tests/1opt-par5.out b/regression-tests/tests/1opt-par5.out new file mode 100644 index 00000000..7ed6ff82 --- /dev/null +++ b/regression-tests/tests/1opt-par5.out @@ -0,0 +1 @@ +5 diff --git a/regression-tests/tests/1opt-par5.ssl.failing b/regression-tests/tests/1opt-par5.ssl.failing new file mode 100644 index 00000000..ba66dfc7 --- /dev/null +++ b/regression-tests/tests/1opt-par5.ssl.failing @@ -0,0 +1,17 @@ +// once non-application expressions are allowed inside a par, +// this test should pass. + +// check that optimize par pass does NOT rewrite par expressions +// that have at least one non-instantaneous argument + +add a b = a + b + +main cin cout = + let q = par 2+3 // instantaneous + 3+4 // instantaneous + wait cout // NOT instantaneous + // ^should NOT rewrite q as a tuple + after 1, cout <- 5 + 48 // Should print 5 + wait cout + after 1, cout <- 10 + wait cout diff --git a/regression-tests/tests/1opt-par6.out b/regression-tests/tests/1opt-par6.out new file mode 100644 index 00000000..7ed6ff82 --- /dev/null +++ b/regression-tests/tests/1opt-par6.out @@ -0,0 +1 @@ +5 diff --git a/regression-tests/tests/1opt-par6.ssl.failing b/regression-tests/tests/1opt-par6.ssl.failing new file mode 100644 index 00000000..d14df711 --- /dev/null +++ b/regression-tests/tests/1opt-par6.ssl.failing @@ -0,0 +1,16 @@ +// once non-application expressions are allowed inside a par, +// this test should pass. + +// check that optimize par pass does NOT rewrite par expressions +// that have at least one non-instantaneous argument + +add a b = a + b +main cin cout = + let q = par wait cout // NOT instantaneous + wait cout // NOT instantaneous + wait cout // NOT instantaneous + // ^should NOT rewrite q as a tuple + after 1, cout <- 5 + 48 // Should print 5 + wait cout + after 1, cout <- 10 + wait cout diff --git a/regression-tests/tests/1opt-par7.failing b/regression-tests/tests/1opt-par7.failing new file mode 100644 index 00000000..5b1662b4 --- /dev/null +++ b/regression-tests/tests/1opt-par7.failing @@ -0,0 +1,41 @@ +// once codegen supports return value of tuples, +// this test should pass + +type Pair2 a b + Pair2 a b + +add a b = a + b + +printCharTuple putc p = + match p + (x,y) = putc x + putc 32 + putc y + +main cin cout = + let putc c = after 1, cout <- c + wait cout + let putnl _ = putc 10 + + let x = 66 + let y = 67 + + // tuple with arguments evaluated sequentially + let w = (add x 0, add y 0) + // let's print it out + printCharTuple putc w // this is okay + putnl () + + // tuple with arguments already evaluated + let r = (x,y) + // let's print it out + printCharTuple putc r // this is okay + putnl () + + // code below causes segmentation fault + // par with arguments evaluated at the same time + let q = par add x 0 + add y 0 + // let's print it out + printCharTuple putc q // this is okay + putnl () \ No newline at end of file diff --git a/regression-tests/tests/1opt-par7.out b/regression-tests/tests/1opt-par7.out new file mode 100644 index 00000000..e7dda80d --- /dev/null +++ b/regression-tests/tests/1opt-par7.out @@ -0,0 +1,3 @@ +B C +B C +B C diff --git a/src/IR.hs b/src/IR.hs index 09399d32..7fb45075 100644 --- a/src/IR.hs +++ b/src/IR.hs @@ -40,6 +40,7 @@ import Text.Show.Pretty -} data Mode = Continue + | DumpOptPar | DumpIR | DumpIRAnnotated | DumpIRConstraints @@ -70,6 +71,11 @@ options = ["dump-ir"] (NoArg $ setMode DumpIRAnnotated) "Print the IR immediately after lowering" + , Option + "" + ["dumpOptPar"] + (NoArg $ setMode DumpOptPar) + "Print the IR immediately after opt par pass" , Option "" ["dump-ir-annotated"] @@ -153,6 +159,7 @@ transform opt p = do p <- dConToFunc p p <- externToCall p p <- optimizePar p + when (mode opt == DumpOptPar) $ (throwError . Dump . ppShow) p p <- liftProgramLambdas p p <- segmentLets p when (mode opt == DumpIRLifted) $ dump p diff --git a/src/IR/OptimizePar.hs b/src/IR/OptimizePar.hs index 6cf7f1cb..85dab248 100644 --- a/src/IR/OptimizePar.hs +++ b/src/IR/OptimizePar.hs @@ -1,5 +1,8 @@ {-# LANGUAGE DerivingVia #-} {-# LANGUAGE OverloadedStrings #-} +--make draft, add 2 test cases + +--to debug, look at the isBad {- | Remove unnecessary Par expressions from the IR @@ -14,13 +17,15 @@ import qualified Common.Compiler as Compiler import Control.Monad.State.Lazy ( MonadState, StateT (..), - evalStateT, - gets, - modify, + evalStateT ) -import IR.IR (Literal (LitIntegral)) +-- import IR.IR (Literal (LitIntegral)) import qualified IR.IR as I - +-- import Data.Bifunctor +import Common.Identifiers(Identifier (Identifier), TVarId (..)) +import IR.Types.Type +import Data.Generics.Aliases ( mkM ) +import Data.Generics.Schemes ( everywhereM ) -- | Optimization Environment data OptParCtx = OptParCtx @@ -32,7 +37,7 @@ data OptParCtx = OptParCtx -- | OptPar Monad -newtype OptParFn a = LiftFn (StateT OptParCtx Compiler.Pass a) +newtype OptParFn a = OptParFn (StateT OptParCtx Compiler.Pass a) deriving (Functor) via (StateT OptParCtx Compiler.Pass) deriving (Applicative) via (StateT OptParCtx Compiler.Pass) deriving (Monad) via (StateT OptParCtx Compiler.Pass) @@ -42,19 +47,19 @@ newtype OptParFn a = LiftFn (StateT OptParCtx Compiler.Pass a) -- | Example func to delete later! Demonstrates how to extract a value from the OptParFn Monad -getNumberOfPars :: OptParFn Int -getNumberOfPars = gets numPars +-- getNumberOfPars :: OptParFn Int +-- getNumberOfPars = gets numPars -- | Example func to delete later! Demonstrates how to modify a value in the OptParFn Monad -updateNumberOfPars :: Int -> OptParFn () -updateNumberOfPars num = do - modify $ \st -> st{numPars = num} +-- updateNumberOfPars :: Int -> OptParFn () +-- updateNumberOfPars num = do +-- modify $ \st -> st{numPars = num} --- | Run a LiftFn computation. -runLiftFn :: OptParFn a -> Compiler.Pass a -runLiftFn (LiftFn m) = +-- | Run a OptParFn computation. +runOptParFn :: OptParFn a -> Compiler.Pass a +runOptParFn (OptParFn m) = evalStateT m OptParCtx @@ -63,54 +68,129 @@ runLiftFn (LiftFn m) = } +--traversing the ir replaced with everywhere + +--rewrite of case1, transorm into tuple is operational + +--isbad is working + +-- run on all regression testss + +--check the types + +--can only take ut the instantenous expression if they occur before + + + +-- prepare for case 2 and casse 3 + {- | Entry-point to Par Optimization. Maps over top level definitions, removing unnecessary pars. -} optimizePar :: I.Program I.Type -> Compiler.Pass (I.Program I.Type) -optimizePar p = runLiftFn $ do - optimizedDefs <- mapM optimizeParTop $ I.programDefs p - return $ p{I.programDefs = optimizedDefs} +optimizePar p = runOptParFn $ do + defs' <- everywhereM (mkM findFixBadPar) $ I.programDefs p + return $ p{I.programDefs = defs'} + -- optimizedDefs <- mapM optimizeParTop $ I.programDefs p + -- fail ("Number of Bad Par Exprs in " ++ show (map fst (map tupleMatch1 optimizedDefs)) ++ ": " ++ (show (map tupleMatch optimizedDefs))) + -- return $ p{I.programDefs = map tupleMatch1 optimizedDefs} -- | Given a top-level definition, detect + replace unnecessary par expressions -optimizeParTop :: (I.Binder I.Type, I.Expr I.Type) -> OptParFn (I.Binder I.Type, I.Expr I.Type) -optimizeParTop (nm, rhs) = do - rhs' <- detectReplaceBadPar rhs - (rhs'', _) <- countPars rhs' -- calling this so we don't get an "unused" warning - (rhs''', _) <- countBadPars rhs'' -- calling this so we don't get an "unused" warning - -- uncomment the line below to test countPars - -- (_, result) <- countPars rhs - -- _ <- fail (show nm ++ ": Number of Par Exprs: " ++ show result) - -- uncomment the two lines below to test countBadPars - -- (_, result') <- countBadPars rhs - -- _ <- fail (show nm ++ ": Number of Bad Par Exprs: " ++ show result') - return (nm, rhs''') +-- optimizeParTop :: (I.Binder I.Type, I.Expr I.Type) -> OptParFn (I.Binder I.Type, I.Expr I.Type) +-- optimizeParTop (nm, rhs) = do +-- rhs' <- detectReplaceBadPar rhs +-- (rhs'', _) <- countPars rhs' -- calling this so we don't get an "unused" warning +-- (rhs''', _) <- countBadPars rhs'' -- calling this so we don't get an "unused" warning +-- -- uncomment the line below to test countPars +-- -- (_, result) <- countPars rhs +-- -- _ <- fail (show nm ++ ": Number of Par Exprs: " ++ show result) +-- -- uncomment the two lines below to test countBadPars +-- -- (_, result') <- countBadPars rhs +-- -- _ <- fail (show nm ++ ": Number of Bad Par Exprs: " ++ show result') +-- return (nm, rhs''') +{- | Given an Expr as input, if it turns out to be a bad Par expr, rewrite it. --- | Detect Unnecessary Par Expressions + Replace With Equivalent Sequential Expression -detectReplaceBadPar :: I.Expr I.Type -> OptParFn (I.Expr I.Type) -detectReplaceBadPar e = do - pure e -- for now, just return the same thing (don't do anyting) +Otherwise, leave the expression alone. +checks for case1: -{- | 1) Count Par Nodes +case 1: +par 5 + 1 + 3 + 2 +-} -Practice Exercise to Delete Later! +--import foldApp from IR.IR +-- import tempTupleId from IR.Types.Type +findFixBadPar :: I.Expr I.Type -> OptParFn (I.Expr I.Type) +findFixBadPar e@__ = if isBad e then rewrite e else pure e + where rewrite :: I.Expr I.Type -> OptParFn (I.Expr I.Type) + -- structure of IR + rewrite (I.Prim I.Par exprlist _) = pure x + where dataConstructorName = "Pair2" + t = TVar $ TVarId (Identifier "PINEAPPLE") --TODO: put in actual type here!!! + --construct type that is tuple of arguments + x = I.foldApp dConNode argsToTuple + dConNode = I.Data (I.DConId (Identifier dataConstructorName)) t + argsToTuple = zip exprlist (repeat t) -Traverse the IR representation of the body of a top level defintion, -and count the number of par expressions present. -Return the body unchanged, as well as the count numPars. --} -countPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int) -countPars e = do - -- currently a stub - -- PUT YOUR IMPLEMENTATION HERE - x <- getNumberOfPars - updateNumberOfPars (x + 0) -- calling this so we don't get an "unused" warning - return (e, x) + -- pure dummy --TODO: rewrite the bad par as good one + + rewrite _ = fail "rewrite should only be called on a Par IR node!" + -- dummy = I.Var (I.VarId (Identifier "PINEAPPLE")) (TVar $ TVarId (Identifier "dummy")) +{- +case 1: +par 5 + 1 + 3 + 2 +^ we agree this par expr is bad +^ this par returns the value (5+1,3+2) + +We want to rewrite case 1 into +(5+1,3+2) +which really desugars into +(Pair2 5+1 3+2) +which as an IR node is +(I.App (I.App (I.DCon DConId "Pair2") (I.Prim (I.PrimOp PrimAdd) [(I.Lit 5), (I.Lit 1)])) (I.Prim (I.PrimOp PrimAdd) [(I.Lit 3), (I.Lit 1)])) + +What does par 5 + 1 look like as an IR node? + 3 + 2 +Prim I.Par [I.Prim (I.PrimOp PrimAdd) [(I.Lit 5), (I.Lit 1)], I.Prim (I.PrimOp PrimAdd) [(I.Lit 3), (I.Lit 1)]] t + +let arg1 = I.Prim (I.PrimOp PrimAdd) [(I.Lit 5), (I.Lit 1)] +let arg2 = I.Prim (I.PrimOp PrimAdd) [(I.Lit 3), (I.Lit 1)] +foldApp (I.Dcon I.DConId "Pair2") [arg1, arg2] +^foldApp returns a nested application such that "Pair2" is applied to a list of arguments + +(I.Prim I.Par exprlist _) + | + | + v +let tupleDataConstructorName = tempTupleId (length exprlist) // returns "Pair2" or "Pair3", or whatever you need +foldApp tupleDataConstructorName exprlist + + +--Case on type of Prime, whether has a Wait or a PrimOp +what is 5+1 as an IR node? +I.Prim (I.PrimOp PrimAdd) [(I.Lit 5), (I.Lit 1)] +^ do you agree with this? + +what is 3+2 as an IR node? +I.Prim (I.PrimOp PrimAdd) [(I.Lit 3), (I.Lit 1)] + + +What is (Pair2 5+1 3+2) as an IR node? +We know for reference: add 5 1 as an IR node is (I.App (I.App (I.Var VarId "add") (I.Lit 5)) (I.Lit 1)) +so we know that pair2 applied to its two arguments will look like in the IR as +(I.App (I.App (I.DCon DConId "Pair2") (I.Prim (I.PrimOp PrimAdd) [(I.Lit 5), (I.Lit 1)])) (I.Prim (I.PrimOp PrimAdd) [(I.Lit 3), (I.Lit 1)])) + +We have a library function called foldApp that takes a function name and a list of arguments, +and wraps them up in application. +-} + {- | 1.5) Implement IsBad Predicate Suggested by John during Monday Meeting. @@ -118,8 +198,46 @@ Returns true if par expr contains only instantaneous expressions as arguments. False otherwise. Useful for exercise 2. -} + +--helper function +--bad par expr: par nodes lisit of arguments contains a wait +-- we assume that all function calls are blocking +--variable are non blocking +--prim operatiions that are not wait are non blocking +--literals are non blockiing +-- function calls are application nodes App (Expr t) (Expr t) t +-- nested function application, +-- add x 0 -> App(App(add,x),0) +-- a@(App _ _ ) +--whenever there any any non blocking calls, reqrite the expression so that the non blocking calls +--outside with a let, and blocking call remain with the par +-- if par has just 0 or 1 blocking call, no need for par +-- take our non blocking, append to tuple in correct position + +isNotWait :: I.Expr I.Type -> Bool +isNotWait (I.Prim I.Wait _ _) = False +isNotWait _ = True + +isNotFunction :: I.Expr I.Type -> Bool +isNotFunction I.App {} = False +isNotFunction _ = True +-- isNotFunction (I.App expr1 expr2 t) = False +-- isNotFunction ( _ ) = True + + isBad :: I.Expr I.Type -> Bool -isBad _ = False -- currently a stub +--isBad theExpr = False -- currently a stub +isBad (I.Prim I.Par exprlist _) = do + let left = all isNotWait exprlist + let right = all isNotFunction exprlist + left && right +isBad _ = False + + + +--isBad look at arguments to par if there is a wait priumitive, its bad + + {- | 2) Count Bad Par Nodes @@ -131,8 +249,11 @@ and count the number of BAD par expressions present. Use the helper predicate "isBad" in your implementation. Return the body unchanged, as well as the count numBadPars. -} -countBadPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int) -countBadPars e = do - -- currently a stub - let y = isBad (I.Lit (LitIntegral 5) (I.extract e)) -- calling this so we don't get an "unused" warning - return (e, fromEnum y) + +--not using this + +-- countBadPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int) +-- countBadPars e = do +-- -- currently a stub +-- let y = isBad (I.Lit (LitIntegral 5) (I.extract e)) -- calling this so we don't get an "unused" warning +-- return (e, fromEnum y)