diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute.hs index 63ee0739a..857d5fa17 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute.hs @@ -53,7 +53,7 @@ import qualified Data.Array.Accelerate.LLVM.Native.Debug as Debug import Control.Concurrent ( myThreadId ) import Control.Concurrent.Extra ( getThreadId ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Control.Monad.Trans ( liftIO ) import Data.ByteString.Short ( ShortByteString ) import Data.IORef ( newIORef, readIORef, writeIORef ) @@ -139,7 +139,7 @@ simpleOp simpleOp name repr NativeR{..} gamma aenv sh = do let fun = nativeExecutable !# name param = TupRsingle $ ParamRarray repr - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr sh scheduleOp fun gamma aenv (arrayRshape repr) sh param result @@ -167,7 +167,7 @@ mapOp inplace repr tp NativeR{..} gamma aenv input = do shr = arrayRshape repr repr' = ArrayR shr tp param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr) - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- case inplace of Just Refl -> return input @@ -201,7 +201,7 @@ transformOp -> Par Native (Future (Array sh' b)) transformOp repr repr' NativeR{..} gamma aenv sh' input = do let fun = nativeExecutable !# "transform" - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr' sh' let param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr) @@ -300,7 +300,7 @@ foldAllOp -> Delayed (Vector e) -> Par Native (Future (Scalar e)) foldAllOp tp NativeR{..} gamma aenv arr = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote (ArrayR dim0 tp) () let @@ -343,7 +343,7 @@ foldDimOp -> Delayed (Array (sh, Int) e) -> Par Native (Future (Array sh e)) foldDimOp repr NativeR{..} gamma aenv arr@(delayedShape -> (sh, _)) = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr sh let @@ -371,7 +371,7 @@ foldSegOp -> Delayed (Segments i) -> Par Native (Future (Array (sh, Int) e)) foldSegOp iR repr NativeR{..} gamma aenv input@(delayedShape -> (sh, _)) segments@(delayedShape -> ((), ss)) = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new let n = ss-1 @@ -428,7 +428,7 @@ scanCore -> Delayed (Array (sh, Int) e) -> Par Native (Future (Array (sh, Int) e)) scanCore repr NativeR{..} gamma aenv m input@(delayedShape -> (sz, n)) = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr (sz, m) -- @@ -527,7 +527,7 @@ scan'Core repr NativeR{..} gamma aenv input@(delayedShape -> sh@(sz, n)) = do paramA = TupRsingle $ ParamRarray repr paramA' = TupRsingle $ ParamRarray repr' -- - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr sh sums <- allocateRemote repr' sz @@ -608,7 +608,7 @@ permuteOp inplace repr shr' NativeR{..} gamma aenv defaults@(shape -> shOut) inp let ArrayR shr tp = repr repr' = ArrayR shr' tp - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- if inplace then Debug.trace Debug.dump_exec "exec: permute/inplace" $ return defaults @@ -701,7 +701,7 @@ stencilCore -> params -> Par Native (Future (Array sh e)) stencilCore repr NativeR{..} gamma aenv halo sh paramsR params = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget future <- new result <- allocateRemote repr sh let @@ -815,7 +815,7 @@ scheduleOp -> Maybe Action -> Par Native () scheduleOp fun gamma aenv shr sz paramsR params done = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget let splits = numWorkers workers - 1 minsize = case shr of @@ -842,7 +842,7 @@ scheduleOpWith -> Maybe Action -- run after the last piece completes -> Par Native () scheduleOpWith splits minsize fun gamma aenv shr sz paramsR params done = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget job <- mkJob splits minsize fun gamma aenv shr (empty shr) sz paramsR params done liftIO $ schedule workers job @@ -858,7 +858,7 @@ scheduleOpUsing -> Maybe Action -> Par Native () scheduleOpUsing ranges fun gamma aenv shr paramsR params jobDone = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget job <- mkJobUsing ranges fun gamma aenv shr paramsR params jobDone liftIO $ schedule workers job @@ -919,7 +919,7 @@ mkTasksUsing -> params -> Par Native (Seq Action) mkTasksUsing ranges (name, f) gamma aenv shr paramsR params = do - arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) + (arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) return $ flip fmap ranges $ \(_,u,v) -> do sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v) let argU = marshalShape' @Native shr u @@ -937,7 +937,7 @@ mkTasksUsingIndex -> params -> Par Native (Seq Action) mkTasksUsingIndex ranges (name, f) gamma aenv shr paramsR params = do - arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) + (arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) return $ flip fmap ranges $ \(i,u,v) -> do sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v) let argU = marshalShape' @Native shr u diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Async.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Async.hs index a951d358a..6dceff996 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Async.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Async.hs @@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.State -- standard library import Control.Concurrent import Control.Monad.Cont -import Control.Monad.State +import Control.Monad.Reader import Data.IORef import Data.Sequence ( Seq ) import qualified Data.Sequence as Seq @@ -78,7 +78,7 @@ data IVar a instance Async Native where type FutureR Native = Future newtype Par Native a = Par { runPar :: ContT () (LLVM Native) a } - deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadState Native ) + deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadReader Native ) {-# INLINE new #-} {-# INLINE newFull #-} @@ -93,7 +93,7 @@ instance Async Native where {-# INLINE get #-} get (Future ref) = callCC $ \k -> do - native <- gets llvmTarget + native <- asks llvmTarget next <- liftIO . atomicModifyIORef' ref $ \case Empty -> (Blocked (Seq.singleton (evalParIO native . k)), reschedule) Blocked ks -> (Blocked (ks Seq.|> evalParIO native . k), reschedule) @@ -102,7 +102,7 @@ instance Async Native where {-# INLINE put #-} put future ref = do - Native{..} <- gets llvmTarget + Native{..} <- asks llvmTarget liftIO (putIO workers future ref) {-# INLINE liftPar #-} diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Marshal.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Marshal.hs index 3ae68b08c..2dd44dcc2 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Marshal.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Marshal.hs @@ -5,6 +5,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -34,9 +35,10 @@ import qualified Foreign.LibFFI as FFI instance Marshal Native where type ArgR Native = FFI.Arg + type MarshalCleanup Native = () marshalInt = $( case finiteBitSize (undefined::Int) of 32 -> [| FFI.argInt32 . fromIntegral |] 64 -> [| FFI.argInt64 . fromIntegral |] _ -> error "I don't know what architecture I am" ) - marshalScalarData' _ = return . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr + marshalScalarData' _ = return . (,()) . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Link.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Link.hs index 830c4afb6..60408ed4d 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Link.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Link.hs @@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.Native.Link.Cache import Data.Array.Accelerate.LLVM.Native.Link.Object import Data.Array.Accelerate.LLVM.Native.Link.Runtime -import Control.Monad.State +import Control.Monad.Reader import Prelude hiding ( lookup ) @@ -48,7 +48,7 @@ instance Link Native where -- link :: ObjectR Native -> LLVM Native (ExecutableR Native) link (ObjectR uid nms _ so) = do - cache <- gets linkCache + cache <- asks linkCache funs <- liftIO $ dlsym uid cache (loadSharedObject nms so) return $! NativeR funs diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Data.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Data.hs index 9ec37b611..62b168f6c 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Data.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Data.hs @@ -44,8 +44,8 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim as Prim import Control.Applicative import Control.Monad -import Control.Monad.Reader -import Control.Monad.State ( gets ) +import Control.Monad.IO.Class ( liftIO ) +import Control.Monad.Reader ( asks ) import System.IO.Unsafe import Prelude @@ -99,7 +99,7 @@ copyToHostLazy (TupRpair r1 r2) (f1, f2) = do a2 <- copyToHostLazy r2 f2 return (a1, a2) copyToHostLazy (TupRsingle (ArrayR shr tp)) future = do - ptx <- gets llvmTarget + ptx <- asks llvmTarget liftIO $ do Array sh adata <- wait future diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Prim.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Prim.hs index 2b901dc44..e90a5cbe9 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Prim.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Prim.hs @@ -116,7 +116,7 @@ pokeArrayAsync !t !n !ad let !src = CUDA.HostPtr (unsafeUniqueArrayPtr ad) !bytes = n * bytesElt (TupRsingle (SingleScalarType t)) -- - stream <- asks ptxStream + stream <- asksParState ptxStream result <- liftPar $ withLifetime stream $ \st -> withDevicePtr t ad $ \dst -> @@ -150,7 +150,7 @@ indexArrayAsync !n !t !ad_src !i let !bytes = n * bytesElt (TupRsingle (SingleScalarType t)) !dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad_dst) -- - stream <- asks ptxStream + stream <- asksParState ptxStream result <- liftPar $ withLifetime stream $ \st -> withDevicePtr t ad_src $ \src -> @@ -179,7 +179,7 @@ peekArrayAsync !t !n !ad let !bytes = n * bytesElt (TupRsingle (SingleScalarType t)) !dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad) -- - stream <- asks ptxStream + stream <- asksParState ptxStream result <- liftPar $ withLifetime stream $ \st -> withDevicePtr t ad $ \src -> @@ -208,7 +208,7 @@ copyArrayAsync !t !n !ad_src !ad_dst = do let !bytes = n * bytesElt (TupRsingle (SingleScalarType t)) -- - stream <- asks ptxStream + stream <- asksParState ptxStream result <- liftPar $ withLifetime stream $ \st -> withDevicePtr t ad_src $ \src -> @@ -287,7 +287,7 @@ memsetArrayAsync !t !n !v !ad = do let !bytes = n * bytesElt (TupRsingle (SingleScalarType t)) -- - stream <- asks ptxStream + stream <- asksParState ptxStream result <- liftPar $ withLifetime stream $ \st -> withDevicePtr t ad $ \ptr -> @@ -350,7 +350,7 @@ nonblocking !stream !action = do return (Nothing, future) else do - future <- Future <$> liftIO (newIORef (Pending event Nothing result)) + future <- Future <$> liftIO (newIORef (Pending event (return ()) result)) return (Just event, future) {-# INLINE withLifetime #-} diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Remote.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Remote.hs index 1efea998f..446f35a0f 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Remote.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Remote.hs @@ -43,7 +43,7 @@ import qualified Foreign.CUDA.Driver as CUDA import qualified Foreign.CUDA.Driver.Stream as CUDA import Control.Exception -import Control.Monad.State +import Control.Monad.Reader import Data.Text.Lazy.Builder import Formatting hiding ( bytes ) import qualified Formatting as F @@ -63,7 +63,7 @@ instance Remote.RemoteMemory (LLVM PTX) where mallocRemote n | n <= 0 = return (Just CUDA.nullDevPtr) | otherwise = do - name <- gets ptxDeviceName + name <- asks ptxDeviceName liftIO $ do ep <- try (CUDA.mallocArray n) case ep of @@ -114,7 +114,7 @@ malloc -> Bool -> LLVM PTX Bool malloc !tp !ad !n !frozen = do - PTX{..} <- gets llvmTarget + PTX{..} <- asks llvmTarget Remote.malloc ptxMemoryTable tp ad frozen n @@ -127,7 +127,7 @@ withRemote -> (CUDA.DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r)) -> LLVM PTX (Maybe r) withRemote !tp !ad !f = do - PTX{..} <- gets llvmTarget + PTX{..} <- asks llvmTarget Remote.withRemote ptxMemoryTable tp ad f diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs index 20b0795a4..d27e67b49 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs @@ -90,7 +90,7 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty as LP import Control.Applicative import Control.Monad ( void ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Data.Bits import Data.Proxy import Data.String @@ -139,7 +139,7 @@ laneMask_ge = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.ge" -- warpId :: CodeGen PTX (Operands Int32) warpId = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties tid <- threadIdx A.quot integralType tid (A.liftInt32 (P.fromIntegral (CUDA.warpSize dev))) @@ -245,7 +245,7 @@ __syncwarp = __syncwarp_mask (liftWord32 0xffffffff) __syncwarp_mask :: HasCallStack => Operands Word32 -> CodeGen PTX () __syncwarp_mask mask = do llvmver <- getLLVMversion - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties case (computeCapability dev >= Compute 7 0, llvmver >= 6) of (True, True) -> void $ call (Lam primType (op primType mask) (Body VoidType (Just Tail) "llvm.nvvm.bar.warp.sync")) [NoUnwind, NoDuplicate, Convergent] (True, False) -> internalError "LLVM-6.0 or above is required for Volta devices and later" @@ -506,7 +506,7 @@ shfl_op -> Operands a -- value to give -> CodeGen PTX (Operands a) -- value received shfl_op sop t delta val = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties let -- The CUDA __shfl* instruction take an optional final parameter @@ -762,7 +762,7 @@ makeOpenAcc -> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv a) makeOpenAcc uid name param kernel = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties makeOpenAccWith (simpleLaunchConfig dev) uid name param kernel -- | Create a single kernel program with the given launch analysis information. diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Fold.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Fold.hs index abbca01ba..b61fa2aad 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Fold.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Fold.hs @@ -46,7 +46,7 @@ import LLVM.AST.Type.Representation import qualified Foreign.CUDA.Analysis as CUDA import Control.Monad ( (>=>) ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Data.String ( fromString ) import Data.Bits as P import Prelude as P @@ -105,7 +105,7 @@ mkFoldAll -> MIRDelayed PTX aenv (Vector e) -- ^ input data -> CodeGen PTX (IROpenAcc PTX aenv (Scalar e)) mkFoldAll uid aenv tp combine mseed macc = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties foldr1 (+++) <$> sequence [ mkFoldAllS uid dev aenv tp combine mseed macc , mkFoldAllM1 uid dev aenv tp combine macc , mkFoldAllM2 uid dev aenv tp combine mseed @@ -303,7 +303,7 @@ mkFoldDim -> MIRDelayed PTX aenv (Array (sh, Int) e) -- ^ input data -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e)) mkFoldDim uid aenv repr@(ArrayR shr tp) combine mseed marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray repr "out" diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/FoldSeg.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/FoldSeg.hs index 259a49cd3..1b9f34a87 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/FoldSeg.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/FoldSeg.hs @@ -45,7 +45,7 @@ import LLVM.AST.Type.Representation import qualified Foreign.CUDA.Analysis as CUDA import Control.Monad ( void ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Data.String ( fromString ) import Prelude as P @@ -88,7 +88,7 @@ mkFoldSegP_block -> MIRDelayed PTX aenv (Segments i) -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)) mkFoldSegP_block uid aenv repr@(ArrayR shr tp) intTp combine mseed marr mseg = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray repr "out" @@ -283,7 +283,7 @@ mkFoldSegP_warp -> MIRDelayed PTX aenv (Segments i) -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)) mkFoldSegP_warp uid aenv repr@(ArrayR shr tp) intTp combine mseed marr mseg = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray repr "out" diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Permute.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Permute.hs index d28b3846e..311ce77aa 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Permute.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Permute.hs @@ -55,7 +55,7 @@ import LLVM.AST.Type.Representation import Foreign.CUDA.Analysis import Control.Monad ( void ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Prelude @@ -126,7 +126,7 @@ mkPermute_rmw -> MIRDelayed PTX aenv (Array sh e) -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e)) mkPermute_rmw uid aenv (ArrayR shr tp) shr' rmw update project marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let outR = ArrayR shr' tp @@ -267,7 +267,7 @@ atomically -> CodeGen PTX a -> CodeGen PTX a atomically barriers i action = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties if computeCapability dev >= Compute 7 0 then atomically_thread barriers i action else atomically_warp barriers i action diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Scan.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Scan.hs index 0318d64aa..6868f4d7b 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Scan.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Scan.hs @@ -51,7 +51,7 @@ import qualified Foreign.CUDA.Analysis as CUDA import Control.Applicative import Control.Monad ( (>=>), void ) -import Control.Monad.State ( gets ) +import Control.Monad.Reader ( asks ) import Data.String ( fromString ) import Data.Coerce as Safe import Data.Bits as P @@ -149,7 +149,7 @@ mkScanAllP1 -> MIRDelayed PTX aenv (Vector e) -- ^ input data -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)) mkScanAllP1 dir uid aenv tp combine mseed marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray (ArrayR dim1 tp) "out" @@ -269,7 +269,7 @@ mkScanAllP2 -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)) mkScanAllP2 dir uid aenv tp combine = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrTmp, paramTmp) = mutableArray (ArrayR dim1 tp) "tmp" @@ -357,7 +357,7 @@ mkScanAllP3 -> MIRExp PTX aenv e -- ^ seed element, if this is an exclusive scan -> CodeGen PTX (IROpenAcc PTX aenv (Vector e)) mkScanAllP3 dir uid aenv tp combine mseed = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray (ArrayR dim1 tp) "out" @@ -456,7 +456,7 @@ mkScan'AllP1 -> MIRDelayed PTX aenv (Vector e) -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP1 dir uid aenv tp combine seed marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray (ArrayR dim1 tp) "out" @@ -569,7 +569,7 @@ mkScan'AllP2 -> IRFun2 PTX aenv (e -> e -> e) -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP2 dir uid aenv tp combine = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrTmp, paramTmp) = mutableArray (ArrayR dim1 tp) "tmp" @@ -666,7 +666,7 @@ mkScan'AllP3 -> IRFun2 PTX aenv (e -> e -> e) -- ^ combination function -> CodeGen PTX (IROpenAcc PTX aenv (Vector e, Scalar e)) mkScan'AllP3 dir uid aenv tp combine = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray (ArrayR dim1 tp) "out" @@ -754,7 +754,7 @@ mkScanDim -> MIRDelayed PTX aenv (Array (sh, Int) e) -- ^ input data -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e)) mkScanDim dir uid aenv repr@(ArrayR (ShapeRsnoc shr) tp) combine mseed marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrOut, paramOut) = mutableArray repr "out" @@ -962,7 +962,7 @@ mkScan'Dim -> MIRDelayed PTX aenv (Array (sh, Int) e) -- ^ input data -> CodeGen PTX (IROpenAcc PTX aenv (Array (sh, Int) e, Array sh e)) mkScan'Dim dir uid aenv repr@(ArrayR (ShapeRsnoc shr) tp) combine seed marr = do - dev <- liftCodeGen $ gets ptxDeviceProperties + dev <- liftCodeGen $ asks ptxDeviceProperties -- let (arrSum, paramSum) = mutableArray (reduceRank repr) "sum" diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile.hs index f25188382..ca79e84a1 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile.hs @@ -49,7 +49,7 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty.PP as LP import qualified Text.PrettyPrint as Pretty import Control.Monad ( when ) -import Control.Monad.State +import Control.Monad.Reader import Data.ByteString.Short ( ShortByteString ) import Data.List ( intercalate ) import qualified Data.List.NonEmpty as NE @@ -86,7 +86,7 @@ compile pacc aenv = do -- Generate code for this Acc operation -- - dev <- gets ptxDeviceProperties + dev <- asks ptxDeviceProperties let CUDA.Compute m n = CUDA.computeCapability dev let arch = printf "sm_%d%d" m n (uid, cacheFile) <- cacheOfPreOpenAcc pacc diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile/Cache.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile/Cache.hs index 723528adb..2a03b75d1 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile/Cache.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Compile/Cache.hs @@ -19,7 +19,7 @@ import Data.Array.Accelerate.LLVM.Compile.Cache import Data.Array.Accelerate.LLVM.PTX.Target import Data.Array.Accelerate.LLVM.Target.ClangInfo ( hostLLVMVersion ) -import Control.Monad.State +import Control.Monad.Reader import Data.Foldable ( toList ) import Data.List ( intercalate ) import Data.Version @@ -33,7 +33,7 @@ import Paths_accelerate_llvm_ptx instance Persistent PTX where targetCacheTemplate = do - Compute m n <- gets (computeCapability . ptxDeviceProperties) + Compute m n <- asks (computeCapability . ptxDeviceProperties) return $ "accelerate-llvm-ptx-" ++ showVersion version "llvmpr-" ++ intercalate "." (map show (toList hostLLVMVersion)) S8.unpack ptxTargetTriple diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute.hs index 733be48c8..79e40d2bd 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute.hs @@ -49,8 +49,7 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event as Event import qualified Foreign.CUDA.Driver as CUDA -import Control.Monad ( when, forM_ ) -import Control.Monad.Reader ( asks, local ) +import Control.Monad ( forM_ ) import Control.Monad.State ( liftIO ) import Data.ByteString.Short.Char8 ( ShortByteString, unpack ) import Data.List ( find ) @@ -131,8 +130,8 @@ simpleOp name repr exe gamma aenv sh = result <- allocateRemote repr sh -- let paramR = TupRsingle $ ParamRarray repr - executeOp (ptxExecutable !# name) gamma aenv (arrayRshape repr) sh paramR result - put future result + cleanup <- executeOp (ptxExecutable !# name) gamma aenv (arrayRshape repr) sh paramR result + putCleanup future cleanup result return future -- Mapping over an array can ignore the dimensionality of the array and @@ -158,8 +157,8 @@ mapOp inplace repr tp exe gamma aenv input@(shape -> sh) = Nothing -> allocateRemote reprOut sh -- let paramsR = TupRsingle (ParamRarray reprOut) `TupRpair` TupRsingle (ParamRarray repr) - executeOp (ptxExecutable !# "map") gamma aenv (arrayRshape repr) sh paramsR (result, input) - put future result + cleanup <- executeOp (ptxExecutable !# "map") gamma aenv (arrayRshape repr) sh paramsR (result, input) + putCleanup future cleanup result return future {-# INLINE generateOp #-} @@ -189,8 +188,8 @@ transformOp repr repr' exe gamma aenv sh' input = future <- new result <- allocateRemote repr' sh' let paramsR = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr) - executeOp (ptxExecutable !# "transform") gamma aenv (arrayRshape repr') sh' paramsR (result, input) - put future result + cleanup <- executeOp (ptxExecutable !# "transform") gamma aenv (arrayRshape repr') sh' paramsR (result, input) + putCleanup future cleanup result return future {-# INLINE backpermuteOp #-} @@ -291,29 +290,30 @@ foldAllOp tp exe gamma aenv input = -- The array is small enough that we can compute it in a single step result <- allocateRemote (ArrayR dim0 tp) () let paramsR = paramsRdim0 `TupRpair` paramsRinput - executeOp ks gamma aenv dim1 sh paramsR (result, manifest input) - put future result + cleanup <- executeOp ks gamma aenv dim1 sh paramsR (result, manifest input) + putCleanup future cleanup result else do -- Multi-kernel reduction to a single element. The first kernel integrates -- any delayed elements, and the second is called recursively until -- reaching a single element. + -- The cleanup function is accumulated. let - rec :: Vector e -> Par PTX () - rec tmp@(Array ((),m) adata) - | m <= 1 = put future (Array () adata) + rec :: Vector e -> IO () -> Par PTX () + rec tmp@(Array ((),m) adata) cleanup + | m <= 1 = putCleanup future cleanup (Array () adata) | otherwise = do let sh' = ((), m `multipleOf` kernelThreadBlockSize km2) out <- allocateRemote (ArrayR dim1 tp) sh' let paramsR2 = paramsRdim1 `TupRpair` paramsRdim1 - executeOp km2 gamma aenv dim1 sh' paramsR2 (tmp, out) - rec out + cleanup2 <- executeOp km2 gamma aenv dim1 sh' paramsR2 (tmp, out) + rec out (cleanup >> cleanup2) -- let sh' = ((), n `multipleOf` kernelThreadBlockSize km1) tmp <- allocateRemote (ArrayR dim1 tp) sh' let paramsR1 = paramsRdim1 `TupRpair` paramsRinput - executeOp km1 gamma aenv dim1 sh' paramsR1 (tmp, manifest input) - rec tmp + cleanup <- executeOp km1 gamma aenv dim1 sh' paramsR1 (tmp, manifest input) + rec tmp cleanup -- return future @@ -335,8 +335,8 @@ foldDimOp repr@(ArrayR shr tp) exe gamma aenv input@(delayedShape -> (sh, sz)) result <- allocateRemote repr sh -- let paramsR = TupRsingle (ParamRarray repr) `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray $ ArrayR (ShapeRsnoc shr) tp) - executeOp (ptxExecutable !# "fold") gamma aenv shr sh paramsR (result, manifest input) - put future result + cleanup <- executeOp (ptxExecutable !# "fold") gamma aenv shr sh paramsR (result, manifest input) + putCleanup future cleanup result return future @@ -369,8 +369,8 @@ foldSegOp intTp repr exe gamma aenv input@(delayedShape -> (sh, sz)) segments@(d future <- new result <- allocateRemote repr (sh, n) let paramsR = TupRsingle (ParamRarray repr) `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray repr) `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray reprSeg) - executeOp foldseg gamma aenv dim1 ((), m) paramsR ((result, manifest input), manifest segments) - put future result + cleanup <- executeOp foldseg gamma aenv dim1 ((), m) paramsR ((result, manifest input), manifest segments) + putCleanup future cleanup result return future @@ -451,15 +451,20 @@ scanAllOp tp exe gamma aenv m input@(delayedShape -> ((), n)) = -- which can be computed by a single thread block will require no -- additional work. tmp <- allocateRemote repr ((), s) - executeOp k1 gamma aenv dim1 ((), s) paramsR1 ((tmp, result), manifest input) + cleanup1 <- executeOp k1 gamma aenv dim1 ((), s) paramsR1 ((tmp, result), manifest input) -- Step 2: Multi-block reductions need to compute the per-block prefix, -- then apply those values to the partial results. - when (s > 1) $ do - executeOp k2 gamma aenv dim1 ((), s) paramR tmp - executeOp k3 gamma aenv dim1 ((), s-1) paramsR3 ((tmp, result), c) - - put future result + cleanup2 <- + if s > 1 + then do + cleanup2a <- executeOp k2 gamma aenv dim1 ((), s) paramR tmp + cleanup2b <- executeOp k3 gamma aenv dim1 ((), s-1) paramsR3 ((tmp, result), c) + return (cleanup2a >> cleanup2b) + else + return (return ()) + + putCleanup future (cleanup1 >> cleanup2) result return future {-# INLINE scanDimOp #-} @@ -478,8 +483,8 @@ scanDimOp repr exe gamma aenv m input@(delayedShape -> (sz, _)) = future <- new result <- allocateRemote repr (sz, m) let paramsR = TupRsingle (ParamRarray repr) `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray repr) - executeOp (ptxExecutable !# "scan") gamma aenv dim1 ((), size shr' sz) paramsR (result, manifest input) - put future result + cleanup <- executeOp (ptxExecutable !# "scan") gamma aenv dim1 ((), size shr' sz) paramsR (result, manifest input) + putCleanup future cleanup result return future @@ -549,7 +554,7 @@ scan'AllOp tp exe gamma aenv input@(delayedShape -> ((), n)) = -- Step 1: independent thread-block-wide scans. Each block stores its partial -- sum to a temporary array. let paramsR1 = paramRdim1 `TupRpair` paramRdim1 `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray repr) - executeOp k1 gamma aenv dim1 ((), s) paramsR1 ((tmp, result), manifest input) + cleanup1 <- executeOp k1 gamma aenv dim1 ((), s) paramsR1 ((tmp, result), manifest input) -- If this was a small array that was processed by a single thread block then -- we are done, otherwise compute the per-block prefix and apply those values @@ -557,15 +562,15 @@ scan'AllOp tp exe gamma aenv input@(delayedShape -> ((), n)) = if s == 1 then case tmp of - Array _ ad -> put future (result, Array () ad) + Array _ ad -> putCleanup future cleanup1 (result, Array () ad) else do sums <- allocateRemote (ArrayR dim0 tp) () let paramsR2 = paramRdim1 `TupRpair` paramRdim0 let paramsR3 = paramRdim1 `TupRpair` paramRdim1 `TupRpair` TupRsingle ParamRint - executeOp k2 gamma aenv dim1 ((), s) paramsR2 (tmp, sums) - executeOp k3 gamma aenv dim1 ((), s-1) paramsR3 ((tmp, result), c) - put future (result, sums) + cleanup2 <- executeOp k2 gamma aenv dim1 ((), s) paramsR2 (tmp, sums) + cleanup3 <- executeOp k3 gamma aenv dim1 ((), s-1) paramsR3 ((tmp, result), c) + putCleanup future (cleanup1 >> cleanup2 >> cleanup3) (result, sums) -- return future @@ -584,8 +589,8 @@ scan'DimOp repr@(ArrayR (ShapeRsnoc shr') _) exe gamma aenv input@(delayedShape result <- allocateRemote repr sh sums <- allocateRemote (reduceRank repr) sz let paramsR = TupRsingle (ParamRarray repr) `TupRpair` TupRsingle (ParamRarray $ reduceRank repr) `TupRpair` TupRsingle (ParamRmaybe $ ParamRarray repr) - executeOp (ptxExecutable !# "scan") gamma aenv dim1 ((), size shr' sz) paramsR ((result, sums), manifest input) - put future (result, sums) + cleanup <- executeOp (ptxExecutable !# "scan") gamma aenv dim1 ((), size shr' sz) paramsR ((result, sums), manifest input) + putCleanup future cleanup (result, sums) return future @@ -622,7 +627,7 @@ permuteOp inplace repr@(ArrayR shr tp) shr' exe gamma aenv defaults@(shape -> sh let kernelName' = let kn = kernelName kernel in SE.take (S.length kn - 65) kn - case kernelName' of + cleanup <- case kernelName' of -- execute directly using atomic operations "permute_rmw" -> let paramsR = paramR' `TupRpair` paramR @@ -640,7 +645,7 @@ permuteOp inplace repr@(ArrayR shr tp) shr' exe gamma aenv defaults@(shape -> sh _ -> internalError "unexpected kernel image" -- - put future result + putCleanup future cleanup result return future @@ -708,12 +713,12 @@ stencilCore repr@(ArrayR shr _) exe gamma aenv halo shOut paramsR params = -- future <- new result <- allocateRemote repr shOut - parent <- asks ptxStream + parent <- asksParState ptxStream parentStartPoint <- liftPar (Event.waypoint parent) -- interior (no bounds checking) let paramsRinside = TupRsingle (ParamRshape shr) `TupRpair` TupRsingle (ParamRarray repr) `TupRpair` paramsR - executeOp inside gamma aenv shr shIn paramsRinside ((shIn, result), params) + cleanup1 <- executeOp inside gamma aenv shr shIn paramsRinside ((shIn, result), params) -- halo regions (bounds checking) -- executed in separate streams so that they might overlap the main stencil @@ -722,7 +727,7 @@ stencilCore repr@(ArrayR shr _) exe gamma aenv halo shOut paramsR params = fork $ do -- synchronise with start of stencil computation, so that the arguments -- are available - child <- asks ptxStream + child <- asksParState ptxStream liftIO (Event.after parentStartPoint child) -- launch in a separate stream @@ -730,7 +735,8 @@ stencilCore repr@(ArrayR shr _) exe gamma aenv halo shOut paramsR params = let paramsRborder = TupRsingle (ParamRshape shr) `TupRpair` TupRsingle (ParamRshape shr) `TupRpair` TupRsingle (ParamRarray repr) `TupRpair` paramsR - executeOp border gamma aenv shr sh paramsRborder (((u, sh), result), params) + cleanup2 <- executeOp border gamma aenv shr sh paramsRborder (((u, sh), result), params) + addCleanup future cleanup2 -- make remainder of the parent stream depend on the border results event <- liftPar (Event.waypoint child) @@ -738,7 +744,7 @@ stencilCore repr@(ArrayR shr _) exe gamma aenv halo shOut paramsR params = if ready then return () else liftIO (Event.after event parent) - put future result + putCleanup future cleanup1 result return future -- Compute the stencil border regions, where we may need to evaluate the @@ -782,7 +788,7 @@ aforeignOp -> as -> Par PTX (Future bs) aforeignOp name _ _ asm arr = do - stream <- asks ptxStream + stream <- asksParState ptxStream Debug.monitorProcTime query msg (Just (unsafeGetValue stream)) (asm arr) where msg = Debug.traceM Debug.dump_exec ("exec: " % string % " " % Debug.elapsed) name @@ -818,7 +824,7 @@ manifest Delayed{} = Nothing -- withExecutable :: HasCallStack => ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b withExecutable PTXR{..} f = - local (\(s,_) -> (s,Just ptxExecutable)) $ do + localParState (\(s,_) -> (s,Just ptxExecutable)) $ do r <- f (unsafeGetValue ptxExecutable) liftIO $ touchLifetime ptxExecutable return r @@ -835,13 +841,17 @@ executeOp -> sh -> ParamsR PTX params -> params - -> Par PTX () + -> Par PTX (IO ()) executeOp kernel gamma aenv shr sh paramsR params = let n = size shr sh - in when (n > 0) $ do - stream <- asks ptxStream - argv <- marshalParams' @PTX (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) - liftIO $ launch kernel stream n $ DL.toList argv + in if n > 0 + then do + stream <- asksParState ptxStream + (argv, cleanup) <- marshalParams' @PTX (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv) + liftIO $ launch kernel stream n $ DL.toList argv + return cleanup + else + return (return ()) -- Execute a device function with the given thread configuration and function diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Async.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Async.hs index 9a70619b8..8bb94846f 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Async.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Async.hs @@ -37,7 +37,6 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event as Event import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Stream as Stream import Control.Monad.Reader -import Control.Monad.State import Data.IORef @@ -67,32 +66,45 @@ data Future a = Future {-# UNPACK #-} !(IORef (IVar a)) data IVar a = Full !a - | Pending {-# UNPACK #-} !Event !(Maybe (Lifetime FunctionTable)) !a - | Empty + | Pending {-# UNPACK #-} !Event !(IO ()) !a + | Empty !(IO ()) +askParState :: Par PTX ParState +askParState = Par ask + +asksParState :: (ParState -> a) -> Par PTX a +asksParState f = Par (asks f) + +localParState :: (ParState -> ParState) -> Par PTX a -> Par PTX a +localParState f (Par m) = Par (local f m) + +instance MonadReader PTX (Par PTX) where + ask = Par (lift ask) + local f (Par (ReaderT g)) = Par (ReaderT (\parstate -> local f (g parstate))) + instance Async PTX where type FutureR PTX = Future newtype Par PTX a = Par { runPar :: ReaderT ParState (LLVM PTX) a } - deriving ( Functor, Applicative, Monad, MonadIO, MonadReader ParState, MonadState PTX ) + deriving ( Functor, Applicative, Monad, MonadIO ) {-# INLINEABLE new #-} {-# INLINEABLE newFull #-} - new = Future <$> liftIO (newIORef Empty) + new = Future <$> liftIO (newIORef (Empty (return ()))) newFull v = Future <$> liftIO (newIORef (Full v)) {-# INLINEABLE spawn #-} spawn m = do s' <- liftPar Stream.create - r <- local (const (s', Nothing)) m + r <- localParState (const (s', Nothing)) m liftIO (Stream.destroy s') return r {-# INLINEABLE fork #-} fork m = do s' <- liftPar (Stream.create) - () <- local (const (s', Nothing)) m + () <- localParState (const (s', Nothing)) m liftIO (Stream.destroy s') -- When we call 'put' the actual work may not have been evaluated yet; get @@ -101,13 +113,16 @@ instance Async PTX where -- {-# INLINEABLE put #-} put (Future ref) v = do - stream <- asks ptxStream - kernel <- asks ptxKernel + stream <- asksParState ptxStream + kernel <- asksParState ptxKernel event <- liftPar (Event.waypoint stream) ready <- liftIO (Event.query event) - liftIO . modifyIORef' ref $ \case - Empty -> if ready then Full v - else Pending event kernel v + let cleanupK = case kernel of + Just k -> touchLifetime k + Nothing -> return () + liftIO . atomicModifyIORef' ref $ \case + Empty cleanup -> if ready then (Full v, ()) + else (Pending event (cleanup >> cleanupK) v, ()) _ -> internalError "multiple put" -- Get the value of Future. Since the actual cross-stream synchronisation @@ -117,23 +132,21 @@ instance Async PTX where -- {-# INLINEABLE get #-} get (Future ref) = do - stream <- asks ptxStream + stream <- asksParState ptxStream liftIO $ do ivar <- readIORef ref case ivar of Full v -> return v - Pending event k v -> do + Pending event cleanup v -> do ready <- Event.query event if ready then do writeIORef ref (Full v) - case k of - Just f -> touchLifetime f - Nothing -> return () + cleanup else Event.after event stream return v - Empty -> internalError "blocked on an IVar" + Empty _ -> internalError "blocked on an IVar" {-# INLINEABLE block #-} block = liftIO . wait @@ -151,12 +164,34 @@ wait (Future ref) = do ivar <- readIORef ref case ivar of Full v -> return v - Pending event k v -> do + Pending event cleanup v -> do Event.block event writeIORef ref (Full v) - case k of - Just f -> touchLifetime f - Nothing -> return () + cleanup return v - Empty -> internalError "blocked on an IVar" + Empty _ -> internalError "blocked on an IVar" + +{-# INLINEABLE putCleanup #-} +putCleanup :: HasCallStack => FutureR PTX a -> IO () -> a -> Par PTX () +putCleanup (Future ref) cleanup v = do + stream <- asksParState ptxStream + kernel <- asksParState ptxKernel + event <- liftPar (Event.waypoint stream) + ready <- liftIO (Event.query event) + let cleanupK = case kernel of + Just k -> touchLifetime k + Nothing -> return () + liftIO . atomicModifyIORef' ref $ \case + Empty cleanup2 -> if ready then (Full v, ()) + else (Pending event (cleanup2 >> cleanup >> cleanupK) v, ()) + _ -> internalError "multiple put" + +{-# INLINEABLE addCleanup #-} +addCleanup :: HasCallStack => FutureR PTX a -> IO () -> Par PTX () +addCleanup (Future ref) cleanup = liftIO $ do + toRunNow <- atomicModifyIORef' ref $ \case + Full v -> (Full v, cleanup) + Pending event cleanup2 v -> (Pending event (cleanup2 >> cleanup) v, return ()) + Empty cleanup2 -> (Empty (cleanup2 >> cleanup), return ()) + toRunNow diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Event.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Event.hs index 44145ceea..92a3a4663 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Event.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Event.hs @@ -33,7 +33,7 @@ import qualified Foreign.CUDA.Driver.Stream as Stream import Control.Exception import Control.Monad -import Control.Monad.State +import Control.Monad.Reader import Data.Text.Lazy.Builder import Formatting @@ -50,7 +50,7 @@ type Event = Lifetime Event.Event {-# INLINEABLE create #-} create :: LLVM PTX Event create = do - ctx <- gets ptxContext + ctx <- asks ptxContext e <- create' event <- liftIO $ newLifetime e liftIO $ addFinalizer event $ do @@ -61,7 +61,7 @@ create = do create' :: LLVM PTX Event.Event create' = do - PTX{ptxMemoryTable} <- gets llvmTarget + PTX{ptxMemoryTable} <- asks llvmTarget me <- attempt "create/new" (liftIO . catchOOM $ Event.create [Event.DisableTiming]) `orElse` do Remote.reclaim ptxMemoryTable diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Marshal.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Marshal.hs index 869863aa5..80f8d469d 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Marshal.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Marshal.hs @@ -37,27 +37,72 @@ import Data.Array.Accelerate.Array.Data import qualified Foreign.CUDA.Driver as CUDA +import Control.Concurrent +import Control.Monad.IO.Class (liftIO) import qualified Data.DList as DL instance Marshal PTX where type ArgR PTX = CUDA.FunParam + type MarshalCleanup PTX = IO () marshalInt = CUDA.VArg marshalScalarData' t | SingleArrayDict <- singleArrayDict t - = liftPar . fmap (DL.singleton . CUDA.VArg) . unsafeGetDevicePtr t + = liftPar . fmap (\(ptr, cleanup) -> (DL.singleton (CUDA.VArg ptr), cleanup)) . getCudaDevicePtr t --- TODO FIXME !!! +-- | Return the CUDA device pointer corresponding to the given array, as well +-- as a cleanup IO action that __MUST__ be run once you are done with the +-- pointer (i.e. the GPU kernel has completed). Not calling the cleanup action +-- will result in leaked memory and resources. Calling the action twice will +-- block indefinitely on an MVar. -- --- We will probably need to change marshal to be a bracketed function, so that --- the garbage collector does not try to evict the array in the middle of --- a computation. +-- This function is a hack. Prim.withDevicePtr is intended to be a wrapping +-- function that retains the resource while the callback is running and +-- releases it when the callback returns. This is all nice, but since the PTX +-- Accelerate runtime is asynchronous, uses of withDevicePtr would not all be +-- neatly nested: the actual array lifetimes are haphazard intervals during +-- program execution. -- -unsafeGetDevicePtr +-- Originally, this function just gave up and extracted the DevicePtr by +-- calling withDevicePtr with a trivial body that simply leaks p; this is +-- unsound (which was acknowledged by a 'fixme' comment...) and appears to have +-- been the cause of silent incorrect results (!) on a GTX 1050 Ti on the +-- adbench-gmmgrad test in accelerate-tests [1]. +-- +-- [1]: https://github.com/tomsmeding/accelerate-tests/blob/master/src/Data/Array/Accelerate/Tests/Prog/ADBenchGMMGrad.hs +-- +-- Fortunately, it turns out that the MemoryTable implementation underlying +-- withDevicePtr does not in fact assume lexical nesting of array usages. Thus +-- we can use a hack to let the callback of withDevicePtr live for the correct +-- amount of time without needing to rearchitect the entire PTX backend: let +-- the call run in a forkIO thread and use MVars to communicate when it should +-- return. This means that we now have the possibility to return a +-- self-contained "cleanup" handler from getCudaDevicePtr that does nothing but +-- signal to the withDevicePtr callback that the array's lifetime has ended and +-- the scope can close. All this is possible because the 'LLVM' monad is just a +-- reader monad over IO, so we can unlift it into IO. +-- +-- As a final, questionable improvement, we let the "cleanup" handler wait +-- until withDevicePtr has properly returned so that we know the array's +-- refcount has been properly decremented and memory has been released if +-- possible. +getCudaDevicePtr :: SingleType e -> ArrayData e - -> LLVM PTX (CUDA.DevicePtr (ScalarArrayDataR e)) -unsafeGetDevicePtr !t !ad = - Prim.withDevicePtr t ad (\p -> return (Nothing, p)) + -> LLVM PTX (CUDA.DevicePtr (ScalarArrayDataR e), IO ()) +getCudaDevicePtr !t !ad = do + ptrVar <- liftIO newEmptyMVar + doneVar <- liftIO newEmptyMVar + releasedVar <- liftIO newEmptyMVar + + _ <- unliftIOLLVM $ \inLLVM -> forkIO $ inLLVM $ do + Prim.withDevicePtr t ad $ \p -> liftIO $ do + putMVar ptrVar p + takeMVar doneVar + return (Nothing, ()) + liftIO $ putMVar releasedVar () + + ptr <- liftIO $ readMVar ptrVar + return (ptr, putMVar doneVar () >> readMVar releasedVar) diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Stream.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Stream.hs index e8ff9e704..6192dcde0 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Stream.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Execute/Stream.hs @@ -35,7 +35,7 @@ import qualified Foreign.CUDA.Driver.Stream as Stream import Control.Exception import Control.Monad -import Control.Monad.State +import Control.Monad.Reader import Data.Text.Lazy.Builder import Formatting @@ -111,7 +111,7 @@ flush !Context{..} !ref = do {-# INLINEABLE create #-} create :: LLVM PTX Stream create = do - PTX{..} <- gets llvmTarget + PTX{..} <- asks llvmTarget s <- create' stream <- liftIO $ newLifetime s liftIO $ addFinalizer stream (RSV.insert ptxStreamReservoir s) @@ -119,7 +119,7 @@ create = do create' :: LLVM PTX Stream.Stream create' = do - PTX{..} <- gets llvmTarget + PTX{..} <- asks llvmTarget ms <- attempt "create/reservoir" (liftIO $ RSV.malloc ptxStreamReservoir) `orElse` attempt "create/new" (liftIO . catchOOM $ Stream.create []) diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Link.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Link.hs index ef2e78396..0b8fc61c2 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Link.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Link.hs @@ -36,7 +36,7 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug import qualified Foreign.CUDA.Analysis as CUDA import qualified Foreign.CUDA.Driver as CUDA -import Control.Monad.State +import Control.Monad.Reader import Data.ByteString.Short.Char8 ( ShortByteString, unpack ) import Formatting import Foreign.Ptr @@ -56,8 +56,8 @@ instance Link PTX where -- link :: ObjectR PTX -> LLVM PTX (ExecutableR PTX) link (ObjectR uid cfg objFname) = do - target <- gets llvmTarget - cache <- gets ptxKernelTable + target <- asks llvmTarget + cache <- asks ptxKernelTable funs <- liftIO $ dlsym uid cache $ do -- Load the SASS object code into the current CUDA context obj <- B.readFile objFname diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Monad.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Monad.hs index 7eda914dd..7ceba3ac1 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Monad.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Monad.hs @@ -50,7 +50,7 @@ import Data.Array.Accelerate.LLVM.CodeGen.IR import Data.Array.Accelerate.LLVM.CodeGen.Intrinsic import Data.Array.Accelerate.LLVM.CodeGen.Module import Data.Array.Accelerate.LLVM.CodeGen.Sugar ( IROpenAcc(..) ) -import Data.Array.Accelerate.LLVM.State ( LLVM ) +import Data.Array.Accelerate.LLVM.State ( LLVM, getLLVMVer ) import Data.Array.Accelerate.LLVM.Target import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type @@ -70,7 +70,7 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty.Triple.Parse as import Control.Applicative import Control.Monad -import Control.Monad.Reader ( ReaderT, MonadReader, runReaderT, ask, asks ) +import Control.Monad.Reader ( ReaderT, MonadReader, runReaderT, asks ) import Control.Monad.State import Data.ByteString.Short ( ShortByteString ) import qualified Data.ByteString.Short.Char8 as SBS8 @@ -136,7 +136,7 @@ evalCodeGen => CodeGen arch (IROpenAcc arch aenv a) -> LLVM arch (Module arch aenv a) evalCodeGen ll = do - llvmver <- ask + llvmver <- getLLVMVer let context = CodeGenContext { codegenLLVMversion = llvmver } (IROpenAcc ks, st) <- runStateT (runReaderT (runCodeGen ll) context) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Execute/Marshal.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Execute/Marshal.hs index c5a949f86..65ee76161 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Execute/Marshal.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/Execute/Marshal.hs @@ -35,6 +35,7 @@ import Data.Array.Accelerate.LLVM.CodeGen.Environment ( Gamma, Idx'(.. import Data.Array.Accelerate.LLVM.Execute.Environment import Data.Array.Accelerate.LLVM.Execute.Async +import Data.Bifunctor ( first ) import Data.DList ( DList ) import qualified Data.DList as DL import qualified Data.IntMap as IM @@ -42,49 +43,53 @@ import qualified Data.IntMap as IM -- Marshalling arguments -- --------------------- -class Async arch => Marshal arch where +class (Async arch, Monoid (MarshalCleanup arch)) => Marshal arch where -- | A type family that is used to specify a concrete kernel argument and -- stream/context type for a given backend target. -- type ArgR arch + -- | A cleanup action for a marshalled argument. On PTX this is an IO action; + -- on Native, no cleanup is necessary. + type MarshalCleanup arch + -- | Used to pass shapes as arguments to kernels. marshalInt :: Int -> ArgR arch -- | Pass arrays to kernels - marshalScalarData' :: SingleType e -> ScalarArrayData e -> Par arch (DList (ArgR arch)) + marshalScalarData' :: SingleType e -> ScalarArrayData e -> Par arch (DList (ArgR arch), MarshalCleanup arch) -- | Convert function arguments into stream a form suitable for function calls -- The functions ending in a prime return a DList, other functions return lists. -- -marshalArrays :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch [ArgR arch] -marshalArrays repr arrs = DL.toList <$> marshalArrays' @arch repr arrs +marshalArrays :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch ([ArgR arch], MarshalCleanup arch) +marshalArrays repr arrs = first DL.toList <$> marshalArrays' @arch repr arrs -marshalArrays' :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch (DList (ArgR arch)) +marshalArrays' :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch (DList (ArgR arch), MarshalCleanup arch) marshalArrays' = marshalTupR' @arch (marshalArray' @arch) -marshalArray' :: forall arch a. Marshal arch => ArrayR a -> a -> Par arch (DList (ArgR arch)) +marshalArray' :: forall arch a. Marshal arch => ArrayR a -> a -> Par arch (DList (ArgR arch), MarshalCleanup arch) marshalArray' (ArrayR shr tp) (Array sh a) = do - arg1 <- marshalArrayData' @arch tp a + (arg1, c1) <- marshalArrayData' @arch tp a let arg2 = marshalShape' @arch shr sh - return $ arg1 `DL.append` arg2 + return (arg1 `DL.append` arg2, c1) -marshalArrayData' :: forall arch t. Marshal arch => TypeR t -> ArrayData t -> Par arch (DList (ArgR arch)) -marshalArrayData' TupRunit () = return DL.empty +marshalArrayData' :: forall arch t. Marshal arch => TypeR t -> ArrayData t -> Par arch (DList (ArgR arch), MarshalCleanup arch) +marshalArrayData' TupRunit () = return (DL.empty, mempty) marshalArrayData' (TupRpair t1 t2) (a1, a2) = do - l1 <- marshalArrayData' t1 a1 - l2 <- marshalArrayData' t2 a2 - return $ l1 `DL.append` l2 + (l1, c1) <- marshalArrayData' t1 a1 + (l2, c2) <- marshalArrayData' t2 a2 + return (l1 `DL.append` l2, c1 <> c2) marshalArrayData' (TupRsingle t) ad | ScalarArrayDict _ s <- scalarArrayDict t = marshalScalarData' @arch s ad -marshalEnv :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch [ArgR arch] -marshalEnv g a = DL.toList <$> marshalEnv' g a +marshalEnv :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch ([ArgR arch], MarshalCleanup arch) +marshalEnv g a = first DL.toList <$> marshalEnv' g a -marshalEnv' :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch)) +marshalEnv' :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch), MarshalCleanup arch) marshalEnv' gamma aenv - = fmap DL.concat + = fmap mconcat $ mapM (\(_, Idx' repr idx) -> marshalArray' @arch repr =<< get (prj idx aenv)) (IM.elems gamma) marshalShape :: forall arch sh. Marshal arch => ShapeR sh -> sh -> [ArgR arch] @@ -105,22 +110,22 @@ data ParamR arch a where ParamRshape :: ShapeR sh -> ParamR arch sh ParamRargs :: ParamR arch (DList (ArgR arch)) -marshalParam' :: forall arch a. Marshal arch => ParamR arch a -> a -> Par arch (DList (ArgR arch)) +marshalParam' :: forall arch a. Marshal arch => ParamR arch a -> a -> Par arch (DList (ArgR arch), MarshalCleanup arch) marshalParam' (ParamRarray repr) a = marshalArray' repr a -marshalParam' (ParamRmaybe _ ) Nothing = return $ DL.empty +marshalParam' (ParamRmaybe _ ) Nothing = return (DL.empty, mempty) marshalParam' (ParamRmaybe repr) (Just a) = marshalParam' repr a marshalParam' (ParamRfuture repr) future = marshalParam' repr =<< get future marshalParam' (ParamRenv gamma) aenv = marshalEnv' gamma aenv -marshalParam' ParamRint x = return $ DL.singleton $ marshalInt @arch x -marshalParam' (ParamRshape shr) sh = return $ marshalShape' @arch shr sh -marshalParam' ParamRargs args = return args +marshalParam' ParamRint x = return (DL.singleton (marshalInt @arch x), mempty) +marshalParam' (ParamRshape shr) sh = return (marshalShape' @arch shr sh, mempty) +marshalParam' ParamRargs args = return (args, mempty) -marshalParams' :: forall arch a. Marshal arch => ParamsR arch a -> a -> Par arch (DList (ArgR arch)) +marshalParams' :: forall arch a. Marshal arch => ParamsR arch a -> a -> Par arch (DList (ArgR arch), MarshalCleanup arch) marshalParams' = marshalTupR' @arch (marshalParam' @arch) {-# INLINE marshalTupR' #-} -marshalTupR' :: forall arch s a. Marshal arch => (forall b. s b -> b -> Par arch (DList (ArgR arch))) -> TupR s a -> a -> Par arch (DList (ArgR arch)) -marshalTupR' _ TupRunit () = return $ DL.empty +marshalTupR' :: forall arch s a. Marshal arch => (forall b. s b -> b -> Par arch (DList (ArgR arch), MarshalCleanup arch)) -> TupR s a -> a -> Par arch (DList (ArgR arch), MarshalCleanup arch) +marshalTupR' _ TupRunit () = return (DL.empty, mempty) marshalTupR' f (TupRsingle t) x = f t x -marshalTupR' f (TupRpair t1 t2) (x1, x2) = DL.append <$> marshalTupR' @arch f t1 x1 <*> marshalTupR' @arch f t2 x2 +marshalTupR' f (TupRpair t1 t2) (x1, x2) = (<>) <$> marshalTupR' @arch f t1 x1 <*> marshalTupR' @arch f t2 x2 diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/State.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/State.hs index cbd98b4cf..9d0ad2c40 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/State.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/State.hs @@ -1,4 +1,7 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.LLVM.State @@ -21,9 +24,8 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty.PP as LP -- standard library import Control.Monad.Catch ( MonadCatch, MonadThrow, MonadMask ) -import Control.Monad.Reader ( ReaderT, MonadReader, runReaderT ) -import Control.Monad.State ( StateT, MonadState, evalStateT ) -import Control.Monad.Trans ( MonadIO ) +import Control.Monad.Reader ( ReaderT(..), MonadReader, runReaderT, ask, local ) +import Control.Monad.Trans ( MonadIO, lift ) import Prelude @@ -34,10 +36,15 @@ import Prelude -- for the LLVM execution context as well as the per-execution target specific -- state 'target'. -- -newtype LLVM target a = LLVM { runLLVM :: ReaderT LP.LLVMVer (StateT target IO) a } - deriving (Functor, Applicative, Monad, MonadIO, MonadReader LP.LLVMVer, MonadState target, MonadThrow, MonadCatch, MonadMask) +newtype LLVM target a = LLVM { runLLVM :: ReaderT LP.LLVMVer (ReaderT target IO) a } + deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch, MonadMask) --- | Extract the execution state: 'gets llvmTarget' +-- not derived because the LLVMVer reader masks this one +instance MonadReader target (LLVM target) where + ask = LLVM (lift ask) + local f (LLVM (ReaderT g)) = LLVM (ReaderT (local f . g)) + +-- | Extract the execution state: 'asks llvmTarget' -- llvmTarget :: t -> t llvmTarget = id @@ -47,9 +54,21 @@ llvmTarget = id evalLLVM :: t -> LLVM t a -> IO a evalLLVM target acc = case llvmverFromTuple hostLLVMVersion of - Just version -> evalStateT (runReaderT (runLLVM acc) version) target + Just version -> runReaderT (runReaderT (runLLVM acc) version) target Nothing -> fail "accelerate-llvm: Could not determine LLVM version from Clang output" +getLLVMVer :: LLVM target LP.LLVMVer +getLLVMVer = LLVM ask + +-- | This is a valid implementation of @withRunInIO@ in +-- unliftio-core:Control.Monad.IO.Unlift(MonadUnliftIO); it's not an instance +-- to avoid a dependency. +unliftIOLLVM :: ((forall a. LLVM target a -> IO a) -> IO b) -> LLVM target b +unliftIOLLVM f = LLVM (ReaderT (\llvmver -> ReaderT (\target -> f (run llvmver target)))) + where + run :: LP.LLVMVer -> target -> LLVM target a -> IO a + run llvmver target (LLVM m) = runReaderT (runReaderT m llvmver) target + -- -- | Make sure the GC knows that we want to keep this thing alive forever. -- --