diff --git a/basis-library/arrays-and-vectors/array-aos-slice.sig b/basis-library/arrays-and-vectors/array-aos-slice.sig new file mode 100644 index 000000000..d3ed41dd0 --- /dev/null +++ b/basis-library/arrays-and-vectors/array-aos-slice.sig @@ -0,0 +1,48 @@ +signature ARRAY_AOS_SLICE = +sig + type 'a slice + + val all: ('a -> bool) -> 'a slice -> bool + val app : ('a -> unit) -> 'a slice -> unit + val appi: (int * 'a -> unit) -> 'a slice -> unit + val base: 'a slice -> 'a ArrayAos.t * int * int + val collate: ('a * 'a -> order) -> 'a slice * 'a slice -> order + val copy: {dst: 'a ArrayAos.t, di: int, src: 'a slice} -> unit + val copyVec: {dst: 'a ArrayAos.t, di: int, src: 'a VectorAosSlice.slice} -> unit + val exists: ('a -> bool) -> 'a slice -> bool + val find: ('a -> bool) -> 'a slice -> 'a option + val findi: (int * 'a -> bool) -> 'a slice -> (int * 'a) option + val foldl: ('a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldli: (int * 'a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldr: ('a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldri: (int * 'a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val full: 'a ArrayAos.t -> 'a slice + val getItem: 'a slice -> ('a * 'a slice) option + val isEmpty: 'a slice -> bool + val length: 'a slice -> int + val modify : ('a -> 'a) -> 'a slice -> unit + val modifyi: (int * 'a -> 'a) -> 'a slice -> unit + val slice: 'a ArrayAos.t * int * int option -> 'a slice + val sub: 'a slice * int -> 'a + val subslice: 'a slice * int * int option -> 'a slice + val update: 'a slice * int * 'a -> unit + val vector: 'a slice -> 'a VectorAos.vector +end + +signature ARRAY_AOS_SLICE_EXTRA = +sig + include ARRAY_AOS_SLICE + + val uninitIsNop: 'a slice -> bool + val uninit: 'a slice * int -> unit + val unsafeSub: 'a slice * int -> 'a + val unsafeCopy: {dst: 'a ArrayAos.t, di: int, src: 'a slice} -> unit + val unsafeCopyVec: {dst: 'a ArrayAos.t, di: int, src: 'a VectorAosSlice.slice} -> unit + val unsafeSlice: 'a ArrayAos.t * int * int option -> 'a slice + val unsafeSubslice: 'a slice * int * int option -> 'a slice + val unsafeUninit: 'a slice * int -> unit + val unsafeUpdate: 'a slice * int * 'a -> unit + + val concat: 'a slice list -> 'a ArrayAos.t + val toList: 'a slice -> 'a list +end diff --git a/basis-library/arrays-and-vectors/array-aos.sig b/basis-library/arrays-and-vectors/array-aos.sig new file mode 100644 index 000000000..e7d7228a7 --- /dev/null +++ b/basis-library/arrays-and-vectors/array-aos.sig @@ -0,0 +1,65 @@ +signature ARRAY_FLAT = +sig + type 'a array = 'a ArrayAos.t + type 'a vector = 'a VectorAos.vector + + val all: ('a -> bool) -> 'a array -> bool + val app: ('a -> unit) -> 'a array -> unit + val appi: (int * 'a -> unit) -> 'a array -> unit + val array: int * 'a -> 'a array + val collate: ('a * 'a -> order) -> 'a array * 'a array -> order + val copy: {src: 'a array, dst: 'a array, di: int} -> unit + val copyVec: {src: 'a vector, dst: 'a array, di: int} -> unit + val exists: ('a -> bool) -> 'a array -> bool + val find: ('a -> bool) -> 'a array -> 'a option + val findi: (int * 'a -> bool) -> 'a array -> (int * 'a) option + val foldl: ('a * 'b -> 'b) -> 'b -> 'a array -> 'b + val foldli: (int * 'a * 'b -> 'b) -> 'b -> 'a array -> 'b + val foldr: ('a * 'b -> 'b) -> 'b -> 'a array -> 'b + val foldri: (int * 'a * 'b -> 'b) -> 'b -> 'a array -> 'b + val fromList: 'a list -> 'a array + val length: 'a array -> int + val maxLen: int + val modify: ('a -> 'a) -> 'a array -> unit + val modifyi: (int * 'a -> 'a) -> 'a array -> unit + val sub: 'a array * int -> 'a + val tabulate: int * (int -> 'a) -> 'a array + val update: 'a array * int * 'a -> unit + val vector: 'a array -> 'a vector +end + +signature ARRAY_AOS_EXTRA = +sig + include ARRAY_FLAT + + structure ArraySlice: ARRAY_AOS_SLICE_EXTRA + + val alloc: int -> 'a array + val uninitIsNop: 'a array -> bool + val uninit: 'a array * int -> unit + val unsafeAlloc: int -> 'a array + val unsafeArray: int * 'a -> 'a array + val unsafeCopy: {dst: 'a array, di: int, src: 'a array} -> unit + val unsafeCopyVec: {dst: 'a array, di: int, src: 'a vector} -> unit + val unsafeSub: 'a array * int -> 'a + val unsafeUninit: 'a array * int -> unit + val unsafeUpdate: 'a array * int * 'a -> unit + + val concat: 'a array list -> 'a array + val duplicate: 'a array -> 'a array + val toList: 'a array -> 'a list + val unfoldi: int * 'b * (int * 'b -> 'a * 'b) -> 'a array * 'b + val unfold: int * 'b * ('b -> 'a * 'b) -> 'a array * 'b + + structure Raw: + sig + type 'a rawarr + val alloc: int -> 'a rawarr + val length: 'a rawarr -> int + val uninit: 'a rawarr * int -> unit + val uninitIsNop: 'a rawarr -> bool + val unsafeAlloc: int -> 'a rawarr + val unsafeToArray: 'a rawarr -> 'a array + val unsafeUninit: 'a rawarr * int -> unit + end +end diff --git a/basis-library/arrays-and-vectors/array-aos.sml b/basis-library/arrays-and-vectors/array-aos.sml new file mode 100644 index 000000000..dcdbee7d6 --- /dev/null +++ b/basis-library/arrays-and-vectors/array-aos.sml @@ -0,0 +1,63 @@ +structure ArrayAos: ARRAY_AOS_EXTRA = +struct + structure A = Sequence (Primitive.ArrayAos) + open A + + val op < = Int.< + val op <= = Int.<= + + fun wrap2 f = fn (i, x) => f (SeqIndex.toIntUnsafe i, x) + + type 'a array = 'a ArrayAos.t + type 'a vector = 'a VectorAos.vector + + structure ArraySlice = + struct + open Slice + val vector = Primitive.ArrayAos.Slice.vector + val copyVec = VectorAos.VectorSlice.copy + val unsafeCopyVec = VectorAos.VectorSlice.unsafeCopy + fun modifyi f sl = Primitive.ArrayAos.Slice.modifyi (wrap2 f) sl + val modify = Primitive.ArrayAos.Slice.modify + end + + val array = new + val unsafeArray = unsafeNew + val vector = Primitive.ArrayAos.vector + val copyVec = VectorAos.copy + val unsafeCopyVec = VectorAos.unsafeCopy + fun modifyi f sl = Primitive.ArrayAos.modifyi (wrap2 f) sl + val modify = Primitive.ArrayAos.modify + structure Raw = Primitive.ArrayAos.Raw + structure Raw = + struct + type 'a rawarr = 'a Raw.rawarr + + fun length a = + if Primitive.Controls.safe + then (SeqIndex.toInt (Raw.length a)) + handle Overflow => raise Fail "Raw.length" + else SeqIndex.toIntUnsafe (Raw.length a) + + fun alloc n = Raw.alloc (SeqIndex.fromIntForLength n) + fun unsafeAlloc n = Raw.unsafeAlloc (SeqIndex.fromIntUnsafe n) + + val uninitIsNop = Raw.uninitIsNop + fun unsafeUninit (a, i) = + Raw.unsafeUninit (a, SeqIndex.fromIntUnsafe i) + fun uninit (a, i) = + if Primitive.Controls.safe + then let + val i = + (SeqIndex.fromInt i) + handle Overflow => raise Subscript + in + Raw.uninit (a, i) + end + else unsafeUninit (a, i) + + val unsafeToArray = Primitive.ArrayAos.Raw.unsafeToArray + end +end + +structure ArrayAosSlice: ARRAY_AOS_SLICE_EXTRA = ArrayAos.ArraySlice \ No newline at end of file diff --git a/basis-library/arrays-and-vectors/sequence.fun b/basis-library/arrays-and-vectors/sequence.fun index 4340e4e0b..ab1ef89a0 100644 --- a/basis-library/arrays-and-vectors/sequence.fun +++ b/basis-library/arrays-and-vectors/sequence.fun @@ -66,7 +66,7 @@ structure SeqIndex = else fromIntUnsafe n end -functor Sequence (S: PRIM_SEQUENCE): SEQUENCE = +functor Sequence (S: PRIM_SEQUENCE): SEQUENCE where type 'a prim_array = 'a S.prim_array = struct val op +! = SeqIndex.+! val op +$ = SeqIndex.+$ @@ -83,6 +83,7 @@ functor Sequence (S: PRIM_SEQUENCE): SEQUENCE = type 'a sequence = 'a S.sequence type 'a elt = 'a S.elt + type 'a prim_array = 'a S.prim_array (* S.maxLen must be representable as an Int.int already *) val maxLen = SeqIndex.toInt S.maxLen @@ -130,6 +131,7 @@ functor Sequence (S: PRIM_SEQUENCE): SEQUENCE = type 'a sequence = 'a S.Slice.sequence type 'a elt = 'a S.Slice.elt type 'a slice = 'a S.Slice.slice + type 'a prim_array = 'a S.Slice.prim_array fun length sl = if Primitive.Controls.safe @@ -250,7 +252,7 @@ functor Sequence (S: PRIM_SEQUENCE): SEQUENCE = handle Overflow => raise Size) else (fn (x, s) => s +! S.Slice.length (toSlice x)) val n = List.foldl add 0 xs - val a = Primitive.Array.alloc n + val a = S.unsafeArrayAlloc n fun loop (di, xs) = case xs of [] => S.unsafeFromArray a @@ -281,7 +283,7 @@ functor Sequence (S: PRIM_SEQUENCE): SEQUENCE = else (fn (x, s) => (s +! sepn +! S.Slice.length (toSlice x))) val n = List.foldl add (S.Slice.length (toSlice x)) xs - val a = Primitive.Array.alloc n + val a = S.unsafeArrayAlloc n fun loop (di, xs) = case xs of [] => raise Fail "Sequence.Slice.concatWithGen" diff --git a/basis-library/arrays-and-vectors/sequence.sig b/basis-library/arrays-and-vectors/sequence.sig index c32c52afb..657fbe72b 100644 --- a/basis-library/arrays-and-vectors/sequence.sig +++ b/basis-library/arrays-and-vectors/sequence.sig @@ -12,9 +12,11 @@ signature SEQUENCE = sig type 'a sequence type 'a elt + type 'a prim_array structure Slice : SLICE where type 'a sequence = 'a sequence and type 'a elt = 'a elt + and type 'a prim_array = 'a prim_array val maxLen: int val length: 'a sequence -> int @@ -25,8 +27,8 @@ signature SEQUENCE = val uninitIsNop: 'a sequence -> bool val uninit: 'a sequence * int -> unit val unsafeUninit: 'a sequence * int -> unit - val copy: {dst: 'a elt Array.array, di: int, src: 'a sequence} -> unit - val unsafeCopy: {dst: 'a elt Array.array, di: int, src: 'a sequence} -> unit + val copy: {dst: 'a elt prim_array, di: int, src: 'a sequence} -> unit + val unsafeCopy: {dst: 'a elt prim_array, di: int, src: 'a sequence} -> unit val tabulate: int * (int -> 'a elt) -> 'a sequence val appi: (int * 'a elt -> unit) -> 'a sequence -> unit val app: ('a elt -> unit) -> 'a sequence -> unit diff --git a/basis-library/arrays-and-vectors/sequence0.sig b/basis-library/arrays-and-vectors/sequence0.sig index 53489cc46..954a2467b 100644 --- a/basis-library/arrays-and-vectors/sequence0.sig +++ b/basis-library/arrays-and-vectors/sequence0.sig @@ -12,9 +12,12 @@ signature PRIM_SEQUENCE = sig type 'a sequence type 'a elt + type 'a prim_array + type 'a prim_vector structure Slice: PRIM_SLICE where type 'a sequence = 'a sequence and type 'a elt = 'a elt + and type 'a prim_array = 'a prim_array val maxLen: SeqIndex.int (* Must also be representable as an Int.int *) val length: 'a sequence -> SeqIndex.int @@ -25,8 +28,8 @@ signature PRIM_SEQUENCE = val unsafeUninit: 'a sequence * SeqIndex.int -> unit val update: 'a sequence * SeqIndex.int * 'a elt -> unit val unsafeUpdate: 'a sequence * SeqIndex.int * 'a elt -> unit - val copy: {dst: 'a elt array, di: SeqIndex.int, src: 'a sequence} -> unit - val unsafeCopy: {dst: 'a elt array, di: SeqIndex.int, src: 'a sequence} -> unit + val copy: {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a sequence} -> unit + val unsafeCopy: {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a sequence} -> unit val tabulate: SeqIndex.int * (SeqIndex.int -> 'a elt) -> 'a sequence val appi: (SeqIndex.int * 'a elt -> unit) -> 'a sequence -> unit val app: ('a elt -> unit) -> 'a sequence -> unit @@ -59,5 +62,6 @@ signature PRIM_SEQUENCE = val unfoldi: SeqIndex.int * 'b * (SeqIndex.int * 'b -> 'a elt * 'b) -> 'a sequence * 'b val unfold: SeqIndex.int * 'b * ('b -> 'a elt * 'b) -> 'a sequence * 'b val unsafeAlloc: SeqIndex.int -> 'a sequence - val unsafeFromArray: 'a elt array -> 'a sequence + val unsafeFromArray: 'a elt prim_array -> 'a sequence + val unsafeArrayAlloc: SeqIndex.int -> 'a prim_array end diff --git a/basis-library/arrays-and-vectors/sequence0.sml b/basis-library/arrays-and-vectors/sequence0.sml index b5de28262..3c8552d3e 100644 --- a/basis-library/arrays-and-vectors/sequence0.sml +++ b/basis-library/arrays-and-vectors/sequence0.sml @@ -9,24 +9,38 @@ * See the file MLton-LICENSE for details. *) -functor PrimSequence (S: sig - type 'a sequence - type 'a elt - val copyUnsafe: 'a elt array * SeqIndex.int * 'a sequence * SeqIndex.int * SeqIndex.int -> unit - (* fromArray should be constant time. *) - val fromArray: 'a elt array -> 'a sequence - val isMutable: bool - val length: 'a sequence -> SeqIndex.int - val sameArray: 'a elt array * 'a sequence -> bool - val subUnsafe: 'a sequence * SeqIndex.int -> 'a elt - val uninitIsNop: 'a sequence -> bool - val uninitUnsafe: 'a sequence * SeqIndex.int -> unit - val updateUnsafe: 'a sequence * SeqIndex.int * 'a elt -> unit - end) :> PRIM_SEQUENCE where type 'a sequence = 'a S.sequence - where type 'a elt = 'a S.elt = - struct - structure Array = Primitive.Array +signature PRIM_SEQUENCE_PARAMS = +sig + type 'a sequence + type 'a elt + type 'a prim_array + type 'a prim_vector + val copyUnsafe: 'a elt prim_array * SeqIndex.int * 'a sequence * SeqIndex.int * SeqIndex.int -> unit + (* fromArray should be constant time. *) + val fromArray: 'a elt prim_array -> 'a sequence + val isMutable: bool + val length: 'a sequence -> SeqIndex.int + val sameArray: 'a elt prim_array * 'a sequence -> bool + val subUnsafe: 'a sequence * SeqIndex.int -> 'a elt + val uninitIsNop: 'a sequence -> bool + val uninitUnsafe: 'a sequence * SeqIndex.int -> unit + val updateUnsafe: 'a sequence * SeqIndex.int * 'a elt -> unit + + val primArrayAllocUnsafe: SeqIndex.int -> 'a prim_array + val primArraySubUnsafe: 'a prim_array * SeqIndex.int -> 'a + val primArrayUpdateUnsafe: 'a prim_array * SeqIndex.int * 'a -> unit + val primArrayLength: 'a prim_array -> SeqIndex.int + val primVectorFromArrayUnsafe: 'a prim_array -> 'a prim_vector +end + + +functor PrimSequence (S: PRIM_SEQUENCE_PARAMS) + :> PRIM_SEQUENCE where type 'a sequence = 'a S.sequence + where type 'a elt = 'a S.elt + where type 'a prim_array = 'a S.prim_array + where type 'a prim_vector = 'a S.prim_vector = + struct val op +! = SeqIndex.+! val op + = SeqIndex.+ val op -! = SeqIndex.-! @@ -42,6 +56,8 @@ functor PrimSequence (S: sig type 'a sequence = 'a S.sequence type 'a elt = 'a S.elt + type 'a prim_array = 'a S.prim_array + type 'a prim_vector = 'a S.prim_vector local fun valOf x: Primitive.Int32.int = case x of SOME y => y | NONE => 0 @@ -71,7 +87,7 @@ functor PrimSequence (S: sig fun length s = S.length s - fun unsafeArrayAlloc n = Array.allocUnsafe n + fun unsafeArrayAlloc n = S.primArrayAllocUnsafe n fun arrayAlloc n = if Primitive.Controls.safe andalso gtu (n, maxLen) @@ -90,18 +106,18 @@ functor PrimSequence (S: sig if Primitive.Controls.safe andalso geu (i, !subLim) then raise Subscript else - Array.subUnsafe (a, i) + S.primArraySubUnsafe (a, i) val updateLim : SeqIndex.t ref = ref 0 fun update (i, x) = if Primitive.Controls.safe andalso geu (i, !updateLim) then if i = !updateLim andalso i < n then - (Array.updateUnsafe (a, i, x); + (S.primArrayUpdateUnsafe (a, i, x); subLim := i + 1; updateLim := i + 1) else raise Subscript else - Array.updateUnsafe (a, i, x) + S.primArrayUpdateUnsafe (a, i, x) val gotIt = ref false fun done () = if !gotIt then @@ -128,7 +144,7 @@ functor PrimSequence (S: sig else let val (x, b) = f (i, b) - val () = Array.updateUnsafe (a, i, x) + val () = S.primArrayUpdateUnsafe (a, i, x) in loop (i +! 1, b) end @@ -148,6 +164,7 @@ functor PrimSequence (S: sig struct type 'a sequence = 'a sequence type 'a elt = 'a elt + type 'a prim_array = 'a prim_array datatype 'a t = T of {seq: 'a sequence, start: SeqIndex.int, len: SeqIndex.int} type 'a slice = 'a t @@ -174,12 +191,12 @@ functor PrimSequence (S: sig then raise Subscript else unsafeUninit (sl, i) local - fun smallCopy {dst: 'a elt array, di: SeqIndex.int, + fun smallCopy {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a sequence, si: SeqIndex.int, len: SeqIndex.int, overlap: unit -> bool} = let - fun move i = Array.updateUnsafe (dst, di +! i, S.subUnsafe (src, si +! i)) + fun move i = S.primArrayUpdateUnsafe (dst, di +! i, S.subUnsafe (src, si +! i)) val len = len -! 1 in if overlap () @@ -201,7 +218,7 @@ functor PrimSequence (S: sig end end val smallCopyLimit = 5 - fun maybeSmallCopy {dst: 'a elt array, di: SeqIndex.int, + fun maybeSmallCopy {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a sequence, si: SeqIndex.int, len: SeqIndex.int, overlap: unit -> bool} = @@ -212,17 +229,17 @@ functor PrimSequence (S: sig overlap = overlap} else S.copyUnsafe (dst, di, src, si, len) in - fun unsafeCopy {dst: 'a elt array, di: SeqIndex.int, + fun unsafeCopy {dst: 'a elt prim_array, di: SeqIndex.int, src = T {seq = src, start = si, len}} = maybeSmallCopy {dst = dst, di = di, src = src, si = si, len = len, overlap = fn () => false} - fun copy {dst: 'a elt array, di: SeqIndex.int, + fun copy {dst: 'a elt prim_array, di: SeqIndex.int, src = T {seq = src, start = si, len}} = if Primitive.Controls.safe - andalso (gtu (di, Array.length dst) - orelse gtu (di +! len, Array.length dst)) + andalso (gtu (di, S.primArrayLength dst) + orelse gtu (di +! len, S.primArrayLength dst)) then raise Subscript else let fun overlap () = @@ -474,104 +491,224 @@ functor PrimSequence (S: sig end -structure Primitive = struct -open Primitive -structure Array = - struct - local - structure P = PrimSequence (type 'a sequence = 'a array - type 'a elt = 'a - val sameArray = op = - val copyUnsafe = Primitive.Array.copyArrayUnsafe - val fromArray = fn a => a - val isMutable = true - val length = Primitive.Array.length - val subUnsafe = Primitive.Array.subUnsafe - val uninitIsNop = Primitive.Array.uninitIsNop - val uninitUnsafe = Primitive.Array.uninitUnsafe - val updateUnsafe = Primitive.Array.updateUnsafe) - in - open P - type 'a array = 'a array - structure Slice = - struct - open Slice - fun vector sl = - let - val a = unsafeAlloc (length sl) - val () = unsafeCopy {dst = a, di = 0, src = sl} - in - Vector.fromArrayUnsafe a - end - fun modifyi f sl = - appi (fn (i, x) => unsafeUpdate (sl, i, f (i, x))) sl - fun modify f sl = modifyi (fn (_, x) => f x) sl +signature PRIM_RAW_ARRAY = +sig + type 'a array + type 'a rawarr + val allocUnsafe: SeqIndex.int -> 'a rawarr + val length: 'a rawarr -> SeqIndex.int + val toArrayUnsafe: 'a rawarr -> 'a array + val uninitIsNop: 'a rawarr -> bool + val uninitUnsafe: 'a rawarr * SeqIndex.int -> unit +end + + +functor WrapRawArray(S: sig + include PRIM_RAW_ARRAY + val maxLen: SeqIndex.int + end) = +struct + type 'a rawarr = 'a S.rawarr + + val length = S.length + + val unsafeAlloc = S.allocUnsafe + fun alloc n = + if Primitive.Controls.safe + andalso SeqIndex.gtu (n, S.maxLen) + then raise Size + else unsafeAlloc n + + val unsafeToArray = S.toArrayUnsafe + + val uninitIsNop = S.uninitIsNop + val unsafeUninit = S.uninitUnsafe + fun uninit (a, i) = + if Primitive.Controls.safe andalso SeqIndex.geu (i, length a) + then raise Subscript + else unsafeUninit (a, i) +end + + +functor MakeArrayPrimSequence (S: sig + include PRIM_SEQUENCE_PARAMS + structure Raw: PRIM_RAW_ARRAY + where type 'a array = 'a prim_array + end) = +struct + + local + structure P = PrimSequence(S) + in + open P + type 'a array = 'a prim_array + type 'a t = 'a prim_array + type 'a vector = 'a prim_vector + structure Slice = + struct + open Slice + fun vector sl = + let + val a = unsafeArrayAlloc (length sl) + val () = unsafeCopy {dst = a, di = 0, src = sl} + in + S.primVectorFromArrayUnsafe a end - fun vector s = Slice.vector (Slice.full s) - fun modifyi f s = Slice.modifyi f (Slice.full s) - fun modify f s = Slice.modify f (Slice.full s) + fun modifyi f sl = + appi (fn (i, x) => unsafeUpdate (sl, i, f (i, x))) sl + fun modify f sl = modifyi (fn (_, x) => f x) sl end - structure Raw = - struct - type 'a rawarr = 'a Primitive.Array.Raw.rawarr + fun vector s = Slice.vector (Slice.full s) + fun modifyi f s = Slice.modifyi f (Slice.full s) + fun modify f s = Slice.modify f (Slice.full s) + end - val length = Primitive.Array.Raw.length + structure Raw = WrapRawArray(open S.Raw val maxLen = maxLen) - val unsafeAlloc = Primitive.Array.Raw.allocUnsafe - fun alloc n = - if Primitive.Controls.safe - andalso SeqIndex.gtu (n, maxLen) - then raise Size - else unsafeAlloc n +end - val unsafeToArray = Primitive.Array.Raw.toArrayUnsafe - val uninitIsNop = Primitive.Array.Raw.uninitIsNop - val unsafeUninit = Primitive.Array.Raw.uninitUnsafe - fun uninit (a, i) = - if Primitive.Controls.safe andalso SeqIndex.geu (i, length a) - then raise Subscript - else unsafeUninit (a, i) - end - end +functor MakeVectorPrimSequence (S: PRIM_SEQUENCE_PARAMS) = +struct + local + structure P = PrimSequence(S) + in + open P + type 'a vector = 'a prim_vector + type 'a t = 'a prim_vector + type 'a array = 'a prim_array + fun updateVector (v, i, x) = + if Primitive.Controls.safe andalso SeqIndex.geu (i, length v) + then raise Subscript + else let + val a = S.primArrayAllocUnsafe (length v) + val () = copy {dst = a, di = 0, src = v} + val () = S.primArrayUpdateUnsafe (a, i, x) + in + S.primVectorFromArrayUnsafe a + end + end +end -structure Vector = - struct - local - exception Vector_uninitIsNop - exception Vector_uninitUnsafe - exception Vector_updateUnsafe - structure P = PrimSequence (type 'a sequence = 'a vector - type 'a elt = 'a - val copyUnsafe = Primitive.Array.copyVectorUnsafe - val fromArray = Primitive.Vector.fromArrayUnsafe - val isMutable = false - val length = Vector.length - val sameArray = fn _ => false - val subUnsafe = Primitive.Vector.subUnsafe - val uninitIsNop = fn _ => - raise Vector_uninitIsNop - val uninitUnsafe = fn _ => - raise Vector_uninitUnsafe - val updateUnsafe = fn _ => - raise Vector_updateUnsafe) - in - open P - type 'a vector = 'a vector - fun updateVector (v, i, x) = - if Primitive.Controls.safe andalso SeqIndex.geu (i, length v) - then raise Subscript - else let - val a = Array.unsafeAlloc (length v) - val () = copy {dst = a, di = 0, src = v} - val () = Array.unsafeUpdate (a, i, x) - in - unsafeFromArray a - end - end - end + +structure Primitive = struct +open Primitive + +structure Array = MakeArrayPrimSequence( + type 'a sequence = 'a array + type 'a elt = 'a + type 'a prim_array = 'a array + type 'a prim_vector = 'a vector + val sameArray = op = + val copyUnsafe = Primitive.Array.copyArrayUnsafe + val fromArray = fn a => a + val isMutable = true + val length = Primitive.Array.length + val subUnsafe = Primitive.Array.subUnsafe + val uninitIsNop = Primitive.Array.uninitIsNop + val uninitUnsafe = Primitive.Array.uninitUnsafe + val updateUnsafe = Primitive.Array.updateUnsafe + + val primArrayAllocUnsafe = Primitive.Array.allocUnsafe + val primArraySubUnsafe = Primitive.Array.subUnsafe + val primArrayUpdateUnsafe = Primitive.Array.updateUnsafe + val primArrayLength = Primitive.Array.length + val primVectorFromArrayUnsafe = Primitive.Vector.fromArrayUnsafe + + structure Raw = + struct + type 'a array = 'a array + open Primitive.Array.Raw + end +) + +structure ArrayAos = MakeArrayPrimSequence( + type 'a sequence = 'a ArrayAos.t + type 'a elt = 'a + type 'a prim_array = 'a ArrayAos.t + type 'a prim_vector = 'a VectorAos.t + val sameArray = op = + val copyUnsafe = Primitive.ArrayAos.copyArrayUnsafe + val fromArray = fn a => a + val isMutable = true + val length = Primitive.ArrayAos.length + val subUnsafe = Primitive.ArrayAos.subUnsafe + val uninitIsNop = Primitive.ArrayAos.uninitIsNop + val uninitUnsafe = Primitive.ArrayAos.uninitUnsafe + val updateUnsafe = Primitive.ArrayAos.updateUnsafe + + val primArrayAllocUnsafe = Primitive.ArrayAos.allocUnsafe + val primArraySubUnsafe = Primitive.ArrayAos.subUnsafe + val primArrayUpdateUnsafe = Primitive.ArrayAos.updateUnsafe + val primArrayLength = Primitive.ArrayAos.length + val primVectorFromArrayUnsafe = Primitive.VectorAos.fromArrayUnsafe + + structure Raw = + struct + type 'a array = 'a ArrayAos.t + open Primitive.ArrayAos.Raw + end +) + + +structure Vector = MakeVectorPrimSequence( + exception Vector_uninitIsNop + exception Vector_uninitUnsafe + exception Vector_updateUnsafe + type 'a sequence = 'a vector + type 'a elt = 'a + type 'a prim_array = 'a array + type 'a prim_vector = 'a vector + val copyUnsafe = Primitive.Array.copyVectorUnsafe + val fromArray = Primitive.Vector.fromArrayUnsafe + val isMutable = false + val length = Vector.length + val sameArray = fn _ => false + val subUnsafe = Primitive.Vector.subUnsafe + val uninitIsNop = fn _ => + raise Vector_uninitIsNop + val uninitUnsafe = fn _ => + raise Vector_uninitUnsafe + val updateUnsafe = fn _ => + raise Vector_updateUnsafe + + val primArrayAllocUnsafe = Primitive.Array.allocUnsafe + val primArraySubUnsafe = Primitive.Array.subUnsafe + val primArrayUpdateUnsafe = Primitive.Array.updateUnsafe + val primArrayLength = Primitive.Array.length + val primVectorFromArrayUnsafe = Primitive.Vector.fromArrayUnsafe +) + + +structure VectorAos = MakeVectorPrimSequence( + exception VectorAos_uninitIsNop + exception VectorAos_uninitUnsafe + exception VectorAos_updateUnsafe + type 'a sequence = 'a VectorAos.t + type 'a elt = 'a + type 'a prim_array = 'a ArrayAos.t + type 'a prim_vector = 'a VectorAos.t + val copyUnsafe = Primitive.ArrayAos.copyVectorUnsafe + val fromArray = Primitive.VectorAos.fromArrayUnsafe + val isMutable = false + val length = VectorAos.length + val sameArray = fn _ => false + val subUnsafe = Primitive.VectorAos.subUnsafe + val uninitIsNop = fn _ => + raise VectorAos_uninitIsNop + val uninitUnsafe = fn _ => + raise VectorAos_uninitUnsafe + val updateUnsafe = fn _ => + raise VectorAos_updateUnsafe + + val primArrayAllocUnsafe = Primitive.ArrayAos.allocUnsafe + val primArraySubUnsafe = Primitive.ArrayAos.subUnsafe + val primArrayUpdateUnsafe = Primitive.ArrayAos.updateUnsafe + val primArrayLength = Primitive.ArrayAos.length + val primVectorFromArrayUnsafe = Primitive.VectorAos.fromArrayUnsafe +) end @@ -583,3 +720,12 @@ structure Vector = struct type 'a vector = 'a vector end + +structure ArrayAos = + struct + type 'a t = 'a Primitive.ArrayAos.t + end +structure VectorAos = + struct + type 'a t = 'a Primitive.VectorAos.t + end \ No newline at end of file diff --git a/basis-library/arrays-and-vectors/slice.sig b/basis-library/arrays-and-vectors/slice.sig index 65529fd55..8d387ba5e 100644 --- a/basis-library/arrays-and-vectors/slice.sig +++ b/basis-library/arrays-and-vectors/slice.sig @@ -13,6 +13,7 @@ signature SLICE = type 'a sequence type 'a elt type 'a slice + type 'a prim_array val length: 'a slice -> int val sub: 'a slice * int -> 'a elt val unsafeSub: 'a slice * int -> 'a elt @@ -21,8 +22,8 @@ signature SLICE = val uninitIsNop: 'a slice -> bool val uninit: 'a slice * int -> unit val unsafeUninit: 'a slice * int -> unit - val copy: {dst: 'a elt Array.array, di: int, src: 'a slice} -> unit - val unsafeCopy: {dst: 'a elt Array.array, di: int, src: 'a slice} -> unit + val copy: {dst: 'a elt prim_array, di: int, src: 'a slice} -> unit + val unsafeCopy: {dst: 'a elt prim_array, di: int, src: 'a slice} -> unit val full: 'a sequence -> 'a slice val slice: 'a sequence * int * int option -> 'a slice val unsafeSlice: 'a sequence * int * int option -> 'a slice diff --git a/basis-library/arrays-and-vectors/slice0.sig b/basis-library/arrays-and-vectors/slice0.sig index 92d687660..744b6343c 100644 --- a/basis-library/arrays-and-vectors/slice0.sig +++ b/basis-library/arrays-and-vectors/slice0.sig @@ -13,6 +13,7 @@ signature PRIM_SLICE = type 'a sequence type 'a elt type 'a slice + type 'a prim_array val length: 'a slice -> SeqIndex.int val sub: 'a slice * SeqIndex.int -> 'a elt val unsafeSub: 'a slice * SeqIndex.int -> 'a elt @@ -21,8 +22,8 @@ signature PRIM_SLICE = val uninitIsNop: 'a slice -> bool val uninit: 'a slice * SeqIndex.int -> unit val unsafeUninit: 'a slice * SeqIndex.int -> unit - val copy: {dst: 'a elt array, di: SeqIndex.int, src: 'a slice} -> unit - val unsafeCopy: {dst: 'a elt array, di: SeqIndex.int, src: 'a slice} -> unit + val copy: {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a slice} -> unit + val unsafeCopy: {dst: 'a elt prim_array, di: SeqIndex.int, src: 'a slice} -> unit val full: 'a sequence -> 'a slice val slice: 'a sequence * SeqIndex.int * SeqIndex.int option -> 'a slice val unsafeSlice: 'a sequence * SeqIndex.int * SeqIndex.int option -> 'a slice diff --git a/basis-library/arrays-and-vectors/vector-aos-slice.sig b/basis-library/arrays-and-vectors/vector-aos-slice.sig new file mode 100644 index 000000000..3dee44dfb --- /dev/null +++ b/basis-library/arrays-and-vectors/vector-aos-slice.sig @@ -0,0 +1,65 @@ +signature VECTOR_AOS_SLICE = +sig + type 'a slice + + val length: 'a slice -> int + val sub: 'a slice * int -> 'a + val full: 'a VectorAos.t -> 'a slice + val slice: 'a VectorAos.t * int * int option -> 'a slice + val subslice: 'a slice * int * int option -> 'a slice + val base: 'a slice -> 'a VectorAos.t * int * int + val vector: 'a slice -> 'a VectorAos.t + val concat: 'a slice list -> 'a VectorAos.t + val isEmpty: 'a slice -> bool + val getItem: 'a slice -> ('a * 'a slice) option + val appi: (int * 'a -> unit) -> 'a slice -> unit + val app: ('a -> unit) -> 'a slice -> unit + val mapi: (int * 'a -> 'b) -> 'a slice -> 'b VectorAos.t + val map: ('a -> 'b) -> 'a slice -> 'b VectorAos.t + val foldli: (int * 'a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldl: ('a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldri: (int * 'a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val foldr: ('a * 'b -> 'b) -> 'b -> 'a slice -> 'b + val findi: (int * 'a -> bool) -> 'a slice -> (int * 'a) option + val find: ('a -> bool) -> 'a slice -> 'a option + val exists: ('a -> bool) -> 'a slice -> bool + val all: ('a -> bool) -> 'a slice -> bool + val collate: ('a * 'a -> order) -> 'a slice * 'a slice -> order +end + +signature VECTOR_AOS_SLICE_EXTRA = +sig + include VECTOR_AOS_SLICE + + val copy: {dst: 'a ArrayAos.t, di: int, src: 'a slice} -> unit + + val unsafeSub: 'a slice * int -> 'a + val unsafeCopy: {dst: 'a ArrayAos.t, di: int, src: 'a slice} -> unit + val unsafeSlice: 'a VectorAos.t * int * int option -> 'a slice + val unsafeSubslice: 'a slice * int * int option -> 'a slice + + (* Used to implement Substring/String functions *) + val concatWith: 'a VectorAos.t -> 'a slice list -> 'a VectorAos.t + val triml: int -> 'a slice -> 'a slice + val trimr: int -> 'a slice -> 'a slice + val isPrefix: ('a * 'a -> bool) -> 'a VectorAos.t -> 'a slice -> bool + val isSubvector: ('a * 'a -> bool) -> 'a VectorAos.t -> 'a slice -> bool + val isSuffix: ('a * 'a -> bool) -> 'a VectorAos.t -> 'a slice -> bool + val splitl: ('a -> bool) -> 'a slice -> 'a slice * 'a slice + val splitr: ('a -> bool) -> 'a slice -> 'a slice * 'a slice + val splitAt: 'a slice * int -> 'a slice * 'a slice + val dropl: ('a -> bool) -> 'a slice -> 'a slice + val dropr: ('a -> bool) -> 'a slice -> 'a slice + val takel: ('a -> bool) -> 'a slice -> 'a slice + val taker: ('a -> bool) -> 'a slice -> 'a slice + val position: ('a * 'a -> bool) + -> 'a VectorAos.t + -> 'a slice + -> 'a slice * 'a slice + val span: ''a slice * ''a slice -> ''a slice + val translate: ('a -> 'b VectorAos.t) -> 'a slice -> 'b VectorAos.t + val tokens: ('a -> bool) -> 'a slice -> 'a slice list + val fields: ('a -> bool) -> 'a slice -> 'a slice list + + val toList: 'a slice -> 'a list +end diff --git a/basis-library/arrays-and-vectors/vector-aos.sig b/basis-library/arrays-and-vectors/vector-aos.sig new file mode 100644 index 000000000..f2ea00d98 --- /dev/null +++ b/basis-library/arrays-and-vectors/vector-aos.sig @@ -0,0 +1,56 @@ +signature VECTOR_FLAT = +sig + type 'a vector = 'a VectorAos.t + + val maxLen: int + val fromList: 'a list -> 'a vector + val tabulate: int * (int -> 'a) -> 'a vector + val length: 'a vector -> int + val sub: 'a vector * int -> 'a + val update: 'a vector * int * 'a -> 'a vector + val concat: 'a vector list -> 'a vector + val appi: (int * 'a -> unit) -> 'a vector -> unit + val app: ('a -> unit) -> 'a vector -> unit + val mapi: (int * 'a -> 'b) -> 'a vector -> 'b vector + val map: ('a -> 'b) -> 'a vector -> 'b vector + val foldli: (int * 'a * 'b -> 'b) -> 'b -> 'a vector -> 'b + val foldri: (int * 'a * 'b -> 'b) -> 'b -> 'a vector -> 'b + val foldl: ('a * 'b -> 'b) -> 'b -> 'a vector -> 'b + val foldr: ('a * 'b -> 'b) -> 'b -> 'a vector -> 'b + val findi: (int * 'a -> bool) -> 'a vector -> (int * 'a) option + val find: ('a -> bool) -> 'a vector -> 'a option + val exists: ('a -> bool) -> 'a vector -> bool + val all: ('a -> bool) -> 'a vector -> bool + val collate: ('a * 'a -> order) -> 'a vector * 'a vector -> order +end + +signature VECTOR_AOS_EXTRA = +sig + include VECTOR_FLAT + structure VectorSlice: VECTOR_AOS_SLICE_EXTRA + + val copy: {dst: 'a ArrayAos.t, di: int, src: 'a vector} -> unit + + val unsafeFromArray: 'a ArrayAos.t -> 'a vector + val unsafeSub: 'a vector * int -> 'a + val unsafeCopy: {dst: 'a ArrayAos.t, di: int, src: 'a vector} -> unit + + (* Used to implement Substring/String functions *) + val concatWith: 'a vector -> 'a vector list -> 'a vector + val isPrefix: ('a * 'a -> bool) -> 'a vector -> 'a vector -> bool + val isSubvector: ('a * 'a -> bool) -> 'a vector -> 'a vector -> bool + val isSuffix: ('a * 'a -> bool) -> 'a vector -> 'a vector -> bool + val translate: ('a -> 'b vector) -> 'a vector -> 'b vector + val tokens: ('a -> bool) -> 'a vector -> 'a vector list + val fields: ('a -> bool) -> 'a vector -> 'a vector list + + val append: 'a vector * 'a vector -> 'a vector + val create: + int + -> {done: unit -> 'a vector, sub: int -> 'a, update: int * 'a -> unit} + val duplicate: 'a vector -> 'a vector + val toList: 'a vector -> 'a list + val unfoldi: int * 'b * (int * 'b -> 'a * 'b) -> 'a vector * 'b + val unfold: int * 'b * ('b -> 'a * 'b) -> 'a vector * 'b + val vector: int * 'a -> 'a vector +end diff --git a/basis-library/arrays-and-vectors/vector-aos.sml b/basis-library/arrays-and-vectors/vector-aos.sml new file mode 100644 index 000000000..438b8a0fe --- /dev/null +++ b/basis-library/arrays-and-vectors/vector-aos.sml @@ -0,0 +1,31 @@ +structure VectorAos: VECTOR_AOS_EXTRA = +struct + structure V = Sequence(Primitive.VectorAos) + open V + + type 'a vector = 'a VectorAos.t + + structure VectorSlice = + struct + open Slice + type 'a vector = 'a vector + val vector = sequence + + val isSubvector = isSubsequence + val span = fn (sl, sl') => + Primitive.VectorAos.Slice.span (op= : ''a vector * ''a vector -> bool) + (sl, sl') + end + + fun update (v, i, x) = + (Primitive.VectorAos.updateVector (v, SeqIndex.fromInt i, x)) + handle Overflow => raise Subscript + + val isSubvector = isSubsequence + + val unsafeFromArray = Primitive.VectorAos.unsafeFromArray + + val vector = new +end + +structure VectorAosSlice: VECTOR_AOS_SLICE_EXTRA = VectorAos.VectorSlice diff --git a/basis-library/build/sources.mlb b/basis-library/build/sources.mlb index 99e8a5dc3..251a2c36d 100644 --- a/basis-library/build/sources.mlb +++ b/basis-library/build/sources.mlb @@ -105,9 +105,15 @@ in ../arrays-and-vectors/vector-slice.sig ../arrays-and-vectors/vector.sig ../arrays-and-vectors/vector.sml + ../arrays-and-vectors/vector-aos-slice.sig + ../arrays-and-vectors/vector-aos.sig + ../arrays-and-vectors/vector-aos.sml ../arrays-and-vectors/array-slice.sig ../arrays-and-vectors/array.sig ../arrays-and-vectors/array.sml + ../arrays-and-vectors/array-aos-slice.sig + ../arrays-and-vectors/array-aos.sig + ../arrays-and-vectors/array-aos.sml ../arrays-and-vectors/array2.sig ../arrays-and-vectors/array2.sml ../arrays-and-vectors/mono-vector-slice.sig diff --git a/basis-library/mpl/mpl.sig b/basis-library/mpl/mpl.sig index 19fa6b560..dde1c384e 100644 --- a/basis-library/mpl/mpl.sig +++ b/basis-library/mpl/mpl.sig @@ -8,4 +8,10 @@ signature MPL = sig structure File: MPL_FILE structure GC: MPL_GC + + structure ArrayAos: ARRAY_AOS_EXTRA + structure ArrayAosSlice: ARRAY_AOS_SLICE_EXTRA + + structure VectorAos: VECTOR_AOS_EXTRA + structure VectorAosSlice: VECTOR_AOS_SLICE_EXTRA end diff --git a/basis-library/mpl/mpl.sml b/basis-library/mpl/mpl.sml index f7ef98161..3fc8cb609 100644 --- a/basis-library/mpl/mpl.sml +++ b/basis-library/mpl/mpl.sml @@ -4,8 +4,14 @@ * See the file MLton-LICENSE for details. *) -structure MPL :> MPL = +structure MPL: MPL = struct structure File = MPLFile structure GC = MPLGC + + structure ArrayAos = ArrayAos + structure ArrayAosSlice = ArrayAosSlice + + structure VectorAos = VectorAos + structure VectorAosSlice = VectorAosSlice end diff --git a/basis-library/primitive/prim-basis.sml b/basis-library/primitive/prim-basis.sml index f67e65ab0..004f45232 100644 --- a/basis-library/primitive/prim-basis.sml +++ b/basis-library/primitive/prim-basis.sml @@ -49,6 +49,17 @@ structure Vector = type 'a vector = 'a t end +structure ArrayAos = + struct + type 'a t = 'a array_aos + type 'a array_aos = 'a t + end +structure VectorAos = + struct + type 'a t = 'a vector_aos + type 'a vector_aos = 'a t + end + (* Primitive Basis (Primitive Types) *) structure Char8 = struct diff --git a/basis-library/primitive/prim-seq.sml b/basis-library/primitive/prim-seq.sml index 0c4e527b1..94808ff61 100644 --- a/basis-library/primitive/prim-seq.sml +++ b/basis-library/primitive/prim-seq.sml @@ -57,4 +57,63 @@ structure Vector = val vector0 = _prim "Vector_vector": unit -> 'a vector; end +(* ---------------------------------------------------------------------------- + * Flattened representations of arrays and vectors + * + * SAM_NOTE: plenty of duplication with the Array and Vector structures above. + * Personally I prefer it this way for clarity, because in general we don't + * want to assume that the {array, vector} types support the same operations as + * {array_aos, vector_aos}. + * + * SAM_NOTE: there is a bit of ad-hoc polymorphism going on here: many of + * these array primitives are essentially polymorphic w.r.t. array vs array_aos + * (and vector vs vector_aos). You'll see that the names of the primitives + * are identical in almost all cases... + * ---------------------------------------------------------------------------- + *) + +structure ArrayAos = + struct + open ArrayAos + val allocUnsafe = _prim "Array_allocAos": SeqIndex.int -> 'a array_aos; + val copyArrayUnsafe = _prim "Array_copyArray": 'a array_aos * SeqIndex.int * 'a array_aos * SeqIndex.int * SeqIndex.int -> unit; + val copyVectorUnsafe = _prim "Array_copyVector": 'a array_aos * SeqIndex.int * 'a VectorAos.t * SeqIndex.int * SeqIndex.int -> unit; + val length = _prim "Array_length": 'a array_aos -> SeqIndex.int; + (* There is no maximum length on arrays, so maxLen' = SeqIndex.maxInt'. *) + (* val maxLen': SeqIndex.int = SeqIndex.maxInt' *) + val subUnsafe = _prim "Array_sub": 'a array_aos * SeqIndex.int -> 'a; + val uninitIsNop = _prim "Array_uninitIsNop": 'a array_aos -> bool; + val uninitUnsafe = _prim "Array_uninit": 'a array_aos * SeqIndex.int -> unit; + val updateUnsafe = _prim "Array_update": 'a array_aos * SeqIndex.int * 'a -> unit; + + structure Raw :> sig + type 'a rawarr + val allocUnsafe: SeqIndex.int -> 'a rawarr + val length: 'a rawarr -> SeqIndex.int + val toArrayUnsafe: 'a rawarr -> 'a array_aos + val uninitIsNop: 'a rawarr -> bool + val uninitUnsafe: 'a rawarr * SeqIndex.int -> unit + end = + struct + type 'a rawarr = 'a array_aos + val allocUnsafe = _prim "Array_allocRawAos": SeqIndex.int -> 'a rawarr; + val length = length + val toArrayUnsafe = _prim "Array_toArray": 'a rawarr -> 'a array_aos; + val uninitIsNop = uninitIsNop + val uninitUnsafe = uninitUnsafe + end + end + +structure VectorAos = + struct + open VectorAos + (* Don't mutate the array after you apply fromArray, because vectors + * are supposed to be immutable and the optimizer depends on this. + *) + val fromArrayUnsafe = _prim "Array_toVector": 'a ArrayAos.t -> 'a vector_aos; + val length = _prim "Vector_length": 'a vector_aos -> SeqIndex.int; + val subUnsafe = _prim "Vector_sub": 'a vector_aos * SeqIndex.int -> 'a; + val vector0 = _prim "Vector_vectorAos": unit -> 'a vector_aos; + end + end diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 16623aa21..0d4fac268 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -40,6 +40,16 @@ struct ArrayExtra.Raw.unsafeToArray a end + fun alloc_aos n = + let + val a = ArrayAosExtra.Raw.alloc n + val _ = + if ArrayAosExtra.Raw.uninitIsNop a then () + else parfor 10000 (0, n) (fn i => ArrayAosExtra.Raw.unsafeUninit (a, i)) + in + ArrayAosExtra.Raw.unsafeToArray a + end + val maxForkDepthSoFar = Scheduler.maxForkDepthSoFar val numSpawnsSoFar = Scheduler.numSpawnsSoFar val numEagerSpawnsSoFar = Scheduler.numEagerSpawnsSoFar @@ -225,6 +235,7 @@ sig val parfor: int -> (int * int) -> (int -> unit) -> unit val alloc: int -> 'a array + val alloc_aos: int -> 'a MPL.ArrayAos.array val idleTimeSoFar: unit -> Time.time val workTimeSoFar: unit -> Time.time diff --git a/basis-library/schedulers/spork/sources.mlb b/basis-library/schedulers/spork/sources.mlb index 62d3e4520..4ef89895a 100644 --- a/basis-library/schedulers/spork/sources.mlb +++ b/basis-library/schedulers/spork/sources.mlb @@ -9,8 +9,12 @@ local in signature ARRAY_EXTRA signature ARRAY_SLICE_EXTRA + signature ARRAY_AOS_EXTRA + signature ARRAY_AOS_SLICE_EXTRA structure ArrayExtra = Array structure ArraySliceExtra = ArraySlice + structure ArrayAosExtra = ArrayAos + structure ArrayAosSliceExtra = ArrayAosSlice structure Primitive functor Int_ChooseFromInt diff --git a/examples/Makefile b/examples/Makefile index f7eea4d0f..523926195 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -13,6 +13,7 @@ PROGRAMS= \ ray \ tokens \ nn \ + nn-flat \ dedup \ nqueens \ reverb \ diff --git a/examples/src/nn-flat/NN.sml b/examples/src/nn-flat/NN.sml new file mode 100644 index 000000000..2db4957c7 --- /dev/null +++ b/examples/src/nn-flat/NN.sml @@ -0,0 +1,284 @@ +structure NN : +sig + type t + type 'a seq = 'a ArraySlice.slice + type 'a seq_flat = 'a MPL.ArrayAosSlice.slice + + type point = Geometry2D.point + + (* makeTree leafSize points *) + val makeTree : int -> point seq_flat -> t + + (* allNearestNeighbors grain quadtree *) + val allNearestNeighbors : int -> t -> int seq +end = +struct + + structure A = Array + structure AS = ArraySlice + structure AFS = MPL.ArrayAosSlice + + type 'a seq = 'a ArraySlice.slice + type 'a seq_flat = 'a MPL.ArrayAosSlice.slice + structure G = Geometry2D + type point = G.point + + fun par4 (a, b, c, d) = + let + val ((ar, br), (cr, dr)) = + ForkJoin.par (fn _ => ForkJoin.par (a, b), + fn _ => ForkJoin.par (c, d)) + in + (ar, br, cr, dr) + end + + datatype tree = + Leaf of { anchor : point + , width : real + , vertices : int seq (* indices of original point seq *) + } + | Node of { anchor : point + , width : real + , count : int + , children : tree seq + } + + type t = tree * point seq_flat + + fun count t = + case t of + Leaf {vertices, ...} => AS.length vertices + | Node {count, ...} => count + + fun anchor t = + case t of + Leaf {anchor, ...} => anchor + | Node {anchor, ...} => anchor + + fun width t = + case t of + Leaf {width, ...} => width + | Node {width, ...} => width + + fun boxOf t = + case t of + Leaf {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + | Node {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + + fun indexApp grain f t = + let + fun downSweep offset t = + case t of + Leaf {vertices, ...} => + AS.appi (fn (i, v) => f (offset + i, v)) vertices + | Node {children, ...} => + let + fun q i = AS.sub (children, i) + fun qCount i = count (q i) + val offset0 = offset + val offset1 = offset0 + qCount 0 + val offset2 = offset1 + qCount 1 + val offset3 = offset2 + qCount 2 + in + if count t <= grain then + ( downSweep offset0 (q 0) + ; downSweep offset1 (q 1) + ; downSweep offset2 (q 2) + ; downSweep offset3 (q 3) + ) + else + ( par4 + ( fn _ => downSweep offset0 (q 0) + , fn _ => downSweep offset1 (q 1) + , fn _ => downSweep offset2 (q 2) + , fn _ => downSweep offset3 (q 3) + ) + ; () + ) + end + in + downSweep 0 t + end + + fun indexMap grain f t = + let + val result = ForkJoin.alloc (count t) + val _ = indexApp grain (fn (i, v) => A.update (result, i, f (i, v))) t + in + AS.full result + end + + fun flatten grain t = indexMap grain (fn (_, v) => v) t + + (* val lowerTime = ref Time.zeroTime + val upperTime = ref Time.zeroTime + fun addTm r t = + if Primitives.numberOfProcessors = 1 then + r := Time.+ (!r, t) + else () + fun clearAndReport r name = + (print (name ^ " " ^ Time.fmt 4 (!r) ^ "\n"); r := Time.zeroTime) *) + + (* Make a tree where all points are in the specified bounding box. *) + fun makeTreeBounded leafSize (verts : point seq_flat) (idx : int seq) ((xLeft, yBot) : G.point) width = + if AS.length idx <= leafSize then + Leaf { anchor = (xLeft, yBot) + , width = width + , vertices = idx + } + else let + val qw = width/2.0 (* quadrant width *) + val center = (xLeft + qw, yBot + qw) + + val ((sorted, offsets), tm) = Util.getTime (fn () => + CountingSort.sort idx (fn i => + G.quadrant center (SeqFlat.nth verts (Seq.nth idx i))) 4) + + (* val _ = + if AS.length idx >= 4 * leafSize then + addTm upperTime tm + else + addTm lowerTime tm *) + + fun quadrant i = + let + val start = AS.sub (offsets, i) + val len = AS.sub (offsets, i+1) - start + val childIdx = AS.subslice (sorted, start, SOME len) + val qAnchor = + case i of + 0 => (xLeft + qw, yBot + qw) + | 1 => (xLeft, yBot + qw) + | 2 => (xLeft, yBot) + | _ => (xLeft + qw, yBot) + in + makeTreeBounded leafSize verts childIdx qAnchor qw + end + + (* val children = Seq.tabulate (Perf.grain 1) quadrant 4 *) + val (a, b, c, d) = + if AS.length idx <= 100 then + (quadrant 0, quadrant 1, quadrant 2, quadrant 3) + else + par4 + ( fn _ => quadrant 0 + , fn _ => quadrant 1 + , fn _ => quadrant 2 + , fn _ => quadrant 3 ) + val children = AS.full (Array.fromList [a,b,c,d]) + in + Node { anchor = (xLeft, yBot) + , width = width + , count = AS.length idx + , children = children + } + end + + fun loop (lo, hi) b f = + if (lo >= hi) then b else loop (lo+1, hi) (f (b, lo)) f + + fun reduce grain f b (get, lo, hi) = + if hi - lo <= grain then + loop (lo, hi) b (fn (b, i) => f (b, get i)) + else let + val mid = lo + (hi-lo) div 2 + val (l,r) = ForkJoin.par + ( fn _ => reduce grain f b (get, lo, mid) + , fn _ => reduce grain f b (get, mid, hi) + ) + in + f (l, r) + end + + fun makeTree leafSize (verts : point seq_flat) = + if AFS.length verts = 0 then raise Fail "makeTree with 0 points" else + let + (* calculate the bounding box *) + fun maxPt ((x1,y1),(x2,y2)) = (Real.max (x1, x2), Real.max (y1, y2)) + fun minPt ((x1,y1),(x2,y2)) = (Real.min (x1, x2), Real.min (y1, y2)) + fun getPt i = AFS.sub (verts, i) + val (xLeft,yBot) = reduce 10000 minPt (Real.posInf, Real.posInf) (getPt, 0, AFS.length verts) + val (xRight,yTop) = reduce 10000 maxPt (Real.negInf, Real.negInf) (getPt, 0, AFS.length verts) + val width = Real.max (xRight-xLeft, yTop-yBot) + + val idx = Seq.tabulate (fn i => i) (AFS.length verts) + val result = makeTreeBounded leafSize verts idx (xLeft, yBot) width + in + (* clearAndReport upperTime "upper sort time"; *) + (* clearAndReport lowerTime "lower sort time"; *) + (result, verts) + end + + (* ======================================================================== *) + + fun constrain (x : real) (lo, hi) = + if x < lo then lo + else if x > hi then hi + else x + + fun distanceToBox (x,y) (xLeft, yBot, xRight, yTop) = + G.distance (x,y) (constrain x (xLeft, xRight), constrain y (yBot, yTop)) + + val dummyBest = (~1, Real.posInf) + + fun nearestNeighbor (t : tree, pts) (pi : int) = + let + fun pt i = SeqFlat.nth pts i + + val p = pt pi + + fun refineNearest (qi, (bestPt, bestDist)) = + if pi = qi then (bestPt, bestDist) else + let + val qDist = G.distance p (pt qi) + in + if qDist < bestDist + then (qi, qDist) + else (bestPt, bestDist) + end + + fun search (best as (_, bestDist : real)) t = + if distanceToBox p (boxOf t) > bestDist then best else + case t of + Leaf {vertices, ...} => + AS.foldl refineNearest best vertices + | Node {anchor=(x,y), width, children, ...} => + let + val qw = width/2.0 + val center = (x+qw, y+qw) + + (* search the quadrant that p is in first *) + val heuristicOrder = + case G.quadrant center p of + 0 => [0,1,2,3] + | 1 => [1,0,2,3] + | 2 => [2,1,3,0] + | _ => [3,0,2,1] + + fun child i = AS.sub (children, i) + fun refine (i, best) = search best (child i) + in + List.foldl refine best heuristicOrder + end + + val (best, _) = search dummyBest t + in + best + end + + fun allNearestNeighbors grain (t, pts) = + let + val n = SeqFlat.length pts + val idxs = flatten 10000 t + val nn = ForkJoin.alloc n + in + ForkJoin.parfor grain (0, n) (fn i => + let + val j = Seq.nth idxs i + in + A.update (nn, j, nearestNeighbor (t, pts) j) + end); + AS.full nn + end + +end diff --git a/examples/src/nn-flat/SeqFlat.sml b/examples/src/nn-flat/SeqFlat.sml new file mode 100644 index 000000000..6ac762c6b --- /dev/null +++ b/examples/src/nn-flat/SeqFlat.sml @@ -0,0 +1,24 @@ +structure SeqFlat = +struct + + structure AF = MPL.ArrayAos + structure AFS = MPL.ArrayAosSlice + + type 'a seq = 'a AFS.slice + type 'a t = 'a seq + + fun length (s : 'a seq) : int = + AFS.length s + + fun nth s i = + AFS.sub (s, i) + + fun tabulate f n : 'a seq = + let + val arr = ForkJoin.alloc_aos n + in + ForkJoin.parform (0, n) (fn i => AF.update (arr, i, f i)); + AFS.full arr + end + +end \ No newline at end of file diff --git a/examples/src/nn-flat/main.sml b/examples/src/nn-flat/main.sml new file mode 100644 index 000000000..860ccc79f --- /dev/null +++ b/examples/src/nn-flat/main.sml @@ -0,0 +1,120 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" 1000000 +val leafSize = CLA.parseInt "leafSize" 16 +val grain = CLA.parseInt "grain" 100 +val seed = CLA.parseInt "seed" 15210 + +val _ = print ("N " ^ Int.toString n ^ "\n") + +fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + +fun genPoint i = (genReal (2*i), genReal (2*i + 1)) + + +val (input, tm) = Util.getTime (fn _ => SeqFlat.tabulate genPoint n) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +val (tree, tm) = Util.getTime (fn _ => NN.makeTree leafSize input) +val _ = print ("built quadtree in " ^ Time.fmt 4 tm ^ "s\n") + +val (nbrs, tm) = Util.getTime (fn _ => NN.allNearestNeighbors grain tree) +val _ = print ("found all neighbors in " ^ Time.fmt 4 tm ^ "s\n") + +(* now input[nbrs[i]] is the closest point to input[i] *) + +(* ========================================================================== + * write image to output + * this only works if all input points are within [0,1) *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then () + else ( print ("to see output, use -output and -resolution arguments\n" ^ + "for example: nn -N 10000 -output result.ppm -resolution 1000\n") + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val width = resolution +val height = resolution + +val image = Seq.tabulate (fn i => Color.white) (width * height) + +fun set (i, j) x = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then ArraySlice.update (image, i*width + j, x) + else () + +val r = Real.fromInt resolution +fun px x = Real.floor (x * r) +fun pos (x, y) = (resolution - px x - 1, px y) + +fun horizontalLine i (j0, j1) = + if j1 < j0 then horizontalLine i (j1, j0) + else Util.for (j0, j1) (fn j => set (i, j) Color.red) + +fun sign xx = + case Int.compare (xx, 0) of LESS => ~1 | EQUAL => 0 | GREATER => 1 + +(* Bresenham's line algorithm *) +fun line (x1, y1) (x2, y2) = + let + val w = x2 - x1 + val h = y2 - y1 + val dx1 = sign w + val dy1 = sign h + val (longest, shortest, dx2, dy2) = + if Int.abs w > Int.abs h then + (Int.abs w, Int.abs h, dx1, 0) + else + (Int.abs h, Int.abs w, 0, dy1) + + fun loop i numerator x y = + if i > longest then () else + let + val numerator = numerator + shortest; + in + set (x, y) Color.red; + if numerator >= longest then + loop (i+1) (numerator-longest) (x+dx1) (y+dy1) + else + loop (i+1) numerator (x+dx2) (y+dy2) + end + in + loop 0 (longest div 2) x1 y1 + end + +(* mark all nearest neighbors with straight red lines *) +val _ = ForkJoin.parfor 10000 (0, SeqFlat.length input) (fn i => + line (pos (SeqFlat.nth input i)) (pos (SeqFlat.nth input (Seq.nth nbrs i)))) + +(* mark input points as a pixel *) +val _ = + ForkJoin.parfor 10000 (0, SeqFlat.length input) (fn i => + let + val (x, y) = pos (SeqFlat.nth input i) + fun b spot = set spot Color.black + in + b (x-1, y); + b (x, y-1); + b (x, y); + b (x, y+1); + b (x+1, y) + end) + +val t1 = Time.now () +val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val image = {width = width, height = height, data = image} +val (_, tm) = Util.getTime (fn _ => PPM.write filename image) +val _ = print ("wrote to " ^ filename ^ " in " ^ Time.fmt 4 tm ^ "s\n") diff --git a/examples/src/nn-flat/sources.mlb b/examples/src/nn-flat/sources.mlb new file mode 100644 index 000000000..7db3b0468 --- /dev/null +++ b/examples/src/nn-flat/sources.mlb @@ -0,0 +1,4 @@ +../../lib/sources.mlb +SeqFlat.sml +NN.sml +main.sml diff --git a/lib/mlton/basic/option.sig b/lib/mlton/basic/option.sig index d7d4a73f6..40c099714 100644 --- a/lib/mlton/basic/option.sig +++ b/lib/mlton/basic/option.sig @@ -21,4 +21,6 @@ signature OPTION = val map: 'a t * ('a -> 'b) -> 'b t val toString: ('a -> string) -> 'a t -> string val valOf: 'a t -> 'a + + val andThen: 'a t * ('a -> 'b t) -> 'b t end diff --git a/lib/mlton/basic/option.sml b/lib/mlton/basic/option.sml index e9f7d7407..cae2e4c7b 100644 --- a/lib/mlton/basic/option.sml +++ b/lib/mlton/basic/option.sml @@ -39,6 +39,11 @@ fun map (opt, f) = NONE => NONE | SOME x => SOME (f x) +fun andThen (opt, f) = + case opt of + NONE => NONE + | SOME x => f x + fun equals (o1, o2, eq) = case (o1, o2) of (NONE, NONE) => true diff --git a/mlton/Makefile b/mlton/Makefile index 0c8cdf240..0e6ebcd53 100644 --- a/mlton/Makefile +++ b/mlton/Makefile @@ -78,6 +78,7 @@ $(MLTON_OUTPUT): $(MLTON_SOURCES) "$(RUN_MLTON)" \ @MLton $(RUN_MLTON_RUNTIME_XARGS) $(RUN_MLTON_RUNTIME_ARGS) gc-summary -- \ $(RUN_MLTON_COMPILE_XARGS) -verbose 2 $(RUN_MLTON_COMPILE_ARGS) \ + -const 'Exn.keepHistory true' \ -target $(TARGET) -output $(MLTON_OUTPUT) \ $(MLTON_MLB) diff --git a/mlton/atoms/array-layout.sml b/mlton/atoms/array-layout.sml new file mode 100644 index 000000000..50e32e4a1 --- /dev/null +++ b/mlton/atoms/array-layout.sml @@ -0,0 +1,29 @@ +(* Memory layouts for arrays. *) +structure ArrayLayout :> +sig + datatype t = Default | Aos + + val equals: t * t -> bool + + (* confusing naming: layout here means to produce a Layout.t + * representation (for pretty printing) of the ArrayLayout.t value + *) + val layout: t -> Layout.t + + val toString: t -> string +end = +struct + datatype t = Default | Aos + + fun equals (Default, Default) = true + | equals (Aos, Aos) = true + | equals _ = false + + fun toString lay = + case lay of + Default => "Default" + | Aos => "Aos" + + fun layout lay = Layout.str (toString lay) + +end \ No newline at end of file diff --git a/mlton/atoms/hash-type.fun b/mlton/atoms/hash-type.fun index 4224c45c7..1d60ff151 100644 --- a/mlton/atoms/hash-type.fun +++ b/mlton/atoms/hash-type.fun @@ -145,7 +145,7 @@ fun ofConst c = | Null => cpointer | Real r => real (RealX.size r) | Word w => word (WordX.size w) - | WordVector v => vector (word (WordXVector.elementSize v)) + | WordVector v => vector ArrayLayout.Default (word (WordXVector.elementSize v)) end fun isUnit t = diff --git a/mlton/atoms/prim-tycons.fun b/mlton/atoms/prim-tycons.fun index c8961e8fa..4454ce0f0 100644 --- a/mlton/atoms/prim-tycons.fun +++ b/mlton/atoms/prim-tycons.fun @@ -19,7 +19,13 @@ type tycon = t local fun make s = (s, fromString s) in - val array = make "array" + val arrayDefault = make "array" + val arrayAos = make "array_aos" + val array = fn lay => + case lay of + ArrayLayout.Default => arrayDefault + | ArrayLayout.Aos => arrayAos + val arrow = make "arrow" val bool = make "bool" val cpointer = make "cpointer" @@ -29,7 +35,23 @@ in val reff = make "ref" val thread = make "thread" val tuple = make "tuple" - val vector = make "vector" + + fun vector lay = + let + val name = + case lay of + ArrayLayout.Default => "vector" + | ArrayLayout.Aos => "vector_aos" + in + make name + end + val vectorDefault = make "vector" + val vectorAos = make "vector_aos" + val vector = fn lay => + case lay of + ArrayLayout.Default => vectorDefault + | ArrayLayout.Aos => vectorAos + val weak = make "weak" end @@ -102,7 +124,8 @@ in end val prims = - List.map ([(array, Arity 1, Always), + List.map ([(array ArrayLayout.Default, Arity 1, Always), + (array ArrayLayout.Aos, Arity 1, Always), (arrow, Arity 2, Never), (bool, Arity 0, Sometimes), (cpointer, Arity 0, Always), @@ -112,7 +135,8 @@ val prims = (reff, Arity 1, Always), (thread, Arity 0, Never), (tuple, Nary, Sometimes), - (vector, Arity 1, Sometimes), + (vector ArrayLayout.Default, Arity 1, Sometimes), + (vector ArrayLayout.Aos, Arity 1, Sometimes), (weak, Arity 1, Never)], fn ((name, tycon), kind, admitsEquality) => {admitsEquality = admitsEquality, @@ -121,7 +145,7 @@ val prims = tycon = tycon}) @ primChars @ primInts @ primReals @ primWords -val array = #2 array +val array = #2 o array val arrow = #2 arrow val bool = #2 bool val cpointer = #2 cpointer @@ -131,7 +155,7 @@ val list = #2 list val reff = #2 reff val thread = #2 thread val tuple = #2 tuple -val vector = #2 vector +val vector = #2 o vector val weak = #2 weak val defaultChar = fn () => @@ -164,6 +188,28 @@ val isCPointer = fn c => equals (c, cpointer) val isIntX = fn c => equals (c, intInf) orelse isIntX c val deIntX = fn c => if equals (c, intInf) then NONE else SOME (deIntX c) +val isArray = fn c => + equals (c, array ArrayLayout.Default) + orelse equals (c, array ArrayLayout.Aos) + +val isVector = fn c => + equals (c, vector ArrayLayout.Default) + orelse equals (c, vector ArrayLayout.Aos) + +fun deArrayLayout c = + if equals (c, array ArrayLayout.Default) + then ArrayLayout.Default + else if equals (c, array ArrayLayout.Aos) + then ArrayLayout.Aos + else Error.bug "PrimTycons.deArrayLayout" + +fun deVectorLayout c = + if equals (c, vector ArrayLayout.Default) + then ArrayLayout.Default + else if equals (c, vector ArrayLayout.Aos) + then ArrayLayout.Aos + else Error.bug "PrimTycons.deVectorLayout" + local local open Layout @@ -216,7 +262,7 @@ in else (mayAlign (Layout.separateLeft (Vector.toListMap (args, maybe TupleElem), "* ")), ({isChar = false}, Tuple)) - else if equals (c, vector) + else if equals (c, vector ArrayLayout.Default) then if #isChar (#1 (#2 (Vector.first args))) then LayoutPretty.simple (str "string") else normal (layoutPretty c, args, {isChar = false}) diff --git a/mlton/atoms/prim-tycons.sig b/mlton/atoms/prim-tycons.sig index 83deb82cb..1d14f0d4e 100644 --- a/mlton/atoms/prim-tycons.sig +++ b/mlton/atoms/prim-tycons.sig @@ -34,11 +34,12 @@ signature PRIM_TYCONS = type tycon - val array: tycon + val array: ArrayLayout.t -> tycon val arrow: tycon val bool: tycon val char: CharSize.t -> tycon val cpointer: tycon + val deArrayLayout: tycon -> ArrayLayout.t val deCharX: tycon -> CharSize.t val defaultChar: unit -> tycon val defaultInt: unit -> tycon @@ -46,11 +47,14 @@ signature PRIM_TYCONS = val defaultWord: unit -> tycon val deIntX: tycon -> IntSize.t option val deRealX: tycon -> RealSize.t + val deVectorLayout: tycon -> ArrayLayout.t val deWordX: tycon -> WordSize.t val exn: tycon val int: IntSize.t -> tycon val ints: (tycon * IntSize.t) vector val intInf: tycon + val isArray: tycon -> bool + val isVector: tycon -> bool val isBool: tycon -> bool val isCharX: tycon -> bool val isCPointer: tycon -> bool @@ -77,7 +81,7 @@ signature PRIM_TYCONS = val reff: tycon val thread: tycon val tuple: tycon - val vector: tycon + val vector: ArrayLayout.t -> tycon val weak: tycon val word: WordSize.t -> tycon val words: (tycon * WordSize.t) vector diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 8e5288965..8e9e99b55 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -34,8 +34,8 @@ structure Kind = end datatype 'a t = - Array_alloc of {raw: bool} (* to rssa (as runtime C fn) *) - | Array_array (* to ssa2 *) + Array_alloc of {raw: bool, layout: ArrayLayout.t} (* to rssa (as runtime C fn) *) + | Array_array of ArrayLayout.t (* to ssa2 *) | Array_cas of CType.t option (* codegen *) | Array_copyArray (* to rssa (as runtime C fn) *) | Array_copyVector (* to rssa (as runtime C fn) *) @@ -166,7 +166,7 @@ datatype 'a t = | TopLevel_setSuffix (* implement suffix *) | Vector_length (* to ssa2 *) | Vector_sub (* to ssa2 *) - | Vector_vector (* to ssa2 *) + | Vector_vector of ArrayLayout.t (* to ssa2 *) | Weak_canGet (* to rssa (as runtime C fn) *) | Weak_get (* to rssa (as runtime C fn) *) | Weak_new (* to rssa (as runtime C fn) *) @@ -227,8 +227,24 @@ fun toString (n: 'a t): string = fun cpointerSet (ty, s) = concat ["CPointer_set", ty, s] in case n of - Array_alloc {raw} => if raw then "Array_allocRaw" else "Array_alloc" - | Array_array => "Array_array" + Array_alloc {raw, layout} => + let + val name = "Array_alloc" + val name = if raw then name ^ "Raw" else name + val name = + case layout of + ArrayLayout.Default => name + | ArrayLayout.Aos => name ^ "Aos" + in + (* Array_alloc + * Array_allocRaw + * Array_allocAos + * Array_allocRawAos + *) + name + end + | Array_array ArrayLayout.Default => "Array_array" + | Array_array ArrayLayout.Aos => "Array_arrayAos" | Array_cas NONE => "Array_cas" | Array_cas (SOME ctype) => concat ["Array", CType.name ctype, "_cas"] | Array_copyArray => "Array_copyArray" @@ -350,7 +366,8 @@ fun toString (n: 'a t): string = | TopLevel_setSuffix => "TopLevel_setSuffix" | Vector_length => "Vector_length" | Vector_sub => "Vector_sub" - | Vector_vector => "Vector_vector" + | Vector_vector ArrayLayout.Default => "Vector_vector" + | Vector_vector ArrayLayout.Aos => "Vector_vectorAos" | Weak_canGet => "Weak_canGet" | Weak_get => "Weak_get" | Weak_new => "Weak_new" @@ -397,8 +414,11 @@ fun layoutFull (p, layoutX) = | p => layout p val equals: 'a t * 'a t -> bool = - fn (Array_alloc {raw = r}, Array_alloc {raw = r'}) => Bool.equals (r, r') - | (Array_array, Array_array) => true + fn (p1, p2) => + case (p1, p2) of + (Array_alloc {raw = r, layout = l}, Array_alloc {raw = r', layout = l'}) => + Bool.equals (r, r') andalso ArrayLayout.equals (l, l') + | (Array_array l1, Array_array l2) => ArrayLayout.equals (l1, l2) | (Array_cas NONE, Array_cas NONE) => true | (Array_cas (SOME ctype1), Array_cas (SOME ctype2)) => CType.equals (ctype1, ctype2) | (Array_copyArray, Array_copyArray) => true @@ -518,7 +538,7 @@ val equals: 'a t * 'a t -> bool = | (TopLevel_setSuffix, TopLevel_setSuffix) => true | (Vector_length, Vector_length) => true | (Vector_sub, Vector_sub) => true - | (Vector_vector, Vector_vector) => true + | (Vector_vector l1, Vector_vector l2) => ArrayLayout.equals (l1, l2) | (Weak_canGet, Weak_canGet) => true | (Weak_get, Weak_get) => true | (Weak_new, Weak_new) => true @@ -583,8 +603,8 @@ val equals: 'a t * 'a t -> bool = val map: 'a t * ('a -> 'b) -> 'b t = fn (p, f) => case p of - Array_alloc {raw} => Array_alloc {raw = raw} - | Array_array => Array_array + Array_alloc {raw, layout} => Array_alloc {raw = raw, layout = layout} + | Array_array l => Array_array l | Array_cas cty => Array_cas cty | Array_copyArray => Array_copyArray | Array_copyVector => Array_copyVector @@ -696,7 +716,7 @@ val map: 'a t * ('a -> 'b) -> 'b t = | TopLevel_setSuffix => TopLevel_setSuffix | Vector_length => Vector_length | Vector_sub => Vector_sub - | Vector_vector => Vector_vector + | Vector_vector l => Vector_vector l | Weak_canGet => Weak_canGet | Weak_get => Weak_get | Weak_new => Weak_new @@ -797,7 +817,7 @@ val kind: 'a t -> Kind.t = in case p of Array_alloc _ => Moveable - | Array_array => Moveable + | Array_array _ => Moveable | Array_cas _ => SideEffect | Array_copyArray => SideEffect | Array_copyVector => SideEffect @@ -912,7 +932,7 @@ val kind: 'a t -> Kind.t = | TopLevel_setSuffix => SideEffect | Vector_length => Functional | Vector_sub => Functional - | Vector_vector => Functional + | Vector_vector _ => Functional | Weak_canGet => DependsOnState | Weak_get => DependsOnState | Weak_new => Moveable @@ -1012,9 +1032,12 @@ local @ wordSigns (s, false) in val all: unit t list = - [Array_alloc {raw = false}, - Array_alloc {raw = true}, - Array_array, + [Array_alloc {raw = false, layout = ArrayLayout.Default}, + Array_alloc {raw = true, layout = ArrayLayout.Default}, + Array_alloc {raw = false, layout = ArrayLayout.Aos}, + Array_alloc {raw = true, layout = ArrayLayout.Aos}, + Array_array ArrayLayout.Default, + Array_array ArrayLayout.Aos, Array_cas NONE, Array_copyArray, Array_copyVector, @@ -1100,7 +1123,8 @@ in TopLevel_setSuffix, Vector_length, Vector_sub, - Vector_vector, + Vector_vector ArrayLayout.Default, + Vector_vector ArrayLayout.Aos, Weak_canGet, Weak_get, Weak_new, @@ -1206,7 +1230,7 @@ fun 'a checkApp (prim: 'a t, {args: 'a vector, result: 'a, targs: 'a vector, - typeOps = {array: 'a -> 'a, + typeOps = {array: ArrayLayout.t -> 'a -> 'a, arrow: 'a * 'a -> 'a, tuple: 'a vector -> 'a, bool: 'a, @@ -1218,7 +1242,7 @@ fun 'a checkApp (prim: 'a t, reff: 'a -> 'a, thread: 'a, unit: 'a, - vector: 'a -> 'a, + vector: ArrayLayout.t -> 'a -> 'a, weak: 'a -> 'a, word: WordSize.t -> 'a}}): bool = let @@ -1333,30 +1357,54 @@ fun 'a checkApp (prim: 'a t, noTargs (fn () => (twoArgs (intInf, csize), intInf)) fun realTernary s = noTargs (fn () => (threeArgs (real s, real s, real s), real s)) - fun wordArray seqSize = array (word seqSize) + fun wordArray seqSize = array ArrayLayout.Default (word seqSize) fun wordShift s = noTargs (fn () => (twoArgs (word s, shiftArg), word s)) - val word8Vector = vector word8 - fun wordVector seqSize = vector (word seqSize) + val word8Vector = vector ArrayLayout.Default word8 + fun wordVector seqSize = vector ArrayLayout.Default (word seqSize) val string = word8Vector + + (* For ad-hoc polymorphism over arrays of different memory layouts. + * (Many array primitives are overloaded. Notably, CAS is not.) + *) + fun anyArrayLayout (f: ArrayLayout.t -> bool) : bool = + List.exists ([ArrayLayout.Default, ArrayLayout.Aos], f) in case prim of - Array_alloc _ => oneTarg (fn targ => (oneArg seqIndex, array targ)) - | Array_array => oneTarg (fn targ => (nArgs (Vector.map (args, fn _ => targ)), array targ)) + Array_alloc {layout, ...} => + oneTarg (fn targ => (oneArg seqIndex, array layout targ)) + | Array_array layout => + oneTarg (fn targ => (nArgs (Vector.map (args, fn _ => targ)), array layout targ)) | Array_cas _ => - oneTarg (fn t => (fourArgs (array t, seqIndex, t, t), t)) - | Array_copyArray => oneTarg (fn t => (fiveArgs (array t, seqIndex, array t, seqIndex, seqIndex), unit)) - | Array_copyVector => oneTarg (fn t => (fiveArgs (array t, seqIndex, vector t, seqIndex, seqIndex), unit)) - | Array_length => oneTarg (fn t => (oneArg (array t), seqIndex)) - | Array_sub _ => oneTarg (fn t => (twoArgs (array t, seqIndex), t)) - | Array_toArray => oneTarg (fn t => (oneArg (array t), array t)) - | Array_toVector => oneTarg (fn t => (oneArg (array t), vector t)) + (* only valid over default arrays, not flattened *) + oneTarg (fn t => (fourArgs (array ArrayLayout.Default t, seqIndex, t, t), t)) + | Array_copyArray => + anyArrayLayout (fn lay => oneTarg (fn t => + (fiveArgs (array lay t, seqIndex, array lay t, seqIndex, seqIndex), unit))) + | Array_copyVector => + anyArrayLayout (fn lay => oneTarg (fn t => + (fiveArgs (array lay t, seqIndex, vector lay t, seqIndex, seqIndex), unit))) + | Array_length => + anyArrayLayout (fn lay => oneTarg (fn t => + (oneArg (array lay t), seqIndex))) + | Array_sub _ => + anyArrayLayout (fn lay => oneTarg (fn t => + (twoArgs (array lay t, seqIndex), t))) + | Array_toArray => + anyArrayLayout (fn lay => oneTarg (fn t => + (oneArg (array lay t), array lay t))) + | Array_toVector => + anyArrayLayout (fn lay => oneTarg (fn t => + (oneArg (array lay t), vector lay t))) | Array_uninit => - oneTarg (fn t => (twoArgs (array t, seqIndex), unit)) + anyArrayLayout (fn lay => oneTarg (fn t => + (twoArgs (array lay t, seqIndex), unit))) | Array_uninitIsNop => - oneTarg (fn t => (oneArg (array t), bool)) + anyArrayLayout (fn lay => oneTarg (fn t => + (oneArg (array lay t), bool))) | Array_update _ => - oneTarg (fn t => (threeArgs (array t, seqIndex, t), unit)) + anyArrayLayout (fn lay => oneTarg (fn t => + (threeArgs (array lay t, seqIndex, t), unit))) | CFunction f => noTargs (fn () => (nArgs (CFunction.args f), CFunction.return f)) | CPointer_add => @@ -1411,7 +1459,7 @@ fun 'a checkApp (prim: 'a t, | IntInf_toString => noTargs (fn () => (threeArgs (intInf, word32, csize), string)) | IntInf_toVector => - noTargs (fn () => (oneArg intInf, vector bigIntInfWord)) + noTargs (fn () => (oneArg intInf, vector ArrayLayout.Default bigIntInfWord)) | IntInf_toWord => noTargs (fn () => (oneArg intInf, smallIntInfWord)) | IntInf_xorb => intInfBinary () | MLton_bogus => oneTarg (fn t => (noArgs, t)) @@ -1493,9 +1541,14 @@ fun 'a checkApp (prim: 'a t, noTargs (fn () => (oneArg (arrow (unit, unit)), unit)) | String_toWord8Vector => noTargs (fn () => (oneArg string, word8Vector)) - | Vector_length => oneTarg (fn t => (oneArg (vector t), seqIndex)) - | Vector_sub => oneTarg (fn t => (twoArgs (vector t, seqIndex), t)) - | Vector_vector => oneTarg (fn targ => (nArgs (Vector.map (args, fn _ => targ)), vector targ)) + | Vector_length => + anyArrayLayout (fn lay => oneTarg (fn t => + (oneArg (vector lay t), seqIndex))) + | Vector_sub => + anyArrayLayout (fn lay => oneTarg (fn t => + (twoArgs (vector lay t, seqIndex), t))) + | Vector_vector layout => + oneTarg (fn targ => (nArgs (Vector.map (args, fn _ => targ)), vector layout targ)) | Weak_canGet => oneTarg (fn t => (oneArg (weak t), bool)) | Weak_get => oneTarg (fn t => (oneArg (weak t), t)) | Weak_new => oneTarg (fn t => (oneArg t, weak t)) @@ -1508,7 +1561,7 @@ fun 'a checkApp (prim: 'a t, | Word8Vector_toString => noTargs (fn () => (oneArg (word8Vector), string)) | WordVector_toIntInf => - noTargs (fn () => (oneArg (vector bigIntInfWord), intInf)) + noTargs (fn () => (oneArg (vector ArrayLayout.Default bigIntInfWord), intInf)) | Word_add s => wordBinary s | Word_addCheckP (s, _) => wordBinaryP s | Word_andb s => wordBinary s @@ -1560,7 +1613,7 @@ fun ('a, 'b) extractTargs (prim: 'b t, in case prim of Array_alloc _ => one (deArray result) - | Array_array => one (deArray result) + | Array_array _ => one (deArray result) | Array_cas _ => one (deArray (arg 0)) | Array_copyArray => one (deArray (arg 0)) | Array_copyVector => one (deArray (arg 0)) @@ -1604,7 +1657,7 @@ fun ('a, 'b) extractTargs (prim: 'b t, | Ref_ref => one (deRef result) | Vector_length => one (deVector (arg 0)) | Vector_sub => one (deVector (arg 0)) - | Vector_vector => one (deVector result) + | Vector_vector _ => one (deVector result) | Weak_canGet => one (deWeak (arg 0)) | Weak_get => one result | Weak_new => one (arg 0) @@ -1850,7 +1903,7 @@ fun ('a, 'b) apply (p: 'a t, | (Real_lt _, [Real r1, Real r2]) => boolOpt (RealX.lt (r1, r2)) | (Real_qequal _, [Real r1, Real r2]) => boolOpt (RealX.qequal (r1, r2)) | (Real_castToWord _, [Real r]) => wordOpt (RealX.castToWord r) - | (Vector_vector, (Word w)::_) => + | (Vector_vector _, (Word w)::_) => (wordVector o WordXVector.fromList) ({elementSize = WordX.size w}, List.map (cs, Const.deWord)) diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index 99cb3fcd6..7a5e32797 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -25,8 +25,8 @@ signature PRIM = include PRIM_STRUCTS datatype 'a t = - Array_alloc of {raw: bool} (* to rssa (as runtime C fn) *) - | Array_array (* to ssa2 *) + Array_alloc of {raw: bool, layout: ArrayLayout.t} (* to rssa (as runtime C fn) *) + | Array_array of ArrayLayout.t (* to ssa2 *) | Array_cas of CType.t option (* codegen *) | Array_copyArray (* to rssa (as runtime C fn) *) | Array_copyVector (* to rssa (as runtime C fn) *) @@ -157,7 +157,7 @@ signature PRIM = | TopLevel_setSuffix (* implement suffix *) | Vector_length (* to ssa2 *) | Vector_sub (* to ssa2 *) - | Vector_vector (* to ssa2 *) + | Vector_vector of ArrayLayout.t (* to ssa2 *) | Weak_canGet (* to rssa (as runtime C fn) *) | Weak_get (* to rssa (as runtime C fn) *) | Weak_new (* to rssa (as runtime C fn) *) @@ -219,7 +219,7 @@ signature PRIM = val checkApp: 'a t * {args: 'a vector, result: 'a, targs: 'a vector, - typeOps: {array: 'a -> 'a, + typeOps: {array: ArrayLayout.t -> 'a -> 'a, arrow: 'a * 'a -> 'a, tuple: 'a vector -> 'a, bool: 'a, @@ -231,7 +231,7 @@ signature PRIM = reff: 'a -> 'a, thread: 'a, unit: 'a, - vector: 'a -> 'a, + vector: ArrayLayout.t -> 'a -> 'a, weak: 'a -> 'a, word: WordSize.t -> 'a}} -> bool val cpointerGet: CType.t -> 'a t diff --git a/mlton/atoms/sources.cm b/mlton/atoms/sources.cm index ef7635799..89340fa21 100644 --- a/mlton/atoms/sources.cm +++ b/mlton/atoms/sources.cm @@ -46,6 +46,8 @@ functor GenericScheme functor HashType functor TypeOps +structure ArrayLayout + is ../../lib/mlton/sources.cm @@ -109,6 +111,7 @@ ffi.sig ffi.fun cases.sig cases.fun +array-layout.sml prim.sig prim.fun exn-dec-elab.sig diff --git a/mlton/atoms/sources.mlb b/mlton/atoms/sources.mlb index 1c4195e82..864b8128b 100644 --- a/mlton/atoms/sources.mlb +++ b/mlton/atoms/sources.mlb @@ -51,6 +51,7 @@ local admits-equality.fun tycon-kind.sig tycon-kind.fun + array-layout.sml prim-tycons.sig prim-tycons.fun tycon.sig diff --git a/mlton/atoms/type-ops.fun b/mlton/atoms/type-ops.fun index ec2a929eb..9d4f28cff 100644 --- a/mlton/atoms/type-ops.fun +++ b/mlton/atoms/type-ops.fun @@ -36,15 +36,15 @@ end local fun unary tycon t = con (tycon, Vector.new1 t) in - val array = unary Tycon.array + fun array (lay: ArrayLayout.t) = unary (Tycon.array lay) val list = unary Tycon.list val reff = unary Tycon.reff - val vector = unary Tycon.vector + fun vector (lay: ArrayLayout.t) = unary (Tycon.vector lay) val weak = unary Tycon.weak end val word8 = word WordSize.word8 -val word8Vector = vector word8 +val word8Vector = vector ArrayLayout.Default word8 val word32 = word WordSize.word32 local @@ -68,11 +68,41 @@ fun deUnary tycon t = SOME t => t | NONE => Error.bug "TypeOps.deUnary" -val deArray = deUnary Tycon.array val deRef = deUnary Tycon.reff -val deVector = deUnary Tycon.vector val deWeak = deUnary Tycon.weak +fun deArrayOpt t = + case deConOpt t of + SOME (c, ts) => + if Tycon.isArray c then SOME (Vector.first ts) else NONE + | _ => NONE + +fun deVectorOpt t = + case deConOpt t of + SOME (c, ts) => + if Tycon.isVector c then SOME (Vector.first ts) else NONE + | _ => NONE + +val deArray = fn t => + case deArrayOpt t of + SOME x => x + | NONE => Error.bug "TypeOps.deArray" + +val deVector = fn t => + case deVectorOpt t of + SOME x => x + | NONE => Error.bug "TypeOps.deVector" + +fun deArrayLayout t = + case deConOpt t of + SOME (c, _) => Tycon.deArrayLayout c + | NONE => Error.bug "TypeOps.deArrayLayout" + +fun deVectorLayout t = + case deConOpt t of + SOME (c, _) => Tycon.deVectorLayout c + | NONE => Error.bug "TypeOps.deVectorLayout" + fun tuple ts = if 1 = Vector.length ts then Vector.first ts diff --git a/mlton/atoms/type-ops.sig b/mlton/atoms/type-ops.sig index 748e111a0..354524f74 100644 --- a/mlton/atoms/type-ops.sig +++ b/mlton/atoms/type-ops.sig @@ -28,12 +28,13 @@ signature TYPE_OPS = type wordSize type t - val array: t -> t + val array: ArrayLayout.t -> t -> t val arrow: t * t -> t val bool: t val con: tycon * t vector -> t val cpointer: t val deArray: t -> t + val deArrayLayout: t -> ArrayLayout.t val deArrow: t -> t * t val deArrowOpt: t -> (t * t) option val deConOpt: t -> (tycon * t vector) option @@ -41,6 +42,7 @@ signature TYPE_OPS = val deTuple: t -> t vector val deTupleOpt: t -> t vector option val deVector: t -> t + val deVectorLayout: t -> ArrayLayout.t val deWeak: t -> t val exn: t val intInf: t @@ -52,7 +54,7 @@ signature TYPE_OPS = val tuple: t vector -> t val unit: t val unitRef: t - val vector: t -> t + val vector: ArrayLayout.t -> t -> t val weak: t -> t val word: wordSize -> t val word8: t diff --git a/mlton/backend/packed-representation.fun b/mlton/backend/packed-representation.fun index 7976678eb..aa6b965e5 100644 --- a/mlton/backend/packed-representation.fun +++ b/mlton/backend/packed-representation.fun @@ -2551,7 +2551,7 @@ fun compute (program as Ssa2.Program.T {datatypes, ...}) = in r end - | ObjectCon.Sequence => + | ObjectCon.Sequence _ => let val hasIdentity = Prod.someIsMutable args val args = Prod.dest args @@ -2700,7 +2700,7 @@ fun compute (program as Ssa2.Program.T {datatypes, ...}) = | Word s => nonObjptr (Type.word s) end)) val () = typeRepRef := typeRep - val _ = typeRep (S.Type.vector1 (S.Type.word WordSize.byte)) + val _ = typeRep (S.Type.vector1 ArrayLayout.Default (S.Type.word WordSize.byte)) (* Establish dependence between constructor argument type representations * and tycon representations. *) @@ -2804,7 +2804,7 @@ fun compute (program as Ssa2.Program.T {datatypes, ...}) = ConRep.ShiftAndTag {selects, ...} => (selects, NONE) | ConRep.Tuple tr => (TupleRep.selects tr, NONE) | _ => Error.bug "PackedRepresentation.getSelects: Con,non-select") - | Sequence => + | Sequence _ => (case sequenceRep objectTy of tr as TupleRep.Indirect pr => (TupleRep.selects tr, diff --git a/mlton/backend/ssa2-to-rssa.fun b/mlton/backend/ssa2-to-rssa.fun index a2827a5dc..f4dcb0e34 100644 --- a/mlton/backend/ssa2-to-rssa.fun +++ b/mlton/backend/ssa2-to-rssa.fun @@ -1475,7 +1475,7 @@ fun convert (program as S.Program.T {functions, globals, main, ...}, return = SOME l})) end end) - | Prim.Array_alloc {raw} => + | Prim.Array_alloc {raw, ...} => let val allocOpt = fn () => let diff --git a/mlton/closure-convert/abstract-value.fun b/mlton/closure-convert/abstract-value.fun index 610f1ddee..ce5892863 100644 --- a/mlton/closure-convert/abstract-value.fun +++ b/mlton/closure-convert/abstract-value.fun @@ -157,12 +157,16 @@ structure LambdaNode: structure UnaryTycon = struct - datatype t = Array | Ref | Vector | Weak + datatype t = + Array of ArrayLayout.t + | Ref + | Vector of ArrayLayout.t + | Weak val toString = - fn Array => "Array" + fn Array lay => "Array(" ^ ArrayLayout.toString lay ^ ")" | Ref => "Ref" - | Vector => "Vector" + | Vector lay => "Vector(" ^ ArrayLayout.toString lay ^ ")" | Weak => "Weak" val equals: t * t -> bool = op = @@ -253,10 +257,14 @@ local end in if Tycon.equals (tycon, Tycon.reff) then mutable UnaryTycon.Ref - else if Tycon.equals (tycon, Tycon.array) - then mutable UnaryTycon.Array - else if Tycon.equals (tycon, Tycon.vector) - then mutable UnaryTycon.Vector + else if Tycon.equals (tycon, Tycon.array ArrayLayout.Default) + then mutable (UnaryTycon.Array ArrayLayout.Default) + else if Tycon.equals (tycon, Tycon.array ArrayLayout.Aos) + then mutable (UnaryTycon.Array ArrayLayout.Aos) + else if Tycon.equals (tycon, Tycon.vector ArrayLayout.Default) + then mutable (UnaryTycon.Vector ArrayLayout.Default) + else if Tycon.equals (tycon, Tycon.vector ArrayLayout.Aos) + then mutable (UnaryTycon.Vector ArrayLayout.Aos) else if Tycon.equals (tycon, Tycon.weak) then mutable UnaryTycon.Weak else if Tycon.equals (tycon, Tycon.tuple) @@ -359,12 +367,12 @@ val coerce = Trace.trace ("AbstractValue.coerce", structure Dest = struct datatype dest = - Array of t + Array of {elem: t, layout: ArrayLayout.t} | Lambdas of Lambdas.t | Ref of t | Tuple of t vector | Type of Type.t - | Vector of t + | Vector of {elem: t, layout: ArrayLayout.t} | Weak of t end @@ -372,9 +380,9 @@ fun dest v = case tree v of Type t => Dest.Type t | Unify (mt, v) => (case mt of - UnaryTycon.Array => Dest.Array v + UnaryTycon.Array lay => Dest.Array {elem=v, layout=lay} | UnaryTycon.Ref => Dest.Ref v - | UnaryTycon.Vector => Dest.Vector v + | UnaryTycon.Vector lay => Dest.Vector {elem=v, layout=lay} | UnaryTycon.Weak => Dest.Weak v) | Tuple vs => Dest.Tuple vs | Lambdas l => Dest.Lambdas (LambdaNode.toSet l) @@ -426,12 +434,14 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = else Error.bug "AbstractValue.primApply.fiveArgs" in case prim of - Prim.Array_array => + Prim.Array_array lay => let val r = result () val _ = case dest r of - Array x => Vector.foreach (args, fn arg => coerce {from = arg, to = x}) + Array {elem = x, layout = lay'} => + (* SAM_NOTE: could do a sanity check here that lay = lay' *) + Vector.foreach (args, fn arg => coerce {from = arg, to = x}) | Type _ => () | _ => typeError () in @@ -442,14 +452,14 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = val (a, _, x, y) = fourArgs () in (case dest a of - Array v => (unify (y, v); unify (x, v); v) + Array {elem=v, ...} => (unify (y, v); unify (x, v); v) | Type _ => result () | _ => typeError ()) end | Prim.Array_copyArray => let val (da, _, sa, _, _) = fiveArgs () in (case (dest da, dest sa) of - (Array dx, Array sx) => unify (dx, sx) + (Array {elem=dx, ...}, Array {elem=sx, ...}) => unify (dx, sx) | (Type _, Type _) => () | _ => typeError () ; result ()) @@ -457,7 +467,7 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = | Prim.Array_copyVector => let val (da, _, sa, _, _) = fiveArgs () in (case (dest da, dest sa) of - (Array dx, Vector sx) => unify (dx, sx) + (Array {elem=dx, ...}, Vector {elem=sx, ...}) => unify (dx, sx) | (Type _, Type _) => () | _ => typeError () ; result ()) @@ -466,7 +476,7 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = let val r = result () in (case (dest (oneArg ()), dest r) of (Type _, Type _) => () - | (Array x, Array y) => + | (Array {elem=x, ...}, Array {elem=y, ...}) => (* Can't do a coercion here because that would imply * walking over each element of the array and coercing it. *) @@ -478,7 +488,7 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = let val r = result () in (case (dest (oneArg ()), dest r) of (Type _, Type _) => () - | (Array x, Vector y) => + | (Array {elem=x, ...}, Vector {elem=y, ...}) => (* Can't do a coercion here because that would imply * walking over each element of the array and coercing it. *) @@ -488,13 +498,13 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = end | Prim.Array_sub _ => (case dest (#1 (twoArgs ())) of - Array x => x + Array {elem, ...} => elem | Type _ => result () | _ => typeError ()) | Prim.Array_update _ => let val (a, _, x) = threeArgs () in (case dest a of - Array x' => coerce {from = x, to = x'} (* unify (x, x') *) + Array {elem=x', ...} => coerce {from = x, to = x'} (* unify (x, x') *) | Type _ => () | _ => typeError ()) ; result () @@ -545,15 +555,17 @@ fun primApply {prim: Type.t Prim.t, args: t vector, resultTy: Type.t}: t = end | Prim.Vector_sub => (case dest (#1 (twoArgs ())) of - Vector x => x + Vector {elem, ...} => elem | Type _ => result () | _ => typeError ()) - | Prim.Vector_vector => + | Prim.Vector_vector lay => let val r = result () val _ = case dest r of - Vector x => Vector.foreach (args, fn arg => coerce {from = arg, to = x}) + Vector {elem = x, layout = lay'} => + (* SAM_NOTE: could do a sanity check here that lay = lay' *) + Vector.foreach (args, fn arg => coerce {from = arg, to = x}) | Type _ => () | _ => typeError () in diff --git a/mlton/closure-convert/abstract-value.sig b/mlton/closure-convert/abstract-value.sig index c56756d2b..9e7ee76ce 100644 --- a/mlton/closure-convert/abstract-value.sig +++ b/mlton/closure-convert/abstract-value.sig @@ -37,12 +37,12 @@ signature ABSTRACT_VALUE = type t datatype dest = - Array of t + Array of {elem: t, layout: ArrayLayout.t} | Lambdas of Lambdas.t | Ref of t | Tuple of t vector | Type of Sxml.Type.t (* type doesn't contain any arrows *) - | Vector of t + | Vector of {elem: t, layout: ArrayLayout.t} | Weak of t val addHandler: t * (Lambda.t -> unit) -> unit diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index 8b4b8a35b..d26aa6ede 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -568,13 +568,15 @@ fun closureConvert else Error.bug "ClosureConvert.convertType.unary: bogus application of unary tycon" val tycons = [(Tycon.arrow, fn _ => Error.bug "ClosureConvert.convertType.array"), - (Tycon.array, unary Type.array), + (Tycon.array ArrayLayout.Default, unary (Type.array ArrayLayout.Default)), + (Tycon.array ArrayLayout.Aos, unary (Type.array ArrayLayout.Aos)), (Tycon.cpointer, nullary Type.cpointer), (Tycon.intInf, nullary Type.intInf), (Tycon.reff, unary Type.reff), (Tycon.thread, nullary Type.thread), (Tycon.tuple, Type.tuple), - (Tycon.vector, unary Type.vector), + (Tycon.vector ArrayLayout.Default, unary (Type.vector ArrayLayout.Default)), + (Tycon.vector ArrayLayout.Aos, unary (Type.vector ArrayLayout.Aos)), (Tycon.weak, unary Type.weak)] @ Vector.toListMap (Tycon.reals, fn (t, s) => (t, nullary (Type.real s))) @ Vector.toListMap (Tycon.words, fn (t, s) => (t, nullary (Type.word s))) @@ -608,13 +610,13 @@ fun closureConvert let val t = case Value.dest v of - Value.Array v => Type.array (valueType v) + Value.Array {elem=v, layout} => Type.array layout (valueType v) | Value.Lambdas ls => #ty (lambdasInfo ls) | Value.Ref v => Type.reff (valueType v) | Value.Type t => convertType t | Value.Tuple vs => Type.tuple (Vector.map (vs, valueType)) - | Value.Vector v => Type.vector (valueType v) + | Value.Vector {elem=v, layout} => Type.vector layout (valueType v) | Value.Weak v => Type.weak (valueType v) in r := SOME t; t end @@ -1270,7 +1272,7 @@ fun closureConvert in simple (case prim of - Prim.Array_array => + Prim.Array_array _ => let val ys = Vector.map (args, varExpInfo) val v = Value.deArray v @@ -1371,7 +1373,7 @@ fun closureConvert v1 (coerce (convertVarInfo y, VarInfo.value y, v))) end - | Prim.Vector_vector => + | Prim.Vector_vector _ => let val ys = Vector.map (args, varExpInfo) val v = Value.deVector v diff --git a/mlton/closure-convert/globalize.fun b/mlton/closure-convert/globalize.fun index 9af62793e..256aa0d58 100644 --- a/mlton/closure-convert/globalize.fun +++ b/mlton/closure-convert/globalize.fun @@ -27,9 +27,11 @@ fun globalize {program = Program.T {datatypes, body, ...}, Property.destGetSetOnce (Tycon.plist, Property.initConst false) fun makeBig tycon = set (tycon, true) val _ = (Vector.foreach (datatypes, makeBig o #tycon) - ; makeBig Tycon.array + ; makeBig (Tycon.array ArrayLayout.Default) + ; makeBig (Tycon.array ArrayLayout.Aos) ; makeBig Tycon.arrow - ; makeBig Tycon.vector) + ; makeBig (Tycon.vector ArrayLayout.Default) + ; makeBig (Tycon.vector ArrayLayout.Aos)) in val tyconIsBig = get val destroyTycon = destroy diff --git a/mlton/defunctorize/defunctorize.fun b/mlton/defunctorize/defunctorize.fun index 5764dbf2d..15373b415 100644 --- a/mlton/defunctorize/defunctorize.fun +++ b/mlton/defunctorize/defunctorize.fun @@ -1091,7 +1091,7 @@ fun defunctorize (CoreML.Program.T {decs}) = var = var ()} | Vector es => Xexp.primApp {args = Vector.map (es, #1 o loopExp), - prim = Prim.Vector_vector, + prim = Prim.Vector_vector (Xtype.deVectorLayout ty), targs = Vector.new1 (Xtype.deVector ty), ty = ty} in diff --git a/mlton/elaborate/elaborate-core.fun b/mlton/elaborate/elaborate-core.fun index d7d0b7bca..bbf213f7e 100644 --- a/mlton/elaborate/elaborate-core.fun +++ b/mlton/elaborate/elaborate-core.fun @@ -532,7 +532,7 @@ fun lookConst {default: string option, expandedTy, name, region}: unit -> Const. then realConstFromString (Tycon.deRealX c) else if Tycon.isWordX c then wordConstFromString (Tycon.deWordX c) - else if Tycon.equals (c, Tycon.vector) + else if Tycon.equals (c, Tycon.vector ArrayLayout.Default) andalso 1 = Vector.length ts andalso (case Type.deConOpt (Vector.first ts) of NONE => false @@ -583,8 +583,8 @@ fun unifySeq (seqTy, seqStr, in fun unifyList (trs: (Type.t * Region.t) vector, unify): Type.t = unifySeq (Type.list, "list", trs, unify) -fun unifyVector (trs: (Type.t * Region.t) vector, unify): Type.t = - unifySeq (Type.vector, "vector", trs, unify) +fun unifyVector (lay: ArrayLayout.t) (trs: (Type.t * Region.t) vector, unify): Type.t = + unifySeq (Type.vector lay, "vector", trs, unify) end val elabPatInfo = Trace.info "ElaborateCore.elabPat" @@ -1006,7 +1006,7 @@ val elaboratePat: val ps' = Vector.map (ps, loop) in Cpat.make (Cpat.Vector ps', - unifyVector + unifyVector ArrayLayout.Default (Vector.map2 (ps, ps', fn (p, p') => (Cpat.ty p', Apat.region p)), unify)) @@ -1116,7 +1116,11 @@ structure Type = {ctype = ctype, name = name, tycon = tycon}) val unary: Tycon.t list = - [Tycon.array, Tycon.reff, Tycon.vector] + [Tycon.array ArrayLayout.Default, + Tycon.array ArrayLayout.Aos, + Tycon.reff, + Tycon.vector ArrayLayout.Default, + Tycon.vector ArrayLayout.Aos] fun toNullaryCType (t: t): {ctype: CType.t, name: string} option = case deConOpt t of @@ -3845,7 +3849,7 @@ fun elaborateDec (d, {env = E, nest}) = val es' = Vector.map (es, elab) in Cexp.make (Cexp.Vector es', - unifyVector + unifyVector ArrayLayout.Default (Vector.map2 (es, es', fn (e, e') => (Cexp.ty e', Aexp.region e)), unify)) diff --git a/mlton/elaborate/type-env.fun b/mlton/elaborate/type-env.fun index 702041cd7..5953870ca 100644 --- a/mlton/elaborate/type-env.fun +++ b/mlton/elaborate/type-env.fun @@ -1018,7 +1018,7 @@ structure Type = val unresolvedWord = make Word end - fun unresolvedString () = vector (unresolvedChar ()) + fun unresolvedString () = vector ArrayLayout.Default (unresolvedChar ()) val traceCanUnify = Trace.trace2 diff --git a/mlton/ssa/analyze2.fun b/mlton/ssa/analyze2.fun index dea9e994a..e99ca9391 100644 --- a/mlton/ssa/analyze2.fun +++ b/mlton/ssa/analyze2.fun @@ -257,7 +257,7 @@ fun 'a analyze let val args = case Type.dest ty of - Type.Object {args = ts, con = ObjectCon.Sequence} => + Type.Object {args = ts, con = ObjectCon.Sequence _} => Vector.map (args, fn args => Prod.make diff --git a/mlton/ssa/constant-propagation.fun b/mlton/ssa/constant-propagation.fun index afde18619..28bc8e5b4 100644 --- a/mlton/ssa/constant-propagation.fun +++ b/mlton/ssa/constant-propagation.fun @@ -159,21 +159,22 @@ structure Value = structure ArrayInit = struct datatype 'a t = - Alloc of {raw: bool} - | Array of {args: 'a vector} + Alloc of {raw: bool, layout: ArrayLayout.t} + | Array of {args: 'a vector, layout: ArrayLayout.t} fun layout layoutA ai = let open Layout in case ai of - Alloc {raw} => + Alloc {raw, layout=lay} => seq [str "Alloc ", - record [("raw", Bool.layout raw)]] - | Array {args} => + record [("raw", Bool.layout raw), + ("layout", ArrayLayout.layout lay)]] + | Array {args, layout=lay} => seq [str "Array ", - record [("args", - Vector.layout layoutA args)]] + record [("args", Vector.layout layoutA args), + ("layout", ArrayLayout.layout lay)]] end end structure ArrayBirth = @@ -428,7 +429,7 @@ structure Value = | Ref of {arg: t, birth: t RefBirth.t} | Tuple of t vector - | Vector of {sequence: t Sequence.t} + | Vector of {sequence: t Sequence.t, layout: ArrayLayout.t} | Weak of t local @@ -463,9 +464,10 @@ structure Value = tuple [layout arg, RefBirth.layout layout birth]] | Tuple vs => Vector.layout layout vs - | Vector {sequence, ...} => + | Vector {sequence, layout=lay} => seq [str "vector ", - tuple [Sequence.layout layout sequence]] + tuple [Sequence.layout layout sequence, + ArrayLayout.layout lay]] | Weak v => seq [str "weak ", layout v] end in @@ -582,11 +584,11 @@ structure Value = fun loop (t: Type.t): t = new (case Type.dest t of - Type.Array t => Array {birth = arrayBirth (), sequence = sequence loop t} + Type.Array {elem=t, layout=lay} => Array {birth = arrayBirth lay, sequence = sequence loop t} | Type.Datatype _ => Datatype (data ()) | Type.Ref t => Ref {arg = loop t, birth = refBirth ()} | Type.Tuple ts => Tuple (Vector.map (ts, loop)) - | Type.Vector t => Vector {sequence = sequence loop t} + | Type.Vector {elem=t, layout=lay} => Vector {sequence = sequence loop t, layout = lay} | Type.Weak t => Weak (loop t) | _ => Const (const ()), t) @@ -595,7 +597,7 @@ structure Value = in val mkFromType = fn {clone, coerce, unify} => - make {arrayBirth = ArrayBirth.undefined, + make {arrayBirth = (fn lay => ArrayBirth.undefined ()), const = Const.undefined, data = Data.undefined, refBirth = RefBirth.undefined, @@ -606,7 +608,7 @@ structure Value = undefined = undefined, unify = unify}} val unknown = - make {arrayBirth = ArrayBirth.unknown, + make {arrayBirth = (fn lay => ArrayBirth.unknown ()), const = Const.unknown, data = Data.unknown, refBirth = RefBirth.unknown, @@ -698,8 +700,8 @@ structure Value = ; unify (argFrom, argTo)) | (Tuple froms, Tuple tos) => coerces {froms = froms, tos = tos} - | (Vector {sequence = sequenceFrom}, - Vector {sequence = sequenceTo}) => + | (Vector {sequence = sequenceFrom, ...}, + Vector {sequence = sequenceTo, ...}) => sequenceCoerce {from = sequenceFrom, to = sequenceTo} | (Weak from, Weak to) => unify (from, to) | (_, _) => error () @@ -761,8 +763,8 @@ structure Value = (RefBirth.unify (birth1, birth2) ; unify (arg1, arg2)) | (Tuple vs1, Tuple vs2) => Vector.foreach2 (vs1, vs2, unify) - | (Vector {sequence = sequence1}, - Vector {sequence = sequence2}) => + | (Vector {sequence = sequence1, ...}, + Vector {sequence = sequence2, ...}) => sequenceUnify (sequence1, sequence2) | (Weak v1, Weak v2) => unify (v1, v2) | _ => error () @@ -793,7 +795,7 @@ structure Value = | Datatype d => Data.makeUnknown d | Ref {arg, ...} => makeUnknown arg | Tuple vs => Vector.foreach (vs, makeUnknown) - | Vector {sequence} => Sequence.makeUnknown makeUnknown sequence + | Vector {sequence, ...} => Sequence.makeUnknown makeUnknown sequence | Weak v => makeUnknown v fun sideEffect (v: t): unit = @@ -887,11 +889,11 @@ structure Value = S.Const.WordVector v => let val eltTy = Type.word (WordXVector.elementSize v) - val vecTy = Type.vector eltTy + val vecTy = Type.vector ArrayLayout.Default eltTy val args = WordXVector.toVectorMap (v, const o S.Const.word) val seq = Sequence.make (args, eltTy) in - new (Vector {sequence = seq}, vecTy) + new (Vector {sequence = seq, layout = ArrayLayout.Default}, vecTy) end | _ => const c @@ -984,21 +986,21 @@ structure Value = (birth, fn ab => if isSmallType ty then (case ab of - ArrayInit.Alloc {raw} => + ArrayInit.Alloc {raw, layout} => (case global length of NONE => NONE | SOME (length, _) => SOME (Exp.PrimApp {args = Vector.new1 length, - prim = Prim.Array_alloc {raw = raw}, + prim = Prim.Array_alloc {raw = raw, layout = layout}, targs = Vector.new1 eltTy})) - | ArrayInit.Array {args} => + | ArrayInit.Array {args, layout} => (case globals args of NONE => NONE | SOME args => SOME (Exp.PrimApp {args = Vector.map (args, #1), - prim = Prim.Array_array, + prim = Prim.Array_array layout, targs = Vector.new1 eltTy}))) else NONE) end @@ -1046,11 +1048,23 @@ structure Value = NONE => No | SOME xts => yes (Exp.Tuple (Vector.map (xts, #1)))) - | Vector {sequence} => + | Vector {sequence, layout=lay} => (case Sequence.Elts.getElts (Sequence.elts sequence) of NONE => No | SOME elts => let + (* presumably, lay should be the same as Type.deVectorLayout ty ? + * sanity check... *) + val () = + if not (ArrayLayout.equals (lay, Type.deVectorLayout ty)) + then + Error.bug + (concat ["ConstantPropagation.Value.global: vector layout mismatch: ", + Layout.toString (ArrayLayout.layout lay), + " vs ", + Layout.toString (ArrayLayout.layout (Type.deVectorLayout ty))]) + else () + val eltTy = Type.deVector ty fun vector () = case globals elts of @@ -1058,7 +1072,7 @@ structure Value = | SOME args => yes (Exp.PrimApp {args = Vector.map (args, #1), - prim = Prim.Vector_vector, + prim = Prim.Vector_vector lay, targs = Vector.new1 eltTy}) fun wordxvector elementSize = Exn.withEscape @@ -1096,11 +1110,15 @@ structure Value = | _ => Error.bug "ConstantPropagation.Value.arrayToArray" fun arrayToVector (v: t): t = - case value v of - Array {sequence, ...} => - new (Vector {sequence = sequence}, - Type.vector (Type.deArray (ty v))) - | _ => Error.bug "ConstantPropagation.Value.arrayToVector" + let + val lay = Type.deArrayLayout (ty v) + in + case value v of + Array {sequence, ...} => + new (Vector {sequence = sequence, layout = lay}, + Type.vector lay (Type.deArray (ty v))) + | _ => Error.bug "ConstantPropagation.Value.arrayToVector" + end end (* ------------------------------------------------- *) @@ -1343,17 +1361,17 @@ fun transform (program: Program.t): Program.t = end) in case prim of - Prim.Array_alloc {raw} => + Prim.Array_alloc {raw, layout} => let - val birth = bear (ArrayInit.Alloc {raw = raw}) + val birth = bear (ArrayInit.Alloc {raw = raw, layout = layout}) val sequence = Sequence.undefined (Type.deArray resultType) val _ = coerce {from = arg 0, to = Sequence.length sequence} in new (Array {birth = birth, sequence = sequence}, resultType) end - | Prim.Array_array => + | Prim.Array_array layout => let - val birth = bear (ArrayInit.Array {args = args}) + val birth = bear (ArrayInit.Array {args = args, layout = layout}) val sequence = Sequence.make (args, Type.deArray resultType) in new (Array {birth = birth, sequence = sequence}, resultType) @@ -1394,11 +1412,12 @@ fun transform (program: Program.t): Program.t = end | Prim.Vector_length => vectorLength (arg 0) | Prim.Vector_sub => sequenceSub vectorSequence - | Prim.Vector_vector => + | Prim.Vector_vector layout => let val sequence = Sequence.make (args, Type.deVector resultType) in - new (Vector {sequence = sequence}, resultType) + new (Vector {sequence = sequence, layout = layout}, + resultType) end | Prim.Weak_get => weakArg (arg 0) | Prim.Weak_new => @@ -1609,13 +1628,13 @@ fun transform (program: Program.t): Program.t = Property.initRec (fn (t, dependsOn) => case Type.dest t of - Array t => dependsOn t + Array {elem=t, ...} => dependsOn t | Datatype tc => (ignore o Graph.addEdge) (graph, {from = n, to = tyconNode tc}) | Ref t => dependsOn t | Tuple ts => Vector.foreach (ts, dependsOn) - | Vector t => dependsOn t + | Vector {elem=t, ...} => dependsOn t | _ => ())) val () = Vector.foreach diff --git a/mlton/ssa/deep-flatten.fun b/mlton/ssa/deep-flatten.fun index 924014e7a..892fa36f2 100644 --- a/mlton/ssa/deep-flatten.fun +++ b/mlton/ssa/deep-flatten.fun @@ -411,7 +411,7 @@ structure Value = andalso Prod.allAreImmutable args andalso (case con of ObjectCon.Con _ => false - | ObjectCon.Sequence => false + | ObjectCon.Sequence _ => false | ObjectCon.Tuple => true) fun objectFields {args, con} = @@ -423,7 +423,7 @@ structure Value = if (case con of ObjectCon.Con _ => true | ObjectCon.Tuple => true - | ObjectCon.Sequence => false) + | ObjectCon.Sequence _=> false) then Vector.foreach (Prod.dest args, fn {elt, isMutable} => if isMutable then () @@ -622,7 +622,7 @@ fun transform2 (program as Program.T {datatypes, functions, globals, main}) = (conValue c, fn () => makeValue (doit ()))) | Tuple => doit () - | Sequence => doit () + | Sequence _ => doit () end | Weak t => (case makeTypeValue t of diff --git a/mlton/ssa/duplicate-globals.fun b/mlton/ssa/duplicate-globals.fun index e9427f8d5..e4a469e9c 100644 --- a/mlton/ssa/duplicate-globals.fun +++ b/mlton/ssa/duplicate-globals.fun @@ -35,7 +35,7 @@ struct (case prim of (* we might want to duplicate this due to the targ *) Prim.MLton_bogus => duplicatable var - | Prim.Vector_vector => duplicatable var + | Prim.Vector_vector _ => duplicatable var | _ => ()) | _ => () in diff --git a/mlton/ssa/flatten-into-sequences.fun b/mlton/ssa/flatten-into-sequences.fun new file mode 100644 index 000000000..d4b468625 --- /dev/null +++ b/mlton/ssa/flatten-into-sequences.fun @@ -0,0 +1,641 @@ +functor FlattenIntoSequences(S: SSA2_TRANSFORM_STRUCTS): SSA2_TRANSFORM = +struct + open S + + + (* ======================================================================== + * just some quick utilities + *) + + fun vector_iterate_prefixes (f: 'b * 'a -> 'b) (b: 'b) (v: 'a Vector.t) = + let + fun loop prev_accs acc i = + if i >= Vector.length v then (Vector.fromListRev prev_accs, acc) + else loop (acc :: prev_accs) (f (acc, Vector.sub (v, i))) (i + 1) + in + loop [] b 0 + end + + + (* ======================================================================== + * type rewrites + *) + + + (* returns NONE if flattening has no change *) + fun try_flatten_tuples (x as {elt: Type.t, isMutable: bool}) : + {elt: Type.t, isMutable: bool} vector option = + case Type.dest elt of + Type.Object {con = ObjectCon.Tuple, args} => + if Prod.someIsMutable args then + NONE + else + let + val flattened = Vector.concatV (Vector.map (Prod.dest args, fn x => + case try_flatten_tuples x of + NONE => Vector.new1 x + | SOME elements => elements)) + val flat_with_mutability_propagated = + Vector.map (flattened, fn {elt, isMutable = isMutable'} => + {elt = elt, isMutable = isMutable orelse isMutable'}) + in + SOME flat_with_mutability_propagated + end + + | _ => NONE + + + (* all types can be locally rewritten without any context *) + fun try_rewrite_type (ty: Type.t) : Type.t option = + case Type.dest ty of + (* Aos-layout sequences get their tuples flattened and unboxed *) + Type.Object + {con = ObjectCon.Sequence ArrayLayout.Aos, args: Type.t Prod.t} => + if + Vector.forall (Prod.dest args, fn {elt, isMutable} => + Option.isNone (Option.andThen (try_rewrite_type elt, fn elt' => + try_flatten_tuples {elt = elt', isMutable = isMutable}))) + then NONE + else SOME (rewrite_sequence_aos_type args) + + (* Default-layout sequences potentially need their element types rewritten, + * but aren't flattened here. Note that deep flattening may still occur, + * but isn't mandated. *) + | Type.Object + {con = ObjectCon.Sequence ArrayLayout.Default, args: Type.t Prod.t} => + if + Vector.forall (Prod.dest args, fn {elt, ...} => + Option.isNone (try_rewrite_type elt)) + then + NONE + else + SOME (Type.sequence ArrayLayout.Default + (Prod.map (args, rewrite_type))) + + | Type.Object {con, args} => + if + Vector.forall (Prod.dest args, fn {elt, ...} => + Option.isNone (try_rewrite_type elt)) + then + NONE + else + SOME (Type.object {con = con, args = Prod.map (args, rewrite_type)}) + + | Type.Weak ty' => Option.map (try_rewrite_type ty', Type.weak) + | Type.CPointer => NONE + | Type.IntInf => NONE + | Type.Thread => NONE + | Type.Datatype tycon => NONE + | Type.Real real_size => NONE + | Type.Word word_size => NONE + + + and rewrite_sequence_aos_type args = + let + val rewritten = Prod.map (args, rewrite_type) + val flat_and_rewritten = + Prod.make (Vector.concatV (Vector.map (Prod.dest rewritten, fn x => + case try_flatten_tuples x of + NONE => Vector.new1 x + | SOME elements => elements))) + in + Type.sequence ArrayLayout.Default flat_and_rewritten + end + + + and rewrite_sequence_soa_type args = + let + val rewritten = Prod.map (args, rewrite_type) + val flat_and_rewritten = + Prod.make (Vector.concatV (Vector.map (Prod.dest rewritten, fn x => + case try_flatten_tuples x of + NONE => Vector.new1 x + | SOME elements => elements))) + + fun make_one_sequence_component {elt, isMutable} = + { elt = Type.sequence ArrayLayout.Default (Prod.make + (Vector.new1 {elt = elt, isMutable = isMutable})) + , isMutable = false + } + val soa = Type.tuple (Prod.make + (Vector.map (Prod.dest flat_and_rewritten, make_one_sequence_component))) + in + soa + end + + + and rewrite_type ty = + case try_rewrite_type ty of + NONE => ty + | SOME ty' => ty' + + + (* A "ground" type is where flattening stops. *) + fun is_ground_type ty = + case Type.dest ty of + Type.Object {con = ObjectCon.Tuple, args} => Prod.someIsMutable args + | _ => true + + + fun remap_offset sequence_ty offset = + case Type.dest sequence_ty of + Type.Object {con = ObjectCon.Sequence ArrayLayout.Aos, args: Type.t Prod.t} => + let + val lens = Vector.map (Prod.dest args, fn x => + case try_flatten_tuples x of + NONE => 1 + | SOME elts => Vector.length elts) + val (new_offsets, _) = vector_iterate_prefixes op+ 0 lens + + val (new_offset, count) = + (Vector.sub (new_offsets, offset), Vector.sub (lens, offset)) + + val () = Control.diagnostics (fn show => + let + open Layout + in + show (seq + [ str "remap_offset " + , Type.layout sequence_ty + , str " " + , Int.layout offset + , str "; lens = " + , Vector.layout Int.layout lens + , str "; new_offsets = " + , Vector.layout Int.layout lens + , str "; result = " + , Int.layout new_offset + , str " " + , Int.layout count + ]) + end) + in + (new_offset, count) + end + + | _ => + Error.bug + ("FlattenIntoSequences.remap_offset: expected flattened-layout sequence argument, but got " + ^ Layout.toString (Type.layout sequence_ty)) + + + (* ======================================================================== + * rewriting expressions, statements, blocks, transfers + *) + + + (* reconstruct var:ty from the flattened ground elements in ground_vs_tys + * for example: + * to reconstruct x:(int*(real*bool)) + * from [i:int, r:real, b:bool] + * we generate the following statements: + * x_inner = Object.Tuple(r, b) + * x = Object.Tuple(i, x_inner) + *) + fun make_pack_statements (var, ty) ground_vs_tys = + if is_ground_type ty then + let + val () = + if Vector.length ground_vs_tys = 1 then + () + else + Error.bug + ("FlattenIntoSequences.make_pack_statements: ground mismatch") + val (var_src, ty_src) = Vector.sub (ground_vs_tys, 0) + (* val () = + if same_type (rewrite_type ty, ty_src) then + () + else + Error.bug + ("FlattenIntoSequences.make_pack_statements: type mismatch: " + ^ + Layout.toString (Layout.seq + [ Type.layout (rewrite_type ty) + , Layout.str " " + , Type.layout ty_src + ])) *) + in + Vector.new1 + (Statement.Bind + {var = SOME var, ty = rewrite_type ty, exp = Exp.Var var_src}) + end + else + case Type.dest ty of + Type.Object {con = ObjectCon.Tuple, args} => + let + val lens = Vector.map (Prod.dest args, fn x => + case try_flatten_tuples x of + NONE => 1 + | SOME elts => Vector.length elts) + val (ground_starts, _) = vector_iterate_prefixes op+ 0 lens + val component_vs_tys = Vector.map (Prod.dest args, fn {elt, ...} => + (Var.newNoname (), rewrite_type elt)) + val packs = + Vector.concatV + (Vector.mapi (component_vs_tys, fn (i, (v', ty')) => + let + val ground_start = Vector.sub (ground_starts, i) + val ground_len = Vector.sub (lens, i) + val grounds = Vector.tabulate (ground_len, fn j => + Vector.sub (ground_vs_tys, ground_start + j)) + in + make_pack_statements (v', ty') grounds + end)) + val final = Statement.Bind + { var = SOME var + , ty = rewrite_type ty + , exp = Exp.Object + {con = NONE, args = Vector.map (component_vs_tys, #1)} + } + in + Vector.concat [packs, Vector.new1 final] + end + | _ => + Error.bug + ("FlattenIntoSequences.make_pack_statements: attempting to pack non-tuple") + + + fun make_load_statements (base, offset, ground_vs_tys, readBarrier) = + Vector.mapi (ground_vs_tys, fn (idx, (v, ty)) => + Statement.Bind + { var = SOME v + , ty = ty + , exp = + Exp.Select + {base = base, offset = offset + idx, readBarrier = readBarrier} + }) + + + fun try_transform_select get_var_type (var, ty, base, offset, readBarrier) = + case base of + Base.Object _ => NONE + | Base.SequenceSub {index, sequence} => + case try_rewrite_type (get_var_type sequence) of + NONE => NONE + | SOME new_type => + let + val (new_offset, ground_count) = + remap_offset (get_var_type sequence) offset + val ground_vs = Vector.tabulate (ground_count, fn _ => + Var.newNoname ()) + val ground_tys = + case Type.dest new_type of + Type.Object {con = ObjectCon.Sequence ArrayLayout.Aos, args} => + let + val args = Prod.dest args + in + Vector.tabulate (ground_count, fn i => + #elt (Vector.sub (args, new_offset + i))) + end + | _ => + Error.bug + ("FlattenIntoSequences.try_transform_select: bug!") + + val () = Control.diagnostics (fn show => + let + open Layout + in + show (seq + [ str "try_transform_select " + , Type.layout (get_var_type sequence) + , str " " + , Int.layout offset + , str " -> " + , Type.layout new_type + , str " " + , Int.layout new_offset + , str " " + , Int.layout ground_count + , str "; ground_tys = " + , Vector.layout Type.layout ground_tys + ]) + end) + + val () = + (* sanity check *) + if Vector.length ground_tys = ground_count then + () + else + Error.bug + ("FlattenIntoSequences.try_transform_select: ground mismatch") + + val ground_vs_tys = Vector.zip (ground_vs, ground_tys) + val loads = + make_load_statements + (base, new_offset, ground_vs_tys, readBarrier) + + val packs = make_pack_statements (var, ty) ground_vs_tys + in + SOME (Vector.concat [loads, packs]) + end + + + fun transform_bind get_var_type {exp, ty, var} = + let + fun no_change () = + Vector.new1 + (Statement.Bind {exp = exp, ty = rewrite_type ty, var = var}) + in + case exp of + Exp.Select {base, offset, readBarrier} => + (case + try_transform_select get_var_type + (Option.valOf var, ty, base, offset, readBarrier) + of + NONE => no_change () + | SOME ss => ss) + + | _ => no_change () + end + + + (* v:ty must be a tuple of all immutable fields, which might recursively + * contain immutable tuples in some positions. Here, we are unpacking + * its contents in preparation for a flattened store in a sequence. + * + * The idea is to replace + * S[i] := v + * with something like this: + * // unpacking part, handling nested flattens as necessary + * v0 = #0 v + * v1 = #1 v + * ... + * vn = ... + * // a bunch of stores + * S[i][0] := v0 + * S[i][1] := v1 + * ... + * S[i][n] := vn + * + * This function returns (ss, vs) where: + * - ss is all of the unpacking statements, and + * - vs is all of the final unpacked vars, in the correct order. + * (There will be exactly one store statement generated per v in vs.) + * + * Note that `ss` could be larger than `vs` due to nesting. For example, + * if we unpack a tuple `x = (1, (2, 3))` then we get the following `ss`: + * x0 = #0 x + * x_inner = #1 x + * x1 = #0 x_inner + * x2 = #1 x_inner + * but only three `vs`, one for each ground component: + * [ x0, x1, x2 ] + *) + fun make_unpack_statements (v: Var.t, ty: Type.t) : + Statement.t vector * Var.t vector = + let + val () = Control.diagnostics (fn show => + let + open Layout + in + show + (seq + [ str "make_unpack_statements " + , Var.layout v + , str " " + , Type.layout ty + ]) + end) + + fun error msg = + Error.bug + ("FlattenIntoSequences.make_unpack_statements: " ^ msg ^ ": " + ^ Layout.toString (Var.layout v) ^ " of type " + ^ Layout.toString (Type.layout ty)) + + fun unpack_one (v', ty') idx = + Statement.Bind + { var = SOME v' + , ty = rewrite_type ty' + , exp = + Exp.Select + {base = Base.Object v, offset = idx, readBarrier = false} + } + + fun unpack_component_at_idx (idx, {elt = component_ty, isMutable}) = + if isMutable then + error + ("trying to unpack mutable component at tuple index " + ^ Int.toString idx) + else + let + val component_var = Var.newNoname () + val unpack_here = unpack_one (component_var, component_ty) idx + in + (* stop recursively unpacking when we get to the bottom *) + if is_ground_type component_ty then + (Vector.new1 unpack_here, Vector.new1 component_var) + else + let + val (nested_unpacks, nested_grounds) = + make_unpack_statements (component_var, component_ty) + in + ( Vector.concat [Vector.new1 unpack_here, nested_unpacks] + , nested_grounds + ) + end + end + in + case Type.dest ty of + Type.Object {con = ObjectCon.Tuple, args} => + let + val (unpacks, grounds) = Vector.unzip + (Vector.mapi (Prod.dest args, unpack_component_at_idx)) + in + (Vector.concatV unpacks, Vector.concatV grounds) + end + | _ => error "trying to unpack non-tuple" + end + + + fun make_store_statements (base, offset, ground_vs, writeBarrier) = + Vector.mapi (ground_vs, fn (idx, v) => + Statement.Update + { base = base + , offset = offset + idx + , value = v + , writeBarrier = writeBarrier + }) + + + fun try_transform_update get_var_type {base, offset, value, writeBarrier} = + case base of + Base.Object _ => NONE + | Base.SequenceSub {index, sequence} => + case try_rewrite_type (get_var_type sequence) of + NONE => NONE + | SOME new_type => + let + val old_type = get_var_type sequence + val (new_offset, ground_count) = + remap_offset (get_var_type sequence) offset + + val () = Control.diagnostics (fn show => + let + open Layout + in + show (seq + [ str "try_transform_update " + , Type.layout old_type + , str " " + , Int.layout offset + , str " -> " + , Type.layout new_type + , str " " + , Int.layout new_offset + , str " " + , Int.layout ground_count + ]) + end) + + val (unpacks, ground_vs) = + if is_ground_type (get_var_type value) then + (Vector.new0 (), Vector.new1 value) + else + make_unpack_statements (value, get_var_type value) + + val () = + (* sanity check *) + if ground_count = Vector.length ground_vs then + () + else + Error.bug + ("FlattenIntoSequences.try_transform_update: ground mismatch") + val stores = + make_store_statements + (base, new_offset, ground_vs, writeBarrier) + in + SOME (Vector.concat [unpacks, stores]) + end + + + fun transform_statement get_var_type (s: Statement.t) : Statement.t vector = + case s of + Statement.Bind (xx as {exp: Exp.t, ty: Type.t, var: Var.t option}) => + transform_bind get_var_type xx + | Statement.Profile _ => Vector.new1 s + | Statement.Update + (xx as + {base: Var.t Base.t, offset: int, value: Var.t, writeBarrier: bool}) => + case try_transform_update get_var_type xx of + NONE => Vector.new1 s + | SOME ss => ss + + + fun transform_transfer get_var_type t = + case t of + Transfer.Runtime {args, prim, return} => + Transfer.Runtime + {args = args, prim = Prim.map (prim, rewrite_type), return = return} + | _ => t + + + fun transform_block get_var_type block = + let + val + Block.T + { args: (Var.t * Type.t) vector + , label: Label.t + , statements: Statement.t vector + , transfer: Transfer.t + } = block + + val args = Vector.map (args, fn (var, ty) => (var, rewrite_type ty)) + val statements = Vector.concatV + (Vector.map (statements, transform_statement get_var_type)) + val transfer = transform_transfer get_var_type transfer + in + Block.T + { args = args + , label = label + , statements = statements + , transfer = transfer + } + end + + + fun transform_function get_var_type (func: Function.t) : Function.t = + let + val + { args: (Var.t * Type.t) vector + , blocks: Block.t vector + , inline: InlineAttr.t + , name: Func.t + , raises: Type.t vector option + , returns: Type.t vector option + , start: Label.t + } = Function.dest func + + val args = Vector.map (args, fn (var, ty) => (var, rewrite_type ty)) + val raises = Option.map (raises, fn ts => Vector.map (ts, rewrite_type)) + val returns = Option.map (returns, fn ts => Vector.map (ts, rewrite_type)) + val blocks = Vector.map (blocks, transform_block get_var_type) + in + Function.new + { args = args + , blocks = blocks + , inline = inline + , name = name + , raises = raises + , returns = returns + , start = start + } + end + + + fun transform_datatype get_var_type (Datatype.T {cons, tycon: Tycon.t}) = + let + val cons: {args: Type.t Prod.t, con: Con.t} vector = cons + in + Datatype.T + { cons = Vector.map (cons, fn {args, con} => + {con = con, args = Prod.map (args, rewrite_type)}) + , tycon = tycon + } + end + + + (* ======================================================================== + * main entrypoint for this pass + *) + + + (* would want to keep this turned on in practice, but can disable for + * debugging. (flattenIntoSequences relies on shrinking to avoid + * unnecessary intermediate allocations, but shrinking obscures what + * the pass did.) + *) + val do_shrink = true + + + fun transform2 (program as Program.T {datatypes, functions, globals, main}) = + let + val {get = get_var_type: Var.t -> Type.t, set = set_var_type, ...} = + Property.getSetOnce + (Var.plist, Property.initRaise ("varType", Var.layout)) + + val () = Program.foreachVar (program, set_var_type) + + val datatypes = Vector.map (datatypes, transform_datatype get_var_type) + + val functions = + if do_shrink then + List.revMap (functions, transform_function get_var_type) + else + List.map (functions, transform_function get_var_type) + + val globals = Vector.concatV + (Vector.map (globals, transform_statement get_var_type)) + + val program = Program.T + { datatypes = datatypes + , functions = functions + , globals = globals + , main = main + } + + val () = Program.clear program + in + if do_shrink then shrink program else program + end +end diff --git a/mlton/ssa/global.fun b/mlton/ssa/global.fun index cde9945f6..29e525e9a 100644 --- a/mlton/ssa/global.fun +++ b/mlton/ssa/global.fun @@ -22,9 +22,10 @@ val expEquals = | (PrimApp {prim = p, targs = ts, args = xs}, PrimApp {prim = p', targs = ts', args = xs'}) => (case (p, p') of - (Prim.Vector_vector, Prim.Vector_vector) => + (Prim.Vector_vector l, Prim.Vector_vector l') => Vector.equals (ts, ts', Type.equals) andalso equalss (xs, xs') + andalso ArrayLayout.equals (l, l') | _ => false) | (Tuple xs, Tuple xs') => equalss (xs, xs') | _ => false diff --git a/mlton/ssa/poly-equal.fun b/mlton/ssa/poly-equal.fun index bd5d3911d..26346a670 100644 --- a/mlton/ssa/poly-equal.fun +++ b/mlton/ssa/poly-equal.fun @@ -201,13 +201,15 @@ fun transform (Program.T {datatypes, globals, functions, main}) = name end and mkVectorEqualFunc {name: Func.t, - ty: Type.t, doEq: bool}: unit = + ty: Type.t, + doEq: bool, + layout: ArrayLayout.t}: unit = let val loop = Func.newString (Func.originalName name ^ "Loop") (* Build two functions, one that checks the lengths and the * other that loops. *) - val vty = Type.vector ty + val vty = Type.vector layout ty local val vec1 = (Var.newNoname (), vty) val vec2 = (Var.newNoname (), vty) @@ -306,14 +308,14 @@ fun transform (Program.T {datatypes, globals, functions, main}) = in () end - and vectorEqualFunc (ty: Type.t): Func.t = + and vectorEqualFunc (layout: ArrayLayout.t) (ty: Type.t): Func.t = case getVectorEqualFunc ty of SOME f => f | NONE => let val name = Func.newString "vectorEqual" val _ = setVectorEqualFunc (ty, SOME name) - val () = mkVectorEqualFunc {name = name, ty = ty, doEq = true} + val () = mkVectorEqualFunc {name = name, ty = ty, doEq = true, layout = layout} in name end @@ -331,7 +333,8 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val bigIntInfEqual = Func.newString "bigIntInfEqual" val () = mkVectorEqualFunc {name = bigIntInfEqual, ty = Type.word bws, - doEq = false} + doEq = false, + layout = ArrayLayout.Default} local val arg1 = (Var.newNoname (), Type.intInf) @@ -350,7 +353,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = {prim = Prim.IntInf_toVector, targs = Vector.new0 (), args = Vector.new1 dx, - ty = Type.vector (Type.word bws)} + ty = Type.vector ArrayLayout.Default (Type.word bws)} val one = Dexp.word (WordX.one sws) val body = Dexp.disjoin @@ -453,8 +456,8 @@ fun transform (Program.T {datatypes, globals, functions, main}) = in loop 0 end - | Type.Vector ty => - Dexp.call {func = vectorEqualFunc ty, + | Type.Vector {elem=ty, layout} => + Dexp.call {func = vectorEqualFunc layout ty, args = Vector.new2 (dx1, dx2), inline = InlineAttr.Auto, ty = Type.bool} diff --git a/mlton/ssa/poly-hash.fun b/mlton/ssa/poly-hash.fun index 161c91c54..70e873a96 100644 --- a/mlton/ssa/poly-hash.fun +++ b/mlton/ssa/poly-hash.fun @@ -404,7 +404,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = in name end - and vectorHashFunc (ty: Type.t): Func.t = + and vectorHashFunc (lay: ArrayLayout.t) (ty: Type.t): Func.t = case getVectorHashFunc ty of SOME f => f | NONE => @@ -415,7 +415,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val name = Func.newString "vectorHash" val _ = setVectorHashFunc (ty, SOME name) val loop = Func.newString "vectorHashLoop" - val vty = Type.vector ty + val vty = Type.vector lay ty local val st = (Var.newNoname (), Hash.stateTy) val vec = (Var.newNoname (), vty) @@ -555,7 +555,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = {prim = Prim.IntInf_toVector, targs = Vector.new0 (), args = Vector.new1 dx, - ty = Type.vector (Type.word bws)} + ty = Type.vector ArrayLayout.Default (Type.word bws)} val w = Var.newNoname () val dw = Dexp.var (w, Type.word sws) val one = Dexp.word (WordX.one sws) @@ -575,7 +575,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = {con = Con.falsee, args = Vector.new0 (), body = - Dexp.call {func = vectorHashFunc (Type.word bws), + Dexp.call {func = vectorHashFunc ArrayLayout.Default (Type.word bws), args = Vector.new2 (dst, toVector), inline = InlineAttr.Auto, ty = Hash.stateTy}})}} @@ -615,8 +615,8 @@ fun transform (Program.T {datatypes, globals, functions, main}) = in loop (0, dst) end - | Type.Vector ty => - Dexp.call {func = vectorHashFunc ty, + | Type.Vector {elem=ty, layout=lay} => + Dexp.call {func = vectorHashFunc lay ty, args = Vector.new2 (dst, dx), inline = InlineAttr.Auto, ty = Hash.stateTy} diff --git a/mlton/ssa/ref-flatten.fun b/mlton/ssa/ref-flatten.fun index 46f7fc5fb..a56770be2 100644 --- a/mlton/ssa/ref-flatten.fun +++ b/mlton/ssa/ref-flatten.fun @@ -360,7 +360,7 @@ fun transform2 (program as Program.T {datatypes, functions, globals, main}) = in v end)) - | Sequence => doit () + | Sequence _ => doit () | Tuple => doit () end | Weak t => diff --git a/mlton/ssa/remove-unused.fun b/mlton/ssa/remove-unused.fun index e46881413..e334e7a2d 100644 --- a/mlton/ssa/remove-unused.fun +++ b/mlton/ssa/remove-unused.fun @@ -342,11 +342,11 @@ fun transform (Program.T {datatypes, globals, functions, main}) = datatype z = datatype Type.dest val () = case Type.dest ty of - Array ty => visitType ty + Array {elem=ty, ...} => visitType ty | Datatype tycon => visitTycon tycon | Ref ty => visitType ty | Tuple tys => Vector.foreach (tys, visitType) - | Vector ty => visitType ty + | Vector {elem=ty, ...} => visitType ty | Weak ty => visitType ty | _ => () in @@ -413,7 +413,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = (TyconInfo.cons (tyconInfo t), fn con => deconCon con) | Tuple ts => Vector.foreach (ts, deconType) - | Vector t => deconType t + | Vector {elem=t, ...} => deconType t | _ => () in () @@ -917,10 +917,10 @@ fun transform (Program.T {datatypes, globals, functions, main}) = datatype z = datatype Type.dest val ty = case Type.dest ty of - Array ty => Type.array (simplifyType ty) + Array {elem=ty, layout=lay} => Type.array lay (simplifyType ty) | Ref ty => Type.reff (simplifyType ty) | Tuple tys => Type.tuple (Vector.map (tys, simplifyType)) - | Vector ty => Type.vector (simplifyType ty) + | Vector {elem=ty, layout=lay} => Type.vector lay (simplifyType ty) | Weak ty => Type.weak (simplifyType ty) | _ => ty in diff --git a/mlton/ssa/remove-unused2.fun b/mlton/ssa/remove-unused2.fun index d6066ea27..e8d38159e 100644 --- a/mlton/ssa/remove-unused2.fun +++ b/mlton/ssa/remove-unused2.fun @@ -364,7 +364,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = val () = case con of Con con => visitCon con - | Sequence => () + | Sequence _ => () | Tuple => () in () @@ -460,7 +460,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = val () = case con of Con con => deconCon con - | Sequence => default () + | Sequence _ => default () | Tuple => default () in () @@ -530,7 +530,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = in () end - | Sequence => Error.bug "RemoveUnused2.visitExp: Select:non-Con|Tuple" + | Sequence _ => Error.bug "RemoveUnused2.visitExp: Select:non-Con|Tuple" | Tuple => ()) | _ => Error.bug "RemovUnused2.visitExp: Select:non-Object" in @@ -578,7 +578,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = ; visitVar base ; visitVar value)) end - | Sequence => Error.bug "RemoveUnused2.visitStatement: Update:non-Con|Tuple" + | Sequence _ => Error.bug "RemoveUnused2.visitStatement: Update:non-Con|Tuple" | Tuple => (visitVar base ; visitVar value)) @@ -1169,7 +1169,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = offset = offset, readBarrier = readBarrier} end - | Sequence => Error.bug "RemoveUnused2.simplifyExp: Update:non-Con|Tuple" + | Sequence _ => Error.bug "RemoveUnused2.simplifyExp: Update:non-Con|Tuple" | Tuple => e) | _ => Error.bug "RemoveUnused2.simplifyExp:Select:non-Object" end @@ -1240,7 +1240,7 @@ fun transform2 (Program.T {datatypes, globals, functions, main}) = end else NONE end - | Sequence => Error.bug "RemoveUnused2.simplifyStatement: Update:non-Con|Tuple" + | Sequence _ => Error.bug "RemoveUnused2.simplifyStatement: Update:non-Con|Tuple" | Tuple => SOME s) | _ => Error.bug "RemoveUnused2.simplifyStatement: Select:non-Object" end diff --git a/mlton/ssa/share-zero-vec.fun b/mlton/ssa/share-zero-vec.fun index 95c9be6f3..a2fd07e33 100644 --- a/mlton/ssa/share-zero-vec.fun +++ b/mlton/ssa/share-zero-vec.fun @@ -43,7 +43,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val hs: (Type.t, Var.t) HashTable.t = HashTable.new {hash = Type.hash, equals = Type.equals} in - fun getZeroArrVar (ty: Type.t): Var.t = + fun getZeroArrVar (lay: ArrayLayout.t) (ty: Type.t): Var.t = HashTable.lookupOrInsert (hs, ty, fn () => @@ -52,10 +52,10 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val statement = Statement.T {var = SOME zeroArrVar, - ty = Type.array ty, + ty = Type.array lay ty, exp = PrimApp {args = Vector.new0 (), - prim = Prim.Array_array, + prim = Prim.Array_array lay, targs = Vector.new1 ty}} val () = List.push (newGlobals, statement) in @@ -76,7 +76,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = case exp of PrimApp ({prim, args, targs}) => (case (var, prim) of - (SOME var, Prim.Array_alloc {raw = false}) => + (SOME var, Prim.Array_alloc {raw = false, ...}) => if List.contains (arrVars, var, Var.equals) then SOME (var, ty, Vector.first targs, @@ -101,6 +101,8 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val ifNonZeroLab = Label.newString "L_nonZeroLen" val joinLab = Label.newString "L_join" + val arrLayout = Type.deArrayLayout arrTy + (* new block up to Array_alloc match *) val preBlock = let @@ -134,7 +136,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = let val transfer = Transfer.Goto - {args = Vector.new1 (getZeroArrVar eltTy), + {args = Vector.new1 (getZeroArrVar arrLayout eltTy), dst = joinLab} in Block.T {label = ifZeroLab, @@ -147,6 +149,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = val ifNonZeroBlock = let val arrVar' = Var.new arrVar + val statements = Vector.new1 (Statement.T @@ -154,7 +157,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = ty = arrTy, exp = PrimApp {args = Vector.new1 lenVar, - prim = Prim.Array_alloc {raw = false}, + prim = Prim.Array_alloc {raw = false, layout = arrLayout}, targs = Vector.new1 eltTy}}) val transfer = Transfer.Goto diff --git a/mlton/ssa/simplify-types.fun b/mlton/ssa/simplify-types.fun index 31c59698b..e14a8ceea 100644 --- a/mlton/ssa/simplify-types.fun +++ b/mlton/ssa/simplify-types.fun @@ -282,11 +282,11 @@ fun transform (Program.T {datatypes, globals, functions, main}) = let fun deepSetFFI t = case Type.dest t of - Type.Array t => deepSetFFI t + Type.Array {elem=t, ...} => deepSetFFI t | Type.Datatype tycon => tyconFFI tycon () | Type.Ref t => deepSetFFI t | Type.Tuple tv => Vector.foreach(tv, deepSetFFI) - | Type.Vector t => deepSetFFI t + | Type.Vector {elem=t, ...} => deepSetFFI t | Type.Weak t => deepSetFFI t | _ => () in @@ -352,7 +352,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = | Ref t => ptrCard t | Thread => Cardinality.many | Tuple ts => tupleCard ts - | Vector t => vecCard t + | Vector {elem=t, ...} => vecCard t | Weak t => ptrCard t | Word _ => Cardinality.many end)) @@ -536,14 +536,14 @@ fun transform (Program.T {datatypes, globals, functions, main}) = Property.initRec (fn (t, containsTycon) => case Type.dest t of - Array t => containsTycon t + Array {elem=t, ...} => containsTycon t | Datatype tyc' => (case tyconReplacement tyc' of NONE => Tycon.equals (tyc, tyc') | SOME t => containsTycon t) | Tuple ts => Vector.exists (ts, containsTycon) | Ref t => containsTycon t - | Vector t => containsTycon t + | Vector {elem=t, ...} => containsTycon t | Weak t => containsTycon t | _ => false)) val res = containsTycon ty @@ -658,7 +658,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = open Type in case dest t of - Array t => array (simplifyType t) + Array {elem=t, layout=l} => array l (simplifyType t) | Datatype tycon => (case tyconReplacement tycon of SOME t => @@ -674,7 +674,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = (case simplifyUsefulTypesOpt ts of NONE => typeVoid | SOME ts => Type.tuple ts) - | Vector t => vector (simplifyType t) + | Vector {elem=t, layout=l} => vector l (simplifyType t) | Weak t => doitPtr (weak, t) | _ => t end)) diff --git a/mlton/ssa/simplify.fun b/mlton/ssa/simplify.fun index 845c349d1..125b682c2 100644 --- a/mlton/ssa/simplify.fun +++ b/mlton/ssa/simplify.fun @@ -37,7 +37,7 @@ structure RedundantTests = RedundantTests (S) structure RemoveUnused = RemoveUnused (S) structure ShareZeroVec = ShareZeroVec (S) structure SimplifyTypes = SimplifyTypes (S) -structure SplitTypes = SplitTypes (S) +(* structure SplitTypes = SplitTypes (S) *) structure Useless = Useless (S) type pass = {name: string, @@ -59,7 +59,7 @@ val ssaPassesDefault = (* SAM_NOTE: disabling splitTypes1 because it does not yet support primitive * polymorphic CAS. We should update the pass and then re-enable. *) - {name = "splitTypes1", doit = SplitTypes.transform, execute = false} :: + (* {name = "splitTypes1", doit = SplitTypes.transform, execute = false} :: *) (* useless should run * - after constant propagation because constant propagation makes * slots of tuples that are constant useless @@ -74,7 +74,7 @@ val ssaPassesDefault = (* SAM_NOTE: disabling splitTypes2 because it does not yet support primitive * polymorphic CAS. We should update the pass and then re-enable. *) - {name = "splitTypes2", doit = SplitTypes.transform, execute = false} :: + (* {name = "splitTypes2", doit = SplitTypes.transform, execute = false} :: *) {name = "simplifyTypes", doit = SimplifyTypes.transform, execute = true} :: (* polyEqual should run * - after types are simplified so that many equals are turned into eqs @@ -249,7 +249,7 @@ local ("removeUnused", RemoveUnused.transform), ("shareZeroVec", ShareZeroVec.transform), ("simplifyTypes", SimplifyTypes.transform), - ("splitTypes", SplitTypes.transform), + (* ("splitTypes", SplitTypes.transform), *) ("useless", Useless.transform), ("ssaAddProfile", Profile.addProfile), ("ssaDropSpork", DropSpork.transform), diff --git a/mlton/ssa/simplify2.fun b/mlton/ssa/simplify2.fun index f5f26ac76..ec0d2f9d8 100644 --- a/mlton/ssa/simplify2.fun +++ b/mlton/ssa/simplify2.fun @@ -13,6 +13,7 @@ struct open S structure DeepFlatten = DeepFlatten (S) +structure FlattenIntoSequences = FlattenIntoSequences (S) structure DropSpork2 = DropSpork2 (S) structure Profile2 = Profile2 (S) structure RefFlatten = RefFlatten (S) @@ -24,6 +25,7 @@ type pass = {name: string, execute: bool} val ssa2PassesDefault = + {name = "flattenIntoSequences", doit = FlattenIntoSequences.transform2, execute = true} :: {name = "deepFlatten", doit = DeepFlatten.transform2, execute = true} :: {name = "refFlatten", doit = RefFlatten.transform2, execute = true} :: {name = "removeUnused5", doit = RemoveUnused2.transform2, execute = true} :: @@ -51,7 +53,8 @@ local val passGens = - List.map([("deepFlatten", DeepFlatten.transform2), + List.map([("flattenIntoSequences", FlattenIntoSequences.transform2), + ("deepFlatten", DeepFlatten.transform2), ("refFlatten", RefFlatten.transform2), ("removeUnused", RemoveUnused2.transform2), ("zone", Zone.transform2), diff --git a/mlton/ssa/sources.cm b/mlton/ssa/sources.cm index fed7201e1..afdc76957 100644 --- a/mlton/ssa/sources.cm +++ b/mlton/ssa/sources.cm @@ -62,6 +62,7 @@ combine-conversions.fun constant-propagation.fun contify.fun deep-flatten.fun +flatten-into-sequences.fun drop-spork.fun drop-spork2.fun duplicate-globals.fun @@ -88,7 +89,7 @@ remove-unused.fun remove-unused2.fun share-zero-vec.fun simplify-types.fun -split-types.fun +(* split-types.fun *) useless.fun zone.fun simplify.sig diff --git a/mlton/ssa/sources.mlb b/mlton/ssa/sources.mlb index 9ca4df17d..cfd9028e4 100644 --- a/mlton/ssa/sources.mlb +++ b/mlton/ssa/sources.mlb @@ -52,6 +52,7 @@ local constant-propagation.fun contify.fun deep-flatten.fun + flatten-into-sequences.fun drop-spork.fun drop-spork2.fun duplicate-globals.fun @@ -78,7 +79,7 @@ local remove-unused2.fun share-zero-vec.fun simplify-types.fun - split-types.fun + (* split-types.fun *) useless.fun zone.fun simplify.sig diff --git a/mlton/ssa/ssa-to-ssa2.fun b/mlton/ssa/ssa-to-ssa2.fun index 527c91fb6..fbf65a968 100644 --- a/mlton/ssa/ssa-to-ssa2.fun +++ b/mlton/ssa/ssa-to-ssa2.fun @@ -38,7 +38,7 @@ fun convert (S.Program.T {datatypes, functions, globals, main}) = Property.initRec (fn (t, convertType) => case S.Type.dest t of - S.Type.Array t => S2.Type.array1 (convertType t) + S.Type.Array {elem, layout} => S2.Type.array1 layout (convertType elem) | S.Type.CPointer => S2.Type.cpointer | S.Type.Datatype tycon => S2.Type.datatypee tycon | S.Type.IntInf => S2.Type.intInf @@ -50,7 +50,7 @@ fun convert (S.Program.T {datatypes, functions, globals, main}) = (Vector.map (ts, fn t => {elt = convertType t, isMutable = false}))) - | S.Type.Vector t => S2.Type.vector1 (convertType t) + | S.Type.Vector {elem, layout} => S2.Type.vector1 layout (convertType elem) | S.Type.Weak t => S2.Type.weak (convertType t) | S.Type.Word s => S2.Type.word s)) fun convertTypes ts = Vector.map (ts, convertType) @@ -125,7 +125,7 @@ fun convert (S.Program.T {datatypes, functions, globals, main}) = readBarrier = false}) in case prim of - Prim.Array_array => sequence () + Prim.Array_array _ => sequence () | Prim.Array_sub {readBarrier} => simple (S2.Exp.Select @@ -159,7 +159,7 @@ fun convert (S.Program.T {datatypes, functions, globals, main}) = simple (S2.Exp.PrimApp {args = args, prim = Prim.Array_length}) | Prim.Vector_sub => sub () - | Prim.Vector_vector => sequence () + | Prim.Vector_vector _ => sequence () | _ => simple (S2.Exp.PrimApp {args = args, prim = convertPrim prim}) diff --git a/mlton/ssa/ssa-tree.fun b/mlton/ssa/ssa-tree.fun index 0f71de359..c17ef0fbf 100644 --- a/mlton/ssa/ssa-tree.fun +++ b/mlton/ssa/ssa-tree.fun @@ -24,7 +24,7 @@ structure Type = plist: PropertyList.t, tree: tree} and tree = - Array of t + Array of {elem: t, layout: ArrayLayout.t} | CPointer | Datatype of Tycon.t | IntInf @@ -32,7 +32,7 @@ structure Type = | Ref of t | Thread | Tuple of t vector - | Vector of t + | Vector of {elem: t, layout: ArrayLayout.t} | Weak of t | Word of WordSize.t @@ -60,18 +60,21 @@ structure Type = (deOpt, de, is) end in - val (_,deArray,_) = make (fn Array t => SOME t | _ => NONE) + val (_,deArray,_) = make (fn Array {elem, ...} => SOME elem | _ => NONE) + val (_,deArrayLayout,_) = make (fn Array {layout, ...} => SOME layout | _ => NONE) val (_,deDatatype,_) = make (fn Datatype tyc => SOME tyc | _ => NONE) val (_,deRef,_) = make (fn Ref t => SOME t | _ => NONE) val (deTupleOpt,deTuple,isTuple) = make (fn Tuple ts => SOME ts | _ => NONE) - val (_,deVector,_) = make (fn Vector t => SOME t | _ => NONE) + val (_,deVector,_) = make (fn Vector {elem, ...} => SOME elem | _ => NONE) + val (_,deVectorLayout,_) = make (fn Vector {layout, ...} => SOME layout | _ => NONE) val (_,deWeak,_) = make (fn Weak t => SOME t | _ => NONE) val (deWordOpt,deWord,_) = make (fn Word ws => SOME ws | _ => NONE) end local val same: tree * tree -> bool = - fn (Array t1, Array t2) => equals (t1, t2) + fn (Array {elem = t1, layout = l1}, Array {elem = t2, layout = l2}) => + equals (t1, t2) andalso ArrayLayout.equals (l1, l2) | (CPointer, CPointer) => true | (Datatype t1, Datatype t2) => Tycon.equals (t1, t2) | (IntInf, IntInf) => true @@ -79,7 +82,8 @@ structure Type = | (Ref t1, Ref t2) => equals (t1, t2) | (Thread, Thread) => true | (Tuple ts1, Tuple ts2) => Vector.equals (ts1, ts2, equals) - | (Vector t1, Vector t2) => equals (t1, t2) + | (Vector {elem = t1, layout = l1}, Vector {elem = t2, layout = l2}) => + equals (t1, t2) andalso ArrayLayout.equals (l1, l2) | (Weak t1, Weak t2) => equals (t1, t2) | (Word s1, Word s2) => WordSize.equals (s1, s2) | _ => false @@ -110,9 +114,22 @@ structure Type = fn t => lookup (Hash.combine (w, hash t), f t) end in - val array = make Array + val arrayDefault = make (fn t => Array {elem = t, layout = ArrayLayout.Default}) + val arrayAos = make (fn t => Array {elem = t, layout = ArrayLayout.Aos}) + fun array (layout: ArrayLayout.t) elem = + case layout of + ArrayLayout.Default => arrayDefault elem + | ArrayLayout.Aos => arrayAos elem + val reff = make Ref - val vector = make Vector + + val vectorDefault = make (fn t => Vector {elem = t, layout = ArrayLayout.Default}) + val vectorAos = make (fn t => Vector {elem = t, layout = ArrayLayout.Aos}) + fun vector (layout: ArrayLayout.t) elem = + case layout of + ArrayLayout.Default => vectorDefault elem + | ArrayLayout.Aos => vectorAos elem + val weak = make Weak end @@ -155,7 +172,7 @@ structure Type = | Null => cpointer | Real r => real (RealX.size r) | Word w => word (WordX.size w) - | WordVector v => vector (word (WordXVector.elementSize v)) + | WordVector v => vector ArrayLayout.Default (word (WordXVector.elementSize v)) end val unit: t = tuple (Vector.new0 ()) @@ -179,7 +196,15 @@ structure Type = seq [paren (layout t), str " ", str tc] in case dest t of - Array t => unary (t, "array") + Array {elem, layout} => + let + val name = + case layout of + ArrayLayout.Default => "array" + | ArrayLayout.Aos => "array_aos" + in + unary (elem, name) + end | CPointer => str "cpointer" | Datatype t => Tycon.layout t | IntInf => str "intInf" @@ -193,7 +218,15 @@ structure Type = (mayAlign o separateRight) (Vector.toListMap (ts, layout), ","), str ") tuple"] - | Vector t => unary (t, "vector") + | Vector {elem, layout} => + let + val name = + case layout of + ArrayLayout.Default => "vector" + | ArrayLayout.Aos => "vector_aos" + in + unary (elem, name) + end | Weak t => unary (t, "weak") | Word s => str (concat ["word", WordSize.toString s]) end)) @@ -212,10 +245,12 @@ structure Type = List.map (WordSize.all, fn ws => ("word" ^ WordSize.toString ws, word ws)) @ List.map (RealSize.all, fn rs => ("real" ^ RealSize.toString rs, real rs))) val unary = - [array <$ P.kw "array", + [array ArrayLayout.Default <$ P.kw "array", + array ArrayLayout.Aos <$ P.kw "array_aos", reff <$ P.kw "ref", (tuple o Vector.new1) <$ P.kw "tuple", - vector <$ P.kw "vector", + vector ArrayLayout.Default <$ P.kw "vector", + vector ArrayLayout.Aos <$ P.kw "vector_aos", weak <$ P.kw "weak"] in fun parse () = @@ -1918,7 +1953,7 @@ structure Program = datatype z = datatype Type.dest val _ = case Type.dest t of - Array t => countType t + Array {elem, layout} => countType elem | CPointer => () | Datatype _ => () | IntInf => () @@ -1926,7 +1961,7 @@ structure Program = | Ref t => countType t | Thread => () | Tuple ts => Vector.foreach (ts, countType) - | Vector t => countType t + | Vector {elem, layout} => countType elem | Weak t => countType t | Word _ => () val _ = Int.inc numTypes diff --git a/mlton/ssa/ssa-tree.sig b/mlton/ssa/ssa-tree.sig index 1125dbae0..19c380396 100644 --- a/mlton/ssa/ssa-tree.sig +++ b/mlton/ssa/ssa-tree.sig @@ -21,7 +21,7 @@ signature SSA_TREE = type t datatype dest = - Array of t + Array of {elem: t, layout: ArrayLayout.t} | CPointer | Datatype of Tycon.t | IntInf @@ -29,11 +29,11 @@ signature SSA_TREE = | Ref of t | Thread | Tuple of t vector - | Vector of t + | Vector of {elem: t, layout: ArrayLayout.t} | Weak of t | Word of WordSize.t - val array: t -> t + val array: ArrayLayout.t -> t -> t val bool: t val checkPrimApp: {targs: t vector, args: t vector, @@ -43,11 +43,13 @@ signature SSA_TREE = val datatypee: Tycon.t -> t val dest: t -> dest val deArray: t -> t + val deArrayLayout: t -> ArrayLayout.t val deDatatype: t -> Tycon.t val deRef: t -> t val deTuple: t -> t vector val deTupleOpt: t -> t vector option val deVector: t -> t + val deVectorLayout: t -> ArrayLayout.t val deWeak: t -> t val deWord: t -> WordSize.t val deWordOpt: t -> WordSize.t option @@ -63,7 +65,7 @@ signature SSA_TREE = val reff: t -> t val thread: t val tuple: t vector -> t - val vector: t -> t + val vector: ArrayLayout.t -> t -> t val weak: t -> t val word: WordSize.t -> t val unit: t diff --git a/mlton/ssa/ssa-tree2.fun b/mlton/ssa/ssa-tree2.fun index e5c53c487..fde78ed27 100644 --- a/mlton/ssa/ssa-tree2.fun +++ b/mlton/ssa/ssa-tree2.fun @@ -21,17 +21,17 @@ structure ObjectCon = struct datatype t = Con of Con.t - | Sequence + | Sequence of ArrayLayout.t | Tuple val equals: t * t -> bool = fn (Con c, Con c') => Con.equals (c, c') - | (Sequence, Sequence) => true + | (Sequence l, Sequence l') => ArrayLayout.equals (l, l') | (Tuple, Tuple) => true | _ => false val isSequence: t -> bool = - fn Sequence => true + fn Sequence _ => true | _ => false val layout: t -> Layout.t = @@ -41,15 +41,19 @@ structure ObjectCon = in case oc of Con c => Con.layout c - | Sequence => str "sequence" + | Sequence ArrayLayout.Default => str "sequence" + | Sequence ArrayLayout.Aos => str "sequence_aos" | Tuple => str "tuple" end local - val conAlts = Vector.fromList [("sequence", Sequence), ("tuple", Tuple)] + val conAlts = Vector.fromList + [("sequence", Sequence ArrayLayout.Default), + ("sequence_aos", Sequence ArrayLayout.Aos), + ("tuple", Tuple)] in val parse = Con.parseAs (conAlts, Con) - end + end end datatype z = datatype ObjectCon.t @@ -88,9 +92,15 @@ structure Type = val deSequenceOpt: t -> t Prod.t option = fn t => case dest t of - Object {args, con = Sequence} => SOME args + Object {args, con = Sequence _} => SOME args | _ => NONE + val deSequenceLayout: t -> ArrayLayout.t = + fn t => + case dest t of + Object {con = Sequence l, ...} => l + | _ => Error.bug "SsaTree2.Type.deSequenceLayout" + val deSequence1: t -> t = fn t => case deSequenceOpt t of @@ -197,7 +207,8 @@ structure Type = local val tuple = newHash () - val sequence = newHash () + val sequenceDefault = newHash () + val sequenceFlat = newHash () fun hashProd (p, base) = Hash.combine (base, Hash.vectorMap (Prod.dest p, fn {elt, ...} => hash elt)) in @@ -206,7 +217,8 @@ structure Type = val base = case con of Con c => Con.hash c - | Sequence => sequence + | Sequence ArrayLayout.Default => sequenceDefault + | Sequence ArrayLayout.Aos => sequenceFlat | Tuple => tuple val hash = hashProd (args, base) in @@ -214,10 +226,10 @@ structure Type = end end - fun sequence p = object {args = p, con = Sequence} + fun sequence lay p = object {args = p, con = Sequence lay} - fun array1 ty = sequence (Prod.new1Mutable ty) - fun vector1 ty = sequence (Prod.new1Immutable ty) + fun array1 lay ty = sequence lay (Prod.new1Mutable ty) + fun vector1 lay ty = sequence lay (Prod.new1Immutable ty) fun ofConst c = let @@ -229,7 +241,7 @@ structure Type = | Null => cpointer | Real r => real (RealX.size r) | Word w => word (WordX.size w) - | WordVector v => vector1 (word (WordXVector.elementSize v)) + | WordVector v => vector1 ArrayLayout.Default (word (WordXVector.elementSize v)) end fun conApp (con, args) = object {args = args, con = Con con} @@ -285,7 +297,8 @@ structure Type = List.map (WordSize.all, fn ws => ("word" ^ WordSize.toString ws, word ws)) @ List.map (RealSize.all, fn rs => ("real" ^ RealSize.toString rs, real rs))) val unary = - Con.parseAs (Vector.new3 (("sequence", sequence o Prod.new1Immutable), + Con.parseAs (Vector.new4 (("sequence", sequence ArrayLayout.Default o Prod.new1Immutable), + ("sequence_aos", sequence ArrayLayout.Aos o Prod.new1Immutable), ("tuple", tuple o Prod.new1Immutable), ("weak", weak)), fn con => fn ty => @@ -357,13 +370,14 @@ structure Type = val seqIndex = word (WordSize.seqIndex ()) in case prim of - Prim.Array_alloc _ => + Prim.Array_alloc {layout, ...} => oneArg (fn n => case deSequenceOpt result of SOME resp => Prod.allAreMutable resp andalso equals (n, seqIndex) + andalso ArrayLayout.equals (deSequenceLayout result, layout) | _ => false) | Prim.Array_copyArray => fiveArgs diff --git a/mlton/ssa/ssa-tree2.sig b/mlton/ssa/ssa-tree2.sig index 5f7776e0b..6c8b9328c 100644 --- a/mlton/ssa/ssa-tree2.sig +++ b/mlton/ssa/ssa-tree2.sig @@ -20,7 +20,7 @@ signature SSA_TREE2 = sig datatype t = Con of Con.t - | Sequence + | Sequence of ArrayLayout.t | Tuple val isSequence: t -> bool @@ -42,7 +42,7 @@ signature SSA_TREE2 = | Weak of t | Word of WordSize.t - val array1: t -> t + val array1: ArrayLayout.t -> t -> t val bool: t val conApp: Con.t * t Prod.t -> t val checkPrimApp: {args: t vector, @@ -53,6 +53,7 @@ signature SSA_TREE2 = val dest: t -> dest val deSequence1: t -> t val deSequenceOpt: t -> t Prod.t option + val deSequenceLayout: t -> ArrayLayout.t val deRef1Opt : t -> t option val deRef1 : t -> t val equals: t * t -> bool @@ -65,10 +66,10 @@ signature SSA_TREE2 = val plist: t -> PropertyList.t val real: RealSize.t -> t val reff1: t -> t - val sequence: t Prod.t -> t + val sequence: ArrayLayout.t -> t Prod.t -> t val thread: t val tuple: t Prod.t -> t - val vector1: t -> t + val vector1: ArrayLayout.t -> t -> t val weak: t -> t val word: WordSize.t -> t val unit: t diff --git a/mlton/ssa/type-check.fun b/mlton/ssa/type-check.fun index 20a1aec7b..4da03a673 100644 --- a/mlton/ssa/type-check.fun +++ b/mlton/ssa/type-check.fun @@ -64,7 +64,7 @@ fun checkScopes (program as datatype z = datatype Type.dest val _ = case Type.dest ty of - Array ty => loopType ty + Array {elem, layout} => loopType elem | CPointer => () | Datatype tycon => getTycon tycon | IntInf => () @@ -72,7 +72,7 @@ fun checkScopes (program as | Ref ty => loopType ty | Thread => () | Tuple tys => Vector.foreach (tys, loopType) - | Vector ty => loopType ty + | Vector {elem, layout} => loopType elem | Weak ty => loopType ty | Word _ => () in diff --git a/mlton/ssa/type-check2.fun b/mlton/ssa/type-check2.fun index fda0cf427..97ab1f0a4 100644 --- a/mlton/ssa/type-check2.fun +++ b/mlton/ssa/type-check2.fun @@ -62,7 +62,7 @@ fun checkScopes (program as val _ = case oc of Con con => getCon con - | Sequence => () + | Sequence _ => () | Tuple => () in () @@ -648,7 +648,7 @@ fun typeCheck (program as Program.T {datatypes, ...}): unit = fun err () = Error.bug "Ssa2.TypeCheck2.sequence (bad sequence)" in case Type.dest resultType of - Type.Object {args = args', con = ObjectCon.Sequence} => + Type.Object {args = args', con = ObjectCon.Sequence _} => (if (Vector.foreach (args, fn args => Vector.foreach2 diff --git a/mlton/ssa/useless.fun b/mlton/ssa/useless.fun index 6478eab24..dfb285174 100644 --- a/mlton/ssa/useless.fun +++ b/mlton/ssa/useless.fun @@ -284,14 +284,14 @@ structure Value = val loop = fn t => loop (t, es) val value = case Type.dest t of - Type.Array t => + Type.Array {elem=t, ...} => Array {useful = useful (), length = loop (Type.word (WordSize.seqIndex ())), elt = slot t} | Type.Ref t => Ref {arg = slot t, useful = useful ()} | Type.Tuple ts => Tuple (Vector.map (ts, slot)) - | Type.Vector t => + | Type.Vector {elem=t, ...} => Vector {length = loop (Type.word (WordSize.seqIndex ())), elt = slot t} | Type.Weak t => Weak {arg = slot t, @@ -437,12 +437,16 @@ structure Value = in case value of Array arg => - (case arrayRep arg of - ArrayRep.Array (ty, u) => (Type.array ty, u) + let + val lay = Type.deArrayLayout ty + in + case arrayRep arg of + ArrayRep.Array (ty, u) => (Type.array lay ty, u) | ArrayRep.Length => (Type.word (WordSize.seqIndex ()), true) | ArrayRep.LengthRef => (Type.reff (Type.word (WordSize.seqIndex ())), true) | ArrayRep.UnitRef => (Type.reff Type.unit, true) - | ArrayRep.Unit => (Type.unit, false)) + | ArrayRep.Unit => (Type.unit, false) + end | Ground u => (ty, Useful.isUseful u) | Ref {arg, useful, ...} => orU (wrap (arg, Type.reff), useful) @@ -459,10 +463,14 @@ structure Value = (Type.tuple ts, b) end | Vector arg => - (case vectorRep arg of - VectorRep.Vector (ty, u) => (Type.vector ty, u) + let + val lay = Type.deVectorLayout ty + in + case vectorRep arg of + VectorRep.Vector (ty, u) => (Type.vector lay ty, u) | VectorRep.Length => (Type.word (WordSize.seqIndex ()), true) - | VectorRep.Unit => (Type.unit, false)) + | VectorRep.Unit => (Type.unit, false) + end | Weak {arg, useful} => orU (wrap (arg, Type.weak), useful) end) @@ -696,7 +704,7 @@ fun transform (program: Program.t): Program.t = Exists.whenExists (#2 (arrayEltSlot result), fn () => Useful.makeUseful (deground (arg 0))) - | Prim.Array_array => seq arrayElt + | Prim.Array_array _ => seq arrayElt (* SAM_NOTE: unification is certainly "correct" but we should * investigate whether coercions are possible. *) | Prim.Array_cas _ => (arg 1 dependsOn (arrayElt (arg 0)) @@ -753,7 +761,7 @@ fun transform (program: Program.t): Program.t = ; unify (result, deref (arg 0))) | Prim.Vector_length => length vectorLength | Prim.Vector_sub => sub vectorElt - | Prim.Vector_vector => seq vectorElt + | Prim.Vector_vector _ => seq vectorElt | Prim.Weak_canGet => Useful.whenUseful (deground result, fn () => @@ -998,7 +1006,7 @@ fun transform (program: Program.t): Program.t = targs = Vector.new1 Type.unit, args = Vector.new1 unitVar}) | Value.ArrayRep.Unit => simple (Var unitVar)) - | Prim.Array_array => + | Prim.Array_array _ => (case Value.arrayRep (Value.arrayArg resultValue) of Value.ArrayRep.Array (eltTy, _) => makeSeq eltTy | Value.ArrayRep.Length => @@ -1150,7 +1158,7 @@ fun transform (program: Program.t): Program.t = | Value.VectorRep.Length => simple (Var (arg 0)) | Value.VectorRep.Unit => Error.bug "Useless.doitPrim: Vector_length/VectorRep.Unit") - | Prim.Vector_vector => + | Prim.Vector_vector _ => (case Value.vectorRep (Value.vectorArg resultValue) of Value.VectorRep.Vector (eltTy, _) => makeSeq eltTy | Value.VectorRep.Length => @@ -1201,7 +1209,7 @@ fun transform (program: Program.t): Program.t = Value.VectorRep.Vector (ty, _) => if Type.isUnit ty then simple (PrimApp - {prim = Prim.Vector_vector, + {prim = Prim.Vector_vector (Type.deVectorLayout resultType), targs = Vector.new1 Type.unit, args = WordXVector.toVectorMap (ws, fn _ => unitVar)}) else simple e