From 93c887b4a5e7b9d433ffc9d6714278d996310429 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Tue, 25 Mar 2025 19:09:34 +0000 Subject: [PATCH 1/8] handle simple polymorphic functions --- granule-compiler.cabal | 1 + src/Language/Granule/Codegen/Compile.hs | 7 +- src/Language/Granule/Codegen/Monomorphise.hs | 232 +++++++++++++++++++ tests/golden/positive/poly-simple.golden | 1 + tests/golden/positive/poly-simple.gr | 5 + 5 files changed, 242 insertions(+), 4 deletions(-) create mode 100644 src/Language/Granule/Codegen/Monomorphise.hs create mode 100644 tests/golden/positive/poly-simple.golden create mode 100644 tests/golden/positive/poly-simple.gr diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 85f4afd..79e98bf 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -44,6 +44,7 @@ library Language.Granule.Codegen.Emit.Names Language.Granule.Codegen.Emit.Primitives Language.Granule.Codegen.Emit.Types + Language.Granule.Codegen.Monomorphise Paths_granule_compiler hs-source-dirs: src diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index b747701..2c48da2 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE ImplicitParams #-} {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} module Language.Granule.Codegen.Compile where @@ -9,13 +8,13 @@ import Language.Granule.Codegen.TopsortDefinitions import Language.Granule.Codegen.ConvertClosures import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals +import Language.Granule.Codegen.Monomorphise import qualified LLVM.AST as IR ---import Language.Granule.Syntax.Pretty ---import Debug.Trace compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let normalised = {-trace (show typedAST)-} (normaliseDefinitions typedAST) + let monomorphised = monomorphiseAST typedAST + normalised = normaliseDefinitions monomorphised markedGlobals = markGlobals normalised (Ok topsorted) = topologicallySortDefinitions markedGlobals closureFree = convertClosures topsorted diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs new file mode 100644 index 0000000..4b3ed3b --- /dev/null +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -0,0 +1,232 @@ +module Language.Granule.Codegen.Monomorphise (monomorphiseAST) where + +import Control.Monad.Identity (runIdentity) +import Data.Bifunctor (Bifunctor (bimap), second) +import Data.Char (isAlphaNum, toLower) +import qualified Data.Map as Map +import Language.Granule.Syntax.Annotated (annotation) +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr hiding (subst) +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- polymorphic id -> [monomorphic id, [(ty var, ty subst)]] +type PolyInstances = Map.Map Id [(Id, [(Id, Type)])] + +-- polymorphic id -> [ty var] +type PolyFuncs = Map.Map Id [Id] + +-- TODO: support tyvar in any argument position and support more than 1 tyvar per function +-- currently only supports single tyvars in first argument + +-- create monomorphic versions for each required instance of polymorphic function and rewrite ast +monomorphiseAST :: AST ev Type -> AST ev Type +monomorphiseAST ast = + let polymorphicFuncs = getPolymorphicFunctions ast + env = collectInstances ast polymorphicFuncs + monoDefs = makeMonoDefs ast env + rewritten = rewriteCalls ast env + in rewritten {definitions = filter (not . isPolymorphic) (definitions rewritten) ++ monoDefs} + +isPolymorphic :: Def ev Type -> Bool +isPolymorphic def = + case defTypeScheme def of + Forall _ ((_, Type 0) : _) _ _ -> True + _ -> False + +-- e.g. id : a -> a when a is int becomes __id_int +makeMonoId :: Id -> [Type] -> Id +makeMonoId (Id internal source) types = + let typeSuffix = concatMap (\ty -> "_" ++ typeToSafeString ty) types + name = "__" ++ internal ++ typeSuffix + in Id name name + where + -- TODO: needs work so we don't make extralong names + typeToSafeString (TyCon (Id _ id)) = map sanitiseChar id + typeToSafeString t = map sanitiseChar (show t) + sanitiseChar c + | isAlphaNum c = toLower c + | otherwise = '_' + +-- create map of polymorphic function id to its ty vars +getPolymorphicFunctions :: AST ev Type -> PolyFuncs +getPolymorphicFunctions ast = + Map.fromList $ map getPolyInfo $ filter isPolymorphic $ definitions ast + where + getPolyInfo :: Def ev Type -> (Id, [Id]) + getPolyInfo def = + case defTypeScheme def of + Forall _ bindings _ _ -> + let tyVars = map fst bindings + in (defId def, tyVars) + +-- collect all insts of polymorphic functions with their concrete type substitutions +collectInstances :: AST ev Type -> PolyFuncs -> PolyInstances +collectInstances ast fns = + foldl collectDef Map.empty (definitions ast) + where + collectDef env def = + let defInstances = foldl collectEquation Map.empty (equations $ defEquations def) + in Map.unionWith (++) env defInstances + + collectEquation env eq = collectExpr (equationBody eq) + + collectExpr :: Expr ev Type -> PolyInstances + collectExpr (App _ _ _ e1 e2) = + let inst = case getPolymorphicCall fns e1 e2 of + Just (id, tyVarSubsts) -> + Map.singleton id [(makeMonoId id (map snd tyVarSubsts), tyVarSubsts)] + Nothing -> Map.empty + in Map.unionWith (++) (collectExprs [e1, e2]) inst + collectExpr (Val _ _ _ val) = collectVal val + collectExpr (Binop _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (Case _ _ _ e bs) = collectExprs (e : map snd bs) + collectExpr (AppTy _ _ _ e _) = collectExpr e + collectExpr (LetDiamond _ _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (TryCatch _ _ _ e1 _ _ e2 e3) = collectExprs [e1, e2, e3] + collectExpr (Unpack _ _ _ _ _ e1 e2) = collectExprs [e1, e2] + collectExpr (Hole {}) = Map.empty + + collectVal :: Value ev Type -> PolyInstances + collectVal (Abs _ _ _ body) = collectExpr body + collectVal (Constr _ _ vals) = foldr (Map.unionWith (++) . collectVal) Map.empty vals + collectVal (Pure _ e) = collectExpr e + collectVal (Promote _ e) = collectExpr e + collectVal (Nec _ e) = collectExpr e + collectVal (Ref _ e) = collectExpr e + collectVal (Pack _ _ _ e _ _ _) = collectExpr e + collectVal (TyAbs _ _ e) = collectExpr e + collectVal _ = Map.empty + + -- combine results from multiple expressions + collectExprs :: [Expr ev Type] -> PolyInstances + collectExprs = foldr (Map.unionWith (++) . collectExpr) Map.empty + +-- identify polymorphic function calls and get substitution info +getPolymorphicCall :: PolyFuncs -> Expr ev Type -> Expr ev Type -> Maybe (Id, [(Id, Type)]) +getPolymorphicCall fns (Val _ _ _ (Var _ id)) arg = + case Map.lookup id fns of + Just tyVars -> + let argType = annotation arg + substitutions = case argType of + TyVar _ -> [] + _ -> [(head tyVars, argType)] + in Just (id, substitutions) + Nothing -> Nothing +getPolymorphicCall _ _ _ = Nothing + +-- create monomorphised definitions for all polymorphic function insts +makeMonoDefs :: AST ev Type -> PolyInstances -> [Def ev Type] +makeMonoDefs ast env = concatMap (monoDefsForFunc ast) (Map.toList env) + +monoDefsForFunc :: AST ev Type -> (Id, [(Id, [(Id, Type)])]) -> [Def ev Type] +monoDefsForFunc ast (id, instances) = + let og = head (filter (\def -> defId def == id) (definitions ast)) + in map (monoDef og) instances + +monoDef :: Def ev Type -> (Id, [(Id, Type)]) -> Def ev Type +monoDef (Def s _ r spec eqs ts) (id', typeSubsts) = + let subs = Map.fromList typeSubsts + subst = substTy subs + eqs' = monoEqList eqs id' subs subst + ts' = substTypeScheme ts subs subst + in Def s id' r spec eqs' ts' + +monoEqList :: EquationList ev Type -> Id -> Map.Map Id Type -> (Type -> Type) -> EquationList ev Type +monoEqList (EquationList s _ r eqs) id' subs applySubst = + let eqs' = map (monoEq id' subs applySubst) eqs + in EquationList s id' r eqs' + +monoEq :: Id -> Map.Map Id Type -> (Type -> Type) -> Equation ev Type -> Equation ev Type +monoEq id' subs applySubst (Equation s id a r ps b) = + let a' = applySubst a + ps' = map (substPat subs) ps + b' = substExpr subs b + in Equation s id' a' r ps' b' + +substTypeScheme :: TypeScheme -> Map.Map Id Type -> (Type -> Type) -> TypeScheme +substTypeScheme (Forall s bs cs ty) subs applySubst = + let bs' = filter (\(tyVar, _) -> not (Map.member tyVar subs)) bs + ty' = applySubst ty + in Forall s bs' cs ty' + +-- use typeFold with our substitution map +substTy :: Map.Map Id Type -> Type -> Type +substTy subs ty = + runIdentity $ typeFoldM (baseTypeFold {tfTyVar = substVar}) ty + where + substVar id = return $ Map.findWithDefault (TyVar id) id subs + +substPat :: Map.Map Id Type -> Pattern Type -> Pattern Type +substPat subs = + patternFold + (\s ty r id -> PVar s (substTy subs ty) r id) + (\s ty r -> PWild s (substTy subs ty) r) + (\s ty r pat -> PBox s (substTy subs ty) r pat) + (\s ty r i -> PInt s (substTy subs ty) r i) + (\s ty r f -> PFloat s (substTy subs ty) r f) + (\s ty r id ids pats -> PConstr s (substTy subs ty) r id ids pats) + +substExpr :: Map.Map Id Type -> Expr ev Type -> Expr ev Type +substExpr subs expr = + case expr of + App s ty r f arg -> App s (apply ty) r (subExp f) (subExp arg) + Val s ty r val -> Val s (apply ty) r (fmap apply val) + Binop s ty r op e1 e2 -> Binop s (apply ty) r op (subExp e1) (subExp e2) + Case s ty r e ps -> Case s (apply ty) r (subExp e) (map (bimap (substPat subs) subExp) ps) + Hole s ty r ids hs -> Hole s (apply ty) r ids hs + AppTy s ty r e t -> AppTy s (apply ty) r (subExp e) t + TryCatch s ty r e p mt e1 e2 -> TryCatch s (apply ty) r e p mt (subExp e1) (subExp e2) + Unpack s ty r tyVar var e1 e2 -> Unpack s (apply ty) r tyVar var (subExp e1) (subExp e2) + LetDiamond s ty r ps mt e1 e2 -> LetDiamond s (apply ty) r ps mt (subExp e1) (subExp e2) + where + apply = substTy subs + subExp = substExpr subs + +-- rewrite polymorphic function calls to use the monomorphised versions +rewriteCalls :: AST ev Type -> PolyInstances -> AST ev Type +rewriteCalls ast env = ast {definitions = map rewriteDef (definitions ast)} + where + rewriteDef def = def {defEquations = rewriteEqList (defEquations def)} + rewriteEqList eqs = eqs {equations = map rewriteEq (equations eqs)} + rewriteEq eq = eq {equationBody = rewriteExpr (equationBody eq)} + + rewriteExpr :: Expr ev Type -> Expr ev Type + rewriteExpr expr@(App s ty r f arg) = + let rewrittenF = rewriteExpr f + rewrittenArg = rewriteExpr arg + newF = case rewrittenF of + Val s' t' r' (Var vt id) -> + -- Only rewrite if this is a polymorphic function in our map + if Map.member id env + then + let argTy = annotation rewrittenArg + ty' = FunTy Nothing Nothing argTy ty + in Val s' ty' r' (Var ty' (makeMonoId id [argTy])) + else rewrittenF + _ -> rewrittenF + in App s ty r newF rewrittenArg + rewriteExpr (Val s ty r val) = Val s ty r (rewriteVal val) + rewriteExpr (Binop s ty r op e1 e2) = Binop s ty r op (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (Case s ty r e ps) = Case s ty r (rewriteExpr e) (map (second rewriteExpr) ps) + rewriteExpr (Hole s a b ids hs) = Hole s a b ids hs + rewriteExpr (AppTy s a b e t) = AppTy s a b (rewriteExpr e) t + rewriteExpr (TryCatch s a b e p mt e1 e2) = TryCatch s a b (rewriteExpr e) p mt (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (Unpack s a rf tyVar var e1 e2) = Unpack s a rf tyVar var (rewriteExpr e1) (rewriteExpr e2) + rewriteExpr (LetDiamond s a b ps mt e1 e2) = LetDiamond s a b ps mt (rewriteExpr e1) (rewriteExpr e2) + + rewriteVal :: Value ev Type -> Value ev Type + rewriteVal (Abs a pat mt e) = Abs a pat mt (rewriteExpr e) + rewriteVal (Constr a idv vals) = Constr a idv (map rewriteVal vals) + rewriteVal (Promote a e) = Promote a (rewriteExpr e) + rewriteVal (Pure a e) = Pure a (rewriteExpr e) + rewriteVal (Nec a e) = Nec a (rewriteExpr e) + rewriteVal (Pack s a ty e v k ty') = Pack s a ty (rewriteExpr e) v k ty' + rewriteVal (TyAbs a v e) = TyAbs a v (rewriteExpr e) + rewriteVal (NumInt n) = NumInt n + rewriteVal (NumFloat n) = NumFloat n + rewriteVal (CharLiteral ch) = CharLiteral ch + rewriteVal (StringLiteral str) = StringLiteral str + rewriteVal (Ext a ev) = Ext a ev + rewriteVal (Var a id) = Var a id diff --git a/tests/golden/positive/poly-simple.golden b/tests/golden/positive/poly-simple.golden new file mode 100644 index 0000000..d81cc07 --- /dev/null +++ b/tests/golden/positive/poly-simple.golden @@ -0,0 +1 @@ +42 diff --git a/tests/golden/positive/poly-simple.gr b/tests/golden/positive/poly-simple.gr new file mode 100644 index 0000000..9c2104f --- /dev/null +++ b/tests/golden/positive/poly-simple.gr @@ -0,0 +1,5 @@ +id : forall {a : Type} . a -> a +id x = x + +main : Int +main = id 42 From 7fb6938c3a0e794b9ed617feedcb0764bf5f7e38 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Tue, 25 Mar 2025 19:30:26 +0000 Subject: [PATCH 2/8] saner names --- src/Language/Granule/Codegen/Monomorphise.hs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs index 4b3ed3b..67c5f04 100644 --- a/src/Language/Granule/Codegen/Monomorphise.hs +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -2,7 +2,6 @@ module Language.Granule.Codegen.Monomorphise (monomorphiseAST) where import Control.Monad.Identity (runIdentity) import Data.Bifunctor (Bifunctor (bimap), second) -import Data.Char (isAlphaNum, toLower) import qualified Data.Map as Map import Language.Granule.Syntax.Annotated (annotation) import Language.Granule.Syntax.Def @@ -35,19 +34,12 @@ isPolymorphic def = Forall _ ((_, Type 0) : _) _ _ -> True _ -> False --- e.g. id : a -> a when a is int becomes __id_int +-- e.g. id -> __id_3856 makeMonoId :: Id -> [Type] -> Id -makeMonoId (Id internal source) types = - let typeSuffix = concatMap (\ty -> "_" ++ typeToSafeString ty) types - name = "__" ++ internal ++ typeSuffix +makeMonoId (Id id _) types = + let hash = abs $ sum $ map fromEnum (show types) + name = "__" ++ id ++ "_" ++ show hash in Id name name - where - -- TODO: needs work so we don't make extralong names - typeToSafeString (TyCon (Id _ id)) = map sanitiseChar id - typeToSafeString t = map sanitiseChar (show t) - sanitiseChar c - | isAlphaNum c = toLower c - | otherwise = '_' -- create map of polymorphic function id to its ty vars getPolymorphicFunctions :: AST ev Type -> PolyFuncs From c07c4a24cb3f41046e506eddb52daf6bf45801d4 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Thu, 27 Mar 2025 03:33:32 +0000 Subject: [PATCH 3/8] better support (see test/poly-curry.gr) --- granule-compiler.cabal | 4 + package.yaml | 1 + src/Language/Granule/Codegen/Emit/EmitLLVM.hs | 4 +- src/Language/Granule/Codegen/Monomorphise.hs | 73 ++++++++++++++----- tests/golden/positive/poly-curry.golden | 1 + tests/golden/positive/poly-curry.gr | 27 +++++++ 6 files changed, 89 insertions(+), 21 deletions(-) create mode 100644 tests/golden/positive/poly-curry.golden create mode 100644 tests/golden/positive/poly-curry.gr diff --git a/granule-compiler.cabal b/granule-compiler.cabal index 79e98bf..4430ef8 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -59,6 +59,7 @@ library base >=4.10 && <5 , containers , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , mtl @@ -88,6 +89,7 @@ executable grlc , gitrev , granule-compiler , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , optparse-applicative @@ -120,6 +122,7 @@ test-suite compiler-spec , filemanip , granule-compiler , granule-frontend + , hashable , hspec , mtl , process @@ -146,6 +149,7 @@ test-suite golden , filepath , granule-compiler , granule-frontend + , hashable , llvm-hs ==12.* , llvm-hs-pure ==12.* , process diff --git a/package.yaml b/package.yaml index 784c838..df5eb50 100644 --- a/package.yaml +++ b/package.yaml @@ -8,6 +8,7 @@ github: granule-project/granule-compiler-llvm dependencies: - base >=4.10 && <5 - process + - hashable default-extensions: - LambdaCase - RecordWildCards diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index 8d324bf..4aa5b38 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -123,7 +123,9 @@ emitFunction _ _ _ _ _ = error "cannot emit function with non function type" paramName :: Pattern GrType -> ParameterName paramName (PConstr _ _ _ (Id "," _) _ _) = parameterNameFromId $ mkId "pair" paramName (PConstr _ _ _ (Id "()" _) _ _) = parameterNameFromId $ mkId "unit" -paramName pat = parameterNameFromId $ head $ boundVars pat +paramName pat = case boundVars pat of + [] -> parameterNameFromId $ mkId "wildcard" + (var : _) -> parameterNameFromId var emitArg :: (MonadState EmitterState m, MonadModuleBuilder m, MonadIRBuilder m) => Pattern GrType diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs index 67c5f04..7cff7a5 100644 --- a/src/Language/Granule/Codegen/Monomorphise.hs +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -2,7 +2,10 @@ module Language.Granule.Codegen.Monomorphise (monomorphiseAST) where import Control.Monad.Identity (runIdentity) import Data.Bifunctor (Bifunctor (bimap), second) +import Data.Foldable (find) +import Data.Hashable (hash) import qualified Data.Map as Map +import Data.Maybe (fromMaybe) import Language.Granule.Syntax.Annotated (annotation) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr hiding (subst) @@ -16,29 +19,31 @@ type PolyInstances = Map.Map Id [(Id, [(Id, Type)])] -- polymorphic id -> [ty var] type PolyFuncs = Map.Map Id [Id] --- TODO: support tyvar in any argument position and support more than 1 tyvar per function --- currently only supports single tyvars in first argument +-- TODO: +-- ensure fixed point +-- more tests -- create monomorphic versions for each required instance of polymorphic function and rewrite ast monomorphiseAST :: AST ev Type -> AST ev Type monomorphiseAST ast = let polymorphicFuncs = getPolymorphicFunctions ast env = collectInstances ast polymorphicFuncs - monoDefs = makeMonoDefs ast env - rewritten = rewriteCalls ast env - in rewritten {definitions = filter (not . isPolymorphic) (definitions rewritten) ++ monoDefs} + in if null env + then ast {definitions = filter (not . isPolymorphic) (definitions ast)} + else + let monoDefs = makeMonoDefs ast env + rewritten = rewriteCalls ast env + in monomorphiseAST (rewritten {definitions = definitions rewritten ++ monoDefs}) isPolymorphic :: Def ev Type -> Bool isPolymorphic def = case defTypeScheme def of - Forall _ ((_, Type 0) : _) _ _ -> True - _ -> False + Forall _ bindings _ _ -> any (\(_, t) -> t == Type 0) bindings -- e.g. id -> __id_3856 -makeMonoId :: Id -> [Type] -> Id -makeMonoId (Id id _) types = - let hash = abs $ sum $ map fromEnum (show types) - name = "__" ++ id ++ "_" ++ show hash +makeMonoId :: Id -> Type -> Id +makeMonoId (Id id _) ty = + let name = "__" ++ id ++ "_" ++ show (abs $ hash $ show ty) in Id name name -- create map of polymorphic function id to its ty vars @@ -68,7 +73,7 @@ collectInstances ast fns = collectExpr (App _ _ _ e1 e2) = let inst = case getPolymorphicCall fns e1 e2 of Just (id, tyVarSubsts) -> - Map.singleton id [(makeMonoId id (map snd tyVarSubsts), tyVarSubsts)] + Map.singleton id [(makeMonoId id (annotation e2), tyVarSubsts)] Nothing -> Map.empty in Map.unionWith (++) (collectExprs [e1, e2]) inst collectExpr (Val _ _ _ val) = collectVal val @@ -97,17 +102,45 @@ collectInstances ast fns = -- identify polymorphic function calls and get substitution info getPolymorphicCall :: PolyFuncs -> Expr ev Type -> Expr ev Type -> Maybe (Id, [(Id, Type)]) -getPolymorphicCall fns (Val _ _ _ (Var _ id)) arg = +getPolymorphicCall fns (Val _ ty1 _ (Var ty2 id)) _ = case Map.lookup id fns of - Just tyVars -> - let argType = annotation arg - substitutions = case argType of - TyVar _ -> [] - _ -> [(head tyVars, argType)] - in Just (id, substitutions) + Just ids -> + let substs = matchTyVars ids ty2 ty1 + in if null substs + then Nothing + else Just (id, substs) Nothing -> Nothing getPolymorphicCall _ _ _ = Nothing +matchTyVars :: [Id] -> Type -> Type -> [(Id, Type)] +matchTyVars actualIds t1 t2 = + fixIds (match t1 t2) actualIds + where + fixIds subs vars = [(findVar id vars, typ) | (id, typ) <- subs] + findVar id vars = fromMaybe id $ find (matchId id) vars + -- e.g. (Id a.1 a.1) == (Id a a`0) + matchId id var = takeWhile (/= '.') (internalName id) == sourceName var + + match t1 t2 = case (t1, t2) of + (TyVar _, TyVar _) -> [] + (TyVar id, ty) -> [(id, ty)] + (FunTy _ _ a b, FunTy _ _ a' b') -> match2 a a' b b' + (TyApp a b, TyApp a' b') -> match2 a a' b b' + (Box _ a, Box _ a') -> match a a' + (Diamond _ a, Diamond _ a') -> match a a' + (Star _ a, Star _ a') -> match a a' + (Borrow _ a, Borrow _ a') -> match a a' + (TySig a _, TySig a' _) -> match a a' + (TyInfix _ a b, TyInfix _ a' b') -> match2 a a' b b' + (TyExists _ _ a, TyExists _ _ a') -> match a a' + (TyForall _ _ a, TyForall _ _ a') -> match a a' + (TyCase a as, TyCase a' as') -> match a a' ++ concat [match2 a a' b b' | ((a, b), (a', b')) <- zip as as'] + (TySet _ ts, TySet _ ts') -> concat (zipWith match ts ts') + (TyGrade (Just t) _, TyGrade (Just t') _) -> match t t' + _ -> [] + + match2 a a' b b' = match a a' ++ match b b' + -- create monomorphised definitions for all polymorphic function insts makeMonoDefs :: AST ev Type -> PolyInstances -> [Def ev Type] makeMonoDefs ast env = concatMap (monoDefsForFunc ast) (Map.toList env) @@ -195,7 +228,7 @@ rewriteCalls ast env = ast {definitions = map rewriteDef (definitions ast)} then let argTy = annotation rewrittenArg ty' = FunTy Nothing Nothing argTy ty - in Val s' ty' r' (Var ty' (makeMonoId id [argTy])) + in Val s' ty' r' (Var ty' (makeMonoId id argTy)) else rewrittenF _ -> rewrittenF in App s ty r newF rewrittenArg diff --git a/tests/golden/positive/poly-curry.golden b/tests/golden/positive/poly-curry.golden new file mode 100644 index 0000000..7093511 --- /dev/null +++ b/tests/golden/positive/poly-curry.golden @@ -0,0 +1 @@ +(100, 42.000000) diff --git a/tests/golden/positive/poly-curry.gr b/tests/golden/positive/poly-curry.gr new file mode 100644 index 0000000..e3c8a8e --- /dev/null +++ b/tests/golden/positive/poly-curry.gr @@ -0,0 +1,27 @@ +curry : forall {a : Type, b : Type, c : Type} . + (a × b -> c) -> a -> b -> c +curry f x y = f (x, y) + +uncurry : forall {a : Type, b : Type, c : Type} . + (a -> b -> c) -> (a × b -> c) +uncurry f (x, y) = f x y + +addInt : (Int, Int) -> Int +addInt (x, y) = x + y + +addFloat : (Float, Float) -> Float +addFloat (x, y) = x + y + +intDrop : forall {a : Type} . Int -> a [0] -> Int +intDrop x [_] = x + +floatDrop : forall {a : Type} . a [0] -> Float -> Float +floatDrop [_] x = x + +swap : forall {a : Type, b : Type} . (a, b) -> (b, a) +swap (x, y) = (y, x) + +main : (Int, Float) +main = let addInt4 = curry addInt 4; + addFloat4 = curry addFloat 4.0 in + swap (swap (addInt4 (intDrop 96 [2000]), addFloat4 (floatDrop [1000] 38.0))) From b43db5b64dc5d0a3befc9666a7085892dd61771e Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sat, 12 Apr 2025 20:47:00 +0100 Subject: [PATCH 4/8] improved and separated tyvar substitution --- granule-compiler.cabal | 1 + src/Language/Granule/Codegen/Compile.hs | 4 +- src/Language/Granule/Codegen/Monomorphise.hs | 75 +++-- src/Language/Granule/Codegen/RewriteAST.hs | 108 +------- .../Granule/Codegen/SubstituteTypes.hs | 261 ++++++++++++++++++ 5 files changed, 302 insertions(+), 147 deletions(-) create mode 100644 src/Language/Granule/Codegen/SubstituteTypes.hs diff --git a/granule-compiler.cabal b/granule-compiler.cabal index aeb4001..dcc6c42 100644 --- a/granule-compiler.cabal +++ b/granule-compiler.cabal @@ -49,6 +49,7 @@ library Language.Granule.Codegen.Emit.Types Language.Granule.Codegen.Monomorphise Language.Granule.Codegen.RewriteAST + Language.Granule.Codegen.SubstituteTypes Paths_granule_compiler hs-source-dirs: src diff --git a/src/Language/Granule/Codegen/Compile.hs b/src/Language/Granule/Codegen/Compile.hs index 1fa368c..f76aa65 100644 --- a/src/Language/Granule/Codegen/Compile.hs +++ b/src/Language/Granule/Codegen/Compile.hs @@ -10,12 +10,14 @@ import Language.Granule.Codegen.Emit.EmitLLVM import Language.Granule.Codegen.MarkGlobals import Language.Granule.Codegen.Monomorphise import Language.Granule.Codegen.RewriteAST +import Language.Granule.Codegen.SubstituteTypes import qualified LLVM.AST as IR compile :: String -> AST () Type -> Either String IR.Module compile moduleName typedAST = - let rewritten = rewriteAST typedAST + let substituted = substituteTypes typedAST + rewritten = rewriteAST substituted monomorphised = monomorphiseAST rewritten normalised = normaliseDefinitions monomorphised markedGlobals = markGlobals normalised diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs index 7cff7a5..2101784 100644 --- a/src/Language/Granule/Codegen/Monomorphise.hs +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -2,10 +2,8 @@ module Language.Granule.Codegen.Monomorphise (monomorphiseAST) where import Control.Monad.Identity (runIdentity) import Data.Bifunctor (Bifunctor (bimap), second) -import Data.Foldable (find) import Data.Hashable (hash) import qualified Data.Map as Map -import Data.Maybe (fromMaybe) import Language.Granule.Syntax.Annotated (annotation) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr hiding (subst) @@ -16,8 +14,8 @@ import Language.Granule.Syntax.Type -- polymorphic id -> [monomorphic id, [(ty var, ty subst)]] type PolyInstances = Map.Map Id [(Id, [(Id, Type)])] --- polymorphic id -> [ty var] -type PolyFuncs = Map.Map Id [Id] +-- polymorphic id -> ty +type PolyFuncs = Map.Map Id Type -- TODO: -- ensure fixed point @@ -51,12 +49,10 @@ getPolymorphicFunctions :: AST ev Type -> PolyFuncs getPolymorphicFunctions ast = Map.fromList $ map getPolyInfo $ filter isPolymorphic $ definitions ast where - getPolyInfo :: Def ev Type -> (Id, [Id]) + getPolyInfo :: Def ev Type -> (Id, Type) getPolyInfo def = case defTypeScheme def of - Forall _ bindings _ _ -> - let tyVars = map fst bindings - in (defId def, tyVars) + Forall _ _ _ ty -> (defId def, ty) -- collect all insts of polymorphic functions with their concrete type substitutions collectInstances :: AST ev Type -> PolyFuncs -> PolyInstances @@ -71,7 +67,7 @@ collectInstances ast fns = collectExpr :: Expr ev Type -> PolyInstances collectExpr (App _ _ _ e1 e2) = - let inst = case getPolymorphicCall fns e1 e2 of + let inst = case getPolymorphicCall fns e1 of Just (id, tyVarSubsts) -> Map.singleton id [(makeMonoId id (annotation e2), tyVarSubsts)] Nothing -> Map.empty @@ -101,45 +97,38 @@ collectInstances ast fns = collectExprs = foldr (Map.unionWith (++) . collectExpr) Map.empty -- identify polymorphic function calls and get substitution info -getPolymorphicCall :: PolyFuncs -> Expr ev Type -> Expr ev Type -> Maybe (Id, [(Id, Type)]) -getPolymorphicCall fns (Val _ ty1 _ (Var ty2 id)) _ = +getPolymorphicCall :: PolyFuncs -> Expr ev Type -> Maybe (Id, [(Id, Type)]) +getPolymorphicCall fns (Val _ _ _ (Var ty id)) = case Map.lookup id fns of - Just ids -> - let substs = matchTyVars ids ty2 ty1 + Just param -> + let substs = match param ty in if null substs then Nothing else Just (id, substs) Nothing -> Nothing -getPolymorphicCall _ _ _ = Nothing - -matchTyVars :: [Id] -> Type -> Type -> [(Id, Type)] -matchTyVars actualIds t1 t2 = - fixIds (match t1 t2) actualIds - where - fixIds subs vars = [(findVar id vars, typ) | (id, typ) <- subs] - findVar id vars = fromMaybe id $ find (matchId id) vars - -- e.g. (Id a.1 a.1) == (Id a a`0) - matchId id var = takeWhile (/= '.') (internalName id) == sourceName var - - match t1 t2 = case (t1, t2) of - (TyVar _, TyVar _) -> [] - (TyVar id, ty) -> [(id, ty)] - (FunTy _ _ a b, FunTy _ _ a' b') -> match2 a a' b b' - (TyApp a b, TyApp a' b') -> match2 a a' b b' - (Box _ a, Box _ a') -> match a a' - (Diamond _ a, Diamond _ a') -> match a a' - (Star _ a, Star _ a') -> match a a' - (Borrow _ a, Borrow _ a') -> match a a' - (TySig a _, TySig a' _) -> match a a' - (TyInfix _ a b, TyInfix _ a' b') -> match2 a a' b b' - (TyExists _ _ a, TyExists _ _ a') -> match a a' - (TyForall _ _ a, TyForall _ _ a') -> match a a' - (TyCase a as, TyCase a' as') -> match a a' ++ concat [match2 a a' b b' | ((a, b), (a', b')) <- zip as as'] - (TySet _ ts, TySet _ ts') -> concat (zipWith match ts ts') - (TyGrade (Just t) _, TyGrade (Just t') _) -> match t t' - _ -> [] - - match2 a a' b b' = match a a' ++ match b b' +getPolymorphicCall _ _ = Nothing + +match :: Type -> Type -> [(Id, Type)] +match param arg = case (param, arg) of + (TyVar _, TyVar _) -> [] + (TyVar id, ty) -> [(id, ty)] + (FunTy _ _ a b, FunTy _ _ a' b') -> match2 a a' b b' + (TyApp a b, TyApp a' b') -> match2 a a' b b' + (Box _ a, Box _ a') -> match a a' + (Diamond _ a, Diamond _ a') -> match a a' + (Star _ a, Star _ a') -> match a a' + (Borrow _ a, Borrow _ a') -> match a a' + (TySig a _, TySig a' _) -> match a a' + (TyInfix _ a b, TyInfix _ a' b') -> match2 a a' b b' + (TyExists _ _ a, TyExists _ _ a') -> match a a' + (TyForall _ _ a, TyForall _ _ a') -> match a a' + (TyCase a as, TyCase a' as') -> match a a' ++ concat [match2 a a' b b' | ((a, b), (a', b')) <- zip as as'] + (TySet _ ts, TySet _ ts') -> concat (zipWith match ts ts') + (TyGrade (Just t) _, TyGrade (Just t') _) -> match t t' + _ -> [] + +match2 :: Type -> Type -> Type -> Type -> [(Id, Type)] +match2 a a' b b' = match a a' ++ match b b' -- create monomorphised definitions for all polymorphic function insts makeMonoDefs :: AST ev Type -> PolyInstances -> [Def ev Type] diff --git a/src/Language/Granule/Codegen/RewriteAST.hs b/src/Language/Granule/Codegen/RewriteAST.hs index 32f710b..48d75be 100644 --- a/src/Language/Granule/Codegen/RewriteAST.hs +++ b/src/Language/Granule/Codegen/RewriteAST.hs @@ -1,17 +1,13 @@ module Language.Granule.Codegen.RewriteAST where -import Data.Bifunctor (bimap) -import Data.List (mapAccumL) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr -import Language.Granule.Syntax.Identifiers (Id) import Language.Granule.Syntax.Pattern import Language.Granule.Syntax.Type -- Rewrite Unpack ASTs into App Abs ASTs which our --- compiler already knows how to handle. WIP. +-- compiler already knows how to handle. +-- TODO: handle unpack in compile rewriteAST :: AST ev Type -> AST ev Type rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} @@ -20,15 +16,12 @@ rewriteAST ast = ast {definitions = map rewriteDef (definitions ast)} rewriteEquationList eqs = eqs {equations = map rewriteEquation (equations eqs)} rewriteEquation eq = eq {equationBody = rewriteExpr (equationBody eq)} +-- TODO: handle not top level rewriteExpr :: Expr ev Type -> Expr ev Type rewriteExpr (Unpack s retTy b tyVar var e1 e2) = - let e1' = e1 - e1Ty = exprTy e1' - e2' = e2 + let e1Ty = exprTy e1 absTy = FunTy Nothing Nothing e1Ty retTy - in fixTypes (App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2')) e1') - where - fixTypes expr = snd $ substExpr emptyEnv expr + in App s retTy b (Val s absTy b (Abs absTy (PVar s e1Ty b var) Nothing e2)) e1 rewriteExpr exp = exp exprTy :: Expr ev Type -> Type @@ -41,94 +34,3 @@ exprTy (Hole _ ty _ _ _) = ty exprTy (AppTy _ ty _ _ _) = ty exprTy (TryCatch _ ty _ _ _ _ _ _) = ty exprTy (Unpack _ ty _ _ _ _ _) = ty - --- `let (x, y) = ` inside of an Unpack seems to leave TyVars in the AST, and these --- are not already handled by the compiler. Here we find the correct types and substitute --- the TyVars. WIP. - --- val var -> Type, type var -> Type -type Env = (Map.Map Id Type, Map.Map Id Type) - -emptyEnv :: Env -emptyEnv = (Map.empty, Map.empty) - -insertEnv :: Env -> Either Id Id -> Type -> Env -insertEnv (vals, tys) (Left id) ty = (Map.insert id ty vals, tys) -insertEnv (vals, tys) (Right id) ty = (vals, Map.insert id ty tys) - -lookupEnv :: Env -> Either Id Id -> Maybe Type -lookupEnv (vals, tys) (Left id) = Map.lookup id vals -lookupEnv (vals, tys) (Right id) = Map.lookup id tys - -substExpr :: Env -> Expr ev Type -> (Env, Expr ev Type) -substExpr env (App s ty b e1 e2) = - let (env', e2') = substExpr env e2 - (env'', e1') = substExpr env' e1 - ty' = substTy env ty - in (env'', App s ty' b e1' e2') -substExpr env (Val s ty b v) = - let (env', v') = substVal env v - ty' = substTy env' ty - in (env', Val s ty' b v') -substExpr env exp = error "TODO expr" - -substVal :: Env -> Value ev Type -> (Env, Value ev Type) -substVal env (Var (TyVar id) var) = - -- see if we already have it - case lookupEnv env (Right id) of - Just ty -> (env, Var ty var) - Nothing -> - -- see if the value variable has it - case lookupEnv env (Left var) of - -- and update - Just ty -> (insertEnv env (Right id) ty, Var ty var) - -- we wont always win - Nothing -> (env, Var (TyVar id) var) -substVal env (Var ty var) = (insertEnv env (Left var) ty, Var ty var) -substVal env (Abs ty p mt e) = - let (env', p') = substPat env p - (env'', e') = substExpr env' e - ty' = substTy env'' ty - in (env'', Abs ty' p' mt e') -substVal env (Constr ty id vals) = - let (env', vals') = mapAccumL substVal env vals - ty' = substTy env' ty - in (env', Constr ty' id vals') -substVal env (NumInt v) = (env, NumInt v) -substVal env (NumFloat v) = (env, NumFloat v) -substVal env (Promote t v) = (env, Promote t v) -substVal env val = error "TODO val" - -substPat :: Env -> Pattern Type -> (Env, Pattern Type) -substPat env (PVar s (TyVar id) b var) = - case lookupEnv env (Right id) of - Just ty -> (env, PVar s ty b var) - Nothing -> - case lookupEnv env (Left var) of - Just ty -> (insertEnv env (Right id) ty, PVar s ty b var) - Nothing -> (env, PVar s (TyVar id) b var) -substPat env (PVar s ty b var) = (insertEnv env (Left var) ty, PVar s ty b var) -substPat env (PConstr s ty b id ids ps) = - let (env', ps') = mapAccumL substPat env ps - ty' = substTy env' ty - in (env', PConstr s ty' b id ids ps') -substPat env p = error "TODO pat" - -substTy :: Env -> Type -> Type -substTy env (TyVar id) = fromMaybe (TyVar id) (lookupEnv env (Right id)) -substTy env (Type i) = Type i -substTy env (FunTy id mc arg ret) = FunTy id mc (substTy env arg) (substTy env ret) -substTy env (TyCon id) = TyCon id -substTy env (Box c t) = substTy env t -substTy env (Diamond e t) = Diamond (substTy env e) (substTy env t) -substTy env (Star g t) = substTy env t -substTy env (Borrow p t) = substTy env t -substTy env (TyApp t1 t2) = TyApp (substTy env t1) (substTy env t2) -substTy env (TyGrade mt i) = TyGrade mt i -substTy env (TyInfix op t1 t2) = TyInfix op (substTy env t1) (substTy env t2) -substTy env (TySet p ts) = TySet p (map (substTy env) ts) -substTy env (TyCase t tps) = TyCase (substTy env t) (map (bimap (substTy env) (substTy env)) tps) -substTy env (TySig t k) = TySig (substTy env t) (substTy env k) -substTy env (TyExists id k t) = substTy env t -substTy env (TyForall id k t) = substTy env t -substTy env t = t diff --git a/src/Language/Granule/Codegen/SubstituteTypes.hs b/src/Language/Granule/Codegen/SubstituteTypes.hs new file mode 100644 index 0000000..225b248 --- /dev/null +++ b/src/Language/Granule/Codegen/SubstituteTypes.hs @@ -0,0 +1,261 @@ +module Language.Granule.Codegen.SubstituteTypes where + +import Data.Bifunctor (bimap) +import qualified Data.Map as Map +import Debug.Trace +import Language.Granule.Syntax.Def +import Language.Granule.Syntax.Expr +import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Pattern +import Language.Granule.Syntax.Type + +-- Many Typed ASTs contain unnecessary TyVars +-- There are a few strategies to substitue these +-- 1. From parent / sibling node i.e. +-- Val (concrete) (Var (variable)), +-- App (variable) (Val (variable -> concrete)) (Val (concrete)) +-- etc. +-- 2. From value bindings +-- i.e. +-- Somewhere: Var (concrete) x +-- Somewhere else: Var (variable) x +-- 3. From type bindings +-- i.e. +-- Somewhere: Val (concrete) (Var (x)) (using ) +-- Somewhere else: Val (x) (Var (variable)) () +-- We do 1 and 3 here +-- +-- TODO: clean up, do in 1 pass if possible +substituteTypes :: AST ev Type -> AST ev Type +substituteTypes ast = ast {definitions = map retypeDef (definitions ast)} + where + retypeDef def = def {defEquations = retypeEquationList (defEquations def)} + retypeEquationList eqs = eqs {equations = map retypeEquation (equations eqs)} + retypeEquation eq = + let expr = equationBody eq + substs = collect expr + expr' = replace substs expr + in eq {equationBody = expr'} + +type VMap = Map.Map Id Type + +diff :: Map.Map Id Type -> Type -> Type -> Map.Map Id Type +diff env t1 t2 = case (t1, t2) of + (TyVar v1, TyVar v2) -> + case (Map.lookup v1 env, Map.lookup v2 env) of + (Nothing, Nothing) -> env + (Nothing, Just t2') -> Map.insert v1 t2' env + (Just t1', Nothing) -> Map.insert v2 t1' env + (Just t1', Just t2') -> diff env t1' t2' + (TyVar v1, t2) -> + case Map.lookup v1 env of + Nothing -> Map.insert v1 t2 env + Just t1 -> diff env t1 t2 + (t1, TyVar v2) -> + case Map.lookup v2 env of + Nothing -> Map.insert v2 t1 env + Just t2 -> diff env t1 t2 + (FunTy _ _ a b, FunTy _ _ a' b') -> diff (diff env b b') a a' + (TyApp a b, TyApp a' b') -> diff (diff env b b') a a' + (Box _ a, Box _ a') -> diff env a a' + (Diamond _ a, Diamond _ a') -> diff env a a' + (Star _ a, Star _ a') -> diff env a a' + (Borrow _ a, Borrow _ a') -> diff env a a' + (TySig a _, TySig a' _) -> diff env a a' + (TyInfix _ a b, TyInfix _ a' b') -> diff (diff env b b') a a' + (TyExists _ _ a, TyExists _ _ a') -> diff env a a' + (TyForall _ _ a, TyForall _ _ a') -> diff env a a' + (TyCase a as, TyCase a' as') -> + foldl + (\e ((p1, r1), (p2, r2)) -> diff (diff e p1 p2) r1 r2) + (diff env a a') + (zip as as') + (TySet _ ts, TySet _ ts') -> + foldl + (\e (t, t') -> diff e t t') + env + (zip ts ts') + (TyGrade (Just t) _, TyGrade (Just t') _) -> diff env t t' + _ -> env + +collect :: Expr ev Type -> Map.Map Id Type +collect expr = + let (env', _) = inExpr Map.empty expr + in let (env'', _) = inExpr env' expr + in env'' + where + inExpr :: Map.Map Id Type -> Expr ev Type -> (Map.Map Id Type, Type) + inExpr env expr = + case expr of + App s retTy r f arg -> + let (env', argTy) = inExpr env arg + (env'', fTy) = inExpr env' f + env''' = diff env'' (FunTy Nothing Nothing argTy retTy) fTy + in (env''', retTy) + Val s ty r val -> + let (env', ty') = inVal env val + env'' = diff env' ty ty' + in (env'', ty) + Binop s ty r op e1 e2 -> (fst $ inExpr (fst $ inExpr env e2) e1, ty) + Case s ty r e ps -> + ( foldl + (\env (p, e) -> fst $ inExpr (fst $ inPat env p) e) + (fst $ inExpr env e) + ps, + ty + ) + AppTy s ty r e t -> error "TODO: AppTy" + LetDiamond s ty r ps mt e1 e2 -> error "TODO: LetDiamond" + TryCatch s ty r e p mt e1 e2 -> error "TODO: TryCatch" + Unpack s ty r tyVar var e1 e2 -> (fst $ inExpr (fst $ inExpr env e1) e2, ty) + Hole s ty r ids hs -> error "TODO: Hole" + + inVal :: Map.Map Id Type -> Value ev Type -> (Map.Map Id Type, Type) + inVal env val = + case val of + Abs funTy p mt e -> + let (env', argTy) = inPat env p + (env'', retTy) = inExpr env' e + env''' = diff env'' (FunTy Nothing Nothing argTy retTy) funTy + in (env''', funTy) + Constr a id vs -> + (foldl (\env v -> fst $ inVal env v) env vs, a) + Promote a e -> (fst $ inExpr env e, a) + Pure a e -> (fst $ inExpr env e, a) + Nec a e -> error "TODO: Nec" + Pack s a ty e v k ty' -> + let (env', eTy) = inExpr env e + + env'' = diff env' (TyExists v k eTy) a + in trace + (show a) + trace + (show (TyExists v k eTy)) + trace + "" + (env'', a) + TyAbs a v e -> error "TODO: TyAbs" + NumInt n -> (env, TyCon (Id "Int" "Int")) + NumFloat n -> (env, TyCon (Id "Float" "Float")) + CharLiteral ch -> (env, TyCon (Id "Char" "Char")) + StringLiteral str -> (env, TyCon (Id "String" "String")) + Ext a ev -> error "TODO: Ext" + Var a id -> (env, a) + + inPat :: Map.Map Id Type -> Pattern Type -> (Map.Map Id Type, Type) + inPat env pat = + case pat of + PVar s ty r id -> (env, ty) + PWild s ty r -> (env, ty) + PBox s ty r pat' -> (fst $ inPat env pat', ty) + PInt s ty r i -> (env, ty) + PFloat s ty r f -> (env, ty) + PConstr s ty r id ids ps -> (foldl (\env p -> fst $ inPat env p) env ps, ty) + +replace :: Map.Map Id Type -> Expr ev Type -> Expr ev Type +replace env expr = inExpr env expr + where + inExpr :: Map.Map Id Type -> Expr ev Type -> Expr ev Type + inExpr env expr = + case expr of + App s ty r f arg -> + let ty' = inTy env ty + arg' = inExpr env arg + f' = inExpr env f + in App s ty' r f' arg' + Val s ty r val -> + let ty' = inTy env ty + val' = inVal env val + in Val s ty' r val' + Binop s ty r op e1 e2 -> + let ty' = inTy env ty + e2' = inExpr env e2 + e1' = inExpr env e1 + in Binop s ty' r op e1' e2' + Case s ty r e ps -> + let ty' = inTy env ty + e' = inExpr env e + ps' = map (bimap (inPat env) (inExpr env)) ps + in Case s ty' r e' ps' + AppTy s ty r e t -> error "TODO: AppTy" + LetDiamond s ty r ps mt e1 e2 -> error "TODO: LetDiamond" + TryCatch s ty r e p mt e1 e2 -> error "TODO: TryCatch" + Unpack s ty r tyVar var e1 e2 -> + let ty' = inTy env ty + e1' = inExpr env e1 + e2' = inExpr env e2 + in Unpack s ty' r tyVar var e1' e2' + Hole s ty r ids hs -> error "TODO: Hole" + + inVal :: Map.Map Id Type -> Value ev Type -> Value ev Type + inVal env val = + case val of + Abs a pat mt e -> + let a' = inTy env a + pat' = inPat env pat + e' = inExpr env e + in Abs a' pat' mt e' + Constr a id vs -> + let a' = inTy env a + vs' = map (inVal env) vs + in Constr a' id vs' + Promote a e -> + let a' = inTy env a + e' = inExpr env e + in Promote a' e' + Pure a e -> + let a' = inTy env a + e' = inExpr env e + in Pure a' e' + Nec a e -> error "TODO: Nec" + Pack s a t1 e v k t2 -> + let a' = inTy env a + t1' = inTy env t1 + e' = inExpr env e + t2' = inTy env t2 + in Pack s a' t1' e' v k t2' + TyAbs a v e -> error "TODO: TyAbs" + NumInt n -> val + NumFloat n -> val + CharLiteral ch -> val + StringLiteral str -> val + Ext a ev -> error "TODO: Ext" + Var a id -> Var (inTy env a) id + + inPat :: Map.Map Id Type -> Pattern Type -> Pattern Type + inPat env pat = + case pat of + PVar s ty r id -> PVar s (inTy env ty) r id + PWild s ty r -> PWild s (inTy env ty) r + PBox s ty r p -> PBox s (inTy env ty) r (inPat env p) + PInt s ty r i -> PInt s (inTy env ty) r i + PFloat s ty r f -> PFloat s (inTy env ty) r f + PConstr s ty r id ids ps -> PConstr s (inTy env ty) r id ids (map (inPat env) ps) + + inTy :: Map.Map Id Type -> Type -> Type + inTy env ty = + case ty of + TyVar id -> + case Map.lookup id env of + Nothing -> ty + Just ty' -> inTy env ty' + Type i -> Type i + FunTy id mc arg ret -> FunTy id mc (inTy env arg) (inTy env ret) + Box c t -> Box c (inTy env t) + Diamond e t -> Diamond e (inTy env t) + Star g t -> Star g (inTy env t) + Borrow p t -> Borrow p (inTy env t) + TyApp t1 t2 -> TyApp (inTy env t1) (inTy env t2) + TyInfix op t1 t2 -> TyInfix op (inTy env t1) (inTy env t2) + TyCase t tys -> TyCase (inTy env t) (map (bimap (inTy env) (inTy env)) tys) + TySig t k -> TySig (inTy env t) k + TyExists id k t -> TyExists id k (inTy env t) + TyForall id k t -> TyForall id k (inTy env t) + TyGrade (Just t) i -> TyGrade (Just (inTy env t)) i + TySet p ts -> TySet p (map (inTy env) ts) + TyCon {} -> ty + TyInt {} -> ty + TyRational {} -> ty + TyFraction {} -> ty + TyName {} -> ty + TyGrade {} -> ty From c6f3abb9eab6226a38599b681ea70ff42b44f30b Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sat, 12 Apr 2025 22:53:28 +0100 Subject: [PATCH 5/8] only emit builtins which are used --- src/Language/Granule/Codegen/ClosureFreeDef.hs | 2 -- src/Language/Granule/Codegen/ConvertClosures.hs | 3 --- src/Language/Granule/Codegen/Emit/EmitBuiltins.hs | 5 +++-- src/Language/Granule/Codegen/Emit/EmitLLVM.hs | 6 ++++-- src/Language/Granule/Codegen/Emit/EmitterState.hs | 10 +++++++++- src/Language/Granule/Codegen/Emit/LowerClosure.hs | 5 ----- src/Language/Granule/Codegen/Emit/LowerExpression.hs | 9 ++++++--- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/Language/Granule/Codegen/ClosureFreeDef.hs b/src/Language/Granule/Codegen/ClosureFreeDef.hs index 0626e0b..ffd109c 100644 --- a/src/Language/Granule/Codegen/ClosureFreeDef.hs +++ b/src/Language/Granule/Codegen/ClosureFreeDef.hs @@ -54,7 +54,6 @@ data ClosureMarker = CapturedVar Type Id Int | MakeClosure Id ClosureEnvironmentInit | MakeTrivialClosure Id - | MakeBuiltinClosure Id deriving (Show, Eq) data ClosureFreeAST = @@ -84,4 +83,3 @@ instance Pretty ClosureMarker where "env(ident = \"" ++ envName ++ "\", " ++ intercalate ", " (map prettyEnvVar varInits) ++ ")" in "make-closure(" ++ pretty ident ++ ", " ++ prettyEnv env ++ ")" pretty (MakeTrivialClosure ident) = pretty ident - pretty (MakeBuiltinClosure ident) = pretty ident diff --git a/src/Language/Granule/Codegen/ConvertClosures.hs b/src/Language/Granule/Codegen/ConvertClosures.hs index 6842b6f..f221252 100644 --- a/src/Language/Granule/Codegen/ConvertClosures.hs +++ b/src/Language/Granule/Codegen/ConvertClosures.hs @@ -167,8 +167,5 @@ convertClosuresFromValue (_, maybeCurrentEnv, locals) (VarF ty ident) ++ sourceName ident ++ " in environment." in return $ Ext ty (Right (CapturedVar ty ident indexInEnv)) -convertClosuresFromValue (_, maybeCurrentEnv, _) (ExtF ty (BuiltinVar _ ident)) = - return $ Ext ty $ Right $ MakeBuiltinClosure ident - convertClosuresFromValue _ other = return $ fixMapExtValue (\ty gv -> Ext ty $ Left gv) other diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index 728990d..dcf1f4c 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -16,9 +16,10 @@ import Language.Granule.Codegen.Builtins.Shared import Language.Granule.Codegen.Emit.LLVMHelpers import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment) import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction) +import Language.Granule.Syntax.Identifiers -emitBuiltins :: (MonadModuleBuilder m) => m [Operand] -emitBuiltins = mapM emitBuiltin builtins +emitBuiltins :: (MonadModuleBuilder m) => [Id] -> m [Operand] +emitBuiltins ids = mapM emitBuiltin (filter (\b -> any (\(Id s _) -> builtinId b == s) ids) builtins) emitBuiltin :: (MonadModuleBuilder m) => Builtin -> m Operand emitBuiltin builtin = diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index 672917d..10fa195 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -33,6 +33,7 @@ import Language.Granule.Syntax.Type hiding (Type) import Data.String (fromString) import qualified Data.Map.Strict as Map +import qualified Data.Set as Set import Control.Monad.Fix import Control.Monad.State.Strict hiding (void) @@ -40,20 +41,21 @@ import LLVM.IRBuilder (int32) emitLLVM :: String -> ClosureFreeAST -> Either String IR.Module emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = - let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty }) + let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty, builtins = Set.empty }) in Right $ buildModule (fromString moduleName) $ do _ <- extern (mkName "malloc") [i64] (ptr i8) _ <- extern (mkName "abort") [] void _ <- externVarArgs (mkName "printf") [ptr i8] i32 _ <- extern (mkName "llvm.memcpy.p0.p0.i32") [ptr i8, ptr i8, i32, i1] void _ <- extern (mkName "free") [ptr i8] void - _ <- emitBuiltins let mainTy = findMainReturnType valueDefs _ <- emitMainOut mainTy mapM_ emitDataDecl dataDecls mapM_ emitEnvironmentType functionDefs mapM_ emitFunctionDef functionDefs valueInitPairs <- mapM emitValueDef valueDefs + builtins <- usedBuiltins + _ <- emitBuiltins builtins emitGlobalInitializer valueInitPairs mainTy emitGlobalInitializer :: (MonadModuleBuilder m) => [(Operand, Operand)] -> GrType -> m Operand diff --git a/src/Language/Granule/Codegen/Emit/EmitterState.hs b/src/Language/Granule/Codegen/Emit/EmitterState.hs index c01c268..4ac0bd1 100644 --- a/src/Language/Granule/Codegen/Emit/EmitterState.hs +++ b/src/Language/Granule/Codegen/Emit/EmitterState.hs @@ -8,8 +8,10 @@ import LLVM.AST (Operand) import Data.Map (Map, insertWith) import qualified Data.Map as Map +import Data.Set (Set) +import qualified Data.Set as Set -data EmitterState = EmitterState { localSymbols :: Map Id Operand } +data EmitterState = EmitterState { localSymbols :: Map Id Operand, builtins :: Set Id } addLocal :: (MonadState EmitterState m) => Id @@ -40,3 +42,9 @@ local name = case local of Just op -> return op Nothing -> error $ internalName name ++ "not registered as a local, missing call to addLocal?\n" + +useBuiltin :: (MonadState EmitterState m) => Id -> m () +useBuiltin id = modify $ \s -> s { builtins = Set.insert id (builtins s) } + +usedBuiltins :: (MonadState EmitterState m) => m [Id] +usedBuiltins = Set.toList <$> gets builtins diff --git a/src/Language/Granule/Codegen/Emit/LowerClosure.hs b/src/Language/Granule/Codegen/Emit/LowerClosure.hs index f7730e5..9dcd8f8 100644 --- a/src/Language/Granule/Codegen/Emit/LowerClosure.hs +++ b/src/Language/Granule/Codegen/Emit/LowerClosure.hs @@ -58,11 +58,6 @@ emitClosureMarker ty maybeParentEnv (MakeClosure ident initializer) = emitClosureMarker ty _ (MakeTrivialClosure identifier) = return $ ConstantOperand $ makeTrivialClosure identifier ty -emitClosureMarker ty _ (MakeBuiltinClosure ident) = do - let functionPtr = ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (functionNameFromId ident) - closure <- insertValue (ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0] - insertValue closure (ConstantOperand $ C.Null (ptr i8)) [1] - emitEnvironmentInit :: (MonadModuleBuilder m, MonadIRBuilder m, MonadState EmitterState m) => [ClosureVariableInit] -> Operand diff --git a/src/Language/Granule/Codegen/Emit/LowerExpression.hs b/src/Language/Granule/Codegen/Emit/LowerExpression.hs index 15295c7..2349501 100644 --- a/src/Language/Granule/Codegen/Emit/LowerExpression.hs +++ b/src/Language/Granule/Codegen/Emit/LowerExpression.hs @@ -5,7 +5,7 @@ module Language.Granule.Codegen.Emit.LowerExpression where import Language.Granule.Codegen.ClosureFreeDef (ClosureMarker) import Language.Granule.Codegen.MarkGlobals (GlobalMarker, GlobalMarker(..)) import Language.Granule.Codegen.Emit.LowerOperator -import Language.Granule.Codegen.Emit.LowerType (llvmType) +import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTopLevelType) import Language.Granule.Codegen.Emit.EmitableDef import Language.Granule.Codegen.Emit.LowerPatterns (emitCaseArm) import Language.Granule.Codegen.Emit.LowerClosure (emitClosureMarker) @@ -29,7 +29,7 @@ import LLVM.IRBuilder.Monad import LLVM.IRBuilder.Instruction import LLVM.AST (Operand) -import LLVM.AST.Type (ptr) +import LLVM.AST.Type (ptr, i8) import LLVM.AST.Constant as C import qualified LLVM.IRBuilder.Constant as IC import qualified LLVM.AST as IR @@ -137,7 +137,10 @@ emitValue _ (ExtF a (Left (GlobalVar ty ident))) = do let ref = IR.ConstantOperand $ C.GlobalReference (ptr (llvmType ty)) (definitionNameFromId ident) load ref 4 emitValue _ (ExtF a (Left (BuiltinVar ty ident))) = do - error "TODO?" + useBuiltin ident + let functionPtr = IR.ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (IR.mkName $ "fn." ++ sourceName ident) + closure <- insertValue (IR.ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0] + insertValue closure (IR.ConstantOperand $ C.Null (ptr i8)) [1] emitValue environment (ExtF ty (Right cm)) = emitClosureMarker ty environment cm {- TODO: Support tagged unions, also affects Case. From e17339ffa90fceea11a8851c44414b3e08958926 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 13 Apr 2025 02:04:14 +0100 Subject: [PATCH 6/8] start support for polymorphic builtins --- .../Granule/Codegen/Builtins/Builtins.hs | 6 +++++- .../Granule/Codegen/Builtins/Extras.hs | 9 ++++++++ .../Granule/Codegen/Emit/EmitBuiltins.hs | 21 ++++++++++++++++--- src/Language/Granule/Codegen/Emit/EmitLLVM.hs | 3 +-- .../Granule/Codegen/Emit/EmitterState.hs | 13 ++++++------ .../Granule/Codegen/Emit/LowerExpression.hs | 4 ++-- src/Language/Granule/Codegen/Emit/MainOut.hs | 1 + src/Language/Granule/Codegen/MarkGlobals.hs | 3 ++- src/Language/Granule/Codegen/Monomorphise.hs | 12 +++++++---- tests/golden/positive/poly-builtins.golden | 1 + tests/golden/positive/poly-builtins.gr | 11 ++++++++++ 11 files changed, 64 insertions(+), 20 deletions(-) create mode 100644 tests/golden/positive/poly-builtins.golden create mode 100644 tests/golden/positive/poly-builtins.gr diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index f4aba6e..4c52cdb 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -17,8 +17,12 @@ builtins = readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, - deleteFloatArrayDef + deleteFloatArrayDef, + useDef ] builtinIds :: [Id] builtinIds = map (mkId . builtinId) builtins + +polyBuiltinIds :: [Id] +polyBuiltinIds = map (mkId . builtinId) [useDef] diff --git a/src/Language/Granule/Codegen/Builtins/Extras.hs b/src/Language/Granule/Codegen/Builtins/Extras.hs index 63e96b8..e278fbd 100644 --- a/src/Language/Granule/Codegen/Builtins/Extras.hs +++ b/src/Language/Granule/Codegen/Builtins/Extras.hs @@ -25,3 +25,12 @@ divDef = args = [TyCon (Id "Int" "Int"), TyCon (Id "Int" "Int")] ret = TyCon (Id "Int" "Int") impl [x, y] = sdiv x y + +-- use :: a -> a [1] +useDef :: Builtin +useDef = + Builtin "use" args ret impl + where + args = [Gr.tyVar "a"] + ret = Box (TyGrade Nothing 1) (Gr.tyVar "a") + impl [val] = return val diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index dcf1f4c..3256879 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -6,7 +6,7 @@ import Control.Monad (forM) import LLVM.AST (Operand) import qualified LLVM.AST as IR import qualified LLVM.AST.Constant as C -import LLVM.AST.Type hiding (Type) +import LLVM.AST.Type hiding (resultType, Type) import LLVM.IRBuilder.Constant (int32) import LLVM.IRBuilder.Instruction import LLVM.IRBuilder.Module @@ -17,9 +17,24 @@ import Language.Granule.Codegen.Emit.LLVMHelpers import Language.Granule.Codegen.Emit.LowerClosure (mallocEnvironment) import Language.Granule.Codegen.Emit.LowerType (llvmType, llvmTypeForClosure, llvmTypeForFunction) import Language.Granule.Syntax.Identifiers +import Language.Granule.Syntax.Type -emitBuiltins :: (MonadModuleBuilder m) => [Id] -> m [Operand] -emitBuiltins ids = mapM emitBuiltin (filter (\b -> any (\(Id s _) -> builtinId b == s) ids) builtins) + +emitBuiltins :: (MonadModuleBuilder m) => [(Id, Type)] -> m () +emitBuiltins uses = do + let instances = foldMap (\(id, ty) -> [(b, id, ty) | b <- builtins, builtinId b == sourceName id]) uses + mapM_ emitInstance instances + where + emitInstance (builtin, Id s i, ty) + | s == i = emitBuiltin builtin + | otherwise = -- we change the internal name for specialisations + emitBuiltin + ( builtin + { builtinId = i, + builtinArgTys = parameterTypes ty, + builtinRetTy = resultType ty + } + ) emitBuiltin :: (MonadModuleBuilder m) => Builtin -> m Operand emitBuiltin builtin = diff --git a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs index 10fa195..0f1ccc5 100644 --- a/src/Language/Granule/Codegen/Emit/EmitLLVM.hs +++ b/src/Language/Granule/Codegen/Emit/EmitLLVM.hs @@ -33,7 +33,6 @@ import Language.Granule.Syntax.Type hiding (Type) import Data.String (fromString) import qualified Data.Map.Strict as Map -import qualified Data.Set as Set import Control.Monad.Fix import Control.Monad.State.Strict hiding (void) @@ -41,7 +40,7 @@ import LLVM.IRBuilder (int32) emitLLVM :: String -> ClosureFreeAST -> Either String IR.Module emitLLVM moduleName (ClosureFreeAST dataDecls functionDefs valueDefs) = - let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty, builtins = Set.empty }) + let buildModule name m = evalState (buildModuleT name m) (EmitterState { localSymbols = Map.empty, builtins = Map.empty }) in Right $ buildModule (fromString moduleName) $ do _ <- extern (mkName "malloc") [i64] (ptr i8) _ <- extern (mkName "abort") [] void diff --git a/src/Language/Granule/Codegen/Emit/EmitterState.hs b/src/Language/Granule/Codegen/Emit/EmitterState.hs index 4ac0bd1..2e669bc 100644 --- a/src/Language/Granule/Codegen/Emit/EmitterState.hs +++ b/src/Language/Granule/Codegen/Emit/EmitterState.hs @@ -2,16 +2,15 @@ module Language.Granule.Codegen.Emit.EmitterState where import Language.Granule.Syntax.Identifiers (Id, internalName) +import Language.Granule.Syntax.Type import Control.Monad.State.Strict hiding (void) import LLVM.AST (Operand) import Data.Map (Map, insertWith) import qualified Data.Map as Map -import Data.Set (Set) -import qualified Data.Set as Set -data EmitterState = EmitterState { localSymbols :: Map Id Operand, builtins :: Set Id } +data EmitterState = EmitterState { localSymbols :: Map Id Operand, builtins :: Map Id Type } addLocal :: (MonadState EmitterState m) => Id @@ -43,8 +42,8 @@ local name = Just op -> return op Nothing -> error $ internalName name ++ "not registered as a local, missing call to addLocal?\n" -useBuiltin :: (MonadState EmitterState m) => Id -> m () -useBuiltin id = modify $ \s -> s { builtins = Set.insert id (builtins s) } +useBuiltin :: (MonadState EmitterState m) => Id -> Type -> m () +useBuiltin id ty = modify $ \s -> s { builtins = Map.insert id ty (builtins s) } -usedBuiltins :: (MonadState EmitterState m) => m [Id] -usedBuiltins = Set.toList <$> gets builtins +usedBuiltins :: (MonadState EmitterState m) => m [(Id, Type)] +usedBuiltins = Map.toList <$> gets builtins diff --git a/src/Language/Granule/Codegen/Emit/LowerExpression.hs b/src/Language/Granule/Codegen/Emit/LowerExpression.hs index 2349501..9f3f6fa 100644 --- a/src/Language/Granule/Codegen/Emit/LowerExpression.hs +++ b/src/Language/Granule/Codegen/Emit/LowerExpression.hs @@ -137,8 +137,8 @@ emitValue _ (ExtF a (Left (GlobalVar ty ident))) = do let ref = IR.ConstantOperand $ C.GlobalReference (ptr (llvmType ty)) (definitionNameFromId ident) load ref 4 emitValue _ (ExtF a (Left (BuiltinVar ty ident))) = do - useBuiltin ident - let functionPtr = IR.ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (IR.mkName $ "fn." ++ sourceName ident) + useBuiltin ident ty + let functionPtr = IR.ConstantOperand $ C.GlobalReference (ptr $ llvmTopLevelType ty) (IR.mkName $ "fn." ++ internalName ident) closure <- insertValue (IR.ConstantOperand $ C.Undef (llvmType ty)) functionPtr [0] insertValue closure (IR.ConstantOperand $ C.Null (ptr i8)) [1] emitValue environment (ExtF ty (Right cm)) = diff --git a/src/Language/Granule/Codegen/Emit/MainOut.hs b/src/Language/Granule/Codegen/Emit/MainOut.hs index e79efb3..b8a51c1 100644 --- a/src/Language/Granule/Codegen/Emit/MainOut.hs +++ b/src/Language/Granule/Codegen/Emit/MainOut.hs @@ -65,4 +65,5 @@ fmtStrForTy x = (TyApp (TyCon (Id "FloatArray" _)) _) -> "" (TyCon (Id "()" _)) -> "()" (TyExists _ _ (Borrow _ ty)) -> "*" ++ fmtStrForTy ty + (Box _ ty) -> "[" ++ fmtStrForTy ty ++ "]" _ -> error ("Unsupported Main type: " ++ show x) diff --git a/src/Language/Granule/Codegen/MarkGlobals.hs b/src/Language/Granule/Codegen/MarkGlobals.hs index 64dc0bc..4bdc303 100644 --- a/src/Language/Granule/Codegen/MarkGlobals.hs +++ b/src/Language/Granule/Codegen/MarkGlobals.hs @@ -35,7 +35,8 @@ markGlobalsInExpr :: [Id] -> Expr () Type -> Expr GlobalMarker Type markGlobalsInExpr globals = bicata fixMapExtExpr markInValue where markInValue (VarF ty ident) - | ident `elem` builtinIds = Ext ty (BuiltinVar ty ident) + | any (\id -> sourceName ident == sourceName id) builtinIds = + Ext ty (BuiltinVar ty ident) | ident `elem` globals = Ext ty (GlobalVar ty ident) | otherwise = Var ty ident markInValue other = diff --git a/src/Language/Granule/Codegen/Monomorphise.hs b/src/Language/Granule/Codegen/Monomorphise.hs index 2101784..b539256 100644 --- a/src/Language/Granule/Codegen/Monomorphise.hs +++ b/src/Language/Granule/Codegen/Monomorphise.hs @@ -4,6 +4,7 @@ import Control.Monad.Identity (runIdentity) import Data.Bifunctor (Bifunctor (bimap), second) import Data.Hashable (hash) import qualified Data.Map as Map +import Language.Granule.Codegen.Builtins.Builtins (polyBuiltinIds) import Language.Granule.Syntax.Annotated (annotation) import Language.Granule.Syntax.Def import Language.Granule.Syntax.Expr hiding (subst) @@ -27,7 +28,9 @@ monomorphiseAST ast = let polymorphicFuncs = getPolymorphicFunctions ast env = collectInstances ast polymorphicFuncs in if null env - then ast {definitions = filter (not . isPolymorphic) (definitions ast)} + then -- we still need to rewrite builtins + let rewritten = rewriteCalls ast Map.empty + in rewritten {definitions = filter (not . isPolymorphic) (definitions rewritten)} else let monoDefs = makeMonoDefs ast env rewritten = rewriteCalls ast env @@ -41,8 +44,8 @@ isPolymorphic def = -- e.g. id -> __id_3856 makeMonoId :: Id -> Type -> Id makeMonoId (Id id _) ty = - let name = "__" ++ id ++ "_" ++ show (abs $ hash $ show ty) - in Id name name + let monoId = "__" ++ id ++ "_" ++ show (abs $ hash $ show ty) + in Id id monoId -- create map of polymorphic function id to its ty vars getPolymorphicFunctions :: AST ev Type -> PolyFuncs @@ -213,7 +216,8 @@ rewriteCalls ast env = ast {definitions = map rewriteDef (definitions ast)} newF = case rewrittenF of Val s' t' r' (Var vt id) -> -- Only rewrite if this is a polymorphic function in our map - if Map.member id env + -- or polymorphic builtin + if Map.member id env || id `elem` polyBuiltinIds then let argTy = annotation rewrittenArg ty' = FunTy Nothing Nothing argTy ty diff --git a/tests/golden/positive/poly-builtins.golden b/tests/golden/positive/poly-builtins.golden new file mode 100644 index 0000000..8aa4ae5 --- /dev/null +++ b/tests/golden/positive/poly-builtins.golden @@ -0,0 +1 @@ +([[42.000000]], [[100]]) diff --git a/tests/golden/positive/poly-builtins.gr b/tests/golden/positive/poly-builtins.gr new file mode 100644 index 0000000..11dcbec --- /dev/null +++ b/tests/golden/positive/poly-builtins.gr @@ -0,0 +1,11 @@ +use' : forall {a : Type} . a -> a [1] +use' x = use x + +use2 : forall {a b : Type} . (a, b) -> (a [1], b [1]) +use2 (x, y) = (use x, use y) + +main : ((Float [1]) [1], (Int [1]) [1]) +main = + let a = use' 42.0; + b = use' 100 in + use2 (a, b) From 0d755727788f8172b415350f13a746041e5ec1e5 Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 13 Apr 2025 11:05:09 +0100 Subject: [PATCH 7/8] capture more tyvars --- src/Language/Granule/Codegen/SubstituteTypes.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Language/Granule/Codegen/SubstituteTypes.hs b/src/Language/Granule/Codegen/SubstituteTypes.hs index 225b248..220d031 100644 --- a/src/Language/Granule/Codegen/SubstituteTypes.hs +++ b/src/Language/Granule/Codegen/SubstituteTypes.hs @@ -107,7 +107,11 @@ collect expr = AppTy s ty r e t -> error "TODO: AppTy" LetDiamond s ty r ps mt e1 e2 -> error "TODO: LetDiamond" TryCatch s ty r e p mt e1 e2 -> error "TODO: TryCatch" - Unpack s ty r tyVar var e1 e2 -> (fst $ inExpr (fst $ inExpr env e1) e2, ty) + Unpack s ty r tyVar var e1 e2 -> + let (env', _) = inExpr env e1 + (env'', retTy) = inExpr env' e2 + env''' = diff env'' ty retTy + in (env''', ty) Hole s ty r ids hs -> error "TODO: Hole" inVal :: Map.Map Id Type -> Value ev Type -> (Map.Map Id Type, Type) From 576b5f36db921a9ed5141d60a03c0bc4fb278b5f Mon Sep 17 00:00:00 2001 From: Jacob Pake Date: Sun, 13 Apr 2025 12:40:12 +0100 Subject: [PATCH 8/8] add "specialisables" --- .../Granule/Codegen/Builtins/Builtins.hs | 17 ++++++++++++----- .../Granule/Codegen/Builtins/Extras.hs | 8 +++----- .../Granule/Codegen/Builtins/Shared.hs | 13 +++++++++++++ .../Granule/Codegen/Emit/EmitBuiltins.hs | 18 +++++------------- 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/Language/Granule/Codegen/Builtins/Builtins.hs b/src/Language/Granule/Codegen/Builtins/Builtins.hs index 4c52cdb..91064e8 100644 --- a/src/Language/Granule/Codegen/Builtins/Builtins.hs +++ b/src/Language/Granule/Codegen/Builtins/Builtins.hs @@ -17,12 +17,19 @@ builtins = readFloatArrayDef, writeFloatArrayDef, lengthFloatArrayDef, - deleteFloatArrayDef, - useDef + deleteFloatArrayDef ] -builtinIds :: [Id] -builtinIds = map (mkId . builtinId) builtins +specialisable :: [Specialisable] +specialisable = + [ useDef + ] + +monoBuiltinIds :: [Id] +monoBuiltinIds = map (mkId . builtinId) builtins polyBuiltinIds :: [Id] -polyBuiltinIds = map (mkId . builtinId) [useDef] +polyBuiltinIds = map (mkId . specialisableId) specialisable + +builtinIds :: [Id] +builtinIds = monoBuiltinIds ++ polyBuiltinIds diff --git a/src/Language/Granule/Codegen/Builtins/Extras.hs b/src/Language/Granule/Codegen/Builtins/Extras.hs index e278fbd..9c09fb7 100644 --- a/src/Language/Granule/Codegen/Builtins/Extras.hs +++ b/src/Language/Granule/Codegen/Builtins/Extras.hs @@ -27,10 +27,8 @@ divDef = impl [x, y] = sdiv x y -- use :: a -> a [1] -useDef :: Builtin +useDef :: Specialisable useDef = - Builtin "use" args ret impl + Specialisable "use" impl where - args = [Gr.tyVar "a"] - ret = Box (TyGrade Nothing 1) (Gr.tyVar "a") - impl [val] = return val + impl _ [val] = return val diff --git a/src/Language/Granule/Codegen/Builtins/Shared.hs b/src/Language/Granule/Codegen/Builtins/Shared.hs index 29c52c5..9e72691 100644 --- a/src/Language/Granule/Codegen/Builtins/Shared.hs +++ b/src/Language/Granule/Codegen/Builtins/Shared.hs @@ -1,5 +1,6 @@ {-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} module Language.Granule.Codegen.Builtins.Shared where @@ -21,6 +22,18 @@ data Builtin = Builtin { builtinRetTy :: Gr.Type, builtinImpl :: forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} +data Specialisable = Specialisable { + specialisableId :: String, + specialisableImpl :: [Gr.Type] -> forall m. (MonadModuleBuilder m, MonadIRBuilder m) => [Operand] -> m Operand} + +specialise :: Specialisable -> String -> Gr.Type -> Builtin +specialise builtin id ty = Builtin id args ret impl + where + args = Gr.parameterTypes ty + ret = Gr.resultType ty + {-# HLINT ignore "Eta reduce" #-} -- has to be lazy or we need IRBuilder early + impl xs = specialisableImpl builtin args xs + -- LLVM helpers allocate :: (MonadIRBuilder m, MonadModuleBuilder m) => Operand -> IR.Type -> m Operand diff --git a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs index 3256879..10ae8f6 100644 --- a/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs +++ b/src/Language/Granule/Codegen/Emit/EmitBuiltins.hs @@ -21,20 +21,12 @@ import Language.Granule.Syntax.Type emitBuiltins :: (MonadModuleBuilder m) => [(Id, Type)] -> m () -emitBuiltins uses = do - let instances = foldMap (\(id, ty) -> [(b, id, ty) | b <- builtins, builtinId b == sourceName id]) uses - mapM_ emitInstance instances +emitBuiltins uses = mapM_ emitBuiltin (monos ++ polys) where - emitInstance (builtin, Id s i, ty) - | s == i = emitBuiltin builtin - | otherwise = -- we change the internal name for specialisations - emitBuiltin - ( builtin - { builtinId = i, - builtinArgTys = parameterTypes ty, - builtinRetTy = resultType ty - } - ) + monos = + [b | (id, _) <- uses, b <- builtins, builtinId b == sourceName id] + polys = + [specialise b (internalName id) ty | (id, ty) <- uses, b <- specialisable, specialisableId b == sourceName id] emitBuiltin :: (MonadModuleBuilder m) => Builtin -> m Operand emitBuiltin builtin =