Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import qualified Data.Array.Accelerate.LLVM.Native.Debug as Debug

import Control.Concurrent ( myThreadId )
import Control.Concurrent.Extra ( getThreadId )
import Control.Monad.State ( gets )
import Control.Monad.Reader ( asks )
import Control.Monad.Trans ( liftIO )
import Data.ByteString.Short ( ShortByteString )
import Data.IORef ( newIORef, readIORef, writeIORef )
Expand Down Expand Up @@ -139,7 +139,7 @@ simpleOp
simpleOp name repr NativeR{..} gamma aenv sh = do
let fun = nativeExecutable !# name
param = TupRsingle $ ParamRarray repr
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr sh
scheduleOp fun gamma aenv (arrayRshape repr) sh param result
Expand Down Expand Up @@ -167,7 +167,7 @@ mapOp inplace repr tp NativeR{..} gamma aenv input = do
shr = arrayRshape repr
repr' = ArrayR shr tp
param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr)
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- case inplace of
Just Refl -> return input
Expand Down Expand Up @@ -201,7 +201,7 @@ transformOp
-> Par Native (Future (Array sh' b))
transformOp repr repr' NativeR{..} gamma aenv sh' input = do
let fun = nativeExecutable !# "transform"
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr' sh'
let param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr)
Expand Down Expand Up @@ -300,7 +300,7 @@ foldAllOp
-> Delayed (Vector e)
-> Par Native (Future (Scalar e))
foldAllOp tp NativeR{..} gamma aenv arr = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote (ArrayR dim0 tp) ()
let
Expand Down Expand Up @@ -343,7 +343,7 @@ foldDimOp
-> Delayed (Array (sh, Int) e)
-> Par Native (Future (Array sh e))
foldDimOp repr NativeR{..} gamma aenv arr@(delayedShape -> (sh, _)) = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr sh
let
Expand Down Expand Up @@ -371,7 +371,7 @@ foldSegOp
-> Delayed (Segments i)
-> Par Native (Future (Array (sh, Int) e))
foldSegOp iR repr NativeR{..} gamma aenv input@(delayedShape -> (sh, _)) segments@(delayedShape -> ((), ss)) = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
let
n = ss-1
Expand Down Expand Up @@ -428,7 +428,7 @@ scanCore
-> Delayed (Array (sh, Int) e)
-> Par Native (Future (Array (sh, Int) e))
scanCore repr NativeR{..} gamma aenv m input@(delayedShape -> (sz, n)) = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr (sz, m)
--
Expand Down Expand Up @@ -527,7 +527,7 @@ scan'Core repr NativeR{..} gamma aenv input@(delayedShape -> sh@(sz, n)) = do
paramA = TupRsingle $ ParamRarray repr
paramA' = TupRsingle $ ParamRarray repr'
--
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr sh
sums <- allocateRemote repr' sz
Expand Down Expand Up @@ -608,7 +608,7 @@ permuteOp inplace repr shr' NativeR{..} gamma aenv defaults@(shape -> shOut) inp
let
ArrayR shr tp = repr
repr' = ArrayR shr' tp
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- if inplace
then Debug.trace Debug.dump_exec "exec: permute/inplace" $ return defaults
Expand Down Expand Up @@ -701,7 +701,7 @@ stencilCore
-> params
-> Par Native (Future (Array sh e))
stencilCore repr NativeR{..} gamma aenv halo sh paramsR params = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
future <- new
result <- allocateRemote repr sh
let
Expand Down Expand Up @@ -815,7 +815,7 @@ scheduleOp
-> Maybe Action
-> Par Native ()
scheduleOp fun gamma aenv shr sz paramsR params done = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
let
splits = numWorkers workers - 1
minsize = case shr of
Expand All @@ -842,7 +842,7 @@ scheduleOpWith
-> Maybe Action -- run after the last piece completes
-> Par Native ()
scheduleOpWith splits minsize fun gamma aenv shr sz paramsR params done = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
job <- mkJob splits minsize fun gamma aenv shr (empty shr) sz paramsR params done
liftIO $ schedule workers job

Expand All @@ -858,7 +858,7 @@ scheduleOpUsing
-> Maybe Action
-> Par Native ()
scheduleOpUsing ranges fun gamma aenv shr paramsR params jobDone = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
job <- mkJobUsing ranges fun gamma aenv shr paramsR params jobDone
liftIO $ schedule workers job

Expand Down Expand Up @@ -919,7 +919,7 @@ mkTasksUsing
-> params
-> Par Native (Seq Action)
mkTasksUsing ranges (name, f) gamma aenv shr paramsR params = do
arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
(arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
return $ flip fmap ranges $ \(_,u,v) -> do
sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v)
let argU = marshalShape' @Native shr u
Expand All @@ -937,7 +937,7 @@ mkTasksUsingIndex
-> params
-> Par Native (Seq Action)
mkTasksUsingIndex ranges (name, f) gamma aenv shr paramsR params = do
arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
(arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
return $ flip fmap ranges $ \(i,u,v) -> do
sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v)
let argU = marshalShape' @Native shr u
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.State
-- standard library
import Control.Concurrent
import Control.Monad.Cont
import Control.Monad.State
import Control.Monad.Reader
import Data.IORef
import Data.Sequence ( Seq )
import qualified Data.Sequence as Seq
Expand Down Expand Up @@ -78,7 +78,7 @@ data IVar a
instance Async Native where
type FutureR Native = Future
newtype Par Native a = Par { runPar :: ContT () (LLVM Native) a }
deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadState Native )
deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadReader Native )

