Skip to content

Commit ca6aeee

Browse files
author
Divesh Otwani
authored
Merge pull request #291 from tweag/dest-arr/no-newtype
Changed destination type to have less unsafe linear casts
2 parents 313142d + 6f81e64 commit ca6aeee

File tree

1 file changed

+27
-38
lines changed

1 file changed

+27
-38
lines changed

src/Data/Array/Destination.hs

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE GADTs #-}
12
{-# LANGUAGE LinearTypes #-}
23
{-# LANGUAGE NoImplicitPrelude #-}
34
{-# LANGUAGE ScopedTypeVariables #-}
@@ -127,41 +128,38 @@ module Data.Array.Destination
127128
)
128129
where
129130

130-
import Control.Exception (evaluate)
131131
import Data.Vector (Vector, (!))
132132
import qualified Data.Vector as Vector
133133
import Data.Vector.Mutable (MVector)
134134
import qualified Data.Vector.Mutable as MVector
135135
import GHC.Exts (RealWorld)
136136
import qualified Prelude as Prelude
137-
import Prelude.Linear hiding (replicate)
138-
import System.IO.Unsafe
137+
import System.IO.Unsafe (unsafeDupablePerformIO)
139138
import GHC.Stack
139+
import Data.Unrestricted.Linear
140+
import Prelude.Linear hiding (replicate)
140141
import qualified Unsafe.Linear as Unsafe
141142

142143
-- | A destination array, or @DArray@, is a write-only array that is filled
143144
-- by some computation which ultimately returns an array.
144-
newtype DArray a = DArray (MVector RealWorld a)
145+
data DArray a where
146+
DArray :: MVector RealWorld a -> DArray a
145147

146148
-- XXX: use of Vector in types is temporary. I will probably move away from
147149
-- vectors and implement most stuff in terms of Array# and MutableArray#
148150
-- eventually, anyway. This would allow to move the MutableArray logic to
149151
-- linear IO, possibly, and segregate the unsafe casts to the Linear IO
150152
-- module. @`alloc` n k@ must be called with a non-negative value of @n@.
151153
alloc :: Int -> (DArray a %1-> ()) %1-> Vector a
152-
alloc n = Unsafe.toLinear unsafeAlloc
153-
where
154-
unsafeAlloc :: (DArray a %1-> ()) -> Vector a
155-
unsafeAlloc build = unsafeDupablePerformIO Prelude.$ do
156-
dest <- MVector.unsafeNew n
157-
evaluate (build (DArray dest))
158-
Vector.unsafeFreeze dest
154+
alloc n writer = (\(Ur dest, vec) -> writer (DArray dest) `lseq` vec) $
155+
unsafeDupablePerformIO Prelude.$ do
156+
destArray <- MVector.unsafeNew n
157+
vec <- Vector.unsafeFreeze destArray
158+
Prelude.return (Ur destArray, vec)
159159

160160
-- | Get the size of a destination array.
161161
size :: DArray a %1-> (Ur Int, DArray a)
162-
size (DArray vec) = Unsafe.toLinear go vec
163-
where
164-
go vec' = (Ur (MVector.length vec'), DArray vec')
162+
size (DArray mvec) = (Ur (MVector.length mvec), DArray mvec)
165163

166164
-- | Fill a destination array with a constant
167165
replicate :: a -> DArray a %1-> ()
@@ -170,34 +168,25 @@ replicate a = fromFunction (const a)
170168
-- | @fill a dest@ fills a singleton destination array.
171169
-- Caution, @'fill' a dest@ will fail is @dest@ isn't of length exactly one.
172170
fill :: HasCallStack => a %1-> DArray a %1-> ()
173-
fill = Unsafe.toLinear2 unsafeFill
174-
-- XXX: we will probably be able to spare this unsafe cast given a
175-
-- (linear) length function on destination.
176-
where
177-
unsafeFill a (DArray ds) =
178-
if MVector.length ds /= 1 then
179-
error "Destination.fill: requires a destination of size 1"
180-
else
181-
unsafeDupablePerformIO Prelude.$ MVector.write ds 0 a
171+
fill a (DArray mvec) =
172+
if MVector.length mvec /= 1
173+
then error "Destination.fill: requires a destination of size 1" $ a
174+
else a &
175+
Unsafe.toLinear (\x -> unsafeDupablePerformIO (MVector.write mvec 0 x))
182176

183177
-- | @dropEmpty dest@ consumes and empty array and fails otherwise.
184178
dropEmpty :: HasCallStack => DArray a %1-> ()
185-
dropEmpty = Unsafe.toLinear unsafeDrop where
186-
unsafeDrop :: DArray a -> ()
187-
unsafeDrop (DArray ds)
188-
| MVector.length ds > 0 = error "Destination.dropEmpty on non-empty array."
189-
| otherwise = ds `seq` ()
179+
dropEmpty (DArray mvec)
180+
| MVector.length mvec > 0 = error "Destination.dropEmpty on non-empty array."
181+
| otherwise = mvec `seq` ()
190182

191183
-- | @'split' n dest = (destl, destr)@ such as @destl@ has length @n@.
192184
--
193185
-- 'split' is total: if @n@ is larger than the length of @dest@, then
194186
-- @destr@ is empty.
195187
split :: Int -> DArray a %1-> (DArray a, DArray a)
196-
split n = Unsafe.toLinear unsafeSplit
197-
where
198-
unsafeSplit (DArray ds) =
199-
let (dsl, dsr) = MVector.splitAt n ds in
200-
(DArray dsl, DArray dsr)
188+
split n (DArray mvec) | (ml, mr) <- MVector.splitAt n mvec =
189+
(DArray ml, DArray mr)
201190

202191
-- | Fills the destination array with the contents of given vector.
203192
--
@@ -211,10 +200,10 @@ mirror v f arr =
211200

212201
-- | Fill a destination array using the given index-to-value function.
213202
fromFunction :: (Int -> b) -> DArray b %1-> ()
214-
fromFunction f = Unsafe.toLinear unsafeFromFunction
215-
where unsafeFromFunction (DArray ds) = unsafeDupablePerformIO Prelude.$ do
216-
let n = MVector.length ds
217-
Prelude.sequence_ [MVector.unsafeWrite ds m (f m) | m <- [0..n-1]]
218-
-- The unsafe cast here is actually safe, since getting the length does not
203+
fromFunction f (DArray mvec) = unsafeDupablePerformIO Prelude.$ do
204+
let n = MVector.length mvec
205+
Prelude.sequence_ [MVector.unsafeWrite mvec m (f m) | m <- [0..n-1]]
206+
-- The use of the mutable array is linear, since getting the length does not
219207
-- touch any elements, and each write fills in exactly one slot, so
220208
-- each slot of the destination array is filled.
209+

0 commit comments

Comments
 (0)