{-# INLINE new #-}
{-# INLINE newFull #-}
Expand All @@ -93,7 +93,7 @@ instance Async Native where
{-# INLINE get #-}
get (Future ref) =
callCC $ \k -> do
native <- gets llvmTarget
native <- asks llvmTarget
next <- liftIO . atomicModifyIORef' ref $ \case
Empty -> (Blocked (Seq.singleton (evalParIO native . k)), reschedule)
Blocked ks -> (Blocked (ks Seq.|> evalParIO native . k), reschedule)
Expand All @@ -102,7 +102,7 @@ instance Async Native where

{-# INLINE put #-}
put future ref = do
Native{..} <- gets llvmTarget
Native{..} <- asks llvmTarget
liftIO (putIO workers future ref)

{-# INLINE liftPar #-}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
Expand Down Expand Up @@ -34,9 +35,10 @@ import qualified Foreign.LibFFI as FFI

instance Marshal Native where
type ArgR Native = FFI.Arg
type MarshalCleanup Native = ()
marshalInt = $( case finiteBitSize (undefined::Int) of
32 -> [| FFI.argInt32 . fromIntegral |]
64 -> [| FFI.argInt64 . fromIntegral |]
_ -> error "I don't know what architecture I am" )
marshalScalarData' _ = return . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr
marshalScalarData' _ = return . (,()) . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.Native.Link.Cache
import Data.Array.Accelerate.LLVM.Native.Link.Object
import Data.Array.Accelerate.LLVM.Native.Link.Runtime

import Control.Monad.State
import Control.Monad.Reader
import Prelude hiding ( lookup )


Expand All @@ -48,7 +48,7 @@ instance Link Native where
--
link :: ObjectR Native -> LLVM Native (ExecutableR Native)
link (ObjectR uid nms _ so) = do
cache <- gets linkCache
cache <- asks linkCache
funs <- liftIO $ dlsym uid cache (loadSharedObject nms so)
return $! NativeR funs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim as Prim

import Control.Applicative
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State ( gets )
import Control.Monad.IO.Class ( liftIO )
import Control.Monad.Reader ( asks )
import System.IO.Unsafe
import Prelude

Expand Down Expand Up @@ -99,7 +99,7 @@ copyToHostLazy (TupRpair r1 r2) (f1, f2) = do
a2 <- copyToHostLazy r2 f2
return (a1, a2)
copyToHostLazy (TupRsingle (ArrayR shr tp)) future = do
ptx <- gets llvmTarget
ptx <- asks llvmTarget
liftIO $ do
Array sh adata <- wait future

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ pokeArrayAsync !t !n !ad
let !src = CUDA.HostPtr (unsafeUniqueArrayPtr ad)
!bytes = n * bytesElt (TupRsingle (SingleScalarType t))
--
stream <- asks ptxStream
stream <- asksParState ptxStream
result <- liftPar $
withLifetime stream $ \st ->
withDevicePtr t ad $ \dst ->
Expand Down Expand Up @@ -150,7 +150,7 @@ indexArrayAsync !n !t !ad_src !i
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
!dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad_dst)
--
stream <- asks ptxStream
stream <- asksParState ptxStream
result <- liftPar $
withLifetime stream $ \st ->
withDevicePtr t ad_src $ \src ->
Expand Down Expand Up @@ -179,7 +179,7 @@ peekArrayAsync !t !n !ad
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
!dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad)
--
stream <- asks ptxStream
stream <- asksParState ptxStream
result <- liftPar $
withLifetime stream $ \st ->
withDevicePtr t ad $ \src ->
Expand Down Expand Up @@ -208,7 +208,7 @@ copyArrayAsync !t !n !ad_src !ad_dst
= do
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
--
stream <- asks ptxStream
stream <- asksParState ptxStream
result <- liftPar $
withLifetime stream $ \st ->
withDevicePtr t ad_src $ \src ->
Expand Down Expand Up @@ -287,7 +287,7 @@ memsetArrayAsync !t !n !v !ad
= do
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
--
stream <- asks ptxStream
stream <- asksParState ptxStream
result <- liftPar $
withLifetime stream $ \st ->
withDevicePtr t ad $ \ptr ->
Expand Down Expand Up @@ -350,7 +350,7 @@ nonblocking !stream !action = do
return (Nothing, future)

else do
future <- Future <$> liftIO (newIORef (Pending event Nothing result))
future <- Future <$> liftIO (newIORef (Pending event (return ()) result))
return (Just event, future)

{-# INLINE withLifetime #-}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Driver.Stream as CUDA

import Control.Exception
import Control.Monad.State
import Control.Monad.Reader
import Data.Text.Lazy.Builder
import Formatting hiding ( bytes )
import qualified Formatting as F
Expand All @@ -63,7 +63,7 @@ instance Remote.RemoteMemory (LLVM PTX) where
mallocRemote n
| n <= 0 = return (Just CUDA.nullDevPtr)
| otherwise = do
name <- gets ptxDeviceName
name <- asks ptxDeviceName
liftIO $ do
ep <- try (CUDA.mallocArray n)
case ep of
Expand Down Expand Up @@ -114,7 +114,7 @@ malloc
-> Bool
-> LLVM PTX Bool
malloc !tp !ad !n !frozen = do
PTX{..} <- gets llvmTarget
PTX{..} <- asks llvmTarget
Remote.malloc ptxMemoryTable tp ad frozen n


Expand All @@ -127,7 +127,7 @@ withRemote
-> (CUDA.DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
withRemote !tp !ad !f = do
PTX{..} <- gets llvmTarget
PTX{..} <- asks llvmTarget
Remote.withRemote ptxMemoryTable tp ad f


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty as LP

import Control.Applicative
import Control.Monad ( void )
import Control.Monad.State ( gets )
import Control.Monad.Reader ( asks )
import Data.Bits
import Data.Proxy
import Data.String
Expand Down Expand Up @@ -139,7 +139,7 @@ laneMask_ge = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.ge"
--
warpId :: CodeGen PTX (Operands Int32)
warpId = do
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties
tid <- threadIdx
A.quot integralType tid (A.liftInt32 (P.fromIntegral (CUDA.warpSize dev)))

Expand Down Expand Up @@ -245,7 +245,7 @@ __syncwarp = __syncwarp_mask (liftWord32 0xffffffff)
__syncwarp_mask :: HasCallStack => Operands Word32 -> CodeGen PTX ()
__syncwarp_mask mask = do
llvmver <- getLLVMversion
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties
case (computeCapability dev >= Compute 7 0, llvmver >= 6) of
(True, True) -> void $ call (Lam primType (op primType mask) (Body VoidType (Just Tail) "llvm.nvvm.bar.warp.sync")) [NoUnwind, NoDuplicate, Convergent]
(True, False) -> internalError "LLVM-6.0 or above is required for Volta devices and later"
Expand Down Expand Up @@ -506,7 +506,7 @@ shfl_op
-> Operands a -- value to give
-> CodeGen PTX (Operands a) -- value received
shfl_op sop t delta val = do
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties

let
-- The CUDA __shfl* instruction take an optional final parameter
Expand Down Expand Up @@ -762,7 +762,7 @@ makeOpenAcc
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc uid name param kernel = do
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties
makeOpenAccWith (simpleLaunchConfig dev) uid name param kernel

-- | Create a single kernel program with the given launch analysis information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import LLVM.AST.Type.Representation
import qualified Foreign.CUDA.Analysis as CUDA

import Control.Monad ( (>=>) )
import Control.Monad.State ( gets )
import Control.Monad.Reader ( asks )
import Data.String ( fromString )
import Data.Bits as P
import Prelude as P
Expand Down Expand Up @@ -105,7 +105,7 @@ mkFoldAll
-> MIRDelayed PTX aenv (Vector e) -- ^ input data
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
mkFoldAll uid aenv tp combine mseed macc = do
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties
foldr1 (+++) <$> sequence [ mkFoldAllS uid dev aenv tp combine mseed macc
, mkFoldAllM1 uid dev aenv tp combine macc
, mkFoldAllM2 uid dev aenv tp combine mseed
Expand Down Expand Up @@ -303,7 +303,7 @@ mkFoldDim
-> MIRDelayed PTX aenv (Array (sh, Int) e) -- ^ input data
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkFoldDim uid aenv repr@(ArrayR shr tp) combine mseed marr = do
dev <- liftCodeGen $ gets ptxDeviceProperties
dev <- liftCodeGen $ asks ptxDeviceProperties
--
let
(arrOut, paramOut) = mutableArray repr "out"
Expand Down
Loading
Loading