diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 000000000..818ad71b8 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1,3 @@ +inputs/ +pbbsbench/ +bin/ diff --git a/tests/Makefile b/tests/Makefile new file mode 100644 index 000000000..13ea052be --- /dev/null +++ b/tests/Makefile @@ -0,0 +1,164 @@ +PBBS_DIR ?= pbbsbench +INPUT_DIR := inputs +BIN_DIR := bin + + +#=================inputs================= +RAND_PTS := $(PBBS_DIR)/testData/geometryData/randPoints +RMAT_GRAPH := $(PBBS_DIR)/testData/graphData/rMatGraph + +UC_INPUTS := $(INPUT_DIR)/uniform-circle-1M $(INPUT_DIR)/uniform-circle-20M +RMAT_INPUTS := $(INPUT_DIR)/rmat-1M-symm $(INPUT_DIR)/rmat-10M-symm +RMAT_BIN_INPUTS := $(patsubst %,%-bin,$(RMAT_INPUTS)) +TEXT_INPUTS := $(INPUT_DIR)/words-8 $(INPUT_DIR)/words-32 +BIN_INPUTS := $(INPUT_DIR)/mangore-waltz.wav $(INPUT_DIR)/pano.ppm $(INPUT_DIR)/moon-landing.wav +ALL_INPUTS := $(UC_INPUTS) $(RMAT_INPUTS) $(RMAT_BIN_INPUTS) $(TEXT_INPUTS) $(BIN_INPUTS) + +.PHONY: all clean deepclean + +all: input tests + +input: $(ALL_INPUTS) + +$(PBBS_DIR): + git clone https://github.com/cmuparlay/pbbsbench $@ + git -C $@ submodule update --init --recursive + +$(RAND_PTS) $(RMAT_GRAPH): $(PBBS_DIR) + $(MAKE) -C $(@D) $(@F) + +$(BIN_INPUTS): + curl -L https://raw.githubusercontent.com/MPLLang/parallel-ml-bench/refs/heads/main/inputs/$(notdir $@) -o $@ + +$(INPUT_DIR)/uniform-circle-%: $(RAND_PTS) + $< -s $(subst M,000000,$(*)) $@ + +$(INPUT_DIR)/rmat-%-symm: $(RMAT_GRAPH) + $< -s 15210 -o -j $(subst M,000000,$(*)) $@ + +$(INPUT_DIR)/rmat-%-symm-bin: $(BIN_DIR)/graphio.mlton $(INPUT_DIR)/rmat-%-symm $(INPUT_DIR) + $< $(word 2, $^) -outfile $@ + +$(INPUT_DIR)/words: + curl -L -o $@ https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt + +$(INPUT_DIR)/words-%: $(INPUT_DIR)/words + # Evil shell magic: + # 1. count the lines of the original words file + # 2. multiply that by n in words-n + # 3. shuffle with repetition that many times + # + # You can technically create words-n for any n, not just 8 or 32 + shuf -n $$(( $$(wc -l < $<) * $(*) )) -o $@ --repeat $< + +$(ALL_INPUTS): | $(INPUT_DIR)/ + +#=================binaries================= +DEFAULT_FLAGS ?= -default-type int64 -default-type word64 +SMLC ?= mlton + +COMPILER_NAME := $(notdir $(SMLC)) +TESTS_NAME := $(sort $(basename $(notdir $(wildcard bench/*/*.mlb)))) +BINARIES := $(addprefix $(BIN_DIR)/,$(TESTS_NAME)) + +$(BIN_DIR)/%.$(COMPILER_NAME): bench/%/*.mlb $(BIN_DIR)/ + $(SMLC) -output $@ -mlb-path-var 'COMPAT $(COMPILER_NAME)' $(DEFAULT_FLAGS) $< + +#=================default parameters for tests================= +primes_N := 100000000 +dense-matmul_N := 1024 +msort_N := 20000000 +suffix-array_N := 1000000 +palindrome_N := 1000000 +nqueens_N := 13 +linefit-opt_N := 500000000 +linearrec_N := 200000000 +bignum-add-opt_N := 500000000 +integrate-opt_N := 500000000 +sparse-mxv-opt_N := 200000000 +mcss-opt_N := 500000000 +ocaml-lu-decomp_N := 1024 +ocaml-binarytrees5_N := 19 + +dedup_W := $(INPUT_DIR)/words-32 +grep_W := $(INPUT_DIR)/words-32 +tokens_W := $(INPUT_DIR)/words-32 +msort-strings_W := $(INPUT_DIR)/words-8 + +delaunay_C := 1M +nearest-nbrs_C := 1M +quickhull_C := 20M + +dedup_PRE := --verbose --no-output +grep_PRE := EE +tokens_PRE := --verbose --no-output +bfs_PRE := --no-dir-opt + +NUMERICAL_TESTS := primes dense-matmul msort suffix-array palindrome \ + nqueens linefit-opt linearrec bignum-add-opt integrate-opt \ + sparse-mxv-opt mcss-opt ocaml-lu-decomp ocaml-binarytrees5 +WORDS_TESTS := dedup grep tokens msort-strings +CIRC_TESTS := delaunay nearest-nbrs quickhull +RMAT_TESTS := bfs centrality low-d-decomp max-indep-set triangle-count wc-opt +MISC_TESTS := tinykaboom reverb seam-carve range-tree raytracer ocaml-nbody-imm + +ALL_TESTS := $(NUMERICAL_TESTS) $(WORDS_TESTS) $(CIRC_TESTS) $(RMAT_TESTS) $(MISC_TESTS) +$(ALL_TESTS): %: $(BIN_DIR)/%.$(COMPILER_NAME) + +test: $(ALL_TESTS) + +$(NUMERICAL_TESTS): +ifdef N + $(BIN_DIR)/$@.$(COMPILER_NAME) -n $(N) +else + $(BIN_DIR)/$@.$(COMPILER_NAME) -n $($@_N) +endif + +.SECONDEXPANSION: + +ifdef WORDS +$(WORDS_TESTS): + $(BIN_DIR)/$@.$(COMPILER_NAME) $($@_PRE) $(WORDS) +else +$(WORDS_TESTS): $$($$@_W) + $(BIN_DIR)/$@.$(COMPILER_NAME) $($@_PRE) $< +endif + + +$(CIRC_TESTS): $(INPUT_DIR)/uniform-circle-$$($$@_C) + $(BIN_DIR)/$@.$(COMPILER_NAME) -input $< + +$(RMAT_TESTS): $(INPUT_DIR)/rmat-10M-symm-bin + $(BIN_DIR)/$@.$(COMPILER_NAME) $($@_PRE) $< + +tinykaboom: + $(BIN_DIR)/$@.$(COMPILER_NAME) -width 100 -height 100 -frames 10 -fps 1 + +reverb: $(INPUT_DIR)/mangore-waltz.wav + $(BIN_DIR)/$@.$(COMPILER_NAME) $(INPUT_DIR)/mangore-waltz.wav + +seam-carve: $(INPUT_DIR)/pano.ppm + $(BIN_DIR)/$@.$(COMPILER_NAME) $(INPUT_DIR)/pano.ppm -num-seams 100 + +range-tree: + $(BIN_DIR)/$@.$(COMPILER_NAME) -n 1000000 -q 1000000 + +raytracer: + $(BIN_DIR)/$@.$(COMPILER_NAME) -n 1000 -m 1000 + +game-of-life: + $(BIN_DIR)/$@.$(COMPILER_NAME) -n_times 100 -board_size 1024 + +ocaml-nbody-imm: + $(BIN_DIR)/$@.$(COMPILER_NAME) -n 500 -num_bodies 1024 + +clean: + rm -rf $(INPUT_DIR) + rm -rf $(BIN_DIR) + +deepclean: clean + -$(MAKE) -C $(PBBS_DIR) clean 2>/dev/null || true + rm -rf $(PBBS_DIR) + +$(INPUT_DIR)/ $(BIN_DIR)/: + mkdir -p $@ diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..dadce24cb --- /dev/null +++ b/tests/README.md @@ -0,0 +1,2 @@ +The `mpllib` is a copy of the parallel utilities in `github.com/mpllang/mpllib`. +All other directories hold tests. diff --git a/tests/bench/bfs-delayed/MkBFS.sml b/tests/bench/bfs-delayed/MkBFS.sml new file mode 100644 index 000000000..e564bafd3 --- /dev/null +++ b/tests/bench/bfs-delayed/MkBFS.sml @@ -0,0 +1,47 @@ +functor MkBFS (Seq: SEQUENCE) = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs graph source = + let + val N = G.numVertices graph + val M = G.numEdges graph + + fun outEdges u = + Seq.map (fn v => (u, v)) (Seq.fromArraySeq (G.neighbors graph u)) + + val parents = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => + Array.update (parents, i, ~1)) + + fun isVisited v = + Array.sub (parents, v) <> ~1 + + fun visit (u, v) = + if not (isVisited v) andalso + (~1 = Concurrency.casArray (parents, v) (~1, u)) + then + SOME v + else + NONE + + fun loop frontier totalVisited = + if Seq.length frontier = 0 then + totalVisited + else + let + val allNeighbors = Seq.flatten (Seq.map outEdges frontier) + val nextFrontier = Seq.mapOption visit allNeighbors + in + loop nextFrontier (totalVisited + Seq.length nextFrontier) + end + + val _ = Array.update (parents, source, source) + val initFrontier = Seq.singleton source + val numVisited = loop initFrontier 1 + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-delayed/SerialBFS.sml b/tests/bench/bfs-delayed/SerialBFS.sml new file mode 100644 index 000000000..347559a3e --- /dev/null +++ b/tests/bench/bfs-delayed/SerialBFS.sml @@ -0,0 +1,39 @@ +structure SerialBFS = +struct + + structure Seq = ArraySequence + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, ~1) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + fun visit (hi', u) = + if Array.sub (parents, u) >= 0 then hi' + else ( Array.update (parents, u, v) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, s) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-delayed/bfs-delayed.mlb b/tests/bench/bfs-delayed/bfs-delayed.mlb new file mode 100644 index 000000000..e71267437 --- /dev/null +++ b/tests/bench/bfs-delayed/bfs-delayed.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +SerialBFS.sml +MkBFS.sml +main.sml diff --git a/tests/bench/bfs-delayed/main.sml b/tests/bench/bfs-delayed/main.sml new file mode 100644 index 000000000..ef63bf1b8 --- /dev/null +++ b/tests/bench/bfs-delayed/main.sml @@ -0,0 +1,72 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence +structure G = AdjacencyGraph(Int) + +(* Set by subdirectory *) +structure BFS = MkBFS(OldDelayedSeq) + +(* Generate an input + * If -infile is given, then will load file. + * Otherwise, uses -n -d to generate a random graph. *) +val filename = CLA.parseString "infile" "" +val t0 = Time.now () +val (graphspec, input) = + if filename <> "" then + (filename, G.parseFile filename) + else + let + val n = CLA.parseInt "n" 1000000 + val d = CLA.parseInt "d" 10 + in + ("random(" ^ Int.toString n ^ "," ^ Int.toString d ^ ")", + G.randSymmGraph n d) + end +val t1 = Time.now () +val _ = print ("loaded graph in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val n = G.numVertices input +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +val _ = print ("graph " ^ graphspec ^ "\n") +val _ = print ("num-verts " ^ Int.toString n ^ "\n") +val _ = print ("num-edges " ^ Int.toString (G.numEdges input) ^ "\n") +val _ = print ("source " ^ Int.toString source ^ "\n") +val _ = print ("check " ^ (if doCheck then "true" else "false") ^ "\n") + +fun task () = + BFS.bfs input source + +val P = Benchmark.run "running bfs" task + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices input) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs input source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () diff --git a/tests/bench/bfs-det-dedup/Dedup.sml b/tests/bench/bfs-det-dedup/Dedup.sml new file mode 100644 index 000000000..2c6375ebd --- /dev/null +++ b/tests/bench/bfs-det-dedup/Dedup.sml @@ -0,0 +1,159 @@ +structure Dedup: +sig + val dedup: ('k * 'k -> bool) (* equality check *) + -> ('k -> Word64.word) (* first hash function *) + -> ('k -> Word64.word) (* second hash function *) + -> 'k Seq.t (* input (with duplicates) *) + -> 'k Seq.t (* deduplicated (not sorted!) *) +end = +struct + + structure A = Array + structure AS = ArraySlice + val update = Array.update + val sub = Array.sub + + fun chunkedfor chunkSize (flo, fhi) f = + let + val n = fhi - flo + val numChunks = (n-1) div chunkSize + 1 + in + Util.for (0, numChunks) (fn i => + let + val clo = flo + i*chunkSize + val chi = if i = numChunks - 1 then fhi else flo + (i+1)*chunkSize + in + Util.for (clo, chi) f + end) + end + + fun chunkedloop chunkSize (flo, fhi) init f = + let + val n = fhi - flo + val numChunks = (n-1) div chunkSize + 1 + in + Util.loop (0, numChunks) init (fn (b, i) => + let + val clo = flo + i*chunkSize + val chi = if i = numChunks - 1 then fhi else flo + (i+1)*chunkSize + val b' = Util.loop (clo, chi) b f + in + b' + end) + end + + datatype 'a bucketTree = + Leaf of 'a array + | Node of int * 'a bucketTree * 'a bucketTree + + fun count t = + case t of + Leaf a => A.length a + | Node (c, _, _) => c + + fun bucketTree n (f : int -> 'a array) = + let + fun tree (lo, hi) = + case hi - lo of + 0 => Leaf (ForkJoin.alloc 0) + | 1 => Leaf (f lo) + | n => let val mid = lo + n div 2 + val (l, r) = ForkJoin.par (fn _ => tree (lo, mid), fn _ => tree (mid, hi)) + in Node (count l + count r, l, r) + end + in + tree (0, n) + end + + fun indexApp chunkSize (f : (int * 'a) -> unit) (t : 'a bucketTree) = + let + fun app offset t = + case t of + Leaf a => chunkedfor chunkSize (0, A.length a) (fn i => f (offset+i, sub (a, i))) + | Node (_, l, r) => + (ForkJoin.par (fn _ => app offset l, fn _ => app (offset + count l) r); + ()) + in + app 0 t + end + + fun compactFilter chunkSize (s : 'a option array) count = + let + val t = ForkJoin.alloc count + val _ = chunkedloop chunkSize (0, A.length s) 0 (fn (ti, si) => + case sub (s, si) of + NONE => ti + | SOME x => (update (t, ti, x); ti+1)) + in + t + end + + fun serialHistogram eq hash s = + let + val n = AS.length s + val tn = Util.boundPow2 n + val tmask = Word64.fromInt (tn - 1) + val t = Array.array (tn, NONE) + + fun insert k = + let + fun probe i = + case sub (t, i) of + NONE => (update (t, i, SOME k); true) + | SOME k' => + if eq (k', k) then + false + else if i+1 = tn then + probe 0 + else + probe (i+1) + val h = Word64.toInt (Word64.andb (hash k, tmask)) + in + probe h + end + + val (sa, slo, sn) = AS.base s + val shi = slo+sn + val count = chunkedloop 1024 (slo, shi) 0 (fn (c, i) => + if insert (sub (sa, i)) + then c+1 + else c) + in + compactFilter 1024 t count + end + + + (* val dedup : ('k * 'k -> bool) equality check + -> ('k -> Word64.word) first hash function + -> ('k -> Word64.word) second hash function + -> 'k seq input (with duplicates) + -> 'k seq deduplicated (not sorted!) + *) + fun dedup eq hash hash' keys = + if AS.length keys = 0 then Seq.empty () else + let + val n = AS.length keys + val bucketBits = + if n < Util.pow2 27 + then (Util.log2 n - 7) div 2 + else Util.log2 n - 17 + val numBuckets = Util.pow2 (bucketBits + 1) + val bucketMask = Word64.fromInt (numBuckets - 1) + fun getBucket k = Word64.toInt (Word64.andb (hash k, bucketMask)) + fun ithKeyBucket i = getBucket (Seq.nth keys i) + val (bucketed, offsets) = CountingSort.sort keys ithKeyBucket numBuckets + fun offset i = Seq.nth offsets i + val tree = bucketTree numBuckets (fn i => + let + val bucketks = Seq.subseq bucketed (offset i, offset (i+1) - offset i) + in + serialHistogram eq hash' bucketks + end) + + val result = ForkJoin.alloc (count tree) + val _ = indexApp 1024 (fn (i, x) => update (result, i, x)) tree + in + AS.full result + end + +end \ No newline at end of file diff --git a/tests/bench/bfs-det-dedup/DedupBFS.sml b/tests/bench/bfs-det-dedup/DedupBFS.sml new file mode 100644 index 000000000..eb91438df --- /dev/null +++ b/tests/bench/bfs-det-dedup/DedupBFS.sml @@ -0,0 +1,147 @@ +structure DedupBFS = +struct + type 'a seq = 'a Seq.t + + (* structure DS = DelayedSeq *) + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun strip s = + let val (s', start, _) = ArraySlice.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun bfs {diropt: bool} (g : G.graph) (s : vertex) = + let + val n = G.numVertices g + val parent = strip (Seq.tabulate (fn _ => ~1) n) + + (* Choose method of filtering the frontier: either frontier always + * only consists of valid vertex ids, or it allows invalid vertices and + * pretends that these vertices are isolated. *) + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + (* + fun degree v = if v < 0 then 0 else Graph.degree g v + fun filterFrontier s = s + *) + + val denseThreshold = G.numEdges g div 20 + + fun sumOfOutDegrees frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (degree o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense frontier = + diropt andalso + let + val n = Seq.length frontier + val m = sumOfOutDegrees frontier + in + n + m > denseThreshold + end + + fun bottomUp frontier = + raise Fail "DedupBFS: direction optimization not implemented yet" + + fun topDown frontier = + let + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs: (vertex * vertex) array = ForkJoin.alloc mf + + fun visitNeighbors offset v nghs = + Util.for (0, Seq.length nghs) (fn i => + let val u = Seq.nth nghs i + in upd (outNbrs, offset+i, (v, u)) + end) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val v = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo v (Seq.subseq (G.neighbors g v) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + val vtow = Word64.fromInt o vtoi + fun h1 w = Word64.>> (Util.hash64_2 w, 0w32) + fun h2 w = Util.hash64_2 w + + (* populates outNbrs *) + val _ = parVisitMany (0, nf + 1) (0, mf) + val outNbrs = ArraySlice.full outNbrs + val unvisited = Seq.filter (fn (_, u) => sub (parent, u) = ~1) outNbrs + val deduped = Dedup.dedup + (fn ((_, u1), (_, u2)) => u1 = u2) + (fn (_, u) => h1 (vtow u)) + (fn (_, u) => h2 (vtow u)) + unvisited + + val nextFrontier = + Seq.map (fn (v, u) => (upd (parent, vtoi u, v); u)) deduped + in + nextFrontier + end + + fun search frontier = + if Seq.length frontier = 0 then + () + else if shouldProcessDense frontier then + let + val (nextFrontier, tm) = Util.getTime (fn _ => bottomUp frontier) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => topDown frontier) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + + val _ = upd (parent, vtoi s, s) + val _ = search (Seq.fromList [s]) + in + ArraySlice.full parent + end + +end diff --git a/tests/bench/bfs-det-dedup/OffsetSearch.sml b/tests/bench/bfs-det-dedup/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/bfs-det-dedup/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/bfs-det-dedup/SerialBFS.sml b/tests/bench/bfs-det-dedup/SerialBFS.sml new file mode 100644 index 000000000..60a010386 --- /dev/null +++ b/tests/bench/bfs-det-dedup/SerialBFS.sml @@ -0,0 +1,38 @@ +structure SerialBFS = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, ~1) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + fun visit (hi', u) = + if Array.sub (parents, u) >= 0 then hi' + else ( Array.update (parents, u, v) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, s) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-det-dedup/bfs-det-dedup.mlb b/tests/bench/bfs-det-dedup/bfs-det-dedup.mlb new file mode 100644 index 000000000..0a8df185f --- /dev/null +++ b/tests/bench/bfs-det-dedup/bfs-det-dedup.mlb @@ -0,0 +1,6 @@ +../../mpllib/sources.$(COMPAT).mlb +Dedup.sml +SerialBFS.sml +OffsetSearch.sml +DedupBFS.sml +main.sml diff --git a/tests/bench/bfs-det-dedup/main.sml b/tests/bench/bfs-det-dedup/main.sml new file mode 100644 index 000000000..e137dd218 --- /dev/null +++ b/tests/bench/bfs-det-dedup/main.sml @@ -0,0 +1,77 @@ +structure CLA = CommandLineArgs +structure BFS = DedupBFS +structure G = BFS.G + +val dontDirOpt = CLA.parseFlag "no-dir-opt" +val diropt = not dontDirOpt + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val _ = + if not diropt then () else + let + val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) + val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + in + () + end + +val P = Benchmark.run "running bfs" + (fn _ => BFS.bfs {diropt = diropt} graph source) + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + diff --git a/tests/bench/bfs-det-priority/OffsetSearch.sml b/tests/bench/bfs-det-priority/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/bfs-det-priority/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/bfs-det-priority/PriorityBFS.sml b/tests/bench/bfs-det-priority/PriorityBFS.sml new file mode 100644 index 000000000..6fdc79801 --- /dev/null +++ b/tests/bench/bfs-det-priority/PriorityBFS.sml @@ -0,0 +1,168 @@ +structure PriorityBFS = +struct + type 'a seq = 'a Seq.t + + (* structure DS = DelayedSeq *) + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun strip s = + let val (s', start, _) = ArraySlice.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun bfs {diropt: bool} (g : G.graph) (s : vertex) = + let + val n = G.numVertices g + val parent = strip (Seq.tabulate (fn _ => ~1) n) + val visited = strip (Seq.tabulate (fn _ => false) n) + + (* Choose method of filtering the frontier: either frontier always + * only consists of valid vertex ids, or it allows invalid vertices and + * pretends that these vertices are isolated. *) + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + (* + fun degree v = if v < 0 then 0 else Graph.degree g v + fun filterFrontier s = s + *) + + val denseThreshold = G.numEdges g div 20 + + fun sumOfOutDegrees frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (degree o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense frontier = + diropt andalso + let + val n = Seq.length frontier + val m = sumOfOutDegrees frontier + in + n + m > denseThreshold + end + + fun bottomUp frontier = + raise Fail "PriorityBFS: bottom up not implemented yet" + + fun topDown frontier = + let + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs = ForkJoin.alloc mf + + (* Priority update, attempt set v as parent of u. Returns true only + * if it is the first visit, to ensure that u appears at most once + * in next frontier. *) + fun tryVisit (u, v) = + if sub (visited, u) then false else + let + val old = sub (parent, u) + val isFirstVisit = (old = ~1) + in + if v <= old then + false + else if old = Concurrency.casArray (parent, u) (old, v) then + isFirstVisit + else + tryVisit (u, v) + end + + fun visitNeighbors offset v nghs = + Util.for (0, Seq.length nghs) (fn i => + let val u = Seq.nth nghs i + in if not (tryVisit (vtoi u, vtoi v)) + then upd (outNbrs, offset + i, itov (~1)) + else upd (outNbrs, offset + i, u) + end) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val v = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo v (Seq.subseq (G.neighbors g v) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + (* Either one of the following is correct, but the second one has + * significantly better granularity control for graphs that have a + * small number of vertices with huge degree. *) + + (* val _ = ParUtil.parfor 100 (0, nf) (fn i => + visitMany i (sub (offsets, i)) (sub (offsets, i+1))) *) + + val _ = parVisitMany (0, nf + 1) (0, mf) + val nextFrontier = filterFrontier (ArraySlice.full outNbrs) + in + ForkJoin.parfor 5000 (0, Seq.length nextFrontier) (fn i => + let + val v = Seq.nth nextFrontier i + in + upd (visited, v, true) + end); + + nextFrontier + end + + fun search frontier = + if Seq.length frontier = 0 then + () + else if shouldProcessDense frontier then + let + val (nextFrontier, tm) = Util.getTime (fn _ => bottomUp frontier) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => topDown frontier) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + + val _ = upd (parent, vtoi s, s) + val _ = upd (visited, vtoi s, true) + val _ = search (Seq.fromList [s]) + in + ArraySlice.full parent + end + +end diff --git a/tests/bench/bfs-det-priority/SerialBFS.sml b/tests/bench/bfs-det-priority/SerialBFS.sml new file mode 100644 index 000000000..60a010386 --- /dev/null +++ b/tests/bench/bfs-det-priority/SerialBFS.sml @@ -0,0 +1,38 @@ +structure SerialBFS = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, ~1) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + fun visit (hi', u) = + if Array.sub (parents, u) >= 0 then hi' + else ( Array.update (parents, u, v) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, s) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-det-priority/bfs-det-priority.mlb b/tests/bench/bfs-det-priority/bfs-det-priority.mlb new file mode 100644 index 000000000..c7b968587 --- /dev/null +++ b/tests/bench/bfs-det-priority/bfs-det-priority.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +SerialBFS.sml +OffsetSearch.sml +PriorityBFS.sml +main.sml diff --git a/tests/bench/bfs-det-priority/main.sml b/tests/bench/bfs-det-priority/main.sml new file mode 100644 index 000000000..8f59fb7fa --- /dev/null +++ b/tests/bench/bfs-det-priority/main.sml @@ -0,0 +1,77 @@ +structure CLA = CommandLineArgs +structure BFS = PriorityBFS +structure G = BFS.G + +val dontDirOpt = CLA.parseFlag "no-dir-opt" +val diropt = not dontDirOpt + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val _ = + if not diropt then () else + let + val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) + val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + in + () + end + +val P = Benchmark.run "running bfs" + (fn _ => BFS.bfs {diropt = diropt} graph source) + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + diff --git a/tests/bench/bfs-tree-entangled-fixed/NondetBFS.sml b/tests/bench/bfs-tree-entangled-fixed/NondetBFS.sml new file mode 100644 index 000000000..2c2effc2b --- /dev/null +++ b/tests/bench/bfs-tree-entangled-fixed/NondetBFS.sml @@ -0,0 +1,206 @@ +(* nondeterministic direction-optimized BFS, using CAS on outneighbors to + * construct next frontier. *) +structure NondetBFS = +struct + type 'a seq = 'a Seq.t + + (* structure DS = DelayedSeq *) + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun strip s = + let val (s', start, _) = ArraySlice.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun tryUpdateSome (xs: 'a option array, i: int, old: 'a option, new: 'a option) = + let + val result = Concurrency.casArray (xs, i) (old, new) + in + if MLton.eq (old, result) then + true + else if Option.isSome result then + false + else + tryUpdateSome (xs, i, result, new) + end + + fun bfs (g : G.graph) (s : vertex) = + let + val n = G.numVertices g + val isVisited = strip (Seq.tabulate (fn _ => 0w0: Word8.word) n) + val parent = strip (Seq.tabulate (fn _ => NONE) n) + + (* Choose method of filtering the frontier: either frontier always + * only consists of valid vertex ids, or it allows invalid vertices and + * pretends that these vertices are isolated. *) + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + (* + fun degree v = if v < 0 then 0 else Graph.degree g v + fun filterFrontier s = s + *) + + val denseThreshold = G.numEdges g div 20 + + fun sumOfOutDegrees frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (degree o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense frontier = false + (* let + val n = Seq.length frontier + val m = sumOfOutDegrees frontier + in + n + m > denseThreshold + end *) + + fun bottomUp frontier = + let + val flags = Seq.tabulate (fn _ => false) n + val _ = Seq.foreach frontier (fn (_, v) => + ArraySlice.update (flags, v, true)) + fun inFrontier v = Seq.nth flags (vtoi v) + + fun processVertex v = + case sub (parent, v) of + SOME _ => NONE + | NONE => + let + val nbrs = G.neighbors g (itov v) + val deg = ArraySlice.length nbrs + fun loop i = + if i >= deg then + NONE + else + let + val u = Seq.nth nbrs i + in + if inFrontier u then + let + val parentList = Option.valOf (sub (parent, u)) + in + upd (isVisited, v, 0w1); + upd (parent, v, SOME (u :: parentList)); + SOME v + end + else + loop (i+1) + end + in + loop 0 + end + in + ArraySlice.full (SeqBasis.tabFilter 1000 (0, n) processVertex) + end + + fun topDown frontier = + let + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs = ForkJoin.alloc mf + + fun claim u = + sub (isVisited, u) = 0w0 + andalso + 0w0 = Concurrency.casArray (isVisited, u) (0w0, 0w1) + + fun visitNeighbors offset v nghs = + Util.for (0, Seq.length nghs) (fn i => + let + val u = Seq.nth nghs i + in + if not (claim (vtoi u)) then + upd (outNbrs, offset + i, itov (~1)) + else + let + val parentList = Option.valOf (sub (parent, vtoi v)) + val parentList' = SOME (v :: parentList) + in + upd (parent, vtoi u, parentList'); + upd (outNbrs, offset + i, u) + end + end) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val v = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo v (Seq.subseq (G.neighbors g v) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + (* Either one of the following is correct, but the second one has + * significantly better granularity control for graphs that have a + * small number of vertices with huge degree. *) + + (* val _ = ParUtil.parfor 100 (0, nf) (fn i => + visitMany i (sub (offsets, i)) (sub (offsets, i+1))) *) + + val _ = parVisitMany (0, nf + 1) (0, mf) + in + filterFrontier (ArraySlice.full outNbrs) + end + + fun search frontier = + if Seq.length frontier = 0 then + () + else if shouldProcessDense frontier then + let + val (nextFrontier, tm) = Util.getTime (fn _ => bottomUp frontier) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => topDown frontier) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + + val _ = upd (parent, vtoi s, SOME []) + val _ = upd (isVisited, vtoi s, 0w1) + val _ = search (Seq.fromList [s]) + in + ArraySlice.full parent + end + +end diff --git a/tests/bench/bfs-tree-entangled-fixed/OffsetSearch.sml b/tests/bench/bfs-tree-entangled-fixed/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/bfs-tree-entangled-fixed/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/bfs-tree-entangled-fixed/SerialBFS.sml b/tests/bench/bfs-tree-entangled-fixed/SerialBFS.sml new file mode 100644 index 000000000..fb06b8ba9 --- /dev/null +++ b/tests/bench/bfs-tree-entangled-fixed/SerialBFS.sml @@ -0,0 +1,42 @@ +structure SerialBFS = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, NONE) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + val parentList = Option.valOf (Array.sub (parents, v)) + + fun visit (hi', u) = + case Array.sub (parents, u) of + SOME _ => hi' + | NONE => + ( Array.update (parents, u, SOME (v :: parentList)) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, SOME []) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-tree-entangled-fixed/bfs-tree-entangled-fixed.mlb b/tests/bench/bfs-tree-entangled-fixed/bfs-tree-entangled-fixed.mlb new file mode 100644 index 000000000..37e86d899 --- /dev/null +++ b/tests/bench/bfs-tree-entangled-fixed/bfs-tree-entangled-fixed.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +SerialBFS.sml +OffsetSearch.sml +NondetBFS.sml +main.sml diff --git a/tests/bench/bfs-tree-entangled-fixed/main.sml b/tests/bench/bfs-tree-entangled-fixed/main.sml new file mode 100644 index 000000000..5e19b9261 --- /dev/null +++ b/tests/bench/bfs-tree-entangled-fixed/main.sml @@ -0,0 +1,66 @@ +structure CLA = CommandLineArgs +structure BFS = NondetBFS +structure G = BFS.G + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +(* val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") *) + +val P = Benchmark.run "running bfs" (fn _ => BFS.bfs graph source) + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Option.isSome (Seq.nth P i) then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P v = + case Seq.nth P v of + NONE => ~1 + | SOME parents => List.length parents + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P i = numHops P' i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + diff --git a/tests/bench/bfs-tree-entangled/NondetBFS.sml b/tests/bench/bfs-tree-entangled/NondetBFS.sml new file mode 100644 index 000000000..993ab05b7 --- /dev/null +++ b/tests/bench/bfs-tree-entangled/NondetBFS.sml @@ -0,0 +1,212 @@ +(* nondeterministic direction-optimized BFS, using CAS on outneighbors to + * construct next frontier. *) +structure NondetBFS = +struct + type 'a seq = 'a Seq.t + + (* structure DS = DelayedSeq *) + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun strip s = + let val (s', start, _) = ArraySlice.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun tryUpdateSome (xs: 'a option array, i: int, old: 'a option, new: 'a option) = + let + val result = Concurrency.casArray (xs, i) (old, new) + in + if MLton.eq (old, result) then + true + else if Option.isSome result then + false + else + tryUpdateSome (xs, i, result, new) + end + + fun bfs (g : G.graph) (s : vertex) = + let + val n = G.numVertices g + val parent = strip (Seq.tabulate (fn _ => NONE) n) + + (* Choose method of filtering the frontier: either frontier always + * only consists of valid vertex ids, or it allows invalid vertices and + * pretends that these vertices are isolated. *) + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + (* + fun degree v = if v < 0 then 0 else Graph.degree g v + fun filterFrontier s = s + *) + + val denseThreshold = G.numEdges g div 20 + + fun sumOfOutDegrees frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (degree o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense frontier = false + (* let + val n = Seq.length frontier + val m = sumOfOutDegrees frontier + in + n + m > denseThreshold + end *) + + fun bottomUp frontier = + let + val flags = Seq.tabulate (fn _ => false) n + val _ = Seq.foreach frontier (fn (_, v) => + ArraySlice.update (flags, v, true)) + fun inFrontier v = Seq.nth flags (vtoi v) + + fun processVertex v = + case sub (parent, v) of + SOME _ => NONE + | NONE => + let + val nbrs = G.neighbors g (itov v) + val deg = ArraySlice.length nbrs + fun loop i = + if i >= deg then + NONE + else + let + val u = Seq.nth nbrs i + in + if inFrontier u then + let + val parentList = Option.valOf (sub (parent, u)) + in + upd (parent, v, SOME (u :: parentList)); + SOME v + end + else + loop (i+1) + end + in + loop 0 + end + in + ArraySlice.full (SeqBasis.tabFilter 1000 (0, n) processVertex) + end + + fun topDown frontier = + let + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs = ForkJoin.alloc mf + + (* attempt to claim parent of u as v *) + (* fun claim (u, v) = + sub (parent, u) = ~1 + andalso + ~1 = Concurrency.casArray (parent, u) (~1, v) *) + + fun visitNeighbors offset v nghs = + Util.for (0, Seq.length nghs) (fn i => + let + val u = Seq.nth nghs i + in + case sub (parent, vtoi u) of + SOME _ => upd (outNbrs, offset + i, itov (~1)) + | old as NONE => + let + val parentList = Option.valOf (sub (parent, vtoi v)) + val parentList' = SOME (v :: parentList) + in + if tryUpdateSome (parent, vtoi u, old, parentList') then + upd (outNbrs, offset + i, u) + else + upd (outNbrs, offset + i, itov (~1)) + end + end) + + (* let val u = Seq.nth nghs i + in if not (claim (vtoi u, vtoi v)) + then upd (outNbrs, offset + i, itov (~1)) + else upd (outNbrs, offset + i, u) + end) *) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val v = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo v (Seq.subseq (G.neighbors g v) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + (* Either one of the following is correct, but the second one has + * significantly better granularity control for graphs that have a + * small number of vertices with huge degree. *) + + (* val _ = ParUtil.parfor 100 (0, nf) (fn i => + visitMany i (sub (offsets, i)) (sub (offsets, i+1))) *) + + val _ = parVisitMany (0, nf + 1) (0, mf) + in + filterFrontier (ArraySlice.full outNbrs) + end + + fun search frontier = + if Seq.length frontier = 0 then + () + else if shouldProcessDense frontier then + let + val (nextFrontier, tm) = Util.getTime (fn _ => bottomUp frontier) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => topDown frontier) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + + val _ = upd (parent, vtoi s, SOME []) + val _ = search (Seq.fromList [s]) + in + ArraySlice.full parent + end + +end diff --git a/tests/bench/bfs-tree-entangled/OffsetSearch.sml b/tests/bench/bfs-tree-entangled/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/bfs-tree-entangled/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/bfs-tree-entangled/SerialBFS.sml b/tests/bench/bfs-tree-entangled/SerialBFS.sml new file mode 100644 index 000000000..fb06b8ba9 --- /dev/null +++ b/tests/bench/bfs-tree-entangled/SerialBFS.sml @@ -0,0 +1,42 @@ +structure SerialBFS = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, NONE) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + val parentList = Option.valOf (Array.sub (parents, v)) + + fun visit (hi', u) = + case Array.sub (parents, u) of + SOME _ => hi' + | NONE => + ( Array.update (parents, u, SOME (v :: parentList)) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, SOME []) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs-tree-entangled/bfs-tree-entangled.mlb b/tests/bench/bfs-tree-entangled/bfs-tree-entangled.mlb new file mode 100644 index 000000000..37e86d899 --- /dev/null +++ b/tests/bench/bfs-tree-entangled/bfs-tree-entangled.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +SerialBFS.sml +OffsetSearch.sml +NondetBFS.sml +main.sml diff --git a/tests/bench/bfs-tree-entangled/main.sml b/tests/bench/bfs-tree-entangled/main.sml new file mode 100644 index 000000000..5e19b9261 --- /dev/null +++ b/tests/bench/bfs-tree-entangled/main.sml @@ -0,0 +1,66 @@ +structure CLA = CommandLineArgs +structure BFS = NondetBFS +structure G = BFS.G + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +(* val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") *) + +val P = Benchmark.run "running bfs" (fn _ => BFS.bfs graph source) + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Option.isSome (Seq.nth P i) then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P v = + case Seq.nth P v of + NONE => ~1 + | SOME parents => List.length parents + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P i = numHops P' i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + diff --git a/tests/bench/bfs/NondetBFS.sml b/tests/bench/bfs/NondetBFS.sml new file mode 100644 index 000000000..0522dd85b --- /dev/null +++ b/tests/bench/bfs/NondetBFS.sml @@ -0,0 +1,177 @@ +(* nondeterministic direction-optimized BFS, using CAS on outneighbors to + * construct next frontier. *) +structure NondetBFS = +struct + type 'a seq = 'a Seq.t + + (* structure DS = DelayedSeq *) + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun strip s = + let val (s', start, _) = ArraySlice.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun bfs {diropt: bool} (g : G.graph) (s : vertex) = + let + val n = G.numVertices g + val parent = strip (Seq.tabulate (fn _ => ~1) n) + + (* Choose method of filtering the frontier: either frontier always + * only consists of valid vertex ids, or it allows invalid vertices and + * pretends that these vertices are isolated. *) + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + (* + fun degree v = if v < 0 then 0 else Graph.degree g v + fun filterFrontier s = s + *) + + val denseThreshold = G.numEdges g div 20 + + fun sumOfOutDegrees frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (degree o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense frontier = + diropt andalso + let + val n = Seq.length frontier + val m = sumOfOutDegrees frontier + in + n + m > denseThreshold + end + + fun bottomUp frontier = + let + val flags = Seq.tabulate (fn _ => false) n + val _ = Seq.foreach frontier (fn (_, v) => + ArraySlice.update (flags, v, true)) + fun inFrontier v = Seq.nth flags (vtoi v) + + fun processVertex v = + if sub (parent, v) <> ~1 then NONE else + let + val nbrs = G.neighbors g (itov v) + val deg = ArraySlice.length nbrs + fun loop i = + if i >= deg then + NONE + else + let + val u = Seq.nth nbrs i + in + if inFrontier u then + (upd (parent, v, u); SOME v) + else + loop (i+1) + end + in + loop 0 + end + in + ArraySlice.full (SeqBasis.tabFilter 1000 (0, n) processVertex) + end + + fun topDown frontier = + let + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs = ForkJoin.alloc mf + + (* attempt to claim parent of u as v *) + fun claim (u, v) = + sub (parent, u) = ~1 + andalso + ~1 = Concurrency.casArray (parent, u) (~1, v) + + fun visitNeighbors offset v nghs = + Util.for (0, Seq.length nghs) (fn i => + let val u = Seq.nth nghs i + in if not (claim (vtoi u, vtoi v)) + then upd (outNbrs, offset + i, itov (~1)) + else upd (outNbrs, offset + i, u) + end) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val v = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo v (Seq.subseq (G.neighbors g v) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + (* Either one of the following is correct, but the second one has + * significantly better granularity control for graphs that have a + * small number of vertices with huge degree. *) + + (* val _ = ParUtil.parfor 100 (0, nf) (fn i => + visitMany i (sub (offsets, i)) (sub (offsets, i+1))) *) + + val _ = parVisitMany (0, nf + 1) (0, mf) + in + filterFrontier (ArraySlice.full outNbrs) + end + + fun search frontier = + if Seq.length frontier = 0 then + () + else if shouldProcessDense frontier then + let + val (nextFrontier, tm) = Util.getTime (fn _ => bottomUp frontier) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => topDown frontier) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + search nextFrontier + end + + val _ = upd (parent, vtoi s, s) + val _ = search (Seq.fromList [s]) + in + ArraySlice.full parent + end + +end diff --git a/tests/bench/bfs/OffsetSearch.sml b/tests/bench/bfs/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/bfs/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/bfs/SerialBFS.sml b/tests/bench/bfs/SerialBFS.sml new file mode 100644 index 000000000..60a010386 --- /dev/null +++ b/tests/bench/bfs/SerialBFS.sml @@ -0,0 +1,38 @@ +structure SerialBFS = +struct + + structure G = AdjacencyGraph(Int) + + fun bfs g s = + let + fun neighbors v = G.neighbors g v + fun degree v = G.degree g v + + val n = G.numVertices g + val m = G.numEdges g + + val queue = ForkJoin.alloc (m+1) + val parents = Array.array (n, ~1) + + fun search (lo, hi) = + if lo >= hi then lo else + let + val v = Array.sub (queue, lo) + fun visit (hi', u) = + if Array.sub (parents, u) >= 0 then hi' + else ( Array.update (parents, u, v) + ; Array.update (queue, hi', u) + ; hi'+1 + ) + in + search (lo+1, Seq.iterate visit hi (neighbors v)) + end + + val _ = Array.update (parents, s, s) + val _ = Array.update (queue, 0, s) + val numVisited = search (0, 1) + in + ArraySlice.full parents + end + +end diff --git a/tests/bench/bfs/bfs.mlb b/tests/bench/bfs/bfs.mlb new file mode 100644 index 000000000..37e86d899 --- /dev/null +++ b/tests/bench/bfs/bfs.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +SerialBFS.sml +OffsetSearch.sml +NondetBFS.sml +main.sml diff --git a/tests/bench/bfs/main.sml b/tests/bench/bfs/main.sml new file mode 100644 index 000000000..4778c2ae9 --- /dev/null +++ b/tests/bench/bfs/main.sml @@ -0,0 +1,77 @@ +structure CLA = CommandLineArgs +structure BFS = NondetBFS +structure G = BFS.G + +val dontDirOpt = CLA.parseFlag "no-dir-opt" +val diropt = not dontDirOpt + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val _ = + if not diropt then () else + let + val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) + val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + in + () + end + +val P = Benchmark.run "running bfs" + (fn _ => BFS.bfs {diropt = diropt} graph source) + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + diff --git a/tests/bench/bignum-add-opt/Add.sml b/tests/bench/bignum-add-opt/Add.sml new file mode 100644 index 000000000..c6e40794f --- /dev/null +++ b/tests/bench/bignum-add-opt/Add.sml @@ -0,0 +1,78 @@ +structure Add = +struct + structure Seq = ArraySequence + + type byte = Word8.word + type bignum = byte Seq.t + + fun init (b1, b2) = + Word8.+ (b1, b2) + + fun copy (a, b) = + if b = 0w127 then a else b + + + fun add (x, y) = + let + val nx = Seq.length x + val ny = Seq.length y + val n = Int.max (nx, ny) + + fun nthx i = if i < nx then Seq.nth x i else 0w0 + fun nthy i = if i < ny then Seq.nth y i else 0w0 + + val blockSize = 10000 + val numBlocks = 1 + ((n-1) div blockSize) + + val blockCarries = + SeqBasis.tabulate 1 (0, numBlocks) (fn blockIdx => + let + val lo = blockIdx * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop acc i = + if i >= hi then + acc + else + loop (copy (acc, init (nthx i, nthy i))) (i+1) + in + loop 0w0 lo + end) + + val blockPartials = + SeqBasis.scan 5000 copy 0w0 (0, numBlocks) + (fn i => Array.sub (blockCarries, i)) + + val lastCarry = Array.sub (blockPartials, numBlocks) + + val result = ForkJoin.alloc (n+1) + + val _ = + ForkJoin.parfor 1 (0, numBlocks) (fn blockIdx => + let + val lo = blockIdx * blockSize + val hi = Int.min (lo + blockSize, n) + + fun loop acc i = + if i >= hi then + () + else + let + val sum = init (nthx i, nthy i) + val acc' = copy (acc, sum) + val thisByte = + Word8.andb (Word8.+ (Word8.>> (acc, 0w7), sum), 0wx7F) + in + Array.update (result, i, thisByte); + loop acc' (i+1) + end + in + loop (Array.sub (blockPartials, blockIdx)) lo + end) + + in + if lastCarry > 0w127 then + (Array.update (result, n, 0w1); ArraySlice.full result) + else + (ArraySlice.slice (result, 0, SOME n)) + end +end diff --git a/tests/bench/bignum-add-opt/bignum-add-opt.mlb b/tests/bench/bignum-add-opt/bignum-add-opt.mlb new file mode 100644 index 000000000..b01243dc3 --- /dev/null +++ b/tests/bench/bignum-add-opt/bignum-add-opt.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +../bignum-add/Bignum.sml +Add.sml +../bignum-add/SequentialAdd.sml +main.sml diff --git a/tests/bench/bignum-add-opt/main.sml b/tests/bench/bignum-add-opt/main.sml new file mode 100644 index 000000000..0e29ab546 --- /dev/null +++ b/tests/bench/bignum-add-opt/main.sml @@ -0,0 +1,36 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure Add = Add + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val seed = CLA.parseInt "seed" 15210 +val doCheck = CLA.parseFlag "check" + +val _ = print ("n " ^ Int.toString n ^ "\n") + +val input1 = Bignum.generate n seed +val input2 = Bignum.generate n (seed + n) + +fun task () = + Add.add (input1, input2) + +fun check result = + if not doCheck then () else + let + val (correctResult, tm) = + Util.getTime (fn _ => SequentialAdd.add (input1, input2)) + val _ = print ("sequential " ^ Time.fmt 4 tm ^ "s\n") + val correct = + Seq.equal op= (result, correctResult) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "bignum add" task +val _ = check result + +(* val _ = print ("result " ^ IntInf.toString (Bignum.toIntInf result) ^ "\n") *) diff --git a/tests/bench/bignum-add/Bignum.sml b/tests/bench/bignum-add/Bignum.sml new file mode 100644 index 000000000..b7ed3e688 --- /dev/null +++ b/tests/bench/bignum-add/Bignum.sml @@ -0,0 +1,57 @@ +structure Bignum = +struct + + structure Seq = ArraySequence + + type byte = Word8.word + + (* radix 128 representation *) + type t = byte Seq.t + + type bignum = t + + fun properlyFormatted x = + Seq.length x = 0 orelse Seq.nth x (Seq.length x - 1) <> 0w0 + + fun fromIntInf (x: IntInf.int): bignum = + if x < 0 then + raise Fail "bignums can't be negative" + else + let + fun toList (x: IntInf.int) : byte list = + if x = 0 then [] + else Word8.fromInt (IntInf.toInt (x mod 128)) :: toList (x div 128) + in + Seq.fromList (toList x) + end + + fun toIntInf (n: bignum): IntInf.int = + if not (properlyFormatted n) then + raise Fail "invalid bignum" + else + let + val n' = Seq.map (IntInf.fromInt o Word8.toInt) n + in + Seq.iterate (fn (x, d) => 128 * x + d) (0: IntInf.int) (Seq.rev n') + end + + fun generate n seed = + let + fun hash seed = Util.hash32_2 (Word32.fromInt seed) + fun w32to8 w = Word8.fromInt (Word32.toInt (Word32.andb (w, 0wx7F))) + fun genByte seed = w32to8 (hash seed) + fun genNonZeroByte seed = + w32to8 (Word32.+ (0w1, Word32.mod (hash seed, 0w127))) + in + Seq.tabulate (fn i => + if i < n-1 then + genByte (seed+i) + else + genNonZeroByte (seed+i)) + n + end + + fun toString x = + Seq.toString Word8.toString x + +end diff --git a/tests/bench/bignum-add/MkAdd.sml b/tests/bench/bignum-add/MkAdd.sml new file mode 100644 index 000000000..d55f8a885 --- /dev/null +++ b/tests/bench/bignum-add/MkAdd.sml @@ -0,0 +1,39 @@ +functor MkAdd (Seq: SEQUENCE) = +struct + + structure ASeq = ArraySequence + + type byte = Word8.word + type bignum = byte ASeq.t + + fun nth' s i = + if i < Seq.length s then Seq.nth s i else (0w0: Word8.word) + + fun add (x, y) = + let + val x = Seq.fromArraySeq x + val y = Seq.fromArraySeq y + + val maxlen = Int.max (Seq.length x, Seq.length y) + val sums = Seq.tabulate (fn i => Word8.+ (nth' x i, nth' y i)) (maxlen+1) + + fun propagate (a, b) = + if b = 0w127 then a else b + val (carries, _) = Seq.scan propagate 0w0 sums + + fun f (carry, sum) = + Word8.andb (Word8.+ (Word8.>> (carry, 0w7), sum), 0wx7F) + + val result = + Seq.force (Seq.zipWith f (carries, sums)) + + val r = Seq.toArraySeq result + in + (* [r] might have a trailing 0. Cut it off. *) + if ASeq.length r = 0 orelse (ASeq.nth r (ASeq.length r - 1) > 0w0) then + r + else + ASeq.take r (ASeq.length r - 1) + end + +end diff --git a/tests/bench/bignum-add/SequentialAdd.sml b/tests/bench/bignum-add/SequentialAdd.sml new file mode 100644 index 000000000..f5a127c53 --- /dev/null +++ b/tests/bench/bignum-add/SequentialAdd.sml @@ -0,0 +1,76 @@ +structure SequentialAdd = +struct + structure A = Array + structure AS = ArraySlice + structure Seq = ArraySequence + + (* radix 128 *) + type byte = Word8.word + type bignum = byte Seq.t + + fun addWithCarry3 (c, b1, b2) = + let + val x = Word8.+ (b1, Word8.+ (b2, c)) + in + {result = Word8.andb (x, 0wx7F), carry = Word8.>> (x, 0w7)} + end + + fun addWithCarry2 (b1, b2) = + addWithCarry3 (0w0, b1, b2) + + fun add (s1, s2) = + let + val n1 = Seq.length s1 + val n2 = Seq.length s2 + val n = Int.max (n1, n2) + + val r = ForkJoin.alloc (1 + n) + + fun finish1 i carry = + if i = n1 then + (A.update (r, i, carry); carry) + else + let + val {result, carry=carry'} = addWithCarry2 (Seq.nth s1 i, carry) + in + A.update (r, i, result); + finish1 (i+1) carry' + end + + fun finish2 i carry = + if i = n2 then + (A.update (r, i, carry); carry) + else + let + val {result, carry=carry'} = addWithCarry2 (Seq.nth s2 i, carry) + in + A.update (r, i, result); + finish2 (i+1) carry' + end + + fun loop i carry = + if i = n1 then + finish2 i carry + else if i = n2 then + finish1 i carry + else + let + val {result, carry=carry'} = + addWithCarry3 (Seq.nth s1 i, Seq.nth s2 i, carry) + in + A.update (r, i, result); + loop (i+1) carry' + end + in + (** Run the loop, and inspect the last carry value. + * If it is 1, then the output is well-formed. + * If it is 0, we need to trim. + *) + case loop 0 0w0 of + 0w0 => + AS.slice (r, 0, SOME n) + | _ => + AS.full r + end + +end diff --git a/tests/bench/bignum-add/bignum-add.mlb b/tests/bench/bignum-add/bignum-add.mlb new file mode 100644 index 000000000..773234ca1 --- /dev/null +++ b/tests/bench/bignum-add/bignum-add.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +Bignum.sml +MkAdd.sml +SequentialAdd.sml +main.sml diff --git a/tests/bench/bignum-add/main.sml b/tests/bench/bignum-add/main.sml new file mode 100644 index 000000000..9e7cebf74 --- /dev/null +++ b/tests/bench/bignum-add/main.sml @@ -0,0 +1,36 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure Add = MkAdd(DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val seed = CLA.parseInt "seed" 15210 +val doCheck = CLA.parseFlag "check" + +val _ = print ("n " ^ Int.toString n ^ "\n") + +val input1 = Bignum.generate n seed +val input2 = Bignum.generate n (seed + n) + +fun task () = + Add.add (input1, input2) + +fun check result = + if not doCheck then () else + let + val (correctResult, tm) = + Util.getTime (fn _ => SequentialAdd.add (input1, input2)) + val _ = print ("sequential " ^ Time.fmt 4 tm ^ "s\n") + val correct = + Seq.equal op= (result, correctResult) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "bignum add" task +val _ = check result + +(* val _ = print ("result " ^ IntInf.toString (Bignum.toIntInf result) ^ "\n") *) diff --git a/tests/bench/centrality/BC.sml b/tests/bench/centrality/BC.sml new file mode 100644 index 000000000..1351e4638 --- /dev/null +++ b/tests/bench/centrality/BC.sml @@ -0,0 +1,333 @@ +structure BC = +struct + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + + type vertex = G.vertex + type graph = G.graph + + val sub = Array.sub + val upd = Array.update + + val vtoi = V.toInt + val itov = V.fromInt + + (* fun ASsub s = + let val (a, i, _) = ArraySlice.base s + in sub (a, i+s) + end *) + + val GRAIN = 10000 + + fun sumOfOutDegrees g frontier = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length frontier) (G.degree g o Seq.nth frontier) + (* DS.reduce op+ 0 (DS.map degree (DS.fromArraySeq frontier)) *) + + fun shouldProcessDense g frontier = + let + val n = Seq.length frontier + val m = sumOfOutDegrees g frontier + in + n + m > (G.numEdges g) div 20 + end + + fun edgeMapDense + { cond : (vertex * 'a) -> bool + , func : (vertex * vertex * 'a * 'a) -> 'a option + , eq : 'a * 'a -> bool + , frontier + , visitedRoundNums + , state + , roundNum + , graph = g + , shouldOutput + } = + let + val N = G.numVertices g + val M = G.numEdges g + + (* val visitedRoundNums = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (visitedRoundNums, i, ~1)) *) + fun visitedRound v = sub (visitedRoundNums, v) + fun isVisited v = visitedRound v <> ~1 + fun setVisited v r = upd (visitedRoundNums, v, r) + + (* val state : 'a array = ForkJoin.alloc N *) + (* val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (state, i, initialState i)) *) + fun getState v = sub (state, v) + fun setState v s = upd (state, v, s) + + val flags = Seq.tabulate (fn _ => false) N + val _ = Seq.foreach frontier (fn (_, v) => + ArraySlice.update (flags, v, true)) + fun inFrontier v = Seq.nth flags (vtoi v) + + fun processVertex v = + if isVisited v then NONE else + let + val nbrs = G.neighbors g (itov v) + val deg = ArraySlice.length nbrs + fun loop visited sv i = + if i >= deg then (visited, sv) else + let + val u = Seq.nth nbrs i + in + if not (inFrontier u) then loop visited sv (i+1) else + case func (u, v, getState u, sv) of + NONE => loop true sv (i+1) + | SOME sv' => if cond (v, sv') then (true, sv') else loop true sv' (i+1) + end + val sv = getState v + val (visited, sv') = loop false sv 0 + in + if eq (sv, sv') then () else setState v sv'; + if not visited then NONE else (setVisited v roundNum; SOME v) + end + in + if shouldOutput then + ArraySlice.full (SeqBasis.tabFilter 1000 (0, N) processVertex) + (* G.tabFilter 1000 (0, N) processVertex *) + else + (ForkJoin.parfor 1000 (0, N) (ignore o processVertex); Seq.empty ()) + end + + fun edgeMapSparse + { cond : (vertex * 'a) -> bool + , func : (vertex * vertex * 'a * 'a) -> 'a option + , eq : 'a * 'a -> bool + , frontier + , visitedRoundNums + , roundNum + , state + , graph = g + , shouldOutput + } = + let + val N = G.numVertices g + val M = G.numEdges g + + fun degree v = G.degree g v + fun filterFrontier s = Seq.filter (fn x => x <> itov (~1)) s + + (* val visitedRoundNums = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (visitedRoundNums, i, ~1)) *) + fun visitedRound v = sub (visitedRoundNums, v) + fun isVisited v = visitedRound v <> ~1 + fun setVisited v r = upd (visitedRoundNums, v, r) + fun claimVisited v r = + ~1 = Concurrency.casArray (visitedRoundNums, v) (~1, r) + + (* val state : 'a array = ForkJoin.alloc N *) + (* val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (state, i, initialState i)) *) + fun getState v = sub (state, v) + fun setState v s = upd (state, v, s) + + (* repeatedly try to update the state of v by CASing the result of + * computing func (u, v, su, sv) *) + fun tryPushUpdateState (u, v, su, sv) = + if cond (v, sv) then () else + case func (u, v, su, sv) of + NONE => () + | SOME desired => + let + val sv' = Concurrency.casArray (state, v) (sv, desired) + in + if eq (sv', sv) + then () + else tryPushUpdateState (u, v, su, sv') + end + + val nf = Seq.length frontier + val offsets = SeqBasis.scan GRAIN op+ 0 (0, nf) (degree o Seq.nth frontier) + val mf = sub (offsets, nf) + val outNbrs = + if shouldOutput + then ForkJoin.alloc mf + else Array.fromList [] + + fun writeOut i x = upd (outNbrs, i, x) + fun checkWriteOut i x = + if shouldOutput then writeOut i x else () + + fun visitNeighbors offset u nghs = + Util.for (0, Seq.length nghs) (fn i => + let + val v = Seq.nth nghs i + val r = visitedRound v + in + if 0 <= r andalso r < roundNum then + (* v was visited on a previous round, so ignore it. *) + checkWriteOut (offset+i) (itov (~1)) + else + ( if shouldOutput then + (if r = ~1 andalso claimVisited v roundNum then + writeOut (offset+i) v + else + writeOut (offset+i) (itov (~1))) + else + (if r = ~1 then setVisited v roundNum else ()) + + (* regardless, we need to check for updating state. *) + ; tryPushUpdateState (u, v, getState u, getState v) + ) + end) + + fun visitMany offlo lo hi = + if lo = hi then () else + let + val u = Seq.nth frontier offlo + val voffset = sub (offsets, offlo) + val k = Int.min (hi - lo, sub (offsets, offlo+1) - lo) + in + if k = 0 then visitMany (offlo+1) lo hi + else ( visitNeighbors lo u (Seq.subseq (G.neighbors g u) (lo - voffset, k)) + ; visitMany (offlo+1) (lo+k) hi + ) + end + + fun parVisitMany (offlo, offhi) (lo, hi) = + if hi - lo <= GRAIN then + visitMany offlo lo hi + else + let + val mid = lo + (hi - lo) div 2 + val (i, j) = OffsetSearch.search mid offsets (offlo, offhi) + val _ = ForkJoin.par + ( fn _ => parVisitMany (offlo, i) (lo, mid) + , fn _ => parVisitMany (j-1, offhi) (mid, hi) + ) + in + () + end + + (* Either one of the following is correct, but the second one has + * significantly better granularity control for graphs that have a + * small number of vertices with huge degree. *) + + (* val _ = ForkJoin.parfor 100 (0, nf) (fn i => + visitMany i (sub (offsets, i)) (sub (offsets, i+1))) *) + + val _ = parVisitMany (0, nf + 1) (0, mf) + in + filterFrontier (ArraySlice.full outNbrs) + end + + fun edgeMap X = + if shouldProcessDense (#graph X) (#frontier X) then + let + val (nextFrontier, tm) = Util.getTime (fn _ => edgeMapDense X) + in + print ("dense " ^ Time.fmt 4 tm ^ "\n"); + nextFrontier + end + else + let + val (nextFrontier, tm) = Util.getTime (fn _ => edgeMapSparse X) + in + print ("sparse " ^ Time.fmt 4 tm ^ "\n"); + nextFrontier + end + + fun bc graph source = + let + val g = graph + + val N = G.numVertices g + val M = G.numEdges g + + fun initialNumPaths v = + if v = source then 1 else 0 + + val visitedRoundNums = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (visitedRoundNums, i, ~1)) + + val numPathsArr = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (numPathsArr, i, initialNumPaths i)) + + (* accumulate number of paths through v *) + fun edgeFunc (u, v, uNumPaths, vNumPaths) = + uNumPaths + vNumPaths + + fun forwardsEdgeMap roundNum frontier = + edgeMap + { cond = (fn _ => false) (* accumulate all edges *) + , func = SOME o edgeFunc + , eq = op= + , frontier = frontier + , visitedRoundNums = visitedRoundNums + , roundNum = roundNum + , state = numPathsArr + , graph = g + , shouldOutput = true + } + + fun forwardsLoop pastFrontiers roundNum frontier = + if Seq.length frontier = 0 then + pastFrontiers + else + let + val nextFrontier = forwardsEdgeMap roundNum frontier + in + forwardsLoop (frontier :: pastFrontiers) (roundNum+1) nextFrontier + end + + val _ = upd (visitedRoundNums, source, 0) + val frontiers = forwardsLoop [] 1 (Seq.fromList [source]) + + val numPaths = ArraySlice.full numPathsArr + + (* val mnp = DS.reduce Int.max 0 (DS.fromArraySeq numPaths) *) + val mnp = SeqBasis.reduce 10000 Int.max 0 (0, N) (Seq.nth numPaths) + val _ = print ("max-num-paths " ^ Int.toString mnp ^ "\n") + + val lastFrontier = List.hd frontiers + + (* ===================================================================== + * second phase: search in reverse. + *) + + val invNumPaths = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => + upd (invNumPaths, i, 1.0 / Real.fromInt (sub (numPathsArr, i)))) + + val visitedRoundNums = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (visitedRoundNums, i, ~1)) + + val deps = ForkJoin.alloc N + val _ = ForkJoin.parfor 10000 (0, N) (fn i => upd (deps, i, 0.0)) + + fun edgeFunc (u, v, uDep, vDep) = + vDep + uDep + sub (invNumPaths, u) + + fun backwardsEdgeMap roundNum frontier = + edgeMap + { cond = (fn _ => false) (* accumulate all edges *) + , func = SOME o edgeFunc + , eq = Real.== + , frontier = frontier + , visitedRoundNums = visitedRoundNums + , roundNum = roundNum + , state = deps + , graph = g + , shouldOutput = false + } + + fun backwardsLoop frontiers roundNum = + case frontiers of + [] => () + | frontier :: frontiers' => + let + val _ = Seq.foreach frontier (fn (_, v) => + upd (visitedRoundNums, v, roundNum)) + + val _ = backwardsEdgeMap (roundNum+1) frontier + in + backwardsLoop frontiers' (roundNum+1) + end + + val _ = backwardsLoop frontiers 0 + in + Seq.tabulate (fn i => sub (deps, i) / sub (invNumPaths, i)) N + end + +end diff --git a/tests/bench/centrality/OffsetSearch.sml b/tests/bench/centrality/OffsetSearch.sml new file mode 100644 index 000000000..7e17febb8 --- /dev/null +++ b/tests/bench/centrality/OffsetSearch.sml @@ -0,0 +1,54 @@ +structure OffsetSearch :> +sig + (* `search x xs (lo, hi)` searches the sorted array `xs` between indices `lo` + * and `hi`, returning `(i, j)` where `i-lo` is the number of elements that + * are strictly less than `x`, and `j-i` is the number of elements which are + * equal to `x`. *) + val search : int -> int array -> int * int -> int * int +end = +struct + + val sub = Array.sub + val upd = Array.update + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > sub (xs, mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Array.length xs - 1) orelse (x < sub (xs, mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int array) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, sub (xs, mid)) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + +end diff --git a/tests/bench/centrality/centrality.mlb b/tests/bench/centrality/centrality.mlb new file mode 100644 index 000000000..8bddac8b8 --- /dev/null +++ b/tests/bench/centrality/centrality.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +OffsetSearch.sml +BC.sml +main.sml diff --git a/tests/bench/centrality/main.sml b/tests/bench/centrality/main.sml new file mode 100644 index 000000000..d34146707 --- /dev/null +++ b/tests/bench/centrality/main.sml @@ -0,0 +1,57 @@ +structure CLA = CommandLineArgs +structure BC = BC +structure G = BC.G + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val source = CLA.parseInt "source" 0 + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + +val result = Benchmark.run "running centrality" (fn _ => BC.bc graph source) + +val maxDep = + SeqBasis.reduce 10000 Real.max 0.0 (0, Seq.length result) (Seq.nth result) + (* DS.reduce Real.max 0.0 (DS.fromArraySeq result) *) +val _ = print ("maxdep " ^ Real.toString maxDep ^ "\n") + +val totDep = + SeqBasis.reduce 10000 op+ 0.0 (0, Seq.length result) (Seq.nth result) + (* DS.reduce op+ 0.0 (DS.fromArraySeq result) *) +val _ = print ("avgdep " ^ Real.toString (totDep / Real.fromInt (Seq.length result)) ^ "\n") + +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length result) + (fn i => if Seq.nth result i < 0.0 then 0 else 1) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +val outfile = CLA.parseString "outfile" "" +val _ = + if outfile = "" then + print ("use -outfile XXX to see result\n") + else + let + val n = Seq.length result + val file = TextIO.openOut outfile + fun dump i = + if i >= n then () + else (TextIO.output (file, Real.toString (Seq.nth result i)); + TextIO.output (file, "\n"); + dump (i+1)) + in + dump 0; + TextIO.closeOut file + end + diff --git a/tests/bench/collect/CollectHash.sml b/tests/bench/collect/CollectHash.sml new file mode 100644 index 000000000..9309c2ef0 --- /dev/null +++ b/tests/bench/collect/CollectHash.sml @@ -0,0 +1,30 @@ +functor CollectHash (structure K: KEY structure V: VALUE): +sig + val collect: (K.t * V.t) Seq.t -> (K.t * V.t) Seq.t +end = +struct + + structure T = HashTable (structure K = K structure V = V) + + + fun collect kvs = + let + val t = T.make {capacity = Seq.length kvs} (* very rough upper bound *) + + val _ = ForkJoin.parfor 100 (0, Seq.length kvs) (fn i => + T.insert_combine t (Seq.nth kvs i)) + + val contents = T.unsafe_view_contents t + val results = + ArraySlice.full + (SeqBasis.filter 1000 (0, DelayedSeq.length contents) + (fn i => valOf (DelayedSeq.nth contents i)) + (fn i => Option.isSome (DelayedSeq.nth contents i))) + + val sorted = + Mergesort.sort (fn ((k1, v1), (k2, v2)) => K.cmp (k1, k2)) results + in + sorted + end + +end diff --git a/tests/bench/collect/CollectSort.sml b/tests/bench/collect/CollectSort.sml new file mode 100644 index 000000000..276b656f4 --- /dev/null +++ b/tests/bench/collect/CollectSort.sml @@ -0,0 +1,45 @@ +functor CollectSort (structure K: KEY structure V: VALUE): +sig + val collect: (K.t * V.t) Seq.t -> (K.t * V.t) Seq.t +end = +struct + + structure Key = Int + + type key = int + + fun collect kvs = + let + val n = Seq.length kvs + + fun key (k, v) = k + fun value (k, v) = v + + fun key_cmp (kv1, kv2) = + K.cmp (key kv1, key kv2) + + val sorted = Mergesort.sort key_cmp kvs + + val boundaries = + ArraySlice.full + (SeqBasis.filter 1000 (0, Seq.length sorted) (fn i => i) (fn i => + i = 0 + orelse key_cmp (Seq.nth sorted (i - 1), Seq.nth sorted i) <> EQUAL)) + + fun make i = + let + val start = Seq.nth boundaries i + val stop = + if i + 1 = Seq.length boundaries then n + else Seq.nth (boundaries) (i + 1) + val k = key (Seq.nth sorted start) + val v = SeqBasis.reduce 1000 V.combine V.zero (start, stop) (fn j => + value (Seq.nth sorted j)) + in + (k, v) + end + in + Seq.tabulate make (Seq.length boundaries) + end + +end diff --git a/tests/bench/collect/HashTable.sml b/tests/bench/collect/HashTable.sml new file mode 100644 index 000000000..0eadfa6df --- /dev/null +++ b/tests/bench/collect/HashTable.sml @@ -0,0 +1,101 @@ +functor HashTable (structure K: KEY structure V: VALUE) = +struct + + datatype t = T of {keys: K.t array, values: V.t array} + + exception Full + exception DuplicateKey + + type table = t + + + fun make {capacity} = + let + val keys = SeqBasis.tabulate 5000 (0, capacity) (fn _ => K.empty) + val values = SeqBasis.tabulate 5000 (0, capacity) (fn _ => V.zero) + in + T {keys = keys, values = values} + end + + + fun capacity (T {keys, ...}) = Array.length keys + + + fun size (T {keys, ...}) = + SeqBasis.reduce 10000 op+ 0 (0, Array.length keys) (fn i => + if K.equal (Array.sub (keys, i), K.empty) then 0 else 1) + + + fun unsafe_view_contents (tab as T {keys, values}) = + let + val capacity = Array.length keys + + fun elem i = + let + val k = Array.sub (keys, i) + in + if K.equal (k, K.empty) then NONE else SOME (k, Array.sub (values, i)) + end + in + DelayedSeq.tabulate elem (Array.length keys) + end + + + fun bcas (arr, i, old, new) = + MLton.eq (old, Concurrency.casArray (arr, i) (old, new)) + + + fun atomic_combine_with (f: 'a * 'a -> 'a) (arr: 'a array, i) (x: 'a) = + let + fun loop current = + let + val desired = f (current, x) + in + if MLton.eq (desired, current) then + () + else + let + val current' = + Concurrency.casArray (arr, i) (current, desired) + in + if MLton.eq (current', current) then () else loop current' + end + end + in + loop (Array.sub (arr, i)) + end + + + fun insert_combine (input as T {keys, values}) (x, v) = + let + val n = Array.length keys + val tolerance = n + + fun claim_slot_at i = bcas (keys, i, K.empty, x) + + fun put_value_at i = + atomic_combine_with V.combine (values, i) v + + fun loop i probes = + if probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let + val k = Array.sub (keys, i) + in + if K.equal (k, K.empty) then + if claim_slot_at i then put_value_at i else loop i probes + else if K.equal (k, x) then + put_value_at i + else + loop (i + 1) (probes + 1) + end + + val start = (K.hash x) mod (Array.length keys) + in + loop start 0 + end + +end diff --git a/tests/bench/collect/KEY.sml b/tests/bench/collect/KEY.sml new file mode 100644 index 000000000..c27654cfa --- /dev/null +++ b/tests/bench/collect/KEY.sml @@ -0,0 +1,8 @@ +signature KEY = +sig + type t + val equal: t * t -> bool + val cmp: t * t -> order + val empty: t + val hash: t -> int +end diff --git a/tests/bench/collect/VALUE.sml b/tests/bench/collect/VALUE.sml new file mode 100644 index 000000000..2b0b19065 --- /dev/null +++ b/tests/bench/collect/VALUE.sml @@ -0,0 +1,6 @@ +signature VALUE = +sig + type t + val zero: t + val combine: t * t -> t +end diff --git a/tests/bench/collect/collect.mlb b/tests/bench/collect/collect.mlb new file mode 100644 index 000000000..3532e2cdc --- /dev/null +++ b/tests/bench/collect/collect.mlb @@ -0,0 +1,7 @@ +../../mpllib/sources.$(COMPAT).mlb +KEY.sml +VALUE.sml +HashTable.sml +CollectHash.sml +CollectSort.sml +main.sml \ No newline at end of file diff --git a/tests/bench/collect/main.sml b/tests/bench/collect/main.sml new file mode 100644 index 000000000..a45b76a4f --- /dev/null +++ b/tests/bench/collect/main.sml @@ -0,0 +1,49 @@ +structure CLA = CommandLineArgs +val n = CLA.parseInt "n" 10000000 +val k = CLA.parseInt "k" 1000 +val seed = CLA.parseInt "seed" 15210 +val impl = CLA.parseString "impl" "sort" + +val _ = print ("n " ^ Int.toString n ^ "\n") +val _ = print ("k " ^ Int.toString k ^ "\n") +val _ = print ("seed " ^ Int.toString seed ^ "\n") +val _ = print ("impl " ^ impl ^ "\n") + +fun gen_real seed = + Real.fromInt (Util.hash seed mod 100000000) / 100000000.0 + +fun gen_elem i = + (Util.hash (seed + 2 * i) mod k, gen_real (seed + 2 * i + 1)) + +val kvs = Seq.tabulate gen_elem n + +structure K = +struct + type t = int + fun equal (x: t, y: t) = (x = y) + fun cmp (x, y) = Int.compare (x, y) + val empty = ~1 + val hash = Util.hash +end + +structure V = struct type t = real val zero = 0.0 val combine = Real.+ end + + +structure CollectSort = CollectSort (structure K = K structure V = V) +structure CollectHash = CollectHash (structure K = K structure V = V) + +fun bench () = + case impl of + "sort" => CollectSort.collect kvs + | "hash" => CollectHash.collect kvs + | _ => Util.die "unknown impl" + +val result = Benchmark.run "collect" bench + +val _ = print ("num unique keys: " ^ Int.toString (Seq.length result) ^ "\n") +val _ = print + ("result: " + ^ + Util.summarizeArraySlice 10 + (fn (k, v) => "(" ^ Int.toString k ^ "," ^ Real.toString v ^ ")") result + ^ "\n") diff --git a/tests/bench/dedup-entangled-fixed/NondetDedup.sml b/tests/bench/dedup-entangled-fixed/NondetDedup.sml new file mode 100644 index 000000000..73a98e925 --- /dev/null +++ b/tests/bench/dedup-entangled-fixed/NondetDedup.sml @@ -0,0 +1,143 @@ +(* Phase-concurrent hash table based deduplication. + * See https://people.csail.mit.edu/jshun/hash.pdf. + * + * This entangled benchmark deduplicates a sequence of 64-bit integers using a + * phase-concurrent hash table implementation. The basic idea is that the hash + * table stores int options, which are heap-allocated by the thread inserting + * into the hash table. Hence, when a thread probes the table during an + * insertion, it may CAS and load an allocation made by a concurrent thread, + * thereby tripping the entanglement checker. + *) + +structure A = Array +structure AS = ArraySlice +val update = Array.update +val sub = Array.sub + +structure Hashtbl = struct + type 'a t = 'a option array * ('a -> int) * (('a * 'a) -> order) + + val gran = 10000 + + fun create hash cmp n = + let + val t = ForkJoin.alloc n + val () = ForkJoin.parfor gran (0, n) (fn i => update (t, i, NONE)) + in + (t, hash, cmp) + end + + fun insert (t, hash, cmp) xx = + let + val x = case xx of SOME x => x | NONE => raise Fail "impossible!" + val n = A.length t + fun nextIndex i = + if i = n - 1 then 0 + else i + 1 + fun hash' x = + let + val y = hash x + in + if y < 0 then ~y mod n + else y mod n + end + fun cmp' (x, y) = + case (x, y) of + (NONE, NONE) => EQUAL + | (NONE, SOME _) => LESS + | (SOME _, NONE) => GREATER + | (SOME x', SOME y') => cmp (x', y') + fun probe (i, x) = + if not (Option.isSome x) then () else + let + val y = sub (t, i) + in + case cmp' (x, y) of + EQUAL => () + | LESS => probe (nextIndex i, x) + | GREATER => + let + val z = Concurrency.casArray (t, i) (y, x) + in + if MLton.eq (y, z) then probe (nextIndex i, y) + else probe (i, x) + end + end + in + probe (hash' x, xx) + end + + fun keys (t, _, _) = + let + val n = A.length t + val t' = SeqBasis.tabFilter gran (0, n) (fn i => sub (t, i)) + in + AS.full t' + end +end + +(* val dedup : ('k -> int) hash function + -> (('k, 'k) -> order) comparison function + -> 'k seq input (with duplicates) + -> 'k seq deduplicated (not sorted!) +*) +fun dedup hash cmp keys = + if AS.length keys = 0 then Seq.empty () else + let + val n = AS.length keys + val tbl = Hashtbl.create hash cmp (4 * n) + val keys' = Seq.map SOME keys + val () = + ForkJoin.parfor 100 (0, n) (fn i => Hashtbl.insert tbl (Seq.nth keys' i)) + in + Hashtbl.keys tbl + end + +(* ========================================================================== + * now the main bit *) + +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: dedup [--verbose] [--no-output] [-N]\n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val n = CLA.parseInt "N" (100 * 1000 * 1000) + +val beVerbose = CommandLineArgs.parseFlag "verbose" +val noOutput = CommandLineArgs.parseFlag "no-output" +val rep = case (Int.fromString (CLA.parseString "repeat" "1")) of + SOME(a) => a + | NONE => 1 + +fun vprint str = + if not beVerbose then () + else TextIO.output (TextIO.stdErr, str) + +val input = Seq.tabulate Util.hash n + +fun dedupEx () = + dedup Util.hash Int.compare input + +val result = Benchmark.run "running dedup" dedupEx + +fun put c = TextIO.output1 (TextIO.stdOut, c) +val _ = + if noOutput then () + else + let + val (_, tm) = Util.getTime (fn _ => + ArraySlice.app (fn x => (print (Int.toString x); put #"\n")) result) + in + vprint ("output in " ^ Time.fmt 4 tm ^ "s\n") + end + +val _ = + if not beVerbose then () + else GCStats.report () + diff --git a/tests/bench/dedup-entangled-fixed/dedup-entangled-fixed.mlb b/tests/bench/dedup-entangled-fixed/dedup-entangled-fixed.mlb new file mode 100644 index 000000000..20bb8b2f1 --- /dev/null +++ b/tests/bench/dedup-entangled-fixed/dedup-entangled-fixed.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +$(SML_LIB)/basis/mlton.mlb +NondetDedup.sml diff --git a/tests/bench/dedup-entangled/NondetDedup.sml b/tests/bench/dedup-entangled/NondetDedup.sml new file mode 100644 index 000000000..e883904cc --- /dev/null +++ b/tests/bench/dedup-entangled/NondetDedup.sml @@ -0,0 +1,141 @@ +(* Phase-concurrent hash table based deduplication. + * See https://people.csail.mit.edu/jshun/hash.pdf. + * + * This entangled benchmark deduplicates a sequence of 64-bit integers using a + * phase-concurrent hash table implementation. The basic idea is that the hash + * table stores int options, which are heap-allocated by the thread inserting + * into the hash table. Hence, when a thread probes the table during an + * insertion, it may CAS and load an allocation made by a concurrent thread, + * thereby tripping the entanglement checker. + *) + +structure A = Array +structure AS = ArraySlice +val update = Array.update +val sub = Array.sub + +structure Hashtbl = struct + type 'a t = 'a option array * ('a -> int) * (('a * 'a) -> order) + + val gran = 10000 + + fun create hash cmp n = + let + val t = ForkJoin.alloc n + val () = ForkJoin.parfor gran (0, n) (fn i => update (t, i, NONE)) + in + (t, hash, cmp) + end + + fun insert (t, hash, cmp) x = + let + val n = A.length t + fun nextIndex i = + if i = n - 1 then 0 + else i + 1 + fun hash' x = + let + val y = hash x + in + if y < 0 then ~y mod n + else y mod n + end + fun cmp' (x, y) = + case (x, y) of + (NONE, NONE) => EQUAL + | (NONE, SOME _) => LESS + | (SOME _, NONE) => GREATER + | (SOME x', SOME y') => cmp (x', y') + fun probe (i, x) = + if not (Option.isSome x) then () else + let + val y = sub (t, i) + in + case cmp' (x, y) of + EQUAL => () + | LESS => probe (nextIndex i, x) + | GREATER => + let + val z = Concurrency.casArray (t, i) (y, x) + in + if MLton.eq (y, z) then probe (nextIndex i, y) + else probe (i, x) + end + end + in + probe (hash' x, SOME x) + end + + fun keys (t, _, _) = + let + val n = A.length t + val t' = SeqBasis.tabFilter gran (0, n) (fn i => sub (t, i)) + in + AS.full t' + end +end + +(* val dedup : ('k -> int) hash function + -> (('k, 'k) -> order) comparison function + -> 'k seq input (with duplicates) + -> 'k seq deduplicated (not sorted!) +*) +fun dedup hash cmp keys = + if AS.length keys = 0 then Seq.empty () else + let + val n = AS.length keys + val tbl = Hashtbl.create hash cmp (4 * n) + val () = + ForkJoin.parfor 100 (0, n) (fn i => Hashtbl.insert tbl (Seq.nth keys i)) + in + Hashtbl.keys tbl + end + +(* ========================================================================== + * now the main bit *) + +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: dedup [--verbose] [--no-output] [-N]\n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val n = CLA.parseInt "N" (100 * 1000 * 1000) + +val beVerbose = CommandLineArgs.parseFlag "verbose" +val noOutput = CommandLineArgs.parseFlag "no-output" +val rep = case (Int.fromString (CLA.parseString "repeat" "1")) of + SOME(a) => a + | NONE => 1 + +fun vprint str = + if not beVerbose then () + else TextIO.output (TextIO.stdErr, str) + +val input = Seq.tabulate Util.hash n + +fun dedupEx () = + dedup Util.hash Int.compare input + +val result = Benchmark.run "running dedup" dedupEx + +fun put c = TextIO.output1 (TextIO.stdOut, c) +val _ = + if noOutput then () + else + let + val (_, tm) = Util.getTime (fn _ => + ArraySlice.app (fn x => (print (Int.toString x); put #"\n")) result) + in + vprint ("output in " ^ Time.fmt 4 tm ^ "s\n") + end + +val _ = + if not beVerbose then () + else GCStats.report () + diff --git a/tests/bench/dedup-entangled/dedup-entangled.mlb b/tests/bench/dedup-entangled/dedup-entangled.mlb new file mode 100644 index 000000000..20bb8b2f1 --- /dev/null +++ b/tests/bench/dedup-entangled/dedup-entangled.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +$(SML_LIB)/basis/mlton.mlb +NondetDedup.sml diff --git a/tests/bench/dedup/dedup.mlb b/tests/bench/dedup/dedup.mlb new file mode 100644 index 000000000..a1801a450 --- /dev/null +++ b/tests/bench/dedup/dedup.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +dedup.sml diff --git a/tests/bench/dedup/dedup.sml b/tests/bench/dedup/dedup.sml new file mode 100644 index 000000000..9173dadfb --- /dev/null +++ b/tests/bench/dedup/dedup.sml @@ -0,0 +1,216 @@ +structure A = Array +structure AS = ArraySlice +val update = Array.update +val sub = Array.sub + +fun chunkedfor chunkSize (flo, fhi) f = + let + val n = fhi - flo + val numChunks = (n-1) div chunkSize + 1 + in + Util.for (0, numChunks) (fn i => + let + val clo = flo + i*chunkSize + val chi = if i = numChunks - 1 then fhi else flo + (i+1)*chunkSize + in + Util.for (clo, chi) f + end) + end + +fun chunkedloop chunkSize (flo, fhi) init f = + let + val n = fhi - flo + val numChunks = (n-1) div chunkSize + 1 + in + Util.loop (0, numChunks) init (fn (b, i) => + let + val clo = flo + i*chunkSize + val chi = if i = numChunks - 1 then fhi else flo + (i+1)*chunkSize + val b' = Util.loop (clo, chi) b f + in + b' + end) + end + +datatype 'a bucketTree = + Leaf of 'a array +| Node of int * 'a bucketTree * 'a bucketTree + +fun count t = + case t of + Leaf a => A.length a + | Node (c, _, _) => c + +fun bucketTree n (f : int -> 'a array) = + let + fun tree (lo, hi) = + case hi - lo of + 0 => Leaf (ForkJoin.alloc 0) + | 1 => Leaf (f lo) + | n => let val mid = lo + n div 2 + val (l, r) = ForkJoin.par (fn _ => tree (lo, mid), fn _ => tree (mid, hi)) + in Node (count l + count r, l, r) + end + in + tree (0, n) + end + +fun indexApp chunkSize (f : (int * 'a) -> unit) (t : 'a bucketTree) = + let + fun app offset t = + case t of + Leaf a => chunkedfor chunkSize (0, A.length a) (fn i => f (offset+i, sub (a, i))) + | Node (_, l, r) => + (ForkJoin.par (fn _ => app offset l, fn _ => app (offset + count l) r); + ()) + in + app 0 t + end + +fun compactFilter chunkSize (s : 'a option array) count = + let + val t = ForkJoin.alloc count + val _ = chunkedloop chunkSize (0, A.length s) 0 (fn (ti, si) => + case sub (s, si) of + NONE => ti + | SOME x => (update (t, ti, x); ti+1)) + in + t + end + +fun serialHistogram eq hash s = + let + val n = AS.length s + val tn = Util.boundPow2 n + val tmask = Word64.fromInt (tn - 1) + val t = Array.array (tn, NONE) + + fun insert k = + let + fun probe i = + case sub (t, i) of + NONE => (update (t, i, SOME k); true) + | SOME k' => + if eq (k', k) then + false + else if i+1 = tn then + probe 0 + else + probe (i+1) + val h = Word64.toInt (Word64.andb (hash k, tmask)) + in + probe h + end + + val (sa, slo, sn) = AS.base s + val shi = slo+sn + val count = chunkedloop 1024 (slo, shi) 0 (fn (c, i) => + if insert (sub (sa, i)) + then c+1 + else c) + in + compactFilter 1024 t count + end + + +(* val dedup : ('k * 'k -> bool) equality check + -> ('k -> Word64.word) first hash function + -> ('k -> Word64.word) second hash function + -> 'k seq input (with duplicates) + -> 'k seq deduplicated (not sorted!) +*) +fun dedup eq hash hash' keys = + if AS.length keys = 0 then Seq.empty () else + let + val n = AS.length keys + val bucketBits = + if n < Util.pow2 27 + then (Util.log2 n - 7) div 2 + else Util.log2 n - 17 + val numBuckets = Util.pow2 (bucketBits + 1) + val bucketMask = Word64.fromInt (numBuckets - 1) + fun getBucket k = Word64.toInt (Word64.andb (hash k, bucketMask)) + fun ithKeyBucket i = getBucket (Seq.nth keys i) + val (bucketed, offsets) = CountingSort.sort keys ithKeyBucket numBuckets + fun offset i = Seq.nth offsets i + val tree = bucketTree numBuckets (fn i => + let + val bucketks = Seq.subseq bucketed (offset i, offset (i+1) - offset i) + in + serialHistogram eq hash' bucketks + end) + + val result = ForkJoin.alloc (count tree) + val _ = indexApp 1024 (fn (i, x) => update (result, i, x)) tree + in + AS.full result + end + +(* ========================================================================== + * now the main bit *) + +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: dedup [--verbose] [--no-output] FILE\n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val filename = + case CLA.positional () of + [x] => x + | _ => usage () + +val beVerbose = CommandLineArgs.parseFlag "verbose" +val noOutput = CommandLineArgs.parseFlag "no-output" +val rep = case (Int.fromString (CLA.parseString "repeat" "1")) of + SOME(a) => a + | NONE => 1 + +fun toWord str = + let + (* just cap at 32 for long strings *) + val n = Int.min (32, String.size str) + fun c i = Word64.fromInt (Char.ord (String.sub (str, i))) + fun loop h i = + if i >= n then h + else loop (Word64.+ (Word64.* (h, 0w31), c i)) (i+1) + in + loop 0w7 0 + end + +fun hash1 str = Util.hash64 (toWord str) +fun hash2 str = Util.hash64 (toWord str + 0w1111111) + +fun vprint str = + if not beVerbose then () + else TextIO.output (TextIO.stdErr, str) + +val (contents, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filename) +val _ = vprint ("read file in " ^ Time.fmt 4 tm ^ "s\n") +val (tokens, tm) = Util.getTime (fn _ => Tokenize.tokens Char.isSpace contents) +val _ = vprint ("tokenized in " ^ Time.fmt 4 tm ^ "s\n") + +fun dedupEx() = + dedup op= hash1 hash2 tokens + +val result = Benchmark.run "running dedup" dedupEx + +fun put c = TextIO.output1 (TextIO.stdOut, c) +val _ = + if noOutput then () + else + let + val (_, tm) = Util.getTime (fn _ => + ArraySlice.app (fn token => (print token; put #"\n")) result) + in + vprint ("output in " ^ Time.fmt 4 tm ^ "s\n") + end + +val _ = + if not beVerbose then () + else GCStats.report () diff --git a/tests/bench/delaunay-animation/DelaunayTriangulation.sml b/tests/bench/delaunay-animation/DelaunayTriangulation.sml new file mode 100644 index 000000000..bdfd5543e --- /dev/null +++ b/tests/bench/delaunay-animation/DelaunayTriangulation.sml @@ -0,0 +1,334 @@ +structure DelaunayTriangulation : +sig + type step = + { mesh: Topology2D.mesh + , updates: (Geometry2D.point * Topology2D.cavity) Seq.t + } + + val triangulate: Geometry2D.point Seq.t -> step Seq.t * Topology2D.mesh +end = +struct + + structure CLA = CommandLineArgs + + val showDelaunayRoundStats = CLA.parseFlag "show-delaunay-round-stats" + val maxBatchDiv = CLA.parseInt "max-batch-divisor" 10 + val reserveGrain = CLA.parseInt "reserve-grain" 20 + val ripAndTentGrain = CLA.parseInt "rip-and-tent-grain" 20 + val initialThreshold = CLA.parseInt "init-threshold" 10000 + val nnRebuildFactor = CLA.parseReal "nn-rebuild-factor" 10.0 + val batchSizeFrac = CLA.parseReal "batch-frac" 0.035 + + val reportTimes = false + + structure G = Geometry2D + structure T = Topology2D + structure NN = NearestNeighbors + structure A = Array + structure AS = ArraySlice + structure DSeq = DelayedSeq + + type vertex = T.vertex + type simplex = T.simplex + + type step = {mesh: T.mesh, updates: (G.point * T.cavity) Seq.t} + + val BOUNDARY_SIZE = 10 + + fun generateBoundary pts = + let + val p0 = Seq.nth pts 0 + val minCorner = Seq.reduce G.Point.minCoords p0 pts + val maxCorner = Seq.reduce G.Point.maxCoords p0 pts + val diagonal = G.Point.sub (maxCorner, minCorner) + val size = G.Vector.length diagonal + val stretch = 10.0 + val radius = stretch*size + val center = G.Vector.add (minCorner, G.Vector.scaleBy 0.5 diagonal) + + val vertexInfo = + { numVertices = Seq.length pts + BOUNDARY_SIZE + , numBoundaryVertices = BOUNDARY_SIZE + } + + val circleInfo = + {center=center, radius=radius} + in + T.initialMeshWithBoundaryCircle vertexInfo circleInfo + end + + + (* fun initialMesh pts = + let + val mesh = generateBoundary pts + + val totalNumVertices = T.numVertices boundaryMesh + Seq.length pts + val totalNumTriangles = T.numTriangles boundaryMesh + 2 * (Seq.length pts) + + val mesh = + T.new {numVertices = totalNumVertices, numTriangles = totalNumTriangles} + in + T.copyData {src = boundaryMesh, dst = mesh}; + + (mesh, T.numVertices boundaryMesh, T.numTriangles boundaryMesh) + end *) + + + fun writeMax a i x = + let + fun loop old = + if x <= old then () else + let + val old' = Concurrency.casArray (a, i) (old, x) + in + if old' = old then () + else loop old' + end + in + loop (A.sub (a, i)) + end + + + fun dsAppend (s, t) = + DSeq.tabulate (fn i => + if i < Seq.length s then + Seq.nth s i + else + Seq.nth t (i - Seq.length s)) + (Seq.length s + Seq.length t) + + + type nn = (Geometry2D.point -> vertex) + + + fun triangulate inputPts = + let + val t0 = Time.now () + + val maxBatch = Util.ceilDiv (Seq.length inputPts) maxBatchDiv + val mesh = generateBoundary inputPts + val totalNumVertices = T.numVertices mesh + + val reserved = + SeqBasis.tabulate 10000 (0, totalNumVertices) (fn _ => ~1) + + val allVertices = Seq.tabulate (fn i => i) (Seq.length inputPts) + + fun nearestSimplex nn pt = + (T.triangleOfVertex mesh (nn pt), 0) + + fun singleInsert start (id, pt) = + let + val center = #1 (T.findPoint mesh pt start) + in + T.ripAndTentCavity mesh center (id, pt) (2*id, 2*id+1) + end + + fun singleInsertLookupStart nn id = + let + val pt = Seq.nth inputPts id + in + singleInsert (nearestSimplex nn pt) (id, pt) + end + + fun batchInsert (nn: nn) (vertsToInsert: vertex DSeq.t) = + let + val m = DSeq.length vertsToInsert + + val centers = + AS.full (SeqBasis.tabulate reserveGrain (0, m) (fn i => + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = + #1 (T.findPoint mesh pt (nearestSimplex nn pt)) + val _ = + T.loopPerimeter mesh center pt () + (fn (_, v) => writeMax reserved v id) + in + center + end)) + + val winnerFlags = + AS.full (SeqBasis.tabulate ripAndTentGrain (0, m) (fn i => + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = Seq.nth centers i + val isWinner = + T.loopPerimeter mesh center pt true + (fn (allMine, v) => + if A.sub (reserved, v) = id then + (A.update (reserved, v, ~1); allMine) + else + false) + in + isWinner + end)) + + val winnerCavities = + AS.full (SeqBasis.tabFilter ripAndTentGrain (0, m) (fn i => + if not (Seq.nth winnerFlags i) then NONE else + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = Seq.nth centers i + in + SOME (T.findCavity mesh center pt) + end)) + + val () = + ForkJoin.parfor ripAndTentGrain (0, m) (fn i => + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = Seq.nth centers i + val isWinner = Seq.nth winnerFlags i + in + if not isWinner then () else + (** rip-and-tent needs to create 1 new vertex and 2 new + * triangles. The new vertex is `id`, and the new triangles + * are respectively `2*id` and `2*id+1`. This ensures unique + * names. + *) + T.ripAndTentCavity mesh center (id, pt) (2*id, 2*id+1) + end) + + val {true=winners, false=losers} = + Split.split vertsToInsert (DSeq.fromArraySeq winnerFlags) + in + (winners, losers, winnerCavities) + end + + fun shouldRebuild numNextRebuild numDone = + let + val n = Seq.length inputPts + in + numDone >= numNextRebuild + andalso + numDone <= Real.floor (Real.fromInt n / nnRebuildFactor) + end + + fun buildNN (done: vertex Seq.t) = + let + val pts = Seq.map (Seq.nth inputPts) done + val tree = NN.makeTree 16 pts + in + (fn pt => Seq.nth done (NN.nearestNeighbor tree pt)) + end + + fun doRebuildNN numNextRebuild doneVertices = + let + val nn = buildNN doneVertices + val numNextRebuild = + Real.ceil (Real.fromInt numNextRebuild * nnRebuildFactor) + in + if not showDelaunayRoundStats then () else + print ("rebuilt nn; next rebuild at " ^ Int.toString numNextRebuild ^ "\n"); + + (nn, numNextRebuild) + end + + + (** start by inserting points one-by-one until mesh is large enough *) + fun smallLoop numDone (nn, numNextRebuild) remaining = + if numDone >= initialThreshold orelse Seq.length remaining = 0 then + (numDone, nn, numNextRebuild, remaining) + else + let + val (id, remaining) = + (Seq.nth remaining 0, Seq.drop remaining 1) + val _ = singleInsertLookupStart nn id + val numDone = numDone+1 + + val (nn, numNextRebuild) = + if not (shouldRebuild numNextRebuild numDone) then + (nn, numNextRebuild) + else + doRebuildNN numNextRebuild (Seq.take allVertices numDone) + in + smallLoop numDone (nn, numNextRebuild) remaining + end + + + fun loop numRounds steps (done, numDone) (nn, numNextRebuild) losers remaining = + if numDone = Seq.length inputPts then + (numRounds, Seq.fromList (List.rev steps)) + else + let + val startMesh = T.copy mesh + + val numRetry = Seq.length losers + val totalRemaining = numRetry + Seq.length remaining + (* val numDone = Seq.length inputPts - totalRemaining *) + val desiredSize = + Int.min (maxBatch, Int.min (totalRemaining, + 1 + Real.round (Real.fromInt numDone * batchSizeFrac))) + val numAdditional = + Int.max (0, Int.min (desiredSize - numRetry, Seq.length remaining)) + val thisBatchSize = numAdditional + numRetry + + val newcomers = Seq.take remaining numAdditional + val remaining = Seq.drop remaining numAdditional + val (winners, losers, winnerCavities) = + batchInsert nn (dsAppend (losers, newcomers)) + + val thisStep = + { mesh = startMesh + , updates = + Seq.zip (Seq.map (Seq.nth inputPts) winners, winnerCavities) + } + val steps = thisStep :: steps + + val numSucceeded = thisBatchSize - Seq.length losers + val numDone = numDone + numSucceeded + val done = winners :: done + + val rate = Real.fromInt numSucceeded / Real.fromInt thisBatchSize + val pcRate = Real.round (100.0 * rate) + + val _ = + if not showDelaunayRoundStats then () else + print ("round " ^ Int.toString numRounds + ^ "\tdone " ^ Int.toString numDone + ^ "\tremaining " ^ Int.toString totalRemaining + ^ "\tdesired " ^ Int.toString desiredSize + ^ "\tretrying " ^ Int.toString numRetry + ^ "\tfresh " ^ Int.toString numAdditional + ^ "\tsuccess-rate " ^ Int.toString pcRate ^ "%\n") + + val (done, (nn, numNextRebuild)) = + if not (shouldRebuild numNextRebuild numDone) then + (done, (nn, numNextRebuild)) + else + let + val done = Seq.flatten (Seq.fromList done) + in + ([done], doRebuildNN numNextRebuild done) + end + in + loop (numRounds+1) steps (done, numDone) (nn, numNextRebuild) losers remaining + end + + val start: simplex = (2 * Seq.length inputPts, 0) + val _ = singleInsert start (0, Seq.nth inputPts 0) + val done = Seq.singleton 0 + val remaining = Seq.drop allVertices 1 + val numDone = 1 + + val nn = buildNN done + val numNextRebuild = 10 + + val (numDone, nn, numNextRebuild, remaining) = + smallLoop numDone (nn, numNextRebuild) remaining + + val done = [Seq.take allVertices numDone] + + val (numRounds, steps) = + loop 0 [] (done, numDone) (nn, numNextRebuild) (Seq.empty()) remaining + + in + (steps, mesh) + end + +end diff --git a/tests/bench/delaunay-animation/Split.sml b/tests/bench/delaunay-animation/Split.sml new file mode 100644 index 000000000..61f0c0f22 --- /dev/null +++ b/tests/bench/delaunay-animation/Split.sml @@ -0,0 +1,73 @@ +structure Split: +sig + type 'a dseq + type 'a seq + val split: 'a dseq -> bool dseq -> {true: 'a seq, false: 'a seq} +end = +struct + + structure A = Array + structure AS = ArraySlice + + structure DS = DelayedSeq + type 'a dseq = 'a DS.t + type 'a seq = 'a AS.slice + + fun split s flags = + let + val n = DS.length s + val blockSize = 2000 + val numBlocks = 1 + (n-1) div blockSize + + (* the later scan(s) appears to be faster when split into two separate + * scans, rather than doing a single scan on tuples. *) + + (* val counts = Primitives.alloc numBlocks *) + val countl = ForkJoin.alloc numBlocks + val countr = ForkJoin.alloc numBlocks + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop (cl, cr) i = + if i >= hi then + (* A.update (counts, b, (cl, cr)) *) + (A.update (countl, b, cl); A.update (countr, b, cr)) + else if DS.nth flags i then + loop (cl+1, cr) (i+1) + else + loop (cl, cr+1) (i+1) + in + loop (0, 0) lo + end) + + (* val (offsets, (totl, totr)) = + Seq.scan (fn ((a,b),(c,d)) => (a+c,b+d)) (0,0) (ArraySlice.full counts) *) + val (offsetsl, totl) = Seq.scan op+ 0 (AS.full countl) + val (offsetsr, totr) = Seq.scan op+ 0 (AS.full countr) + + val left = ForkJoin.alloc totl + val right = ForkJoin.alloc totr + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + (* val (offsetl, offsetr) = Seq.nth offsets b *) + val offsetl = Seq.nth offsetsl b + val offsetr = Seq.nth offsetsr b + fun loop (cl, cr) i = + if i >= hi then () + else if DS.nth flags i then + (A.update (left, offsetl+cl, DS.nth s i); loop (cl+1, cr) (i+1)) + else + (A.update (right, offsetr+cr, DS.nth s i); loop (cl, cr+1) (i+1)) + in + loop (0, 0) lo + end) + in + {true = AS.full left, false = AS.full right} + end + +end diff --git a/tests/bench/delaunay-animation/delaunay-animation.mlb b/tests/bench/delaunay-animation/delaunay-animation.mlb new file mode 100644 index 000000000..67d2cfe4e --- /dev/null +++ b/tests/bench/delaunay-animation/delaunay-animation.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +Split.sml +DelaunayTriangulation.sml +main.sml diff --git a/tests/bench/delaunay-animation/main.sml b/tests/bench/delaunay-animation/main.sml new file mode 100644 index 000000000..bfbd4d3ab --- /dev/null +++ b/tests/bench/delaunay-animation/main.sml @@ -0,0 +1,222 @@ +structure CLA = CommandLineArgs +structure T = Topology2D +structure DT = DelaunayTriangulation + +val n = CLA.parseInt "n" (1000 * 1000) +val seed = CLA.parseInt "seed" 15210 +val filename = CLA.parseString "input" "" + +fun generateInputPoints () = + let + fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + + fun genPoint i = (genReal (2*i), genReal (2*i + 1)) + + val (points, tm) = Util.getTime (fn _ => Seq.tabulate genPoint n) + val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + in + points + end + +fun parseInputFile () = + let + val (points, tm) = Util.getTime (fn _ => + ParseFile.readSequencePoint2d filename) + in + print ("parsed input points in " ^ Time.fmt 4 tm ^ "s\n"); + points + end + + +val input = + case filename of + "" => generateInputPoints () + | _ => parseInputFile () + + +val (steps, mesh) = Benchmark.run "delaunay" (fn _ => DT.triangulate input) +val _ = print ("num rounds " ^ Int.toString (Seq.length steps) ^ "\n") + +(* val _ = + print ("\n" ^ T.toString mesh ^ "\n") *) + + +(* ========================================================================== + * output result image + * only works if all input points are in range [0,1) + *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then () + else ( print ("\nto see output, use -output, -resolution, and -fps arguments\n" ^ + "for example: delaunay -n 1000 -output result.gif -resolution 1000 -fps 10.0\n") + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val fadeEdgeMargin = CLA.parseInt "fade-edge" 0 +val borderEdgeMargin = CLA.parseInt "border-edge" 0 +val fps = CLA.parseReal "fps" 10.0 + +val niceBackground: Color.color = + let + val c = Real.fromInt 0xFD / 255.0 + in + {red = c, green = c, blue = c, alpha = 1.0} + end +val niceBackgroundPx: Color.pixel = Color.colorToPixel niceBackground + +(* val bg = #fdfdfd *) + +(* val image = MeshToImage.toImage {mesh = mesh, resolution = resolution} *) + +(* fun makeRed b = + let + val c = Word8.fromInt (Real.ceil (255.0 * (1.0 - b))) + val color = {blue = c, green = c, red = 0w255} + in + color + end + +fun makeGray b = + Color.colorToPixel ({red = 0.5, blue = 0.5, green = 0.5, alpha = b}) *) + +(* val niceGray = Color.hsv {h=0.0, s=0.0, v=0.88} *) +(* val niceRed = Color.colorToPixel + (Color.hsva {h = 0.0, s = 0.55, v = 0.95, a = 0.8}) *) +(* val colors = [Color.white, Color.black, Color.red] +val palette = + GIF.Palette.summarizeBySampling colors 103 + (fn i => + if i < 50 then + (* 50 shades of gray *) + makeGray (Real.fromInt (1 + i mod 50) / 50.0) + else + (* ... and 50 shades of red *) + makeRed (Real.fromInt (1 + i mod 50) / 50.0)) *) + +fun fadeEdges margin (img as {width, height, data}) = + let + fun clampedDistFromZero x = + Real.fromInt (margin - (Int.max (0, margin - x))) / Real.fromInt margin + fun hbrightness col = + Real.min (clampedDistFromZero col, clampedDistFromZero (width - col - 1)) + fun vbrightness row = + Real.min (clampedDistFromZero row, clampedDistFromZero (height - row - 1)) + fun brightness (row, col) = + Real.max (0.0, Real.min (hbrightness col, vbrightness row)) + fun update (row, col) = + if 0 <= row andalso row < height andalso 0 <= col andalso col < width then + let + val px = Seq.nth data (row*width + col) + val {red, green, blue, ...} = Color.pixelToColor px + val alpha = brightness (row, col) + val px' = Color.colorToPixel (Color.overlayColor + { fg = {red=red, green=green, blue=blue, alpha=alpha} + , bg = niceBackground + }) + in + ArraySlice.update (data, row*width + col, px') + end + else () + + fun updateBox {topleft=(row0, col0), botright=(row1, col1)} = + Util.for (row0, row1) (fn row => + Util.for (col0, col1) (fn col => + update (row, col) + )) + in + updateBox {topleft = (0, 0), botright = (margin, width)}; + updateBox {topleft = (margin, 0), botright = (height-margin, margin)}; + updateBox {topleft = (margin, width-margin), botright = (height-margin, width)}; + updateBox {topleft = (height-margin, 0), botright = (height, width)}; + + img + end + +fun drawBorder margin (img as {width, height, data}) = + let + fun update (row, col) = + if 0 <= row andalso row < height andalso 0 <= col andalso col < width then + let + in + ArraySlice.update (data, row*width + col, Color.black) + end + else () + + fun updateBox {topleft=(row0, col0), botright=(row1, col1)} = + Util.for (row0, row1) (fn row => + Util.for (col0, col1) (fn col => + update (row, col) + )) + in + updateBox {topleft = (0, 0), botright = (margin, width)}; + updateBox {topleft = (margin, 0), botright = (height-margin, margin)}; + updateBox {topleft = (margin, width-margin), botright = (height-margin, width)}; + updateBox {topleft = (height-margin, 0), botright = (height, width)}; + + img + end + +(* val numImages = 2 * (Seq.length steps) + 1 *) +(* val numImages = 3 *) +val numImages = Seq.length steps + 1 +val images = SeqBasis.tabulate 1 (0, numImages) (fn i => + let + val j = i (*div 2*) + val mesh = + if j < Seq.length steps then #mesh (Seq.nth steps j) else mesh + val cavs = + if (*i mod 2 = 1 andalso*) j < Seq.length steps then + SOME (#updates (Seq.nth steps j)) + else + NONE + val img = + MeshToImage.toImage {mesh = mesh, resolution = resolution, cavities = cavs, background = niceBackground} + in + if fadeEdgeMargin > 0 then + fadeEdges fadeEdgeMargin img + else if borderEdgeMargin > 0 then + drawBorder borderEdgeMargin img + else + img + end) + +val t1 = Time.now () + +val _ = print ("generated frames in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val _ = print ("writing to " ^ filename ^"...\n") + +val palette = + GIF.Palette.summarizeBySampling [niceBackgroundPx, Color.black, Color.red] 256 + (fn i => + let + val j1 = Util.hash (2*i) mod (Array.length images) + val img = Array.sub (images, j1) + val j2 = Util.hash (2*i + 1) mod (Seq.length (#data img)) + in + Seq.nth (#data img) j2 + end) + +val msBetween = Real.round ((1.0 / fps) * 100.0) +val (_, tm) = Util.getTime (fn _ => + GIF.writeMany filename msBetween palette + { width = resolution + , height = resolution + , numImages = Array.length images + , getImage = fn i => #remap palette (Array.sub (images, i)) + }) +val _ = print ("wrote all frames in " ^ Time.fmt 4 tm ^ "s\n") + +(* val (_, tm) = Util.getTime (fn _ => PPM.write "first.ppm" (Array.sub (images, 1))) +val _ = print ("wrote to " ^ "first.ppm" ^ " in " ^ Time.fmt 4 tm ^ "s\n") *) diff --git a/tests/bench/delaunay-animation/test.sml b/tests/bench/delaunay-animation/test.sml new file mode 100644 index 000000000..a82638d33 --- /dev/null +++ b/tests/bench/delaunay-animation/test.sml @@ -0,0 +1,125 @@ +structure CLA = CommandLineArgs +structure T = Topology2D +structure DT = DelaunayTriangulation + +val (filename, testPtStr) = + case CLA.positional () of + [x, y] => (x, y) + | _ => Util.die "usage: ./foo " + +val testType = CLA.parseString "test" "split" + +val testPoint = + case List.mapPartial Real.fromString (String.tokens (fn c => c = #",") testPtStr) of + [x,y] => (x,y) + | _ => Util.die ("bad test point") + +val (mesh, tm) = Util.getTime (fn _ => T.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (T.numVertices mesh) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (T.numTriangles mesh) ^ "\n") + +val _ = print ("\n" ^ T.toString mesh ^ "\n\n") + +val start: T.simplex = (0, 0) + +fun simpToString (t, i) = + "triangle " ^ Int.toString t ^ " orientation " ^ Int.toString i ^ ": " ^ + let + val T.Tri {vertices=(a,b,c), ...} = T.tdata mesh t + in + String.concatWith " " (List.map Int.toString + (if i = 0 then [a,b,c] + else if i = 1 then [b,c,a] + else [c,a,b])) + end + +fun triToString t = + "triangle " ^ Int.toString t ^ ": " ^ + let + val T.Tri {vertices=(a,b,c), ...} = T.tdata mesh t + in + String.concatWith " " (List.map Int.toString [a,b,c]) + end + +val _ = Util.for (0, T.numVertices mesh) (fn v => + let + val s = T.find mesh v start + in + print ("found " ^ Int.toString v ^ ": " ^ simpToString s ^ "\n") + end) + + +(* ======================================================================== *) + +fun testSplit () = + let + val _ = print ("===================================\nTESTING SPLIT\n") + + val ((center, tris), verts) = T.findCavityAndPerimeter mesh start testPoint + val _ = + print ("CAVITY CENTER IS:\n " ^ triToString center ^ "\n") + val _ = + print ("CAVITY MEMBERS ARE:\n" + ^ String.concatWith "\n" (List.map (fn x => " " ^ simpToString x) tris) ^ "\n") + val _ = + print ("CAVITY PERIMETER VERTICES ARE:\n " + ^ String.concatWith " " (List.map (fn x => Int.toString x) verts) ^ "\n") + + + val mesh' = T.split mesh center testPoint + val _ = + print ("===================================\nAFTER SPLIT:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +fun testFlip () = + let + val _ = print ("===================================\nTESTING FLIP\n") + + val simp = T.findPoint mesh testPoint start + val _ = + print ("SIMPLEX CONTAINING POINT:\n " ^ simpToString simp ^ "\n") + + val mesh' = T.flip mesh simp + val _ = + print ("===================================\nAFTER FLIP:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +fun testRipAndTent () = + let + val _ = print ("===================================\nTESTING RIP-AND-TENT\n") + + val (cavity as (center, tris), verts) = + T.findCavityAndPerimeter mesh start testPoint + + val _ = + print ("CAVITY CENTER IS:\n " ^ triToString center ^ "\n") + val _ = + print ("CAVITY MEMBERS ARE:\n" + ^ String.concatWith "\n" (List.map (fn x => " " ^ simpToString x) tris) ^ "\n") + val _ = + print ("CAVITY PERIMETER VERTICES ARE:\n " + ^ String.concatWith " " (List.map (fn x => Int.toString x) verts) ^ "\n") + + val mesh' = T.ripAndTentOne (cavity, testPoint) mesh + val _ = + print ("===================================\nAFTER FLIP:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +val _ = + case testType of + "split" => testSplit () + | "flip" => testFlip () + | "rip-and-tent" => testRipAndTent () + | _ => Util.die ("unknown test type") diff --git a/tests/bench/delaunay-top-down/DelaunayTriangulationTopDown.sml b/tests/bench/delaunay-top-down/DelaunayTriangulationTopDown.sml new file mode 100644 index 000000000..5d56035bd --- /dev/null +++ b/tests/bench/delaunay-top-down/DelaunayTriangulationTopDown.sml @@ -0,0 +1,191 @@ +functor DelaunayTriangulationTopDown (structure R: REAL structure I: INTEGER) = +struct + + fun par3 (f1, f2, f3) = + let val (a, (b, c)) = ForkJoin.par (f1, fn _ => ForkJoin.par (f2, f3)) + in (a, b, c) + end + + type id = I.int + + fun r (x: real) : R.real = + R.fromLarge IEEEReal.TO_NEAREST (Real.toLarge x) + + fun ii x = I.fromInt x + + structure Point = + struct + datatype t = T of {id: id, x: R.real, y: R.real} + + fun x (T p) = #x p + fun y (T p) = #y p + fun id (T p) = #id p + fun id_cmp (p1, p2) = + I.compare (id p1, id p2) + fun id_less_than (p1, p2) = + I.< (id p1, id p2) + + type vec = LargeReal.real * LargeReal.real * LargeReal.real + + fun cross ((x1, y1, z1): vec, (x2, y2, z2): vec) = + (y1 * z2 - z1 * y2, z1 * x2 - x1 * z2, x1 * y2 - y1 * x2) + + fun dot ((x1, y1, z1): vec, (x2, y2, z2): vec) = + x1 * x2 + y1 * y2 + z1 * z2 + + fun project (d: t) (p: t) = + let + val px = R.toLarge (R.- (x p, x d)) + val py = R.toLarge (R.- (y p, y d)) + in + (px, py, LargeReal.+ (LargeReal.* (px, px), LargeReal.* (py, py))) + end + + fun in_circle (a: t, b: t, d: t) = + let val cp = cross (project d a, project d b) + in fn c => dot (cp, project d c) > 0.0 + end + end + + type tri = id * id * id + val dummy_tri = (ii ~1, ii ~1, ii ~1) + + type edge = id * id + val dummy_edge = (ii ~1, ii ~1) + + fun itow64 i = + Word64.fromLarge (LargeWord.fromInt (I.toInt i)) + + structure EH = + HashTable + (struct + type t = edge + val equal = op= + val default = dummy_edge + fun hash (i1, i2) = + Word64.toIntX (Word64.xorb + (Util.hash64 (itow64 i1), Util.hash64_2 (itow64 i2))) + end) + + structure TH = + HashTable + (struct + type t = tri + val equal = op= + val default = dummy_tri + fun hash (i1, i2, i3) = + Word64.toIntX + (Word64.xorb + ( Util.hash64 (itow64 i1) + , Word64.xorb (Util.hash64_2 (itow64 i2), Util.hash64_2 + (Util.hash64_2 (itow64 i3))) + )) + end) + + + type triangle = {tri: tri, conflicts: Point.t Seq.t} + + + fun filter_points points (t1: triangle, t2: triangle, t: tri) = + let + val a = Merge.merge Point.id_cmp (#conflicts t1, #conflicts t2) + val an = Seq.length a + + fun lookup_point id = + Seq.nth points (I.toInt id) + + val is_in_circle = Point.in_circle + (lookup_point (#1 t), lookup_point (#2 t), lookup_point (#3 t)) + + fun same_id i j = + Point.id_cmp (Seq.nth a i, Seq.nth a j) = EQUAL + + fun keep i = + (i <> 0) andalso not (same_id i (i - 1)) + andalso + ((i + 1 < an andalso same_id i (i + 1)) + orelse is_in_circle (Seq.nth a i)) + in + ArraySlice.full (SeqBasis.filter 2000 (0, an) (Seq.nth a) keep) + end + + + (* ~2/3 max load with a bit of slop *) + fun sloppy_capacity max_expected_elems = + 100 + (max_expected_elems * 3) div 2 + + + fun triangulate (points: Point.t Seq.t) = + let + val n = Seq.length points + + fun earliest ({tri, conflicts}: triangle) = + if Seq.length conflicts = 0 then ii n + else Point.id (Seq.nth conflicts 0) + + val edges = EH.make + { default = {tri = dummy_tri, conflicts = Seq.fromList []} + , capacity = sloppy_capacity (6 * n) + } + + val mesh = TH.make {default = (), capacity = sloppy_capacity (2 * n)} + + (* enclosing triangle *) + val p0 = Point.T {id = ii n, x = r 0.0, y = r 100.0} + val p1 = Point.T {id = ii (n + 1), x = r 100.0, y = r ~100.0} + val p2 = Point.T {id = ii (n + 2), x = r ~100.0, y = r ~100.0} + val enclosing_t = + {tri = (ii n, ii (n + 1), ii (n + 2)), conflicts = points} + + val all_points = Seq.append (points, Seq.fromList [p0, p1, p2]) + + fun process_edge (t1: triangle, e: edge, t2: triangle) = + if Seq.length (#conflicts t1) = 0 andalso Seq.length (#conflicts t2) = 0 then + (TH.insert mesh (#tri t1, ()); TH.insert mesh (#tri t2, ()); ()) + else if earliest t1 = earliest t2 then + () + else + let + val (t1, e, t2) = + if I.<= (earliest t1, earliest t2) then (t1, e, t2) + else (t2, (#2 e, #1 e), t1) + + val p = earliest t1 + val t = (#1 e, #2 e, p) + val t1 = {tri = t, conflicts = filter_points all_points (t1, t2, t)} + in + par3 + ( fn _ => check_edge ((p, #1 e), t1) + , fn _ => check_edge ((#2 e, p), t1) + , fn _ => process_edge (t1, e, t2) + ); + () + end + + + and check_edge (e: edge, tp: triangle) = + let + val key = if I.< (#1 e, #2 e) then e else (#2 e, #1 e) + in + if EH.insert edges (key, tp) then + () + else + case EH.remove edges key of + NONE => raise Fail "impossible?" + | SOME tt => process_edge (tp, e, tt) + end + + + val t = enclosing_t + val te = {tri = dummy_tri, conflicts = Seq.empty ()} + val _ = + par3 + ( fn _ => process_edge (t, (ii n, ii (n + 1)), te) + , fn _ => process_edge (t, (ii (n + 1), ii (n + 2)), te) + , fn _ => process_edge (t, (ii (n + 2), ii n), te) + ) + in + TH.keys mesh + end + +end diff --git a/tests/bench/delaunay-top-down/HashTable.sml b/tests/bench/delaunay-top-down/HashTable.sml new file mode 100644 index 000000000..26c162d1b --- /dev/null +++ b/tests/bench/delaunay-top-down/HashTable.sml @@ -0,0 +1,190 @@ +functor HashTable + (K: + sig + type t + val equal: t * t -> bool + val default: t + val hash: t -> int + end) = +struct + + structure Status = + struct + type t = Word8.word + val empty: t = 0w0 + val full: t = 0w1 + val locked: t = 0w2 + val tomb: t = 0w3 + end + + + datatype 'a entry = E of {status: Status.t ref, key: K.t ref, value: 'a ref} + + datatype 'a t = T of 'a entry array + type 'a table = 'a t + + exception Full + + fun make {default: 'a, capacity} : 'a table = + let + fun default_entry _ = + E {status = ref Status.empty, key = ref K.default, value = ref default} + val entries = SeqBasis.tabulate 5000 (0, capacity) default_entry + in + T entries + end + + + fun capacity (T entries) = Array.length entries + + fun status (T entries) i = + let val E {status, ...} = Array.sub (entries, i) + in status + end + + + fun size (t as T entries) = + SeqBasis.reduce 5000 op+ 0 (0, Array.length entries) (fn i => + if !(status t i) = Status.full then 0 else 1) + + + fun keys (t as T entries) = + let + fun key_at i = + let val E {key, ...} = Array.sub (entries, i) + in !key + end + + fun keep_at i = + !(status t i) = Status.full + in + ArraySlice.full + (SeqBasis.filter 2000 (0, Array.length entries) key_at keep_at) + end + + + (* fun unsafe_view_contents (tab as T {keys, values}) = + let + val capacity = Array.length keys + + fun elem i = + let + val k = Array.sub (keys, i) + in + if K.equal (k, K.default) then NONE + else SOME (k, Array.sub (values, i)) + end + in + DelayedSeq.tabulate elem (Array.length keys) + end *) + + + fun bcas (r, old, new) = + MLton.eq (old, Concurrency.cas r (old, new)) + + + (* fun atomic_combine_with (f: 'a * 'a -> 'a) (arr: 'a array, i) (x: 'a) = + let + fun loop current = + let + val desired = f (current, x) + in + if MLton.eq (desired, current) then + () + else + let + val current' = + MLton.Parallel.arrayCompareAndSwap (arr, i) (current, desired) + in + if MLton.eq (current', current) then () else loop current' + end + end + in + loop (Array.sub (arr, i)) + end *) + + + fun insert (t as T entries) (k, v) = + let + val n = Array.length entries + val tolerance = n + + fun claim_slot_at i expected = + bcas (status t i, expected, Status.locked) + + fun put_kv_at i = + let val E {status, key, value} = Array.sub (entries, i) + in key := k; value := v; bcas (status, Status.locked, Status.full) + end + + fun put_v_at i = + let val E {status, key, value} = Array.sub (entries, i) + in value := v; bcas (status, Status.locked, Status.full) + end + + fun loop i probes = + if probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let + val e as E {status, key, value} = Array.sub (entries, i) + val s = !status + in + if s = Status.full orelse s = Status.tomb then + if K.equal (!key, k) then + if s = Status.full then false + else if claim_slot_at i s then (put_v_at i; true) + else loop i probes + else + loop (i + 1) (probes + 1) + else if s = Status.empty then + if claim_slot_at i s then (put_kv_at i; true) else loop i probes + else + loop (i + 1) (probes + 1) + end + + val start = (K.hash k) mod (Array.length entries) + in + loop start 0 + end + + + fun remove (t as T entries) k = + let + val n = Array.length entries + val tolerance = n + + fun release_slot_at i = + bcas (status t i, Status.full, Status.tomb) + + fun loop i probes = + if probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let + val e as E {status, key, value} = Array.sub (entries, i) + val s = !status + in + if s = Status.empty orelse s = Status.locked then + NONE + else if K.equal (!key, k) then + if s = Status.full then + let val v = !value + in release_slot_at i; SOME v + end + else + NONE + else + loop (i + 1) (probes + 1) + end + + val start = (K.hash k) mod (Array.length entries) + in + loop start 0 + end + +end diff --git a/tests/bench/delaunay-top-down/README.md b/tests/bench/delaunay-top-down/README.md new file mode 100644 index 000000000..404e93dcb --- /dev/null +++ b/tests/bench/delaunay-top-down/README.md @@ -0,0 +1,28 @@ +This "top down" Delaunay implementation is a direct translation of this file: +https://github.com/cmuparlay/parlaylib/blob/e1f1dc0ccf930492a2723f7fbef8510d35bf57f5/examples/delaunay.h + +It is interesting algorithmically, but not +especially fast. I would be curious to see how well it performs in comparison +to the original C++ code. + +In comparison to the [MaPLe PBBS `delaunay`](https://github.com/MPLLang/parallel-ml-bench/tree/main/mpl/bench/delaunay), it is significantly slower +and less space efficient. I believe this is partly due to the use of two +(somewhat unoptimized) global hash tables. One stores the mesh of triangles, +and the other stores a set of outstanding edges that may need to be processed. +Note that the use of hash tables in this way results in memory entanglement, +which is another source of overhead. Disentangling the hash table access +patterns would be interesting to look into; it's not immediately clear to me if +this is possible. + +Some basic tests have been run but more testing is required. + +```bash +[parallel-ml-bench/mpl]$ make delaunay-top-down.mpl.bin +[parallel-ml-bench/mpl]$ bin/delaunay-top-down.mpl.bin @mpl procs 4 -- -n 10000 -output result.ppm -resolution 2000 + +# ... see image in result.ppm +``` + +**NOTE (9/2/25)**: Currently, there is an unresolved race condition which +sometimes results in portions of the triangulation being dropped. It may be a +bug in the hash table implementation. \ No newline at end of file diff --git a/tests/bench/delaunay-top-down/delaunay-top-down.mlb b/tests/bench/delaunay-top-down/delaunay-top-down.mlb new file mode 100644 index 000000000..abed8a2d9 --- /dev/null +++ b/tests/bench/delaunay-top-down/delaunay-top-down.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +HashTable.sml +DelaunayTriangulationTopDown.sml +main.sml diff --git a/tests/bench/delaunay-top-down/main.sml b/tests/bench/delaunay-top-down/main.sml new file mode 100644 index 000000000..779d173fd --- /dev/null +++ b/tests/bench/delaunay-top-down/main.sml @@ -0,0 +1,231 @@ +structure CLA = CommandLineArgs +structure T = Topology2D +structure R = Real32 +structure I = Int32 +structure DT = DelaunayTriangulationTopDown (structure R = R structure I = I) + +val n = CLA.parseInt "n" (1000 * 1000) +val seed = CLA.parseInt "seed" 15210 +val filename = CLA.parseString "input" "" + +fun generateInputPoints () = + let + fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + + fun genPoint i = + (genReal (2 * i), genReal (2 * i + 1)) + + val (points, tm) = Util.getTime (fn _ => Seq.tabulate genPoint n) + val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + in + points + end + +(* This silly thing helps ensure good placement, by + * forcing points to be reallocated more adjacent. + * It's a no-op, but gives us as much as 2x time + * improvement (!) + *) +fun swap pts = + Seq.map (fn (x, y) => (y, x)) pts +fun compactify pts = + swap (swap pts) + +fun parseInputFile () = + let + val (points, tm) = Util.getTime (fn _ => + compactify (ParseFile.readSequencePoint2d filename)) + in + print ("parsed input points in " ^ Time.fmt 4 tm ^ "s\n"); + points + end + + +val points = + case filename of + "" => generateInputPoints () + | _ => parseInputFile () + +val input = + Seq.mapIdx + (fn (i, (x, y)) => DT.Point.T {id = I.fromInt i, x = DT.r x, y = DT.r y}) + points + +val mesh = Benchmark.run "delaunay-top-down" (fn _ => DT.triangulate input) +val _ = print ("num triangles " ^ Int.toString (Seq.length mesh) ^ "\n") + +(* ===================================================================== *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then + () + else + ( print + ("\nto see output, use -output and -resolution arguments\n" + ^ + "for example: delaunay-top-down -n 1000 -output result.ppm -resolution 1000\n") + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val width = resolution +val height = resolution + +val image = + { width = width + , height = height + , data = Seq.tabulate (fn _ => Color.white) (width * height) + } + +fun set (i, j) x = + if 0 <= i andalso i < height andalso 0 <= j andalso j < width then + ArraySlice.update (#data image, i * width + j, x) + else + () + +fun setxy (x, y) z = + set (resolution - y - 1, x) z + +val r = Real.fromInt resolution +fun px x = + Real.floor (x * r + 0.5) + +fun ipart x = Real.floor x +fun fpart x = x - Real.realFloor x +fun rfpart x = 1.0 - fpart x + +(** input points should be in range [0,1] *) +fun aaLine (x0, y0) (x1, y1) = + if x1 < x0 then + aaLine (x1, y1) (x0, y0) + else + let + (** scale to resolution *) + val (x0, y0, x1, y1) = + (r * x0 + 0.5, r * y0 + 0.5, r * x1 + 0.5, r * y1 + 0.5) + + fun plot (x, y, c) = + let + val c = Word8.fromInt (Real.ceil (255.0 * (1.0 - c))) + val color = {blue = c, green = c, red = 0w255} + in + (* print (Int.toString x ^ " " ^ Int.toString y ^ "\n"); *) + setxy (x, y) color + end + + val dx = x1 - x0 + val dy = y1 - y0 + val yxSlope = dy / dx + val xySlope = dx / dy + (* val xhop = Real.fromInt (Real.sign dx) *) + (* val yhop = Real.fromInt (Real.sign dy) *) + + (* fun y x = x0 + (x-x0) * slope *) + + (** (x,y) = current point on the line *) + fun normalLoop (x, y) = + if x > x1 then + () + else + ( plot (ipart x, ipart y, rfpart y) + ; plot (ipart x, ipart y + 1, fpart y) + ; normalLoop (x + 1.0, y + yxSlope) + ) + + fun steepUpLoop (x, y) = + if y > y1 then + () + else + ( plot (ipart x, ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepUpLoop (x + xySlope, y + 1.0) + ) + + fun steepDownLoop (x, y) = + if y < y1 then + () + else + ( plot (ipart x, ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepDownLoop (x - xySlope, y - 1.0) + ) + in + if Real.abs dx > Real.abs dy then normalLoop (x0, y0) + else if y1 > y0 then steepUpLoop (x0, y0) + else steepDownLoop (x0, y0) + end + +(* draw all triangle edges as straight red lines *) +val _ = ForkJoin.parfor 1000 (0, Seq.length mesh) (fn i => + let + (* val _ = print ("triangle number " ^ Int.toString i ^ "\n") *) + (** cut off anything that is outside the image (not important other than + * a little faster this way). + *) + fun constrain (x, y) = + (Real.min (1.0, Real.max (0.0, x)), Real.min (1.0, Real.max (0.0, y))) + + (* fun vpos v = constrain (T.vdata mesh v) + fun doLineIf b (u, v) = + if b then aaLine (vpos u) (vpos v) else () *) + + fun doLineIf _ (id1, id2) = + let + val id1 = I.toInt id1 + val id2 = I.toInt id2 + in + if + id1 >= Seq.length points orelse id2 >= Seq.length points + orelse id1 < 0 orelse id2 < 0 + then + () + else + aaLine (constrain (Seq.nth points id1)) (constrain + (Seq.nth points id2)) + end + + val (u, v, w) = Seq.nth mesh i + (* val () = print + ("drawing " ^ I.toString u ^ " " ^ I.toString v ^ " " ^ I.toString w + ^ "\n") *) + in + (** TODO: ensure each line segment is only drawn once? There is overlap + * here with adjacent triangles. + *) + doLineIf true (w, u); + doLineIf true (u, v); + doLineIf true (v, w) + end + handle e => (print ("error at " ^ Int.toString i ^ "\n"); raise e)) + +val _ = print ("drew all triangles\n") + +(* mark input points as a pixel *) +val _ = ForkJoin.parfor 10000 (0, Seq.length points) (fn i => + let + val (x, y) = Seq.nth points i + val (x, y) = (px x, px y) + fun b spot = setxy spot Color.black + in + b (x - 1, y); + b (x, y - 1); + b (x, y); + b (x, y + 1); + b (x + 1, y) + end) + +val t1 = Time.now () + +val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val (_, tm) = Util.getTime (fn _ => PPM.write filename image) +val _ = print ("wrote to " ^ filename ^ " in " ^ Time.fmt 4 tm ^ "s\n") diff --git a/tests/bench/delaunay/DelaunayTriangulation.sml b/tests/bench/delaunay/DelaunayTriangulation.sml new file mode 100644 index 000000000..e02d6f011 --- /dev/null +++ b/tests/bench/delaunay/DelaunayTriangulation.sml @@ -0,0 +1,368 @@ +structure DelaunayTriangulation : +sig + val triangulate: Geometry2D.point Seq.t -> int * Topology2D.mesh +end = +struct + + structure CLA = CommandLineArgs + + val showDelaunayRoundStats = CLA.parseFlag "show-delaunay-round-stats" + val maxBatchDiv = CLA.parseInt "max-batch-divisor" 10 + val reserveGrain = CLA.parseInt "reserve-grain" 20 + val ripAndTentGrain = CLA.parseInt "rip-and-tent-grain" 20 + val initialThreshold = CLA.parseInt "init-threshold" 10000 + val nnRebuildFactor = CLA.parseReal "nn-rebuild-factor" 10.0 + val batchSizeFrac = CLA.parseReal "batch-frac" 0.035 + + val reportTimes = CLA.parseFlag "report-delaunay-times" + + structure G = Geometry2D + structure T = Topology2D + structure NN = NearestNeighbors + structure A = Array + structure AS = ArraySlice + structure DSeq = DelayedSeq + + type vertex = T.vertex + type simplex = T.simplex + + val BOUNDARY_SIZE = 10 + + fun generateBoundary pts = + let + val p0 = Seq.nth pts 0 + val minCorner = Seq.reduce G.Point.minCoords p0 pts + val maxCorner = Seq.reduce G.Point.maxCoords p0 pts + val diagonal = G.Point.sub (maxCorner, minCorner) + val size = G.Vector.length diagonal + val stretch = 10.0 + val radius = stretch*size + val center = G.Vector.add (minCorner, G.Vector.scaleBy 0.5 diagonal) + + val vertexInfo = + { numVertices = Seq.length pts + BOUNDARY_SIZE + , numBoundaryVertices = BOUNDARY_SIZE + } + + val circleInfo = + {center=center, radius=radius} + in + T.initialMeshWithBoundaryCircle vertexInfo circleInfo + end + + + (* fun initialMesh pts = + let + val mesh = generateBoundary pts + + val totalNumVertices = T.numVertices boundaryMesh + Seq.length pts + val totalNumTriangles = T.numTriangles boundaryMesh + 2 * (Seq.length pts) + + val mesh = + T.new {numVertices = totalNumVertices, numTriangles = totalNumTriangles} + in + T.copyData {src = boundaryMesh, dst = mesh}; + + (mesh, T.numVertices boundaryMesh, T.numTriangles boundaryMesh) + end *) + + + fun writeMax a i x = + let + fun loop old = + if x <= old then () else + let + val old' = Concurrency.casArray (a, i) (old, x) + in + if old' = old then () + else loop old' + end + in + loop (A.sub (a, i)) + end + + + fun dsAppend (s, t) = + DSeq.tabulate (fn i => + if i < Seq.length s then + Seq.nth s i + else + Seq.nth t (i - Seq.length s)) + (Seq.length s + Seq.length t) + + + type timers = (string * real ref) list + val timers = + [ ("reserve", ref 0.0) + , ("rip-and-tent", ref 0.0) + , ("split", ref 0.0) + , ("nn-rebuild", ref 0.0) + ] + + fun clearTimers () = + List.app (fn (_, t) => t := 0.0) timers + + fun updateTimer timers (name, tm) = + case timers of + [] => () + | (s, t) :: rest => + if name = s then + t := (!t + Time.toReal tm) + else + updateTimer rest (name, tm) + val updateTimer = updateTimer timers + + fun timed name f = + if not reportTimes then f () else + let + val (result, tm) = Util.getTime f + in + updateTimer (name, tm); + result + end + + fun getTimer timers name = + case timers of + [] => 0.0 + | (s, t) :: rest => if s = name then !t else getTimer rest name + val getTimer = getTimer timers + + fun timerSum () = + List.foldl (fn ((_, tm), t) => t + !tm) 0.0 timers + + + type nn = (Geometry2D.point -> vertex) + + + fun triangulate inputPts = + let + val t0 = Time.now () + + val maxBatch = Util.ceilDiv (Seq.length inputPts) maxBatchDiv + val mesh = generateBoundary inputPts + val totalNumVertices = T.numVertices mesh + + val reserved = + SeqBasis.tabulate 10000 (0, totalNumVertices) (fn _ => ~1) + + val allVertices = Seq.tabulate (fn i => i) (Seq.length inputPts) + + fun nearestSimplex nn pt = + (T.triangleOfVertex mesh (nn pt), 0) + + fun singleInsert start (id, pt) = + let + val center = #1 (T.findPoint mesh pt start) + in + T.ripAndTentCavity mesh center (id, pt) (2*id, 2*id+1) + end + + fun singleInsertLookupStart nn id = + let + val pt = Seq.nth inputPts id + in + singleInsert (nearestSimplex nn pt) (id, pt) + end + + fun batchInsert (nn: nn) (vertsToInsert: vertex DSeq.t) = + let + val m = DSeq.length vertsToInsert + + val centers = timed "reserve" (fn _ => + AS.full (SeqBasis.tabulate reserveGrain (0, m) (fn i => + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = + #1 (T.findPoint mesh pt (nearestSimplex nn pt)) + val _ = + T.loopPerimeter mesh center pt () + (fn (_, v) => writeMax reserved v id) + in + center + end))) + + val winnerFlags = timed "rip-and-tent" (fn _ => + AS.full (SeqBasis.tabulate ripAndTentGrain (0, m) (fn i => + let + val id = DSeq.nth vertsToInsert i + val pt = Seq.nth inputPts id + val center = Seq.nth centers i + val isWinner = + T.loopPerimeter mesh center pt true + (fn (allMine, v) => + if A.sub (reserved, v) = id then + (A.update (reserved, v, ~1); allMine) + else + false) + in + if not isWinner then () else + (** rip-and-tent needs to create 1 new vertex and 2 new + * triangles. The new vertex is `id`, and the new triangles + * are respectively `2*id` and `2*id+1`. This ensures unique + * names. + *) + T.ripAndTentCavity mesh center (id, pt) (2*id, 2*id+1); + + isWinner + end))) + + val {true=winners, false=losers} = timed "split" (fn _ => + Split.split vertsToInsert (DSeq.fromArraySeq winnerFlags)) + in + (winners, losers) + end + + fun shouldRebuild numNextRebuild numDone = + let + val n = Seq.length inputPts + in + numDone >= numNextRebuild + andalso + numDone <= Real.floor (Real.fromInt n / nnRebuildFactor) + end + + fun buildNN (done: vertex Seq.t) = + let + val pts = Seq.map (Seq.nth inputPts) done + val tree = NN.makeTree 16 pts + in + (fn pt => Seq.nth done (NN.nearestNeighbor tree pt)) + end + + fun doRebuildNN numNextRebuild doneVertices = + let + val nn = timed "nn-rebuild" (fn _ => buildNN doneVertices) + val numNextRebuild = + Real.ceil (Real.fromInt numNextRebuild * nnRebuildFactor) + in + if not showDelaunayRoundStats then () else + print ("rebuilt nn; next rebuild at " ^ Int.toString numNextRebuild ^ "\n"); + + (nn, numNextRebuild) + end + + + (** start by inserting points one-by-one until mesh is large enough *) + fun smallLoop numDone (nn, numNextRebuild) remaining = + if numDone >= initialThreshold orelse Seq.length remaining = 0 then + (numDone, nn, numNextRebuild, remaining) + else + let + val (id, remaining) = + (Seq.nth remaining 0, Seq.drop remaining 1) + val _ = singleInsertLookupStart nn id + val numDone = numDone+1 + + val (nn, numNextRebuild) = + if not (shouldRebuild numNextRebuild numDone) then + (nn, numNextRebuild) + else + doRebuildNN numNextRebuild (Seq.take allVertices numDone) + in + smallLoop numDone (nn, numNextRebuild) remaining + end + + + fun loop numRounds (done, numDone) (nn, numNextRebuild) losers remaining = + if numDone = Seq.length inputPts then + numRounds + else + let + val numRetry = Seq.length losers + val totalRemaining = numRetry + Seq.length remaining + (* val numDone = Seq.length inputPts - totalRemaining *) + val desiredSize = + Int.min (maxBatch, Int.min (totalRemaining, + 1 + Real.round (Real.fromInt numDone * batchSizeFrac))) + val numAdditional = + Int.max (0, Int.min (desiredSize - numRetry, Seq.length remaining)) + val thisBatchSize = numAdditional + numRetry + + val newcomers = Seq.take remaining numAdditional + val remaining = Seq.drop remaining numAdditional + val (winners, losers) = + batchInsert nn (dsAppend (losers, newcomers)) + + val numSucceeded = thisBatchSize - Seq.length losers + val numDone = numDone + numSucceeded + val done = winners :: done + + val rate = Real.fromInt numSucceeded / Real.fromInt thisBatchSize + val pcRate = Real.round (100.0 * rate) + + val _ = + if not showDelaunayRoundStats then () else + print ("round " ^ Int.toString numRounds + ^ "\tdone " ^ Int.toString numDone + ^ "\tremaining " ^ Int.toString totalRemaining + ^ "\tdesired " ^ Int.toString desiredSize + ^ "\tretrying " ^ Int.toString numRetry + ^ "\tfresh " ^ Int.toString numAdditional + ^ "\tsuccess-rate " ^ Int.toString pcRate ^ "%\n") + + val (done, (nn, numNextRebuild)) = + if not (shouldRebuild numNextRebuild numDone) then + (done, (nn, numNextRebuild)) + else + let + val done = Seq.flatten (Seq.fromList done) + in + ([done], doRebuildNN numNextRebuild done) + end + in + loop (numRounds+1) (done, numDone) (nn, numNextRebuild) losers remaining + end + + val start: simplex = (2 * Seq.length inputPts, 0) + val _ = singleInsert start (0, Seq.nth inputPts 0) + val done = Seq.singleton 0 + val remaining = Seq.drop allVertices 1 + val numDone = 1 + + (* val _ = print "inserted first\n" *) + + val nn = buildNN done + val numNextRebuild = 10 + + (* val _ = print ("built initial nn\n") *) + + val (numDone, nn, numNextRebuild, remaining) = + smallLoop numDone (nn, numNextRebuild) remaining + + (* val _ = print ("finished small loop\n") *) + + val done = [Seq.take allVertices numDone] + + val numRounds = + loop 0 (done, numDone) (nn, numNextRebuild) (Seq.empty()) remaining + + val t1 = Time.now () + val elapsed = Time.toReal (Time.- (t1, t0)) + + fun percent x = Real.round (100.0 * (x / elapsed)) + fun rtos x = Real.fmt (StringCvt.FIX (SOME 4)) x + fun stuff x = + rtos x ^ "s (" ^ Int.toString (percent x) ^ "%)" + + val _ = + if not reportTimes then () else + let + val _ = print ("----\n") + (* val _ = print ("find-centers " ^ stuff (getTimer "find-centers") ^ "\n") *) + val _ = print ("reserve " ^ stuff (getTimer "reserve") ^ "\n") + val _ = print ("rip-and-tent " ^ stuff (getTimer "rip-and-tent") ^ "\n") + val _ = print ("split " ^ stuff (getTimer "split") ^ "\n") + val _ = print ("nn-rebuild " ^ stuff (getTimer "nn-rebuild") ^ "\n") + val _ = print ("other " ^ stuff (elapsed - timerSum ()) ^ "\n") + + val _ = clearTimers () + in + () + end + + in + (numRounds, mesh) + end + +end diff --git a/tests/bench/delaunay/Split.sml b/tests/bench/delaunay/Split.sml new file mode 100644 index 000000000..61f0c0f22 --- /dev/null +++ b/tests/bench/delaunay/Split.sml @@ -0,0 +1,73 @@ +structure Split: +sig + type 'a dseq + type 'a seq + val split: 'a dseq -> bool dseq -> {true: 'a seq, false: 'a seq} +end = +struct + + structure A = Array + structure AS = ArraySlice + + structure DS = DelayedSeq + type 'a dseq = 'a DS.t + type 'a seq = 'a AS.slice + + fun split s flags = + let + val n = DS.length s + val blockSize = 2000 + val numBlocks = 1 + (n-1) div blockSize + + (* the later scan(s) appears to be faster when split into two separate + * scans, rather than doing a single scan on tuples. *) + + (* val counts = Primitives.alloc numBlocks *) + val countl = ForkJoin.alloc numBlocks + val countr = ForkJoin.alloc numBlocks + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop (cl, cr) i = + if i >= hi then + (* A.update (counts, b, (cl, cr)) *) + (A.update (countl, b, cl); A.update (countr, b, cr)) + else if DS.nth flags i then + loop (cl+1, cr) (i+1) + else + loop (cl, cr+1) (i+1) + in + loop (0, 0) lo + end) + + (* val (offsets, (totl, totr)) = + Seq.scan (fn ((a,b),(c,d)) => (a+c,b+d)) (0,0) (ArraySlice.full counts) *) + val (offsetsl, totl) = Seq.scan op+ 0 (AS.full countl) + val (offsetsr, totr) = Seq.scan op+ 0 (AS.full countr) + + val left = ForkJoin.alloc totl + val right = ForkJoin.alloc totr + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + (* val (offsetl, offsetr) = Seq.nth offsets b *) + val offsetl = Seq.nth offsetsl b + val offsetr = Seq.nth offsetsr b + fun loop (cl, cr) i = + if i >= hi then () + else if DS.nth flags i then + (A.update (left, offsetl+cl, DS.nth s i); loop (cl+1, cr) (i+1)) + else + (A.update (right, offsetr+cr, DS.nth s i); loop (cl, cr+1) (i+1)) + in + loop (0, 0) lo + end) + in + {true = AS.full left, false = AS.full right} + end + +end diff --git a/tests/bench/delaunay/delaunay.mlb b/tests/bench/delaunay/delaunay.mlb new file mode 100644 index 000000000..67d2cfe4e --- /dev/null +++ b/tests/bench/delaunay/delaunay.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +Split.sml +DelaunayTriangulation.sml +main.sml diff --git a/tests/bench/delaunay/main.sml b/tests/bench/delaunay/main.sml new file mode 100644 index 000000000..128684434 --- /dev/null +++ b/tests/bench/delaunay/main.sml @@ -0,0 +1,207 @@ +structure CLA = CommandLineArgs +structure T = Topology2D +structure DT = DelaunayTriangulation + +val n = CLA.parseInt "n" (1000 * 1000) +val seed = CLA.parseInt "seed" 15210 +val filename = CLA.parseString "input" "" + +fun generateInputPoints () = + let + fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + + fun genPoint i = (genReal (2*i), genReal (2*i + 1)) + + val (points, tm) = Util.getTime (fn _ => Seq.tabulate genPoint n) + val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + in + points + end + +(* This silly thing helps ensure good placement, by + * forcing points to be reallocated more adjacent. + * It's a no-op, but gives us as much as 2x time + * improvement (!) + *) +fun swap pts = Seq.map (fn (x, y) => (y, x)) pts +fun compactify pts = swap (swap pts) + +fun parseInputFile () = + let + val (points, tm) = Util.getTime (fn _ => + compactify (ParseFile.readSequencePoint2d filename)) + in + print ("parsed input points in " ^ Time.fmt 4 tm ^ "s\n"); + points + end + + +val input = + case filename of + "" => generateInputPoints () + | _ => parseInputFile () + + +val (numRounds, mesh) = Benchmark.run "delaunay" (fn _ => DT.triangulate input) +val _ = print ("num rounds " ^ Int.toString numRounds ^ "\n") + + +(* val _ = + print ("\n" ^ T.toString mesh ^ "\n") *) + + +(* ========================================================================== + * output result image + * only works if all input points are in range [0,1) + *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then () + else ( print ("\nto see output, use -output and -resolution arguments\n" ^ + "for example: delaunay -n 1000 -output result.ppm -resolution 1000\n") + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val width = resolution +val height = resolution + +val image = + { width = width + , height = height + , data = Seq.tabulate (fn _ => Color.white) (width*height) + } + +fun set (i, j) x = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then ArraySlice.update (#data image, i*width + j, x) + else () + +fun setxy (x, y) z = + set (resolution - y - 1, x) z + +val r = Real.fromInt resolution +fun px x = Real.floor (x * r + 0.5) + +fun ipart x = Real.floor x +fun fpart x = x - Real.realFloor x +fun rfpart x = 1.0 - fpart x + +(** input points should be in range [0,1] *) +fun aaLine (x0, y0) (x1, y1) = + if x1 < x0 then aaLine (x1, y1) (x0, y0) else + let + (** scale to resolution *) + val (x0, y0, x1, y1) = (r*x0 + 0.5, r*y0 + 0.5, r*x1 + 0.5, r*y1 + 0.5) + + fun plot (x, y, c) = + let + val c = Word8.fromInt (Real.ceil (255.0 * (1.0 - c))) + val color = {blue = c, green = c, red = 0w255} + in + (* print (Int.toString x ^ " " ^ Int.toString y ^ "\n"); *) + setxy (x, y) color + end + + val dx = x1-x0 + val dy = y1-y0 + val yxSlope = dy / dx + val xySlope = dx / dy + (* val xhop = Real.fromInt (Real.sign dx) *) + (* val yhop = Real.fromInt (Real.sign dy) *) + + (* fun y x = x0 + (x-x0) * slope *) + + (** (x,y) = current point on the line *) + fun normalLoop (x, y) = + if x > x1 then () else + ( plot (ipart x, ipart y , rfpart y) + ; plot (ipart x, ipart y + 1, fpart y) + ; normalLoop (x + 1.0, y + yxSlope) + ) + + fun steepUpLoop (x, y) = + if y > y1 then () else + ( plot (ipart x , ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepUpLoop (x + xySlope, y + 1.0) + ) + + fun steepDownLoop (x, y) = + if y < y1 then () else + ( plot (ipart x , ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepDownLoop (x - xySlope, y - 1.0) + ) + in + if Real.abs dx > Real.abs dy then + normalLoop (x0, y0) + else if y1 > y0 then + steepUpLoop (x0, y0) + else + steepDownLoop (x0, y0) + end + +(* +val _ = aaLine (0.25, 0.5) (0.25, 0.9) (* vertical *) +val _ = aaLine (0.25, 0.5) (0.5, 0.9) (* steep up *) +val _ = aaLine (0.25, 0.5) (0.5, 0.6) (* normal up *) +val _ = aaLine (0.25, 0.5) (0.5, 0.5) (* horizontal *) +val _ = aaLine (0.25, 0.5) (0.5, 0.4) (* normal down *) +val _ = aaLine (0.25, 0.5) (0.5, 0.1) (* steep down *) +val _ = aaLine (0.25, 0.5) (0.25, 0.1) (* vertical down *) +*) + +(* draw all triangle edges as straight red lines *) +val _ = ForkJoin.parfor 1000 (0, T.numTriangles mesh) (fn i => + let + (** cut off anything that is outside the image (not important other than + * a little faster this way). + *) + fun constrain (x, y) = + (Real.min (1.0, Real.max (0.0, x)), Real.min (1.0, Real.max (0.0, y))) + fun vpos v = constrain (T.vdata mesh v) + fun doLineIf b (u, v) = + if b then aaLine (vpos u) (vpos v) else () + + val T.Tri {vertices=(u,v,w), neighbors=(a,b,c)} = T.tdata mesh i + in + (** This ensures that each line segment is only drawn once. The person + * responsible for drawing it is the triangle with larger id. + *) + doLineIf (i > a) (w, u); + doLineIf (i > b) (u, v); + doLineIf (i > c) (v, w) + end) + +(* mark input points as a pixel *) +val _ = + ForkJoin.parfor 10000 (0, Seq.length input) (fn i => + let + val (x, y) = Seq.nth input i + val (x, y) = (px x, px y) + fun b spot = setxy spot Color.black + in + b (x-1, y); + b (x, y-1); + b (x, y); + b (x, y+1); + b (x+1, y) + end) + +val t1 = Time.now () + +val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val (_, tm) = Util.getTime (fn _ => PPM.write filename image) +val _ = print ("wrote to " ^ filename ^ " in " ^ Time.fmt 4 tm ^ "s\n") diff --git a/tests/bench/delaunay/test.sml b/tests/bench/delaunay/test.sml new file mode 100644 index 000000000..a82638d33 --- /dev/null +++ b/tests/bench/delaunay/test.sml @@ -0,0 +1,125 @@ +structure CLA = CommandLineArgs +structure T = Topology2D +structure DT = DelaunayTriangulation + +val (filename, testPtStr) = + case CLA.positional () of + [x, y] => (x, y) + | _ => Util.die "usage: ./foo " + +val testType = CLA.parseString "test" "split" + +val testPoint = + case List.mapPartial Real.fromString (String.tokens (fn c => c = #",") testPtStr) of + [x,y] => (x,y) + | _ => Util.die ("bad test point") + +val (mesh, tm) = Util.getTime (fn _ => T.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (T.numVertices mesh) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (T.numTriangles mesh) ^ "\n") + +val _ = print ("\n" ^ T.toString mesh ^ "\n\n") + +val start: T.simplex = (0, 0) + +fun simpToString (t, i) = + "triangle " ^ Int.toString t ^ " orientation " ^ Int.toString i ^ ": " ^ + let + val T.Tri {vertices=(a,b,c), ...} = T.tdata mesh t + in + String.concatWith " " (List.map Int.toString + (if i = 0 then [a,b,c] + else if i = 1 then [b,c,a] + else [c,a,b])) + end + +fun triToString t = + "triangle " ^ Int.toString t ^ ": " ^ + let + val T.Tri {vertices=(a,b,c), ...} = T.tdata mesh t + in + String.concatWith " " (List.map Int.toString [a,b,c]) + end + +val _ = Util.for (0, T.numVertices mesh) (fn v => + let + val s = T.find mesh v start + in + print ("found " ^ Int.toString v ^ ": " ^ simpToString s ^ "\n") + end) + + +(* ======================================================================== *) + +fun testSplit () = + let + val _ = print ("===================================\nTESTING SPLIT\n") + + val ((center, tris), verts) = T.findCavityAndPerimeter mesh start testPoint + val _ = + print ("CAVITY CENTER IS:\n " ^ triToString center ^ "\n") + val _ = + print ("CAVITY MEMBERS ARE:\n" + ^ String.concatWith "\n" (List.map (fn x => " " ^ simpToString x) tris) ^ "\n") + val _ = + print ("CAVITY PERIMETER VERTICES ARE:\n " + ^ String.concatWith " " (List.map (fn x => Int.toString x) verts) ^ "\n") + + + val mesh' = T.split mesh center testPoint + val _ = + print ("===================================\nAFTER SPLIT:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +fun testFlip () = + let + val _ = print ("===================================\nTESTING FLIP\n") + + val simp = T.findPoint mesh testPoint start + val _ = + print ("SIMPLEX CONTAINING POINT:\n " ^ simpToString simp ^ "\n") + + val mesh' = T.flip mesh simp + val _ = + print ("===================================\nAFTER FLIP:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +fun testRipAndTent () = + let + val _ = print ("===================================\nTESTING RIP-AND-TENT\n") + + val (cavity as (center, tris), verts) = + T.findCavityAndPerimeter mesh start testPoint + + val _ = + print ("CAVITY CENTER IS:\n " ^ triToString center ^ "\n") + val _ = + print ("CAVITY MEMBERS ARE:\n" + ^ String.concatWith "\n" (List.map (fn x => " " ^ simpToString x) tris) ^ "\n") + val _ = + print ("CAVITY PERIMETER VERTICES ARE:\n " + ^ String.concatWith " " (List.map (fn x => Int.toString x) verts) ^ "\n") + + val mesh' = T.ripAndTentOne (cavity, testPoint) mesh + val _ = + print ("===================================\nAFTER FLIP:\n" ^ T.toString mesh' ^ "\n") + in + () + end + +(* ======================================================================== *) + +val _ = + case testType of + "split" => testSplit () + | "flip" => testFlip () + | "rip-and-tent" => testRipAndTent () + | _ => Util.die ("unknown test type") diff --git a/tests/bench/dense-matmul/dense-matmul.mlb b/tests/bench/dense-matmul/dense-matmul.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/dense-matmul/dense-matmul.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/dense-matmul/main.sml b/tests/bench/dense-matmul/main.sml new file mode 100644 index 000000000..05e906b65 --- /dev/null +++ b/tests/bench/dense-matmul/main.sml @@ -0,0 +1,29 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "n" 1024 +val _ = + if Util.boundPow2 n = n then () + else Util.die "sidelength N must be a power of two" + +val _ = print ("generating matrices of sidelength " ^ Int.toString n ^ "\n") +val input = TreeMatrix.tabulate n (fn (i, j) => 1.0) + +val result = + Benchmark.run "multiplying" (fn _ => TreeMatrix.multiply (input, input)) + +val doCheck = CLA.parseFlag "check" +val _ = + if not doCheck then () else + let + val stuff = TreeMatrix.flatten result + val correct = + Array.length stuff = n * n + andalso + SeqBasis.reduce 1000 (fn (a, b) => a andalso b) true (0, Array.length stuff) + (fn i => Util.closeEnough (Array.sub (stuff, i), Real.fromInt n)) + in + print ("correct? "); + if correct then print "yes" else print "no"; + print "\n" + end + diff --git a/tests/bench/fib/fib.mlb b/tests/bench/fib/fib.mlb new file mode 100644 index 000000000..9435124ee --- /dev/null +++ b/tests/bench/fib/fib.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +fib.sml diff --git a/tests/bench/fib/fib.sml b/tests/bench/fib/fib.sml new file mode 100644 index 000000000..c989ac710 --- /dev/null +++ b/tests/bench/fib/fib.sml @@ -0,0 +1,41 @@ +structure CLA = CommandLineArgs + +val grain = CLA.parseInt "grain" 20 + +fun sfib n = + if n <= 1 then n else sfib (n-1) + sfib (n-2) + +fun fib n = + if n <= grain then sfib n + else + let + val (x,y) = ForkJoin.par (fn _ => fib (n-1), fn _ => fib (n-2)) + in + x + y + end + +fun fully_par_fib n = + if n < 2 then n + else op+ (ForkJoin.par (fn _ => fully_par_fib (n-1), fn _ => fully_par_fib (n-2))) + +val no_gran_control = CLA.parseFlag "no-gran-control" +val n = CLA.parseInt "N" 39 +val _ = print ("N " ^ Int.toString n ^ "\n") + +val result = Benchmark.run "running fib" (fn _ => + if no_gran_control then + fully_par_fib n + else + fib n) + +val _ = print ("result " ^ Int.toString result ^ "\n") + +val doCheck = CLA.parseFlag "check" +val _ = + if not doCheck then + print ("do --check to check correctness\n") + else if result = sfib n then + print ("correct? yes\n") + else + print ("correct? no\n") + diff --git a/tests/bench/flatten/AllBSFlatten.sml b/tests/bench/flatten/AllBSFlatten.sml new file mode 100644 index 000000000..b271ea8ef --- /dev/null +++ b/tests/bench/flatten/AllBSFlatten.sml @@ -0,0 +1,19 @@ +structure AllBSFlatten = +struct + + fun flatten s = + let + val (offsets, total) = Seq.scan op+ 0 (Seq.map Seq.length s) + + fun getElem i = + let + val segIdx = (BinarySearch.numLeq offsets i) - 1 + val segOff = Seq.nth offsets segIdx + in + Seq.nth (Seq.nth s segIdx) (i - segOff) + end + in + Seq.tabulate getElem total + end + +end diff --git a/tests/bench/flatten/BinarySearch.sml b/tests/bench/flatten/BinarySearch.sml new file mode 100644 index 000000000..17d81e011 --- /dev/null +++ b/tests/bench/flatten/BinarySearch.sml @@ -0,0 +1,96 @@ +structure BinarySearch: +sig + val numLt: int Seq.t -> int -> int + val numLeq: int Seq.t -> int -> int + + (** Compute both of the above. + * Should be slightly faster than two individual calls. + *) + val numLtAndLeq: int Seq.t -> int -> (int * int) +end = +struct + + fun lowSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, Seq.nth xs mid) of + LESS => lowSearch x xs (lo, mid) + | GREATER => lowSearch x xs (mid + 1, hi) + | EQUAL => lowSearchEq x xs (lo, mid) + end + + and lowSearchEq x xs (lo, mid) = + if (mid = 0) orelse (x > Seq.nth xs (mid-1)) + then mid + else lowSearch x xs (lo, mid) + + and highSearch x xs (lo, hi) = + case hi - lo of + 0 => lo + | n => let val mid = lo + n div 2 + in case Int.compare (x, Seq.nth xs mid) of + LESS => highSearch x xs (lo, mid) + | GREATER => highSearch x xs (mid + 1, hi) + | EQUAL => highSearchEq x xs (mid, hi) + end + + and highSearchEq x xs (mid, hi) = + if (mid = Seq.length xs - 1) orelse (x < Seq.nth xs (mid + 1)) + then mid + 1 + else highSearch x xs (mid + 1, hi) + + and search (x : int) (xs : int Seq.t) (lo, hi) : int * int = + case hi - lo of + 0 => (lo, lo) + | n => let val mid = lo + n div 2 + in case Int.compare (x, Seq.nth xs mid) of + LESS => search x xs (lo, mid) + | GREATER => search x xs (mid + 1, hi) + | EQUAL => (lowSearchEq x xs (lo, mid), highSearchEq x xs (mid, hi)) + end + + fun numLtAndLeq s x = search x s (0, Seq.length s) + + + fun numLeq (s: int Seq.t) (x: int) = + let + (* val _ = + print ("numLeq (" ^ Seq.toString Int.toString s ^ ") " ^ Int.toString x ^ "\n") *) + fun loop lo hi = + case hi - lo of + 0 => lo + | 1 => (if x < Seq.nth s lo then lo else hi) + | n => + let + val mid = lo + n div 2 + in + if x < Seq.nth s mid then + loop lo mid + else + loop (mid+1) hi + end + in + loop 0 (Seq.length s) + end + + fun numLt (s: int Seq.t) (x: int) = + let + fun loop lo hi = + case hi - lo of + 0 => lo + | 1 => (if x <= Seq.nth s lo then lo else hi) + | n => + let + val mid = lo + n div 2 + in + if x <= Seq.nth s mid then + loop lo mid + else + loop (mid+1) hi + end + in + loop 0 (Seq.length s) + end + +end diff --git a/tests/bench/flatten/BlockedAllBSFlatten.sml b/tests/bench/flatten/BlockedAllBSFlatten.sml new file mode 100644 index 000000000..e50504b19 --- /dev/null +++ b/tests/bench/flatten/BlockedAllBSFlatten.sml @@ -0,0 +1,71 @@ +structure BlockedAllBSFlatten = +struct + + val blockSize = CommandLineArgs.parseInt "block-size" 16 + + fun flatten s = + let + val (offsets, total) = Seq.scan op+ 0 (Seq.map Seq.length s) + + (* val _ = print ("offsets " ^ Seq.toString Int.toString offsets ^ "\n") *) + + val numBlocks = Util.ceilDiv total blockSize + + (* fun blockOffsetLo b = + BinarySearch.numLeq offsets (b * blockSize) - 1 *) + (* fun blockOffsetHi b = + BinarySearch.numLt offsets ((b+1) * blockSize) *) + + (* val (blockOffsetLos, tm1) = Util.getTime (fn _ => + Seq.tabulate blockOffsetLo numBlocks + ) *) + (* val (blockOffsetHis, tm2) = Util.getTime (fn _ => + Seq.tabulate blockOffsetHi numBlocks + ) *) + + (* val _ = print ("offlos " ^ Time.fmt 4 tm1 ^ "\n") + val _ = print ("offhis " ^ Time.fmt 4 tm2 ^ "\n") *) + + fun boundary b = + BinarySearch.numLtAndLeq offsets (b * blockSize) + + val (boundaries, tm) = Util.getTime (fn _ => + Seq.tabulate boundary (numBlocks+1) + ) + val _ = print ("boundaries " ^ Time.fmt 4 tm ^ "\n") + + fun getElem i = + let + val blockNum = i div blockSize + (* val offlo = Seq.nth blockOffsetLos blockNum + val offhi = Seq.nth blockOffsetHis blockNum *) + val offlo = #2 (Seq.nth boundaries blockNum) - 1 + val offhi = #1 (Seq.nth boundaries (blockNum+1)) + val segIdx = + offlo + (BinarySearch.numLeq (Seq.subseq offsets (offlo, offhi-offlo))) i - 1 + (* val _ = + print ("getElem " ^ Int.toString i ^ + " blockNum " ^ Int.toString blockNum ^ + " offlo " ^ Int.toString offlo ^ + " offhi " ^ Int.toString offhi ^ + " segIdx " ^ Int.toString segIdx ^ + "\n") *) + val segOff = Seq.nth offsets segIdx + in + Seq.nth (Seq.nth s segIdx) (i - segOff) + end + (* handle Subscript => + ( print ("getElem " ^ Int.toString i ^ "\n") + ; raise Subscript + ) *) + + val (result, tm3) = Util.getTime (fn _ => + Seq.tabulate getElem total + ) + + val _ = print ("tabulate " ^ Time.fmt 4 tm3 ^ "\n") + in + result + end + +end diff --git a/tests/bench/flatten/ExpandFlatten.sml b/tests/bench/flatten/ExpandFlatten.sml new file mode 100644 index 000000000..30ff1e945 --- /dev/null +++ b/tests/bench/flatten/ExpandFlatten.sml @@ -0,0 +1,102 @@ +structure ExpandFlatten = +struct + + val expansionFactor = CommandLineArgs.parseInt "f" 8 + val targetOffsetsPerElem = CommandLineArgs.parseInt "off-per-elem" 8 + val targetBlockSize = CommandLineArgs.parseInt "min-block-size" 64 + + fun reportTime msg f = + let + val (result, tm) = Util.getTime f + in + print (msg ^ " " ^ Time.fmt 4 tm ^ "\n"); + result + end + + fun seqNumLeq (xs: int Seq.t) x = + let + val n = Seq.length xs + fun loop i = + if i < n andalso Seq.nth xs i <= x then + loop (i+1) + else + i + in + loop 0 + end + + + fun tabulateG grain f n = + ArraySlice.full (SeqBasis.tabulate grain (0, n) f) + + + fun flatten s = + let + val (offsets, total) = Seq.scan op+ 0 (Seq.map Seq.length s) + + fun expand (prevBoundaries, prevBlockSize) = + let + val newBlockSize = prevBlockSize div expansionFactor + + fun newBoundary b = + let + val blockNum = (b * newBlockSize) div prevBlockSize + val offlo = #2 (Seq.nth prevBoundaries blockNum) - 1 + val offhi = + if blockNum+1 < Seq.length prevBoundaries then + #1 (Seq.nth prevBoundaries (blockNum+1)) + else + Seq.length offsets + + val (nlt, nleq) = + BinarySearch.numLtAndLeq + (Seq.subseq offsets (offlo, offhi-offlo)) + (b * newBlockSize) + in + (offlo + nlt, offlo + nleq) + end + + val newNumBlocks = Util.ceilDiv total newBlockSize + val bs = + reportTime "boundaries" (fn _ => + Seq.tabulate newBoundary newNumBlocks) + in + (bs, newBlockSize) + end + + fun expansionLoop (boundaries, blockSize) = + if + Seq.length boundaries >= Seq.length offsets div targetOffsetsPerElem + orelse blockSize <= targetBlockSize + then + (boundaries, blockSize) + else + expansionLoop (expand (boundaries, blockSize)) + + (** Initial block is the whole array, but we need to compute the proper + * boundary for the start of that block. Then we expand until we + * hit the target. + *) + val init = (Seq.fromList [BinarySearch.numLtAndLeq offsets 0], total) + val (boundaries, blockSize) = expansionLoop init + + (** Finally, pull individual elements *) + fun getElem i = + let + val blockNum = i div blockSize + val offlo = #2 (Seq.nth boundaries blockNum) - 1 + val segIdx = + offlo - 1 + + (seqNumLeq (Seq.drop offsets offlo) i) + val segOff = Seq.nth offsets segIdx + in + Seq.nth (Seq.nth s segIdx) (i - segOff) + end + + val result = reportTime "tabulate" (fn _ => + Seq.tabulate getElem total) + in + result + end + +end diff --git a/tests/bench/flatten/FullExpandPow2Flatten.sml b/tests/bench/flatten/FullExpandPow2Flatten.sml new file mode 100644 index 000000000..45cd4be2f --- /dev/null +++ b/tests/bench/flatten/FullExpandPow2Flatten.sml @@ -0,0 +1,132 @@ +structure FullExpandPow2Flatten = +struct + + fun biggestPow2LessOrEqualTo x = + let + fun loop y = if 2*y > x then y else loop (2*y) + in + loop 1 + end + + (** Choose segment indices for n elements based on the given segment offsets. + * + * EXAMPLE INPUT: + * offsets = [0,3,4,4,4,4,7,8] + * n = 10 + * OUTPUT: + * [ 0,0,0, 1, 5,5,5, 6, 7,7 ] + * + *) + fun pickSegments (offsets, n) = + let + fun offmids half lo (offlo, offhi) = + let + val mid = lo + half + val (offmidlo, offmidhi) = + BinarySearch.numLtAndLeq (Seq.subseq offsets (offlo, offhi-offlo)) mid + in + (offlo + offmidlo, offlo + offmidhi - 1) + end + + fun loop toWidth width results = + if width = toWidth then + results + else + let + fun getOffmids i = + offmids (width div 2) (i * width) (Seq.nth results i) + + val (O, tm) = Util.getTime (fn _ => + Seq.tabulate getOffmids (Seq.length results) + ) + val _ = + if width > 64 then () + else print ("offsets " ^ Int.toString width ^ ": " ^ Time.fmt 4 tm ^ "\n") + + fun get i = + let + val i' = i div 2 + val lo = i' * width + val hi = lo + width + val (offlo, offhi) = Seq.nth results i' + val (offmidlo, offmidhi) = Seq.nth O i' + val mid = lo + (hi - lo) div 2 + in + if i mod 2 = 0 then + (offlo, offmidlo) (* "left" *) + else + (offmidhi, offhi) (* "right" *) + end + val (results', tm) = Util.getTime (fn _ => + Seq.tabulate get (2 * Seq.length results) + ) + val _ = + if width > 64 then () + else print ("expand " ^ Int.toString width ^ ": " ^ Time.fmt 4 tm ^ "\n") + in + loop toWidth (width div 2) results' + end + + (** (width, results) represents current prefix that we've finished + * processing: it's been decomposed into some number of subsequences + * each of the given width. + * + * (offlo, lo) is the remaining suffix, where lo is the starting index + * of the suffix and offlo is the segment index for lo. + *) + fun handleNonPow2Loop (width, results) (offlo, lo) = + if lo >= n then + loop 1 width results + else + let + val remainingSize = n - lo + val targetWidth = biggestPow2LessOrEqualTo remainingSize + val results' = loop targetWidth width results + + val (offmidlo, offmidhi) = offmids targetWidth lo (offlo, Seq.length offsets) + val new = (offlo, offmidlo) + val results'' = Seq.append (results', Seq.fromList [new]) + in + handleNonPow2Loop (targetWidth, results'') (offmidhi, lo+targetWidth) + end + + val targetWidth = biggestPow2LessOrEqualTo n + val offlo = (BinarySearch.numLeq offsets 0) - 1 + val (offmidlo, offmidhi) = offmids targetWidth 0 (0, Seq.length offsets) + val init = Seq.fromList [(offlo, offmidlo)] + in + Seq.map #1 (handleNonPow2Loop (targetWidth, init) (offmidhi, targetWidth)) + end + + fun flatten s = + let + val n = Seq.length s + val ((offsets, total), tm1) = Util.getTime (fn _ => + Seq.scan op+ 0 (Seq.map Seq.length s) + ) + + val (segIdxs, tm2) = Util.getTime (fn _ => + pickSegments (offsets, total) + ) + + fun getElem i = + let + val segIdx = Seq.nth segIdxs i + val segOff = Seq.nth offsets segIdx + val j = i - segOff + in + Seq.nth (Seq.nth s segIdx) j + end + + val (result, tm3) = Util.getTime (fn _ => + Seq.tabulate getElem total + ) + + val _ = print ("scan: " ^ Time.fmt 4 tm1 ^ "\n") + val _ = print ("pickSegments: " ^ Time.fmt 4 tm2 ^ "\n") + val _ = print ("tabulate: " ^ Time.fmt 4 tm3 ^ "\n") + in + result + end + +end diff --git a/tests/bench/flatten/MultiBlockedBSFlatten.sml b/tests/bench/flatten/MultiBlockedBSFlatten.sml new file mode 100644 index 000000000..6e43caba1 --- /dev/null +++ b/tests/bench/flatten/MultiBlockedBSFlatten.sml @@ -0,0 +1,105 @@ +structure MultiBlockedBSFlatten = +struct + + val blockSizesStr = CommandLineArgs.parseString "bs" "1000,50" + val doReport = CommandLineArgs.parseFlag "report-times" + + val blockSizes = + List.map (valOf o Int.fromString) + (String.tokens (fn c => c = #",") blockSizesStr) + handle _ => raise Fail ("error parsing block sizes '" ^ blockSizesStr ^ "'") + + fun reportTime msg f = + let + val (result, tm) = Util.getTime f + in + if not doReport then () else print (msg ^ " " ^ Time.fmt 4 tm ^ "\n"); + result + end + + fun seqNumLeq (xs: int Seq.t) x = + let + val n = Seq.length xs + fun loop i = + if i < n andalso Seq.nth xs i <= x then + loop (i+1) + else + i + in + loop 0 + end + + + fun tabulateG grain f n = + ArraySlice.full (SeqBasis.tabulate grain (0, n) f) + + fun flatten s = + let + val (offsets, total) = Seq.scan op+ 0 (Seq.map Seq.length s) + + fun expand (prevBoundaries, prevBlockSize) newBlockSize = + let + fun newBoundary b = + let + val blockNum = (b * newBlockSize) div prevBlockSize + val offlo = #2 (Seq.nth prevBoundaries blockNum) - 1 + val offhi = + if blockNum+1 < Seq.length prevBoundaries then + #1 (Seq.nth prevBoundaries (blockNum+1)) + else + Seq.length offsets + + val (nlt, nleq) = + BinarySearch.numLtAndLeq + (Seq.subseq offsets (offlo, offhi-offlo)) + (b * newBlockSize) + in + (offlo + nlt, offlo + nleq) + end + + val newNumBlocks = Util.ceilDiv total newBlockSize + val bs = + reportTime "boundaries" (fn _ => + Seq.tabulate newBoundary newNumBlocks) + in + (bs, newBlockSize) + end + + (** compute initial boundary *) + val blockSize = List.hd blockSizes + val numBlocks1 = Util.ceilDiv total blockSize + fun boundary1 b = BinarySearch.numLtAndLeq offsets (b * blockSize) + val boundaries = + reportTime "boundaries" (fn _ => + Seq.tabulate boundary1 numBlocks1) + + fun expansionLoop (boundaries, blockSize) bs = + case bs of + [] => (boundaries, blockSize) + | b :: bs' => + expansionLoop (expand (boundaries, blockSize) b) bs' + + (** expand boundaries a few times *) + val (boundaries, blockSize) = + expansionLoop (boundaries, blockSize) (List.tl blockSizes) + + (** pull individual elements *) + fun getElem i = + let + val blockNum = i div blockSize + val offlo = #2 (Seq.nth boundaries blockNum) - 1 + val segIdx = + offlo - 1 + + (seqNumLeq (Seq.drop offsets offlo) i) + val segOff = Seq.nth offsets segIdx + in + Seq.nth (Seq.nth s segIdx) (i - segOff) + end + + val result = reportTime "tabulate" (fn _ => + Seq.tabulate getElem total) + in + result + end + +end diff --git a/tests/bench/flatten/SimpleBlockedFlatten.sml b/tests/bench/flatten/SimpleBlockedFlatten.sml new file mode 100644 index 000000000..ee0f1d1e7 --- /dev/null +++ b/tests/bench/flatten/SimpleBlockedFlatten.sml @@ -0,0 +1,66 @@ +structure SimpleBlockedFlatten = +struct + + val blockSize = CommandLineArgs.parseInt "block-size" 4096 + val doReport = CommandLineArgs.parseFlag "report-times" + + fun reportTime msg f = + let + val (result, tm) = Util.getTime f + in + if not doReport then () else print (msg ^ " " ^ Time.fmt 4 tm ^ "\n"); + result + end + + fun tabulateG grain f n = + ArraySlice.full (SeqBasis.tabulate grain (0, n) f) + + fun flatten s = + let + val (offsets, total) = reportTime "scan" (fn _ => + Seq.scan op+ 0 (Seq.map Seq.length s)) + + val numBlocks = Util.ceilDiv total blockSize + + fun getBlock bidx = + let + val lo = bidx * blockSize + val hi = (if bidx+1 = numBlocks then total else lo + blockSize) + val size = hi - lo + + fun loop (count, elems) (segIdx, i) = + if count < size andalso i < Seq.length (Seq.nth s segIdx) then + loop (count+1, Seq.nth (Seq.nth s segIdx) i :: elems) (segIdx, i+1) + else if count >= size then + elems + else + loop (count, elems) (segIdx+1, 0) + + val segIdx = BinarySearch.numLeq offsets lo - 1 + val segOff = Seq.nth offsets segIdx + val elems = loop (0, []) (segIdx, lo - segOff) + in + Vector.fromList elems + end + + val blocks = reportTime "tab blocks" (fn _ => + tabulateG 1 getBlock numBlocks) + + fun getElem i = + let + val bidx = i div blockSize + val blo = bidx * blockSize + val block = Seq.nth blocks bidx + val blen = Vector.length block + in + (** The vector is reversed, because built from list. *) + Vector.sub (block, blen - 1 - (i - blo)) + end + + val result = reportTime "tabulate" (fn _ => + Seq.tabulate getElem total) + in + result + end + +end diff --git a/tests/bench/flatten/SimpleExpandFlatten.sml b/tests/bench/flatten/SimpleExpandFlatten.sml new file mode 100644 index 000000000..a5a188803 --- /dev/null +++ b/tests/bench/flatten/SimpleExpandFlatten.sml @@ -0,0 +1,58 @@ +(** This is the algorithm I describe in my blog post. Asymptotically efficient + * and simple to understand, but not very well optimized. + *) +structure SimpleExpandFlatten = +struct + + fun flatten s = + let + val (offsets, total) = Seq.scan op+ 0 (Seq.map Seq.length s) + + fun expand (prevBoundaries, prevBlockSize) = + let + val newBlockSize = prevBlockSize div 2 + + fun newBoundary b = + let + val blockNum = (b * newBlockSize) div prevBlockSize + val offlo = #2 (Seq.nth prevBoundaries blockNum) - 1 + val offhi = + if blockNum+1 < Seq.length prevBoundaries then + #1 (Seq.nth prevBoundaries (blockNum+1)) + else + Seq.length offsets + + val (nlt, nleq) = + BinarySearch.numLtAndLeq + (Seq.subseq offsets (offlo, offhi-offlo)) + (b * newBlockSize) + in + (offlo + nlt, offlo + nleq) + end + + val newNumBlocks = Util.ceilDiv total newBlockSize + in + (Seq.tabulate newBoundary newNumBlocks, newBlockSize) + end + + fun expansionLoop (boundaries, blockSize) = + if blockSize = 1 then + boundaries + else + expansionLoop (expand (boundaries, blockSize)) + + val boundaries = + expansionLoop (Seq.fromList [BinarySearch.numLtAndLeq offsets 0], total) + + fun getElem i = + let + val segIdx = #2 (Seq.nth boundaries i) - 1 + val segOff = Seq.nth offsets segIdx + in + Seq.nth (Seq.nth s segIdx) (i - segOff) + end + in + Seq.tabulate getElem total + end + +end diff --git a/tests/bench/flatten/flatten.mlb b/tests/bench/flatten/flatten.mlb new file mode 100644 index 000000000..67c9017df --- /dev/null +++ b/tests/bench/flatten/flatten.mlb @@ -0,0 +1,10 @@ +../../mpllib/sources.$(COMPAT).mlb +BinarySearch.sml +FullExpandPow2Flatten.sml +AllBSFlatten.sml +BlockedAllBSFlatten.sml +MultiBlockedBSFlatten.sml +ExpandFlatten.sml +SimpleBlockedFlatten.sml +SimpleExpandFlatten.sml +flatten.sml diff --git a/tests/bench/flatten/flatten.sml b/tests/bench/flatten/flatten.sml new file mode 100644 index 000000000..4183668a3 --- /dev/null +++ b/tests/bench/flatten/flatten.sml @@ -0,0 +1,43 @@ +structure CLA = CommandLineArgs + +(* ========================================================================== + * parse command-line arguments and run + *) + +val numElems = CLA.parseInt "num-elems" (1000 * 1000 * 10) +val numSeqs = CLA.parseInt "num-seqs" (numElems div 5) +val impl = CLA.parseString "impl" "lib" +val _ = print ("num-elems " ^ Int.toString numElems ^ "\n") +val _ = print ("num-seqs " ^ Int.toString numSeqs ^ "\n") + +val doit = + case impl of + "lib" => Seq.flatten + | "full-expand-pow2" => FullExpandPow2Flatten.flatten + | "all-bs" => AllBSFlatten.flatten + | "blocked-all-bs" => BlockedAllBSFlatten.flatten + | "multi-blocked-bs" => MultiBlockedBSFlatten.flatten + | "expand" => ExpandFlatten.flatten + | "simple-blocked" => SimpleBlockedFlatten.flatten + | "simple-expand" => SimpleExpandFlatten.flatten + | _ => raise Fail ("unknown impl '" ^ impl ^ "'") + +val _ = print ("impl " ^ impl ^ "\n") + +val offsets = + Mergesort.sort Int.compare + (Seq.tabulate (fn i => Util.hash i mod numElems) numSeqs) +fun O i = + if i = 0 then 0 + else if i >= numSeqs then numElems + else Seq.nth offsets i +val elems = Seq.tabulate (fn i => i) numElems +val input = Seq.tabulate (fn i => Seq.subseq elems (O i, O (i+1) - O i)) numSeqs +val _ = print ("generated input\n") + +val result = Benchmark.run "flatten" (fn _ => doit input) + +val correct = Seq.equal op= (elems, result) +val _ = print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") +val _ = print ("result " ^ Util.summarizeArraySlice 10 Int.toString result ^ "\n") + diff --git a/tests/bench/gif-encode/main.sml b/tests/bench/gif-encode/main.sml new file mode 100644 index 000000000..a3e979a4f --- /dev/null +++ b/tests/bench/gif-encode/main.sml @@ -0,0 +1,55 @@ +structure CLA = CommandLineArgs + +local +open GIF +in + +fun encode palette {width, height, numImages, getImage} = + if numImages <= 0 then + err "Must be at least one image" + else + let + val width16 = checkToWord16 "width" width + val height16 = checkToWord16 "height" height + + val numberOfColors = Seq.length (#colors palette) + + val _ = + if numberOfColors <= 256 then () + else err "Must have at most 256 colors in the palette" + + val imageData = + AS.full (SeqBasis.tabulate 1 (0, numImages) (fn i => + let + val img = getImage i + in + if Seq.length img <> height * width then + err "Not all images are the right dimensions" + else + LZW.packCodeStream numberOfColors + (LZW.codeStream numberOfColors img) + end)) + in + imageData + end + +end + +val width = CLA.parseInt "width" +val height = CLA.parseInt "height" + +fun pixel (i, j) = + Color.hsv + { h = 90.0 + (Real.fromInt i / Real.fromInt width) * 135.0 + , s = 0.5 + (Real.fromInt j / Real.fromInt height) * 0.5 + , v = 0.8 + } + +val image = + { height = height + , width = width + , data = Seq.tabulate (fn i => (i div width, i mod width)) (width * height) + } + +val imageData = + diff --git a/tests/bench/graphio/graphio.mlb b/tests/bench/graphio/graphio.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/graphio/graphio.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/graphio/main.sml b/tests/bench/graphio/main.sml new file mode 100644 index 000000000..35c9c685b --- /dev/null +++ b/tests/bench/graphio/main.sml @@ -0,0 +1,21 @@ +structure CLA = CommandLineArgs +structure G = AdjacencyGraph(Int) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val graph = G.parseFile filename +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val outfile = CLA.parseString "outfile" "" + +val _ = + if outfile = "" then () else + let + val (_, tm) = Util.getTime (fn _ => G.writeAsBinaryFormat graph outfile) + in + print ("wrote graph (binary format) to " ^ outfile ^ " in " ^ Time.fmt 4 tm ^ "s\n") + end diff --git a/tests/bench/grep-old/Grep.sml b/tests/bench/grep-old/Grep.sml new file mode 100644 index 000000000..3c64e2303 --- /dev/null +++ b/tests/bench/grep-old/Grep.sml @@ -0,0 +1,129 @@ +structure Grep: +sig + (* returns number of matching lines, and output that unix grep would show *) + val grep : char Seq.t (* pattern *) + -> char Seq.t (* source text *) + -> int * (char Seq.t) +end = +struct + + structure A = Array + structure AS = ArraySlice + + type 'a seq = 'a Seq.t + + (* fun lines (s : char seq) : (char seq) Seq.seq = + let + val n = ASeq.length s + val indices = Seq.tabulate (fn i => i) n + fun isNewline i = (ASeq.nth s i = #"\n") + val locs = Seq.filter isNewline indices + val m = Seq.length locs + + fun line i = + let + val lo = (if i = 0 then 0 else 1 + Seq.nth locs (i-1)) + val hi = (if i = m then n else Seq.nth locs i) + in + ASeq.subseq s (lo, hi-lo) + end + in + Seq.tabulate line (m+1) + end *) + + (* check if line[i..] matches the pattern *) + fun checkMatch pattern line i = + (i + Seq.length pattern <= Seq.length line) andalso + let + val m = Seq.length pattern + (* pattern[j..] matches line[i+j..] *) + fun matchesFrom j = + (j >= m) orelse + ((Seq.nth line (i+j) = Seq.nth pattern j) andalso matchesFrom (j+1)) + in + matchesFrom 0 + end + + (* fun grep pat source = + let + val granularity = CommandLineArgs.parseOrDefaultInt "granularity" 1000 + val ff = FindFirst.findFirst granularity + (* val ff = FindFirst.findFirstSerial *) + fun containsPat line = + case ff (0, ASeq.length line) (checkMatch pat line) of + NONE => false + | SOME _ => true + + val linesWithPat = Seq.filter containsPat (lines source) + val newln = Seq.singleton #"\n" + + fun choose i = + if Util.even i + then Seq.fromArraySeq (Seq.nth linesWithPat (i div 2)) + else newln + in + Seq.toArraySeq (Seq.flatten (Seq.tabulate choose (2 * Seq.length linesWithPat))) + end *) + + val ffGrain = CommandLineArgs.parseInt "ff-grain" 1000 + val findFirst = FindFirst.findFirst ffGrain + + fun grep pat source = + let + fun isNewline i = (Seq.nth source i = #"\n") + + val nlPos = + AS.full (SeqBasis.filter 10000 (0, Seq.length source) (fn i => i) isNewline) + val numLines = Seq.length nlPos + 1 + fun lineStart i = + if i = 0 then 0 else 1 + Seq.nth nlPos (i-1) + fun lineEnd i = + if i = Seq.length nlPos then Seq.length source else Seq.nth nlPos i + fun line i = Seq.subseq source (lineStart i, lineEnd i - lineStart i) + + (* val _ = print ("got newline positions\n") *) + + (* compute whether or not each line contains the pattern *) + val hasPatFlags = AS.full (SeqBasis.tabulate 1000 (0, numLines) (fn i => + let + val ln = line i + in + case findFirst (0, Seq.length ln) (checkMatch pat ln) of + SOME _ => true + | NONE => false + end)) + + (* val _ = print ("found the patterns\n") *) + + val linesWithPat = + AS.full (SeqBasis.filter 4096 (0, numLines) (fn i => i) (Seq.nth hasPatFlags)) + val numLinesOutput = Seq.length linesWithPat + + (* val _ = print ("filtered the lines\n") *) + (* val _ = print ("num lines: " ^ Int.toString numLinesOutput ^ "\n") *) + + val outputOffsets = + AS.full (SeqBasis.scan 4096 op+ 0 (0, numLinesOutput) + (* +1 to include newline *) + (fn i => 1 + Seq.length (line (Seq.nth linesWithPat i)))) + + val outputLen = Seq.nth outputOffsets numLinesOutput + + (* val _ = print ("computed line offsets\n") *) + + val output = ForkJoin.alloc outputLen + fun put i c = A.update (output, i, c) + in + ForkJoin.parfor 1000 (0, numLinesOutput) (fn i => + let + val ln = line (Seq.nth linesWithPat i) + val off = Seq.nth outputOffsets i + in + Seq.foreach ln (fn (j, c) => put (off+j) c); + put (off + Seq.length ln) #"\n" + end); + + (numLinesOutput, AS.full output) + end + +end diff --git a/tests/bench/grep-old/grep-old.mlb b/tests/bench/grep-old/grep-old.mlb new file mode 100644 index 000000000..a4cfd0cab --- /dev/null +++ b/tests/bench/grep-old/grep-old.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +Grep.sml +main.sml diff --git a/tests/bench/grep-old/main.sml b/tests/bench/grep-old/main.sml new file mode 100644 index 000000000..1dbc2d345 --- /dev/null +++ b/tests/bench/grep-old/main.sml @@ -0,0 +1,32 @@ +structure CLA = CommandLineArgs + +val (pat, file) = + case CLA.positional () of + [pat, file] => (pat, file) + | _ => Util.die ("[ERR] usage: grep PATTERN FILE") + +val benchmark = CLA.parseFlag "benchmark" + +fun bprint str = + if not benchmark then () + else print str + +val (source, tm) = Util.getTime (fn _ => ReadFile.contentsSeq file) +val _ = bprint ("read file in " ^ Time.fmt 4 tm ^ "s\n") + +val pattern = Seq.tabulate (fn i => String.sub (pat, i)) (String.size pat) + +val (matches, output) = + Benchmark.run "running grep" (fn _ => Grep.grep pattern source) + +val _ = bprint ("number of matched lines: " ^ Int.toString matches ^ "\n") +val _ = bprint ("length of output: " ^ Int.toString (Seq.length output) ^ "\n") + +val _ = + if benchmark then () + else + ArraySlice.app (fn c => TextIO.output1 (TextIO.stdOut, c)) output + +val _ = + if benchmark then GCStats.report () + else () diff --git a/tests/bench/grep/grep.mlb b/tests/bench/grep/grep.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/grep/grep.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/grep/main.sml b/tests/bench/grep/main.sml new file mode 100644 index 000000000..75394c1a8 --- /dev/null +++ b/tests/bench/grep/main.sml @@ -0,0 +1,53 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +(* chosen by subdirectory *) +structure Grep = MkGrep(OldDelayedSeq) + +(* +val pattern = CLA.parseString "pattern" "" +val filePath = CLA.parseString "infile" "" + +val input = + let + val (source, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filePath) + val _ = print ("loadtime " ^ Time.fmt 4 tm ^ "s\n") + in + source + end + +val pattern = + Seq.tabulate (fn i => String.sub (pattern, i)) (String.size pattern) +*) + +val (pat, file) = + case CLA.positional () of + [pat, file] => (pat, file) + | _ => Util.die ("[ERR] usage: grep PATTERN FILE") + +val (input, tm) = Util.getTime (fn _ => ReadFile.contentsSeq file) +val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") + +val pattern = Seq.tabulate (fn i => String.sub (pat, i)) (String.size pat) + +val n = Seq.length input +val _ = print ("n " ^ Int.toString n ^ "\n") + +fun task () = + Grep.grep pattern input + +val result = Benchmark.run "running grep" task +val _ = print ("num matching lines " ^ Int.toString (Seq.length result) ^ "\n") + +(* fun dumpLoop i = + if i >= Seq.length result then () else + let + val (s, e) = Seq.nth result i + val tok = CharVector.tabulate (e-s, fn k => Seq.nth input (s+k)) + in + print tok; + print "\n"; + dumpLoop (i+1) + end + +val _ = dumpLoop 0 *) diff --git a/tests/bench/high-frag/high-frag.mlb b/tests/bench/high-frag/high-frag.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/high-frag/high-frag.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/high-frag/main.sml b/tests/bench/high-frag/main.sml new file mode 100644 index 000000000..5800d238a --- /dev/null +++ b/tests/bench/high-frag/main.sml @@ -0,0 +1,51 @@ +structure CLA = CommandLineArgs +val nodes = CLA.parseInt "nodes-per-tree" 1000000 +val numTrees = CLA.parseInt "num-trees" 100 +val grain = CLA.parseInt "grain" 1 +val _ = print ("nodes-per-tree " ^ Int.toString nodes ^ "\n") +val _ = print ("num-trees " ^ Int.toString numTrees ^ "\n") + +datatype tree = Leaf of int | Node of tree * tree + + +fun go (i, j) = + case j-i of + 1 => Leaf (Util.hash i) + | n => + if n <= grain then + Node (go (i, i + n div 2), go (i + n div 2, j)) + else + Node (ForkJoin.par (fn _ => go (i, i + n div 2), + fn _ => go (i + n div 2, j))) + +fun benchmark () = + List.tabulate (numTrees, fn i => go (i*nodes, (i+1)*nodes)) + +val results = + Benchmark.run "high-fragmentation tree" benchmark + + +(** ====================================================================== + * Now do some arbitrary computation on the result to make sure it's not + * optimized out. + *) + +fun reduce f t = + let + fun loop depth t = + case t of + Leaf x => x + | Node (a, b) => + if depth > 10 then + f (loop (depth+1) a, loop (depth+1) b) + else + f (ForkJoin.par (fn _ => loop (depth+1) a, + fn _ => loop (depth+1) b)) + in + loop 0 t + end + +val foo = + List.foldl Int.min (valOf Int.maxInt) (List.map (reduce Int.max) results) + +val _ = print ("foo " ^ Int.toString foo ^ "\n") diff --git a/tests/bench/integrate-opt/Integrate.sml b/tests/bench/integrate-opt/Integrate.sml new file mode 100644 index 000000000..e77efffaf --- /dev/null +++ b/tests/bench/integrate-opt/Integrate.sml @@ -0,0 +1,15 @@ +structure Integrate = +struct + + fun integrate f (s, e) n = + let + val delta = (e - s) / (Real.fromInt n) + val s' = s + delta / 2.0 + in + delta + * + SeqBasis.reduce 5000 op+ 0.0 (0, n) (fn i => + f (s' + (Real.fromInt i) * delta)) + end + +end diff --git a/tests/bench/integrate-opt/integrate-opt.mlb b/tests/bench/integrate-opt/integrate-opt.mlb new file mode 100644 index 000000000..30864afc1 --- /dev/null +++ b/tests/bench/integrate-opt/integrate-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +Integrate.sml +main.sml diff --git a/tests/bench/integrate-opt/main.sml b/tests/bench/integrate-opt/main.sml new file mode 100644 index 000000000..eb3fbac85 --- /dev/null +++ b/tests/bench/integrate-opt/main.sml @@ -0,0 +1,26 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +val range = (1.0, 1000.0) + +val (f, correctAnswer) = + (fn x => Math.sqrt (1.0 / x), 61.245553203367586639977870888654371) + (* (fn x => 1.0 / x, 6.9077552789821370520539743640530926) *) + (* (fn x => Math.sin (1.0 / x), 6.8264726355070070694576392250122662) *) + +fun task () = + Integrate.integrate f range n + +fun check result = + if Util.closeEnough (result, correctAnswer) then + print ("correct? yes\n") + else + print ("correct? no (error = " + ^ Real.toString (Real.abs (result - correctAnswer)) + ^ ")\n") + +val result = Benchmark.run "integrate" task +val _ = print ("result " ^ Real.toString result ^ "\n") +val _ = check result diff --git a/tests/bench/integrate/MkIntegrate.sml b/tests/bench/integrate/MkIntegrate.sml new file mode 100644 index 000000000..014c1c96f --- /dev/null +++ b/tests/bench/integrate/MkIntegrate.sml @@ -0,0 +1,13 @@ +functor MkIntegrate(Seq: SEQUENCE) = +struct + + fun integrate (f: real -> real) (s: real, e: real) (n: int) = + let + val delta = (e - s) / (Real.fromInt n) + val s' = s + delta / 2.0 + val X = Seq.tabulate (fn i => f (s' + (Real.fromInt i) * delta)) n + in + (Seq.reduce op+ 0.0 X) * delta + end + +end diff --git a/tests/bench/integrate/integrate.mlb b/tests/bench/integrate/integrate.mlb new file mode 100644 index 000000000..f20a233f9 --- /dev/null +++ b/tests/bench/integrate/integrate.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkIntegrate.sml +main.sml diff --git a/tests/bench/integrate/main.sml b/tests/bench/integrate/main.sml new file mode 100644 index 000000000..23bf1020f --- /dev/null +++ b/tests/bench/integrate/main.sml @@ -0,0 +1,39 @@ +structure CLA = CommandLineArgs + +structure IntegrateAS = MkIntegrate(Seq) +structure IntegrateDS = MkIntegrate(DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val impl = CLA.parseString "impl" "delayed-seq" +val doCheck = CLA.parseFlag "check" + +val _ = print ("n " ^ Int.toString n ^ "\n") +val _ = print ("impl " ^ impl ^ "\n") +val _ = print ("check? " ^ (if doCheck then "yes" else "no") ^ "\n") + +val range = (1.0, 1000.0) + +val (f, correctAnswer) = + (fn x => Math.sqrt (1.0 / x), 61.245553203367586639977870888654371) +(* (fn x => 1.0 / x, 6.9077552789821370520539743640530926) *) +(* (fn x => Math.sin (1.0 / x), 6.8264726355070070694576392250122662) *) + +val task = + case impl of + "delayed-seq" => (fn () => IntegrateDS.integrate f range n) + | "array-seq" => (fn () => IntegrateAS.integrate f range n) + | _ => + Util.die + ("unknown impl: " ^ impl ^ "; options are: delayed-seq, array-seq") + +fun check result = + if Util.closeEnough (result, correctAnswer) then + print ("correct? yes\n") + else + print + ("correct? no (error = " + ^ Real.toString (Real.abs (result - correctAnswer)) ^ ")\n") + +val result = Benchmark.run "integrate" task +val _ = print ("result " ^ Real.toString result ^ "\n") +val _ = check result diff --git a/tests/bench/interval-tree/IntervalTree.sml b/tests/bench/interval-tree/IntervalTree.sml new file mode 100644 index 000000000..60c69a2c6 --- /dev/null +++ b/tests/bench/interval-tree/IntervalTree.sml @@ -0,0 +1,47 @@ +structure ITree : Aug = +struct + type key = int + type value = int + type aug = int + val compare = Int.compare + val g = fn (x, y) => y + val f = fn (x, y) => Int.max (x, y) + val id = ~1073741824 + val balance = WB 0.28 + fun debug (k, v, a) = Int.toString k ^ ", " ^ Int.toString v ^ ", " ^ Int.toString a +end + +signature INTERVAL_MAP = +sig + type point + type interval = point * point + type imap + + val interval_map : interval Seq.t -> int -> imap + val multi_insert : imap -> interval Seq.t -> imap + val stab : imap -> point -> bool + val report_all : imap -> point -> imap + val size : imap -> int + val print : imap -> unit +end + +structure IntervalMap : INTERVAL_MAP = +struct + structure amap = PAM(ITree) + type point = ITree.key + type interval = point * point + type imap = amap.am + + fun interval_map s n = amap.build s 0 n + + fun multi_insert im s = amap.multi_insert im s (Int.max) + + fun stab im p = (amap.aug_left im p) > p + + fun report_all im p = amap.aug_filter (amap.up_to im p) (fn q => q > p) + + fun size im = (amap.size im) + + fun print im = amap.print_tree im "" +end + diff --git a/tests/bench/interval-tree/interval-tree.mlb b/tests/bench/interval-tree/interval-tree.mlb new file mode 100644 index 000000000..83f4bffdb --- /dev/null +++ b/tests/bench/interval-tree/interval-tree.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +IntervalTree.sml +new-main.sml diff --git a/tests/bench/interval-tree/main.sml b/tests/bench/interval-tree/main.sml new file mode 100644 index 000000000..3418eb0f7 --- /dev/null +++ b/tests/bench/interval-tree/main.sml @@ -0,0 +1,87 @@ +fun randRange i j = + i + Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt (j - i))) + + +fun uniform_input n w shuffle = + let + fun g i = (randRange 1 w, i) + val pairs = Seq.tabulate (fn i => g i) n + in + pairs + end + + +fun eval_build_im n = +let + val max_size = 2147483647 + val ilint = uniform_input n max_size false + val v = Seq.map (fn (i, j) => (i, (randRange (LargeInt.fromInt i) max_size))) ilint + val t0 = Time.now () + val i = IntervalMap.interval_map v n + val t1 = Time.now () +in + (t0, t1, i) +end + +fun eval_queries_im im q = + let + val max_size = 2147483647 + val queries = Seq.map (fn i => (#1 i)) (uniform_input q max_size false) + val t0 = Time.now() + val r = Seq.map (IntervalMap.stab im) queries + val t1 = Time.now() + in + (t0, t1, r) + end + +fun eval_multi_insert_im im n = + let + val max_size = 2147483647 + val ilint = uniform_input n max_size false + val v = Seq.map (fn (i, j) => (i, (randRange (LargeInt.fromInt i) max_size))) ilint + val t0 = Time.now () + val i = IntervalMap.multi_insert im v + val t1 = Time.now () + in + (t0, t1, i) + end + + + + +fun run_rounds f r = + let + fun round_rec i diff = + if i = 0 then diff + else + let + val (t0, t1, _) = f() + val new_diff = Time.- (t1, t0) + val _ = print ("round " ^ (Int.toString (r - i + 1)) ^ " in " ^ Time.fmt 4 (new_diff) ^ "s\n") + in + round_rec (i - 1) (Time.+ (diff, new_diff)) + end + in + round_rec r Time.zeroTime + end + +val query_size = CommandLineArgs.parseInt "q" 100000000 +val size = CommandLineArgs.parseInt "n" 100000000 +val rep = CommandLineArgs.parseInt "repeat" 1 + +val diff = + if query_size = 0 then + run_rounds (fn _ => eval_build_im size) rep + else + let + val c = eval_build_im size + val curr = eval_queries_im (#3 c) + in + run_rounds (fn _ => curr query_size) rep + end + +val _ = print ("total " ^ Time.fmt 4 diff ^ "s\n") +val avg = Time.toReal diff / (Real.fromInt rep) +val _ = print ("average " ^ Real.fmt (StringCvt.FIX (SOME 4)) avg ^ "s\n") + + diff --git a/tests/bench/interval-tree/new-main.sml b/tests/bench/interval-tree/new-main.sml new file mode 100644 index 000000000..e8499065b --- /dev/null +++ b/tests/bench/interval-tree/new-main.sml @@ -0,0 +1,65 @@ +structure CLA = CommandLineArgs + +val q = CLA.parseInt "q" 100 +val n = CLA.parseInt "n" 1000000 + +val _ = print ("n " ^ Int.toString n ^ "\n") +val _ = print ("q " ^ Int.toString q ^ "\n") + +val max_size = 1000000000 +(* val gap_size = 100 *) + +fun randRange i j seed = + i + Word64.toInt + (Word64.mod (Util.hash64 (Word64.fromInt seed), Word64.fromInt (j - i))) + +(* +fun randSeg seed = + let + val p = gap_size * (randRange 1 (max_size div gap_size) seed) + val hi = Int.min (p + gap_size, max_size) + in + (p, randRange p hi (seed+1)) + end +*) + +fun randSeg seed = + let + val p = randRange 1 max_size seed + val space = max_size - p + val hi = p + 1 + space div 100 + in + (p, randRange p hi (seed+1)) + end + +(* fun query seed = + IntervalMap.stab tree (randRange 1 max_size seed) *) + +fun query tree seed = + IntervalMap.size (IntervalMap.report_all tree (randRange 1 max_size seed)) + +fun bench () = + let + val (tree, tm) = Util.getTime (fn _ => + IntervalMap.interval_map (Seq.tabulate (fn i => randSeg (2*i)) n) n) + val _ = print ("generated tree in " ^ Time.fmt 4 tm ^ "s\n") + in + ArraySlice.full (SeqBasis.tabulate 1 (0, q) (fn i => query tree (2*n + i))) + end + +val result = Benchmark.run "generating and stabbing intervals..." bench + +(* val numHits = Seq.reduce op+ 0 (Seq.map (fn true => 1 | _ => 0) result) +val _ = print ("hits " ^ Int.toString numHits ^ "\n") + +val hitrate = Real.round (100.0 * (Real.fromInt numHits / Real.fromInt q)) +val _ = print ("hitrate " ^ Int.toString hitrate ^ "%\n") *) + +val numHits = Seq.reduce op+ 0 result +val minHits = Seq.reduce Int.min (valOf Int.maxInt) result +val maxHits = Seq.reduce Int.max 0 result +val avgHits = Real.round (Real.fromInt numHits / Real.fromInt q) +val _ = print ("hits " ^ Int.toString numHits ^ "\n") +val _ = print ("min " ^ Int.toString minHits ^ "\n") +val _ = print ("avg " ^ Int.toString avgHits ^ "\n") +val _ = print ("max " ^ Int.toString maxHits ^ "\n") diff --git a/tests/bench/linearrec-opt/LinearRec.sml b/tests/bench/linearrec-opt/LinearRec.sml new file mode 100644 index 000000000..b6023f1b2 --- /dev/null +++ b/tests/bench/linearrec-opt/LinearRec.sml @@ -0,0 +1,66 @@ +structure LinearRec = +struct + + structure A = Array + structure AS = ArraySlice + + fun upd a i x = A.update (a, i, x) + fun nth a i = A.sub (a, i) + + val parfor = ForkJoin.parfor + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + type elem = real * real + + fun scanMap grain (g: elem * elem -> elem) b (lo, hi) (f : int -> elem) (out: elem -> 'a) = + if hi - lo <= grain then + let + val n = hi - lo + val result = allocate (n+1) + fun bump ((j,b),x) = (upd result j (out b); (j+1, g (b, x))) + val (_, total) = SeqBasis.foldl bump (0, b) (lo, hi) f + in + upd result n (out total); + result + end + else + let + val n = hi - lo + val k = grain + val m = 1 + (n-1) div k (* number of blocks *) + val sums = SeqBasis.tabulate 1 (0, m) (fn i => + let val start = lo + i*k + in SeqBasis.foldl g b (start, Int.min (start+k, hi)) f + end) + val partials = SeqBasis.scan grain g b (0, m) (nth sums) + val result = allocate (n+1) + in + parfor 1 (0, m) (fn i => + let + fun bump ((j,b),x) = (upd result j (out b); (j+1, g (b, x))) + val start = lo + i*k + in + SeqBasis.foldl bump (i*k, nth partials i) (start, Int.min (start+k, hi)) f; + () + end); + upd result n (out (nth partials m)); + result + end + + (* ====================================================================== *) + + fun combine ((x1, y1), (x2, y2)) = + (x1 * x2, y1 * x2 + y2) + + val id = (1.0, 0.0) + + fun linearRec s = + let + val result = + scanMap 5000 combine id (0, Seq.length s) (Seq.nth s) #2 + in + Seq.drop (AS.full result) 1 + end + +end diff --git a/tests/bench/linearrec-opt/linearrec-opt.mlb b/tests/bench/linearrec-opt/linearrec-opt.mlb new file mode 100644 index 000000000..cbef997d4 --- /dev/null +++ b/tests/bench/linearrec-opt/linearrec-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +LinearRec.sml +main.sml diff --git a/tests/bench/linearrec-opt/main.sml b/tests/bench/linearrec-opt/main.sml new file mode 100644 index 000000000..55a07b315 --- /dev/null +++ b/tests/bench/linearrec-opt/main.sml @@ -0,0 +1,21 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure L = LinearRec + +val n = CLA.parseInt "n" (1000 * 1000 * 100) + +val _ = print ("n " ^ Int.toString n ^ "\n") + +(* fun gen i = + Real.fromInt ((Util.hash i) mod 1000 - 500) / 500.0 *) + +val input = + Seq.tabulate (fn i => (1.0, 1.0)) n + +fun task () = + L.linearRec input + +val result = Benchmark.run "linear recurrence" task +val x = Seq.nth result (n-1) +val _ = print ("result " ^ Real.toString x ^ "\n") diff --git a/tests/bench/linearrec/MkLinearRec.sml b/tests/bench/linearrec/MkLinearRec.sml new file mode 100644 index 000000000..e57ebb22e --- /dev/null +++ b/tests/bench/linearrec/MkLinearRec.sml @@ -0,0 +1,12 @@ +functor MkLinearRec (Seq: SEQUENCE) = +struct + + fun combine ((x1, y1), (x2, y2)) = + (x1 * x2, y1 * x2 + y2) + + val id = (1.0, 0.0) + + fun linearRec s = + Seq.toArraySeq (Seq.map #2 (Seq.scanIncl combine id (Seq.fromArraySeq s))) + +end diff --git a/tests/bench/linearrec/linearrec.mlb b/tests/bench/linearrec/linearrec.mlb new file mode 100644 index 000000000..38ea89521 --- /dev/null +++ b/tests/bench/linearrec/linearrec.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkLinearRec.sml +main.sml diff --git a/tests/bench/linearrec/main.sml b/tests/bench/linearrec/main.sml new file mode 100644 index 000000000..4eb449dfb --- /dev/null +++ b/tests/bench/linearrec/main.sml @@ -0,0 +1,21 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure L = MkLinearRec(OldDelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) + +val _ = print ("n " ^ Int.toString n ^ "\n") + +(* fun gen i = + Real.fromInt ((Util.hash i) mod 1000 - 500) / 500.0 *) + +val input = + Seq.tabulate (fn i => (1.0, 1.0)) n + +fun task () = + L.linearRec input + +val result = Benchmark.run "linear recurrence" task +val x = Seq.nth result (n-1) +val _ = print ("result " ^ Real.toString x ^ "\n") diff --git a/tests/bench/linefit-opt/LineFit.sml b/tests/bench/linefit-opt/LineFit.sml new file mode 100644 index 000000000..94e8bfd68 --- /dev/null +++ b/tests/bench/linefit-opt/LineFit.sml @@ -0,0 +1,25 @@ +structure LineFit = +struct + + fun linefit (points : (real * real) Seq.t) = + let + + val n = Real.fromInt (Seq.length points) + fun sumPair((x1,y1),(x2,y2)) = (x1 + x2, y1 + y2) + fun sum f = + SeqBasis.reduce 5000 sumPair (0.0, 0.0) + (0, Seq.length points) + (f o Seq.nth points) + (* Seq.reduce sumPair (0.0, 0.0) (Seq.map f points) *) + + fun square x = x * x + val (xsum, ysum) = sum (fn (x,y) => (x,y)) + val (xa, ya) = (xsum/n, ysum/n) + val (Stt, bb) = sum (fn (x,y) => (square(x - xa), (x - xa) * y)) + val b = bb / Stt + val a = ya - xa * b + in + (a, b) + end + +end diff --git a/tests/bench/linefit-opt/linefit-opt.mlb b/tests/bench/linefit-opt/linefit-opt.mlb new file mode 100644 index 000000000..1b44cc1c4 --- /dev/null +++ b/tests/bench/linefit-opt/linefit-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +LineFit.sml +main.sml diff --git a/tests/bench/linefit-opt/main.sml b/tests/bench/linefit-opt/main.sml new file mode 100644 index 000000000..974964ac8 --- /dev/null +++ b/tests/bench/linefit-opt/main.sml @@ -0,0 +1,32 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence +structure LF = LineFit + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +fun gen i = + (Real.fromInt i, Real.fromInt i) + +val input = + Seq.tabulate gen n + +fun task () = + LF.linefit input + +fun check result = + if not doCheck then () else + let + val (a, b) = result + val correct = + Real.< (Real.abs a , 0.000001) andalso + Real.< (Real.abs (b-1.0), 0.000001) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "linefit" task +val _ = check result diff --git a/tests/bench/linefit/MkLineFit.sml b/tests/bench/linefit/MkLineFit.sml new file mode 100644 index 000000000..d4c858371 --- /dev/null +++ b/tests/bench/linefit/MkLineFit.sml @@ -0,0 +1,21 @@ +functor MkLineFit (Seq : SEQUENCE) = +struct + + fun linefit (points : (real * real) ArraySequence.t) = + let + val points = Seq.fromArraySeq points + + val n = Real.fromInt (Seq.length points) + fun sumPair((x1,y1),(x2,y2)) = (x1 + x2, y1 + y2) + fun sum f = Seq.reduce sumPair (0.0, 0.0) (Seq.map f points) + fun square x = x * x + val (xsum, ysum) = sum (fn (x,y) => (x,y)) + val (xa, ya) = (xsum/n, ysum/n) + val (Stt, bb) = sum (fn (x,y) => (square(x - xa), (x - xa) * y)) + val b = bb / Stt + val a = ya - xa * b + in + (a, b) + end + +end diff --git a/tests/bench/linefit/linefit.mlb b/tests/bench/linefit/linefit.mlb new file mode 100644 index 000000000..1fdda8cbd --- /dev/null +++ b/tests/bench/linefit/linefit.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkLineFit.sml +main.sml diff --git a/tests/bench/linefit/main.sml b/tests/bench/linefit/main.sml new file mode 100644 index 000000000..06170ddd2 --- /dev/null +++ b/tests/bench/linefit/main.sml @@ -0,0 +1,34 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +(* chosen by subdirectory *) +structure LF = MkLineFit(OldDelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +fun gen i = + (Real.fromInt i, Real.fromInt i) + +val input = + Seq.tabulate gen n + +fun task () = + LF.linefit input + +fun check result = + if not doCheck then () else + let + val (a, b) = result + val correct = + Real.< (Real.abs a , 0.000001) andalso + Real.< (Real.abs (b-1.0), 0.000001) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "linefit" task +val _ = check result diff --git a/tests/bench/low-d-decomp/LDD.sml b/tests/bench/low-d-decomp/LDD.sml new file mode 100644 index 000000000..32c1b67b1 --- /dev/null +++ b/tests/bench/low-d-decomp/LDD.sml @@ -0,0 +1,172 @@ +structure LDD = +struct + type 'a seq = 'a Seq.t + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + structure AS = ArraySlice + structure DS = DelayedSeq + + type vertex = G.vertex + + fun strip s = + let val (s', start, _) = AS.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun vertex_map g f h = + AS.full (SeqBasis.tabFilter 10000 (0, G.numVertices g) (fn i => if h(i) then f(i) else NONE)) + + (* inplace Knuth shuffle [l, r) *) + fun seq_random_shuffle s l r seed = + let + fun item i = AS.sub (s, i) + fun set (i, v) = AS.update (s, i, v) + (* get a random idx in [l, i] *) + fun rand_idx i = Int.mod (Util.hash (seed + i), i - l + 1) + l + fun swap (i,j) = + let + val tmp = item i + in + set(i, item j); set(j, tmp) + end + fun shuffle_helper li = + if r - li < 2 then () + else (swap (li, rand_idx li); shuffle_helper (li + 1)) + in + shuffle_helper l + end + + fun log2_up n = Real.ceil (Math.log10 (Real.fromInt n) / (Math.log10 2.0)) + + fun bit_and (n, mask) = Word.toInt (Word.andb (Word.fromInt n, mask)) + fun range_check s n = Seq.length (Seq.filter (fn i => i >= n) s) = 0 + + fun shuffle s (n : int) seed = + if n < 1000 then + let + val cs = Seq.map (fn i => i) s + val _ = seq_random_shuffle cs 0 n 0 + in + cs + end + else + let + val l = log2_up n + val bits = if n < Real.floor (Math.pow (2.0, 27.0)) then Int.div ((l - 7), 2) + else l - 17 + val num_buckets = Real.floor (Math.pow (2.0, Real.fromInt bits)) + val mask = Word.fromInt (num_buckets - 1) + fun rand_pos i = bit_and (Util.hash (seed + i), mask) + (* size of bucket_offsets = num_buckets + 1 *) + val (s', bucket_offsets) = CountingSort.sort s rand_pos num_buckets + fun bucket_shuffle i = seq_random_shuffle s' (Seq.nth bucket_offsets i) (Seq.nth bucket_offsets (i + 1)) seed + val _ = ForkJoin.parfor 1 (0, num_buckets) bucket_shuffle + in + s' + end + + fun partition n b = + let + val s = shuffle (Seq.tabulate (fn i => i) n) n 0 + fun subseq s (i, j) = Seq.subseq s (i, j - i) + fun distribute w acc = + let + val i = Real.fromInt (List.length acc) + val bi = Real.floor (Math.pow (Math.e, b*i)) + val r = w + bi + in + if r >= n then (subseq s (w, n))::acc + else distribute r ((subseq s (w, r))::acc) + end + in + Seq.rev (Seq.fromList (distribute 0 [])) + end + + fun ldd g b = + let + val n = G.numVertices g + val (pr, tm) = Util.getTime (fn _ => partition n b) + val _ = print ("partition: " ^ Time.fmt 4 tm ^ "\n") + val pr_len = Seq.length pr + val visited = strip (Seq.tabulate (fn i => false) n) + val cluster = Seq.tabulate (fn i => n + 1) n + val parent = Seq.tabulate (fn i => n + 1) n + + fun item s i = AS.sub (s, i) + fun set s i v = AS.update (s, i, v) + + fun initialize_cluster u = + if (Seq.nth cluster u) > n then + (set cluster u u; Array.update (visited, u, true); SOME(u)) + else NONE + + fun update (s, d) = + if not (Concurrency.casArray (visited, d) (false, true)) then + (set cluster d (item cluster s); set parent d s; (SOME d)) + else NONE + + fun cond u = not (Array.sub (visited, u)) + + fun ldd_helper fr i = + if i >= pr_len then () + else + let + val (fr', tm) = Util.getTime (fn _ => + let + val pri = Seq.nth pr i + val new_clusters = SeqBasis.tabFilter 1000 (0, Seq.length pri) (fn i => initialize_cluster (Seq.nth pri i)) + (* val new_clusters = Seq.filter (fn v => (Seq.nth cluster v) > n) (Seq.nth pr i) *) + (* val _ = Seq.foreach new_clusters initialize_cluster *) + val fr_len = Seq.length fr + val nc = AS.full new_clusters + val app_frontier = fn i => if (i < fr_len) then Seq.nth fr i else Seq.nth nc (i - fr_len) + (* val fr' = Seq.append (fr, nc) *) + in + (app_frontier, fr_len + (Seq.length nc)) + end) + val _ = print ("round " ^ Int.toString i ^ ": new_clusters: " ^ Time.fmt 4 tm ^ "\n") + val (fr'', tm) = Util.getTime (fn _ => AdjInt.edge_map g fr' update cond) + val _ = print ("round " ^ Int.toString i ^ ": edge_map: " ^ Time.fmt 4 tm ^ "\n") + in + ldd_helper fr'' (i + 1) + end + in + (ldd_helper (Seq.empty ()) 0; + (cluster, parent)) + end + + fun check_ldd g c p = + let + val m = G.numEdges g + val n = G.numVertices g + val arr_set = strip (Seq.tabulate (fn i => false) n) + val tups = Seq.tabulate (fn i => (i, Seq.nth c i)) (Seq.length c) + fun outgoing cid = + let + val s = Seq.map (fn (i, j) => i) (Seq.filter (fn (i, j) => j = cid) tups) + val _ = Seq.foreach s (fn (_ ,v) => Array.update (arr_set, v, true)) + fun greater_neighbors v = Seq.filter (fn u => v < u) (G.neighbors g v) + val grt_neighbors = Seq.flatten (Seq.map greater_neighbors s) + val grt_out_neighbors = Seq.filter (fn v => not (Array.sub(arr_set, v))) grt_neighbors + val _ = Seq.foreach s (fn (_ ,v) => Array.update (arr_set, v, false)) + in + if (Seq.length s) = 0 then ~1 + else Seq.length grt_out_neighbors + end + + fun check_helper i cc ce = + if i >= n then (cc, ce) + else + let + val num_outgoing = outgoing i + in + if num_outgoing = ~1 then check_helper (i + 1) cc ce + else (check_helper (i + 1) (cc + 1) (ce + num_outgoing)) + end + val (cc, ce) = check_helper 0 0 0 + in + print ("num clusters = " ^ (Int.toString cc) ^ ", num edges = " ^ (Int.toString m) ^ ", inter-edges = " ^ (Int.toString ce) ^ "\n") + end + fun slts s = " " ^ Int.toString (Seq.length s) ^ " " + fun print_seq si s se = (print (si ^ " "); Seq.foreach s (fn (_,v) => print ((Int.toString v)^ " ") ) ; print (" " ^ se ^ "\n")) +end diff --git a/tests/bench/low-d-decomp/ldd-alt.sml b/tests/bench/low-d-decomp/ldd-alt.sml new file mode 100644 index 000000000..c32a54cb7 --- /dev/null +++ b/tests/bench/low-d-decomp/ldd-alt.sml @@ -0,0 +1,132 @@ +structure LDD = +struct + type 'a seq = 'a Seq.t + structure G = AdjacencyGraph(Int) + (* structure VS = G.VertexSubset *) + structure V = G.Vertex + structure AS = ArraySlice + + type vertex = G.vertex + + fun strip s = + let val (s', start, _) = AS.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun partition n b = + let + val s = Shuffle.shuffle (Seq.tabulate (fn i => i) n) 0 + fun subseq s (i, j) = Seq.subseq s (i, j - i) + fun distribute w acc = + let + val i = Real.fromInt (List.length acc) + val bi = Real.floor (Math.pow (Math.e, b*i)) + val r = w + bi + in + if r >= n then (subseq s (w, n))::acc + else distribute r ((subseq s (w, r))::acc) + end + in + Seq.rev (Seq.fromList (distribute 0 [])) + end + + fun ldd g b = + let + val n = G.numVertices g + val (pr, tm) = Util.getTime (fn _ => partition n b) + val _ = print ("partition: " ^ Time.fmt 4 tm ^ "\n") + val pr_len = Seq.length pr + val visited = strip (Seq.tabulate (fn i => false) n) + val cluster = Seq.tabulate (fn i => n + 1) n + val parent = Seq.tabulate (fn i => n + 1) n + + fun item s i = AS.sub (s, i) + fun set s i v = AS.update (s, i, v) + + fun initialize_cluster u = + if (Seq.nth cluster u) > n then + (set cluster u u; Array.update (visited, u, true); SOME(u)) + else NONE + + fun updateseq (s, d) = + if (Array.sub (visited, d)) then NONE + else + (Array.update(visited, d, true); set cluster d (item cluster s); set parent d s; (SOME d)) + + fun updatepar (s, d) = + if not (Concurrency.casArray (visited, d) (false, true)) then + (set cluster d (item cluster s); set parent d s; (SOME d)) + else NONE + + val update = (updatepar, updateseq) + + fun cond u = not (Array.sub (visited, u)) + val deg = Int.div (G.numEdges g, G.numVertices g) + val denseThreshold = G.numEdges g div (20*(1 + deg)) + + fun ldd_helper fr i = + if i >= pr_len then () + else + let + val (fr', tm) = Util.getTime (fn _ => + let + val pri = Seq.nth pr i + val new_clusters = SeqBasis.tabFilter 1000 (0, Seq.length pri) (fn i => initialize_cluster (Seq.nth pri i)) + (* val new_clusters = Seq.filter (fn v => (Seq.nth cluster v) > n) (Seq.nth pr i) *) + (* val _ = Seq.foreach new_clusters initialize_cluster *) + (* val fr_len = Seq.length fr *) + val nc = AS.full new_clusters + val app_frontier = AdjInt.append fr nc n + (* fn i => if (i < fr_len) then Seq.nth fr i else Seq.nth nc (i - fr_len) *) + (* val fr' = Seq.append (fr, nc) *) + in + (* (app_frontier, fr_len + (Seq.length nc)) *) + app_frontier + end) + val _ = print ("round " ^ Int.toString i ^ ": new_clusters: " ^ Time.fmt 4 tm ^ "\n") + val (fr'', tm) = Util.getTime (fn _ => AdjInt.edge_map g fr' update cond) + (* val b = if (should_process_sparse g fr') then "sparse" else "dense" *) + val _ = print ("round " ^ Int.toString i ^ " edge_map: " ^ Time.fmt 4 tm ^ "\n") + in + ldd_helper fr'' (i + 1) + end + in + (ldd_helper (AdjInt.empty (denseThreshold)) 0; + (cluster, parent)) + end + + fun check_ldd g c p = + let + val m = G.numEdges g + val n = G.numVertices g + val arr_set = strip (Seq.tabulate (fn i => false) n) + val tups = Seq.tabulate (fn i => (i, Seq.nth c i)) (Seq.length c) + fun outgoing cid = + let + val s = Seq.map (fn (i, j) => i) (Seq.filter (fn (i, j) => j = cid) tups) + val _ = Seq.foreach s (fn (_ ,v) => Array.update (arr_set, v, true)) + fun greater_neighbors v = Seq.filter (fn u => v < u) (G.neighbors g v) + val grt_neighbors = Seq.flatten (Seq.map greater_neighbors s) + val grt_out_neighbors = Seq.filter (fn v => not (Array.sub(arr_set, v))) grt_neighbors + val _ = Seq.foreach s (fn (_ ,v) => Array.update (arr_set, v, false)) + in + if (Seq.length s) = 0 then ~1 + else Seq.length grt_out_neighbors + end + + fun check_helper i cc ce = + if i >= n then (cc, ce) + else + let + val num_outgoing = outgoing i + in + if num_outgoing = ~1 then check_helper (i + 1) cc ce + else (check_helper (i + 1) (cc + 1) (ce + num_outgoing)) + end + val (cc, ce) = check_helper 0 0 0 + in + print ("num clusters = " ^ (Int.toString cc) ^ ", num edges = " ^ (Int.toString m) ^ ", inter-edges = " ^ (Int.toString ce) ^ "\n") + end + fun slts s = " " ^ Int.toString (Seq.length s) ^ " " + fun print_seq si s se = (print (si ^ " "); Seq.foreach s (fn (_,v) => print ((Int.toString v)^ " ") ) ; print (" " ^ se ^ "\n")) +end diff --git a/tests/bench/low-d-decomp/low-d-decomp.mlb b/tests/bench/low-d-decomp/low-d-decomp.mlb new file mode 100644 index 000000000..1116e5502 --- /dev/null +++ b/tests/bench/low-d-decomp/low-d-decomp.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +ldd-alt.sml +main.sml + diff --git a/tests/bench/low-d-decomp/main.sml b/tests/bench/low-d-decomp/main.sml new file mode 100644 index 000000000..ebf7d2163 --- /dev/null +++ b/tests/bench/low-d-decomp/main.sml @@ -0,0 +1,80 @@ +structure CLA = CommandLineArgs +structure G = AdjacencyGraph(Int) + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + + +val b = (CommandLineArgs.parseReal "b" 0.3) + +val (cluster, parent) = + Benchmark.run "running ldd: " (fn _ => LDD.ldd graph b) + +(* val numClusters = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length cluster) (fn i => + if Seq.nth cluster i = i then 1 else 0) *) +(* val _ = print ("num clusters " ^ Int.toString numClusters ^ "\n") *) + +(* val _ = print ("num-triangles = " ^ (Int.toString P) ^ "\n") *) +(* val _ = LDD.check_ldd graph cluster parent *) +(* val _ = Benchmark.run "running connectivity" (fn _ => LDD.connectivity graph b) *) +(* +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + +val _ = GCStats.report () *) diff --git a/tests/bench/max-indep-set/MIS.sml b/tests/bench/max-indep-set/MIS.sml new file mode 100644 index 000000000..adaf94935 --- /dev/null +++ b/tests/bench/max-indep-set/MIS.sml @@ -0,0 +1,105 @@ +structure MIS = +struct + + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + structure AS = ArraySlice + + fun zero_out_neighbors g roots pr = + let + fun updateseq (s, d) = + if (Array.sub (pr, d) = 0) then NONE + else + (Array.update (pr, d, 0); SOME d) + + fun updatepar (s, d) = + let + val v = Array.sub (pr, d) + in + if (v = 0) then NONE + else if (Concurrency.casArray (pr, d) (v, 0) = v) then (SOME d) + else NONE + end + + val update = (updatepar, updatepar) + val removed = AdjInt.edge_map g roots update (fn d => Array.sub (pr, d) > 0) + in + removed + end + + fun strip s = + let val (s', start, _) = AS.base s + in if start = 0 then s' else raise Fail "strip base <> 0" + end + + fun mis g = + let + val (n, m) = (G.numVertices g, G.numEdges g) + val deg = Int.div (m, n) + val P = Shuffle.shuffle (Seq.tabulate (fn i => i) n) 0 + val pr = SeqBasis.tabulate 100 (0, n) + (fn i => let + val my_prio = Seq.nth P i + in + Seq.iterate (fn (acc, j) => acc + (if (Seq.nth P j) < my_prio then 1 else 0)) 0 (G.neighbors g i) + end) + val pr_seq = AS.full pr + + val denseThreshold = G.numEdges g div (20*(1 + deg)) + val sparse_rep = AS.full (SeqBasis.filter 10000 (0, n) (fn i => i) (fn i => Seq.nth pr_seq i = 0)) + val ind_set = Seq.tabulate (fn _ => false) n + + val roots = AdjInt.from_sparse_rep sparse_rep denseThreshold n + fun loop_roots finished roots = + if finished < n then + let + val _ = AdjInt.vertex_foreach g roots (fn u => if (Seq.nth ind_set u) then () else AS.update (ind_set, u, true)) + val removed = zero_out_neighbors g roots pr + fun decrement_priority_par (s, d) = + let + val p = Seq.nth P s + val q = Seq.nth P d + in + if p >= q then NONE + else if (Concurrency.faaArray (pr, d) (~1) = 1) then SOME d + else NONE + end + fun decrement_priority_seq (s, d) = + let + val p = Seq.nth P s + val q = Seq.nth P d + val prio = Array.sub (pr, d) + val _ = if prio <= 0 orelse p >= q then () + else Array.update (pr, d, prio-1) + in + if (prio = 1) then SOME d + else NONE + end + val dec = (decrement_priority_par, decrement_priority_seq) + val new_roots = AdjInt.edge_map g removed dec (fn d => Array.sub (pr, d) > 0) + in + loop_roots (finished + (AdjInt.size roots) + (AdjInt.size removed)) new_roots + end + else () + in + loop_roots 0 roots; + ind_set + end + + fun verify_mis g ind_set = + let + val (n, m) = (G.numVertices g, G.numEdges g) + val int_ind = Seq.tabulate (fn i => 0) n + fun ok_f u = + let + val count = Seq.iterate (fn (acc, b) => if (Seq.nth ind_set b) then acc + 1 else acc) 0 (G.neighbors g u) + in + (Seq.nth ind_set u) orelse (not (count = 0)) + end + val bool_ok = Seq.tabulate ok_f n + val all_ok = Seq.reduce (fn (b, acc) => b andalso acc) true bool_ok + in + if all_ok then () + else print ("Invalid Independent Set\n") + end +end diff --git a/tests/bench/max-indep-set/faa.mlton.sml b/tests/bench/max-indep-set/faa.mlton.sml new file mode 100644 index 000000000..fa2992415 --- /dev/null +++ b/tests/bench/max-indep-set/faa.mlton.sml @@ -0,0 +1,11 @@ +structure Concurrency = +struct + open Concurrency + + fun faaArray (a, i) x = + let + val rx = Array.sub (a, i) + in + (Array.update (a, i, rx + x); rx) + end +end diff --git a/tests/bench/max-indep-set/faa.mpl.sml b/tests/bench/max-indep-set/faa.mpl.sml new file mode 100644 index 000000000..3056feee4 --- /dev/null +++ b/tests/bench/max-indep-set/faa.mpl.sml @@ -0,0 +1,5 @@ +structure Concurrency = +struct + open Concurrency + val faaArray = MLton.Parallel.arrayFetchAndAdd +end diff --git a/tests/bench/max-indep-set/main.sml b/tests/bench/max-indep-set/main.sml new file mode 100644 index 000000000..804e23b2d --- /dev/null +++ b/tests/bench/max-indep-set/main.sml @@ -0,0 +1,86 @@ +structure CLA = CommandLineArgs +structure G = AdjacencyGraph(Int) + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + + +val b = (CommandLineArgs.parseReal "b" 0.3) + +val ind_set = + Benchmark.run "running independent set: " (fn _ => MIS.mis graph) + +val c = Seq.reduce op+ 0 (Seq.map (fn i => if i then 1 else 0) ind_set) +val _ = print ("num elements = " ^ (Int.toString c) ^ "\n") + +val _ = MIS.verify_mis graph ind_set + + +(* val numClusters = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length cluster) (fn i => + if Seq.nth cluster i = i then 1 else 0) *) +(* val _ = print ("num clusters " ^ Int.toString numClusters ^ "\n") *) + +(* val _ = print ("num-triangles = " ^ (Int.toString P) ^ "\n") *) +(* val _ = LDD.check_ldd graph cluster parent *) +(* val _ = Benchmark.run "running connectivity" (fn _ => LDD.connectivity graph b) *) +(* +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + +val _ = GCStats.report () *) diff --git a/tests/bench/max-indep-set/max-indep-set.mlb b/tests/bench/max-indep-set/max-indep-set.mlb new file mode 100644 index 000000000..24e29325c --- /dev/null +++ b/tests/bench/max-indep-set/max-indep-set.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +faa.$(COMPAT).sml +MIS.sml +main.sml + diff --git a/tests/bench/mcss-opt/MCSS.sml b/tests/bench/mcss-opt/MCSS.sml new file mode 100644 index 000000000..4ed408de2 --- /dev/null +++ b/tests/bench/mcss-opt/MCSS.sml @@ -0,0 +1,30 @@ +structure MCSS = +struct + + val max = Real.max + + fun combine((l1,r1,b1,t1),(l2,r2,b2,t2)) = + (max(l1, t1+l2), + max(r2, r1+t2), + max(r1+l2, max(b1,b2)), + t1+t2) + + val id = (0.0, 0.0, 0.0, 0.0) + + fun singleton v = + let + val vp = max (v, 0.0) + in + (vp, vp, vp, v) + end + + fun mcss (s : real Seq.t) : real = + let + val (_,_,b,_) = + SeqBasis.reduce 5000 combine id (0, Seq.length s) + (fn i => singleton (Seq.nth s i)) + in + b + end + +end diff --git a/tests/bench/mcss-opt/main.sml b/tests/bench/mcss-opt/main.sml new file mode 100644 index 000000000..845f60dc3 --- /dev/null +++ b/tests/bench/mcss-opt/main.sml @@ -0,0 +1,20 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure M = MCSS + +val n = CLA.parseInt "n" (1000 * 1000 * 100) + +fun gen i = + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), 0w1000)) - 500) / 500.0 + +val input = + Seq.tabulate gen n + +fun task () = + M.mcss input + +val result = Benchmark.run "mcss" task +val _ = print ("result " ^ Real.toString result ^ "\n") + +val _ = print ("input " ^ Util.summarizeArraySlice 12 Real.toString input ^ "\n") diff --git a/tests/bench/mcss-opt/mcss-opt.mlb b/tests/bench/mcss-opt/mcss-opt.mlb new file mode 100644 index 000000000..9a97da9bb --- /dev/null +++ b/tests/bench/mcss-opt/mcss-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MCSS.sml +main.sml diff --git a/tests/bench/mcss/MkMapReduceMCSS.sml b/tests/bench/mcss/MkMapReduceMCSS.sml new file mode 100644 index 000000000..7a1af36a1 --- /dev/null +++ b/tests/bench/mcss/MkMapReduceMCSS.sml @@ -0,0 +1,29 @@ +functor MkMapReduceMCSS (Seq : SEQUENCE) = +struct + + val max = Real.max + + fun combine((l1,r1,b1,t1),(l2,r2,b2,t2)) = + (max(l1, t1+l2), + max(r2, r1+t2), + max(r1+l2, max(b1,b2)), + t1+t2) + + val id = (0.0, 0.0, 0.0, 0.0) + + fun singleton v = + let + val vp = max (v, 0.0) + in + (vp, vp, vp, v) + end + + fun mcss (s : real ArraySequence.t) : real = + let + val (_,_,b,_) = + Seq.reduce combine id (Seq.map singleton (Seq.fromArraySeq s)) + in + b + end + +end diff --git a/tests/bench/mcss/MkScanMCSS.sml b/tests/bench/mcss/MkScanMCSS.sml new file mode 100644 index 000000000..f037c7c44 --- /dev/null +++ b/tests/bench/mcss/MkScanMCSS.sml @@ -0,0 +1,24 @@ +functor MkScanMCSS (Seq: SEQUENCE) = +struct + + fun mcss (s: real ArraySequence.t) : real = + let + val s = Seq.fromArraySeq s + val t = Util.startTiming () + + val p = Seq.scanIncl op+ 0.0 s + val t = Util.tick t "plus scan" + + val (m, _) = Seq.scan Real.min Real.posInf p + val t = Util.tick t "min scan" + + val b = Seq.zipWith op- (p, m) + val t = Util.tick t "zipWith" + + val result = Seq.reduce Real.max Real.negInf b + val t = Util.tick t "reduce" + in + result + end + +end diff --git a/tests/bench/mcss/main.sml b/tests/bench/mcss/main.sml new file mode 100644 index 000000000..43dd22a26 --- /dev/null +++ b/tests/bench/mcss/main.sml @@ -0,0 +1,18 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure M = MkMapReduceMCSS (DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) + +fun gen i = + Real.fromInt ((Util.hash i) mod 1000 - 500) / 500.0 + +val input = + Seq.tabulate gen n + +fun task () = + M.mcss input + +val result = Benchmark.run "mcss" task +val _ = print ("result " ^ Real.toString result ^ "\n") diff --git a/tests/bench/mcss/mcss.mlb b/tests/bench/mcss/mcss.mlb new file mode 100644 index 000000000..1d654745f --- /dev/null +++ b/tests/bench/mcss/mcss.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkMapReduceMCSS.sml +main.sml diff --git a/tests/bench/msort-int32/msort-int32.mlb b/tests/bench/msort-int32/msort-int32.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/msort-int32/msort-int32.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/msort-int32/msort.sml b/tests/bench/msort-int32/msort.sml new file mode 100644 index 000000000..c88cd5aca --- /dev/null +++ b/tests/bench/msort-int32/msort.sml @@ -0,0 +1,16 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" (100 * 1000 * 1000) +val _ = print ("N " ^ Int.toString n ^ "\n") + +val _ = print ("generating " ^ Int.toString n ^ " random integers\n") + +fun elem i = + Int32.fromInt (Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt n))) +val input = ArraySlice.full (SeqBasis.tabulate 10000 (0, n) elem) + +val result = + Benchmark.run "running mergesort" (fn _ => Mergesort.sort Int32.compare input) + +val _ = print ("result " ^ Util.summarizeArraySlice 8 Int32.toString result ^ "\n") + diff --git a/tests/bench/msort-strings/msort-strings.mlb b/tests/bench/msort-strings/msort-strings.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/msort-strings/msort-strings.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/msort-strings/msort.sml b/tests/bench/msort-strings/msort.sml new file mode 100644 index 000000000..6986e6ab2 --- /dev/null +++ b/tests/bench/msort-strings/msort.sml @@ -0,0 +1,34 @@ +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: msort-strings FILE\n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val filename = + case CLA.positional () of + [x] => x + | _ => usage () + +val makeLong = CLA.parseFlag "long" + +val (contents, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filename) +val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") +val (tokens, tm) = Util.getTime (fn _ => Tokenize.tokens Char.isSpace contents) +val _ = print ("tokenized in " ^ Time.fmt 4 tm ^ "s\n") + +val prefix = CharVector.tabulate (32, fn _ => #"a") + +val tokens = + if not makeLong then tokens + else Seq.map (fn str => prefix ^ str) tokens + +val result = + Benchmark.run "running mergesort" (fn _ => Mergesort.sort String.compare tokens) + +val _ = print ("result " ^ Util.summarizeArraySlice 8 (fn x => x) result ^ "\n") + diff --git a/tests/bench/msort/msort.mlb b/tests/bench/msort/msort.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/msort/msort.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/msort/msort.sml b/tests/bench/msort/msort.sml new file mode 100644 index 000000000..7f3ab9e0d --- /dev/null +++ b/tests/bench/msort/msort.sml @@ -0,0 +1,17 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "n" (100 * 1000 * 1000) +val _ = print ("N " ^ Int.toString n ^ "\n") + +val _ = print ("generating " ^ Int.toString n ^ " random integers\n") + +fun elem i = + Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt n)) +val input = ArraySlice.full (SeqBasis.tabulate 10000 (0, n) elem) + +val result = + Benchmark.run "running mergesort" (fn _ => Mergesort.sort Int.compare input) + +val _ = print ("input " ^ Util.summarizeArraySlice 8 Int.toString input ^ "\n") +val _ = print ("result " ^ Util.summarizeArraySlice 8 Int.toString result ^ "\n") + diff --git a/tests/bench/nearest-nbrs/ParseFile.sml b/tests/bench/nearest-nbrs/ParseFile.sml new file mode 100644 index 000000000..c7eb4e101 --- /dev/null +++ b/tests/bench/nearest-nbrs/ParseFile.sml @@ -0,0 +1,190 @@ +(** SAM_NOTE: copy/pasted... some repetition here with Parse. *) +structure ParseFile = +struct + + structure RF = ReadFile + structure Seq = ArraySequence + structure DS = OldDelayedSeq + + fun tokens (f: char -> bool) (cs: char Seq.t) : (char DS.t) DS.t = + let + val n = Seq.length cs + val s = DS.tabulate (Seq.nth cs) n + val indices = DS.tabulate (fn i => i) (n+1) + fun check i = + if (i = n) then not (f(DS.nth s (n-1))) + else if (i = 0) then not (f(DS.nth s 0)) + else let val i1 = f (DS.nth s i) + val i2 = f (DS.nth s (i-1)) + in (i1 andalso not i2) orelse (i2 andalso not i1) end + val ids = DS.filter check indices + val res = DS.tabulate (fn i => + let val (start, e) = (DS.nth ids (2*i), DS.nth ids (2*i+1)) + in DS.tabulate (fn i => Seq.nth cs (start+i)) (e - start) + end) + ((DS.length ids) div 2) + in + res + end + + fun eqStr str (chars : char DS.t) = + let + val n = String.size str + fun checkFrom i = + i >= n orelse + (String.sub (str, i) = DS.nth chars i andalso checkFrom (i+1)) + in + DS.length chars = n + andalso + checkFrom 0 + end + + fun parseDigit char = + let + val code = Char.ord char + val code0 = Char.ord #"0" + val code9 = Char.ord #"9" + in + if code < code0 orelse code9 < code then + NONE + else + SOME (code - code0) + end + + (* This implementation doesn't work with mpl :( + * Need to fix the basis library... *) + (* + fun parseReal chars = + let + val str = CharVector.tabulate (DS.length chars, DS.nth chars) + in + Real.fromString str + end + *) + + fun parseInt (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun build x i = + if i >= n then SOME x else + case c i of + #"," => build x (i+1) + | #"_" => build x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => build (x * 10 + dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1) (build 0 1) + else if (c 0 = #"+") then + build 0 1 + else + build 0 0 + end + + fun parseReal (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun buildAfterE x i = + let + val chars' = DS.subseq chars (i, n-i) + in + Option.map (fn e => x * Math.pow (10.0, Real.fromInt e)) + (parseInt chars') + end + + fun buildAfterPoint m x i = + if i >= n then SOME x else + case c i of + #"," => buildAfterPoint m x (i+1) + | #"_" => buildAfterPoint m x (i+1) + | #"." => NONE + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildAfterPoint (m * 0.1) (x + m * (Real.fromInt dig)) (i+1) + + fun buildBeforePoint x i = + if i >= n then SOME x else + case c i of + #"," => buildBeforePoint x (i+1) + | #"_" => buildBeforePoint x (i+1) + | #"." => buildAfterPoint 0.1 x (i+1) + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildBeforePoint (x * 10.0 + Real.fromInt dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1.0) (buildBeforePoint 0.0 1) + else + buildBeforePoint 0.0 0 + end + + fun readSequencePoint2d filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequencePoint2d" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun r i = Option.valOf (parseReal (tok (1 + i))) + + fun pt i = + (r (2*i), r (2*i+1)) + handle e => raise Fail ("error parsing point " ^ Int.toString i ^ " (" ^ exnMessage e ^ ")") + + val result = Seq.tabulate pt (n div 2) + in + result + end + + fun readSequenceInt filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceInt" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseInt (tok (1 + i))) + handle e => raise Fail ("error parsing integer " ^ Int.toString i) + in + Seq.tabulate p n + end + + fun readSequenceReal filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceDouble" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseReal (tok (1 + i))) + handle e => raise Fail ("error parsing double value " ^ Int.toString i) + in + Seq.tabulate p n + end + +end diff --git a/tests/bench/nearest-nbrs/main.sml b/tests/bench/nearest-nbrs/main.sml new file mode 100644 index 000000000..b358805ec --- /dev/null +++ b/tests/bench/nearest-nbrs/main.sml @@ -0,0 +1,160 @@ +structure NN = NearestNeighbors +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" 1000000 +val inputFile = CLA.parseString "input" "" +val leafSize = CLA.parseInt "leafSize" 50 +val grain = CLA.parseInt "grain" 1000 +val seed = CLA.parseInt "seed" 15210 + +fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + +fun genPoint i = (genReal (2*i), genReal (2*i + 1)) + +(* This silly thing helps ensure good placement, by + * forcing points to be reallocated more adjacent. + * It's a no-op, but gives us as much as 2x time + * improvement (!) + *) +fun swap pts = Seq.map (fn (x, y) => (y, x)) pts +fun compactify pts = swap (swap pts) + +val input = + case inputFile of + "" => + let + val (input, tm) = Util.getTime (fn _ => Seq.tabulate genPoint n) + in + print ("generated input in " ^ Time.fmt 4 tm ^ "s\n"); + input + end + + | filename => + let + val (points, tm) = Util.getTime (fn _ => + compactify (ParseFile.readSequencePoint2d filename)) + in + print ("parsed input points in " ^ Time.fmt 4 tm ^ "s\n"); + points + end + +fun nnEx() = + let + val (tree, tm) = Util.getTime (fn _ => NN.makeTree leafSize input) + val _ = print ("built quadtree in " ^ Time.fmt 4 tm ^ "s\n") + + val (nbrs, tm) = Util.getTime (fn _ => NN.allNearestNeighbors grain tree) + val _ = print ("found all neighbors in " ^ Time.fmt 4 tm ^ "s\n") + in + (tree, nbrs) + end + +val (tree, nbrs) = Benchmark.run "running nearest neighbors" nnEx +val _ = + print ("result " ^ Util.summarizeArraySlice 12 Int.toString nbrs ^ "\n") + +(* now input[nbrs[i]] is the closest point to input[i] *) + +(* ========================================================================== + * write image to output + * this only works if all input points are within [0,1) *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then () + else ( print ("to see output, use -output and -resolution arguments\n" ^ + "for example: nn -N 10000 -output result.ppm -resolution 1000\n") + ; GCStats.report () + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val width = resolution +val height = resolution + +val image = + { width = width + , height = height + , data = Seq.tabulate (fn _ => Color.white) (width*height) + } + +fun set (i, j) x = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then ArraySlice.update (#data image, i*width + j, x) + else () + +val r = Real.fromInt resolution +fun px x = Real.floor (x * r) +fun pos (x, y) = (resolution - px x - 1, px y) + +fun horizontalLine i (j0, j1) = + if j1 < j0 then horizontalLine i (j1, j0) + else Util.for (j0, j1) (fn j => set (i, j) Color.red) + +fun sign xx = + case Int.compare (xx, 0) of LESS => ~1 | EQUAL => 0 | GREATER => 1 + +(* Bresenham's line algorithm *) +fun line (x1, y1) (x2, y2) = + let + val w = x2 - x1 + val h = y2 - y1 + val dx1 = sign w + val dy1 = sign h + val (longest, shortest, dx2, dy2) = + if Int.abs w > Int.abs h then + (Int.abs w, Int.abs h, dx1, 0) + else + (Int.abs h, Int.abs w, 0, dy1) + + fun loop i numerator x y = + if i > longest then () else + let + val numerator = numerator + shortest; + in + set (x, y) Color.red; + if numerator >= longest then + loop (i+1) (numerator-longest) (x+dx1) (y+dy1) + else + loop (i+1) numerator (x+dx2) (y+dy2) + end + in + loop 0 (longest div 2) x1 y1 + end + +(* mark all nearest neighbors with straight red lines *) +val t0 = Time.now () + +val _ = ForkJoin.parfor 10000 (0, Seq.length input) (fn i => + line (pos (Seq.nth input i)) (pos (Seq.nth input (Seq.nth nbrs i)))) + +(* mark input points as a pixel *) +val _ = + ForkJoin.parfor 10000 (0, Seq.length input) (fn i => + let + val (x, y) = pos (Seq.nth input i) + fun b spot = set spot Color.black + in + b (x-1, y); + b (x, y-1); + b (x, y); + b (x, y+1); + b (x+1, y) + end) + +val t1 = Time.now () + +val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val (_, tm) = Util.getTime (fn _ => PPM.write filename image) +val _ = print ("wrote to " ^ filename ^ " in " ^ Time.fmt 4 tm ^ "s\n") + diff --git a/tests/bench/nearest-nbrs/nearest-nbrs.mlb b/tests/bench/nearest-nbrs/nearest-nbrs.mlb new file mode 100644 index 000000000..2c6fe21d1 --- /dev/null +++ b/tests/bench/nearest-nbrs/nearest-nbrs.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +ParseFile.sml +main.sml diff --git a/tests/bench/nqueens-simple/main.sml b/tests/bench/nqueens-simple/main.sml new file mode 100644 index 000000000..4f05b84ce --- /dev/null +++ b/tests/bench/nqueens-simple/main.sml @@ -0,0 +1,104 @@ +(* ========================================================================== + * VERSION 1: NO GRANULARITY CONTROL + * ========================================================================== + *) + + +(* compute the sum of f(lo), f(lo+1), ..., f(hi-1) *) +fun sum (lo, hi, f) = + if lo >= hi then + 0 + else if lo + 1 = hi then + f lo + else + let + val mid = lo + (hi - lo) div 2 + val (left, right) = ForkJoin.par (fn () => sum (lo, mid, f), fn () => + sum (mid, hi, f)) + in + left + right + end + + +(* queens at positions (row, col) *) +type locations = (int * int) list + + +fun queen_is_threatened (i, j) (other_queens: locations) = + List.exists + (fn (x, y) => i = x orelse j = y orelse i - j = x - y orelse i + j = x + y) + other_queens + + +fun nqueens_count_solutions n = + let + fun search i queens = + if i >= n then + 1 + else + let + fun do_column j = + if queen_is_threatened (i, j) queens then 0 + else search (i + 1) ((i, j) :: queens) + in + sum (0, n, do_column) + end + in + search 0 [] + end + + +(* ========================================================================== + * VERSION 2: MANUAL GRANULARITY CONTROL + * ========================================================================== + *) + + +(* sequential alternative *) +fun sum_serial (lo, hi, f) = + Util.loop (lo, hi) 0 (fn (acc, i) => acc + f i) + + +fun nqueens_count_solutions_manual_gran_control n = + let + fun search i queens = + if i >= n then + 1 + else + let + fun do_column j = + if queen_is_threatened (i, j) queens then 0 + else search (i + 1) ((i, j) :: queens) + in + if i >= 3 then + (* simple heuristic for granularity control: swich to sequential + * algorithm after getting a few levels deep. + *) + sum_serial (0, n, do_column) + else + sum (0, n, do_column) + end + in + search 0 [] + end + + +(* ========================================================================== + * parse command-line arguments and run + * ========================================================================== + *) + +val n = CommandLineArgs.parseInt "N" 13 +val do_gran_control = CommandLineArgs.parseFlag "do-gran-control" +val _ = print ("N " ^ Int.toString n ^ "\n") +val _ = print + ("do-gran-control? " ^ (if do_gran_control then "yes" else "no") ^ "\n") + +val msg = + "counting number of " ^ Int.toString n ^ "x" ^ Int.toString n ^ " solutions" + +val result = Benchmark.run msg (fn _ => + if do_gran_control then nqueens_count_solutions_manual_gran_control n + else nqueens_count_solutions n) + +val _ = print ("result " ^ Int.toString result ^ "\n") diff --git a/tests/bench/nqueens-simple/nqueens-simple.mlb b/tests/bench/nqueens-simple/nqueens-simple.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/nqueens-simple/nqueens-simple.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/nqueens/nqueens.mlb b/tests/bench/nqueens/nqueens.mlb new file mode 100644 index 000000000..889d11f07 --- /dev/null +++ b/tests/bench/nqueens/nqueens.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +nqueens.sml diff --git a/tests/bench/nqueens/nqueens.sml b/tests/bench/nqueens/nqueens.sml new file mode 100644 index 000000000..765b0c13c --- /dev/null +++ b/tests/bench/nqueens/nqueens.sml @@ -0,0 +1,39 @@ +structure CLA = CommandLineArgs + +type board = (int * int) list + +fun threatened (i,j) [] = false + | threatened (i,j) ((x,y)::Q) = + i = x orelse j = y orelse i-j = x-y orelse i+j = x+y + orelse threatened (i,j) Q + +structure Seq = FuncSequence + +fun countSol n = + let + fun search i b = + if i >= n then 1 else + let + fun tryCol j = + if threatened (i, j) b then 0 else search (i+1) ((i,j)::b) + in + if i >= 3 then + (* if we're already a few levels deep, then just go sequential *) + Seq.iterate op+ 0 (Seq.tabulate tryCol n) + else + Seq.reduce op+ 0 (Seq.tabulate tryCol n) + end + in + search 0 [] + end + +val n = CommandLineArgs.parseInt "n" 13 +val _ = print ("n " ^ Int.toString n ^ "\n") + +val msg = + "counting number of " ^ Int.toString n ^ "x" ^ Int.toString n ^ " solutions" + +val result = Benchmark.run msg (fn _ => countSol n) + +val _ = print ("result " ^ Int.toString result ^ "\n") + diff --git a/tests/bench/ocaml-binarytrees5/main.sml b/tests/bench/ocaml-binarytrees5/main.sml new file mode 100644 index 000000000..cd328b38f --- /dev/null +++ b/tests/bench/ocaml-binarytrees5/main.sml @@ -0,0 +1,90 @@ +structure CLA = CommandLineArgs + +val max_depth = CLA.parseInt "max_depth" 10 +val num_domains = Concurrency.numberOfProcessors + +val _ = print ("max_depth " ^ Int.toString max_depth ^ "\n") + +datatype tree = Empty | Node of tree * tree + +fun make d = + if d = 0 then Node (Empty, Empty) + else let val d = d-1 + in Node (make d, make d) + end + +fun check t = + case t of + Empty => 0 + | Node (l, r) => 1 + check l + check r + +val min_depth = 4 +val max_depth = Int.max (min_depth + 2, max_depth) +val stretch_depth = max_depth + 1 + +val _ = check (make stretch_depth) + +val long_lived_tree = make max_depth + +val values = + Array.array (num_domains, 0) + +fun calculate d st en ind = + let + val c = ref 0 + in + Util.for (st, en+1) (fn _ => + c := !c + check (make d) + ); + Array.update (values, ind, !c) + end + +fun calculate d st en ind = + let + val c = + Util.loop (st, en+1) 0 (fn (c, _) => + c + check (make d) + ) + in + Array.update (values, ind, c) + end + +fun parfor g (i, j) f = + if j-i <= 1 then + (** MPL relies on `par` for its GC policy. This particular benchmark + * happens to not call par on 1 processor, so let's fix that. + *) + (ForkJoin.par (fn _ => ForkJoin.parfor g (i, j) f, fn () => ()); ()) + else + ForkJoin.parfor g (i, j) f + + +fun loop_depths d = + Util.for (0, (max_depth - d) div 2 + 1) (fn i => + let + val d = d + i * 2 + val niter = Util.pow2 (max_depth - d + min_depth) + in + (* ocaml source does N-way async/await loop, but this is just + * a parallel for. *) + parfor 1 (0, num_domains) (fn i => + calculate d + (i * niter div num_domains) + (((i + 1) * niter div num_domains) - 1) + i); + + Array.foldl op+ 0 values; + + () + end) + +val result = Benchmark.run "running binary trees" (fn _ => + let + in + loop_depths min_depth; + check long_lived_tree + end) + +val _ = print ("result " ^ Int.toString result ^ "\n") +val _ = print ("values " ^ Util.summarizeArray 10 Int.toString values ^ "\n") + diff --git a/tests/bench/ocaml-binarytrees5/ocaml-binarytrees5.mlb b/tests/bench/ocaml-binarytrees5/ocaml-binarytrees5.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/ocaml-binarytrees5/ocaml-binarytrees5.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/ocaml-binarytrees5/ocaml-source.ml b/tests/bench/ocaml-binarytrees5/ocaml-source.ml new file mode 100644 index 000000000..031bcd38b --- /dev/null +++ b/tests/bench/ocaml-binarytrees5/ocaml-source.ml @@ -0,0 +1,74 @@ +(* Copied from + * https://github.com/ocaml-bench/sandmark + * file benchmarks/multicore-numerical/binarytrees5_multicore.ml + * commit 861f568d869c95bc9aa8fc1fd90a13ab6cbe7afb + *) + +module T = Domainslib.Task + +let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +let max_depth = try int_of_string Sys.argv.(2) with _ -> 10 +let pool = T.setup_pool ~num_domains:(num_domains - 1) + +type 'a tree = Empty | Node of 'a tree * 'a tree + +let rec make d = +(* if d = 0 then Empty *) + if d = 0 then Node(Empty, Empty) + else let d = d - 1 in Node(make d, make d) + +let rec check t = + Domain.Sync.poll (); + match t with + | Empty -> 0 + | Node(l, r) -> 1 + check l + check r + +let min_depth = 4 +let max_depth = max (min_depth + 2) max_depth +let stretch_depth = max_depth + 1 + +let () = + (* Gc.set { (Gc.get()) with Gc.minor_heap_size = 1024 * 1024; max_overhead = -1; }; *) + let _ = check (make stretch_depth) in + () + (* Printf.printf "stretch tree of depth %i\t check: %i\n" stretch_depth c *) + +let long_lived_tree = make max_depth + +let values = Array.make num_domains 0 + +let calculate d st en ind = + (* Printf.printf "st = %d en = %d\n" st en; *) + let c = ref 0 in + for _ = st to en do + c := !c + check (make d) + done; + (* Printf.printf "ind = %d\n" ind; *) + values.(ind) <- !c + +let loop_depths d = + for i = 0 to ((max_depth - d) / 2 + 1) - 1 do + let d = d + i * 2 in + let niter = 1 lsl (max_depth - d + min_depth) in + let rec loop acc i num_domains = + if i = num_domains then begin + List.rev acc |> List.iter (fun pr -> T.await pool pr) + end else begin + loop + ((T.async pool (fun _ -> + calculate d (i * niter / num_domains) (((i + 1) * niter / num_domains) - 1) i)) :: acc) + (i + 1) + num_domains + end in + + loop [] 0 num_domains; + let _ = Array.fold_left (+) 0 values in + () + (* Printf.printf "%i\t trees of depth %i\t check: %i\n" niter d sum *) + done + +let () = + loop_depths min_depth; + let _ = max_depth in + let _ = (check long_lived_tree) in + T.teardown_pool pool diff --git a/tests/bench/ocaml-game-of-life-pure/main.sml b/tests/bench/ocaml-game-of-life-pure/main.sml new file mode 100644 index 000000000..956aa4c6a --- /dev/null +++ b/tests/bench/ocaml-game-of-life-pure/main.sml @@ -0,0 +1,114 @@ +structure CLA = CommandLineArgs +val n_times = CLA.parseInt "n_times" 2 +val board_size = CLA.parseInt "board_size" 1024 +val _ = print ("n_times " ^ Int.toString n_times ^ "\n") +val _ = print ("board_size " ^ Int.toString board_size ^ "\n") + +fun randBool pos = + Util.hash pos mod 2 + +val bs = board_size + +val g = + PureSeq.tabulate (fn i => PureSeq.tabulate (fn j => randBool (i*bs + j)) bs) bs + +fun get g x y = + PureSeq.nth (PureSeq.nth g x) y + handle _ => 0 + +fun neighbourhood g x y = + (get g (x-1) (y-1)) + + (get g (x-1) (y )) + + (get g (x-1) (y+1)) + + (get g (x ) (y-1)) + + (get g (x ) (y+1)) + + (get g (x+1) (y-1)) + + (get g (x+1) (y )) + + (get g (x+1) (y+1)) + +fun next_cell g x y = + let + val n = neighbourhood g x y + in + (* Why not just write it like this?? + case (Seq.nth (Seq.nth g x) y, n) of + (1, 2) => 1 (* lives *) + | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ => 0 + *) + + (* I could enable MLton or-patterns, but whatever *) + case (PureSeq.nth (PureSeq.nth g x) y, n) of + (1, 0) => 0 (* lonely *) + | (1, 1) => 0 (* lonely *) + | (1, 4) => 0 (* overcrowded *) + | (1, 5) => 0 (* overcrowded *) + | (1, 6) => 0 (* overcrowded *) + | (1, 7) => 0 (* overcrowded *) + | (1, 8) => 0 (* overcrowded *) + | (1, 2) => 1 (* lives *) + | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ (* 0, (0|1|2|4|5|6|7|8) *) => 0 (* barren *) + + (* With or-patterns it would look like: + case (Seq.nth (Seq.nth g x) y, n) of + (1, 0) | (1, 1) => 0 (* lonely *) + | (1, 4) | (1, 5) | (1, 6) | (1, 7) | (1, 8) => 0 (* overcrowded *) + | (1, 2) | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ (* 0, (0|1|2|4|5|6|7|8) *) => 0 (* barren *) + *) + end + +fun loop curr remaining = + if remaining <= 0 then + curr + else + let + (* SAM_NOTE: this is my own granularity control. The ocaml source does + * static partitioning based on num_domains, but this is unnecessary. + * Just choose a static granularity that is reasonable, and then it will + * work decently for any number of processors. *) + val target_granularity = 5000 + val chunk_size = Int.max (1, target_granularity div board_size) + + val next = + PureSeq.tabulateG chunk_size (fn x => + PureSeq.tabulate (fn y => next_cell curr x y) board_size) + board_size + in + loop next (remaining-1) + end + +val msg = "doing " ^ Int.toString n_times ^ " iterations" +val result = Benchmark.run msg (fn _ => loop g n_times) + +(* =========================================================================== + * SAM_NOTE: rest is my stuff. Just outputting the result. + *) + +val output = CLA.parseString "output" "" +val _ = + if output = "" then + print ("use -output XXX.ppm to see result\n") + else + let + val g = result + + fun color 0 = Color.white + | color _ = Color.black + + val image = + { height = board_size + , width = board_size + , data = Seq.tabulate (fn k => + color (get g (k div board_size) (k mod board_size))) + (board_size * board_size) + } + val (_, tm) = Util.getTime (fn _ => PPM.write output image) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end + diff --git a/tests/bench/ocaml-game-of-life-pure/ocaml-game-of-life-pure.mlb b/tests/bench/ocaml-game-of-life-pure/ocaml-game-of-life-pure.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/ocaml-game-of-life-pure/ocaml-game-of-life-pure.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/ocaml-game-of-life/main.sml b/tests/bench/ocaml-game-of-life/main.sml new file mode 100644 index 000000000..25da1ee9b --- /dev/null +++ b/tests/bench/ocaml-game-of-life/main.sml @@ -0,0 +1,231 @@ +(* ocaml source: + * +## let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +## let n_times = try int_of_string Sys.argv.(2) with _ -> 2 +## let board_size = try int_of_string Sys.argv.(3) with _ -> 1024 + *) +structure CLA = CommandLineArgs +val n_times = CLA.parseInt "n_times" 2 +val board_size = CLA.parseInt "board_size" 1024 +val _ = print ("n_times " ^ Int.toString n_times ^ "\n") +val _ = print ("board_size " ^ Int.toString board_size ^ "\n") + +(* ocaml source: + * +## module T = Domainslib.Task +## +## let rg = +## ref (Array.init board_size (fun _ -> +## Array.init board_size (fun _ -> Random.int 2))) +## let rg' = +## ref (Array.init board_size (fun _ -> +## Array.init board_size (fun _ -> Random.int 2))) +## let buf = Bytes.create board_size + * + * The buf is not used. + * + * Most obvious adaptation is just nested sequences. We'll use + * a hash function to seed the initial state. + *) + +fun randBool pos = + Util.hash pos mod 2 + +val bs = board_size + +fun vtab f n = Vector.tabulate (n, f) +fun vnth v i = Vector.sub (v, i) + +fun atab f n = Array.tabulate (n, f) +fun anth a i = Array.sub (a, i) + +val rg = + (*ref*) (vtab (fn i => atab (fn j => randBool (i*bs + j)) bs) bs) +val rg' = + (*ref*) (vtab (fn i => atab (fn j => randBool (bs*bs + i*bs + j)) bs) bs) + +(* ocaml source: + * +## let get g x y = +## try g.(x).(y) +## with _ -> 0 +## +## let neighbourhood g x y = +## (get g (x-1) (y-1)) + +## (get g (x-1) (y )) + +## (get g (x-1) (y+1)) + +## (get g (x ) (y-1)) + +## (get g (x ) (y+1)) + +## (get g (x+1) (y-1)) + +## (get g (x+1) (y )) + +## (get g (x+1) (y+1)) +## +## let next_cell g x y = +## let n = neighbourhood g x y in +## match g.(x).(y), n with +## | 1, 0 | 1, 1 -> 0 (* lonely *) +## | 1, 4 | 1, 5 | 1, 6 | 1, 7 | 1, 8 -> 0 (* overcrowded *) +## | 1, 2 | 1, 3 -> 1 (* lives *) +## | 0, 3 -> 1 (* get birth *) +## | _ (* 0, (0|1|2|4|5|6|7|8) *) -> 0 (* barren *) + *) + +fun get g x y = + anth (vnth g x) y + handle _ => 0 + +(* fun neighbourhood g x y = + (get g (x-1) (y-1)) + + (get g (x-1) (y )) + + (get g (x-1) (y+1)) + + (get g (x ) (y-1)) + + (get g (x ) (y+1)) + + (get g (x+1) (y-1)) + + (get g (x+1) (y )) + + (get g (x+1) (y+1)) *) + +fun neighbourhood g x y = + let + fun get_element s y = + anth s y + handle _ => 0 + fun sum_row(x, y) = + let + val gx = vnth g x + in + (get_element gx (y-1)) + (get_element gx y) + (get_element gx (y+1)) + end + handle _ => 0 + in + sum_row(x-1, y) + sum_row(x, y) + sum_row(x + 1, y) + end + +fun next_cell g x y = + let + val n = neighbourhood g x y + in + (* Why not just write it like this?? + case (Seq.nth (Seq.nth g x) y, n) of + (1, 2) => 1 (* lives *) + | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ => 0 + *) + + (* I could enable MLton or-patterns, but whatever *) + case (anth (vnth g x) y, n) of + (1, 0) => 0 (* lonely *) + | (1, 1) => 0 (* lonely *) + | (1, 4) => 0 (* overcrowded *) + | (1, 5) => 0 (* overcrowded *) + | (1, 6) => 0 (* overcrowded *) + | (1, 7) => 0 (* overcrowded *) + | (1, 8) => 0 (* overcrowded *) + | (1, 2) => 1 (* lives *) + | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ (* 0, (0|1|2|4|5|6|7|8) *) => 0 (* barren *) + + (* With or-patterns it would look like: + case (Seq.nth (Seq.nth g x) y, n) of + (1, 0) | (1, 1) => 0 (* lonely *) + | (1, 4) | (1, 5) | (1, 6) | (1, 7) | (1, 8) => 0 (* overcrowded *) + | (1, 2) | (1, 3) => 1 (* lives *) + | (0, 3) => 1 (* get birth *) + | _ (* 0, (0|1|2|4|5|6|7|8) *) => 0 (* barren *) + *) + end + +(* ocaml source: + * +## (* let print g = +## for x = 0 to board_size - 1 do +## for y = 0 to board_size - 1 do +## if g.(x).(y) = 0 +## then Bytes.set buf y '.' +## else Bytes.set buf y 'o' +## done; +## print_endline (Bytes.unsafe_to_string buf) +## done; +## print_endline "" *) +## +## let next pool = +## let g = !rg in +## let new_g = !rg' in +## T.parallel_for pool ~chunk_size:(board_size/num_domains) ~start:0 +## ~finish:(board_size - 1) ~body:(fun x -> +## for y = 0 to board_size - 1 do +## new_g.(x).(y) <- next_cell g x y +## done); +## rg := new_g; +## rg' := g +## +## +## let rec repeat pool n = +## match n with +## | 0-> () +## | _-> next pool; repeat pool (n-1) +## +## let ()= +## let pool = T.setup_pool ~num_domains:(num_domains - 1) in +## (* print !rg; *) +## repeat pool n_times; +## (* print !rg; *) +## T.teardown_pool pool + *) + +fun next (g, new_g) = + let + (* val g = !rg + val new_g = !rg' *) + + (* SAM_NOTE: this is my own granularity control. The ocaml source does + * static partitioning based on num_domains, but this is unnecessary. + * Just choose a static granularity that is reasonable, and then it will + * work decently for any number of processors. *) + val target_granularity = 10000 + val chunk_size = Int.max (1, target_granularity div board_size) + in + ForkJoin.parfor chunk_size (0, board_size) (fn x => + Util.for (0, board_size) (fn y => + Array.update (vnth new_g x, y, next_cell g x y))); + + (new_g, g) + end + +fun repeat state n = + case n of + 0 => state + | _ => repeat (next state) (n-1) + +val msg = "doing " ^ Int.toString n_times ^ " iterations" +val (result, _) = Benchmark.run msg (fn _ => repeat (rg, rg') n_times) + +(* =========================================================================== + * SAM_NOTE: rest is my stuff. Just outputting the result. + *) + +val output = CLA.parseString "output" "" +val _ = + if output = "" then + print ("use -output XXX to see result\n") + else + let + (* val g = !rg *) + val g = result + + fun color 0 = Color.white + | color _ = Color.black + + val image = + { height = board_size + , width = board_size + , data = Seq.tabulate (fn k => + color (get g (k div board_size) (k mod board_size))) + (board_size * board_size) + } + val (_, tm) = Util.getTime (fn _ => PPM.write output image) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end + diff --git a/tests/bench/ocaml-game-of-life/ocaml-game-of-life.mlb b/tests/bench/ocaml-game-of-life/ocaml-game-of-life.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/ocaml-game-of-life/ocaml-game-of-life.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/ocaml-game-of-life/ocaml-source.ml b/tests/bench/ocaml-game-of-life/ocaml-source.ml new file mode 100644 index 000000000..3c9df39d9 --- /dev/null +++ b/tests/bench/ocaml-game-of-life/ocaml-source.ml @@ -0,0 +1,75 @@ +(* Copied from + * https://github.com/ocaml-bench/sandmark + * directory benchmarks/multicore-numerical/game_of_life_multicore.ml + * commit 23a73b28c803763cfb7c8282d2a847261cb7d4a9 + *) + +let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +let n_times = try int_of_string Sys.argv.(2) with _ -> 2 +let board_size = try int_of_string Sys.argv.(3) with _ -> 1024 + +module T = Domainslib.Task + +let rg = + ref (Array.init board_size (fun _ -> Array.init board_size (fun _ -> Random.int 2))) +let rg' = + ref (Array.init board_size (fun _ -> Array.init board_size (fun _ -> Random.int 2))) +let buf = Bytes.create board_size + +let get g x y = + try g.(x).(y) + with _ -> 0 + +let neighbourhood g x y = + (get g (x-1) (y-1)) + + (get g (x-1) (y )) + + (get g (x-1) (y+1)) + + (get g (x ) (y-1)) + + (get g (x ) (y+1)) + + (get g (x+1) (y-1)) + + (get g (x+1) (y )) + + (get g (x+1) (y+1)) + +let next_cell g x y = + let n = neighbourhood g x y in + match g.(x).(y), n with + | 1, 0 | 1, 1 -> 0 (* lonely *) + | 1, 4 | 1, 5 | 1, 6 | 1, 7 | 1, 8 -> 0 (* overcrowded *) + | 1, 2 | 1, 3 -> 1 (* lives *) + | 0, 3 -> 1 (* get birth *) + | _ (* 0, (0|1|2|4|5|6|7|8) *) -> 0 (* barren *) + +(* let print g = + for x = 0 to board_size - 1 do + for y = 0 to board_size - 1 do + if g.(x).(y) = 0 + then Bytes.set buf y '.' + else Bytes.set buf y 'o' + done; + print_endline (Bytes.unsafe_to_string buf) + done; + print_endline "" *) + +let next pool = + let g = !rg in + let new_g = !rg' in + T.parallel_for pool ~chunk_size:(board_size/num_domains) ~start:0 + ~finish:(board_size - 1) ~body:(fun x -> + for y = 0 to board_size - 1 do + new_g.(x).(y) <- next_cell g x y + done); + rg := new_g; + rg' := g + + +let rec repeat pool n = + match n with + | 0-> () + | _-> next pool; repeat pool (n-1) + +let ()= + let pool = T.setup_pool ~num_domains:(num_domains - 1) in + (* print !rg; *) + repeat pool n_times; + (* print !rg; *) + T.teardown_pool pool diff --git a/tests/bench/ocaml-lu-decomp/main.sml b/tests/bench/ocaml-lu-decomp/main.sml new file mode 100644 index 000000000..72728bbe7 --- /dev/null +++ b/tests/bench/ocaml-lu-decomp/main.sml @@ -0,0 +1,212 @@ +(* ocaml source: + * +## module T = Domainslib.Task +## let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +## let mat_size = try int_of_string Sys.argv.(2) with _ -> 1200 +## let chunk_size = try int_of_string Sys.argv.(3) with _ -> 16 + *) + +structure CLA = CommandLineArgs +val mat_size = CLA.parseInt "n" 1200 +val chunk_size = CLA.parseInt "chunk_size" 16 +val _ = print ("n" ^ Int.toString mat_size ^ "\n") +val _ = print ("chunk_size " ^ Int.toString chunk_size ^ "\n") + +(* ocaml source: + * +## +## module SquareMatrix = struct +## +## let create f : float array = +## let fa = Array.create_float (mat_size * mat_size) in +## for i = 0 to mat_size * mat_size - 1 do +## fa.(i) <- f (i / mat_size) (i mod mat_size) +## done; +## fa +## let parallel_create pool f : float array = +## let fa = Array.create_float (mat_size * mat_size) in +## T.parallel_for pool ~chunk_size:(mat_size * mat_size / num_domains) ~start:0 +## ~finish:( mat_size * mat_size - 1) ~body:(fun i -> +## fa.(i) <- f (i / mat_size) (i mod mat_size)); +## fa +## +## let get (m : float array) r c = m.(r * mat_size + c) +## let set (m : float array) r c v = m.(r * mat_size + c) <- v +## let parallel_copy pool a = +## let n = Array.length a in +## let copy_part a b i = +## let s = (i * n / num_domains) in +## let e = (i+1) * n / num_domains - 1 in +## Array.blit a s b s (e - s + 1) in +## let b = Array.create_float n in +## let rec aux acc num_domains i = +## if (i = num_domains) then +## (List.iter (fun e -> T.await pool e) acc) +## else begin +## aux ((T.async pool (fun _ -> copy_part a b i))::acc) num_domains (i+1) +## end +## in +## aux [] num_domains 0; +## b +## end +*) + +structure SquareMatrix = +struct + fun create f: real array = + let + val fa = ForkJoin.alloc (mat_size * mat_size) + in + Util.for (0, mat_size * mat_size) (fn i => + Array.update (fa, i, f (i div mat_size, i mod mat_size))); + fa + end + + fun parallel_create f: real array = + let + val fa = ForkJoin.alloc (mat_size * mat_size) + in + ForkJoin.parfor 10000 (0, mat_size * mat_size) (fn i => + Array.update (fa, i, f (i div mat_size, i mod mat_size))); + fa + end + + fun get m r c = Array.sub (m, r * mat_size + c) + fun set m r c v = Array.update (m, r * mat_size + c, v) + + (* SAM_NOTE: This function is a bit overengineered in the ocaml source in + * a way that is probably negatively impacting performance. The easiest way + * to do it would just be this: + * + * fun parallel_copy a = + * let + * val n = Array.length a + * val b = allocate n + * in + * parfor GRAIN (0, n) (fn i => Array.update (b, i, Array.sub (a, i))); + * b + * end + * + * But, rather than implement it like this, I'll try to be faithful to the + * original ocaml code here. + * + * The ocaml code for this function uses futures (async/await), which MPL + * doesn't support. But using futures is unnecessary, because the ocaml code + * just does uses them in fork-join style anyway! Also, the ocaml code + * relies upon knowing the number of processors to do granularity control, + * but this is unnecessary. Instead we can choose a static GRAIN and then + * divide the array into ceil(n/GRAIN) parts. + * + * The ocaml source seems like it has a bug, if n is not divisible + * by num_domains. Easy fix is below, when calculating the end of the + * part (variable e below). + *) + fun parallel_copy a = + let + val n = Array.length a + val GRAIN = 10000 (* same role as n/num_domains in ocaml source *) + fun copy_part a b i = + let + val s = i * GRAIN + val e = Int.min (n, (i+1) * GRAIN) (* fixed bug! take min. *) + in + Util.for (s, e) (fn j => Array.update (b, j, Array.sub (a, j))) + end + val num_parts = Util.ceilDiv n GRAIN + val b = ForkJoin.alloc n + in + ForkJoin.parfor 1 (0, num_parts) (copy_part a b); + b + end +end + +(* ocaml source: + * +## +## open SquareMatrix +## +## let lup pool (a0 : float array) = +## let a = parallel_copy pool a0 in +## for k = 0 to (mat_size - 2) do +## T.parallel_for pool ~chunk_size:chunk_size ~start:(k + 1) ~finish:(mat_size -1) +## ~body:(fun row -> +## let factor = get a row k /. get a k k in +## for col = k + 1 to mat_size-1 do +## set a row col (get a row col -. factor *. (get a k col)) +## done; +## set a row k factor ) +## done ; +## a + *) + +open SquareMatrix + +fun lup a0 = + let + val a = parallel_copy a0 + in + Util.for (0, mat_size-1) (fn k => + ForkJoin.parfor chunk_size (k+1, mat_size) (fn row => + let + val factor = get a row k / get a k k + in + Util.for (k+1, mat_size) (fn col => + set a row col (get a row col - factor * (get a k col))); + set a row k factor + end)); + a + end + +(* ocaml source: + * +## let () = +## let pool = T.setup_pool ~num_domains:(num_domains - 1) in +## let a = create (fun _ _ -> (Random.float 100.0) +. 1.0 ) in +## let lu = lup pool a in +## let _l = parallel_create pool (fun i j -> if i > j then get lu i j else if i = j then 1.0 else 0.0) in +## let _u = parallel_create pool (fun i j -> if i <= j then get lu i j else 0.0) in +## T.teardown_pool pool + *) + +(* SAM_NOTE: It seems like the ocaml source chose to initialize sequentially + * because of the stateful RNG. We'll do a PRNG based on a hash function, to + * be safe for parallelism, and initialize in parallel. + *) +(* +val rand = Random.rand (15, 210) (* seed the generator *) +fun randReal bound = + bound * Random.randReal rand +*) + +fun randReal bound seed = + bound * (Real.fromInt (Util.hash seed mod 1000001) / 1000000.0) + +val (a, tm) = Util.getTime (fn _ => + create (fn (i, j) => 1.0 + randReal 100.0 (i * mat_size + j))) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +fun ludEx () = + let + val (lu, tm) = Util.getTime (fn _ => lup a) + val _ = print ("main decomposition in " ^ Time.fmt 4 tm ^ "s\n") + + val ((l, u), tm) = Util.getTime (fn _ => + let + val l = + parallel_create (fn (i, j) => + if i > j then get lu i j + else if i = j then 1.0 + else 0.0) + val u = + parallel_create (fn (i, j) => + if i <= j then get lu i j else 0.0) + in + (l, u) + end) + val _ = print ("extracted L and U in " ^ Time.fmt 4 tm ^ "s\n") + in + (l, u) + end + +val _ = Benchmark.run "running LU decomposition" ludEx + diff --git a/tests/bench/ocaml-lu-decomp/ocaml-lu-decomp.mlb b/tests/bench/ocaml-lu-decomp/ocaml-lu-decomp.mlb new file mode 100644 index 000000000..4942dcdd5 --- /dev/null +++ b/tests/bench/ocaml-lu-decomp/ocaml-lu-decomp.mlb @@ -0,0 +1,7 @@ +../../mpllib/sources.$(COMPAT).mlb +(*local + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb +in + structure Random +end*) +main.sml diff --git a/tests/bench/ocaml-lu-decomp/ocaml-source.ml b/tests/bench/ocaml-lu-decomp/ocaml-source.ml new file mode 100644 index 000000000..d5bda24cc --- /dev/null +++ b/tests/bench/ocaml-lu-decomp/ocaml-source.ml @@ -0,0 +1,62 @@ +module T = Domainslib.Task +let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +let mat_size = try int_of_string Sys.argv.(2) with _ -> 1200 +let chunk_size = try int_of_string Sys.argv.(3) with _ -> 16 + +module SquareMatrix = struct + + let create f : float array = + let fa = Array.create_float (mat_size * mat_size) in + for i = 0 to mat_size * mat_size - 1 do + fa.(i) <- f (i / mat_size) (i mod mat_size) + done; + fa + let parallel_create pool f : float array = + let fa = Array.create_float (mat_size * mat_size) in + T.parallel_for pool ~chunk_size:(mat_size * mat_size / num_domains) ~start:0 + ~finish:( mat_size * mat_size - 1) ~body:(fun i -> + fa.(i) <- f (i / mat_size) (i mod mat_size)); + fa + + let get (m : float array) r c = m.(r * mat_size + c) + let set (m : float array) r c v = m.(r * mat_size + c) <- v + let parallel_copy pool a = + let n = Array.length a in + let copy_part a b i = + let s = (i * n / num_domains) in + let e = (i+1) * n / num_domains - 1 in + Array.blit a s b s (e - s + 1) in + let b = Array.create_float n in + let rec aux acc num_domains i = + if (i = num_domains) then + (List.iter (fun e -> T.await pool e) acc) + else begin + aux ((T.async pool (fun _ -> copy_part a b i))::acc) num_domains (i+1) + end + in + aux [] num_domains 0; + b +end + +open SquareMatrix + +let lup pool (a0 : float array) = + let a = parallel_copy pool a0 in + for k = 0 to (mat_size - 2) do + T.parallel_for pool ~chunk_size:chunk_size ~start:(k + 1) ~finish:(mat_size -1) + ~body:(fun row -> + let factor = get a row k /. get a k k in + for col = k + 1 to mat_size-1 do + set a row col (get a row col -. factor *. (get a k col)) + done; + set a row k factor ) + done ; + a + +let () = + let pool = T.setup_pool ~num_domains:(num_domains - 1) in + let a = create (fun _ _ -> (Random.float 100.0) +. 1.0 ) in + let lu = lup pool a in + let _l = parallel_create pool (fun i j -> if i > j then get lu i j else if i = j then 1.0 else 0.0) in + let _u = parallel_create pool (fun i j -> if i <= j then get lu i j else 0.0) in + T.teardown_pool pool diff --git a/tests/bench/ocaml-mandelbrot/main.sml b/tests/bench/ocaml-mandelbrot/main.sml new file mode 100644 index 000000000..87d14732b --- /dev/null +++ b/tests/bench/ocaml-mandelbrot/main.sml @@ -0,0 +1,187 @@ +structure CLA = CommandLineArgs + +(* ocaml source: + * +## let niter = 50 +## let limit = 4. +## +## let num_domains = int_of_string (Array.get Sys.argv 1) +## let w = int_of_string (Array.get Sys.argv 2) + * + *) + +val niter = 50 +val limit = 4.0 + +val w = CLA.parseInt "w" 16000 +val _ = print ("w " ^ Int.toString w ^ "\n") + +(* ocaml source: + * +## +## let worker w h_lo h_hi = +## let buf = +## Bytes.create ((w / 8 + (if w mod 8 > 0 then 1 else 0)) * (h_hi - h_lo)) +## and ptr = ref 0 in +## let fw = float w /. 2. in +## let fh = fw in +## let red_w = w - 1 and red_h_hi = h_hi - 1 and byte = ref 0 in +## for y = h_lo to red_h_hi do +## let ci = float y /. fh -. 1. in +## for x = 0 to red_w do +## let cr = float x /. fw -. 1.5 +## and zr = ref 0. and zi = ref 0. and trmti = ref 0. and n = ref 0 in +## begin try +## while true do +## Domain.Sync.poll (); +## zi := 2. *. !zr *. !zi +. ci; +## zr := !trmti +. cr; +## let tr = !zr *. !zr and ti = !zi *. !zi in +## if tr +. ti > limit then begin +## byte := !byte lsl 1; +## raise Exit +## end else if incr n; !n = niter then begin +## byte := (!byte lsl 1) lor 0x01; +## raise Exit +## end else +## trmti := tr -. ti +## done +## with Exit -> () +## end; +## if x mod 8 = 7 then begin +## Bytes.set buf !ptr (Char.chr !byte); +## incr ptr; +## byte := 0 +## end +## done; +## let rem = w mod 8 in +## if rem != 0 then begin +## Bytes.set buf !ptr (Char.chr (!byte lsl (8 - rem))); +## incr ptr; +## byte := 0 +## end +## done; +## buf + * + *) + +fun incr x = (x := !x + 1) + +fun worker w h_lo h_hi = + let + val buf = ForkJoin.alloc ((w div 8 + (if w mod 8 > 0 then 1 else 0)) * (h_hi - h_lo)) + val ptr = ref 0 + val fw = Real.fromInt w / 2.0 + val fh = fw + val byte = ref 0w0 + in + Util.for (h_lo, h_hi) (fn y => + let + val ci = Real.fromInt y / fh - 1.0 + in + (* print ("y=" ^ Int.toString y ^ "\n"); *) + Util.for (0, w) (fn x => + let + val cr = Real.fromInt x / fw - 1.5 + val zr = ref 0.0 + val zi = ref 0.0 + val trmti = ref 0.0 + val n = ref 0 + + fun loop () = + ( zi := 2.0 * !zr * !zi + ci + ; zr := !trmti + cr + ; let + val tr = !zr * !zr + val ti = !zi * !zi + in + if tr + ti > limit then + (byte := Word.<< (!byte, 0w1)) + else if (incr n; !n = niter) then + (byte := Word.orb (Word.<< (!byte, 0w1), 0wx01)) + else + (trmti := tr - ti; loop ()) + end + ) + in + (* print ("x=" ^ Int.toString x ^ "\n"); *) + loop (); + if x mod 8 = 7 then + ( Array.update (buf, !ptr, Char.chr (Word.toInt (!byte))) + ; incr ptr + ; byte := 0w0 + ) + else () + end); + + let + val rem = w mod 8 + in + if rem <> 0 then + ( Array.update (buf, !ptr, + Char.chr (Word.toInt (Word.<< (!byte, Word.fromInt (8 - rem))))) + ; incr ptr + ; byte := 0w0 + ) + else () + end + + end); + + buf + end + +(* ocaml source: + * +## let _ = +## let pool = T.setup_pool ~num_domains:(num_domains - 1) in +## let rows = w / num_domains and rem = w mod num_domains in +## Printf.printf "P4\n%i %i\n%!" w w; +## let work i () = +## worker w (i * rows + min i rem) ((i+1) * rows + min (i+1) rem) +## in +## let doms = +## Array.init (num_domains - 1) (fun i -> T.async pool (work i)) in +## let r = work (num_domains-1) () in +## Array.iter (fun d -> Printf.printf "%a%!" output_bytes (T.await pool d)) doms; +## Printf.printf "%a%!" output_bytes r; +## T.teardown_pool pool + *) + +(* GRAIN is the target work for one worker (in terms of number of pixels); + * this should be big enough to amortize the cost of parallelism. To match up + * with the original code, pick the maximum "number of domains" so that each + * domain has approximately at least GRAIN work to do (but cap this at some + * reasonably large number...) + *) +val GRAIN = 1000 +val num_domains = Int.min (500, Int.min (w, Util.ceilDiv (w * w) GRAIN)) + +val rows = w div num_domains +val rem = w mod num_domains + +fun work i = + worker w (i * rows + Int.min (i, rem)) ((i+1) * rows + Int.min (i+1, rem)) + +val results = + Benchmark.run "running mandelbrot" (fn _ => + SeqBasis.tabulate 1 (0, num_domains) work) + +val outfile = CLA.parseString "output" "" + +val _ = + if outfile = "" then + print ("use -output XXX to see result\n") + else + let + val file = TextIO.openOut outfile + fun dump1 c = TextIO.output1 (file, c) + fun dump str = TextIO.output (file, str) + in + ( dump "P4\n" + ; dump (Int.toString w ^ " " ^ Int.toString w ^ "\n") + ; Array.app (Array.app dump1) results + ; TextIO.closeOut file + ) + end + diff --git a/tests/bench/ocaml-mandelbrot/ocaml-mandelbrot.mlb b/tests/bench/ocaml-mandelbrot/ocaml-mandelbrot.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/ocaml-mandelbrot/ocaml-mandelbrot.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/ocaml-mandelbrot/ocaml-source.ml b/tests/bench/ocaml-mandelbrot/ocaml-source.ml new file mode 100644 index 000000000..52536421e --- /dev/null +++ b/tests/bench/ocaml-mandelbrot/ocaml-source.ml @@ -0,0 +1,80 @@ +(* + * The Computer Language Benchmarks Game + * https://salsa.debian.org/benchmarksgame-team/benchmarksgame/ + * + * Contributed by Paolo Ribeca + * + * (Very loosely based on previous version Ocaml #3, + * which had been contributed by + * Christophe TROESTLER + * and enhanced by + * Christian Szegedy and Yaron Minsky) + * + * fix compile errors by using Bytes instead of String, by Tony Tavener + *) + +module T = Domainslib.Task + +let niter = 50 +let limit = 4. + +let num_domains = int_of_string (Array.get Sys.argv 1) +let w = int_of_string (Array.get Sys.argv 2) + +let worker w h_lo h_hi = + let buf = + Bytes.create ((w / 8 + (if w mod 8 > 0 then 1 else 0)) * (h_hi - h_lo)) + and ptr = ref 0 in + let fw = float w /. 2. in + let fh = fw in + let red_w = w - 1 and red_h_hi = h_hi - 1 and byte = ref 0 in + for y = h_lo to red_h_hi do + let ci = float y /. fh -. 1. in + for x = 0 to red_w do + let cr = float x /. fw -. 1.5 + and zr = ref 0. and zi = ref 0. and trmti = ref 0. and n = ref 0 in + begin try + while true do + Domain.Sync.poll (); + zi := 2. *. !zr *. !zi +. ci; + zr := !trmti +. cr; + let tr = !zr *. !zr and ti = !zi *. !zi in + if tr +. ti > limit then begin + byte := !byte lsl 1; + raise Exit + end else if incr n; !n = niter then begin + byte := (!byte lsl 1) lor 0x01; + raise Exit + end else + trmti := tr -. ti + done + with Exit -> () + end; + if x mod 8 = 7 then begin + Bytes.set buf !ptr (Char.chr !byte); + incr ptr; + byte := 0 + end + done; + let rem = w mod 8 in + if rem != 0 then begin + Bytes.set buf !ptr (Char.chr (!byte lsl (8 - rem))); + incr ptr; + byte := 0 + end + done; + buf + +let _ = + let pool = T.setup_pool ~num_domains:(num_domains - 1) in + let rows = w / num_domains and rem = w mod num_domains in + Printf.printf "P4\n%i %i\n%!" w w; + let work i () = + worker w (i * rows + min i rem) ((i+1) * rows + min (i+1) rem) + in + let doms = + Array.init (num_domains - 1) (fun i -> T.async pool (work i)) in + let r = work (num_domains-1) () in + Array.iter (fun d -> Printf.printf "%a%!" output_bytes (T.await pool d)) doms; + Printf.printf "%a%!" output_bytes r; + T.teardown_pool pool diff --git a/tests/bench/ocaml-nbody-imm/README b/tests/bench/ocaml-nbody-imm/README new file mode 100644 index 000000000..029bf5032 --- /dev/null +++ b/tests/bench/ocaml-nbody-imm/README @@ -0,0 +1,3 @@ +Difference from ocaml-nbody: + - switch from mutable fields in planet to fully immutable record + - switch from imperative main loop to purely functional diff --git a/tests/bench/ocaml-nbody-imm/main.sml b/tests/bench/ocaml-nbody-imm/main.sml new file mode 100644 index 000000000..b8ac55e4f --- /dev/null +++ b/tests/bench/ocaml-nbody-imm/main.sml @@ -0,0 +1,148 @@ +structure CLA = CommandLineArgs + +val num_domains = Concurrency.numberOfProcessors + +val n = CLA.parseInt "n" 500 +val num_bodies = CLA.parseInt "num_bodies" 1024 +val gran = CLA.parseInt "gran" 20 + +val pi = 3.141592653589793 +val solar_mass = 4.0 * pi * pi +val days_per_year = 365.24 + +type planet = + { x: real, y: real, z: real + , vx: real, vy: real, vz: real + , mass: real + } + +fun advance bodies dt = + let + fun velocity i = + let + val b = Array.sub (bodies, i) + val (vx, vy, vz) = (#vx b, #vy b, #vz b) + in + Util.loop (0, Array.length bodies) (vx, vy, vz) (fn ((vx, vy, vz), j) => + let + val b' = Array.sub (bodies, j) + in + if i <> j then + let + val dx = #x b - #x b' + val dy = #y b - #y b' + val dz = #z b - #z b' + val dist2 = dx * dx + dy * dy + dz * dz + val mag = dt / (dist2 * Math.sqrt(dist2)) + in + ( vx - dx * #mass b' * mag + , vy - dy * #mass b' * mag + , vz - dz * #mass b' * mag + ) + end + else (vx, vy, vz) + end) + end + + val velocities = PureSeq.tabulateG gran velocity num_bodies + in + Util.for (0, num_bodies) (fn i => + let + val b = Array.sub (bodies, i) + val (vx, vy, vz) = PureSeq.nth velocities i + in + Array.update (bodies, i, + { x = #x b + dt * vx + , y = #y b + dt * vy + , z = #z b + dt * vz + , vx = vx + , vy = vy + , vz = vz + , mass = #mass b + }) + end) + end + + +fun energy bodies = + let + in + SeqBasis.reduce 1 op+ 0.0 (0, Array.length bodies) (fn i => + let + val b = Array.sub (bodies, i) + val e = ref 0.0 + in + e := !e + 0.5 * #mass b * + (#vx b * #vx b + #vy b * #vy b + #vz b * #vz b); + + Util.for (i+1, Array.length bodies) (fn j => + let + val b' = Array.sub (bodies, j) + val dx = #x b - #x b' + val dy = #y b - #y b' + val dz = #z b - #z b' + val distance = Math.sqrt (dx * dx + dy * dy + dz * dz) + in + e := !e - (#mass b * #mass b') / distance + end); + + !e + end) + end + +fun offset_momentum bodies = + let + val px = ref 0.0 + val py = ref 0.0 + val pz = ref 0.0 + val b0 = Array.sub (bodies, 0) + in + Util.for (0, Array.length bodies) (fn i => + let + val b = Array.sub (bodies, i) + in + px := !px + #vx b * #mass b; + py := !py + #vy b * #mass b; + pz := !pz + #vz b * #mass b + end); + Array.update (bodies, 0, + { x = #x b0 + , y = #y b0 + , z = #z b0 + , vx = ~ (!px) / solar_mass + , vy = ~ (!py) / solar_mass + , vz = ~ (!pz) / solar_mass + , mass = #mass b0 + }) + end + +val seed = Random.rand (42, 15210) +fun randFloat _ bound = + bound * (Random.randReal seed) + +(* fun randFloat seed bound = + bound * (Real.fromInt (Util.hash seed mod 1000000000) / 1000000000.0) *) + +val bodies = + Array.tabulate (num_bodies, fn i => + let + val seed = 7*i + in + { x = randFloat seed 10.0 + , y = randFloat (seed+1) 10.0 + , z = randFloat (seed+2) 10.0 + , vx = randFloat (seed+3) 5.0 * days_per_year + , vy = randFloat (seed+4) 4.0 * days_per_year + , vz = randFloat (seed+5) 5.0 * days_per_year + , mass = randFloat (seed+6) 10.0 * solar_mass + } + end) + +val _ = offset_momentum bodies +val _ = print ("initial energy: " ^ Real.toString (energy bodies) ^ "\n") + +val _ = Benchmark.run "running simulation" (fn _ => + Util.for (0, n) (fn _ => advance bodies 0.01)) + +val _ = print ("final energy: " ^ Real.toString (energy bodies) ^ "\n") + diff --git a/tests/bench/ocaml-nbody-imm/ocaml-nbody-imm.mlb b/tests/bench/ocaml-nbody-imm/ocaml-nbody-imm.mlb new file mode 100644 index 000000000..4c9eb5f6a --- /dev/null +++ b/tests/bench/ocaml-nbody-imm/ocaml-nbody-imm.mlb @@ -0,0 +1,7 @@ +../../mpllib/sources.$(COMPAT).mlb +local + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb +in + structure Random +end +main.sml diff --git a/tests/bench/ocaml-nbody-packed/README b/tests/bench/ocaml-nbody-packed/README new file mode 100644 index 000000000..9e16289f4 --- /dev/null +++ b/tests/bench/ocaml-nbody-packed/README @@ -0,0 +1,5 @@ +difference from ocaml-nbody: + - explicit packing of planet type into a single array of floats + (this guarantees that flattening occurred, but after measuring + no performance improvement suggests that the other implementations + do actually manage to flatten.) diff --git a/tests/bench/ocaml-nbody-packed/main.sml b/tests/bench/ocaml-nbody-packed/main.sml new file mode 100644 index 000000000..0d7b08425 --- /dev/null +++ b/tests/bench/ocaml-nbody-packed/main.sml @@ -0,0 +1,156 @@ +structure CLA = CommandLineArgs + +val _ = print ("hello\n") + +val num_domains = Concurrency.numberOfProcessors + +val n = CLA.parseInt "n" 500 +val num_bodies = CLA.parseInt "num_bodies" 1024 + +val pi = 3.141592653589793 +val solar_mass = 4.0 * pi * pi +val days_per_year = 365.24 + +type planet = + { x: real ref, y: real ref, z: real ref + , vx: real ref, vy: real ref, vz: real ref + , mass: real + } + +(** manually packed. size = 7 * num bodies *) +type bodies = real array +fun get_x bodies i = Array.sub (bodies, 7*i) +fun get_y bodies i = Array.sub (bodies, 7*i + 1) +fun get_z bodies i = Array.sub (bodies, 7*i + 2) +fun get_vx bodies i = Array.sub (bodies, 7*i + 3) +fun get_vy bodies i = Array.sub (bodies, 7*i + 4) +fun get_vz bodies i = Array.sub (bodies, 7*i + 5) +fun get_mass bodies i = Array.sub (bodies, 7*i + 6) +fun set_x bodies i x = Array.update (bodies, 7*i, x) +fun set_y bodies i x = Array.update (bodies, 7*i + 1, x) +fun set_z bodies i x = Array.update (bodies, 7*i + 2, x) +fun set_vx bodies i x = Array.update (bodies, 7*i + 3, x) +fun set_vy bodies i x = Array.update (bodies, 7*i + 4, x) +fun set_vz bodies i x = Array.update (bodies, 7*i + 5, x) +fun set_mass bodies i x = Array.update (bodies, 7*i + 6, x) + + +fun advance bodies dt = + let + in + ForkJoin.parfor 1 (0, num_bodies) (fn i => + let + val (vx, vy, vz) = (get_vx bodies i, get_vy bodies i, get_vz bodies i) + val (vx, vy, vz) = + Util.loop (0, num_bodies) (vx, vy, vz) (fn ((vx, vy, vz), j) => + let + (* val b' = Array.sub (bodies, j) *) + in + if i <> j then + let + val dx = get_x bodies i - get_x bodies j + val dy = get_y bodies i - get_y bodies j + val dz = get_z bodies i - get_z bodies j + val dist2 = dx * dx + dy * dy + dz * dz + val mag = dt / (dist2 * Math.sqrt(dist2)) + in + ( vx - dx * get_mass bodies j * mag + , vy - dy * get_mass bodies j * mag + , vz - dz * get_mass bodies j * mag + ) + end + else (vx, vy, vz) + end); + in + set_vx bodies i vx; + set_vy bodies i vy; + set_vz bodies i vz + end); + + Util.for (0, num_bodies) (fn i => + let + (* val b = Array.sub (bodies, i) *) + in + set_x bodies i (get_x bodies i + dt * get_vx bodies i); + set_y bodies i (get_y bodies i + dt * get_vy bodies i); + set_z bodies i (get_z bodies i + dt * get_vz bodies i) + end) + end + +fun energy bodies = + let + in + SeqBasis.reduce 1 op+ 0.0 (0, num_bodies) (fn i => + let + (* val b = Array.sub (bodies, i) *) + val e = ref 0.0 + in + e := !e + 0.5 * get_mass bodies i * + (get_vx bodies i * get_vx bodies i + + get_vy bodies i * get_vy bodies i + + get_vz bodies i * get_vz bodies i); + + Util.for (i+1, num_bodies) (fn j => + let + (* val b' = Array.sub (bodies, j) *) + val dx = get_x bodies i - get_x bodies j + val dy = get_y bodies i - get_y bodies j + val dz = get_z bodies i - get_z bodies j + val distance = Math.sqrt (dx * dx + dy * dy + dz * dz) + in + e := !e - (get_mass bodies i * get_mass bodies j) / distance + end); + + !e + end) + end + +fun offset_momentum bodies = + let + val px = ref 0.0 + val py = ref 0.0 + val pz = ref 0.0 + in + Util.for (0, num_bodies) (fn i => + let + (* val b = Array.sub (bodies, i) *) + in + px := !px + get_vx bodies i * get_mass bodies i; + py := !py + get_vy bodies i * get_mass bodies i; + pz := !pz + get_vz bodies i * get_mass bodies i + end); + set_vx bodies 0 (~ (!px) / solar_mass); + set_vy bodies 0 (~ (!py) / solar_mass); + set_vz bodies 0 (~ (!pz) / solar_mass) + end + +val seed = Random.rand (42, 15210) +fun randFloat bound = + bound * (Random.randReal seed) + +val _ = print ("initializing bodies...\n") + +val bodies = Array.array (num_bodies * 7, 0.0); +val _ = + Util.for (0, num_bodies) (fn i => + ( set_x bodies i (randFloat 10.0) + ; set_y bodies i (randFloat 10.0) + ; set_z bodies i (randFloat 10.0) + ; set_vx bodies i (randFloat 5.0 * days_per_year) + ; set_vy bodies i (randFloat 4.0 * days_per_year) + ; set_vz bodies i (randFloat 5.0 * days_per_year) + ; set_mass bodies i (randFloat 10.0 * solar_mass) + ) + ) + +val _ = print ("offset momentum...\n"); +val _ = offset_momentum bodies + +val _ = print ("calculating initial energy...\n"); +val _ = print ("initial energy: " ^ Real.toString (energy bodies) ^ "\n") + +val _ = Benchmark.run "running simulation" (fn _ => + Util.for (0, n) (fn _ => advance bodies 0.01)) + +val _ = print ("final energy: " ^ Real.toString (energy bodies) ^ "\n") + diff --git a/tests/bench/ocaml-nbody-packed/ocaml-nbody-packed.mlb b/tests/bench/ocaml-nbody-packed/ocaml-nbody-packed.mlb new file mode 100644 index 000000000..4c9eb5f6a --- /dev/null +++ b/tests/bench/ocaml-nbody-packed/ocaml-nbody-packed.mlb @@ -0,0 +1,7 @@ +../../mpllib/sources.$(COMPAT).mlb +local + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb +in + structure Random +end +main.sml diff --git a/tests/bench/ocaml-nbody/main.sml b/tests/bench/ocaml-nbody/main.sml new file mode 100644 index 000000000..589b33051 --- /dev/null +++ b/tests/bench/ocaml-nbody/main.sml @@ -0,0 +1,128 @@ +structure CLA = CommandLineArgs + +val num_domains = Concurrency.numberOfProcessors + +val n = CLA.parseInt "n" 500 +val num_bodies = CLA.parseInt "num_bodies" 1024 + +val pi = 3.141592653589793 +val solar_mass = 4.0 * pi * pi +val days_per_year = 365.24 + +type planet = + { x: real ref, y: real ref, z: real ref + , vx: real ref, vy: real ref, vz: real ref + , mass: real + } + +fun advance bodies dt = + let + in + ForkJoin.parfor 20 (0, num_bodies) (fn i => + let + val b = Array.sub (bodies, i) + val (vx, vy, vz) = (ref (!(#vx b)), ref (!(#vy b)), ref (!(#vz b))) + in + Util.for (0, Array.length bodies) (fn j => + let + val b' = Array.sub (bodies, j) + in + if i <> j then + let + val dx = !(#x b) - !(#x b') + val dy = !(#y b) - !(#y b') + val dz = !(#z b) - !(#z b') + val dist2 = dx * dx + dy * dy + dz * dz + val mag = dt / (dist2 * Math.sqrt(dist2)) + in + vx := !vx - dx * #mass b' * mag; + vy := !vy - dy * #mass b' * mag; + vz := !vz - dz * #mass b' * mag + end + else () + end); + + #vx b := !vx; + #vy b := !vy; + #vz b := !vz + end); + + Util.for (0, num_bodies) (fn i => + let + val b = Array.sub (bodies, i) + in + #x b := !(#x b) + dt * !(#vx b); + #y b := !(#y b) + dt * !(#vy b); + #z b := !(#z b) + dt * !(#vz b) + end) + + end + +fun energy bodies = + let + in + SeqBasis.reduce 1 op+ 0.0 (0, Array.length bodies) (fn i => + let + val b = Array.sub (bodies, i) + val e = ref 0.0 + in + e := !e + 0.5 * #mass b * + (!(#vx b) * !(#vx b) + !(#vy b) * !(#vy b) + !(#vz b) * !(#vz b)); + + Util.for (i+1, Array.length bodies) (fn j => + let + val b' = Array.sub (bodies, j) + val dx = !(#x b) - !(#x b') + val dy = !(#y b) - !(#y b') + val dz = !(#z b) - !(#z b') + val distance = Math.sqrt (dx * dx + dy * dy + dz * dz) + in + e := !e - (#mass b * #mass b') / distance + end); + + !e + end) + end + +fun offset_momentum bodies = + let + val px = ref 0.0 + val py = ref 0.0 + val pz = ref 0.0 + in + Util.for (0, Array.length bodies) (fn i => + let + val b = Array.sub (bodies, i) + in + px := !px + !(#vx b) * #mass b; + py := !py + !(#vy b) * #mass b; + pz := !pz + !(#vz b) * #mass b + end); + #vx (Array.sub (bodies, 0)) := ~ (!px) / solar_mass; + #vy (Array.sub (bodies, 0)) := ~ (!py) / solar_mass; + #vz (Array.sub (bodies, 0)) := ~ (!pz) / solar_mass + end + +val seed = Random.rand (42, 15210) +fun randFloat bound = + bound * (Random.randReal seed) + +val bodies = + Array.tabulate (num_bodies, fn _ => + { x = ref (randFloat 10.0) + , y = ref (randFloat 10.0) + , z = ref (randFloat 10.0) + , vx = ref (randFloat 5.0 * days_per_year) + , vy = ref (randFloat 4.0 * days_per_year) + , vz = ref (randFloat 5.0 * days_per_year) + , mass = randFloat 10.0 * solar_mass + }) + +val _ = offset_momentum bodies +val _ = print ("initial energy: " ^ Real.toString (energy bodies) ^ "\n") + +val _ = Benchmark.run "running simulation" (fn _ => + Util.for (0, n) (fn _ => advance bodies 0.01)) + +val _ = print ("final energy: " ^ Real.toString (energy bodies) ^ "\n") + diff --git a/tests/bench/ocaml-nbody/ocaml-nbody.mlb b/tests/bench/ocaml-nbody/ocaml-nbody.mlb new file mode 100644 index 000000000..4c9eb5f6a --- /dev/null +++ b/tests/bench/ocaml-nbody/ocaml-nbody.mlb @@ -0,0 +1,7 @@ +../../mpllib/sources.$(COMPAT).mlb +local + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb +in + structure Random +end +main.sml diff --git a/tests/bench/ocaml-nbody/ocaml-source.ml b/tests/bench/ocaml-nbody/ocaml-source.ml new file mode 100644 index 000000000..fa0af0d9e --- /dev/null +++ b/tests/bench/ocaml-nbody/ocaml-source.ml @@ -0,0 +1,101 @@ +(* Copied from + * https://github.com/ocaml-bench/sandmark + * file benchmarks/multicore-numerical/nbody_multicore.ml + * commit fc1d270db57db643031deb66a25dba9147904a05 + *) + +(* The Computer Language Benchmarks Game + * https://salsa.debian.org/benchmarksgame-team/benchmarksgame/ + * + * Contributed by Troestler Christophe + *) + +module T = Domainslib.Task + +let num_domains = try int_of_string Sys.argv.(1) with _ -> 1 +let n = try int_of_string Sys.argv.(2) with _ -> 500 +let num_bodies = try int_of_string Sys.argv.(3) with _ -> 1024 + +let pi = 3.141592653589793 +let solar_mass = 4. *. pi *. pi +let days_per_year = 365.24 + +type planet = { mutable x : float; mutable y : float; mutable z : float; + mutable vx: float; mutable vy: float; mutable vz: float; + mass : float } + +let advance pool bodies dt = + T.parallel_for pool + ~start:0 + ~finish:(num_bodies - 1) + ~body:(fun i -> + let b = bodies.(i) in + let vx, vy, vz = ref b.vx, ref b.vy, ref b.vz in + for j = 0 to Array.length bodies - 1 do + Domain.Sync.poll(); + let b' = bodies.(j) in + if (i!=j) then begin + let dx = b.x -. b'.x and dy = b.y -. b'.y and dz = b.z -. b'.z in + let dist2 = dx *. dx +. dy *. dy +. dz *. dz in + let mag = dt /. (dist2 *. sqrt(dist2)) in + vx := !vx -. dx *. b'.mass *. mag; + vy := !vy -. dy *. b'.mass *. mag; + vz := !vz -. dz *. b'.mass *. mag; + end + done; + b.vx <- !vx; + b.vy <- !vy; + b.vz <- !vz); + for i = 0 to num_bodies - 1 do + Domain.Sync.poll(); + let b = bodies.(i) in + b.x <- b.x +. dt *. b.vx; + b.y <- b.y +. dt *. b.vy; + b.z <- b.z +. dt *. b.vz; + done + +let energy pool bodies = + T.parallel_for_reduce pool (+.) 0. + ~start:0 + ~finish:(Array.length bodies -1) + ~body:(fun i -> + let b = bodies.(i) and e = ref 0. in + e := !e +. 0.5 *. b.mass *. (b.vx *. b.vx +. b.vy *. b.vy +. b.vz *. b.vz); + for j = i+1 to Array.length bodies - 1 do + let b' = bodies.(j) in + let dx = b.x -. b'.x and dy = b.y -. b'.y and dz = b.z -. b'.z in + let distance = sqrt(dx *. dx +. dy *. dy +. dz *. dz) in + e := !e -. (b.mass *. b'.mass) /. distance; + Domain.Sync.poll () + done; + !e) + +let offset_momentum bodies = + let px = ref 0. and py = ref 0. and pz = ref 0. in + for i = 0 to Array.length bodies - 1 do + let b = bodies.(i) in + px := !px +. b.vx *. b.mass; + py := !py +. b.vy *. b.mass; + pz := !pz +. b.vz *. b.mass; + done; + bodies.(0).vx <- -. !px /. solar_mass; + bodies.(0).vy <- -. !py /. solar_mass; + bodies.(0).vz <- -. !pz /. solar_mass + +let bodies = + Array.init num_bodies (fun _ -> + { x = (Random.float 10.); + y = (Random.float 10.); + z = (Random.float 10.); + vx= (Random.float 5.) *. days_per_year; + vy= (Random.float 4.) *. days_per_year; + vz= (Random.float 5.) *. days_per_year; + mass=(Random.float 10.) *. solar_mass; }) + +let () = + let pool = T.setup_pool ~num_additional_domains:(num_domains - 1) in + offset_momentum bodies; + Printf.printf "%.9f\n" (energy pool bodies); + for _i = 1 to n do advance pool bodies 0.01 done; + Printf.printf "%.9f\n" (energy pool bodies); + T.teardown_pool pool diff --git a/tests/bench/palindrome/Pal.sml b/tests/bench/palindrome/Pal.sml new file mode 100644 index 000000000..318aa20ce --- /dev/null +++ b/tests/bench/palindrome/Pal.sml @@ -0,0 +1,107 @@ +structure Pal: +sig + val longest: char Seq.t -> int * int +end = +struct + + structure AS = ArraySlice + + val p = 1045678717: Int64.int + val base = 500000000: Int64.int + val length = Seq.length + fun sub (str, i) = Seq.nth str i + fun toStr str = CharVector.tabulate (length str, (fn i => sub (str, i))) + + fun check str = + let + val n = length str + fun fromChar c = Int64.fromInt (Char.ord c) + fun modp v = v mod p + fun mul (a, b) = modp (Int64.* (a, b)) + fun add (a, b) = modp (Int64.+ (a, b)) + fun charMul (c, b) = mul (fromChar c, b) + + (* val (basePowers, _) = Seq.scan mul 1 (Seq.tabulate (fn _ => base) n) *) + val basePowers = AS.full (SeqBasis.scan 10000 mul 1 (0, n) (fn _ => base)) + + fun genHash getChar = + let + val P = AS.full (SeqBasis.scan 10000 add 0 (0, n) + (fn i => charMul (getChar i, Seq.nth basePowers i))) + (* val (P, total) = Seq.scan add 0 (Seq.zipWith charMul (str, basePowers)) *) + fun H (i, j) = + let + val last = Seq.nth P j + (* val last = if (j = n) then total else Seq.nth P j *) + val first = Seq.nth P i + val offset = Seq.nth basePowers (n-i-1) + open Int64 + in + modp ((last - first) * offset) + end + in + H + end + + val forwadHash = genHash (Seq.nth str) + val backHash = genHash (fn i => Seq.nth str (n-i-1)) + + fun perhaps (i, j) = (forwadHash(i, j) = backHash(n-j, n-i)) + in + perhaps + end + + (* Verifies that str[i:j] is a palindrome. *) + fun verify str (i, j) = + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (i, i + (j-i) div 2) + (fn k => Seq.nth str k = Seq.nth str (i+(j-k-1))) + + fun binarySearch check = + let + fun bs (i, j) = + if j - i = 1 then i + else + let + val mid = (i+j) div 2 + in + if check mid then bs (mid, j) + else bs (i, mid) + end + fun double i = if check (2*i) then double (2*i) else bs (i, 2*i) + in + if check 1 then double 1 else 0 + end + + (* Generates a polynomial hash of the string and reversed string. Then for + * each index in the string, binary search on the longest palindrome whose + * middle is that index. + *) + fun longest str = + let + val n = length str + (* val isPalinrome = Util.printTime "build" (fn () => check str) *) + val isPalinrome = check str + + fun maxval ((ia, la), (ib, lb)) = if la > lb then (ia, la) else (ib, lb) + + (* fun getMax f = Seq.reduce maxval (0, 0) (Seq.tabulate f n) *) + + fun getMax f = SeqBasis.reduce 1000 maxval (0, 0) (0, n) f + + fun checkOdd i j = + (i - j >= 0) andalso (i + j < n) andalso isPalinrome (i - j, i + j + 1) + val (io, lo) = getMax (fn i => (i, binarySearch (checkOdd i))) + + fun checkEven i j = + (i - j + 1 >= 0) andalso (i + j < n) + andalso isPalinrome(i - j + 1, i + j + 1) + val (ie, le) = getMax (fn i => (i, binarySearch (checkEven i))) + + val (i, l) = if le > lo then (ie-le+1, 2*le) else (io-lo, 2*lo+1) + val _ = if not(verify str (i, i + l)) then print("Failed!\n") + else () (* print(toStr(Seq.subseq str (i,l)) ^ "\n") *) + in + (i, l) + end + +end diff --git a/tests/bench/palindrome/main.sml b/tests/bench/palindrome/main.sml new file mode 100644 index 000000000..c49a3c8cc --- /dev/null +++ b/tests/bench/palindrome/main.sml @@ -0,0 +1,26 @@ +structure CLA = CommandLineArgs +structure P = Pal + +val n = CLA.parseInt "n" (1000 * 1000) + +(* makes the sequence `ababab...` *) +fun gen i = if i mod 2 = 0 then #"a" else #"b" +val (input, tm) = Util.getTime (fn _ => Seq.tabulate gen n) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +val result = + Benchmark.run "finding longest palindrome" (fn _ => Pal.longest input) + +val _ = print ("found longest palindrome in " ^ Time.fmt 4 tm ^ "s\n") + +val correct = + if n mod 2 = 0 + then result = (1, n-1) + else result = (0, n) + +val _ = + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + diff --git a/tests/bench/palindrome/palindrome.mlb b/tests/bench/palindrome/palindrome.mlb new file mode 100644 index 000000000..0c8ac80dd --- /dev/null +++ b/tests/bench/palindrome/palindrome.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +Pal.sml +main.sml diff --git a/tests/bench/parens/MkParens.sml b/tests/bench/parens/MkParens.sml new file mode 100644 index 000000000..bd868e94a --- /dev/null +++ b/tests/bench/parens/MkParens.sml @@ -0,0 +1,31 @@ +functor MkParens (Seq: SEQUENCE) : +sig + datatype paren = Left | Right + val parenMatch: (paren ArraySequence.t) -> bool +end = +struct + + datatype paren = Left | Right + + fun combine ((l1,r1),(l2,r2)) = + if r1 > l2 then + (l1,r1+r2-l2) + else + (l1+l2-r1, r2) + + val id = (0, 0) + + fun singleton v = + case v of + Left => (0,1) + | Right => (1,0) + + fun parenMatch s = + let + val (l, r) = + Seq.reduce combine id (Seq.map singleton (Seq.fromArraySeq s)) + in + l = 0 andalso r = 0 + end + +end diff --git a/tests/bench/parens/main.sml b/tests/bench/parens/main.sml new file mode 100644 index 000000000..b0170542d --- /dev/null +++ b/tests/bench/parens/main.sml @@ -0,0 +1,33 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure Parens = MkParens(DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +(* makes the sequence `()()()...` *) +fun gen i = + if i mod 2 = 0 then Parens.Left else Parens.Right + +val input = Seq.tabulate gen n + +fun task () = + Parens.parenMatch input + +fun check result = + if not doCheck then () else + let + val correct = + (n mod 2 = 0 andalso result) + orelse + (n mod 2 = 1 andalso not result) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "parens" task +val _ = check result diff --git a/tests/bench/parens/parens.mlb b/tests/bench/parens/parens.mlb new file mode 100644 index 000000000..5f79a1b5e --- /dev/null +++ b/tests/bench/parens/parens.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkParens.sml +main.sml diff --git a/tests/bench/primes-blocked/primes-blocked.mlb b/tests/bench/primes-blocked/primes-blocked.mlb new file mode 100644 index 000000000..51fd277a8 --- /dev/null +++ b/tests/bench/primes-blocked/primes-blocked.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +primes-blocked.sml diff --git a/tests/bench/primes-blocked/primes-blocked.sml b/tests/bench/primes-blocked/primes-blocked.sml new file mode 100644 index 000000000..ad51aa8d2 --- /dev/null +++ b/tests/bench/primes-blocked/primes-blocked.sml @@ -0,0 +1,61 @@ +structure CLA = CommandLineArgs + +(* primes: int -> int array + * generate all primes up to (and including) n *) +fun blockedPrimes n = + if n < 2 then + ForkJoin.alloc 0 + else + let + val sqrtN = Real.floor (Math.sqrt (Real.fromInt n)) + val sqrtPrimes = blockedPrimes sqrtN + + val flags = ForkJoin.alloc (n + 1) : Word8.word array + fun mark i = Array.update (flags, i, 0w0) + fun unmark i = Array.update (flags, i, 0w1) + fun isMarked i = + Array.sub (flags, i) = 0w0 + val _ = ForkJoin.parfor 10000 (0, n + 1) mark + + val blockSize = Int.max (sqrtN, 1000) + val numBlocks = Util.ceilDiv (n + 1) blockSize + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n + 1) + + fun loop i = + if i >= Array.length sqrtPrimes then + () + else if 2 * Array.sub (sqrtPrimes, i) >= hi then + () + else + let + val p = Array.sub (sqrtPrimes, i) + val lom = Int.max (2, Util.ceilDiv lo p) + val him = Util.ceilDiv hi p + in + Util.for (lom, him) (fn m => unmark (m * p)); + loop (i + 1) + end + in + loop 0 + end) + in + SeqBasis.filter 4096 (2, n + 1) (fn i => i) isMarked + end + +(* ========================================================================== + * parse command-line arguments and run + *) + +val n = CLA.parseInt "N" (100 * 1000 * 1000) + +val msg = "generating primes up to " ^ Int.toString n +val result = Benchmark.run msg (fn _ => blockedPrimes n) + +val numPrimes = Array.length result +val _ = print ("number of primes " ^ Int.toString numPrimes ^ "\n") +val _ = print ("result " ^ Util.summarizeArray 8 Int.toString result ^ "\n") + diff --git a/tests/bench/primes-segmented/SegmentedPrimes.sml b/tests/bench/primes-segmented/SegmentedPrimes.sml new file mode 100644 index 000000000..a26c3718b --- /dev/null +++ b/tests/bench/primes-segmented/SegmentedPrimes.sml @@ -0,0 +1,104 @@ +functor SegmentedPrimes + (I: + sig + type t + val from_int: int -> t + val to_int: t -> int + end): +sig + val primes: int -> I.t Seq.t + val primes_with_params: {block_size_factor: real, report_times: bool} + -> int + -> I.t Seq.t +end = +struct + + (* The block size used in the algorithm is approximately + * sqrt(n)*block_size_factor + * + * Increasing block_size_factor will use larger blocks, which has all of the + * following effects on performance: + * (1) decreased theoretical work + * (2) less data locality (?) + * (3) less parallelism + *) + fun primes_with_params (params as {block_size_factor, report_times}) n : + I.t Seq.t = + if n < 2 then + Seq.empty () + else + let + val sqrt_n = Real.floor (Math.sqrt (Real.fromInt n)) + + val sqrt_primes = primes_with_params params sqrt_n + + (* Split the range [2,n+1) into blocks *) + val block_size = Real.ceil (Real.fromInt sqrt_n * block_size_factor) + val block_size = Int.max (block_size, 1000) + val num_blocks = Util.ceilDiv ((n + 1) - 2) block_size + + val (block_results, tm) = Util.getTime (fn _ => + SeqBasis.reduce 1 TreeSeq.append (TreeSeq.empty ()) (0, num_blocks) + (fn b => + let + val lo = 2 + b * block_size + val hi = Int.min (lo + block_size, n + 1) + + val flags = Array.array (hi - lo, 0w1 : Word8.word) + fun unmark i = + Array.update (flags, i - lo, 0w0) + + fun loop i = + if i >= Seq.length sqrt_primes then + () + else if 2 * I.to_int (Seq.nth sqrt_primes i) >= hi then + () + else + let + val p = I.to_int (Seq.nth sqrt_primes i) + val lom = Int.max (2, Util.ceilDiv lo p) + val him = Util.ceilDiv hi p + in + Util.for (lom, him) (fn m => unmark (m * p)); + loop (i + 1) + end + + val _ = loop 0 + + val numPrimes = Util.loop (0, hi - lo) 0 (fn (count, i) => + if Array.sub (flags, i) = 0w0 then count else count + 1) + + val output = ForkJoin.alloc numPrimes + + val _ = Util.loop (lo, hi) 0 (fn (outi, i) => + if Array.sub (flags, i - lo) = 0w0 then outi + else (Array.update (output, outi, I.from_int i); outi + 1)) + in + TreeSeq.from_array_seq (ArraySlice.full output) + end)) + + val _ = + if not report_times then + () + else + print + ("sieve (n = " ^ Int.toString n ^ "): " ^ Time.fmt 4 tm ^ "s\n") + + val (result, tm) = Util.getTime (fn _ => + TreeSeq.to_array_seq block_results) + + val _ = + if not report_times then + () + else + print + ("flatten (n = " ^ Int.toString n ^ "): " ^ Time.fmt 4 tm ^ "s\n") + in + result + end + + + fun primes n = + primes_with_params {block_size_factor = 8.0, report_times = false} n + +end diff --git a/tests/bench/primes-segmented/TreeSeq.sml b/tests/bench/primes-segmented/TreeSeq.sml new file mode 100644 index 000000000..0c8322569 --- /dev/null +++ b/tests/bench/primes-segmented/TreeSeq.sml @@ -0,0 +1,83 @@ +structure TreeSeq = +struct + datatype 'a t = + Leaf + | Elem of 'a + | Flat of 'a Seq.t + | Node of {num_elems: int, num_blocks: int, left: 'a t, right: 'a t} + + type 'a seq = 'a t + + fun length Leaf = 0 + | length (Elem _) = 1 + | length (Flat s) = Seq.length s + | length (Node {num_elems = n, ...}) = n + + fun num_blocks Leaf = 0 + | num_blocks (Elem _) = 1 + | num_blocks (Flat _) = 1 + | num_blocks (Node {num_blocks = nb, ...}) = nb + + + fun append (t1, t2) = + Node + { num_elems = length t1 + length t2 + , num_blocks = num_blocks t1 + num_blocks t2 + , left = t1 + , right = t2 + } + + + fun to_blocks (t: 'a t) : 'a Seq.t Seq.t = + let + val blocks = ForkJoin.alloc (num_blocks t) + + fun putBlocks offset t = + case t of + Leaf => () + | Elem x => Array.update (blocks, offset, Seq.singleton x) + | Flat s => Array.update (blocks, offset, s) + | Node {num_blocks = nb, left = l, right = r, ...} => + let + fun left () = putBlocks offset l + fun right () = + putBlocks (offset + num_blocks l) r + in + if nb <= 1000 then (left (); right ()) + else (ForkJoin.par (left, right); ()) + end + in + putBlocks 0 t; + ArraySlice.full blocks + end + + + fun to_array_seq t = + let + val a = ForkJoin.alloc (length t) + fun put offset t = + case t of + Leaf => () + | Elem x => Array.update (a, offset, x) + | Flat s => Seq.foreach s (fn (i, x) => Array.update (a, offset + i, x)) + | Node {num_elems = n, left = l, right = r, ...} => + let + fun left () = put offset l + fun right () = + put (offset + length l) r + in + if n <= 4096 then (left (); right ()) + else (ForkJoin.par (left, right); ()) + end + in + put 0 t; + ArraySlice.full a + end + + fun from_array_seq a = Flat a + + fun empty () = Leaf + fun singleton x = Elem x + val $ = singleton + +end diff --git a/tests/bench/primes-segmented/main.sml b/tests/bench/primes-segmented/main.sml new file mode 100644 index 000000000..c1b3aad51 --- /dev/null +++ b/tests/bench/primes-segmented/main.sml @@ -0,0 +1,64 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" (100 * 1000 * 1000) +val block_size_factor = CLA.parseReal "block-size-factor" 16.0 +val bits = CLA.parseInt "bits" 64 +val report_times = CLA.parseFlag "report-times" + +val _ = print ("N " ^ Int.toString n ^ "\n") +val _ = print ("block-size-factor " ^ Real.toString block_size_factor ^ "\n") +val _ = print ("bits " ^ Int.toString bits ^ "\n") +val _ = print ("report-times? " ^ (if report_times then "yes" else "no") ^ "\n") + +functor Main + (I: + sig + type t + val from_int: int -> t + val to_int: t -> int + val to_string: t -> string + end) = +struct + structure Primes = SegmentedPrimes(I) + + fun main () = + let + val params = + {block_size_factor = block_size_factor, report_times = report_times} + val msg = "generating primes up to " ^ Int.toString n + + val result = Benchmark.run msg (fn _ => + Primes.primes_with_params params n) + + val numPrimes = Seq.length result + val _ = print ("number of primes " ^ Int.toString numPrimes ^ "\n") + val _ = print + ("result " ^ Util.summarizeArraySlice 8 I.to_string result ^ "\n") + in + () + end +end + +structure Main32 = + Main + (struct + type t = Int32.int + val from_int = Int32.fromInt + val to_int = Int32.toInt + val to_string = Int32.toString + end) + +structure Main64 = + Main + (struct + type t = Int64.int + val from_int = Int64.fromInt + val to_int = Int64.toInt + val to_string = Int64.toString + end) + +val _ = + case bits of + 64 => Main64.main () + | 32 => Main32.main () + | _ => Util.die ("unknown -bits " ^ Int.toString bits ^ ": must be 32 or 64") diff --git a/tests/bench/primes-segmented/primes-segmented.mlb b/tests/bench/primes-segmented/primes-segmented.mlb new file mode 100644 index 000000000..e6746318a --- /dev/null +++ b/tests/bench/primes-segmented/primes-segmented.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +TreeSeq.sml +SegmentedPrimes.sml +main.sml diff --git a/tests/bench/primes/check-stack-dir/check-stack b/tests/bench/primes/check-stack-dir/check-stack new file mode 100755 index 000000000..ae2f945fb Binary files /dev/null and b/tests/bench/primes/check-stack-dir/check-stack differ diff --git a/tests/bench/primes/check-stack-dir/primes.mlb b/tests/bench/primes/check-stack-dir/primes.mlb new file mode 100644 index 000000000..b0c9af4bd --- /dev/null +++ b/tests/bench/primes/check-stack-dir/primes.mlb @@ -0,0 +1,5 @@ +$(SML_LIB)/basis/basis.mlb +$(SML_LIB)/basis/fork-join.mlb + +../../mpllib/sources.$(COMPAT).mlb +primes.sml diff --git a/tests/bench/primes/check-stack-dir/primes.sml b/tests/bench/primes/check-stack-dir/primes.sml new file mode 100644 index 000000000..2165ba8ac --- /dev/null +++ b/tests/bench/primes/check-stack-dir/primes.sml @@ -0,0 +1,37 @@ +datatype linkedList = Node of (linkedList ref * int) + | End + +(*val x = node(ref (nil2) * ref(nil2) * 4)*) +exception IndexError + +val n = 10000 + +fun elem i = i + +val y = SeqBasis.foldl (fn (lis, x) => Node(ref(lis), x)) End (0, n) elem + +(*fun printList lis = + case lis of + End => print "end\n" + | Node(a, b) => (print (Int.toString b ^ " "); printList (!a)) + +fun delete (idx:int) (lis:linkedList) = + case lis of + End => raise IndexError + | Node(a, b) => + if idx = 0 then + (!a, Node(a, b)) + else if idx = 1 then + let + val delElement = !a + val delNext = case delElement of + End => raise IndexError + | Node(c, d) => !c + val _ = (a := delNext) + in + (Node(a, b), delElement) + end + else + delete (idx - 1) (!a)*) + +(*val x = delete (n-1) y*) diff --git a/tests/bench/primes/primes.mlb b/tests/bench/primes/primes.mlb new file mode 100644 index 000000000..180294849 --- /dev/null +++ b/tests/bench/primes/primes.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +primes.sml diff --git a/tests/bench/primes/primes.sml b/tests/bench/primes/primes.sml new file mode 100644 index 000000000..e2dd7cc5f --- /dev/null +++ b/tests/bench/primes/primes.sml @@ -0,0 +1,46 @@ +structure CLA = CommandLineArgs + +(* primes: int -> int array + * generate all primes up to (and including) n *) +fun primes n = + if n < 2 then ForkJoin.alloc 0 else + let + (* all primes up to sqrt(n) *) + val sqrtPrimes = primes (Real.floor (Math.sqrt (Real.fromInt n))) + + (* allocate array of flags to mark primes. *) + val flags = ForkJoin.alloc (n+1) : Word8.word array + fun mark i = Array.update (flags, i, 0w0) + fun unmark i = Array.update (flags, i, 0w1) + fun isMarked i = Array.sub (flags, i) = 0w0 + + (* initially, mark every number *) + val _ = ForkJoin.parfor 10000 (0, n+1) mark + + (* unmark every multiple of every prime in sqrtPrimes *) + val _ = + ForkJoin.parfor 1 (0, Array.length sqrtPrimes) (fn i => + let + val p = Array.sub (sqrtPrimes, i) + val numMultiples = n div p - 1 + in + ForkJoin.parfor 4096 (0, numMultiples) (fn j => unmark ((j+2) * p)) + end) + in + (* for every i in 2 <= i <= n, filter those that are still marked *) + SeqBasis.filter 4096 (2, n+1) (fn i => i) isMarked + end + +(* ========================================================================== + * parse command-line arguments and run + *) + +val n = CLA.parseInt "n" (100 * 1000 * 1000) + +val msg = "generating primes up to " ^ Int.toString n +val result = Benchmark.run msg (fn _ => primes n) + +val numPrimes = Array.length result +val _ = print ("number of primes " ^ Int.toString numPrimes ^ "\n") +val _ = print ("result " ^ Util.summarizeArray 8 Int.toString result ^ "\n") + diff --git a/tests/bench/primes/safe/primes.mlb b/tests/bench/primes/safe/primes.mlb new file mode 100644 index 000000000..180294849 --- /dev/null +++ b/tests/bench/primes/safe/primes.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +primes.sml diff --git a/tests/bench/primes/safe/primes.sml b/tests/bench/primes/safe/primes.sml new file mode 100644 index 000000000..4db9b2354 --- /dev/null +++ b/tests/bench/primes/safe/primes.sml @@ -0,0 +1,47 @@ +(* primes: int -> int array + * generate all primes up to (and including) n *) +fun primes n = + if n < 2 then ForkJoin.alloc 0 else + let + (* all primes up to sqrt(n) *) + val sqrtPrimes = primes (Real.floor (Math.sqrt (Real.fromInt n))) + + (* allocate array of flags to mark primes. *) + val flags = ForkJoin.alloc (n+1) : Word8.word array + fun mark i = Array.update (flags, i, 0w0) + fun unmark i = Array.update (flags, i, 0w1) + fun isMarked i = Array.sub (flags, i) = 0w0 + + (* initially, mark every number *) + val _ = ForkJoin.parfor 10000 (0, n+1) mark + + (* unmark every multiple of every prime in sqrtPrimes *) + val _ = + ForkJoin.parfor 1 (0, Array.length sqrtPrimes) (fn i => + let + val p = Array.sub (sqrtPrimes, i) + val numMultiples = n div p - 1 + in + ForkJoin.parfor 4096 (0, numMultiples) (fn j => unmark ((j+2) * p)) + end) + in + (* for every i in 2 <= i <= n, filter those that are still marked *) + SeqBasis.filter 4096 (2, n+1) (fn i => i) isMarked + end + +(* ========================================================================== + * parse command-line arguments and run + *) + +val n = CommandLineArgs.parseInt "N" (100 * 1000 * 1000) +val _ = print ("generating primes up to " ^ Int.toString n ^ "\n") + +val t0 = Time.now () +val result = primes n +val t1 = Time.now () + +val _ = print ("finished in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val numPrimes = Array.length result +val _ = print ("number of primes " ^ Int.toString numPrimes ^ "\n") +val _ = print ("result " ^ Util.summarizeArray 8 Int.toString result ^ "\n") diff --git a/tests/bench/pure-msort-int32/msort.sml b/tests/bench/pure-msort-int32/msort.sml new file mode 100644 index 000000000..064b82c95 --- /dev/null +++ b/tests/bench/pure-msort-int32/msort.sml @@ -0,0 +1,37 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" (100 * 1000 * 1000) +val quicksortGrain = CLA.parseInt "quicksort" 1024 +val grain = CLA.parseInt "grain" 1024 +val _ = print ("N " ^ Int.toString n ^ "\n") + +val _ = print ("generating " ^ Int.toString n ^ " random integers\n") + +val max32 = Word64.fromLargeInt (Int32.toLarge (valOf Int32.maxInt)) + +fun elem i = + Int32.fromInt (Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt n))) +val input = PureSeq.tabulate elem n + +fun sort cmp xs = + if PureSeq.length xs <= quicksortGrain then + PureSeq.quicksort cmp xs + else + let + val n = PureSeq.length xs + val l = PureSeq.take xs (n div 2) + val r = PureSeq.drop xs (n div 2) + val (l', r') = + if n <= grain then + (sort cmp l, sort cmp r) + else + ForkJoin.par (fn _ => sort cmp l, fn _ => sort cmp r) + in + PureSeq.merge cmp (l', r') + end + +val result = + Benchmark.run "running mergesort" (fn _ => sort Int32.compare input) + +(* val _ = print ("result " ^ Util.summarizeArraySlice 8 Int.toString result ^ "\n") *) + diff --git a/tests/bench/pure-msort-int32/pure-msort-int32.mlb b/tests/bench/pure-msort-int32/pure-msort-int32.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/pure-msort-int32/pure-msort-int32.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/pure-msort-strings/msort.sml b/tests/bench/pure-msort-strings/msort.sml new file mode 100644 index 000000000..01d170e31 --- /dev/null +++ b/tests/bench/pure-msort-strings/msort.sml @@ -0,0 +1,55 @@ +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: msort-strings FILE [-grain ...] [--long] \n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val filename = + case CLA.positional () of + [x] => x + | _ => usage () + +val makeLong = CLA.parseFlag "long" +val quicksortGrain = CLA.parseInt "quicksort" 1024 +val grain = CLA.parseInt "grain" 1024 + +val (contents, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filename) +val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") +val (tokens, tm) = Util.getTime (fn _ => Tokenize.tokens Char.isSpace contents) +val _ = print ("tokenized in " ^ Time.fmt 4 tm ^ "s\n") + +val prefix = CharVector.tabulate (32, fn _ => #"a") + +val tokens = + if not makeLong then tokens + else Seq.map (fn str => prefix ^ str) tokens + +val tokens = PureSeq.fromSeq tokens + +fun sort cmp xs = + if PureSeq.length xs <= quicksortGrain then + PureSeq.quicksort cmp xs + else + let + val n = PureSeq.length xs + val l = PureSeq.take xs (n div 2) + val r = PureSeq.drop xs (n div 2) + val (l', r') = + if n <= grain then + (sort cmp l, sort cmp r) + else + ForkJoin.par (fn _ => sort cmp l, fn _ => sort cmp r) + in + PureSeq.merge cmp (l', r') + end + +val result = + Benchmark.run "running mergesort" (fn _ => sort String.compare tokens) + +(* val _ = print ("result " ^ Util.summarizeArraySlice 8 (fn x => x) result ^ "\n") *) + diff --git a/tests/bench/pure-msort-strings/pure-msort-strings.mlb b/tests/bench/pure-msort-strings/pure-msort-strings.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/pure-msort-strings/pure-msort-strings.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/pure-msort/msort.sml b/tests/bench/pure-msort/msort.sml new file mode 100644 index 000000000..c61b9baed --- /dev/null +++ b/tests/bench/pure-msort/msort.sml @@ -0,0 +1,35 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" (100 * 1000 * 1000) +val quicksortGrain = CLA.parseInt "quicksort" 1024 +val grain = CLA.parseInt "grain" 1024 +val _ = print ("N " ^ Int.toString n ^ "\n") + +val _ = print ("generating " ^ Int.toString n ^ " random integers\n") + +fun elem i = + Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt n)) +val input = PureSeq.tabulate elem n + +fun sort cmp xs = + if PureSeq.length xs <= quicksortGrain then + PureSeq.quicksort cmp xs + else + let + val n = PureSeq.length xs + val l = PureSeq.take xs (n div 2) + val r = PureSeq.drop xs (n div 2) + val (l', r') = + if n <= grain then + (sort cmp l, sort cmp r) + else + ForkJoin.par (fn _ => sort cmp l, fn _ => sort cmp r) + in + PureSeq.merge cmp (l', r') + end + +val result = + Benchmark.run "running mergesort" (fn _ => sort Int.compare input) + +val _ = print ("result " ^ PureSeq.summarize 10 Int.toString result ^ "\n") + diff --git a/tests/bench/pure-msort/pure-msort.mlb b/tests/bench/pure-msort/pure-msort.mlb new file mode 100644 index 000000000..e05edb8f7 --- /dev/null +++ b/tests/bench/pure-msort/pure-msort.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +msort.sml diff --git a/tests/bench/pure-nn/nn.sml b/tests/bench/pure-nn/nn.sml new file mode 100644 index 000000000..4053a4d5d --- /dev/null +++ b/tests/bench/pure-nn/nn.sml @@ -0,0 +1,420 @@ +structure NN : +sig + type t + type 'a seq = 'a PureSeq.t + + type point = Geometry2D.point + + (* makeTree leafSize points *) + val makeTree : int -> point seq -> t + + (* allNearestNeighbors grain quadtree *) + val allNearestNeighbors : int -> t -> int seq +end = +struct + + structure A = Array + structure AS = ArraySlice + structure V = Vector + structure VS = VectorSlice + val unsafeCast: 'a array -> 'a vector = VectorExtra.unsafeFromArray + + type 'a seq = 'a VS.slice + structure G = Geometry2D + type point = G.point + + fun par4 (a, b, c, d) = + let + val ((ar, br), (cr, dr)) = + ForkJoin.par (fn _ => ForkJoin.par (a, b), + fn _ => ForkJoin.par (c, d)) + in + (ar, br, cr, dr) + end + + datatype tree = + Leaf of { anchor : point + , width : real + , vertices : int seq (* indices of original point seq *) + } + | Node of { anchor : point + , width : real + , count : int + , children : tree seq + } + + type t = tree * point seq + + fun count t = + case t of + Leaf {vertices, ...} => PureSeq.length vertices + | Node {count, ...} => count + + fun anchor t = + case t of + Leaf {anchor, ...} => anchor + | Node {anchor, ...} => anchor + + fun width t = + case t of + Leaf {width, ...} => width + | Node {width, ...} => width + + fun boxOf t = + case t of + Leaf {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + | Node {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + + fun indexApp grain f t = + let + fun downSweep offset t = + case t of + Leaf {vertices, ...} => + VS.appi (fn (i, v) => f (offset + i, v)) vertices + | Node {children, ...} => + let + fun q i = VS.sub (children, i) + fun qCount i = count (q i) + val offset0 = offset + val offset1 = offset0 + qCount 0 + val offset2 = offset1 + qCount 1 + val offset3 = offset2 + qCount 2 + in + if count t <= grain then + ( downSweep offset0 (q 0) + ; downSweep offset1 (q 1) + ; downSweep offset2 (q 2) + ; downSweep offset3 (q 3) + ) + else + ( par4 + ( fn _ => downSweep offset0 (q 0) + , fn _ => downSweep offset1 (q 1) + , fn _ => downSweep offset2 (q 2) + , fn _ => downSweep offset3 (q 3) + ) + ; () + ) + end + in + downSweep 0 t + end + + fun indexMap grain f t = + let + val result = ForkJoin.alloc (count t) + val _ = indexApp grain (fn (i, v) => A.update (result, i, f (i, v))) t + in + VS.full (unsafeCast result) + end + + fun flatten grain t = indexMap grain (fn (_, v) => v) t + + (* val lowerTime = ref Time.zeroTime + val upperTime = ref Time.zeroTime + fun addTm r t = + if Primitives.numberOfProcessors = 1 then + r := Time.+ (!r, t) + else () + fun clearAndReport r name = + (print (name ^ " " ^ Time.fmt 4 (!r) ^ "\n"); r := Time.zeroTime) *) + + (* Make a tree where all points are in the specified bounding box. *) + fun makeTreeBounded leafSize (verts : point seq) (idx : int Seq.t) ((xLeft, yBot) : G.point) width = + if Seq.length idx <= leafSize then + Leaf { anchor = (xLeft, yBot) + , width = width + , vertices = PureSeq.fromSeq idx + } + else let + val qw = width/2.0 (* quadrant width *) + val center = (xLeft + qw, yBot + qw) + + val ((sorted, offsets), tm) = Util.getTime (fn () => + CountingSort.sort idx (fn i => + G.quadrant center (PureSeq.nth verts (Seq.nth idx i))) 4) + + (* val _ = + if AS.length idx >= 4 * leafSize then + addTm upperTime tm + else + addTm lowerTime tm *) + + fun quadrant i = + let + val start = AS.sub (offsets, i) + val len = AS.sub (offsets, i+1) - start + val childIdx = AS.subslice (sorted, start, SOME len) + val qAnchor = + case i of + 0 => (xLeft + qw, yBot + qw) + | 1 => (xLeft, yBot + qw) + | 2 => (xLeft, yBot) + | _ => (xLeft + qw, yBot) + in + makeTreeBounded leafSize verts childIdx qAnchor qw + end + + (* val children = Seq.tabulate (Perf.grain 1) quadrant 4 *) + val (a, b, c, d) = + if AS.length idx <= 100 then + (quadrant 0, quadrant 1, quadrant 2, quadrant 3) + else + par4 + ( fn _ => quadrant 0 + , fn _ => quadrant 1 + , fn _ => quadrant 2 + , fn _ => quadrant 3 ) + val children = PureSeq.fromList [a,b,c,d] + in + Node { anchor = (xLeft, yBot) + , width = width + , count = AS.length idx + , children = children + } + end + + fun loop (lo, hi) b f = + if (lo >= hi) then b else loop (lo+1, hi) (f (b, lo)) f + + fun reduce grain f b (get, lo, hi) = + if hi - lo <= grain then + loop (lo, hi) b (fn (b, i) => f (b, get i)) + else let + val mid = lo + (hi-lo) div 2 + val (l,r) = ForkJoin.par + ( fn _ => reduce grain f b (get, lo, mid) + , fn _ => reduce grain f b (get, mid, hi) + ) + in + f (l, r) + end + + fun makeTree leafSize (verts : point seq) = + if PureSeq.length verts = 0 then raise Fail "makeTree with 0 points" else + let + (* calculate the bounding box *) + fun maxPt ((x1,y1),(x2,y2)) = (Real.max (x1, x2), Real.max (y1, y2)) + fun minPt ((x1,y1),(x2,y2)) = (Real.min (x1, x2), Real.min (y1, y2)) + fun getPt i = PureSeq.nth verts i + val (xLeft,yBot) = reduce 10000 minPt (Real.posInf, Real.posInf) (getPt, 0, VS.length verts) + val (xRight,yTop) = reduce 10000 maxPt (Real.negInf, Real.negInf) (getPt, 0, VS.length verts) + val width = Real.max (xRight-xLeft, yTop-yBot) + + val idx = Seq.tabulate (fn i => i) (PureSeq.length verts) + val result = makeTreeBounded leafSize verts idx (xLeft, yBot) width + in + (* clearAndReport upperTime "upper sort time"; *) + (* clearAndReport lowerTime "lower sort time"; *) + (result, verts) + end + + (* ======================================================================== *) + + fun constrain (x : real) (lo, hi) = + if x < lo then lo + else if x > hi then hi + else x + + fun distanceToBox (x,y) (xLeft, yBot, xRight, yTop) = + G.distance (x,y) (constrain x (xLeft, xRight), constrain y (yBot, yTop)) + + val dummyBest = (~1, Real.posInf) + + fun nearestNeighbor (t : tree, pts) (pi : int) = + let + fun pt i = PureSeq.nth pts i + + val p = pt pi + + fun refineNearest (qi, (bestPt, bestDist)) = + if pi = qi then (bestPt, bestDist) else + let + val qDist = G.distance p (pt qi) + in + if qDist < bestDist + then (qi, qDist) + else (bestPt, bestDist) + end + + fun search (best as (_, bestDist : real)) t = + if distanceToBox p (boxOf t) > bestDist then best else + case t of + Leaf {vertices, ...} => + VS.foldl refineNearest best vertices + | Node {anchor=(x,y), width, children, ...} => + let + val qw = width/2.0 + val center = (x+qw, y+qw) + + (* search the quadrant that p is in first *) + val heuristicOrder = + case G.quadrant center p of + 0 => [0,1,2,3] + | 1 => [1,0,2,3] + | 2 => [2,1,3,0] + | _ => [3,0,2,1] + + fun child i = VS.sub (children, i) + fun refine (i, best) = search best (child i) + in + List.foldl refine best heuristicOrder + end + + val (best, _) = search dummyBest t + in + best + end + + fun allNearestNeighbors grain (t, pts) = + let + val n = PureSeq.length pts + val idxs = flatten 10000 t + val nn = ForkJoin.alloc n + in + ForkJoin.parfor grain (0, n) (fn i => + let + val j = PureSeq.nth idxs i + in + A.update (nn, j, nearestNeighbor (t, pts) j) + end); + VS.full (unsafeCast nn) + end + +end + +(* ========================================================================== + * Now the main bit + *) + +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" 1000000 +val leafSize = CLA.parseInt "leafSize" 16 +val grain = CLA.parseInt "grain" 100 +val seed = CLA.parseInt "seed" 15210 + +fun genReal i = + let + val x = Word64.fromInt (seed + i) + in + Real.fromInt (Word64.toInt (Word64.mod (Util.hash64 x, 0w1000000))) + / 1000000.0 + end + +fun genPoint i = (genReal (2*i), genReal (2*i + 1)) +val (input, tm) = Util.getTime (fn _ => PureSeq.tabulate genPoint n) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +fun nnEx() = + let + val (tree, tm) = Util.getTime (fn _ => NN.makeTree leafSize input) + val _ = print ("built quadtree in " ^ Time.fmt 4 tm ^ "s\n") + + val (nbrs, tm) = Util.getTime (fn _ => NN.allNearestNeighbors grain tree) + val _ = print ("found all neighbors in " ^ Time.fmt 4 tm ^ "s\n") + in + (tree, nbrs) + end + +val (tree, nbrs) = Benchmark.run "running nearest neighbors" nnEx + +(* now input[nbrs[i]] is the closest point to input[i] *) + +(* ========================================================================== + * write image to output + * this only works if all input points are within [0,1) *) + +val filename = CLA.parseString "output" "" +val _ = + if filename <> "" then () + else ( print ("to see output, use -output and -resolution arguments\n" ^ + "for example: nn -N 10000 -output result.ppm -resolution 1000\n") + ; GCStats.report () + ; OS.Process.exit OS.Process.success + ) + +val t0 = Time.now () + +val resolution = CLA.parseInt "resolution" 1000 +val width = resolution +val height = resolution + +val image = + { width = width + , height = height + , data = Seq.tabulate (fn _ => Color.white) (width*height) + } + +fun set (i, j) x = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then ArraySlice.update (#data image, i*width + j, x) + else () + +val r = Real.fromInt resolution +fun px x = Real.floor (x * r) +fun pos (x, y) = (resolution - px x - 1, px y) + +fun horizontalLine i (j0, j1) = + if j1 < j0 then horizontalLine i (j1, j0) + else Util.for (j0, j1) (fn j => set (i, j) Color.red) + +fun sign xx = + case Int.compare (xx, 0) of LESS => ~1 | EQUAL => 0 | GREATER => 1 + +(* Bresenham's line algorithm *) +fun line (x1, y1) (x2, y2) = + let + val w = x2 - x1 + val h = y2 - y1 + val dx1 = sign w + val dy1 = sign h + val (longest, shortest, dx2, dy2) = + if Int.abs w > Int.abs h then + (Int.abs w, Int.abs h, dx1, 0) + else + (Int.abs h, Int.abs w, 0, dy1) + + fun loop i numerator x y = + if i > longest then () else + let + val numerator = numerator + shortest; + in + set (x, y) Color.red; + if numerator >= longest then + loop (i+1) (numerator-longest) (x+dx1) (y+dy1) + else + loop (i+1) numerator (x+dx2) (y+dy2) + end + in + loop 0 (longest div 2) x1 y1 + end + +(* mark all nearest neighbors with straight red lines *) +val t0 = Time.now () + +val _ = ForkJoin.parfor 10000 (0, PureSeq.length input) (fn i => + line (pos (PureSeq.nth input i)) (pos (PureSeq.nth input (PureSeq.nth nbrs i)))) + +(* mark input points as a pixel *) +val _ = + ForkJoin.parfor 10000 (0, PureSeq.length input) (fn i => + let + val (x, y) = pos (PureSeq.nth input i) + fun b spot = set spot Color.black + in + b (x-1, y); + b (x, y-1); + b (x, y); + b (x, y+1); + b (x+1, y) + end) + +val t1 = Time.now () + +val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n") + +val (_, tm) = Util.getTime (fn _ => PPM.write filename image) +val _ = print ("wrote to " ^ filename ^ " in " ^ Time.fmt 4 tm ^ "s\n") + diff --git a/tests/bench/pure-nn/pure-nn.mlb b/tests/bench/pure-nn/pure-nn.mlb new file mode 100644 index 000000000..bfa54aec7 --- /dev/null +++ b/tests/bench/pure-nn/pure-nn.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +nn.sml diff --git a/tests/bench/pure-quickhull/Quickhull.sml b/tests/bench/pure-quickhull/Quickhull.sml new file mode 100644 index 000000000..a5e97b415 --- /dev/null +++ b/tests/bench/pure-quickhull/Quickhull.sml @@ -0,0 +1,120 @@ +structure Quickhull : +sig + val hull : (real * real) PureSeq.t -> int PureSeq.t +end = +struct + + structure AS = ArraySlice + structure G = Geometry2D + structure Tree = TreeSeq + + fun hull pts = + let + fun pt i = PureSeq.nth pts i + fun dist p q i = G.Point.triArea (p, q, pt i) + fun max ((i, di), (j, dj)) = + if di > dj then (i, di) else (j, dj) + fun x i = #1 (pt i) + + fun aboveLine p q i = (dist p q i > 0.0) + + fun parHull idxs l r = + if PureSeq.length idxs < 2 then + Tree.fromPureSeq idxs + (* if DS.length idxs <= 2048 then + seqHull idxs l r *) + else + let + val lp = pt l + val rp = pt r + fun d i = dist lp rp i + + (* val idxs = DS.fromArraySeq idxs *) + + val (mid, _) = SeqBasis.reduce 10000 max (~1, Real.negInf) + (0, PureSeq.length idxs) + (fn i => (PureSeq.nth idxs i, d (PureSeq.nth idxs i))) + (* val distances = DS.map (fn i => (i, d i)) idxs + val (mid, _) = DS.reduce max (~1, Real.negInf) distances *) + + val midp = pt mid + + fun flag i = + if aboveLine lp midp i then Split.Left + else if aboveLine midp rp i then Split.Right + else Split.Throwaway + val (left, right) = + Split.parSplit idxs (PureSeq.map flag idxs) + (* (DS.force (DS.map flag idxs)) *) + + fun doLeft () = parHull left l mid + fun doRight () = parHull right mid r + val (leftHull, rightHull) = + if PureSeq.length left + PureSeq.length right <= 2048 + then (doLeft (), doRight ()) + else ForkJoin.par (doLeft, doRight) + in + Tree.append (leftHull, + (Tree.append (Tree.$ mid, rightHull))) + end + + (* val tm = Util.startTiming () *) + + (* val allIdx = DS.tabulate (fn i => i) (Seq.length pts) *) + + (* This is faster than doing two reduces *) + (* val (l, r) = DS.reduce + (fn ((l1, r1), (l2, r2)) => + (if x l1 < x l2 then l1 else l2, + if x r1 > x r2 then r1 else r2)) + (0, 0) + (DS.map (fn i => (i, i)) allIdx) *) + + val (l, r) = SeqBasis.reduce 10000 + (fn ((l1, r1), (l2, r2)) => + (if x l1 < x l2 then l1 else l2, + if x r1 > x r2 then r1 else r2)) + (0, 0) + (0, PureSeq.length pts) + (fn i => (i, i)) + + (* val tm = Util.tick tm "endpoints" *) + + val lp = pt l + val rp = pt r + + fun flag i = + let + val d = dist lp rp i + in + if d > 0.0 then Split.Left + else if d < 0.0 then Split.Right + else Split.Throwaway + end + val (above, below) = + (* Split.parSplit allIdx (DS.force (DS.map flag allIdx)) *) + Split.parSplit + (PureSeq.tabulate (fn i => i) (PureSeq.length pts)) + (PureSeq.tabulate flag (PureSeq.length pts)) + + (* val tm = Util.tick tm "above/below filter" *) + + val (above, below) = ForkJoin.par + (fn _ => parHull above l r, + fn _ => parHull below r l) + + (* val tm = Util.tick tm "quickhull" *) + + val hullt = + Tree.append + (Tree.append (Tree.$ l, above), + Tree.append (Tree.$ r, below)) + + val result = Tree.toPureSeq hullt + + (* val tm = Util.tick tm "flatten" *) + in + result + end + +end diff --git a/tests/bench/pure-quickhull/Split.sml b/tests/bench/pure-quickhull/Split.sml new file mode 100644 index 000000000..1de0a3646 --- /dev/null +++ b/tests/bench/pure-quickhull/Split.sml @@ -0,0 +1,126 @@ +structure Split : +sig + type 'a seq + + (* val inPlace : 'a seq -> ('a -> bool) -> ('a -> bool) -> (int * int) *) + + datatype flag = Left | Right | Throwaway + val parSplit : 'a seq -> flag seq -> 'a seq * 'a seq +end = +struct + + structure A = Array + structure AS = ArraySlice + structure V = Vector + structure VS = VectorSlice + + val unsafeCast: 'a array -> 'a vector = VectorExtra.unsafeFromArray + + type 'a seq = 'a PureSeq.t +(* + fun inPlace s putLeft putRight = + let + val (a, start, n) = AS.base s + fun item i = A.sub (a, i) + fun set i x = A.update (a, i, x) + + fun growLeft ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item lm + in + if putRight x then + growRight ll lm rm rr + else if not (putLeft x) then + growLeft ll (lm+1) rm rr + else + (set ll x; growLeft (ll+1) (lm+1) rm rr) + end + + and growRight ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item (rm-1) + in + if putLeft x then + swapThenContinue ll lm rm rr + else if not (putRight x) then + growRight ll lm (rm-1) rr + else + (set (rr-1) x; growRight ll lm (rm-1) (rr-1)) + end + + and swapThenContinue ll lm rm rr = + let + val tmp = item lm + in + set ll (item (rm-1)); + set (rr-1) tmp; + growLeft (ll+1) (lm+1) (rm-1) (rr-1) + end + + val (ll, rr) = growLeft start start (start+n) (start+n) + in + (ll-start, (start+n)-rr) + end +*) + datatype flag = Left | Right | Throwaway + + fun parSplit s flags = + let + val n = PureSeq.length s + val blockSize = 10000 + val numBlocks = 1 + (n-1) div blockSize + + (* the later scan(s) appears to be faster when split into two separate + * scans, rather than doing a single scan on tuples. *) + + (* val counts = Primitives.alloc numBlocks *) + val countl = ForkJoin.alloc numBlocks + val countr = ForkJoin.alloc numBlocks + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop (cl, cr) i = + if i >= hi then + (* A.update (counts, b, (cl, cr)) *) + (A.update (countl, b, cl); A.update (countr, b, cr)) + else case PureSeq.nth flags i of + Left => loop (cl+1, cr) (i+1) + | Right => loop (cl, cr+1) (i+1) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + + (* val (offsets, (totl, totr)) = + Seq.scan (fn ((a,b),(c,d)) => (a+c,b+d)) (0,0) (ArraySlice.full counts) *) + val (offsetsl, totl) = PureSeq.scan op+ 0 (VS.full (unsafeCast countl)) + val (offsetsr, totr) = PureSeq.scan op+ 0 (VS.full (unsafeCast countr)) + + val left = ForkJoin.alloc totl + val right = ForkJoin.alloc totr + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + (* val (offsetl, offsetr) = Seq.nth offsets b *) + val offsetl = PureSeq.nth offsetsl b + val offsetr = PureSeq.nth offsetsr b + fun loop (cl, cr) i = + if i >= hi then () else + case PureSeq.nth flags i of + Left => (A.update (left, offsetl+cl, PureSeq.nth s i); loop (cl+1, cr) (i+1)) + | Right => (A.update (right, offsetr+cr, PureSeq.nth s i); loop (cl, cr+1) (i+1)) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + in + (VS.full (unsafeCast left), VS.full (unsafeCast right)) + end + +end diff --git a/tests/bench/pure-quickhull/TreeSeq.sml b/tests/bench/pure-quickhull/TreeSeq.sml new file mode 100644 index 000000000..6c0346f71 --- /dev/null +++ b/tests/bench/pure-quickhull/TreeSeq.sml @@ -0,0 +1,54 @@ +structure TreeSeq = +struct + datatype 'a t = + Leaf + | Elem of 'a + | Flat of 'a PureSeq.t + | Node of int * 'a t * 'a t + + type 'a seq = 'a t + type 'a ord = 'a * 'a -> order + datatype 'a listview = NIL | CONS of 'a * 'a seq + datatype 'a treeview = EMPTY | ONE of 'a | PAIR of 'a seq * 'a seq + + exception Range + exception Size + exception NYI + + fun length Leaf = 0 + | length (Elem _) = 1 + | length (Flat s) = PureSeq.length s + | length (Node (n, _, _)) = n + + fun append (t1, t2) = Node (length t1 + length t2, t1, t2) + + fun toPureSeq t = + let + val a = ForkJoin.alloc (length t) + fun put offset t = + case t of + Leaf => () + | Elem x => Array.update (a, offset, x) + | Flat s => PureSeq.foreach s (fn (i, x) => Array.update (a, offset+i, x)) + | Node (n, l, r) => + let + fun left () = put offset l + fun right () = put (offset + length l) r + in + if n <= 4096 then + (left (); right ()) + else + (ForkJoin.par (left, right); ()) + end + in + put 0 t; + PureSeq.fromSeq (ArraySlice.full a) + end + + fun fromPureSeq v = Flat v + + fun empty () = Leaf + fun singleton x = Elem x + val $ = singleton + +end diff --git a/tests/bench/pure-quickhull/main.sml b/tests/bench/pure-quickhull/main.sml new file mode 100644 index 000000000..80ccca51c --- /dev/null +++ b/tests/bench/pure-quickhull/main.sml @@ -0,0 +1,58 @@ +structure CLA = CommandLineArgs + +val resolution = 1000000 +fun randReal seed = + Real.fromInt (Util.hash seed mod resolution) / Real.fromInt resolution + +fun randPt seed = + let + val r = Math.sqrt (randReal (2*seed)) + val theta = randReal (2*seed+1) * 2.0 * Math.pi + in + (1.0 + r * Math.cos(theta), 1.0 + r * Math.sin(theta)) + end + +(* val filename = CLA.parseString "infile" "" *) +val outfile = CLA.parseString "outfile" "" +val n = CLA.parseInt "N" (1000 * 1000 * 100) + +val _ = print ("input size " ^ Int.toString n ^ "\n") + +val (inputPts, tm) = Util.getTime (fn _ => PureSeq.tabulate randPt n) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +val result = Benchmark.run "running quickhull" (fn _ => Quickhull.hull inputPts) + +val _ = print ("hull size " ^ Int.toString (PureSeq.length result) ^ "\n") + +fun rtos x = + if x < 0.0 then "-" ^ rtos (~x) + else Real.fmt (StringCvt.FIX (SOME 3)) x +fun pttos (x,y) = + String.concat ["(", rtos x, ",", rtos y, ")"] + +(* fun check result = + let + val correct = Checkhull.check inputPts result + in + print ("correct? " ^ + Checkhull.report inputPts result (Checkhull.check inputPts result) + ^ "\n") + end *) + +val _ = + if outfile = "" then + print ("use -outfile XXX to see result\n") + else + let + val out = TextIO.openOut outfile + fun writeln str = TextIO.output (out, str ^ "\n") + fun dump i = + if i >= PureSeq.length result then () + else (writeln (Int.toString (PureSeq.nth result i)); dump (i+1)) + in + writeln "pbbs_sequenceInt"; + dump 0; + TextIO.closeOut out + end + diff --git a/tests/bench/pure-quickhull/pure-quickhull.mlb b/tests/bench/pure-quickhull/pure-quickhull.mlb new file mode 100644 index 000000000..7c74996b1 --- /dev/null +++ b/tests/bench/pure-quickhull/pure-quickhull.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +TreeSeq.sml +Split.sml +Quickhull.sml +main.sml diff --git a/tests/bench/pure-skyline/CityGen.sml b/tests/bench/pure-skyline/CityGen.sml new file mode 100644 index 000000000..70d0c15a2 --- /dev/null +++ b/tests/bench/pure-skyline/CityGen.sml @@ -0,0 +1,77 @@ +structure CityGen: +sig + + (* (city n x) produces a sequence of n random buildings, seeded by x (any + * integer will do). *) + val city : int -> int -> (int * int * int) PureSeq.t + + (* (cities m n x) produces m cities, each is a sequence of at most n random + * buildings, seeded by x (any integer will do). *) + val cities : int -> int -> int -> (int * int * int) PureSeq.t PureSeq.t +end = +struct + + structure R = FastHashRand + + (* Fisher-Yates shuffle aka Knuth shuffle *) + fun shuffle s r = + let + val n = PureSeq.length s + val data = Array.tabulate (n, PureSeq.nth s) + + fun swapLoop (r, i) = + if i >= n then r + else let + val j = R.boundedInt (i, n) r + val (x, y) = (Array.sub (data, i), Array.sub (data, j)) + in + Array.update (data, i, y); + Array.update (data, j, x); + swapLoop (R.next r, i+1) + end + + val r' = swapLoop (r, 0) + in + (r', PureSeq.tabulate (fn i => Array.sub (data, i)) n) + end + + fun citySeeded n r0 = + let + val (r1, xs) = shuffle (PureSeq.tabulate (fn i => i) (2*n)) r0 + val (_, seeds) = R.splitTab (r1, n) + fun pow b e = if e <= 0 then 1 else b * pow b (e-1) + + fun makeBuilding i = + let + val xpair = (PureSeq.nth xs (2*i), PureSeq.nth xs (2*i + 1)) + val lo = Int.min xpair + val hi = Int.max xpair + val width = hi-lo + val maxHeight = Int.max (1, 2*n div width) + val maxHeight = + if maxHeight >= n then + 1 + pow (Util.log2 maxHeight) 2 + else + maxHeight + pow (Util.log2 maxHeight) 2 + val heightRange = (Int.max (1, maxHeight-(n div 100)), maxHeight+1) + val height = R.boundedInt heightRange (seeds i) + in + (lo, height, hi) + end + in + PureSeq.tabulate makeBuilding n + end + + fun city n x = citySeeded n (R.fromInt x) + + fun cities m n x = + let + val (_, rs) = R.splitTab (R.fromInt x, m) + fun ithCity i = + let val r = rs i + in citySeeded (R.boundedInt (0, n+1) r) (R.next r) + end + in PureSeq.tabulate ithCity m + end + +end diff --git a/tests/bench/pure-skyline/FastHashRand.sml b/tests/bench/pure-skyline/FastHashRand.sml new file mode 100644 index 000000000..205a7737d --- /dev/null +++ b/tests/bench/pure-skyline/FastHashRand.sml @@ -0,0 +1,66 @@ +(* MUCH faster random number generation than DotMix. + * I wonder how good its randomness is? *) +structure FastHashRand = +struct + type rand = Word64.word + + val maxWord = 0wxFFFFFFFFFFFFFFFF : Word64.word + + exception FastHashRand + + fun hashWord w = + let + open Word64 + infix 2 >> infix 2 << infix 2 xorb infix 2 andb + val v = w * 0w3935559000370003845 + 0w2691343689449507681 + val v = v xorb (v >> 0w21) + val v = v xorb (v << 0w37) + val v = v xorb (v >> 0w4) + val v = v * 0w4768777513237032717 + val v = v xorb (v << 0w20) + val v = v xorb (v >> 0w41) + val v = v xorb (v << 0w5) + in + v + end + + fun fromInt x = hashWord (Word64.fromInt x) + + fun next r = hashWord r + + fun split r = (hashWord r, (hashWord (r+0w1), hashWord (r+0w2))) + + fun biasedBool (h, t) r = + let + val scaleFactor = Word64.div (maxWord, Word64.fromInt (h+t)) + in + Word64.<= (r, Word64.* (Word64.fromInt h, scaleFactor)) + end + + fun split3 _ = raise FastHashRand + fun splitTab (r, n) = + (hashWord r, fn i => hashWord (r + Word64.fromInt (i+1))) + + val intp = + case Int.precision of + SOME n => n + | NONE => (print "[ERR] int precision\n"; OS.Process.exit OS.Process.failure) + + val mask = Word64.<< (0w1, Word.fromInt (intp-1)) + + fun int r = + Word64.toIntX (Word64.andb (r, mask) - 0w1) + + fun int r = + Word64.toIntX (Word64.>> (r, Word.fromInt (64-intp+1))) + + fun boundedInt (a, b) r = a + ((int r) mod (b-a)) + + fun bool _ = raise FastHashRand + + fun biasedInt _ _ = raise FastHashRand + fun real _ = raise FastHashRand + fun boundedReal _ _ = raise FastHashRand + fun char _ = raise FastHashRand + fun boundedChar _ _ = raise FastHashRand +end diff --git a/tests/bench/pure-skyline/Skyline.sml b/tests/bench/pure-skyline/Skyline.sml new file mode 100644 index 000000000..0b10b41a1 --- /dev/null +++ b/tests/bench/pure-skyline/Skyline.sml @@ -0,0 +1,56 @@ +structure Skyline = +struct + type 'a seq = 'a PureSeq.t + type skyline = (int * int) PureSeq.t + + fun singleton (l, h, r) = PureSeq.fromList [(l, h), (r, 0)] + + fun combine (sky1, sky2) = + let + val lMarked = PureSeq.map (fn (x, y) => (x, SOME y, NONE)) sky1 + val rMarked = PureSeq.map (fn (x, y) => (x, NONE, SOME y)) sky2 + + fun cmp ((x1, _, _), (x2, _, _)) = Int.compare (x1, x2) + val merged = PureSeq.merge cmp (lMarked, rMarked) + + fun copy (a, b) = case b of SOME _ => b | NONE => a + fun copyFused ((x1, yl1, yr1), (x2, yl2, yr2)) = + (x2, copy (yl1, yl2), copy (yr1, yr2)) + + val allHeights = PureSeq.scanIncl copyFused (0,NONE,NONE) merged + + fun squish (x, y1, y2) = + (x, Int.max (Option.getOpt (y1, 0), Option.getOpt (y2, 0))) + val sky = PureSeq.map squish allHeights + in + sky + end + + fun skyline g bs = + let + fun skyline' bs = + case PureSeq.length bs of + 0 => PureSeq.empty () + | 1 => singleton (PureSeq.nth bs 0) + | n => + let + val half = n div 2 + val sfL = fn _ => skyline' (PureSeq.take bs half) + val sfR = fn _ => skyline' (PureSeq.drop bs half) + in + if PureSeq.length bs <= g then + combine (sfL (), sfR ()) + else + combine (ForkJoin.par (sfL, sfR)) + end + + val sky = skyline' bs + + fun isUnique (i, (x, h)) = + i = 0 orelse let val (_, prevh) = PureSeq.nth sky (i-1) in h <> prevh end + val sky = PureSeq.filterIdx isUnique sky + in + sky + end + +end diff --git a/tests/bench/pure-skyline/main.sml b/tests/bench/pure-skyline/main.sml new file mode 100644 index 000000000..b205cf8e1 --- /dev/null +++ b/tests/bench/pure-skyline/main.sml @@ -0,0 +1,102 @@ +structure CLA = CommandLineArgs +structure Gen = CityGen + +(* +functor S (Sky : SKYLINE where type skyline = (int * int) Seq.t) = +struct + open Sky + fun skyline bs = + case Seq.splitMid bs of + Seq.EMPTY => Seq.empty () + | Seq.ONE b => singleton b + | Seq.PAIR (l, r) => + let + fun sl _ = skyline l + fun sr _ = skyline r + val (l', r') = + if Seq.length bs <= 1000 + then (sl (), sr ()) + else Primitives.par (sl, sr) + in + combine (l', r') + end +end + +structure Stu = S (MkSkyline (structure Seq = Seq)) +structure Ref = S (MkRefSkyline (structure Seq = Seq)) +*) + +fun pairEq ((x1, y1), (x2, y2)) = (x1 = x2 andalso y1 = y2) + +fun skylinesEq (s1, s2) = + PureSeq.length s1 = PureSeq.length s2 andalso + PureSeq.reduce (fn (a,b) => a andalso b) true + (PureSeq.tabulate (fn i => pairEq (PureSeq.nth s1 i, PureSeq.nth s2 i)) (PureSeq.length s1)) + +val size = CLA.parseInt "size" 1000000 +val seed = CLA.parseInt "seed" 15210 +val grain = CLA.parseInt "grain" 1000 +val output = CLA.parseString "output" "" + +(* ensure newline at end of string *) +fun println s = + let + val needsNewline = + String.size s = 0 orelse String.sub (s, String.size s - 1) <> #"\n" + in + print (if needsNewline then s ^ "\n" else s) + end + +val _ = println ("size " ^ Int.toString size) +val _ = println ("seed " ^ Int.toString seed) +val _ = println ("grain " ^ Int.toString grain) + +val (input, tm) = Util.getTime (fn _ => Gen.city size seed) +val _ = println ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +val sky = Benchmark.run "skyline" (fn _ => Skyline.skyline grain input) +val _ = print ("result-len " ^ Int.toString (PureSeq.length sky) ^ "\n") + +val _ = + if output = "" then + print ("use -output XXX.ppm to see result\n") + else + let + val (xMin, _) = PureSeq.nth sky 0 + val (xMax, _) = PureSeq.nth sky (PureSeq.length sky - 1) + val yMax = PureSeq.reduce Int.max 0 (PureSeq.map (fn (_,y) => y) sky) + val _ = print ("xMin " ^ Int.toString xMin ^ "\n") + val _ = print ("xMax " ^ Int.toString xMax ^ "\n") + val _ = print ("yMax " ^ Int.toString yMax ^ "\n") + + val width = 1000 + val height = 250 + + val padding = 20 + + fun col x = + padding + width * (x - xMin) div (1 + xMax - xMin) + fun row y = + padding + height - 1 - (height * y div (1 + yMax)) + + val width' = 2*padding + width + val height' = padding + height + val image = Seq.tabulate (fn _ => Color.white) (width' * height') + + val _ = PureSeq.foreach sky (fn (idx, (x, y)) => + if idx >= PureSeq.length sky - 1 then () else + let + val (x', _) = PureSeq.nth sky (idx+1) + + val ihi = row y + val jlo = col x + val jhi = Int.max (col x + 1, col x') + in + Util.for (ihi, height') (fn i => + Util.for (jlo, jhi) (fn j => + ArraySlice.update (image, i*width' + j, Color.black))) + end) + in + PPM.write output {width=width', height=height', data=image}; + print ("wrote output to " ^ output ^ "\n") + end diff --git a/tests/bench/pure-skyline/pure-skyline.mlb b/tests/bench/pure-skyline/pure-skyline.mlb new file mode 100644 index 000000000..20d3615dd --- /dev/null +++ b/tests/bench/pure-skyline/pure-skyline.mlb @@ -0,0 +1,6 @@ +../../mpllib/sources.$(COMPAT).mlb +FastHashRand.sml +CityGen.sml +Skyline.sml +main.sml + diff --git a/tests/bench/quickhull/MkOptSplit.sml b/tests/bench/quickhull/MkOptSplit.sml new file mode 100644 index 000000000..0a68264ca --- /dev/null +++ b/tests/bench/quickhull/MkOptSplit.sml @@ -0,0 +1,123 @@ +functor MkSplit (Seq: SEQUENCE) : +sig + type 'a seq = 'a Seq.t + type 'a aseq = 'a ArraySequence.t + + datatype flag = Left | Right | Throwaway + val parSplit: 'a seq -> flag seq -> 'a aseq * 'a aseq +end = +struct + + structure A = Array + structure AS = ArraySlice + structure ASeq = ArraySequence + + type 'a seq = 'a Seq.t + type 'a aseq = 'a ASeq.t + + fun inPlace (s: 'a ASeq.t) putLeft putRight = + let + val (a, start, n) = AS.base s + fun item i = Array.sub (a, i) + fun set i x = Array.update (a, i, x) + + fun growLeft ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item lm + in + if putRight x then + growRight ll lm rm rr + else if not (putLeft x) then + growLeft ll (lm+1) rm rr + else + (set ll x; growLeft (ll+1) (lm+1) rm rr) + end + + and growRight ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item (rm-1) + in + if putLeft x then + swapThenContinue ll lm rm rr + else if not (putRight x) then + growRight ll lm (rm-1) rr + else + (set (rr-1) x; growRight ll lm (rm-1) (rr-1)) + end + + and swapThenContinue ll lm rm rr = + let + val tmp = item lm + in + set ll (item (rm-1)); + set (rr-1) tmp; + growLeft (ll+1) (lm+1) (rm-1) (rr-1) + end + + val (ll, rr) = growLeft start start (start+n) (start+n) + in + (ll-start, (start+n)-rr) + end + + datatype flag = Left | Right | Throwaway + + fun parSplit s flags = + let + val n = Seq.length s + val blockSize = 10000 + val numBlocks = 1 + (n-1) div blockSize + + (* the later scan(s) appears to be faster when split into two separate + * scans, rather than doing a single scan on tuples. *) + + (* val counts = Primitives.alloc numBlocks *) + val countl = ForkJoin.alloc numBlocks + val countr = ForkJoin.alloc numBlocks + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop (cl, cr) i = + if i >= hi then + (* Array.update (counts, b, (cl, cr)) *) + (Array.update (countl, b, cl); Array.update (countr, b, cr)) + else case Seq.nth flags i of + Left => loop (cl+1, cr) (i+1) + | Right => loop (cl, cr+1) (i+1) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + + (* val (offsets, (totl, totr)) = + Seq.scan (fn ((a,b),(c,d)) => (a+c,b+d)) (0,0) (ArraySlice.full counts) *) + val (offsetsl, totl) = ASeq.scan op+ 0 (ArraySlice.full countl) + val (offsetsr, totr) = ASeq.scan op+ 0 (ArraySlice.full countr) + + val left = ForkJoin.alloc totl + val right = ForkJoin.alloc totr + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + (* val (offsetl, offsetr) = Seq.nth offsets b *) + val offsetl = ASeq.nth offsetsl b + val offsetr = ASeq.nth offsetsr b + fun loop (cl, cr) i = + if i >= hi then () else + case Seq.nth flags i of + Left => (Array.update (left, offsetl+cl, Seq.nth s i); loop (cl+1, cr) (i+1)) + | Right => (Array.update (right, offsetr+cr, Seq.nth s i); loop (cl, cr+1) (i+1)) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + in + (ArraySlice.full left, ArraySlice.full right) + end + +end diff --git a/tests/bench/quickhull/MkPurishSplit.sml b/tests/bench/quickhull/MkPurishSplit.sml new file mode 100644 index 000000000..00fabf674 --- /dev/null +++ b/tests/bench/quickhull/MkPurishSplit.sml @@ -0,0 +1,46 @@ +functor MkPurishSplit (Seq: SEQUENCE) : +sig + type 'a seq = 'a Seq.t + type 'a aseq = 'a ArraySequence.t + + datatype flag = Left | Right | Throwaway + val parSplit: 'a seq -> flag seq -> 'a aseq * 'a aseq +end = +struct + + structure A = Array + structure AS = ArraySlice + structure ASeq = ArraySequence + + type 'a seq = 'a Seq.t + type 'a aseq = 'a ASeq.t + + datatype flag = Left | Right | Throwaway + + fun parSplit s flags = + let + fun countFlag i = + case Seq.nth flags i of + Left => (1, 0) + | Right => (0, 1) + | Throwaway => (0, 0) + + fun add ((a, b), (c, d)) = (a+c, b+d) + + val n = Seq.length s + val (offsets, (tl, tr)) = Seq.scan add (0, 0) (Seq.tabulate countFlag n) + + val left = ForkJoin.alloc tl + val right = ForkJoin.alloc tr + in + Seq.applyIdx offsets (fn (i, (offl, offr)) => + case Seq.nth flags i of + Left => Array.update (left, offl, Seq.nth s i) + | Right => Array.update (right, offr, Seq.nth s i) + | _ => () + ); + + (AS.full left, AS.full right) + end + +end diff --git a/tests/bench/quickhull/MkQuickhull.sml b/tests/bench/quickhull/MkQuickhull.sml new file mode 100644 index 000000000..c1bccd7a3 --- /dev/null +++ b/tests/bench/quickhull/MkQuickhull.sml @@ -0,0 +1,109 @@ +functor MkQuickhull (Seq: SEQUENCE): +sig + type 'a aseq = 'a ArraySequence.t + val hull: (real * real) aseq -> int aseq +end = +struct + + structure AS = ArraySlice + structure ASeq = ArraySequence + type 'a aseq = 'a ASeq.t + + structure G = Geometry2D + structure Tree = TreeSeq + + structure Split = MkSplit (Seq) + + fun hull pts = + let + fun pt i = ASeq.nth pts i + fun dist p q i = G.Point.triArea (p, q, pt i) + fun max ((i, di), (j, dj)) = + if di > dj then (i, di) else (j, dj) + fun x i = #1 (pt i) + + fun aboveLine p q i = (dist p q i > 0.0) + + fun parHull idxs l r = + if ASeq.length idxs < 2 then + Tree.fromArraySeq idxs + else + let + val lp = pt l + val rp = pt r + fun d i = dist lp rp i + + val idxs = Seq.fromArraySeq idxs + + val distances = Seq.map (fn i => (i, d i)) idxs + val (mid, _) = Seq.reduce max (~1, Real.negInf) distances + + val midp = pt mid + + fun flag i = + if aboveLine lp midp i then Split.Left + else if aboveLine midp rp i then Split.Right + else Split.Throwaway + val (left, right) = + Split.parSplit idxs (Seq.force (Seq.map flag idxs)) + + fun doLeft () = parHull left l mid + fun doRight () = parHull right mid r + val (leftHull, rightHull) = + if ASeq.length left + ASeq.length right <= 2048 + then (doLeft (), doRight ()) + else ForkJoin.par (doLeft, doRight) + in + Tree.append (leftHull, + (Tree.append (Tree.$ mid, rightHull))) + end + + (* val tm = Util.startTiming () *) + + val allIdx = Seq.tabulate (fn i => i) (ASeq.length pts) + + (* This is faster than doing two reduces *) + val (l, r) = Seq.reduce + (fn ((l1, r1), (l2, r2)) => + (if x l1 < x l2 then l1 else l2, + if x r1 > x r2 then r1 else r2)) + (0, 0) + (Seq.map (fn i => (i, i)) allIdx) + + (* val tm = Util.tick tm "endpoints" *) + + val lp = pt l + val rp = pt r + + fun flag i = + let + val d = dist lp rp i + in + if d > 0.0 then Split.Left + else if d < 0.0 then Split.Right + else Split.Throwaway + end + val (above, below) = + Split.parSplit allIdx (Seq.force (Seq.map flag allIdx)) + + (* val tm = Util.tick tm "above/below filter" *) + + val (above, below) = ForkJoin.par + (fn _ => parHull above l r, + fn _ => parHull below r l) + + (* val tm = Util.tick tm "quickhull" *) + + val hullt = + Tree.append + (Tree.append (Tree.$ l, above), + Tree.append (Tree.$ r, below)) + + val result = Tree.toArraySeq hullt + + (* val tm = Util.tick tm "flatten" *) + in + result + end + +end diff --git a/tests/bench/quickhull/ParseFile.sml b/tests/bench/quickhull/ParseFile.sml new file mode 100644 index 000000000..c7eb4e101 --- /dev/null +++ b/tests/bench/quickhull/ParseFile.sml @@ -0,0 +1,190 @@ +(** SAM_NOTE: copy/pasted... some repetition here with Parse. *) +structure ParseFile = +struct + + structure RF = ReadFile + structure Seq = ArraySequence + structure DS = OldDelayedSeq + + fun tokens (f: char -> bool) (cs: char Seq.t) : (char DS.t) DS.t = + let + val n = Seq.length cs + val s = DS.tabulate (Seq.nth cs) n + val indices = DS.tabulate (fn i => i) (n+1) + fun check i = + if (i = n) then not (f(DS.nth s (n-1))) + else if (i = 0) then not (f(DS.nth s 0)) + else let val i1 = f (DS.nth s i) + val i2 = f (DS.nth s (i-1)) + in (i1 andalso not i2) orelse (i2 andalso not i1) end + val ids = DS.filter check indices + val res = DS.tabulate (fn i => + let val (start, e) = (DS.nth ids (2*i), DS.nth ids (2*i+1)) + in DS.tabulate (fn i => Seq.nth cs (start+i)) (e - start) + end) + ((DS.length ids) div 2) + in + res + end + + fun eqStr str (chars : char DS.t) = + let + val n = String.size str + fun checkFrom i = + i >= n orelse + (String.sub (str, i) = DS.nth chars i andalso checkFrom (i+1)) + in + DS.length chars = n + andalso + checkFrom 0 + end + + fun parseDigit char = + let + val code = Char.ord char + val code0 = Char.ord #"0" + val code9 = Char.ord #"9" + in + if code < code0 orelse code9 < code then + NONE + else + SOME (code - code0) + end + + (* This implementation doesn't work with mpl :( + * Need to fix the basis library... *) + (* + fun parseReal chars = + let + val str = CharVector.tabulate (DS.length chars, DS.nth chars) + in + Real.fromString str + end + *) + + fun parseInt (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun build x i = + if i >= n then SOME x else + case c i of + #"," => build x (i+1) + | #"_" => build x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => build (x * 10 + dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1) (build 0 1) + else if (c 0 = #"+") then + build 0 1 + else + build 0 0 + end + + fun parseReal (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun buildAfterE x i = + let + val chars' = DS.subseq chars (i, n-i) + in + Option.map (fn e => x * Math.pow (10.0, Real.fromInt e)) + (parseInt chars') + end + + fun buildAfterPoint m x i = + if i >= n then SOME x else + case c i of + #"," => buildAfterPoint m x (i+1) + | #"_" => buildAfterPoint m x (i+1) + | #"." => NONE + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildAfterPoint (m * 0.1) (x + m * (Real.fromInt dig)) (i+1) + + fun buildBeforePoint x i = + if i >= n then SOME x else + case c i of + #"," => buildBeforePoint x (i+1) + | #"_" => buildBeforePoint x (i+1) + | #"." => buildAfterPoint 0.1 x (i+1) + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildBeforePoint (x * 10.0 + Real.fromInt dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1.0) (buildBeforePoint 0.0 1) + else + buildBeforePoint 0.0 0 + end + + fun readSequencePoint2d filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequencePoint2d" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun r i = Option.valOf (parseReal (tok (1 + i))) + + fun pt i = + (r (2*i), r (2*i+1)) + handle e => raise Fail ("error parsing point " ^ Int.toString i ^ " (" ^ exnMessage e ^ ")") + + val result = Seq.tabulate pt (n div 2) + in + result + end + + fun readSequenceInt filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceInt" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseInt (tok (1 + i))) + handle e => raise Fail ("error parsing integer " ^ Int.toString i) + in + Seq.tabulate p n + end + + fun readSequenceReal filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceDouble" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseReal (tok (1 + i))) + handle e => raise Fail ("error parsing double value " ^ Int.toString i) + in + Seq.tabulate p n + end + +end diff --git a/tests/bench/quickhull/Quickhull.sml b/tests/bench/quickhull/Quickhull.sml new file mode 100644 index 000000000..fbe8f83d0 --- /dev/null +++ b/tests/bench/quickhull/Quickhull.sml @@ -0,0 +1,119 @@ +structure Quickhull : +sig + val hull : (real * real) Seq.t -> int Seq.t +end = +struct + + structure AS = ArraySlice + structure G = Geometry2D + structure Tree = TreeSeq + + fun hull pts = + let + fun pt i = Seq.nth pts i + fun dist p q i = G.Point.triArea (p, q, pt i) + fun max ((i, di), (j, dj)) = + if di > dj then (i, di) else (j, dj) + fun x i = #1 (pt i) + + fun aboveLine p q i = (dist p q i > 0.0) + + fun parHull idxs l r = + if Seq.length idxs < 2 then + Tree.fromArraySeq idxs + (* if DS.length idxs <= 2048 then + seqHull idxs l r *) + else + let + val lp = pt l + val rp = pt r + fun d i = dist lp rp i + + (* val idxs = DS.fromArraySeq idxs *) + + val (mid, _) = SeqBasis.reduce 10000 max (~1, Real.negInf) + (0, Seq.length idxs) (fn i => (Seq.nth idxs i, d (Seq.nth idxs i))) + (* val distances = DS.map (fn i => (i, d i)) idxs + val (mid, _) = DS.reduce max (~1, Real.negInf) distances *) + + val midp = pt mid + + fun flag i = + if aboveLine lp midp i then Split.Left + else if aboveLine midp rp i then Split.Right + else Split.Throwaway + val (left, right) = + Split.parSplit idxs (Seq.map flag idxs) + (* (DS.force (DS.map flag idxs)) *) + + fun doLeft () = parHull left l mid + fun doRight () = parHull right mid r + val (leftHull, rightHull) = + if Seq.length left + Seq.length right <= 2048 + then (doLeft (), doRight ()) + else ForkJoin.par (doLeft, doRight) + in + Tree.append (leftHull, + (Tree.append (Tree.$ mid, rightHull))) + end + + (* val tm = Util.startTiming () *) + + (* val allIdx = DS.tabulate (fn i => i) (Seq.length pts) *) + + (* This is faster than doing two reduces *) + (* val (l, r) = DS.reduce + (fn ((l1, r1), (l2, r2)) => + (if x l1 < x l2 then l1 else l2, + if x r1 > x r2 then r1 else r2)) + (0, 0) + (DS.map (fn i => (i, i)) allIdx) *) + + val (l, r) = SeqBasis.reduce 10000 + (fn ((l1, r1), (l2, r2)) => + (if x l1 < x l2 then l1 else l2, + if x r1 > x r2 then r1 else r2)) + (0, 0) + (0, Seq.length pts) + (fn i => (i, i)) + + (* val tm = Util.tick tm "endpoints" *) + + val lp = pt l + val rp = pt r + + fun flag i = + let + val d = dist lp rp i + in + if d > 0.0 then Split.Left + else if d < 0.0 then Split.Right + else Split.Throwaway + end + val (above, below) = + (* Split.parSplit allIdx (DS.force (DS.map flag allIdx)) *) + Split.parSplit + (Seq.tabulate (fn i => i) (Seq.length pts)) + (Seq.tabulate flag (Seq.length pts)) + + (* val tm = Util.tick tm "above/below filter" *) + + val (above, below) = ForkJoin.par + (fn _ => parHull above l r, + fn _ => parHull below r l) + + (* val tm = Util.tick tm "quickhull" *) + + val hullt = + Tree.append + (Tree.append (Tree.$ l, above), + Tree.append (Tree.$ r, below)) + + val result = Tree.toArraySeq hullt + + (* val tm = Util.tick tm "flatten" *) + in + result + end + +end diff --git a/tests/bench/quickhull/Split.sml b/tests/bench/quickhull/Split.sml new file mode 100644 index 000000000..0d3efec8e --- /dev/null +++ b/tests/bench/quickhull/Split.sml @@ -0,0 +1,122 @@ +structure Split : +sig + type 'a seq + + val inPlace : 'a seq -> ('a -> bool) -> ('a -> bool) -> (int * int) + + datatype flag = Left | Right | Throwaway + val parSplit : 'a seq -> flag seq -> 'a seq * 'a seq +end = +struct + + structure A = Array + structure AS = ArraySlice + + type 'a seq = 'a Seq.t + + fun inPlace s putLeft putRight = + let + val (a, start, n) = AS.base s + fun item i = A.sub (a, i) + fun set i x = A.update (a, i, x) + + fun growLeft ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item lm + in + if putRight x then + growRight ll lm rm rr + else if not (putLeft x) then + growLeft ll (lm+1) rm rr + else + (set ll x; growLeft (ll+1) (lm+1) rm rr) + end + + and growRight ll lm rm rr = + if lm >= rm then (ll, rr) else + let + val x = item (rm-1) + in + if putLeft x then + swapThenContinue ll lm rm rr + else if not (putRight x) then + growRight ll lm (rm-1) rr + else + (set (rr-1) x; growRight ll lm (rm-1) (rr-1)) + end + + and swapThenContinue ll lm rm rr = + let + val tmp = item lm + in + set ll (item (rm-1)); + set (rr-1) tmp; + growLeft (ll+1) (lm+1) (rm-1) (rr-1) + end + + val (ll, rr) = growLeft start start (start+n) (start+n) + in + (ll-start, (start+n)-rr) + end + + datatype flag = Left | Right | Throwaway + + fun parSplit s flags = + let + val n = Seq.length s + val blockSize = 10000 + val numBlocks = 1 + (n-1) div blockSize + + (* the later scan(s) appears to be faster when split into two separate + * scans, rather than doing a single scan on tuples. *) + + (* val counts = Primitives.alloc numBlocks *) + val countl = ForkJoin.alloc numBlocks + val countr = ForkJoin.alloc numBlocks + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + fun loop (cl, cr) i = + if i >= hi then + (* A.update (counts, b, (cl, cr)) *) + (A.update (countl, b, cl); A.update (countr, b, cr)) + else case Seq.nth flags i of + Left => loop (cl+1, cr) (i+1) + | Right => loop (cl, cr+1) (i+1) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + + (* val (offsets, (totl, totr)) = + Seq.scan (fn ((a,b),(c,d)) => (a+c,b+d)) (0,0) (ArraySlice.full counts) *) + val (offsetsl, totl) = Seq.scan op+ 0 (AS.full countl) + val (offsetsr, totr) = Seq.scan op+ 0 (AS.full countr) + + val left = ForkJoin.alloc totl + val right = ForkJoin.alloc totr + + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val lo = b * blockSize + val hi = Int.min (lo + blockSize, n) + (* val (offsetl, offsetr) = Seq.nth offsets b *) + val offsetl = Seq.nth offsetsl b + val offsetr = Seq.nth offsetsr b + fun loop (cl, cr) i = + if i >= hi then () else + case Seq.nth flags i of + Left => (A.update (left, offsetl+cl, Seq.nth s i); loop (cl+1, cr) (i+1)) + | Right => (A.update (right, offsetr+cr, Seq.nth s i); loop (cl, cr+1) (i+1)) + | _ => loop (cl, cr) (i+1) + in + loop (0, 0) lo + end) + in + (AS.full left, AS.full right) + end + +end diff --git a/tests/bench/quickhull/TreeSeq.sml b/tests/bench/quickhull/TreeSeq.sml new file mode 100644 index 000000000..766667e71 --- /dev/null +++ b/tests/bench/quickhull/TreeSeq.sml @@ -0,0 +1,54 @@ +structure TreeSeq = +struct + datatype 'a t = + Leaf + | Elem of 'a + | Flat of 'a Seq.t + | Node of int * 'a t * 'a t + + type 'a seq = 'a t + type 'a ord = 'a * 'a -> order + datatype 'a listview = NIL | CONS of 'a * 'a seq + datatype 'a treeview = EMPTY | ONE of 'a | PAIR of 'a seq * 'a seq + + exception Range + exception Size + exception NYI + + fun length Leaf = 0 + | length (Elem _) = 1 + | length (Flat s) = Seq.length s + | length (Node (n, _, _)) = n + + fun append (t1, t2) = Node (length t1 + length t2, t1, t2) + + fun toArraySeq t = + let + val a = ForkJoin.alloc (length t) + fun put offset t = + case t of + Leaf => () + | Elem x => Array.update (a, offset, x) + | Flat s => Seq.foreach s (fn (i, x) => Array.update (a, offset+i, x)) + | Node (n, l, r) => + let + fun left () = put offset l + fun right () = put (offset + length l) r + in + if n <= 4096 then + (left (); right ()) + else + (ForkJoin.par (left, right); ()) + end + in + put 0 t; + ArraySlice.full a + end + + fun fromArraySeq a = Flat a + + fun empty () = Leaf + fun singleton x = Elem x + val $ = singleton + +end diff --git a/tests/bench/quickhull/main.sml b/tests/bench/quickhull/main.sml new file mode 100644 index 000000000..fa2edac0a --- /dev/null +++ b/tests/bench/quickhull/main.sml @@ -0,0 +1,86 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure Quickhull = MkQuickhull(OldDelayedSeq) + +val resolution = 1000000 +fun randReal seed = + Real.fromInt (Util.hash seed mod resolution) / Real.fromInt resolution + +fun randPt seed = + let + val r = Math.sqrt (randReal (2*seed)) + val theta = randReal (2*seed+1) * 2.0 * Math.pi + in + (1.0 + r * Math.cos(theta), 1.0 + r * Math.sin(theta)) + end + +val filename = CLA.parseString "input" "" +val outfile = CLA.parseString "outfile" "" +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +(* This silly thing helps ensure good placement, by + * forcing points to be reallocated more adjacent. + * It's a no-op, but gives us as much as 2x time + * improvement (!) + *) +fun swap pts = Seq.map (fn (x, y) => (y, x)) pts +fun compactify pts = swap (swap pts) + +val inputPts = + case filename of + "" => Seq.tabulate randPt n + | _ => compactify (ParseFile.readSequencePoint2d filename) + +val n = Seq.length inputPts + +fun task () = + Quickhull.hull inputPts + +fun rtos x = + if x < 0.0 then "-" ^ rtos (~x) + else Real.fmt (StringCvt.FIX (SOME 3)) x +fun pttos (x,y) = + String.concat ["(", rtos x, ",", rtos y, ")"] + +(* +fun check result = + if not doCheck then () else + let + val correct = Checkhull.check inputPts result + in + print ("correct? " ^ + Checkhull.report inputPts result (Checkhull.check inputPts result) + ^ "\n") + end +*) + +(* val _ = + (writeln "pbbs_sequencePoint2d"; dump inputPts 0; OS.Process.exit OS.Process.success) *) + +val result = Benchmark.run "quickhull" task +val _ = print ("hull size " ^ Int.toString (Seq.length result) ^ "\n") +(* val _ = check result *) + +val _ = + if outfile = "" then () else + let + val out = TextIO.openOut outfile + fun writeln str = TextIO.output (out, str ^ "\n") + fun dump i = + if i >= Seq.length result then () + else (writeln (Int.toString (Seq.nth result i)); dump (i+1)) + in + writeln "pbbs_sequenceInt"; + dump 0; + TextIO.closeOut out + end + + +(* fun dumpPt (x, y) = writeln (rtos x ^ " " ^ rtos y) *) +(* fun dump pts i = + if i >= Seq.length pts then () + else (dumpPt (Seq.nth pts i); dump pts (i+1)) *) +(* val hullPts = Seq.map (Seq.nth inputPts) result *) +(* dump hullPts 0 *) diff --git a/tests/bench/quickhull/quickhull.mlb b/tests/bench/quickhull/quickhull.mlb new file mode 100644 index 000000000..599e6986b --- /dev/null +++ b/tests/bench/quickhull/quickhull.mlb @@ -0,0 +1,18 @@ +../../mpllib/sources.$(COMPAT).mlb +ParseFile.sml +TreeSeq.sml + +(* +Split.sml +Quickhull.sml +*) + + +local + MkPurishSplit.sml +in + functor MkSplit = MkPurishSplit +end +MkQuickhull.sml + +main.sml diff --git a/tests/bench/random/random.mlb b/tests/bench/random/random.mlb new file mode 100644 index 000000000..00d163b4c --- /dev/null +++ b/tests/bench/random/random.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +random.sml diff --git a/tests/bench/random/random.sml b/tests/bench/random/random.sml new file mode 100644 index 000000000..117694042 --- /dev/null +++ b/tests/bench/random/random.sml @@ -0,0 +1,33 @@ +structure CLA = CommandLineArgs + +val grain = CLA.parseInt "grain" 10000 + +(* build an array in parallel with elements f(i) for each 0 <= i < n *) +fun tabulate (n, f) = + let + val arr = ForkJoin.alloc n + in + ForkJoin.parfor grain (0, n) (fn i => Array.update (arr, i, f i)); + arr + end + +(* generate the ith element with a hash function *) +fun gen seed i = Util.hash64 (Word64.xorb (Word64.fromInt i, seed)) + +(* ========================================================================== + * parse command-line arguments and run + *) + +val n = CLA.parseInt "N" (1000 * 1000 * 1000) +val seed = CLA.parseInt "seed" 0 + +val _ = print ("tabulate " ^ Int.toString n ^ " pseudo-random 64-bit words\n") +val _ = print ("seed " ^ Int.toString seed ^ "\n") + +val seed' = Util.hash64 (Word64.fromInt seed) + +val result = Benchmark.run "tabulating" (fn _ => tabulate (n, gen seed')) + +fun str x = Word64.fmt StringCvt.HEX x +val _ = print ("result " ^ Util.summarizeArray 3 str result ^ "\n") + diff --git a/tests/bench/range-tree/RangeTree.sml b/tests/bench/range-tree/RangeTree.sml new file mode 100644 index 000000000..4ee1dc749 --- /dev/null +++ b/tests/bench/range-tree/RangeTree.sml @@ -0,0 +1,59 @@ +signature RANGE_TREE = +sig + type rt + type point = int * int + type weight = int + + val build : (point * weight) Seq.t -> int -> rt + val query : rt -> point -> point -> weight + val print : rt -> unit +end + +structure RangeTree : RANGE_TREE = +struct + type point = int * int + type weight = int + + structure IRTree : Aug = + struct + type key = point + type value = weight + type aug = weight + val compare = fn (p1 : point, p2 : point) => Int.compare (#2 p1, #2 p2) + val g = fn (x, y) => y + val f = fn (x, y) => x + y + val id = 0 + val balance = WB 0.28 + fun debug (k, v, a) = " " + end + + structure InnerRangeTree = PAM(IRTree) + + structure ORTree : Aug = + struct + type key = point + type value = weight + type aug = InnerRangeTree.am + val compare = fn (p1 : key, p2 : key) => Int.compare (#1 p1, #1 p2) + val g = fn (x, y) => InnerRangeTree.singleton x y + val f = fn (x, y) => InnerRangeTree.union x y (Int.+) + val id = InnerRangeTree.empty () + val balance = WB 0.28 + fun debug (k, v, a) = (InnerRangeTree.print_tree a ""; Int.toString v) + end + + structure OuterRangeTree = PAM(ORTree) + type rt = OuterRangeTree.am + + fun build s n = OuterRangeTree.build s 0 n + + fun query r p1 p2 = + let + fun g' ri = InnerRangeTree.aug_range ri p1 p2 + in + OuterRangeTree.aug_project g' (Int.+) r p1 p2 + end + + fun print r = OuterRangeTree.print_tree r " " +end + diff --git a/tests/bench/range-tree/main.sml b/tests/bench/range-tree/main.sml new file mode 100644 index 000000000..bfdf82065 --- /dev/null +++ b/tests/bench/range-tree/main.sml @@ -0,0 +1,103 @@ +structure CLA = CommandLineArgs + +val q = CLA.parseInt "q" 10000000 +val n = CLA.parseInt "n" 10000000 + +val max_size = 2147483647 + +fun randRange i j seed = + i + Word64.toInt + (Word64.mod (Util.hash64 (Word64.fromInt seed), Word64.fromInt (j - i))) + +fun randPt seed = + let + val p = randRange 0 max_size seed + in + (randRange 0 max_size seed, randRange 0 max_size (seed+1)) + end + +(* copied from PAM: range_utils.h *) +fun generate_points n seed = + let + fun rand_coordinate i = randRange 0 max_size i + val rand_numbers = Seq.tabulate rand_coordinate (3*n) + val get = Seq.nth rand_numbers + val points = Seq.tabulate (fn i => ((get i, get (i + n)), get (i + 2*n))) n + in + points + end + +val (tree, tm) = Util.getTime (fn _ => + RangeTree.build (generate_points n 0) n) +val _ = print ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +fun query i = + let + val p1 = randPt (4*i) + val p2 = randPt (4*i + 2) + in + RangeTree.query tree p1 p2 + end + +fun bench () = SeqBasis.tabulate 100 (0, q) query +val result = Benchmark.run "querying range tree" bench +val _ = Util.summarizeArray 10 Int.toString result + +(* + +fun run_rounds f r = + let + fun round_rec i diff = + if i = 0 then diff + else + let + val (t0, t1, _) = f() + val new_diff = Time.- (t1, t0) + val _ = print ("round " ^ (Int.toString (r - i + 1)) ^ " in " ^ Time.fmt 4 (new_diff) ^ "s\n") + in + round_rec (i - 1) (Time.+ (diff, new_diff)) + end + in + round_rec r Time.zeroTime + end + +fun eval_build_range_tree n = + let + val points = generate_points n 0 + val t0 = Time.now () + val rt = RangeTree.build points n + val t1 = Time.now () + in + (t0, t1, rt) + end + +fun eval_queries_range_tree rt q = + let + val max_size = 2147483647 + val pl = generate_points q 0 + val pr = generate_points q 0 + val t0 = Time.now() + val r = SeqBasis.tabulate 100 (0, q) (fn i => RangeTree.query rt (#1 (Seq.nth pl i)) (#1 (Seq.nth pr i))) + val t1 = Time.now() + in + (t0, t1, 0) + end + +val query_size = CommandLineArgs.parseInt "q" 10000000 +val size = CommandLineArgs.parseInt "n" 10000000 +val rep = CommandLineArgs.parseInt "repeat" 1 + +val diff = + if query_size = 0 then + run_rounds (fn _ => eval_build_range_tree size) rep + else + let + val curr = eval_queries_range_tree (#3 (eval_build_range_tree size)) + in + run_rounds (fn _ => curr query_size) rep + end + +val _ = print ("total " ^ Time.fmt 4 diff ^ "s\n") +val avg = Time.toReal diff / (Real.fromInt rep) +val _ = print ("average " ^ Real.fmt (StringCvt.FIX (SOME 4)) avg ^ "s\n") +*) diff --git a/tests/bench/range-tree/range-tree.mlb b/tests/bench/range-tree/range-tree.mlb new file mode 100644 index 000000000..cb2f76121 --- /dev/null +++ b/tests/bench/range-tree/range-tree.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +RangeTree.sml +main.sml diff --git a/tests/bench/raytracer/main.sml b/tests/bench/raytracer/main.sml new file mode 100644 index 000000000..0e0dba2a4 --- /dev/null +++ b/tests/bench/raytracer/main.sml @@ -0,0 +1,429 @@ +(* Author: Troels Henriksen, https://sigkill.dk/ *) + +(* A ray tracer that fires one ray per pixel and only supports +coloured, reflective spheres. It parallelises two things + + 0. The construction of a BVH for accelerating ray lookups + (divide-and-conquer task parallelism) + + 1. The parallel loop across all of the pixels to be computed (data + parallelism, albeit potentially poorly load balanced) + +*) + +type vec3 = {x: real, y: real, z: real} + +local + fun vf f (v1: vec3) (v2: vec3) = + {x= f (#x v1, #x v2), + y= f (#y v1, #y v2), + z= f (#z v1, #z v2)} +in + +val vec_add = vf (op+) +val vec_sub = vf (op-) +val vec_mul = vf (op* ) +val vec_div = vf (op/) + +fun scale s {x,y,z} = {x=s*x, y=s*y, z=s*z} : vec3 + +fun dot (v1: vec3) (v2: vec3) = + let val v3 = vec_mul v1 v2 + in #x v3 + #y v3 + #z v3 end + +fun norm v = Math.sqrt (dot v v) + +fun normalise v = scale (1.0 / norm v) v + +fun cross {x=x1, y=y1, z=z1} {x=x2, y=y2, z=z2} = + {x=y1*z2-z1*y2, y=z1*x2-x1*z2, z=x1*y2-y1*x2} : vec3 + +end + +type aabb = { min: vec3, max: vec3 } + +fun min x y : real = + if x < y then x else y + +fun max x y : real = + if x < y then y else x + +fun enclosing (box0: aabb) (box1: aabb) = + let val small = { x = min (#x (#min box0)) (#x (#min box1)) + , y = min (#y (#min box0)) (#y (#min box1)) + , z = min (#z (#min box0)) (#z (#min box1)) + } + val big = { x = max (#x (#max box0)) (#x (#max box1)) + , y = max (#y (#max box0)) (#y (#max box1)) + , z = max (#z (#max box0)) (#z (#max box1)) + } + in {min=small, max=big} end + +fun centre (aabb: aabb) = + { x = (#x (#min aabb) + (#x (#max aabb) - #x (#min aabb))), + y = (#y (#min aabb) + (#y (#max aabb) - #y (#min aabb))), + z = (#z (#min aabb) + (#z (#max aabb) - #z (#min aabb))) + } + +datatype 'a bvh = bvh_leaf of aabb * 'a + | bvh_split of aabb * 'a bvh * 'a bvh + +fun bvh_aabb (bvh_leaf (box, _)) = box + | bvh_aabb (bvh_split (box, _, _)) = box + +(* Couldn't find a sorting function in MLtons stdlib - this is from Rosetta Code. *) +local + fun merge cmp ([], ys) = ys + | merge cmp (xs, []) = xs + | merge cmp (xs as x::xs', ys as y::ys') = + case cmp (x, y) of + GREATER => y :: merge cmp (xs, ys') + | _ => x :: merge cmp (xs', ys) + fun sort cmp [] = [] + | sort cmp [x] = [x] + | sort cmp xs = + let + val ys = List.take (xs, length xs div 2) + val zs = List.drop (xs, length xs div 2) + in + merge cmp (sort cmp ys, sort cmp zs) + end +in +fun mk_bvh f all_objs = + let fun mk _ _ [] = raise Fail "mk_bvh: no nodes" + | mk _ _ [x] = bvh_leaf(f x, x) + | mk d n xs = + let val axis = case d mod 3 of 0 => #x + | 1 => #y + | _ => #z + fun cmp (x, y) = + Real.compare(axis(centre(f x)), + axis(centre(f y))) + val xs_sorted = sort cmp xs + val xs_left = List.take(xs_sorted, n div 2) + val xs_right = List.drop(xs_sorted, n div 2) + fun do_left () = mk (d+1) (n div 2) xs_left + fun do_right () = mk (d+1) (n-(n div 2)) xs_right + val (left, right) = + if n < 100 + then (do_left(), do_right()) + else ForkJoin.par (do_left, do_right) + val box = enclosing (bvh_aabb left) (bvh_aabb right) + in bvh_split (box, left, right) end + in mk 0 (length all_objs) all_objs end +end + +type pos = vec3 +type dir = vec3 +type colour = vec3 + +val black : vec3 = {x=0.0, y=0.0, z=0.0} +val white : vec3 = {x=1.0, y=1.0, z=1.0} + +type ray = {origin: pos, dir: dir} + +fun point_at_param (ray: ray) t = + vec_add (#origin ray) (scale t (#dir ray)) + +type hit = { t: real + , p: pos + , normal: dir + , colour: colour + } + +type sphere = { pos: pos + , colour: colour + , radius: real + } + +fun sphere_aabb {pos, colour=_, radius} = + {min = vec_sub pos {x=radius, y=radius, z=radius}, + max = vec_add pos {x=radius, y=radius, z=radius}} + +fun sphere_hit {pos, colour, radius} r t_min t_max : hit option = + let val oc = vec_sub (#origin r) pos + val a = dot (#dir r) (#dir r) + val b = dot oc (#dir r) + val c = dot oc oc - radius*radius + val discriminant = b*b - a*c + fun try temp = + if temp < t_max andalso temp > t_min + then SOME { t = temp + , p = point_at_param r temp + , normal = scale (1.0/radius) + (vec_sub (point_at_param r temp) pos) + , colour = colour + } + else NONE + in if discriminant <= 0.0 + then NONE + else case try ((~b - Math.sqrt(b*b-a*c))/a) of + SOME hit => SOME hit + | NONE => try ((~b + Math.sqrt(b*b-a*c))/a) + end + +fun aabb_hit aabb ({origin, dir}: ray) tmin0 tmax0 = + let fun iter min' max' origin' dir' tmin' tmax' = + let val invD = 1.0 / dir' + val t0 = (min' - origin') * invD + val t1 = (max' - origin') * invD + val (t0', t1') = if invD < 0.0 then (t1, t0) else (t0, t1) + val tmin'' = max t0' tmin' + val tmax'' = min t1' tmax' + in (tmin'', tmax'') end + val (tmin1, tmax1) = + iter + (#x (#min aabb)) (#x (#max aabb)) + (#x origin) (#x dir) + tmin0 tmax0 + in if tmax1 <= tmin1 then false + else let val (tmin2, tmax2) = + iter (#y (#min aabb)) (#y (#max aabb)) + (#y origin) (#y dir) + tmin1 tmax1 + in if tmax2 <= tmin2 then false + else let val (tmin3, tmax3) = + iter (#z (#min aabb)) (#z (#max aabb)) + (#z origin) (#z dir) + tmin2 tmax2 + in not (tmax3 <= tmin3) end + end + end + +type objs = sphere bvh + +fun objs_hit (bvh_leaf (_, s)) r t_min t_max = + sphere_hit s r t_min t_max + | objs_hit (bvh_split (box, left, right)) r t_min t_max = + if not (aabb_hit box r t_min t_max) + then NONE + else case objs_hit left r t_min t_max of + SOME h => (case objs_hit right r t_min (#t h) of + NONE => SOME h + | SOME h' => SOME h') + | NONE => objs_hit right r t_min t_max + +type camera = { origin: pos + , llc: pos + , horizontal: dir + , vertical: dir + } + +fun camera lookfrom lookat vup vfov aspect = + let val theta = vfov * Math.pi / 180.0 + val half_height = Math.tan (theta / 2.0) + val half_width = aspect * half_height + val origin = lookfrom + val w = normalise (vec_sub lookfrom lookat) + val u = normalise (cross vup w) + val v = cross w u + in { origin = lookfrom + , llc = vec_sub + (vec_sub (vec_sub origin (scale half_width u)) + (scale half_height v)) w + , horizontal = scale (2.0*half_width) u + , vertical = scale (2.0*half_height) v + } + end + +fun get_ray (cam: camera) s t : ray= + { origin = #origin cam + , dir = vec_sub (vec_add (vec_add (#llc cam) (scale s (#horizontal cam))) + (scale t (#vertical cam))) + (#origin cam) + } + +fun reflect v n = + vec_sub v (scale (2.0 * dot v n) n) + +fun scatter (r: ray) (hit: hit) = + let val reflected = + reflect (normalise (#dir r)) (#normal hit) + val scattered = {origin = #p hit, dir = reflected} + in if dot (#dir scattered) (#normal hit) > 0.0 + then SOME (scattered, #colour hit) + else NONE + end + +fun ray_colour objs r depth = + case objs_hit objs r 0.001 1000000000.0 of + SOME hit => (case scatter r hit of + SOME (scattered, attenuation) => + if depth < 50 + then vec_mul attenuation (ray_colour objs scattered (depth+1)) + else black + | NONE => black) + | NONE => let val unit_dir = normalise (#dir r) + val t = 0.5 * (#y unit_dir + 1.0) + val bg = {x=0.5, y=0.7, z=1.0} + in vec_add (scale (1.0-t) white) (scale t bg) + end + +fun trace_ray objs width height cam j i : colour = + let val u = real i / real width + val v = real j / real height + val ray = get_ray cam u v + in ray_colour objs ray 0 end + +type pixel = int * int * int + +fun colour_to_pixel {x=r,y=g,z=b} = + let val ir = trunc (255.99 * r) + val ig = trunc (255.99 * g) + val ib = trunc (255.99 * b) + in (ir, ig, ib) end + +type image = { pixels: pixel Array.array + , height: int + , width: int} + +fun image2ppm out ({pixels, height, width}: image) = + let fun onPixel (r,g,b) = + TextIO.output(out, + Int.toString r ^ " " ^ + Int.toString g ^ " " ^ + Int.toString b ^ "\n") + in TextIO.output(out, + "P3\n" ^ + Int.toString width ^ " " ^ Int.toString height ^ "\n" ^ + "255\n") + before Array.app onPixel pixels + end + +fun image2ppm6 out ({pixels, height, width}: image) = + let + fun onPixel (r,g,b) = + TextIO.output(out, String.implode (List.map Char.chr [r,g,b])) + in TextIO.output(out, + "P6\n" ^ + Int.toString width ^ " " ^ Int.toString height ^ "\n" ^ + "255\n") + before Array.app onPixel pixels + end + +fun render objs width height cam : image = + let val pixels = ForkJoin.alloc (height*width) + fun pixel l = + let val i = l mod width + val j = height - l div width + in Array.update (pixels, + l, + colour_to_pixel (trace_ray objs width height cam j i)) end + val _ = ForkJoin.parfor 256 (0,height*width) pixel + in {width = width, + height = height, + pixels = pixels + } + end + +type scene = { camLookFrom: pos + , camLookAt: pos + , camFov: real + , spheres: sphere list + } + +fun from_scene width height (scene: scene) : objs * camera = + (mk_bvh sphere_aabb (#spheres scene), + camera (#camLookFrom scene) (#camLookAt scene) {x=0.0, y=1.0, z=0.0} + (#camFov scene) (real width/real height)) + +fun tabulate_2d m n f = + List.concat (List.tabulate (m, fn j => List.tabulate (n, fn i => f (j, i)))) + +val rgbbox : scene = + let val n = 10 + val k = 60.0 + + val leftwall = + tabulate_2d n n (fn (y, z) => + { pos={x=(~k/2.0), + y=(~k/2.0 + (k/real n) * real y), + z=(~k/2.0 + (k/real n) * real z)} + , colour={x=1.0, y=0.0, z=0.0} + , radius = (k/(real n*2.0)) + }) + + val midwall = + tabulate_2d n n (fn (x,y) => + { pos={x=(~k/2.0 + (k/real n) * real x), + y=(~k/2.0 + (k/real n) * real y), + z=(~k/2.0)} + , colour={x=1.0, y=1.0, z=0.0} + , radius = (k/(real n*2.0))}) + + val rightwall = + tabulate_2d n n (fn (y,z) => + { pos={x=(k/2.0), + y=(~k/2.0 + (k/real n) * real y), + z=(~k/2.0 + (k/real n) * real z)} + , colour={x=0.0, y=0.0, z=1.0} + , radius = (k/(real n*2.0)) + }) + + + val bottom = + tabulate_2d n n (fn (x,z) => + { pos={x=(~k/2.0 + (k/real n) * real x), + y=(~k/2.0), + z=(~k/2.0 + (k/real n) * real z)} + , colour={x=1.0, y=1.0, z=1.0} + , radius = (k/(real n*2.0)) + }) + + + in { spheres = leftwall @ midwall @ rightwall @ bottom + , camLookFrom = {x=0.0, y=30.0, z=30.0} + , camLookAt = {x=0.0, y= ~1.0, z= ~1.0} + , camFov = 75.0 + } + end + +val irreg : scene = + let val n = 100 + val k = 600.0 + val bottom = + tabulate_2d n n (fn (x,z) => + { pos={x=(~k/2.0 + (k/real n) * real x), + y=0.0, + z=(~k/2.0 + (k/real n) * real z)} + , colour = white + , radius = k/(real n * 2.0) + }) + in { spheres = bottom + , camLookFrom = {x=0.0, y=12.0, z=30.0} + , camLookAt = {x=0.0, y=10.0, z= ~1.0} + , camFov = 75.0 } + end + +structure CLA = CommandLineArgs + +val height = CLA.parseInt "m" 200 +val width = CLA.parseInt "n" 200 +val f = CLA.parseString "f" "" +val dop6 = CLA.parseFlag "ppm6" +val scene_name = CLA.parseString "s" "rgbbox" +val scene = case scene_name of + "rgbbox" => rgbbox + | "irreg" => irreg + | s => raise Fail ("No such scene: " ^ s) +val rep = case (Int.fromString (CLA.parseString "repeat" "1")) of + SOME(a) => a + | NONE => 1 + +val _ = print ("Using scene '" ^ scene_name ^ "' (-s to switch)\n") + +val ((objs, cam), tm1) = Util.getTime(fn _ => from_scene width height scene) +val _ = print ("Scene BVH construction in " ^ Time.fmt 4 tm1 ^ "s\n") + +val result = Benchmark.run "rendering" (fn _ => render objs width height cam) + +val writeImage = if dop6 then image2ppm6 else image2ppm + +val _ = if f <> "" then + let val out = TextIO.openOut f + in print ("Writing image to " ^ f ^ ".\n") + before writeImage out (result) + before TextIO.closeOut out + end + else print ("-f not passed, so not writing image to file.\n") + diff --git a/tests/bench/raytracer/raytracer.mlb b/tests/bench/raytracer/raytracer.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/raytracer/raytracer.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/reverb/main.sml b/tests/bench/reverb/main.sml new file mode 100644 index 000000000..aad622a52 --- /dev/null +++ b/tests/bench/reverb/main.sml @@ -0,0 +1,26 @@ +structure CLA = CommandLineArgs + + + +val infile = + case CLA.positional () of + [x] => x + | _ => Util.die ("[ERR] usage: reverb INPUT_FILE [-output OUTPUT_FILE]\n") + +val outfile = CLA.parseString "output" "" + +val (snd, tm) = Util.getTime (fn _ => NewWaveIO.readSound infile) +val _ = print ("read sound in " ^ Time.fmt 4 tm ^ "s\n") + +val rsnd = Benchmark.run "reverberating" (fn _ => Signal.reverb snd) + +val _ = + if outfile = "" then + print ("use -output file.wav to hear results\n") + else + let + val (_, tm) = Util.getTime (fn _ => NewWaveIO.writeSound rsnd outfile) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end + diff --git a/tests/bench/reverb/reverb.mlb b/tests/bench/reverb/reverb.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/reverb/reverb.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/samplesort/main.sml b/tests/bench/samplesort/main.sml new file mode 100644 index 000000000..b59ed73a1 --- /dev/null +++ b/tests/bench/samplesort/main.sml @@ -0,0 +1,16 @@ +structure CLA = CommandLineArgs + +val n = CLA.parseInt "N" (100 * 1000 * 1000) +val _ = print ("N " ^ Int.toString n ^ "\n") + +val _ = print ("generating " ^ Int.toString n ^ " random integers\n") + +fun elem i = + Word64.toInt (Word64.mod (Util.hash64 (Word64.fromInt i), Word64.fromInt n)) +val input = ArraySlice.full (SeqBasis.tabulate 10000 (0, n) elem) + +val result = + Benchmark.run "running samplesort" (fn _ => SampleSort.sort Int.compare input) + +val _ = print ("result " ^ Util.summarizeArraySlice 8 Int.toString result ^ "\n") + diff --git a/tests/bench/samplesort/samplesort.mlb b/tests/bench/samplesort/samplesort.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/samplesort/samplesort.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/seam-carve-index/README b/tests/bench/seam-carve-index/README new file mode 100644 index 000000000..4a6f232cc --- /dev/null +++ b/tests/bench/seam-carve-index/README @@ -0,0 +1,7 @@ +This benchmark computes a seam-carving index, which is the order in which +pixels are removed. This allows us to efficiently generate any intermediate +image. + +So, whereas the "seam-carve" benchmark really only removes one seam +(and this can be iterated to remove many seams), this benchmark is designed +to carve out many seams. diff --git a/tests/bench/seam-carve-index/SCI.sml b/tests/bench/seam-carve-index/SCI.sml new file mode 100644 index 000000000..a0ce09346 --- /dev/null +++ b/tests/bench/seam-carve-index/SCI.sml @@ -0,0 +1,192 @@ +structure SCI: +sig + type image = PPM.image + type seam = int Seq.t + + (* `makeSeamCarveIndex n img` removes `n` seams and returns + * a mapping X that indicates the order in which pixels are removed. + * + * For a pixel at (i,j): + * - if not removed, then X[i*width + j] = -1 + * - otherwise, removed in seam number X[i*width + j] + * + * So, for an image of height H, there will be H pixels that are marked 0 + * and H other pixels that are marked 1, etc. + *) + val makeSeamCarveIndex: int -> image -> int Seq.t +end = +struct + + type image = PPM.image + type seam = int Seq.t + + val blockWidth = CommandLineArgs.parseInt "block-width" 80 + val _ = print ("block-width " ^ Int.toString blockWidth ^ "\n") + val _ = + if blockWidth mod 2 = 0 then () + else Util.die ("block-width must be even!") + + (* This is copied/adapted from ../seam-carve/SC.sml. See that file for + * explanation of the algorithm. *) + fun triangularBlockedWriteAllMinSeams width height energy minSeamEnergies = + let + fun M (i, j) = + if j < 0 orelse j >= width then Real.posInf + else Array.sub (minSeamEnergies, i*width + j) + fun setM (i, j) = + let + val x = + if i = 0 then 0.0 + else energy (i, j) + + Real.min (M (i-1, j), Real.min (M (i-1, j-1), M (i-1, j+1))) + in + Array.update (minSeamEnergies, i*width + j, x) + end + + val blockHeight = blockWidth div 2 + val numBlocks = 1 + (width - 1) div blockWidth + + fun upperTriangle i jMid = + Util.for (0, Int.min (height-i, blockHeight)) (fn k => + let + val lo = Int.max (0, jMid-blockHeight+k) + val hi = Int.min (width, jMid+blockHeight-k) + in + Util.for (lo, hi) (fn j => setM (i+k, j)) + end) + + fun lowerTriangle i jMid = + Util.for (0, Int.min (height-i, blockHeight)) (fn k => + let + val lo = Int.max (0, jMid-k-1) + val hi = Int.min (width, jMid+k+1) + in + Util.for (lo, hi) (fn j => setM (i+k, j)) + end) + + fun setStripStartingAt i = + ( ForkJoin.parfor 1 (0, numBlocks) (fn b => + upperTriangle i (b * blockWidth + blockHeight)) + ; ForkJoin.parfor 1 (0, numBlocks+1) (fn b => + lowerTriangle (i+1) (b * blockWidth)) + ) + + fun loop i = + if i >= height then () else + ( setStripStartingAt i + ; loop (i + blockHeight + 1) + ) + in + loop 0 + end + + (* ====================================================================== *) + + fun isolateMinSeam width height M = + let + fun idxMin2 ((j1, m1), (j2, m2)) = + if m1 > m2 then (j2, m2) else (j1, m1) + fun idxMin3 (a, b, c) = idxMin2 (a, idxMin2 (b, c)) + + (* the index of the minimum seam in the last row *) + val (jMin, _) = + SeqBasis.reduce 1000 + idxMin2 (~1, Real.posInf) (0, width) (fn j => (j, M (height-1, j))) + + val seam = ForkJoin.alloc height + + fun computeSeamBackwards (i, j) = + if i = 0 then + Array.update (seam, 0, j) + else + let + val (j', _) = idxMin3 + ( (j, M (i-1, j )) + , (j-1, M (i-1, j-1)) + , (j+1, M (i-1, j+1)) + ) + in + Array.update (seam, i, j); + computeSeamBackwards (i-1, j') + end + in + computeSeamBackwards (height-1, jMin); + ArraySlice.full seam + end + + (* ====================================================================== *) + + structure VSIM = VerticalSeamIndexMap + + fun makeSeamCarveIndex numSeamsToRemove image = + let + val N = #width image * #height image + + (* This buffer will be reused throughout *) + val minSeamEnergies = ForkJoin.alloc N + + fun pixel idx (i, j) = PPM.elem image (VSIM.remap idx (i, j)) + + (* =========================================== + * computing the energy of all pixels + * (gradient values) + *) + + fun d p1 p2 = Color.distance (p1, p2) + + fun energy idx (i, j) = + let + val (h, w) = VSIM.domain idx + in + if j = w-1 then Real.posInf + else if i = h - 1 then 0.0 + else let + val p = pixel idx (i, j) + val dx = d p (pixel idx (i, j+1)) + val dy = d p (pixel idx (i+1, j)) + in + Math.sqrt (dx + dy) + end + end + + (* ============================================ + * loop to remove seams + *) + + val X = ForkJoin.alloc N + val _ = ForkJoin.parfor 4000 (0, N) (fn i => Array.update (X, i, ~1)) + fun setX (i, j) x = Array.update (X, i*(#width image) + j, x) + + val idx = VSIM.new (#height image, #width image) + + fun loop numSeamsRemoved = + if numSeamsRemoved >= numSeamsToRemove then () else + let + val currentWidth = #width image - numSeamsRemoved + val _ = triangularBlockedWriteAllMinSeams + currentWidth + (#height image) + (energy idx) + minSeamEnergies (* results written here *) + + fun M (i, j) = + if j < 0 orelse j >= currentWidth then Real.posInf + else Array.sub (minSeamEnergies, i*currentWidth + j) + + val seam = isolateMinSeam currentWidth (#height image) M + in + Seq.foreach seam (fn (i, j) => + setX (VSIM.remap idx (i, j)) numSeamsRemoved); + + VSIM.carve idx seam; + + loop (numSeamsRemoved+1) + end + + in + loop 0; + + ArraySlice.full X + end + +end diff --git a/tests/bench/seam-carve-index/VerticalSeamIndexMap.sml b/tests/bench/seam-carve-index/VerticalSeamIndexMap.sml new file mode 100644 index 000000000..2b331c655 --- /dev/null +++ b/tests/bench/seam-carve-index/VerticalSeamIndexMap.sml @@ -0,0 +1,80 @@ +structure VerticalSeamIndexMap :> +sig + type t + type seam = int Seq.t + + (* `new (height, width)` *) + val new: (int * int) -> t + + (* Remaps from some (H, W) into (H', W'). For vertical seams, we always + * have H = H'. This signature could be reused for horizontal seams, though. + *) + val domain: t -> (int * int) + val range: t -> (int * int) + + (* Remap an index (i, j) to some (i', j'), where i is a row index and j + * is a column index. With vertical seams, i = i'. + *) + val remap: t -> (int * int) -> (int * int) + + (* Remove the given seam. + * Causes all (i, j) on the right of the seam to be remapped to (i, j+1). + *) + val carve: t -> seam -> unit +end = +struct + + structure AS = ArraySlice + + type t = {displacement: int Seq.t, domain: (int * int) ref, range: int * int} + type seam = int Seq.t + + fun new (height, width) = + { displacement = AS.full (SeqBasis.tabulate 4000 (0, width*height) (fn _ => 0)) + , domain = ref (height, width) + , range = (height, width) + } + + fun domain ({domain = ref d, ...}: t) = d + fun range ({range=r, ...}: t) = r + + fun remap ({displacement, range=(h, w), ...}: t) (i, j) = + (i, j + Seq.nth displacement (i*w + j)) + + (* fun carve ({displacement, domain=(h, w), range=r}: t) seam = + let + in + { domain = (h, w-1) + , range = r + , displacement = displacement + } + end *) + + fun carve ({displacement, domain=(d as ref (h, w)), range=(_, w0)}: t) seam = + ( d := (h, w-1) + ; ForkJoin.parfor 1 (0, h) (fn i => + let + val s = Seq.nth seam i + in + Util.for (s+1, w) (fn j => + ArraySlice.update (displacement, i*w0 + j - 1, + 1 + Seq.nth displacement (i*w0 + j))) + end) + ) + + (* { domain = (h, w-1) + , range = (h, w0) + , displacement = AS.full (SeqBasis.tabulate 1000 (0, w0*h) (fn k => + let + val i = k div w0 + val j = k mod w0 + val s = Seq.nth seam i + in + if j < s then + Seq.nth displacement (i*w0 + j) + else + 1 + Seq.nth displacement (i*w0 + j) + end)) + } *) + +end diff --git a/tests/bench/seam-carve-index/main.sml b/tests/bench/seam-carve-index/main.sml new file mode 100644 index 000000000..d9efb5142 --- /dev/null +++ b/tests/bench/seam-carve-index/main.sml @@ -0,0 +1,122 @@ +structure CLA = CommandLineArgs + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val numSeams = CLA.parseInt "num-seams" 100 +val _ = print ("num-seams " ^ Int.toString numSeams ^ "\n") + +val (image, tm) = Util.getTime (fn _ => PPM.read filename) +val _ = print ("read image in " ^ Time.fmt 4 tm ^ "s\n") + +val w = #width image +val h = #height image + +val _ = print ("height " ^ Int.toString h ^ "\n") +val _ = print ("width " ^ Int.toString w ^ "\n") + +val _ = + if numSeams >= 0 andalso numSeams <= w then () + else + Util.die ("cannot remove " ^ Int.toString numSeams + ^ " seams from image of width " ^ Int.toString w ^ "\n") + +val X = Benchmark.run "seam carving" + (fn _ => SCI.makeSeamCarveIndex numSeams image) + +val outfile = CLA.parseString "output" "" + +val _ = + if outfile = "" then + print ("use -output XXX.gif to see result\n") + else + let + val ((palette, indices), tm) = Util.getTime (fn _ => + let + val palette = GIF.Palette.summarize [Color.black, Color.red] 128 image + in + (palette, #remap palette image) + end) + + val _ = print ("remapped color palette in " ^ Time.fmt 4 tm ^ "s\n") + fun getIdx (i, j) = Seq.nth indices (i*w + j) + + val redIdx = GIF.Palette.remapColor palette Color.red + val blackIdx = GIF.Palette.remapColor palette Color.black + + fun removeSeams count = + let + val data = ForkJoin.alloc (w * h) + fun set (i, j) x = Array.update (data, i*w + j, x) + + (* compact row i from index j, writing the result at index k *) + fun compactRow i j k = + if j >= w then + Util.for (k, w) (fn kk => set (i, kk) blackIdx) + else + let + val xx = Seq.nth X (i*w + j) + in + if xx = ~1 orelse xx > count then + ( set (i, k) (getIdx (i, j)) + ; compactRow i (j+1) (k+1) + ) + else if xx = count then + ( set (i, k) redIdx + ; compactRow i (j+1) (k+1) + ) + else + compactRow i (j+1) k + end + in + ForkJoin.parfor 1 (0, h) (fn i => compactRow i 0 0); + ArraySlice.full data + end + + val (images, tm) = Util.getTime (fn _ => + ArraySlice.full (SeqBasis.tabulate 1 (0, numSeams+1) removeSeams)) + val _ = print ("generated images in " ^ Time.fmt 4 tm ^ "s\n") + + val (_, tm) = Util.getTime (fn _ => + GIF.writeMany outfile 10 palette + { width = w + , height = h + , numImages = numSeams+1 + , getImage = Seq.nth images + }) + in + print ("wrote to " ^ outfile ^ " in " ^ Time.fmt 4 tm ^ "s\n") + end + +(* val _ = + if outfile = "" then + print ("use -output XXX to see result\n") + else + let + fun colorSeam i = + Color.hsv { h = 100.0 * (Real.fromInt i / Real.fromInt numSeams) + , s = 1.0 + , v = 1.0 + } + + val carved = + { width = w + , height = h + , data = ArraySlice.full (SeqBasis.tabulate 4000 (0, w * h) (fn k => + let + val i = k div w + val j = k mod w + in + if Seq.nth X k < 0 then + PPM.elem image (i, j) + else + colorSeam (Seq.nth X k) + end)) + } + val (_, tm) = Util.getTime (fn _ => PPM.write outfile carved) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end *) + diff --git a/tests/bench/seam-carve-index/seam-carve-index.mlb b/tests/bench/seam-carve-index/seam-carve-index.mlb new file mode 100644 index 000000000..215968970 --- /dev/null +++ b/tests/bench/seam-carve-index/seam-carve-index.mlb @@ -0,0 +1,4 @@ +../../mpllib/sources.$(COMPAT).mlb +VerticalSeamIndexMap.sml +SCI.sml +main.sml diff --git a/tests/bench/seam-carve/SC.sml b/tests/bench/seam-carve/SC.sml new file mode 100644 index 000000000..993f82487 --- /dev/null +++ b/tests/bench/seam-carve/SC.sml @@ -0,0 +1,287 @@ +structure SC: +sig + type image = PPM.image + type seam = int Seq.t + + val minSeam: image -> seam + val paintSeam: image -> seam -> PPM.pixel -> image + + val carve: image -> seam -> image + + val removeSeams: int -> image -> image +end = +struct + + structure AS = ArraySlice + + type image = PPM.image + type seam = int Seq.t + + (* ====================================================================== + * seam finding algorithms to solve the equation + * M(i,j) = E(i,j) + min(M(i-1,j), M(i-1,j-1), M(i-1,j+1)) + *) + + (* row-wise bottom-up DP goes by increasing i *) + fun rowWiseAllMinSeams width height (energy: int -> int -> real): real Seq.t = + let + val minSeamEnergies = ForkJoin.alloc (width * height) + fun M i j = + if j < 0 orelse j >= width then Real.posInf + else Array.sub (minSeamEnergies, i*width + j) + fun setM i j = + let + val x = + if i = 0 then 0.0 + else energy i j + + Real.min (M (i-1) j, Real.min (M (i-1) (j-1), M (i-1) (j+1))) + in + Array.update (minSeamEnergies, i*width + j, x) + end + + fun computeMinSeamEnergies i = + if i >= height then () + else ( ForkJoin.parfor 1000 (0, width) (setM i) + ; computeMinSeamEnergies (i+1) + ) + + val _ = computeMinSeamEnergies 0 + in + AS.full minSeamEnergies + end + + (* I tuned this a little bit. For an image approximately 1000 pixels + * wide, this only gives us about 12x possible speedup. But any smaller + * and the grains are too small! *) + val blockWidth = CommandLineArgs.parseInt "block-width" 80 + val _ = print ("block-width " ^ Int.toString blockWidth ^ "\n") + val _ = + if blockWidth mod 2 = 0 then () + else Util.die ("block-width must be even!") + + (* Triangular-blocked bottom-up DP does fancy triangular strategy, to + * improve granularity. + * + * Imaging breaking up the image into a bunch of strips, where each strip + * is then divided into triangles. If each triangle is processed sequentially + * row-wise from top to bottom, then we can compute a strip in parallel + * by first doing all of upper triangles (#), and then doing all of the lower + * triangles (.): + * + * +------------------------------------+ + * strip 1 -> |\####/\####/\####/\####/\####/\####/| + * |.\##/..\##/..\##/..\##/..\##/..\##/.| + * |..\/....\/....\/....\/....\/....\/..| + * strip 2 -> |\####/\####/\####/\####/\####/\####/| + * |.\##/..\##/..\##/..\##/..\##/..\##/.| + * |..\/....\/....\/....\/....\/....\/..| + * strip 3 -> | | + *) + fun triangularBlockedAllMinSeams width height energy = + let + val minSeamEnergies = ForkJoin.alloc (width * height) + fun M i j = + if j < 0 orelse j >= width then Real.posInf + else Array.sub (minSeamEnergies, i*width + j) + fun setM i j = + let + val x = + if i = 0 then 0.0 + else energy i j + + Real.min (M (i-1) j, Real.min (M (i-1) (j-1), M (i-1) (j+1))) + in + Array.update (minSeamEnergies, i*width + j, x) + end + + val blockHeight = blockWidth div 2 + val numBlocks = 1 + (width - 1) div blockWidth + + (* Fill in a triangle starting at row i, centered at jMid, with + * the fat end at top and small end at bottom. + * + * For example with blockWidth 6: + * jMid + * | + * i -- X X X X X X + * X X X X + * X X + *) + fun upperTriangle i jMid = + Util.for (0, Int.min (height-i, blockHeight)) (fn k => + let + val lo = Int.max (0, jMid-blockHeight+k) + val hi = Int.min (width, jMid+blockHeight-k) + in + Util.for (lo, hi) (fn j => setM (i+k) j) + end) + + (* The other way around. For example with blockWidth 6: + * jMid + * | + * i -- X X + * X X X X + * X X X X X X + *) + fun lowerTriangle i jMid = + Util.for (0, Int.min (height-i, blockHeight)) (fn k => + let + val lo = Int.max (0, jMid-k-1) + val hi = Int.min (width, jMid+k+1) + in + Util.for (lo, hi) (fn j => setM (i+k) j) + end) + + (* This sets rows [i, i + blockHeight]. + * Note that this includes the row i+blockHeight; i.e. the number of + * rows set is blockHeight+1 + *) + fun setStripStartingAt i = + ( ForkJoin.parfor 1 (0, numBlocks) (fn b => + upperTriangle i (b * blockWidth + blockHeight)) + ; ForkJoin.parfor 1 (0, numBlocks+1) (fn b => + lowerTriangle (i+1) (b * blockWidth)) + ) + + fun computeMinSeamEnergies i = + if i >= height then () else + ( setStripStartingAt i + ; computeMinSeamEnergies (i + blockHeight + 1) + ) + + val _ = computeMinSeamEnergies 0 + in + AS.full minSeamEnergies + end + + (* ====================================================================== + * find the min seam + * can choose for the allMinSeams algorithm: + * 1. rowWiseAllMinSeams + * 2. triangularBlockedAllMinSeams + *) + + fun minSeam' allMinSeams image = + let + val height = #height image + val width = #width image + fun pixel i j = PPM.elem image (i, j) + + (* =========================================== + * compute the energy of all pixels + * (gradient values) + *) + + fun c x = Real.fromInt (Word8.toInt x) / 255.0 + fun sq (x: real) = x * x + fun d {red=r1, green=g1, blue=b1} {red=r2, green=g2, blue=b2} = + sq (c r2 - c r1) + sq (c g2 - c g1) + sq (c b2 - c b1) + + fun computeEnergy i j = + if j = width-1 then Real.posInf + else if i = height - 1 then 0.0 + else let + val dx = d (pixel i j) (pixel i (j+1)) + val dy = d (pixel i j) (pixel (i+1) j) + in + Math.sqrt (dx + dy) + end + + val energies = + AS.full (SeqBasis.tabulate 4000 (0, width*height) (fn k => + computeEnergy (k div width) (k mod width))) + fun energy i j = + Seq.nth energies (i*width + j) + + (* =========================================== + * compute the min seam energies + *) + + val MM = allMinSeams width height energy + fun M i j = + if j < 0 orelse j >= width then Real.posInf + else Seq.nth MM (i*width + j) + + (* =========================================== + * isolate the minimum seam + *) + + fun idxMin2 ((j1, m1), (j2, m2)) = + if m1 > m2 then (j2, m2) else (j1, m1) + fun idxMin3 (a, b, c) = idxMin2 (a, idxMin2 (b, c)) + + (* the index of the minimum seam in the last row *) + val (jMin, _) = + SeqBasis.reduce 1000 + idxMin2 (~1, Real.posInf) (0, width) (fn j => (j, M (height-1) j)) + + fun computeSeamBackwards seam (i, j) = + if i = 0 then j::seam else + let + val (j', _) = idxMin3 + ( (j, M (i-1) j ) + , (j-1, M (i-1) (j-1)) + , (j+1, M (i-1) (j+1)) + ) + in + computeSeamBackwards (j::seam) (i-1, j') + end + in + Seq.fromList (computeSeamBackwards [] (height-1, jMin)) + end + + (* val minSeam = minSeam' rowWiseAllMinSeams *) + val minSeam = minSeam' triangularBlockedAllMinSeams + + (* ====================================================================== + * utilities: carving, painting seams, etc. + *) + + fun carve image seam = + let + val height = #height image + val width = #width image + fun pixel i j = PPM.elem image (i, j) + + fun newElem k = + let + val i = k div (width-1) + val j = k mod (width-1) + val r = Seq.nth seam i + in + if j < r then pixel i j else pixel i (j+1) + end + in + { width = width-1 + , height = height + , data = AS.full (SeqBasis.tabulate 4000 (0, (width-1)*height) newElem) + } + end + + fun paintSeam image seam seamColor = + let + val height = #height image + val width = #width image + fun pixel i j = PPM.elem image (i, j) + + fun newElem k = + let + val i = k div width + val j = k mod width + val r = Seq.nth seam i + in + if j = r then seamColor else pixel i j + end + in + { width = width + , height = height + , data = Seq.tabulate newElem (height * width) + } + end + + fun removeSeams n image = + if n = 0 then + image + else + removeSeams (n-1) (carve image (minSeam image)) + +end diff --git a/tests/bench/seam-carve/main.sml b/tests/bench/seam-carve/main.sml new file mode 100644 index 000000000..ff28b8e0e --- /dev/null +++ b/tests/bench/seam-carve/main.sml @@ -0,0 +1,29 @@ +structure CLA = CommandLineArgs + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val numSeams = CLA.parseInt "num-seams" 100 +val _ = print ("num-seams " ^ Int.toString numSeams ^ "\n") + +val (image, tm) = Util.getTime (fn _ => PPM.read filename) +val _ = print ("read image in " ^ Time.fmt 4 tm ^ "s\n") + +val carved = Benchmark.run "seam carving" (fn _ => SC.removeSeams numSeams image) + +val outfile = CLA.parseString "output" "" +val _ = + if outfile = "" then + print ("use -output XXX to see result\n") + else + let + (* val red = {red=0w255, green=0w0, blue=0w0} + val (_, tm) = Util.getTime (fn _ => + PPM.write outfile (SC.paintSeam image seam red)) *) + val (_, tm) = Util.getTime (fn _ => PPM.write outfile carved) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end + diff --git a/tests/bench/seam-carve/seam-carve.mlb b/tests/bench/seam-carve/seam-carve.mlb new file mode 100644 index 000000000..8f9c5c603 --- /dev/null +++ b/tests/bench/seam-carve/seam-carve.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +SC.sml +main.sml diff --git a/tests/bench/shuf/main.sml b/tests/bench/shuf/main.sml new file mode 100644 index 000000000..ec9a2ef4b --- /dev/null +++ b/tests/bench/shuf/main.sml @@ -0,0 +1,47 @@ +structure CLA = CommandLineArgs +val seed = CLA.parseInt "seed" 15210 +val outfile = CLA.parseString "o" "" + +val filename = + case CLA.positional () of + [f] => f + | _ => Util.die ("usage: shuf [-o OUTPUT_FILE] [-seed N] INPUT_FILE") + +val (contents, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filename) +val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") + + +fun shuf () = + let + val (lines, tm) = Util.getTime (fn _ => + Tokenize.tokens (fn c => c = #"\n") contents) + val _ = print ("tokenized in " ^ Time.fmt 4 tm ^ "s\n") + + val indices = Seq.tabulate (fn i => i) (Seq.length lines) + val perm = Shuffle.shuffle indices seed + in + Seq.tabulate (fn i => Seq.nth lines (Seq.nth perm i)) (Seq.length lines) + end + +val result = Benchmark.run "shuffle" shuf + +fun dump () = + let + val f = TextIO.openOut outfile + in + Util.for (0, Seq.length result) (fn i => + ( TextIO.output (f, Seq.nth result i) + ; TextIO.output1 (f, #"\n") + )); + TextIO.closeOut f + end + +val _ = + if outfile = "" then + print ("no output specified; use -o OUTPUT_FILE to see results\n") + else + let + val ((), tm) = Util.getTime dump + in + print ("wrote to " ^ outfile ^ " in " ^ Time.fmt 4 tm ^ "s\n") + end diff --git a/tests/bench/shuf/shuf.mlb b/tests/bench/shuf/shuf.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/shuf/shuf.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/skyline/CityGen.sml b/tests/bench/skyline/CityGen.sml new file mode 100644 index 000000000..d88c2285d --- /dev/null +++ b/tests/bench/skyline/CityGen.sml @@ -0,0 +1,77 @@ +structure CityGen: +sig + + (* (city n x) produces a sequence of n random buildings, seeded by x (any + * integer will do). *) + val city : int -> int -> (int * int * int) Seq.t + + (* (cities m n x) produces m cities, each is a sequence of at most n random + * buildings, seeded by x (any integer will do). *) + val cities : int -> int -> int -> (int * int * int) Seq.t Seq.t +end = +struct + + structure R = FastHashRand + + (* Fisher-Yates shuffle aka Knuth shuffle *) + fun shuffle s r = + let + val n = Seq.length s + val data = Array.tabulate (n, Seq.nth s) + + fun swapLoop (r, i) = + if i >= n then r + else let + val j = R.boundedInt (i, n) r + val (x, y) = (Array.sub (data, i), Array.sub (data, j)) + in + Array.update (data, i, y); + Array.update (data, j, x); + swapLoop (R.next r, i+1) + end + + val r' = swapLoop (r, 0) + in + (r', Seq.tabulate (fn i => Array.sub (data, i)) n) + end + + fun citySeeded n r0 = + let + val (r1, xs) = shuffle (Seq.tabulate (fn i => i) (2*n)) r0 + val (_, seeds) = R.splitTab (r1, n) + fun pow b e = if e <= 0 then 1 else b * pow b (e-1) + + fun makeBuilding i = + let + val xpair = (Seq.nth xs (2*i), Seq.nth xs (2*i + 1)) + val lo = Int.min xpair + val hi = Int.max xpair + val width = hi-lo + val maxHeight = Int.max (1, 2*n div width) + val maxHeight = + if maxHeight >= n then + 1 + pow (Util.log2 maxHeight) 2 + else + maxHeight + pow (Util.log2 maxHeight) 2 + val heightRange = (Int.max (1, maxHeight-(n div 100)), maxHeight+1) + val height = R.boundedInt heightRange (seeds i) + in + (lo, height, hi) + end + in + Seq.tabulate makeBuilding n + end + + fun city n x = citySeeded n (R.fromInt x) + + fun cities m n x = + let + val (_, rs) = R.splitTab (R.fromInt x, m) + fun ithCity i = + let val r = rs i + in citySeeded (R.boundedInt (0, n+1) r) (R.next r) + end + in Seq.tabulate ithCity m + end + +end diff --git a/tests/bench/skyline/FastHashRand.sml b/tests/bench/skyline/FastHashRand.sml new file mode 100644 index 000000000..205a7737d --- /dev/null +++ b/tests/bench/skyline/FastHashRand.sml @@ -0,0 +1,66 @@ +(* MUCH faster random number generation than DotMix. + * I wonder how good its randomness is? *) +structure FastHashRand = +struct + type rand = Word64.word + + val maxWord = 0wxFFFFFFFFFFFFFFFF : Word64.word + + exception FastHashRand + + fun hashWord w = + let + open Word64 + infix 2 >> infix 2 << infix 2 xorb infix 2 andb + val v = w * 0w3935559000370003845 + 0w2691343689449507681 + val v = v xorb (v >> 0w21) + val v = v xorb (v << 0w37) + val v = v xorb (v >> 0w4) + val v = v * 0w4768777513237032717 + val v = v xorb (v << 0w20) + val v = v xorb (v >> 0w41) + val v = v xorb (v << 0w5) + in + v + end + + fun fromInt x = hashWord (Word64.fromInt x) + + fun next r = hashWord r + + fun split r = (hashWord r, (hashWord (r+0w1), hashWord (r+0w2))) + + fun biasedBool (h, t) r = + let + val scaleFactor = Word64.div (maxWord, Word64.fromInt (h+t)) + in + Word64.<= (r, Word64.* (Word64.fromInt h, scaleFactor)) + end + + fun split3 _ = raise FastHashRand + fun splitTab (r, n) = + (hashWord r, fn i => hashWord (r + Word64.fromInt (i+1))) + + val intp = + case Int.precision of + SOME n => n + | NONE => (print "[ERR] int precision\n"; OS.Process.exit OS.Process.failure) + + val mask = Word64.<< (0w1, Word.fromInt (intp-1)) + + fun int r = + Word64.toIntX (Word64.andb (r, mask) - 0w1) + + fun int r = + Word64.toIntX (Word64.>> (r, Word.fromInt (64-intp+1))) + + fun boundedInt (a, b) r = a + ((int r) mod (b-a)) + + fun bool _ = raise FastHashRand + + fun biasedInt _ _ = raise FastHashRand + fun real _ = raise FastHashRand + fun boundedReal _ _ = raise FastHashRand + fun char _ = raise FastHashRand + fun boundedChar _ _ = raise FastHashRand +end diff --git a/tests/bench/skyline/Skyline.sml b/tests/bench/skyline/Skyline.sml new file mode 100644 index 000000000..d871c1382 --- /dev/null +++ b/tests/bench/skyline/Skyline.sml @@ -0,0 +1,60 @@ +structure Skyline = +struct + type 'a seq = 'a Seq.t + type skyline = (int * int) Seq.t + + fun singleton (l, h, r) = Seq.fromList [(l, h), (r, 0)] + + fun combine (sky1, sky2) = + let + val lMarked = Seq.map (fn (x, y) => (x, SOME y, NONE)) sky1 + val rMarked = Seq.map (fn (x, y) => (x, NONE, SOME y)) sky2 + + fun cmp ((x1, _, _), (x2, _, _)) = Int.compare (x1, x2) + val merged = Merge.merge cmp (lMarked, rMarked) + + fun copy (a, b) = case b of SOME _ => b | NONE => a + fun copyFused ((x1, yl1, yr1), (x2, yl2, yr2)) = + (x2, copy (yl1, yl2), copy (yr1, yr2)) + + val allHeights = Seq.scanIncl copyFused (0,NONE,NONE) merged + + fun squish (x, y1, y2) = + (x, Int.max (Option.getOpt (y1, 0), Option.getOpt (y2, 0))) + val sky = Seq.map squish allHeights + + (*fun isUnique (i, (x, h)) = + i = 0 orelse let val (_, prevh) = Seq.nth sky (i-1) in h <> prevh end*) + (*val sky = Seq.filterIdx isUnique sky*) + in + sky + end + + fun skyline g bs = + let + fun skyline' bs = + case Seq.length bs of + 0 => Seq.empty () + | 1 => singleton (Seq.nth bs 0) + | n => + let + val half = n div 2 + val sfL = fn _ => skyline' (Seq.take bs half) + val sfR = fn _ => skyline' (Seq.drop bs half) + in + if Seq.length bs <= g then + combine (sfL (), sfR ()) + else + combine (ForkJoin.par (sfL, sfR)) + end + + val sky = skyline' bs + + fun isUnique (i, (x, h)) = + i = 0 orelse let val (_, prevh) = Seq.nth sky (i-1) in h <> prevh end + val sky = Seq.filterIdx isUnique sky + in + sky + end + +end diff --git a/tests/bench/skyline/main.sml b/tests/bench/skyline/main.sml new file mode 100644 index 000000000..13810c918 --- /dev/null +++ b/tests/bench/skyline/main.sml @@ -0,0 +1,102 @@ +structure CLA = CommandLineArgs +structure Gen = CityGen + +(* +functor S (Sky : SKYLINE where type skyline = (int * int) Seq.t) = +struct + open Sky + fun skyline bs = + case Seq.splitMid bs of + Seq.EMPTY => Seq.empty () + | Seq.ONE b => singleton b + | Seq.PAIR (l, r) => + let + fun sl _ = skyline l + fun sr _ = skyline r + val (l', r') = + if Seq.length bs <= 1000 + then (sl (), sr ()) + else Primitives.par (sl, sr) + in + combine (l', r') + end +end + +structure Stu = S (MkSkyline (structure Seq = Seq)) +structure Ref = S (MkRefSkyline (structure Seq = Seq)) +*) + +fun pairEq ((x1, y1), (x2, y2)) = (x1 = x2 andalso y1 = y2) + +fun skylinesEq (s1, s2) = + Seq.length s1 = Seq.length s2 andalso + Seq.reduce (fn (a,b) => a andalso b) true + (Seq.tabulate (fn i => pairEq (Seq.nth s1 i, Seq.nth s2 i)) (Seq.length s1)) + +val size = CLA.parseInt "size" 1000000 +val seed = CLA.parseInt "seed" 15210 +val grain = CLA.parseInt "grain" 1000 +val output = CLA.parseString "output" "" + +(* ensure newline at end of string *) +fun println s = + let + val needsNewline = + String.size s = 0 orelse String.sub (s, String.size s - 1) <> #"\n" + in + print (if needsNewline then s ^ "\n" else s) + end + +val _ = println ("size " ^ Int.toString size) +val _ = println ("seed " ^ Int.toString seed) +val _ = println ("grain " ^ Int.toString grain) + +val (input, tm) = Util.getTime (fn _ => Gen.city size seed) +val _ = println ("generated input in " ^ Time.fmt 4 tm ^ "s\n") + +val sky = Benchmark.run "skyline" (fn _ => Skyline.skyline grain input) +val _ = print ("result-len " ^ Int.toString (Seq.length sky) ^ "\n") + +val _ = + if output = "" then + print ("use -output XXX.ppm to see result\n") + else + let + val (xMin, _) = Seq.nth sky 0 + val (xMax, _) = Seq.nth sky (Seq.length sky - 1) + val yMax = Seq.reduce Int.max 0 (Seq.map (fn (_,y) => y) sky) + val _ = print ("xMin " ^ Int.toString xMin ^ "\n") + val _ = print ("xMax " ^ Int.toString xMax ^ "\n") + val _ = print ("yMax " ^ Int.toString yMax ^ "\n") + + val width = 1000 + val height = 250 + + val padding = 20 + + fun col x = + padding + width * (x - xMin) div (1 + xMax - xMin) + fun row y = + padding + height - 1 - (height * y div (1 + yMax)) + + val width' = 2*padding + width + val height' = padding + height + val image = Seq.tabulate (fn _ => Color.white) (width' * height') + + val _ = Seq.foreach sky (fn (idx, (x, y)) => + if idx >= Seq.length sky - 1 then () else + let + val (x', _) = Seq.nth sky (idx+1) + + val ihi = row y + val jlo = col x + val jhi = Int.max (col x + 1, col x') + in + Util.for (ihi, height') (fn i => + Util.for (jlo, jhi) (fn j => + ArraySlice.update (image, i*width' + j, Color.black))) + end) + in + PPM.write output {width=width', height=height', data=image}; + print ("wrote output to " ^ output ^ "\n") + end diff --git a/tests/bench/skyline/skyline.mlb b/tests/bench/skyline/skyline.mlb new file mode 100644 index 000000000..20d3615dd --- /dev/null +++ b/tests/bench/skyline/skyline.mlb @@ -0,0 +1,6 @@ +../../mpllib/sources.$(COMPAT).mlb +FastHashRand.sml +CityGen.sml +Skyline.sml +main.sml + diff --git a/tests/bench/sparse-mxv-opt/SparseMxV.sml b/tests/bench/sparse-mxv-opt/SparseMxV.sml new file mode 100644 index 000000000..c6b277163 --- /dev/null +++ b/tests/bench/sparse-mxv-opt/SparseMxV.sml @@ -0,0 +1,14 @@ +structure SparseMxV = +struct + + fun sparseMxV (mat: (int * real) Seq.t Seq.t) (vec: real Seq.t) = + let + fun f (i,x) = (Seq.nth vec i) * x + fun rowSum r = + SeqBasis.reduce 5000 op+ 0.0 (0, Seq.length r) (fn i => f(Seq.nth r i)) + in + ArraySlice.full (SeqBasis.tabulate 100 (0, Seq.length mat) (fn i => + rowSum (Seq.nth mat i))) + end + +end diff --git a/tests/bench/sparse-mxv-opt/main.sml b/tests/bench/sparse-mxv-opt/main.sml new file mode 100644 index 000000000..64e971937 --- /dev/null +++ b/tests/bench/sparse-mxv-opt/main.sml @@ -0,0 +1,37 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence +structure DS = DelayedSeq + +structure M = SparseMxV + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +val rowLen = 100 +val numRows = n div rowLen +val vec = Seq.tabulate (fn i => 1.0) numRows +fun gen i j = + ((Util.hash (i * rowLen + j) mod numRows), 1.0) +val mat = Seq.tabulate (fn i => Seq.tabulate (gen i) rowLen) numRows + +fun task () = + M.sparseMxV mat vec + +fun check result = + if not doCheck then () else + let + fun closeEnough (a, b) = Real.< (Real.abs (a - b), 0.000001) + val correct = + DS.reduce (fn (a, b) => a andalso b) true + (DS.tabulate + (fn i => closeEnough (Seq.nth result i, Real.fromInt rowLen)) + numRows) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "sparse-mxv" task +val _ = check result diff --git a/tests/bench/sparse-mxv-opt/sparse-mxv-opt.mlb b/tests/bench/sparse-mxv-opt/sparse-mxv-opt.mlb new file mode 100644 index 000000000..551a98a1c --- /dev/null +++ b/tests/bench/sparse-mxv-opt/sparse-mxv-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +SparseMxV.sml +main.sml diff --git a/tests/bench/sparse-mxv/MkMXV.sml b/tests/bench/sparse-mxv/MkMXV.sml new file mode 100644 index 000000000..d69d4af94 --- /dev/null +++ b/tests/bench/sparse-mxv/MkMXV.sml @@ -0,0 +1,16 @@ +functor MkMXV (Seq : SEQUENCE) = +struct + + structure ASeq = ArraySequence + type 'a seq = 'a ASeq.t + + fun sparseMxV (mat : (int * real) seq seq) (vec : real seq) = + let + fun f (i,x) = (ASeq.nth vec i) * x + fun rowSum r = + Seq.reduce op+ 0.0 (Seq.map f (Seq.fromArraySeq r)) + in + Seq.toArraySeq (Seq.map rowSum (Seq.fromArraySeq mat)) + end + +end diff --git a/tests/bench/sparse-mxv/main.sml b/tests/bench/sparse-mxv/main.sml new file mode 100644 index 000000000..2c2c514b4 --- /dev/null +++ b/tests/bench/sparse-mxv/main.sml @@ -0,0 +1,37 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence +structure DS = DelayedSeq + +structure M = MkMXV (DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val doCheck = CLA.parseFlag "check" + +val rowLen = 100 +val numRows = n div rowLen +val vec = Seq.tabulate (fn i => 1.0) numRows +fun gen i j = + ((Util.hash (i * rowLen + j) mod numRows), 1.0) +val mat = Seq.tabulate (fn i => Seq.tabulate (gen i) rowLen) numRows + +fun task () = + M.sparseMxV mat vec + +fun check result = + if not doCheck then () else + let + fun closeEnough (a, b) = Real.< (Real.abs (a - b), 0.000001) + val correct = + DS.reduce (fn (a, b) => a andalso b) true + (DS.tabulate + (fn i => closeEnough (Seq.nth result i, Real.fromInt rowLen)) + numRows) + in + if correct then + print ("correct? yes\n") + else + print ("correct? no\n") + end + +val result = Benchmark.run "sparse-mxv" task +val _ = check result diff --git a/tests/bench/sparse-mxv/sparse-mxv.mlb b/tests/bench/sparse-mxv/sparse-mxv.mlb new file mode 100644 index 000000000..ed7e69797 --- /dev/null +++ b/tests/bench/sparse-mxv/sparse-mxv.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkMXV.sml +main.sml diff --git a/tests/bench/subset-sum/SubsetSumTiled.sml b/tests/bench/subset-sum/SubsetSumTiled.sml new file mode 100644 index 000000000..199e6a2d3 --- /dev/null +++ b/tests/bench/subset-sum/SubsetSumTiled.sml @@ -0,0 +1,100 @@ +structure SubsetSumTiled: +sig + val subset_sum: {unsafe_skip_table_set: bool} + -> int Seq.t * int + -> int Seq.t option +end = +struct + + structure Table: + sig + type t + val new: {unsafe_skip_table_set: bool} -> int * int -> t + val set: t -> int * int -> bool -> unit + val get: t -> int * int -> bool + end = + struct + datatype t = T of {num_rows: int, num_cols: int, data: Word8.word array} + + fun new {unsafe_skip_table_set} (num_rows, num_cols) = + let + val data = ForkJoin.alloc (num_rows * num_cols) + in + if unsafe_skip_table_set then + () + else + ForkJoin.parfor 1000 (0, num_rows * num_cols) (fn i => + Array.update (data, i, 0w0 : Word8.word)); + T {num_rows = num_rows, num_cols = num_cols, data = data} + end + + fun set (T {num_rows, num_cols, data}) (r, c) b = + Array.update (data, r * num_cols + c, if b then 0w1 else 0w0) + + fun get (T {num_rows, num_cols, data}) (r, c) = + Array.sub (data, r * num_cols + c) = 0w1 + end + + + fun subset_sum {unsafe_skip_table_set} (bag: int Seq.t, goal: int) : + int Seq.t option = + let + val n = Seq.length bag + + val table = + Table.new {unsafe_skip_table_set = unsafe_skip_table_set} + (1 + n, 1 + goal) + fun get (r, c) = Table.get table (r, c) + fun set (r, c) b = + Table.set table (r, c) b + + fun do_node (i, j) = + if j = 0 then set (i, j) true + else if i >= n then set (i, j) false + else if Seq.nth bag i > j then set (i, j) (get (i + 1, j)) + else set (i, j) (get (i + 1, j) orelse get (i + 1, j - Seq.nth bag i)) + + fun do_tile (i_lo, i_hi, j_lo, j_hi) = + let + val i_sz = i_hi - i_lo + val j_sz = j_hi - j_lo + in + if i_sz * j_sz <= 1000 then + Util.forBackwards (i_lo, i_hi) (fn i => + Util.for (j_lo, j_hi) (fn j => do_node (i, j))) + else if i_sz = 1 then + ForkJoin.parfor 1000 (j_lo, j_hi) (fn j => do_node (i_lo, j)) + else if j_sz = 1 then + (* no parallelism is possible within a single column *) + Util.forBackwards (i_lo, i_hi) (fn i => do_node (i, j_lo)) + else + let + val i_mid = i_lo + i_sz div 2 + val j_mid = j_lo + j_sz div 2 + in + do_tile (i_mid, i_hi, j_lo, j_mid); + ForkJoin.par + ( fn () => do_tile (i_lo, i_mid, j_lo, j_mid) + , fn () => do_tile (i_mid, i_hi, j_mid, j_hi) + ); + do_tile (i_lo, i_mid, j_mid, j_hi) + end + end + + fun reconstruct_path acc (i, j) = + if j = 0 then + Seq.fromRevList acc + else + let + val x = Seq.nth bag i + in + if get (i + 1, j) then reconstruct_path acc (i + 1, j) + else reconstruct_path (x :: acc) (i + 1, j - x) + end + in + do_tile (0, n + 1, 0, goal + 1); + + if get (0, goal) then SOME (reconstruct_path [] (0, goal)) else NONE + end + +end diff --git a/tests/bench/subset-sum/main.sml b/tests/bench/subset-sum/main.sml new file mode 100644 index 000000000..e0dc55038 --- /dev/null +++ b/tests/bench/subset-sum/main.sml @@ -0,0 +1,36 @@ +structure CLA = CommandLineArgs + +val bag_str = CLA.parseString "bag" "3,2,3,1,2,1,5,10,10000000,30,10000" +val goal = CLA.parseInt "goal" 10010021 +val unsafe_skip_table_set = CLA.parseFlag "unsafe_skip_table_set" + +val bag = + Seq.fromList (List.map (valOf o Int.fromString) + (String.tokens (fn c => c = #",") bag_str)) + handle _ => Util.die ("parsing -bag ... failed") + +val _ = + if Util.all (0, Seq.length bag) (fn i => Seq.nth bag i > 0) then () + else Util.die ("bag elements must be all >0") + +val _ = if goal >= 0 then () else Util.die ("goal must be >=0") + +val bag_str = let val s = Seq.toString Int.toString bag + in String.substring (s, 1, String.size s - 2) + end +val _ = print ("bag " ^ bag_str ^ "\n") +val _ = print ("goal " ^ Int.toString goal ^ "\n") +val _ = print + ("unsafe_skip_table_set? " ^ (if unsafe_skip_table_set then "yes" else "no") + ^ "\n") + +val result = Benchmark.run "subset-sum" (fn () => + SubsetSumTiled.subset_sum {unsafe_skip_table_set = unsafe_skip_table_set} + (bag, goal)) + +val out_str = + case result of + NONE => "NONE" + | SOME x => "SOME " ^ Seq.toString Int.toString x + +val _ = print ("result " ^ out_str ^ "\n") diff --git a/tests/bench/subset-sum/subset-sum.mlb b/tests/bench/subset-sum/subset-sum.mlb new file mode 100644 index 000000000..cba3e06a8 --- /dev/null +++ b/tests/bench/subset-sum/subset-sum.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +SubsetSumTiled.sml +main.sml \ No newline at end of file diff --git a/tests/bench/suffix-array/AS.sml b/tests/bench/suffix-array/AS.sml new file mode 100644 index 000000000..b3a4705ce --- /dev/null +++ b/tests/bench/suffix-array/AS.sml @@ -0,0 +1,8 @@ +structure AS = +struct + open ArraySlice + open Seq + + val GRAIN = 4096 + val ASupdate = ArraySlice.update +end diff --git a/tests/bench/suffix-array/BruteForce.sml b/tests/bench/suffix-array/BruteForce.sml new file mode 100644 index 000000000..6352310b5 --- /dev/null +++ b/tests/bench/suffix-array/BruteForce.sml @@ -0,0 +1,29 @@ +(* Author: Lawrence Wang (lawrenc2@andrew.cmu.edu, github.com/larry98) + *) + +structure BruteForceSuffixArray :> +sig + val makeSuffixArray : string -> int Seq.t +end = +struct + + fun makeSuffixArray str = + let + val n = String.size str + val sa = AS.tabulate (fn i => i) n + fun cmp k (i, j) = + if i = j then EQUAL + else if i + k >= n then LESS + else if j + k >= n then GREATER + else + let + val c1 = String.sub (str, i + k) + val c2 = String.sub (str, j + k) + in + Char.compare (c1, c2) + end + in + RadixSort.quicksort sa cmp n + end + +end diff --git a/tests/bench/suffix-array/PrefixDoubling.sml b/tests/bench/suffix-array/PrefixDoubling.sml new file mode 100644 index 000000000..686f28606 --- /dev/null +++ b/tests/bench/suffix-array/PrefixDoubling.sml @@ -0,0 +1,256 @@ +(* Author: Lawrence Wang (lawrenc2@andrew.cmu.edu, github.com/larry98) + * + * Based on the prefix doubling approach used by Manbar & Meyers and Larson & + * Sudakane. The general idea is to repeatedly sort the suffixes by their first + * k characters for k = 1, 2, 4, 8, ... This can be done efficiently by keeping + * track of the "ranks" (or groups, buckets, inverse suffix array, etc) of each + * suffix on each round. + * + * Our algorithm maintains sets of "active groups", which are groups of + * suffixes whose k-prefixes are equal. Each of the active groups are processed + * in parallel, and the processing consists of sorting the suffixes in each + * group by their 2k-prefixes, which may create more active groups for the + * next round. We only maintain active groups of size greater than 1, which + * is an optimization made in Larson & Sudakane. This effectively skips over + * the parts of the suffix array which are already in their sorted positions. + * + * Larson & Sudakane also perform additional optimizations, such as using a + * modified version of 3-way pivot quicksort that updates the group numbers + * as part of the sorting routine (essentially it assigns the group number + * when processing the EQUAL partition). We do not implement this optimization + * as our algorithm uses DualPivotQuickSort. + *) +structure PrefixDoublingSuffixArray :> +sig + val makeSuffixArray : string -> int Seq.t +end = +struct + + val GRAIN = 10000 + + structure AS = + struct + open ArraySlice + open Seq + val ASupdate = ArraySlice.update + + fun filter p s = + full (SeqBasis.filter GRAIN (0, length s) (nth s) (p o nth s)) + + fun scanIncl f b s = + let + val a = SeqBasis.scan GRAIN f b (0, length s) (nth s) + in + slice (a, 1, NONE) + end + end + + fun appG grain f s = + ForkJoin.parfor grain (0, AS.length s) (fn i => f (AS.nth s i)) + + fun initCountingSort str = + let + val n = String.size str + + fun bucket i = Char.ord (String.sub (str, i)) + val (sa, offsets) = + CountingSort.sort (AS.tabulate (fn i => i) n) bucket 256 + val offsets = + AS.full (SeqBasis.filter GRAIN (0, AS.length offsets) (AS.nth offsets) + (fn i => i = 0 orelse AS.nth offsets i <> AS.nth offsets (i-1))) + val numGroups = AS.length offsets + val groupNum = ArraySlice.full (ForkJoin.alloc n) + + fun makeGroup offsets n i = + let + val off = AS.nth offsets i + val len = if i = AS.length offsets - 1 then n - off + else AS.nth offsets (i + 1) - off + in + (off, len) + end + + fun updateGroupNum (off, len) = + ForkJoin.parfor GRAIN (off, off + len) (fn j => + AS.ASupdate (groupNum, AS.nth sa j, off) + ) + + val groups = AS.tabulate (makeGroup offsets n) numGroups + val () = appG 1 updateGroupNum groups + in + (sa, groupNum, groups, 1) + end + + fun initSampleSort str = + let + val n = String.size str + + fun pack i = + let + val charToWord = Word64.fromInt o Char.ord + fun getChar i = if i >= n then Char.minChar else String.sub (str, i) + val orb = Word64.orb + val << = Word64.<< + infix 2 << infix 2 orb + val v = charToWord (String.sub (str, i)) << 0w56 + val v = v orb (charToWord (getChar (i + 1)) << 0w48) + val v = v orb (charToWord (getChar (i + 2)) << 0w40) + val v = v orb (charToWord (getChar (i + 3)) << 0w32) + in + v orb (Word64.fromInt i) + end + val words = SampleSort.sort Word64.compare (AS.tabulate pack n) + val idxMask = 0w4294967295 + + val sa = AS.map (fn w => Word64.toInt (Word64.andb (w, idxMask))) words + val groupNum = ArraySlice.full (ForkJoin.alloc n) + + fun eq w1 w2 = Word64.>> (w1, 0w32) = Word64.>> (w2, 0w32) + fun f i = + if i > 0 andalso eq (AS.nth words i) (AS.nth words (i - 1)) + then 0 + else i + val offsets = AS.scanIncl Int.max 0 (AS.tabulate f n) + val maxOffset = AS.nth offsets (n - 1) + val groups = AS.tabulate (fn i => (0, 0)) n + val () = ForkJoin.parfor GRAIN (1, n) (fn i => + let + val off1 = AS.nth offsets (i - 1) + val off2 = AS.nth offsets i + in + AS.ASupdate (groupNum, AS.nth sa i, off2); + if off1 = off2 then () + else AS.ASupdate (groups, i - 1, (off1, off2 - off1)) + end + ) + val () = AS.ASupdate (groupNum, AS.nth sa 0, AS.nth offsets 0) + val () = AS.ASupdate (groups, n - 1, (maxOffset, n - maxOffset)) + in + (sa, groupNum, groups, 4) + end + + fun makeSuffixArray str = + let + val n = String.size str + val (sa, groupNum, groups, k) = initSampleSort str + + fun isActive (off, len) = len > 1 + val activeGroups = AS.filter isActive groups + + val ranks = ArraySlice.full (ForkJoin.alloc n) + val aux = ArraySlice.full (ForkJoin.alloc n) + + fun loop activeGroups k = + if AS.length activeGroups = 0 then () + else if k > n then () + else + let + fun cmp (i, j) = + let + val x = AS.nth ranks i + val y = AS.nth ranks j + in + if x = ~1 andalso y = ~1 then EQUAL + else if x = ~1 then LESS + else if y = ~1 then GREATER + else Int.compare (x, y) + end + + fun sortGroup group = + Quicksort.sortInPlaceG n cmp (AS.subseq sa group) + + fun expandGroupSeq s (off, len) = + let + fun loop i numGroups start = + if i = off + len then ( + AS.ASupdate (s, numGroups, (start, i - start)); + numGroups + 1 + ) else if cmp (AS.nth sa i, AS.nth sa (i - 1)) <> EQUAL then ( + AS.ASupdate (groupNum, AS.nth sa i, i); + AS.ASupdate (s, numGroups, (start, i - start)); + loop (i + 1) (numGroups + 1) i + ) else ( + AS.ASupdate (groupNum, AS.nth sa i, start); + loop (i + 1) numGroups start + ) + val numGroups = loop (off + 1) 0 off + val () = AS.ASupdate (groupNum, AS.nth sa off, off) + in + Util.for (numGroups, len) (fn i => + AS.ASupdate (s, i, (0, 0)) + ) + end + + fun expandGroupPar s (off, len) = + let + fun f i = + if i = 0 then off + else + case cmp (AS.nth sa (off + i), AS.nth sa (off + i - 1)) of + EQUAL => 0 + | _ => off + i + val names = AS.scanIncl Int.max 0 (AS.tabulate f len) + val maxName = AS.nth names (len - 1) + in + ( + ForkJoin.parfor GRAIN (1, len) (fn i => ( + let + val name = AS.nth names i + val name' = AS.nth names (i - 1) + in + AS.ASupdate (groupNum, AS.nth sa (off + i), name); + if AS.nth names i = off + i andalso i > 0 then + AS.ASupdate (s, i - 1, (name', off + i - name')) + else AS.ASupdate (s, i - 1, (0, 0)) + end + )); + AS.ASupdate (groupNum, AS.nth sa off, AS.nth names 0); + AS.ASupdate (s, len - 1, (maxName, off + len - maxName)) + ) + end + + val seqExpand = + AS.length activeGroups > Concurrency.numberOfProcessors + fun expandGroup s (off, len) = + if len <= GRAIN orelse seqExpand then expandGroupSeq s (off, len) + else expandGroupPar s (off, len) + + val groupLens = AS.map #2 activeGroups + val (groupStarts, maxGroups) = AS.scan (op +) 0 groupLens + val avgLen = maxGroups div (AS.length activeGroups) + (* TODO: tune grain size *) + val grain = if avgLen >= 256 then 1 + else if avgLen >= 64 then 32 + else if avgLen >= 16 then 64 + else 4096 + + val () = print ("grain is " ^ (Int.toString grain) ^ "\n") + val () = print ("avgLen is " ^ (Int.toString avgLen) ^ "\n") + val () = print ("numGroups is " ^ (Int.toString (AS.length activeGroups)) ^ "\n") + + (* Its faster to copy all of the ranks instead of just those in + active groups *) + val () = ForkJoin.parfor GRAIN (0, n) (fn i => + let val x = if i + k >= n then ~1 else AS.nth groupNum (i + k) + in AS.ASupdate (ranks, i, x) end + ) + + val newGroups = AS.take aux maxGroups + val () = ForkJoin.parfor grain (0, AS.length activeGroups) (fn i => + let + val group = AS.nth activeGroups i + val start = AS.nth groupStarts i + val s = AS.subseq newGroups (start, #2 group) + in + (sortGroup group; expandGroup s group) + end + ) + in + loop (AS.filter isActive newGroups) (2 * k) + end + val () = loop activeGroups k + in + sa + end + +end diff --git a/tests/bench/suffix-array/main.sml b/tests/bench/suffix-array/main.sml new file mode 100644 index 000000000..17e17ab4d --- /dev/null +++ b/tests/bench/suffix-array/main.sml @@ -0,0 +1,104 @@ +structure CLA = CommandLineArgs +(* structure Seq = ArraySequence *) + +val str = CLA.parseString "str" "" +(* val algo = CLA.parseString "algo" "" *) +val check = CLA.parseFlag "check" +val benchmark = CLA.parseFlag "benchmark" +val benchSize = CLA.parseInt "n" 10000000 +val printResult = CLA.parseFlag "print" +val filename = CLA.parseString "file" "" +val rep = case (Int.fromString (CLA.parseString "repeat" "1")) of + SOME(a) => a + | NONE => 1 + +fun load filename = + ReadFile.contents filename + (* let val str = Util.readFile filename + in CharVector.tabulate (Array.length str, fn i => Array.sub (str, i)) end *) + +val str = if filename <> "" then load filename else str + +val maker = + PrefixDoublingSuffixArray.makeSuffixArray + (*if algo = "DC3" then DC3SuffixArray.makeSuffixArray + else if algo = "PD" then PrefixDoublingSuffixArray.makeSuffixArray + else if algo = "BF" then BruteForceSuffixArray.makeSuffixArray + else Util.exit "Unknown algorithm" *) + +(* val _ = MLton.Rusage.measureGC true *) + +(* fun totalGCTime () = + let + val n = Primitives.numberOfProcessors + val time = ref Time.zeroTime + val () = Primitives.for (0, n) (fn i => + (time := Time.+ (!time, Primitives.localGCTimeOfProc i); + time := Time.+ (!time, Primitives.promoTimeOfProc i)) + ) + in + !time + end *) + +fun runTrial str = + let + (* val _ = MLton.GC.collect () *) + (* val gcTime0 = totalGCTime () *) + val t0 = Time.now () + val result = maker str + val t1 = Time.now () + (* val gcTime1 = totalGCTime () *) + val elapsed = Time.toMilliseconds (Time.- (t1, t0)) + (* val gcTimeTotal = Time.toMilliseconds (Time.- (gcTime1, gcTime0)) *) + val gcTimeTotal = 0 + val () = print ("GC: " ^ LargeInt.toString gcTimeTotal ^ " ms\t" + ^ "Total: " ^ LargeInt.toString elapsed ^ " ms\n") + in + result + end + +val result = if str <> "" then runTrial str else Seq.empty () + +val () = + if printResult then + Util.for (0, Seq.length result) (fn i => + print (Int.toString (Seq.nth result i) ^ "\n") + ) + else () + +fun checker str result = + let + val answer = BruteForceSuffixArray.makeSuffixArray str + in + if Seq.equal (op =) (result, answer) + then print "Correct\n" + else print "Incorrect\n" + end + +val _ = if str <> "" andalso check then checker str result else () + +fun runBenchmark () = + let + fun randChar seed = Char.chr (Util.hash seed mod 256) + fun randString n = CharVector.tabulate (n, randChar) + val _ = print ("N " ^ Int.toString benchSize ^ "\n") + val (str, tm1) = Util.getTime (fn _ => randString benchSize) + val _ = print ("generated input in " ^ Time.fmt 4 tm1 ^ "s\n") + + val result = Benchmark.run "running suffix array" (fn _ => maker str) + + val _ = + if not check then () else + let val (_, tm) = + Util.getTime (fn _ => if check then checker str result else ()) + in print ("checking took " ^ Time.fmt 4 tm ^ "s\n") + end + in + () + end + +val _ = if benchmark then runBenchmark () else () + +val _ = + if benchmark then GCStats.report () + else () diff --git a/tests/bench/suffix-array/suffix-array.mlb b/tests/bench/suffix-array/suffix-array.mlb new file mode 100644 index 000000000..49037ac80 --- /dev/null +++ b/tests/bench/suffix-array/suffix-array.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +AS.sml +BruteForce.sml +PrefixDoubling.sml +main.sml diff --git a/tests/bench/tape-delay/main.sml b/tests/bench/tape-delay/main.sml new file mode 100644 index 000000000..2018ae0ec --- /dev/null +++ b/tests/bench/tape-delay/main.sml @@ -0,0 +1,34 @@ +structure CLA = CommandLineArgs + +val infile = + case CLA.positional () of + [x] => x + | _ => Util.die ("[ERR] usage: tape-delay INPUT_FILE [-output OUTPUT_FILE]\n") + +val outfile = CLA.parseString "output" "" + +val delayTime = CLA.parseReal "delay" 0.5 +val decayFactor = CLA.parseReal "decay" 0.2 + +val decaydB = Real.round (20.0 * Math.log10 decayFactor) + +val _ = print ("delay " ^ Real.toString delayTime ^ "s\n") +val _ = print ("decay " ^ Real.toString decayFactor ^ " (" + ^ Int.toString decaydB ^ "dB)\n") + +val (snd, tm) = Util.getTime (fn _ => NewWaveIO.readSound infile) +val _ = print ("read sound in " ^ Time.fmt 4 tm ^ "s\n") + +val esnd = + Benchmark.run "echoing" (fn _ => Signal.delay delayTime decayFactor snd) + +val _ = + if outfile = "" then + print ("use -output file.wav to hear results\n") + else + let + val (_, tm) = Util.getTime (fn _ => NewWaveIO.writeSound esnd outfile) + in + print ("wrote output in " ^ Time.fmt 4 tm ^ "s\n") + end + diff --git a/tests/bench/tape-delay/tape-delay.mlb b/tests/bench/tape-delay/tape-delay.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/tape-delay/tape-delay.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/tinykaboom/TinyKaboom.sml b/tests/bench/tinykaboom/TinyKaboom.sml new file mode 100644 index 000000000..2ca7616e2 --- /dev/null +++ b/tests/bench/tinykaboom/TinyKaboom.sml @@ -0,0 +1,142 @@ +structure TinyKaboom = +struct + +type f32 = f32.real + +val sphere_radius: f32 = 1.5 +val noise_amplitude: f32 = 1.0 + +fun hash n = + let + val x = f32.sin(n)*43758.5453 + in + x-f32.realFloor(x) + end + +type vec3 = vec3.vector + +fun vec3f (x, y, z): vec3 = {x=x,y=y,z=z} + +fun lerp (v0, v1, t) = + v0 + (v1-v0) * f32.max 0.0 (f32.min 1.0 t) + +fun vlerp (v0, v1, t) = + vec3.map2 (fn x => fn y => lerp (x, y, t)) v0 v1 + +fun noise (x: vec3) = + let + val p = {x = f32.realFloor(#x x), y = f32.realFloor(#y x), z = f32.realFloor(#z x)} + val f = {x = #x x - #x p, y = #y x - #y p, z = #z x - #z p} + val f = vec3.scale (vec3.dot (f, vec3.sub ({x=3.0,y=3.0,z=3.0}, vec3.scale 2.0 f))) f + val n = vec3.dot (p, {x=1.0, y=57.0, z=113.0}) + in lerp(lerp(lerp(hash(n + 0.0), hash(n + 1.0), #x f), + lerp(hash(n + 57.0), hash(n + 58.0), #x f), #y f), + lerp(lerp(hash(n + 113.0), hash(n + 114.0), #x f), + lerp(hash(n + 170.0), hash(n + 171.0), #x f), #y f), #z f) + end + +fun rotate v = + vec3f(vec3.dot (vec3f(0.00, 0.80, 0.60), v), + vec3.dot (vec3f(~0.80, 0.36, ~0.48), v), + vec3.dot (vec3f(~0.60, ~0.48, 0.64), v)) + +fun fractal_brownian_motion (x: vec3) = let + val p = rotate x + val f = 0.0 + val f = f + 0.5000*noise p + val p = vec3.scale 2.32 p + val f = f + 0.2500*noise p + val p = vec3.scale 3.03 p + val f = f + 0.1250*noise p + val p = vec3.scale 2.61 p + val f = f + 0.0625*noise p + in f / 0.9375 + end + +fun palette_fire (d: f32): vec3 = let + val yellow = vec3f (1.7, 1.3, 1.0) + val orange = vec3f (1.0, 0.6, 0.0) + val red = vec3f (1.0, 0.0, 0.0) + val darkgray = vec3f (0.2, 0.2, 0.2) + val gray = vec3f (0.4, 0.4, 0.4) + + val x = f32.max 0.0 (f32.min 1.0 d) + in if x < 0.25 then vlerp(gray, darkgray, x*4.0) + else if x < 0.5 then vlerp(darkgray, red, x*4.0-1.0) + else if x < 0.75 then vlerp(red, orange, x*4.0-2.0) + else vlerp(orange, yellow, x*4.0-3.0) + end + +fun signed_distance t p = let + val displacement = ~(fractal_brownian_motion(vec3.scale 3.4 p)) * noise_amplitude + in vec3.norm p - (sphere_radius * f32.sin(t*0.25) + displacement) + end + +fun sq (x: f32) = x * x + +fun loop state continue f = + if continue state then + loop (f state) continue f + else + state + +fun sphere_trace t (orig: vec3, dir: vec3) : (bool * vec3) = let + fun check (i, hit) = (i = 1337, hit) in + if (vec3.dot (orig, orig)) - sq (vec3.dot (orig, dir)) > sq sphere_radius + then (false, orig) + else + check ( + loop (0, orig) (fn (i, _) => i < 64) (fn (i, pos) => + let val d = signed_distance t pos + in if d < 0.0 + then (1337, pos) + else (i + 1, vec3.add (pos, vec3.scale (f32.max (d*0.1) 0.1) dir)) + end) + ) + end + +fun distance_field_normal t pos = let + val eps = 0.1 + val d = signed_distance t pos + val nx = signed_distance t (vec3.add (pos, vec3f(eps, 0.0, 0.0))) - d + val ny = signed_distance t (vec3.add (pos, vec3f(0.0, eps, 0.0))) - d + val nz = signed_distance t (vec3.add (pos, vec3f(0.0, 0.0, eps))) - d + in vec3.normalise (vec3f(nx, ny, nz)) + end + +fun rgb (r: f32) (g: f32) (b: f32) = + let + fun clamp x = f32.min 1.0 (f32.max 0.0 x) + fun ch x = Word8.fromInt (f32.round (clamp x * 255.0)) + in + {red = ch r, green = ch g, blue = ch b} + end + +fun frame (t: f32) (width: int) (height: int): Color.pixel array = let + val fov = f32.pi / 3.0 + fun f j i = let + val dir_x = (f32.fromInt i + 0.5) - f32.fromInt width / 2.0 + val dir_y = ~(f32.fromInt j + 0.5) + f32.fromInt height / 2.0 + val dir_z = ~(f32.fromInt height)/(2.0*f32.tan(fov/2.0)) + val (is_hit, hit) = + sphere_trace t (vec3f(0.0, 0.0, 3.0), + vec3.normalise (vec3f(dir_x, dir_y, dir_z))) + in + if is_hit then let + val noise_level = (sphere_radius - vec3.norm hit)/noise_amplitude + val light_dir = vec3.normalise (vec3.sub (vec3f(10.0, 10.0, 10.0), hit)) + val light_intensity = + f32.max 0.4 (vec3.dot (light_dir, distance_field_normal t hit)) + val {x, y, z} = + vec3.scale light_intensity (palette_fire((noise_level - 0.2)*2.0)) + in rgb x y z + end + else + rgb 0.2 0.7 0.8 + end + in + SeqBasis.tabulate 10 (0, width*height) + (fn k => f (k div width) (k mod width)) + end + +end diff --git a/tests/bench/tinykaboom/f32.sml b/tests/bench/tinykaboom/f32.sml new file mode 100644 index 000000000..5360ef194 --- /dev/null +++ b/tests/bench/tinykaboom/f32.sml @@ -0,0 +1,7 @@ +structure f32 = +struct + open Real32 + open Real32.Math + fun max a b = Real32.max (a, b) + fun min a b = Real32.min (a, b) +end diff --git a/tests/bench/tinykaboom/main.sml b/tests/bench/tinykaboom/main.sml new file mode 100644 index 000000000..9083f54cd --- /dev/null +++ b/tests/bench/tinykaboom/main.sml @@ -0,0 +1,102 @@ +structure CLA = CommandLineArgs + +val fps = CLA.parseInt "fps" 60 +val width = CLA.parseInt "width" 640 +val height = CLA.parseInt "height" 480 +val frames = CLA.parseInt "frames" (10 * fps) +val outfile = CLA.parseString "outfile" "" +(* val frame = CLA.parseInt "frame" 100 *) + +val _ = print ("width " ^ Int.toString width ^ "\n") +val _ = print ("height " ^ Int.toString height ^ "\n") +(* val _ = print ("frame " ^ Int.toString frame ^ "\n") *) +val _ = print ("fps " ^ Int.toString fps ^ "\n") +val _ = print ("frames " ^ Int.toString frames ^ "\n") + +val duration = Real.fromInt frames / Real.fromInt fps + +val _ = print ("(" ^ Real.fmt (StringCvt.FIX (SOME 2)) duration ^ " seconds)\n") + +fun bench () = + let + val _ = print ("generating frames...\n") + val (images, tm) = Util.getTime (fn _ => + SeqBasis.tabulate 1 (0, frames) (fn frame => + { width = width + , height = height + , data = + ArraySlice.full + (TinyKaboom.frame (f32.fromInt frame / f32.fromInt fps) width height) + })) + val _ = print ("generated all frames in " ^ Time.fmt 4 tm ^ "s\n") + val perFrame = Time.fromReal (Time.toReal tm / Real.fromInt frames) + val _ = print ("average time per frame: " ^ Time.fmt 4 perFrame ^ "s\n") + in + images + end + +val images = Benchmark.run "tinykaboom" bench + +val _ = + if outfile = "" then + print ("no output file specified; use -outfile XXX.gif to see result\n") + else + let + val _ = print ("generating palette...\n") + (* val palette = GIF.Palette.summarize [Color.white, Color.black] 256 + { width = width + , height = height + , data = ArraySlice.full (TinyKaboom.frame 5.1667 640 480) + } *) + + fun sampleColor i = + let + val k = Util.hash i + val frame = (k div (width*height)) mod frames + val idx = k mod (width*height) + in + Seq.nth (#data (Array.sub (images, frame))) idx + end + + val palette = GIF.Palette.summarizeBySampling [Color.white, Color.black] 256 + sampleColor + + val blowUpFactor = CLA.parseInt "blowup" 1 + val _ = print ("blowup " ^ Int.toString blowUpFactor ^ "\n") + + fun blowUpImage (image as {width, height, data}) = + if blowUpFactor = 1 then image else + let + val width' = blowUpFactor * width + val height' = blowUpFactor * height + val output = ForkJoin.alloc (width' * height') + val _ = + ForkJoin.parfor 1 (0, height) (fn i => + ForkJoin.parfor (1000 div blowUpFactor) (0, width) (fn j => + let + val c = Seq.nth data (i*width + j) + in + Util.for (0, blowUpFactor) (fn di => + Util.for (0, blowUpFactor) (fn dj => + Array.update (output, (i*blowUpFactor+di)*width' + (j*blowUpFactor+dj), c))) + end)) + in + { width = width' + , height = height' + , data = ArraySlice.full output + } + end + + val _ = print ("writing to " ^ outfile ^"...\n") + val msBetween = Real.round ((1.0 / Real.fromInt fps) * 100.0) + val (_, tm) = Util.getTime (fn _ => + GIF.writeMany outfile msBetween palette + { width = blowUpFactor * width + , height = blowUpFactor * height + , numImages = frames + , getImage = fn i => #remap palette (blowUpImage (Array.sub (images, i))) + }) + val _ = print ("wrote all frames in " ^ Time.fmt 4 tm ^ "s\n") + in + () + end diff --git a/tests/bench/tinykaboom/tinykaboom.mlb b/tests/bench/tinykaboom/tinykaboom.mlb new file mode 100644 index 000000000..6a0e4cd9d --- /dev/null +++ b/tests/bench/tinykaboom/tinykaboom.mlb @@ -0,0 +1,5 @@ +../../mpllib/sources.$(COMPAT).mlb +f32.sml +vec3.sml +TinyKaboom.sml +main.sml diff --git a/tests/bench/tinykaboom/vec3.sml b/tests/bench/tinykaboom/vec3.sml new file mode 100644 index 000000000..5c6399baf --- /dev/null +++ b/tests/bench/tinykaboom/vec3.sml @@ -0,0 +1,20 @@ +structure vec3 = +struct + type f32 = f32.real + + type vector = {x: f32, y: f32, z: f32} + fun add (a: vector, b: vector) = + {x = #x a + #x b, y = #y a + #y b, z = #z a + #z b} + fun sub (a: vector, b: vector) = + {x = #x a - #x b, y = #y a - #y b, z = #z a - #z b} + fun dot (a: vector, b: vector) = + (#x a * #x b) + (#y a * #y b) + (#z a * #z b) + fun scale s ({x,y,z}: vector) = + {x = s*x, y = s*y, z = s*z} + fun map2 f (a: vector) (b: vector) = + {x = f (#x a) (#x b), y = f (#y a) (#y b), z = f (#z a) (#z b)} + fun norm a = + f32.sqrt (dot (a, a)) + fun normalise (v: vector): vector = + scale (1.0 / norm v) v +end diff --git a/tests/bench/to-gif/main.sml b/tests/bench/to-gif/main.sml new file mode 100644 index 000000000..5bf573277 --- /dev/null +++ b/tests/bench/to-gif/main.sml @@ -0,0 +1,39 @@ +structure CLA = CommandLineArgs + +val (input, output) = + case CLA.positional () of + [input, output] => (input, output) + | _ => Util.die "missing filename" + +val (image, tm) = Util.getTime (fn _ => PPM.read input) +val _ = print ("read image in " ^ Time.fmt 4 tm ^ "s\n") + +(* +val w = #width image +val h = #height image + +fun noisy i = + let + val data = Seq.map (fn x => x) (#data image) + in + (* spit on 10% of all pixels *) + Util.for (0, Seq.length data div 10) (fn j => + let + val k = Util.hash (i * Seq.length data + j) mod Seq.length data + in + ArraySlice.update (data, k, Color.red) + end); + {width = w, height = h, data = data} + end + +val (_, tm) = Util.getTime (fn _ => + GIF.writeMany output { width = w + , height = h + , numImages = 10 + , getImage = noisy + }) +*) + +val (_, tm) = Util.getTime (fn _ => GIF.write output image) + +val _ = print ("wrote " ^ output ^ " in " ^ Time.fmt 4 tm ^ "s\n") diff --git a/tests/bench/to-gif/to-gif.mlb b/tests/bench/to-gif/to-gif.mlb new file mode 100644 index 000000000..5f06290fb --- /dev/null +++ b/tests/bench/to-gif/to-gif.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +main.sml diff --git a/tests/bench/tokens/tokens.mlb b/tests/bench/tokens/tokens.mlb new file mode 100644 index 000000000..01809d9b0 --- /dev/null +++ b/tests/bench/tokens/tokens.mlb @@ -0,0 +1,2 @@ +../../mpllib/sources.$(COMPAT).mlb +tokens.sml diff --git a/tests/bench/tokens/tokens.sml b/tests/bench/tokens/tokens.sml new file mode 100644 index 000000000..b1afcabf8 --- /dev/null +++ b/tests/bench/tokens/tokens.sml @@ -0,0 +1,48 @@ +structure CLA = CommandLineArgs + +fun usage () = + let + val msg = + "usage: tokens [--verbose] [--no-output] FILE\n" + in + TextIO.output (TextIO.stdErr, msg); + OS.Process.exit OS.Process.failure + end + +val filename = + case CLA.positional () of + [x] => x + | _ => usage () + +val beVerbose = CLA.parseFlag "verbose" +val noOutput = CLA.parseFlag "no-output" + +fun vprint str = + if not beVerbose then () + else TextIO.output (TextIO.stdErr, str) + +val (contents, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filename) +val _ = vprint ("read file in " ^ Time.fmt 4 tm ^ "s\n") + +val tokens = + Benchmark.run "tokenizing" (fn _ => Tokenize.tokensSeq Char.isSpace contents) + +val _ = vprint ("number of tokens " ^ Int.toString (Seq.length tokens) ^ "\n") + +fun put c = TextIO.output1 (TextIO.stdOut, c) +fun putToken token = + Util.for (0, Seq.length token) (put o Seq.nth token) + +val _ = + if noOutput then () + else + let + val (_, tm) = Util.getTime (fn _ => + ArraySlice.app (fn token => (putToken token; put #"\n")) tokens) + in + vprint ("output in " ^ Time.fmt 4 tm ^ "s\n") + end + +val _ = + if beVerbose then GCStats.report () + else () diff --git a/tests/bench/triangle-count/TriangleCount.sml b/tests/bench/triangle-count/TriangleCount.sml new file mode 100644 index 000000000..cf9bd7482 --- /dev/null +++ b/tests/bench/triangle-count/TriangleCount.sml @@ -0,0 +1,125 @@ +structure TriangleCount = +struct + type 'a seq = 'a Seq.t + + structure G = AdjacencyGraph(Int) + structure V = G.Vertex + structure AS = ArraySlice + + type vertex = G.vertex + exception Assert + + (* assumes l <= r and all elements accessible in [l, r) *) + (* ensures that all elements left of the returned index are lesser *) + fun bin_search k s (l, r) = + if l = r then (l, false) + else if (r - l) = 1 then (l, (Seq.nth s l) = k) + else + let + val mid = l + Int.div (r - l - 1, 2) + in + case Int.compare (k, Seq.nth s mid) of + EQUAL => (mid, true) + | LESS => bin_search k s (l, mid) + | GREATER => bin_search k s (mid + 1, r) + end + + fun intersection_count s s' gran = + let + fun countseq1 s1 s2 = + let + val (n1, n2) = (Seq.length s1, Seq.length s2) + fun helper l1 l2 acc = + if (n1 <= l1) orelse (n2 <= l2) then acc + else + case Int.compare (Seq.nth s1 l1, Seq.nth s2 l2) of + EQUAL => helper (l1 + 1) (l2 + 1) (acc + 1) + | LESS => helper (l1 + 1) l2 acc + | GREATER => helper l1 (l2 + 1) acc + in + helper 0 0 0 + end + + fun countseq2 s1 s2 = + let + val (n1, n2) = (Seq.length s1, Seq.length s2) + fun helper l acc = + if l >= n1 then acc + else + let + val k = Seq.nth s1 l + val (idx, found) = bin_search k s2 (0, n2) + val bump = if found then 1 else 0 + in + helper (l + 1) (acc + bump) + end + in + if n2 = 0 then 0 + else helper 0 0 + end + + fun subs s i j = Seq.subseq s (i, j - i) + fun countpar s1 s2 = + let + val (n1, n2) = (Seq.length s1, Seq.length s2) + val nR = n1 + n2 + in + if nR < gran then countseq1 s1 s2 + else if n2 < n1 then countpar s2 s1 + else if n1 < Int.div (gran, 64) then countseq2 s1 s2 + else + let + val mid1 = Int.div (n1, 2) + val k1 = Seq.nth s1 mid1 + val (mid2, found) = bin_search k1 s2 (0, n2) + val bump = if found then 1 else 0 + val (l, r) = ForkJoin.par (fn _ => countpar (subs s1 0 mid1) (subs s2 0 (mid2 + 1 - bump)), + fn _ => countpar (subs s1 (mid1 + 1) n1) (subs s2 (mid2 + bump) n2)) + in + l + bump + r + end + end + val r = countpar s s' + in + r + end + + (* get common vertices greater than min_elt *) + fun intersection_count_thresh s s' gran min_elt = + let + val (n1, n2) = (Seq.length s, Seq.length s') + val (k1, _) = bin_search min_elt s (0, n1) + val (k2, _) = bin_search min_elt s' (0, n2) + fun subs s i j = Seq.subseq s (i, j - i) + in + if (k1 = n1) orelse (k2 = n2) then 0 + else intersection_count (subs s k1 n1) (subs s' k2 n2) gran + end + + fun triangle_count g = + let + fun count u = + let + val ngbrs = G.neighbors g u + val num_ngbrs = Seq.length ngbrs + val (idx, _) = bin_search u ngbrs (0, num_ngbrs) + val ngbrs = Seq.subseq ngbrs (idx, num_ngbrs - idx) + val num_ngbrs = Seq.length ngbrs + fun helpi i = + let + val v = Seq.nth ngbrs i + in + if u < v then + intersection_count_thresh ngbrs (G.neighbors g v) 10000 v + else + 0 + end + val r = SeqBasis.reduce 100 Int.+ 0 (0, num_ngbrs) helpi + in + r + end + (* val tr_counts = Seq.tabulate count (G.numVertices g) *) + in + SeqBasis.reduce 100 op+ 0 (0, G.numVertices g) count + end +end diff --git a/tests/bench/triangle-count/main.sml b/tests/bench/triangle-count/main.sml new file mode 100644 index 000000000..479b92471 --- /dev/null +++ b/tests/bench/triangle-count/main.sml @@ -0,0 +1,71 @@ +structure CLA = CommandLineArgs +structure G = AdjacencyGraph(Int) + +val source = CLA.parseInt "source" 0 +val doCheck = CLA.parseFlag "check" + +(* +val N = CLA.parseInt "N" 10000000 +val D = CLA.parseInt "D" 10 + +val (graph, tm) = Util.getTime (fn _ => G.randSymmGraph N D) +val _ = print ("generated graph in " ^ Time.fmt 4 tm ^ "s\n") +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") +*) + +val filename = + case CLA.positional () of + [x] => x + | _ => Util.die "missing filename" + +val (graph, tm) = Util.getTime (fn _ => G.parseFile filename) +val _ = print ("num vertices: " ^ Int.toString (G.numVertices graph) ^ "\n") +val _ = print ("num edges: " ^ Int.toString (G.numEdges graph) ^ "\n") + +val (_, tm) = Util.getTime (fn _ => + if G.parityCheck graph then () + else TextIO.output (TextIO.stdErr, + "WARNING: parity check failed; graph might not be symmetric " ^ + "or might have duplicate- or self-edges\n")) +val _ = print ("parity check in " ^ Time.fmt 4 tm ^ "s\n") + +val P = Benchmark.run "running tc: " (fn _ => TriangleCount.triangle_count graph) +val _ = print ("num-triangles = " ^ (Int.toString P) ^ "\n") + +(* val _ = LDD.check_ldd graph (#1 P) (#2 P) *) +(* val _ = Benchmark.run "running connectivity" (fn _ => LDD.connectivity graph b) *) +(* +val numVisited = + SeqBasis.reduce 10000 op+ 0 (0, Seq.length P) + (fn i => if Seq.nth P i >= 0 then 1 else 0) +val _ = print ("visited " ^ Int.toString numVisited ^ "\n") + +fun numHops P hops v = + if hops > Seq.length P then ~2 + else if Seq.nth P v = ~1 then ~1 + else if Seq.nth P v = v then hops + else numHops P (hops+1) (Seq.nth P v) + +val maxHops = + SeqBasis.reduce 100 Int.max ~3 (0, G.numVertices graph) (numHops P 0) +val _ = print ("max dist " ^ Int.toString maxHops ^ "\n") + +fun check () = + let + val (P', serialTime) = + Util.getTime (fn _ => SerialBFS.bfs graph source) + + val correct = + Seq.length P = Seq.length P' + andalso + SeqBasis.reduce 10000 (fn (a, b) => a andalso b) true (0, Seq.length P) + (fn i => numHops P 0 i = numHops P' 0 i) + in + print ("serial finished in " ^ Time.fmt 4 serialTime ^ "s\n"); + print ("correct? " ^ (if correct then "yes" else "no") ^ "\n") + end + +val _ = if doCheck then check () else () + +val _ = GCStats.report () *) diff --git a/tests/bench/triangle-count/triangle-count.mlb b/tests/bench/triangle-count/triangle-count.mlb new file mode 100644 index 000000000..e3cc8eb11 --- /dev/null +++ b/tests/bench/triangle-count/triangle-count.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +TriangleCount.sml +main.sml diff --git a/tests/bench/wc-opt/WC.sml b/tests/bench/wc-opt/WC.sml new file mode 100644 index 000000000..dc7931f30 --- /dev/null +++ b/tests/bench/wc-opt/WC.sml @@ -0,0 +1,51 @@ +structure WC : +sig + type 'a seq = 'a ArraySequence.t + + (* returns (num lines, num words, num characters) *) + val wc: char seq -> (int * int * int) +end = +struct + + structure ASeq = ArraySequence + type 'a seq = 'a ASeq.t + + fun wc seq = + let + (* + val (a, i, n) = ArraySlice.base seq + val _ = if i = 0 then () else raise Fail "uh oh" + fun nth i = Array.sub (a, i) + *) + fun nth i = ASeq.nth seq i + (* Create a delayed sequence of pairs of integers: + * the first is 1 if it is line break, 0 otherwise; + * the second is 1 if the start of a word, 0 otherwise. + *) + fun isSpace a = (a = #"\n" orelse a = #"\t" orelse a = #" ") + (*val isSpace = Char.isSpace*) + fun f i = + let + val si = nth i + val wordStart = + if (i = 0 orelse isSpace (nth (i-1))) andalso + not (isSpace si) + then 1 else 0 + val lineBreak = if si = #"\n" then 1 else 0 + in + (lineBreak, wordStart) + end + (* val x = Seq.tabulate f (ASeq.length seq) + val (lines, words) = + Seq.reduce (fn ((lb1, ws1), (lb2, ws2)) => (lb1 + lb2, ws1 + ws2)) (0, 0) x *) + val (lines, words) = + SeqBasis.reduce 5000 + (fn ((lb1, ws1), (lb2, ws2)) => (lb1 + lb2, ws1 + ws2)) + (0, 0) + (0, ASeq.length seq) + f + in + (lines, words, ASeq.length seq) + end + +end diff --git a/tests/bench/wc-opt/main.sml b/tests/bench/wc-opt/main.sml new file mode 100644 index 000000000..51fb6c080 --- /dev/null +++ b/tests/bench/wc-opt/main.sml @@ -0,0 +1,29 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val source = + case CLA.positional () of + [filePath] => + let + val (source, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filePath) + val _ = print ("loadtime " ^ Time.fmt 3 tm ^ "\n") + in + source + end + | _ => + Seq.tabulate (fn i => Char.chr (Util.hash i mod 255)) n + +fun task () = + WC.wc source + +fun check (lines, words, bytes) = + let + in + print ("correct? checker for wc not implemented yet\n") + end + +val (nl, nw, nb) = Benchmark.run "wc" task +val _ = print ("lines " ^ Int.toString nl ^ "\n") +val _ = print ("words " ^ Int.toString nw ^ "\n") +val _ = print ("chars " ^ Int.toString nb ^ "\n") diff --git a/tests/bench/wc-opt/wc-opt.mlb b/tests/bench/wc-opt/wc-opt.mlb new file mode 100644 index 000000000..568dd94ff --- /dev/null +++ b/tests/bench/wc-opt/wc-opt.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +WC.sml +main.sml diff --git a/tests/bench/wc/MkWC.sml b/tests/bench/wc/MkWC.sml new file mode 100644 index 000000000..609c4066e --- /dev/null +++ b/tests/bench/wc/MkWC.sml @@ -0,0 +1,45 @@ +functor MkWC (Seq : SEQUENCE) : +sig + type 'a seq = 'a ArraySequence.t + + (* returns (num lines, num words, num characters) *) + val wc: char seq -> (int * int * int) +end = +struct + + structure ASeq = ArraySequence + type 'a seq = 'a ASeq.t + + fun wc seq = + let + (* + val (a, i, n) = ArraySlice.base seq + val _ = if i = 0 then () else raise Fail "uh oh" + fun nth i = Array.sub (a, i) + *) + fun nth i = ASeq.nth seq i + (* Create a delayed sequence of pairs of integers: + * the first is 1 if it is line break, 0 otherwise; + * the second is 1 if the start of a word, 0 otherwise. + *) + fun isSpace a = (a = #"\n" orelse a = #"\t" orelse a = #" ") + (*val isSpace = Char.isSpace*) + fun f i = + let + val si = nth i + val wordStart = + if (i = 0 orelse isSpace (nth (i-1))) andalso + not (isSpace si) + then 1 else 0 + val lineBreak = if si = #"\n" then 1 else 0 + in + (lineBreak, wordStart) + end + val x = Seq.tabulate f (ASeq.length seq) + val (lines, words) = + Seq.reduce (fn ((lb1, ws1), (lb2, ws2)) => (lb1 + lb2, ws1 + ws2)) (0, 0) x + in + (lines, words, ASeq.length seq) + end + +end diff --git a/tests/bench/wc/main.sml b/tests/bench/wc/main.sml new file mode 100644 index 000000000..4edddfe4a --- /dev/null +++ b/tests/bench/wc/main.sml @@ -0,0 +1,33 @@ +structure CLA = CommandLineArgs +structure Seq = ArraySequence + +structure WC = MkWC(DelayedSeq) + +val n = CLA.parseInt "n" (1000 * 1000 * 100) +val filePath = CLA.parseString "infile" "" + +val source = + if filePath = "" then + (*Seq.tabulate (fn _ => #" ") n*) + Seq.tabulate (fn i => Char.chr (Util.hash i mod 255)) n + else + let + val (source, tm) = Util.getTime (fn _ => ReadFile.contentsSeq filePath) + val _ = print ("loadtime " ^ Time.fmt 3 tm ^ "\n") + in + source + end + +fun task () = + WC.wc source + +fun check (lines, words, bytes) = + let + in + print ("correct? checker for wc not implemented yet\n") + end + +val (nl, nw, nb) = Benchmark.run "wc" task +val _ = print ("lines " ^ Int.toString nl ^ "\n") +val _ = print ("words " ^ Int.toString nw ^ "\n") +val _ = print ("chars " ^ Int.toString nb ^ "\n") diff --git a/tests/bench/wc/wc.mlb b/tests/bench/wc/wc.mlb new file mode 100644 index 000000000..bc7303f32 --- /dev/null +++ b/tests/bench/wc/wc.mlb @@ -0,0 +1,3 @@ +../../mpllib/sources.$(COMPAT).mlb +MkWC.sml +main.sml diff --git a/tests/compile.py b/tests/compile.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mpllib/AdjacencyGraph.sml b/tests/mpllib/AdjacencyGraph.sml new file mode 100644 index 000000000..2e77447ab --- /dev/null +++ b/tests/mpllib/AdjacencyGraph.sml @@ -0,0 +1,390 @@ +functor AdjacencyGraph (Vertex: INTEGER) = +struct + + structure A = Array + structure AS = ArraySlice + + structure Vertex = + struct + type t = Vertex.int + open Vertex + val maxVal = toInt (valOf maxInt) + end + + structure VertexSubset = + struct + datatype h = SPARSE of Vertex.t Seq.t | DENSE of int Seq.t + type t = h * int + exception BadRep + + fun empty thresh = (SPARSE (Seq.empty()), thresh) + + fun size (vs, thresh) = + case vs of + SPARSE s => Seq.length s + | DENSE s => Seq.reduce op+ 0 s + + fun plugOnes s positions = + (Seq.foreach positions (fn (i, v) => AS.update (s, Vertex.toInt v, 1))) + + fun append (vs, threshold) s n = + case vs of + SPARSE ss => + if (Seq.length ss) + (Seq.length s) > threshold then + let + val dense_rep = Seq.tabulate (fn x => 0) n + val _ = plugOnes dense_rep ss + val _ = plugOnes dense_rep s + in + (DENSE (dense_rep), threshold) + end + else (SPARSE(Seq.append (ss, s)), threshold) + | DENSE ss => (plugOnes ss s; (DENSE ss, threshold)) + + fun sparse_to_dense vs n = + case vs of + SPARSE s => + let + val dense_rep = Seq.tabulate (fn x => 0) n + val _ = Seq.foreach s (fn (i, v) => AS.update (dense_rep, Vertex.toInt v, 1)) + in + DENSE (dense_rep) + end + | DENSE _ => raise BadRep + + fun dense_to_sparse vs = + case vs of + SPARSE _ => raise BadRep + | DENSE s => + let + val (offsets, total) = Seq.scan op+ 0 s + val sparse = ForkJoin.alloc total + val _ = Seq.foreach s (fn (i, v) => + if (v=1) then A.update (sparse, Seq.nth offsets i, Vertex.fromInt i) + else if (v = 0) then () + else raise BadRep + ) + in + SPARSE (AS.full sparse) + end + + fun from_sparse_rep s threshold n = + if (Seq.length s) < threshold then (SPARSE (s), threshold) + else (sparse_to_dense (SPARSE (s)) n, threshold) + + fun from_dense_rep s countopt threshold = + let + val count = + case countopt of + SOME x => x + | NONE => Seq.reduce op+ 0 s + val d = DENSE(s) + in + if count < threshold then (dense_to_sparse(d), threshold) + else (d, threshold) + end + end + + type vertex = Vertex.t + fun vertexNth s v = Seq.nth s (Vertex.toInt v) + fun vToWord v = Word64.fromInt (Vertex.toInt v) + + (* offsets, degrees, compact neighbors *) + type graph = (int Seq.t) * (int Seq.t) * (vertex Seq.t) + + fun degree G v = + let val (offsets, degrees, _) = G + in (vertexNth degrees v) + end + + fun neighbors G v = + let + val (offsets, _, nbrs) = G + in + Seq.subseq nbrs (vertexNth offsets v, degree G v) + end + + fun numVertices G = + let val (_, degrees, _) = G + in Seq.length degrees + end + + fun numEdges G = + let val (_, _, nbrs) = G + in Seq.length nbrs + end + + fun computeDegrees (N, M, offsets) = + AS.full (SeqBasis.tabulate 10000 (0, N) (fn i => + let + val off = Seq.nth offsets i + val nextOff = if i+1 < N then Seq.nth offsets (i+1) else M + val deg = nextOff - off + in + if deg < 0 then + raise Fail ("AdjacencyGraph.computeDegrees: vertex " ^ Int.toString i + ^ " has negative degree") + else + deg + end)) + + fun parse chars = + let + fun isNewline i = (Seq.nth chars i = #"\n") + + (* Computing newline positions takes up about half the time of parsing... + * Can we do this faster? *) + val nlPos = + AS.full (SeqBasis.filter 10000 (0, Seq.length chars) (fn i => i) isNewline) + val numLines = Seq.length nlPos + 1 + fun lineStart i = + if i = 0 then 0 else 1 + Seq.nth nlPos (i-1) + fun lineEnd i = + if i = Seq.length nlPos then Seq.length chars else Seq.nth nlPos i + fun line i = Seq.subseq chars (lineStart i, lineEnd i - lineStart i) + + val _ = + if numLines >= 3 then () + else raise Fail ("AdjacencyGraph: missing or incomplete header") + + val _ = + if Parse.parseString (line 0) = "AdjacencyGraph" then () + else raise Fail ("expected AdjacencyGraph header") + + fun tryParse thing lineNum = + let + fun whoops () = + raise Fail ("AdjacencyGraph: line " + ^ Int.toString (lineNum+1) + ^ ": error while parsing " ^ thing) + in + case (Parse.parseInt (line lineNum) handle _ => whoops ()) of + SOME x => if x >= 0 then x else whoops () + | NONE => whoops () + end + + val numVertices = tryParse "num vertices" 1 + val numEdges = tryParse "num edges" 2 + + val _ = + if numLines >= numVertices + numEdges + 3 then () + else raise Fail ("AdjacencyGraph: not enough offsets and/or edges to parse") + + val offsets = AS.full (SeqBasis.tabulate 1000 (0, numVertices) + (fn i => tryParse "edge offset" (3+i))) + + val neighbors = AS.full (SeqBasis.tabulate 1000 (0, numEdges) + (fn i => Vertex.fromInt (tryParse "neighbor" (3+numVertices+i)))) + in + (offsets, computeDegrees (numVertices, numEdges, offsets), neighbors) + end + + fun writeAsBinaryFormat g filename = + let + val (offsets, _, nbrs) = g + + val file = TextIO.openOut filename + val _ = TextIO.output (file, "AdjacencyGraphBin\n") + val _ = TextIO.closeOut file + + val file = BinIO.openAppend filename + fun w8 (w: Word8.word) = BinIO.output1 (file, w) + fun w64 (w: Word64.word) = + let + open Word64 + infix 2 >> andb + in + (* this will only work if Word64 = LargeWord, which is good. *) + w8 (Word8.fromLarge (w >> 0w56)); + w8 (Word8.fromLarge (w >> 0w48)); + w8 (Word8.fromLarge (w >> 0w40)); + w8 (Word8.fromLarge (w >> 0w32)); + w8 (Word8.fromLarge (w >> 0w24)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + fun wi (x: int) = w64 (Word64.fromInt x) + fun wv (v: vertex) = w64 (vToWord v) + in + wi (numVertices g); + wi (numEdges g); + Util.for (0, numVertices g) (fn i => wi (Seq.nth offsets i)); + Util.for (0, numEdges g) (fn i => wv (Seq.nth nbrs i)); + BinIO.closeOut file + end + + fun parseBin bytes = + let + val header = "AdjacencyGraphBin\n" + val header' = + if Seq.length bytes < String.size header then + raise Fail ("AdjacencyGraphBin: missing or incomplete header") + else + CharVector.tabulate (String.size header, fn i => + Char.chr (Word8.toInt (Seq.nth bytes i))) + val _ = + if header = header' then () + else raise Fail ("expected AdjacencyGraphBin header") + + val bytes = Seq.drop bytes (String.size header) + + (* this will only work if Word64 = LargeWord, which is good. *) + fun r64 i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val off = i*8 + val w = Word8.toLarge (Seq.nth bytes off) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+3))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+4))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+5))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+6))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off+7))) + in + w + end + + fun ri i = Word64.toInt (r64 i) + fun rv i = Vertex.fromInt (ri i) + + val numVertices = ri 0 + val numEdges = ri 1 + + val offsets = + AS.full (SeqBasis.tabulate 10000 (0, numVertices) (fn i => ri (i+2))) + val nbrs = + AS.full (SeqBasis.tabulate 10000 (0, numEdges) (fn i => rv (i+2+numVertices))) + in + (offsets, computeDegrees (numVertices, numEdges, offsets), nbrs) + end + + fun parseFile path = + let + val file = TextIO.openIn path + + val h1 = "AdjacencyGraph\n" + val h2 = "AdjacencyGraphBin\n" + + val actualHeader = + TextIO.inputN (file, Int.max (String.size h1, String.size h2)) + in + TextIO.closeIn file; + + if String.isPrefix h1 actualHeader then + let + val (c, tm) = Util.getTime (fn _ => ReadFile.contentsSeq path) + val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") + val (graph, tm) = Util.getTime (fn _ => parse c) + val _ = print ("parsed graph in " ^ Time.fmt 4 tm ^ "s\n") + in + graph + end + else if String.isPrefix h2 actualHeader then + let + val (c, tm) = Util.getTime (fn _ => ReadFile.contentsBinSeq path) + val _ = print ("read file in " ^ Time.fmt 4 tm ^ "s\n") + val (graph, tm) = Util.getTime (fn _ => parseBin c) + val _ = print ("parsed graph in " ^ Time.fmt 4 tm ^ "s\n") + in + graph + end + else + raise Fail ("unknown header " ^ actualHeader) + end + + (* Useful as a sanity check for symmetrized graphs -- + * (every symmetrized graph has edge parity 0, but not all graphs with + * edge parity 0 are symmetrized!) *) + fun parityCheck g = + let + val (offsets, _, _) = g + val n = numVertices g + + fun canonical (u, v) = + if Vertex.< (u, v) then (u, v) else (v, u) + fun xorEdges ((u1, v1), (u2, v2)) = + (Word64.xorb (u1, u2), Word64.xorb (v1, v2)) + fun packEdge (u, v) = (vToWord u, vToWord v) + + val (p1, p2) = SeqBasis.reduce 100 xorEdges (0w0, 0w0) (0, n) (fn i => + let + val u = Vertex.fromInt i + val offset = Seq.nth offsets i + in + SeqBasis.reduce 1000 xorEdges (0w0, 0w0) (0, degree g u) (fn j => + packEdge (canonical (u, Seq.nth (neighbors g u) j))) + end) + + in + p1 = 0w0 andalso p2 = 0w0 + end + + fun fromSortedEdges sorted = + let + fun edgeInts (u, v) = (Vertex.toInt u, Vertex.toInt v) + val m = Seq.length sorted + val n = + 1 + SeqBasis.reduce 10000 Int.max ~1 (0, m) + (Int.max o edgeInts o Seq.nth sorted) + + fun k i = Vertex.toInt (#1 (Seq.nth sorted i)) + + val ends = Seq.tabulate (fn i => if i = n then m else 0) (n+1) + val _ = ForkJoin.parfor 10000 (0, m) (fn i => + if i = m-1 then + AS.update (ends, k i, m) + else if k i <> k (i+1) then + AS.update (ends, k i, i+1) + else ()) + val (offsets, _) = Seq.scan Int.max 0 ends + + fun off i = Seq.nth offsets (i+1) - Seq.nth offsets i + val degrees = Seq.tabulate off n + + val nbrs = Seq.map #2 sorted + in + (offsets, degrees, nbrs) + end + + fun dedupEdges edges = + let + val sorted = + Mergesort.sort (fn ((u1,v1), (u2,v2)) => + case Vertex.compare (u1, u2) of + EQUAL => Vertex.compare (v1, v2) + | other => other) edges + in + AS.full (SeqBasis.filter 5000 (0, Seq.length sorted) (Seq.nth sorted) + (fn i => i = 0 orelse Seq.nth sorted (i-1) <> Seq.nth sorted i)) + end + + fun randSymmGraph n d = + let + val m = Real.ceil (Real.fromInt n * Real.fromInt d / 2.0) + + fun makeEdge i = + let + val u = (2 * i) div d + val v = Util.hash i mod (n-1) + in + (Vertex.fromInt u, Vertex.fromInt (if v < u then v else v+1)) + end + + val bothWays = ForkJoin.alloc (2*m) + val _ = ForkJoin.parfor 1000 (0, m) (fn i => + let + val (u, v) = makeEdge i + in + A.update (bothWays, 2*i, (u,v)); + A.update (bothWays, 2*i+1, (v,u)) + end) + in + fromSortedEdges (dedupEdges (AS.full bothWays)) + end + +end diff --git a/tests/mpllib/AdjacencyInt.sml b/tests/mpllib/AdjacencyInt.sml new file mode 100644 index 000000000..936410d77 --- /dev/null +++ b/tests/mpllib/AdjacencyInt.sml @@ -0,0 +1,157 @@ +structure AdjInt = +struct + type 'a seq = 'a Seq.t + + structure G = AdjacencyGraph(Int) + structure AS = ArraySlice + open G.VertexSubset + + fun to_seq g (vs, threshold) = + case vs of + SPARSE s => s + | DENSE s => to_seq g (G.VertexSubset.dense_to_sparse vs, threshold) + + fun should_process_sparse g n = + let + val denseThreshold = G.numEdges g div 20 + val deg = Int.div (G.numEdges g, G.numVertices g) + val count = (1 + deg) * n + in + count <= denseThreshold + end + + fun edge_map_dense g vertices f h = + let + val inFrontier = vertices + val n = Seq.length vertices + val res = Seq.tabulate (fn _ => 0) n + + fun processVertex v = + if not (h v) then 0 + else + let + val neighbors = G.neighbors g v + fun loop i = + if i >= Seq.length neighbors then 0 else + let val u = Seq.nth neighbors i + in + if not (Seq.nth inFrontier u = 1) then + loop (i+1) + else + case f (u, v) of + NONE => loop (i+1) + | SOME x => (AS.update (res, x, 1); 1) + end + in + loop 0 + end + val count = SeqBasis.reduce 1000 op+ 0 (0, n) processVertex + in + (res, count) + end + + fun edge_map_sparse g vertices f h = + let + val n = Seq.length vertices + fun ui uidx = Seq.nth vertices uidx + val r = + SeqBasis.scan 1000 op+ 0 (0, n) (G.degree g o ui) + val (offsets, totalOutDegree) = (AS.full r, Array.sub (r, n)) + val store = ForkJoin.alloc totalOutDegree + val k = 100 + val numBlocks = 1 + (totalOutDegree-1) div k + fun map_block i = + let + val lo = i*k + val hi = Int.min((i+1)*k, totalOutDegree) + val ulo = + let + val a = BinarySearch.search (Int.compare) offsets lo + in + if (Seq.nth offsets a) > lo then a - 1 + else a + end + fun map_seq idx (u, uidx) count = + if idx >= hi then count + else if idx >= (Seq.nth offsets (uidx + 1)) then map_seq idx (ui (uidx + 1), uidx + 1) count + else + let + val v = Seq.nth (G.neighbors g u) (idx - (Seq.nth offsets uidx)) + in + if (h v) then + case f (u, v) of + SOME x => (Array.update (store, lo + count, x); map_seq (idx + 1) (u, uidx) (count + 1)) + | NONE => (map_seq (idx + 1) (u, uidx) count) + else + (map_seq (idx + 1) (u, uidx) count) + end + in + map_seq lo (ui ulo, ulo) 0 + end + val counts = SeqBasis.tabulate 1 (0, numBlocks) map_block + val outOff = SeqBasis.scan 10000 op+ 0 (0, numBlocks) (fn i => Array.sub (counts, i)) + val outSize = Array.sub (outOff, numBlocks) + val result = ForkJoin.alloc outSize + in + ForkJoin.parfor (totalOutDegree div (Int.max (outSize, 1))) (0, numBlocks) (fn i => + let + val soff = i * k + val doff = Array.sub (outOff, i) + val size = Array.sub (outOff, i+1) - doff + in + Util.for (0, size) (fn j => + Array.update (result, doff+j, Array.sub (store, soff+j))) + end); + (AS.full result) + end + + fun edge_map g (vs, threshold) (fpar, f) h = + case vs of + SPARSE s => + from_sparse_rep (edge_map_sparse g s fpar h) threshold (G.numVertices g) + | DENSE s => + let + val (res, count) = edge_map_dense g s f h + in + from_dense_rep res (SOME count) threshold + end + + fun vertex_foreach g (vs, threshold) f = + case vs of + SPARSE s => + Seq.foreach s (fn (i, u) => f u) + | DENSE s => + Seq.foreach s (fn (i, b) => if (b = 1) then (f i) else ()) + + fun vertex_map_ g (vs, threshold) f = + case vs of + SPARSE s => + let + val s' = + AS.full (SeqBasis.tabFilter 1000 (0, Seq.length s) + (fn i => + let + val u = Seq.nth s i + val b = f u + in + if b then SOME u + else NONE + end + )) + in + (from_sparse_rep s' threshold (G.numVertices g)) + end + | DENSE s => + let + val res = + Seq.map (fn i => if (Seq.nth s i = 1) andalso f i then 1 else 0) s + in + from_dense_rep res NONE threshold + end + + + fun vertex_map g vs f needOut = + if needOut then vertex_map_ g vs f + else (vertex_foreach g vs; vs) + +end diff --git a/tests/mpllib/ArraySequence.sml b/tests/mpllib/ArraySequence.sml new file mode 100644 index 000000000..952e5b55a --- /dev/null +++ b/tests/mpllib/ArraySequence.sml @@ -0,0 +1,281 @@ +structure ArraySequence = +struct + + val for = Util.for + val par = ForkJoin.par + val parfor = ForkJoin.parfor + val alloc = ForkJoin.alloc + + + val GRAN = 5000 + + + structure A = + struct + open Array + type 'a t = 'a array + fun nth a i = sub (a, i) + end + + + structure AS = + struct + open ArraySlice + type 'a t = 'a slice + fun nth a i = sub (a, i) + end + + + type 'a t = 'a AS.t + type 'a seq = 'a t + + (* for compatibility across all sequence implementations *) + fun fromArraySeq s = s + fun toArraySeq s = s + fun force s = s + + + val nth = AS.nth + val length = AS.length + + fun empty () = AS.full (A.fromList []) + fun singleton x = AS.full (A.array (1, x)) + val $ = singleton + fun toString f s = + "<" ^ String.concatWith "," (List.tabulate (length s, f o nth s)) ^ ">" + + + fun fromArray a = AS.full a + + + fun fromList l = AS.full (A.fromList l) + val % = fromList + fun toList s = + SeqBasis.foldl (fn (list, x) => x :: list) [] + (0, length s) (fn i => nth s (length s - i - 1)) + + + fun subseq s (i, k) = + AS.subslice (s, i, SOME k) + fun take s n = subseq s (0, n) + fun drop s n = subseq s (n, length s - n) + fun first s = nth s 0 + fun last s = nth s (length s - 1) + + fun tabulate f n = + AS.full (SeqBasis.tabulate GRAN (0, n) f) + + + fun map f s = + tabulate (f o nth s) (length s) + + + fun mapIdx f s = + tabulate (fn i => f (i, nth s i)) (length s) + + + fun enum s = + mapIdx (fn xx => xx) s + + + fun zipWith f (s, t) = + tabulate + (fn i => f (nth s i, nth t i)) + (Int.min (length s, length t)) + + + fun zipWith3 f (s1, s2, s3) = + tabulate + (fn i => f (nth s1 i, nth s2 i, nth s3 i)) + (Int.min (length s1, Int.min (length s2, length s3))) + + + fun zip (s, t) = + zipWith (fn xx => xx) (s, t) + + + fun rev s = + tabulate (fn i => nth s (length s - 1 - i)) (length s) + + + (** TODO: make faster *) + fun fromRevList list = rev (fromList list) + + + fun append (s, t) = + let + val (ns, nt) = (length s, length t) + fun ith i = if i < ns then nth s i else nth t (i-ns) + in + tabulate ith (ns + nt) + end + + + fun append3 (a, b, c) = + let + val (na, nb, nc) = (length a, length b, length c) + fun ith i = + if i < na then + nth a i + else if i < na + nb then + nth b (i - na) + else + nth c (i - na - nb) + in + tabulate ith (na + nb + nc) + end + + + fun foldl f b s = + SeqBasis.foldl f b (0, length s) (nth s) + + fun foldr f b s = + SeqBasis.foldr f b (0, length s) (nth s) + + fun iterate f b s = + SeqBasis.foldl f b (0, length s) (nth s) + + + fun iteratePrefixes f b s = + let + val prefixes = alloc (length s) + fun g ((i, b), a) = + let + val _ = A.update (prefixes, i, b) + in + (i+1, f (b, a)) + end + val (_, r) = iterate g (0, b) s + in + (AS.full prefixes, r) + end + + + fun reduce f b s = + SeqBasis.reduce GRAN f b (0, length s) (nth s) + + + fun scan f b s = + let + val p = AS.full (SeqBasis.scan GRAN f b (0, length s) (nth s)) + in + (take p (length s), nth p (length s)) + end + + fun scanWithTotal f b s = + AS.full (SeqBasis.scan GRAN f b (0, length s) (nth s)) + + fun scanIncl f b s = + let + val p = AS.full (SeqBasis.scan GRAN f b (0, length s) (nth s)) + in + drop p 1 + end + + + fun filter p s = + (* Assumes that the predicate p is pure *) + AS.full (SeqBasis.filter GRAN (0, length s) (nth s) (p o nth s)) + + fun filterSafe p s = + (* Does not assume that the predicate p is pure *) + AS.full (SeqBasis.tabFilter GRAN (0, length s) (fn i => if p (nth s i) then SOME (nth s i) else NONE)) + + fun filterIdx p s = + AS.full (SeqBasis.filter GRAN (0, length s) (nth s) (fn i => p (i, nth s i))) + + fun filtermap (p: 'a -> bool) (f:'a -> 'b) (s: 'a t): 'b t = + AS.full (SeqBasis.filter GRAN (0, length s) (fn i => f (nth s i)) (p o nth s)) + + + fun mapOption f s = + AS.full (SeqBasis.tabFilter GRAN (0, length s) (f o nth s)) + + + fun equal eq (s, t) = + length s = length t andalso + SeqBasis.reduce GRAN (fn (a, b) => a andalso b) true (0, length s) + (fn i => eq (nth s i, nth t i)) + + + fun inject (s, updates) = + let + val result = map (fn x => x) s + in + parfor GRAN (0, length updates) (fn i => + let + val (idx, r) = nth updates i + in + AS.update (result, idx, r) + end); + + result + end + + + fun applyIdx s f = + parfor GRAN (0, length s) (fn i => f (i, nth s i)) + + + fun foreach s f = applyIdx s f + + + fun indexSearch (start, stop, offset: int -> int) k = + case stop-start of + 0 => + raise Fail "ArraySequence.indexSearch: should not have hit 0" + | 1 => + start + | n => + let + val mid = start + (n div 2) + in + if k < offset mid then + indexSearch (start, mid, offset) k + else + indexSearch (mid, stop, offset) k + end + + + fun flatten s = + let + val offsets = SeqBasis.scan GRAN op+ 0 (0, length s) (length o nth s) + fun offset i = A.nth offsets i + val total = offset (length s) + val result = alloc total + + val blockSize = GRAN + val numBlocks = Util.ceilDiv total blockSize + in + parfor 1 (0, numBlocks) (fn blockIdx => + let + val lo = blockIdx * blockSize + val hi = Int.min (lo + blockSize, total) + + val firstOuterIdx = indexSearch (0, length s, offset) lo + val firstInnerIdx = lo - offset firstOuterIdx + + (** i = outer index + * j = inner index + * k = output index, ranges from [lo] to [hi] + *) + fun loop i j k = + if k >= hi then () else + let + val inner = nth s i + val numAvailableHere = length inner - j + val numRemainingInBlock = hi - k + val numHere = Int.min (numAvailableHere, numRemainingInBlock) + in + for (0, numHere) (fn z => A.update (result, k+z, nth inner (j+z))); + loop (i+1) 0 (k+numHere) + end + in + loop firstOuterIdx firstInnerIdx lo + end); + + AS.full result + end + + +end diff --git a/tests/mpllib/AugMap.sml b/tests/mpllib/AugMap.sml new file mode 100644 index 000000000..78b907c1f --- /dev/null +++ b/tests/mpllib/AugMap.sml @@ -0,0 +1,509 @@ +datatype scheme = WB of real + +signature Aug = +sig + type key + type value (* can make this polymorhpic *) + type aug + val compare : key * key -> order + val g : key * value -> aug + val f : aug * aug -> aug + val id : aug + val balance : scheme + val debug : key * value * aug -> string +end + +signature AugMap = +sig + exception Assert + structure T : Aug + datatype am = Leaf | Node of {l : am, k : T.key, v : T.value, a : T.aug, r : am, size : int} + val empty : unit -> am + val size : am -> int + val find : am -> T.key -> T.value option + val union : am -> am -> (T.value * T.value -> T.value) -> am + val insert : am -> T.key -> T.value -> (T.value * T.value -> T.value) -> am + val multi_insert : am -> ((T.key * T.value) Seq.t) -> (T.value * T.value -> T.value) -> am + val mapReduce : am -> (T.key * T.value -> 'b) -> ('b * 'b -> 'b)-> 'b -> 'b + val join : am -> T.key -> T.value -> am -> am + val filter : (T.key * T.value -> bool) -> am -> am + val build : ((T.key * T.value) Seq.t) -> int -> int -> am + val aug_left : am -> T.key -> T.aug + val aug_filter : am -> (T.aug -> bool) -> am + val aug_range : am -> T.key -> T.key -> T.aug + val aug_project : (T.aug -> 'a) -> ('a * 'a -> 'a) -> am -> T.key -> T.key -> 'a + val up_to : am -> T.key -> am + val print_tree : am -> string -> unit + val singleton : T.key -> T.value -> am +end + +functor PAM (T: Aug) : AugMap = +struct + type key = T.key + type value = T.value + type aug = T.aug + + (* how to add the metric for balancing? *) + datatype am = Leaf | Node of {l : am, k : key, v : value, a : aug, r : am, size : int} + exception Assert + structure T = T + (* | FatLeaf (array ) in order traversal gran + leaves*) + (* joinG that takes the grain and join which does something itself *) + (* maybe not store augmented values for thin leaves*) + (* fat leaves trick -- use leaves of arrays *) + val gran = 100 + + fun weight m = + case m of + Leaf => 0 + | Node n => #size n + + fun size m = + case m of + Leaf => 0 + | Node n => #size n + + fun aug_val m = + case m of + Leaf => T.id + | Node {a, ...} => a + + structure WBST = + struct + + fun leaf_weight () = 0 + + fun singleton_weight () = 1 + + val ratio = + case T.balance of + WB (x) => x / (1.0 - x) + + fun size_heavy s1 s2 = + ratio * (Real.fromInt s1) > (Real.fromInt s2) + + fun heavy (m1 : am, m2 : am) = size_heavy (weight m1) (weight m2) + + fun like s1 s2 = not (size_heavy s1 s2) andalso not (size_heavy s2 s1) + + fun compose m1 k v m2 = + let + val new_size = (size m1) + (size m2) + 1 + val a = T.f ((aug_val m1), T.f (T.g(k, v), (aug_val m2))) + in + Node {l = m1, k = k, v = v, a = a, r = m2, size = new_size} + end + + fun rotateLeft m = + case m of + Leaf => m + | Node {l, k, v, a, r, size} => + case r of + Leaf => m + | Node {l = rl, k = rk, v = rv, r = rr, ...} => + let + val left = compose l k v rl + in + compose left rk rv rr + end + + fun rotateRight m = + case m of + Leaf => m + | Node {l, k, v, a, r, size} => + case l of + Leaf => m + | Node {l = ll, k = lk, v = lv, r = lr, ...} => + let + val right = compose lr k v r + in + compose ll lk lv right + end + + fun joinLeft (m1 : am) k v (m2 : am) = + let + val w1 = weight m1 + val w2 = weight m2 + in + if (like w1 w2) then compose m1 k v m2 + else + case m2 of + Leaf => compose m1 k v m2 + | Node {l, k = kr, v = vr, r, size, ...} => + let + val t' = joinLeft m1 k v l + val (wlt', wrt') = case t' of + Leaf => raise Assert + | Node {l, r, ...} => ((weight l), (weight r)) + val wr = weight r + in + if like (weight t') wr then compose t' kr vr r + else if like wrt' wr andalso like wlt' (wrt' + wr) then + rotateRight (compose t' kr vr r) + else rotateRight (compose (rotateLeft t') kr vr r) + end + end + + fun joinRight m1 k v m2 = + let + val w1 = weight m1 + val w2 = weight m2 + in + if like w1 w2 then compose m1 k v m2 + else + case m1 of + Leaf => compose m1 k v m2 + | Node {l, k = kl, v = vl, r, size, ...} => + let + val t' = joinRight r k v m2 + val (wlt', wrt') = case t' of + Leaf => raise Assert + | Node n => ((weight (#l n)), (weight (#r n))) + val wl = weight l + in + if like wl (weight t') then compose l kl vl t' + else if like wl wlt' andalso like (wl + wlt') wrt' then + rotateLeft (compose l kl vl t') + else rotateLeft (compose l kl vl (rotateRight t')) + end + end + + fun join m1 k v m2 = + if heavy(m1, m2) then joinRight m1 k v m2 + else if heavy(m2, m1) then joinLeft m1 k v m2 + else compose m1 k v m2 + + end + + fun par (f1, f2) = ForkJoin.par (f1, f2) + + fun eval (inpar, f1, f2) = + if inpar then ForkJoin.par (f1, f2) + else (f1(), f2()) + + fun join m1 k v m2 = + case T.balance of + WB _ => WBST.join m1 k v m2 + + fun join2 m1 m2 = + let + fun splitLast {l, k, v, a, r, size, ...} = + case r of + Leaf => (l, k, v) + | Node n => + let + val (m', k', v') = splitLast n + in + (join l k v m', k', v') + end + in + case m1 of + Leaf => m2 + | Node n => + let + val (m', k', v') = splitLast n (*get the greatest element in m1*) + in + join m' k' v' m2 + end + end + + fun empty () = Leaf + + fun singleton k v = + Node {l = Leaf, k = k, v = v, a = T.g (k, v), r = Leaf, size = 1} + + fun build_sorted s i j = + if i = j then empty() + else if i + 1 = j then singleton (#1 (Seq.nth s i)) (#2 (Seq.nth s i)) + else + let + val m = i + Int.div ((j - i), 2) + val (l, r) = if (j - i) > 1000 then eval (true, fn _ => build_sorted s i m, fn _ => build_sorted s (m + 1) j) + else (build_sorted s i m, build_sorted s (m + 1) j) + val (x, y) = Seq.nth s m + in + join l x y r + end + + fun find m k = + case m of + Leaf => NONE + | Node ({l, k = kr, v, r, ...}) => + case T.compare(k, kr) of + LESS => find l k + | EQUAL => SOME v + | GREATER => find r k + + fun insert m k v h = + case m of + Leaf => singleton k v + | Node {l, k = kr, v = vr, r, ...} => + case T.compare(k, kr) of + EQUAL => join l k (h (v, vr)) r + | LESS => join (insert l k v h) kr vr r + | GREATER => join l kr vr (insert r k v h) + + fun key_equal k1 k2 = T.compare (k1, k2) = EQUAL + + fun multi_insert m s h = + let + val ss = Mergesort.sort (fn (i, j) => T.compare(#1 i, #1 j)) s + fun insert_helper m' i j = + if i >= j then + m' + else if (j - i) < gran then + (* this v/s the else branch with recursive calls done sequentially *) + SeqBasis.foldl (fn (m'', (k, v)) => insert m'' k v h) m' (i, j) (Seq.nth ss) + else + case m' of + Leaf => build_sorted s i j + | Node {l, k = kr, v = vr, r, ...} => + let + fun bin_search k i j = + (* inv i < j, all k inside [i, j) *) + (* returns [lk, rk) every element in the range has key = k *) + if (i + 1 = j) then + if key_equal k (#1 (Seq.nth ss i)) then (i, j) + else (i + 1, j) + else + let + val mid = i + Int.div ((j - i), 2) (* i < mid < j*) + val mid_val = Seq.nth ss mid + fun until_boundary b1 b2 e i f = + if i < b1 then b1 + else if i >= b2 then b2 - 1 + else if (e (Seq.nth ss i)) then + until_boundary b1 b2 e (f i) f + else i + in + case T.compare (k, #1 mid_val) of + LESS => bin_search k i mid + | EQUAL => + let + val bound_func = until_boundary i j (fn t => key_equal k (#1 t)) mid + in + (bound_func (fn i => i - 1), 1 + bound_func (fn i => i + 1)) + end + | GREATER => bin_search k mid j + end + val (lk, rk) = bin_search kr i j + val (l, r) = par (fn _ => insert_helper l i lk, fn _ => insert_helper r rk j) + val nvr = SeqBasis.foldl (fn (v', (k, v)) => h (v, v')) vr (lk, rk) (Seq.nth ss) + in + join l kr nvr r + end + in + insert_helper m 0 (Seq.length ss) + end + + fun split m k = + case m of + Leaf => (Leaf, NONE, Leaf) + | Node {l, k = kr, v = vr, r, ...} => + case T.compare(k, kr) of + EQUAL => (l, SOME vr, r) + | LESS => + let + val (ll, so, lr) = split l k + in + (ll, so, join lr kr vr r) + end + | GREATER => + let + val (rl, so, rr) = split r k + in + (join l kr vr rl, so, rr) + end + + (* this is not efficient because each join is more expensive *) + (* fun split_tail_rec m k = + let + fun split_helper m accl accr = + case m of + Leaf => (accl, NONE, accr) + | Node {l, k = kr, v = vr, r, ...} => + case T.compare(k, kr) of + EQUAL => (join2 l accl, SOME vr, join2 r accr) + | LESS => split_helper l accl (join2 (singleton kr vr) (join2 r accr)) + | GREATER => split_helper r (join2 (singleton kr vr) (join2 accl l)) accr + in + split_helper m (empty()) (empty()) + end *) + + fun union m1 m2 h = + case (m1, m2) of + (Leaf, _) => m2 + | (_, Leaf) => m1 + | (Node {l = l1, k = k1, v = v1, r = r1, size, ...}, m2) => + let + val (l2, so, r2) = split m2 k1 + val new_val = case so of + NONE => v1 + | SOME v' => h(v1, v') + + val (ul, ur) = eval(size > gran, fn _ => union l1 l2 h, fn _ => union r1 r2 h) + in + join ul k1 new_val ur + end + + fun filter h m = + case m of + Leaf => m + | Node {l, k, v, r, ...} => + let + val (l', r') = par(fn _ => filter h l, fn _ => filter h r) + in + if h(k, v) then join l' k v r' + else join2 l' r' + end + + fun mapReduce m g f id = + case m of + Leaf => id + | Node {l, k, v, r, size, ...} => + let + val (l', r') = eval(size > gran, fn _ => mapReduce l g f id, fn _ => mapReduce r g f id) + in + f(f(l', g(k, v)), r') + end + + fun build ss i j = + let + + val t0 = Time.now () + (* val _ = Mergesort.sortInPlace (fn (i, j) => T.compare(#1 i, #1 j)) ss *) + + val idx = Seq.tabulate (fn i => i) (Seq.length ss) + fun cmp (x, y) = T.compare (#1 (Seq.nth ss x), #1 (Seq.nth ss y)) + + val idx' = idx + (* val idx' = Dedup.dedup (fn (i, j) => cmp (i, j) = EQUAL) (fn i => Util.hash64 (Word64.fromInt i)) (fn i =>Util.hash64 (Word64.fromInt (i + 1))) idx *) + + val _ = Mergesort.sortInPlace cmp idx' + + val ss' = Seq.map (Seq.nth ss) idx' + + val t1 = Time.now() + val _ = print ("sorting time = " ^ Time.fmt 4 (Time.-(t1, t0)) ^ "s\n") + val t2 = Time.now() + val r = build_sorted ss' i (Seq.length ss') + val t3 = Time.now() + val _ = print ("from sorted time = " ^ Time.fmt 4 (Time.-(t3, t2)) ^ "s\n") + in + r + end + + fun up_to m k = + case m of + Leaf => m + | Node {l, k = k', r, v, ...} => + case T.compare (k, k') of + LESS => up_to l k + | EQUAL => join2 l (singleton k' v) + | GREATER => join l k' v (up_to r k) + + fun aug_filter m h = + case m of + Leaf => m + | Node {l, k, v, r, a, size, ...} => + let + val (l', r') = eval (size > gran, fn _ => aug_filter l h, fn _ => aug_filter r h) + in + if h (T.g (k, v)) then join l' k v r' + else join2 l' r' + end + + (* case m of + Leaf => T.id + | Node {l, k, v, r, size, a, ...} => + case (T.compare (k, k1), T.compare (k, k2)) of + (LESS, _) => aug_range r k1 k2 + | (EQUAL, _) => T.f ((aug_range r k1 k2), T.g (k, v)) + | (_, EQUAL) => T.f ((aug_range l k1 k2), T.g (k, v)) + | (_, GREATER) => aug_range l k1 k2 + | (GREATER, LESS) => + let + val (lval, rval) = eval (size > gran, fn _ => aug_range l k1 k2, fn _ => aug_range r k1 k2) + in + T.f (lval, T.f (T.g (k, v), rval)) + end *) + + fun aug_project (ga : T.aug -> 'a) fa m (k1: T.key) (k2 : T.key) = + let + val default_val : 'a = ga T.id + fun until_root_in_range m k1 k2 = + case m of + Leaf => m + | Node {l, k, r, ...} => + case (T.compare (k, k1), T.compare (k, k2)) of + (LESS, _) => until_root_in_range r k1 k2 + | (_, GREATER) => until_root_in_range l k1 k2 + | _ => m + + fun compose_map_kv m k v = fa (ga (aug_val m), ga (T.g(k, v))) + + fun compose_kv_map k v m = fa (ga (T.g(k, v)), ga (aug_val m)) + + fun aug_proj_left m k acc = + case m of + Leaf => acc + | Node {l, k = k', v, r, ...} => + case T.compare (k, k') of + LESS => aug_proj_left l k acc + | EQUAL => fa (acc, compose_map_kv l k v) + | GREATER => aug_proj_left r k (fa (acc, compose_map_kv l k v)) + + fun aug_proj_right m k acc = + case m of + Leaf => acc + | Node {l, k = k', v, r, ...} => + case T.compare (k, k') of + LESS => aug_proj_right l k (fa (acc, compose_kv_map k v r)) + | EQUAL => fa (acc, compose_kv_map k v r) + | GREATER => aug_proj_right r k acc + + fun aug_proj m = + let + val sm = until_root_in_range m k1 k2 + in + case sm of + Leaf => default_val + | Node {l, k, v, r, ...} => + let + val ra = aug_proj_left r k2 default_val + val la = aug_proj_right l k1 default_val + val ka = ga (T.g (k, v)) + in + fa (fa (la, ka), ra) + end + end + in + aug_proj m + end + + fun aug_range m k1 k2 = aug_project (fn x => x) (T.f) m k1 k2 + + fun aug_left m k = + case m of + Leaf => T.id + | Node {l, k = k', v, r, ...} => + case T.compare(k, k') of + LESS => aug_left l k + | EQUAL => T.f (aug_val l, T.g(k, v)) + | GREATER => T.f (aug_val l, T.f (T.g(k, v), aug_left r k)) + + fun print_tree m indent = + case m of + Leaf => print (indent ^ "Leaf") + | Node {l, k, v, r, a, size} => + let + val _ = print "(" + val _ = print_tree l (indent^"") + val _ = print (", " ^(T.debug (k, v, a))) + val _ = print (", weight = " ^ (Int.toString size) ^ ",") + val _ = print_tree r (indent^"") + in + print ")" + end + +end diff --git a/tests/mpllib/Benchmark.sml b/tests/mpllib/Benchmark.sml new file mode 100644 index 000000000..e32a05f55 --- /dev/null +++ b/tests/mpllib/Benchmark.sml @@ -0,0 +1,83 @@ +structure Benchmark = +struct + + fun getTimes msg n f = + let + fun loop tms n = + let + val (result, tm) = Util.getTime f + in + print (msg ^ " " ^ Time.fmt 4 tm ^ "s\n"); + + if n <= 1 then (result, List.rev (tm :: tms)) + else loop (tm :: tms) (n - 1) + end + in + loop [] n + end + + fun run msg f = + let + val warmup = Time.fromReal (CommandLineArgs.parseReal "warmup" 0.0) + val rep = CommandLineArgs.parseInt "repeat" 1 + val _ = if rep >= 1 then () else Util.die "-repeat N must be at least 1" + + val _ = print ("warmup " ^ Time.fmt 4 warmup ^ "\n") + val _ = print ("repeat " ^ Int.toString rep ^ "\n") + + fun warmupLoop startTime = + if Time.>= (Time.- (Time.now (), startTime), warmup) then + () (* warmup done! *) + else + let val (_, tm) = Util.getTime f + in print ("warmup_run " ^ Time.fmt 4 tm ^ "s\n"); warmupLoop startTime + end + + val _ = + if Time.<= (warmup, Time.zeroTime) then + () + else + ( print ("====== WARMUP ======\n" ^ msg ^ "\n") + ; warmupLoop (Time.now ()) + ; print ("==== END WARMUP ====\n") + ) + + val _ = print (msg ^ "\n") + val s0 = RuntimeStats.get () + val t0 = Time.now () + val (result, tms) = getTimes "time" rep f + val t1 = Time.now () + val s1 = RuntimeStats.get () + val endToEnd = Time.- (t1, t0) + + fun stdev rtms avg = + let + val SS = List.foldr (fn (a, b) => (a - avg) * (a - avg) + b) 0.0 rtms + val sample = Real.fromInt (List.length rtms - 1) + in + Math.sqrt (SS / sample) + end + + val rtms = List.map Time.toReal tms + val total = List.foldl Time.+ Time.zeroTime tms + val avg = Time.toReal total / (Real.fromInt rep) + val std = if rep > 1 then stdev rtms avg else 0.0 + val tmax = Time.toReal + (List.foldl (fn (a, M) => if Time.< (a, M) then M else a) (List.hd tms) + (List.tl tms)) + val tmin = Time.toReal + (List.foldl (fn (a, m) => if Time.< (a, m) then a else m) (List.hd tms) + (List.tl tms)) + in + print "\n"; + print ("average " ^ Real.fmt (StringCvt.FIX (SOME 4)) avg ^ "s\n"); + print ("minimum " ^ Real.fmt (StringCvt.FIX (SOME 4)) tmin ^ "s\n"); + print ("maximum " ^ Real.fmt (StringCvt.FIX (SOME 4)) tmax ^ "s\n"); + print ("std dev " ^ Real.fmt (StringCvt.FIX (SOME 4)) std ^ "s\n"); + print ("total " ^ Time.fmt 4 total ^ "s\n"); + print ("end-to-end " ^ Time.fmt 4 endToEnd ^ "s\n"); + RuntimeStats.benchReport {before = s0, after = s1}; + result + end + +end diff --git a/tests/mpllib/BinarySearch.sml b/tests/mpllib/BinarySearch.sml new file mode 100644 index 000000000..6b07ef8b5 --- /dev/null +++ b/tests/mpllib/BinarySearch.sml @@ -0,0 +1,81 @@ +structure BinarySearch: +sig + type 'a seq = 'a ArraySlice.slice + val search: ('a * 'a -> order) -> 'a seq -> 'a -> int + + (* count the number of elements strictly less than the target *) + val countLess: ('a * 'a -> order) -> 'a seq -> 'a -> int + + (** Sometimes, you aren't looking for a particular element, but instead just + * some position in the sequence. The function ('a -> order) is used here to + * point towards the target position. + * + * Note that this is more general than the plain `search` function, because + * we can implement `search` in terms of `searchPosition`: + * fun search cmp s x = searchPosition s (fn y => cmp (x, y)) + *) + val searchPosition: 'a seq -> ('a -> order) -> int +end = +struct + + type 'a seq = 'a ArraySlice.slice + + fun search cmp s x = + let + fun loop lo hi = + case hi - lo of + 0 => lo + | n => + let + val mid = lo + n div 2 + val pivot = ArraySlice.sub (s, mid) + in + case cmp (x, pivot) of + LESS => loop lo mid + | EQUAL => mid + | GREATER => loop (mid+1) hi + end + in + loop 0 (ArraySlice.length s) + end + + + fun countLess cmp s x = + let + fun loop lo hi = + case hi - lo of + 0 => lo + | n => + let + val mid = lo + n div 2 + val pivot = ArraySlice.sub (s, mid) + in + case cmp (x, pivot) of + GREATER => loop (mid+1) hi + | _ => loop lo mid + end + in + loop 0 (ArraySlice.length s) + end + + + fun searchPosition s compareTargetAgainst = + let + fun loop lo hi = + case hi - lo of + 0 => lo + | n => + let + val mid = lo + n div 2 + val pivot = ArraySlice.sub (s, mid) + in + case compareTargetAgainst pivot of + LESS => loop lo mid + | EQUAL => mid + | GREATER => loop (mid+1) hi + end + in + loop 0 (ArraySlice.length s) + end + +end diff --git a/tests/mpllib/CheckSort.sml b/tests/mpllib/CheckSort.sml new file mode 100644 index 000000000..ba9cc112d --- /dev/null +++ b/tests/mpllib/CheckSort.sml @@ -0,0 +1,69 @@ +functor CheckSort + (val sort_func: ('a * 'a -> order) -> 'a Seq.t -> 'a Seq.t): +sig + datatype 'a error = + LengthChange (* output length differs from input *) + | MissingElem of int (* index of missing input element *) + | Inversion of int * int (* indices of two elements not in order in output *) + | Unstable of int * int (* indices of two equal swapped elements *) + + val check: + { input: 'a Seq.t + , compare: 'a * 'a -> order + , check_stable: bool + } + -> 'a error option (* NONE if correct *) +end = +struct + + structure DS = DelayedSeq + + type 'a seq = 'a Seq.t + + datatype 'a error = + LengthChange (* output length differs from input *) + | MissingElem of int (* index of missing input element *) + | Inversion of int * int (* indices of two elements not in order in output *) + | Unstable of int * int (* indices of two equal swapped elements *) + + type 'a check_input = + { input: 'a seq + , compare: 'a * 'a -> order + , check_stable: bool + } + + fun check ({input, compare, check_stable}: 'a check_input) = + let + val n = Seq.length input + val input' = Seq.mapIdx (fn (i, x) => (i, x)) input + fun compare' ((i1, k1), (i2, k2)) = compare (k1, k2) + + val result = sort_func compare' input' + + val noElemsMissing: bool = + DS.reduce (fn (a, b) => a andalso b) true + (DS.inject + ( DS.tabulate (fn _ => false) n + , DS.map (fn (i, k) => (i, true)) (DS.fromArraySeq result) + )) + + fun adjacentPairProblem ((i1, k1), (i2, k2)) = + case compare (k1, k2) of + LESS => NONE + | GREATER => SOME (Inversion (i1, i2)) + | EQUAL => + if check_stable andalso i1 > i2 then + SOME (Unstable (i1, i2)) + else + NONE + + fun problemAt i = + adjacentPairProblem (Seq.nth result i, Seq.nth result (i+1)) + fun isProblem i = Option.isSome (problemAt i) + in + case FindFirst.findFirst 1000 (0, n-1) isProblem of + NONE => NONE + | SOME i => problemAt i + end + +end diff --git a/tests/mpllib/ChunkedTreap.sml b/tests/mpllib/ChunkedTreap.sml new file mode 100644 index 000000000..8b7088ae0 --- /dev/null +++ b/tests/mpllib/ChunkedTreap.sml @@ -0,0 +1,186 @@ +functor ChunkedTreap( + structure A: + sig + type 'a t + type 'a array = 'a t + val tabulate: int * (int -> 'a) -> 'a t + val sub: 'a t * int -> 'a + val subseq: 'a t -> {start: int, len: int} -> 'a t + val update: 'a t * int * 'a -> 'a t + val length: 'a t -> int + end + + structure Key: + sig + type t + type key = t + val comparePriority: key * key -> order + val compare: key * key -> order + val toString: key -> string + end + + val leafSize: int +) : +sig + type 'a t + type 'a bst = 'a t + type key = Key.t + + val empty: unit -> 'a bst + val singleton: key * 'a -> 'a bst + + val toString: 'a bst -> string + + val size: 'a bst -> int + val join: 'a bst * 'a bst -> 'a bst + val updateKey: 'a bst -> (key * 'a) -> 'a bst + val lookup: 'a bst -> key -> 'a option +end = +struct + + type key = Key.t + + datatype 'a t = + Empty + | Chunk of (key * 'a) A.t + | Node of {left: 'a t, right: 'a t, key: key, value: 'a, size: int} + + type 'a bst = 'a t + + fun toString t = + case t of + Empty => "()" + | Chunk arr => + "(" ^ String.concatWith " " + (List.tabulate (A.length arr, fn i => Key.toString (#1 (A.sub (arr, i))))) + ^ ")" + | Node {left, key, right, ...} => + "(" ^ toString left ^ " " ^ Key.toString key ^ " " ^ toString right ^ ")" + + fun empty () = Empty + + fun singleton (k, v) = + Chunk (A.tabulate (1, fn _ => (k, v))) + + fun size Empty = 0 + | size (Chunk a) = A.length a + | size (Node {size=n, ...}) = n + + fun makeChunk arr {start, len} = + if len = 0 then + Empty + else + Chunk (A.subseq arr {start = start, len = len}) + + fun expose Empty = NONE + + | expose (Node {key, value, left, right, ...}) = + SOME (left, key, value, right) + + | expose (Chunk arr) = + let + val n = A.length arr + val half = n div 2 + val left = makeChunk arr {start = 0, len = half} + val right = makeChunk arr {start = half+1, len = n-half-1} + val (k, v) = A.sub (arr, half) + in + SOME (left, k, v, right) + end + + + fun nth t i = + case t of + Empty => raise Fail "ChunkedTreap.nth Empty" + | Chunk arr => A.sub (arr, i) + | Node {left, right, key, value, ...} => + if i < size left then + nth left i + else if i = size left then + (key, value) + else + nth right (i - size left - 1) + + + fun makeNode left (k, v) right = + let + val n = size left + size right + 1 + val node = Node {left=left, right=right, key=k, value=v, size=n} + in + if n <= leafSize then + Chunk (A.tabulate (n, nth node)) + else + node + end + + + fun join (t1, t2) = + if size t1 + size t2 = 0 then + Empty + else if size t1 + size t2 <= leafSize then + Chunk (A.tabulate (size t1 + size t2, fn i => + if i < size t1 then nth t1 i else nth t2 (i - size t1))) + else + case (expose t1, expose t2) of + (NONE, _) => t2 + | (_, NONE) => t1 + | (SOME (l1, k1, v1, r1), SOME (l2, k2, v2, r2)) => + case Key.comparePriority (k1, k2) of + GREATER => makeNode l1 (k1, v1) (join (r1, t2)) + | _ => makeNode (join (t1, l2)) (k2, v2) r2 + + local + fun fail k = + raise Fail ("ChunkedTreap.updateKey: key not found: " ^ Key.toString k) + in + fun updateKey t (k, v) = ((*print ("updateKey " ^ toString t ^ " (" ^ Key.toString k ^ ", ...)" ^ "\n");*) + case t of + Empty => fail k + + | Chunk arr => + let + val n = A.length arr + fun loop i = + if i >= n then + fail k + else case Key.compare (k, #1 (A.sub (arr, i))) of + EQUAL => A.update (arr, i, (k, v)) + | GREATER => loop (i+1) + | LESS => fail k + in + Chunk (loop 0) + end + + | Node {left, right, key, value, ...} => + case Key.compare (k, key) of + LESS => makeNode (updateKey left (k, v)) (key, value) right + | GREATER => makeNode left (key, value) (updateKey right (k, v)) + | EQUAL => makeNode left (k, v) right) + end + + + fun lookup t k = + case t of + Empty => NONE + + | Chunk arr => + let + val n = A.length arr + fun loop i = + if i >= n then + NONE + else case Key.compare (k, #1 (A.sub (arr, i))) of + EQUAL => SOME (#2 (A.sub (arr, i))) + | GREATER => loop (i+1) + | LESS => NONE + in + loop 0 + end + + | Node {left, right, key, value, ...} => + case Key.compare (k, key) of + LESS => lookup left k + | GREATER => lookup right k + | EQUAL => SOME value + +end diff --git a/tests/mpllib/Color.sml b/tests/mpllib/Color.sml new file mode 100644 index 000000000..b3b0b3992 --- /dev/null +++ b/tests/mpllib/Color.sml @@ -0,0 +1,176 @@ +structure Color = +struct + type channel = Word8.word + type pixel = {red: channel, green: channel, blue: channel} + + val white: pixel = {red=0w255, green=0w255, blue=0w255} + val black: pixel = {red=0w0, green=0w0, blue=0w0} + val red: pixel = {red=0w255, green=0w0, blue=0w0} + val blue: pixel = {red=0w0, green=0w0, blue=0w255} + + fun packInt ({red, green, blue}: pixel) = + let + fun b x = Word8.toInt x + in + (65536 * b red) + (256 * b green) + b blue + end + + fun compare (p1, p2) = Int.compare (packInt p1, packInt p2) + + fun equal (p1, p2) = (compare (p1, p2) = EQUAL) + + (* Based on the "low-cost approximation" given at + * https://www.compuphase.com/cmetric.htm *) + fun approxHumanPerceptionDistance + ({red=r1, green=g1, blue=b1}, {red=r2, green=g2, blue=b2}) = + let + fun c x = Word8.toInt x + val (r1, g1, b1, r2, g2, b2) = (c r1, c g1, c b1, c r2, c g2, c b2) + + val rmean = (r1 + r2) div 2 + val r = r1 - r2 + val g = g1 - g2 + val b = b1 - b2 + in + Math.sqrt (Real.fromInt + ((((512+rmean)*r*r) div 256) + 4*g*g + (((767-rmean)*b*b) div 256))) + end + + local + fun c x = Word8.toInt x + fun sq x = x * x + in + fun sqDistance ({red=r1, green=g1, blue=b1}, {red=r2, green=g2, blue=b2}) = + sq (c r2 - c r1) + sq (c g2 - c g1) + sq (c b2 - c b1) + end + + fun distance (p1, p2) = Math.sqrt (Real.fromInt (sqDistance (p1, p2))) + + fun to256 rchannel = + Word8.fromInt (Real.ceil (rchannel * 255.0)) + + fun from256 channel = + Real.fromInt (Word8.toInt channel) / 255.0 + + + (** hue in range [0,360) + * 0-------60-------120-------180-------240-------300-------360 + * red yellow green cyan blue purple red + * + * saturation in range [0,1] + * 0--------------1 + * grayscale vibrant + * + * value in range [0,1] + * 0--------------1 + * dark light + *) + fun hsv {h: real, s: real, v: real}: pixel = + let + val H = h + val S = s + val V = v + + (* from https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB *) + val C = V * S + val H' = H / 60.0 + val X = C * (1.0 - Real.abs (Real.rem (H', 2.0) - 1.0)) + + val (R1, G1, B1) = + if H' < 1.0 then (C, X, 0.0) + else if H' < 2.0 then (X, C, 0.0) + else if H' < 3.0 then (0.0, C, X) + else if H' < 4.0 then (0.0, X, C) + else if H' < 5.0 then (X, 0.0, C) + else (C, 0.0, X) + + val m = V - C + in + {red = to256 (R1 + m), green = to256 (G1 + m), blue = to256 (B1 + m)} + end + + (* ======================================================================= *) + + type color = {red: real, green: real, blue: real, alpha: real} + + (** hue in range [0,360) + * 0-------60-------120-------180-------240-------300-------360 + * red yellow green cyan blue purple red + * + * saturation in range [0,1] + * 0--------------1 + * grayscale vibrant + * + * value in range [0,1] + * 0--------------1 + * dark light + * + * alpha in range [0,1] + * 0--------------1 + * transparent opaque + *) + fun hsva {h: real, s: real, v: real, a: real}: color = + let + val H = h + val S = s + val V = v + + (* from https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB *) + val C = V * S + val H' = H / 60.0 + val X = C * (1.0 - Real.abs (Real.rem (H', 2.0) - 1.0)) + + val (R1, G1, B1) = + if H' < 1.0 then (C, X, 0.0) + else if H' < 2.0 then (X, C, 0.0) + else if H' < 3.0 then (0.0, C, X) + else if H' < 4.0 then (0.0, X, C) + else if H' < 5.0 then (X, 0.0, C) + else (C, 0.0, X) + + val m = V - C + in + {red = R1 + m, green = G1 + m, blue = B1 + m, alpha = a} + end + + fun overlayColor {fg: color, bg: color} = + let + val alpha = 1.0 - (1.0 - #alpha fg) * (1.0 - #alpha bg) + in + if alpha < 1e~6 then + (* essentially fully transparent, color doesn't matter *) + { red = 0.0, green = 0.0, blue = 0.0, alpha = alpha } + else + let + val red = + #red fg * #alpha fg / alpha + + #red bg * #alpha bg * (1.0 - #alpha fg) / alpha + + val green = + #green fg * #alpha fg / alpha + + #green bg * #alpha bg * (1.0 - #alpha fg) / alpha + + val blue = + #blue fg * #alpha fg / alpha + + #blue bg * #alpha bg * (1.0 - #alpha fg) / alpha + in + {red=red, green=green, blue=blue, alpha=alpha} + end + end + + (** Converts a color with transparency to a concrete pixel. Assumes white + * background. + *) + fun colorToPixel (color: color) : pixel = + let + val white = {red = 1.0, blue = 1.0, green = 1.0, alpha = 1.0} + (* alpha will be 1 *) + val {red, green, blue, ...} = overlayColor {fg = color, bg = white} + in + { red = to256 red, green = to256 green, blue = to256 blue} + end + + fun pixelToColor {red, green, blue} = + {red = from256 red, green = from256 green, blue = from256 blue, alpha = 1.0} + +end diff --git a/tests/mpllib/CommandLineArgs.sml b/tests/mpllib/CommandLineArgs.sml new file mode 100644 index 000000000..c89f89052 --- /dev/null +++ b/tests/mpllib/CommandLineArgs.sml @@ -0,0 +1,104 @@ +structure CommandLineArgs : +sig + (* each takes a key K and a default value D, looks for -K V in the + * command-line arguments, and returns V if it finds it, or D otherwise. *) + val parseString: string -> string -> string + val parseInt: string -> int -> int + val parseReal: string -> real -> real + val parseBool: string -> bool -> bool + + (** Look for every instance of -K V and return seq of the Vs. + * For example, if this is given on the commandline: + * -arg a -arg b -arg c -arg d + * then + * parseStrings "arg" ==> ["a", "b", "c", "d"] + *) + val parseStrings: string -> string list + + (* parseFlag K returns true if --K given on command-line *) + val parseFlag: string -> bool + + val positional: unit -> string list +end = +struct + + fun die msg = + ( TextIO.output (TextIO.stdErr, msg ^ "\n") + ; TextIO.flushOut TextIO.stdErr + ; OS.Process.exit OS.Process.failure + ) + + fun positional () = + let + fun loop found rest = + case rest of + [] => List.rev found + | [x] => List.rev (if not (String.isPrefix "-" x) then x::found else found) + | x::y::rest' => + if not (String.isPrefix "-" x) then + loop (x::found) (y::rest') + else if String.isPrefix "--" x then + loop found (y::rest') + else + loop found rest' + in + loop [] (CommandLine.arguments ()) + end + + fun search key args = + case args of + [] => NONE + | x :: args' => + if key = x + then SOME args' + else search key args' + + fun parseString key default = + case search ("-" ^ key) (CommandLine.arguments ()) of + NONE => default + | SOME [] => die ("Missing argument of \"-" ^ key ^ "\" ") + | SOME (s :: _) => s + + fun parseStrings key = + let + fun loop args = + case search ("-" ^ key) args of + NONE => [] + | SOME [] => die ("Missing argument of \"-" ^ key ^ "\"") + | SOME (v :: args') => v :: loop args' + in + loop (CommandLine.arguments ()) + end + + fun parseInt key default = + case search ("-" ^ key) (CommandLine.arguments ()) of + NONE => default + | SOME [] => die ("Missing argument of \"-" ^ key ^ "\" ") + | SOME (s :: _) => + case Int.fromString s of + NONE => die ("Cannot parse integer from \"-" ^ key ^ " " ^ s ^ "\"") + | SOME x => x + + fun parseReal key default = + case search ("-" ^ key) (CommandLine.arguments ()) of + NONE => default + | SOME [] => die ("Missing argument of \"-" ^ key ^ "\" ") + | SOME (s :: _) => + case Real.fromString s of + NONE => die ("Cannot parse real from \"-" ^ key ^ " " ^ s ^ "\"") + | SOME x => x + + fun parseBool key default = + case search ("-" ^ key) (CommandLine.arguments ()) of + NONE => default + | SOME [] => die ("Missing argument of \"-" ^ key ^ "\" ") + | SOME ("true" :: _) => true + | SOME ("false" :: _) => false + | SOME (s :: _) => die ("Cannot parse bool from \"-" ^ key ^ " " ^ s ^ "\"") + + fun parseFlag key = + case search ("--" ^ key) (CommandLine.arguments ()) of + NONE => false + | SOME _ => true + +end diff --git a/tests/mpllib/CountingSort.sml b/tests/mpllib/CountingSort.sml new file mode 100644 index 000000000..71719a7da --- /dev/null +++ b/tests/mpllib/CountingSort.sml @@ -0,0 +1,129 @@ +structure CountingSort :> +sig + type 'a seq = 'a ArraySlice.slice + + val sort : 'a seq + -> (int -> int) (* bucket id of ith element *) + -> int (* number of buckets *) + -> 'a seq * int seq (* sorted, bucket offsets *) +end = +struct + + structure A = Array + structure AS = ArraySlice + + type 'a seq = 'a ArraySlice.slice + + val for = Util.for + val loop = Util.loop + val forBackwards = Util.forBackwards + + fun seqSortInternal In Out Keys Counts genOffsets = + let + val n = AS.length In + val m = AS.length Counts + (* val _ = print ("seqSortInternal n=" ^ Int.toString n ^ " m=" ^ Int.toString m ^ "\n") *) + val sub = AS.sub + val update = AS.update + in + for (0, m) (fn i => update (Counts,i,0)); + + for (0, n) (fn i => + let + val j = Keys i + (* val _ = print ("update " ^ Int.toString j ^ "\n") *) + in + update (Counts, j, sub(Counts,j) + 1) + end); + + (* print ("counts: " ^ Seq.toString Int.toString Counts ^ "\n"); *) + + loop (0, m) 0 (fn (s,i) => + let + val t = sub(Counts, i) + in + update(Counts, i, s); + s + t + end); + + (* print ("counts: " ^ Seq.toString Int.toString Counts ^ "\n"); *) + + for (0, n) (fn i => + let + val j = Keys(i) + val k = sub(Counts, j) + in + update(Counts, j, k+1); + update(Out, k, sub(In, i)) + end); + + if genOffsets then + (forBackwards (0,m-1) (fn i => + update(Counts,i+1,sub(Counts,i))); + update(Counts,0,0); 0) + else + loop (0, m) 0 (fn (s,i) => + let + val t = sub(Counts, i) + in + (update(Counts, i, t - s); t) + end) + end + + fun seqSort(In, Keys, numBuckets) = + let + val Counts = AS.full(ForkJoin.alloc (numBuckets+1)) + val Out = AS.full(ForkJoin.alloc (AS.length In)) + in + seqSortInternal In Out Keys (Seq.subseq Counts (0,numBuckets)) true; + AS.update(Counts, numBuckets, AS.length In); + (Out, Counts) + end + + fun sort In Keys numBuckets = + let + val SeqThreshold = 8192 + val BlockFactor = 32 + val n = AS.length In + (* pad to avoid false sharing *) + val numBucketsPad = Int.max(numBuckets, 16) + val sqrt = Real.floor(Math.sqrt(Real.fromInt n)) + val numBlocks = n div (numBuckets * BlockFactor) + in + if (numBlocks <= 1 orelse n < SeqThreshold) then + seqSort(In, Keys, numBuckets) + else let + val blockSize = ((n-1) div numBlocks) + 1; + val m = numBlocks * numBucketsPad + val B = AS.full(ForkJoin.alloc(AS.length In)) + val Counts = AS.full(ForkJoin.alloc(m)) + val _ = ForkJoin.parfor 1 (0, numBlocks) (fn i => + let + val start = Int.min(i * blockSize, n) + val len = Int.min((i+1)* blockSize, n) - start + in + seqSortInternal + (AS.subslice(In, start, SOME(len))) + (AS.subslice(B, start, SOME(len))) + (fn i => Keys(i+start)) + (AS.subslice(Counts,i*numBucketsPad,SOME(numBucketsPad))) + false; + () + end) + val (sourceOffsets, _) = Seq.scan op+ 0 Counts + val transCounts = SampleSort.transpose(Counts, numBlocks, + numBucketsPad) + val (destOffsets, _) = Seq.scan op+ 0 transCounts + val C = SampleSort.transposeBlocks(B, sourceOffsets, destOffsets, + Counts, numBlocks, numBucketsPad, n) + val bucketOffsets = + Seq.tabulate (fn i => + if (i = numBuckets) then n + else AS.sub (destOffsets, i * numBlocks)) + (numBuckets+1) + in + (C, bucketOffsets) + end + end + +end diff --git a/tests/mpllib/DelayedSeq.sml b/tests/mpllib/DelayedSeq.sml new file mode 100644 index 000000000..901235208 --- /dev/null +++ b/tests/mpllib/DelayedSeq.sml @@ -0,0 +1,414 @@ +functor MkDelayedSeq (Stream: STREAM) : SEQUENCE = +struct + + exception NYI + exception Range + exception Size + + (* structure Stream = DelayedStream *) + + val for = Util.for + val par = ForkJoin.par + val parfor = ForkJoin.parfor + val alloc = ForkJoin.alloc + + val gran = 5000 + val blockSize = 5000 + fun numBlocks n = Util.ceilDiv n blockSize + + fun blockStart b n = b * blockSize + fun blockEnd b n = Int.min (n, (b+1) * blockSize) + fun getBlockSize b n = blockEnd b n - blockStart b n + fun convertToBlockIdx i n = + (i div blockSize, i mod blockSize) + + structure A = + struct + open Array + type 'a t = 'a array + fun nth a i = sub (a, i) + end + + structure AS = + struct + open ArraySlice + type 'a t = 'a slice + fun nth a i = sub (a, i) + end + + + type 'a rad = int * int * (int -> 'a) + type 'a bid = int * (int -> 'a Stream.t) + datatype 'a seq = + Full of 'a AS.t + | Rad of 'a rad + | Bid of 'a bid + + type 'a t = 'a seq + + + fun radlength (start, stop, _) = stop-start + fun radnth (start, _, f) i = f (start+i) + + + fun length s = + case s of + Full slice => AS.length slice + | Rad rad => radlength rad + | Bid (n, _) => n + + + fun nth s i = + case s of + Full slice => AS.nth slice i + | Rad rad => radnth rad i + | Bid (n, getBlock) => + let + val (outer, inner) = convertToBlockIdx i n + in + Stream.nth (getBlock outer) inner + end + + + fun bidify (s: 'a seq) : 'a bid = + let + fun block start nth b = + Stream.tabulate (fn i => nth (start + b * blockSize + i)) + in + case s of + Full slice => + let + val (a, start, n) = AS.base slice + in + (n, block start (A.nth a)) + end + + | Rad (start, stop, nth) => + (stop-start, block start nth) + + | Bid xx => xx + end + + + fun applyIdx (s: 'a seq) (g: int * 'a -> unit) = + let + val (n, getBlock) = bidify s + in + parfor 1 (0, numBlocks n) (fn b => + let + val lo = blockStart b n + in + Stream.applyIdx (getBlockSize b n, getBlock b) (fn (j, x) => g (lo+j, x)) + end) + end + + + fun apply (s: 'a seq) (g: 'a -> unit) = + applyIdx s (fn (_, x) => g x) + + + fun reify s = + let + val a = alloc (length s) + in + applyIdx s (fn (i, x) => A.update (a, i, x)); + AS.full a + end + + + fun force s = Full (reify s) + + + fun radify s = + case s of + Full slice => + let + val (a, i, n) = AS.base slice + in + (i, i+n, A.nth a) + end + + | Rad xx => xx + + | Bid (n, blocks) => + radify (force s) + + + fun tabulate f n = + Rad (0, n, f) + + + fun fromList xs = + Full (AS.full (Array.fromList xs)) + + + fun % xs = + fromList xs + + + fun singleton x = + Rad (0, 1, fn _ => x) + + + fun $ x = + singleton x + + + fun empty () = + fromList [] + + + fun fromArraySeq a = + Full a + + + fun range (i, j) = + Rad (i, j, fn k => k) + + + fun toArraySeq s = + case s of + Full x => x + | _ => reify s + + + fun map f s = + case s of + Full _ => map f (Rad (radify s)) + | Rad (i, j, g) => Rad (i, j, f o g) + | Bid (n, getBlock) => Bid (n, Stream.map f o getBlock) + + + fun mapIdx f s = + case s of + Full _ => mapIdx f (Rad (radify s)) + | Rad (i, j, g) => Rad (0, j-i, fn k => f (k, g (i+k))) + | Bid (n, getBlock) => + Bid (n, fn b => + Stream.mapIdx (fn (i, x) => f (b*blockSize + i, x)) (getBlock b)) + + + fun enum s = + mapIdx (fn (i,x) => (i,x)) s + + + fun flatten (ss: 'a seq seq) : 'a seq = + let + val numChildren = length ss + val children: 'a rad AS.t = reify (map radify ss) + val offsets = + SeqBasis.scan gran op+ 0 (0, numChildren) (radlength o AS.nth children) + val totalLen = A.nth offsets numChildren + fun offset i = A.nth offsets i + + val getBlock = + Stream.makeBlockStreams + { blockSize = blockSize + , numChildren = numChildren + , offset = offset + , getElem = (fn i => fn j => radnth (AS.nth children i) j) + } + in + Bid (totalLen, getBlock) + end + + + fun mapOption (f: 'a -> 'b option) (s: 'a seq) = + let + val (n, getBlock) = bidify s + val nb = numBlocks n + val packed: 'b rad array = + SeqBasis.tabulate 1 (0, nb) (fn b => + radify (Full (Stream.pack f (getBlockSize b n, getBlock b))) + ) + val offsets = + SeqBasis.scan gran op+ 0 (0, nb) (radlength o A.nth packed) + val totalLen = A.nth offsets nb + fun offset i = A.nth offsets i + + val getBlock' = + Stream.makeBlockStreams + { blockSize = blockSize + , numChildren = nb + , offset = offset + , getElem = (fn i => fn j => radnth (A.nth packed i) j) + } + in + Bid (totalLen, getBlock') + end + + + fun filter p s = + mapOption (fn x => if p x then SOME x else NONE) s + + + fun inject (s, u) = + let + val a = reify s + val (base, i, _) = AS.base a + in + apply u (fn (j, x) => Array.update (base, i+j, x)); + Full a + end + + + fun bidZipWith f (s1, s2) = + let + val (n, getBlock1) = bidify s1 + val (_, getBlock2) = bidify s2 + in + Bid (n, fn b => Stream.zipWith f (getBlock1 b, getBlock2 b)) + end + + fun radZipWith f (s1, s2) = + let + val (lo1, hi1, nth1) = radify s1 + val (lo2, _, nth2) = radify s2 + in + Rad (0, hi1-lo1, fn i => f (nth1 (lo1+i), nth2 (lo2+i))) + end + + fun zipWith f (s1, s2) = + if length s1 <> length s2 then raise Size else + case (s1, s2) of + (Bid _, _) => bidZipWith f (s1, s2) + | (_, Bid _) => bidZipWith f (s1, s2) + | _ => radZipWith f (s1, s2) + + fun zip (s1, s2) = + zipWith (fn (x, y) => (x, y)) (s1, s2) + + + fun scan f z s = + let + val (n, getBlock) = bidify s + val nb = numBlocks n + val blockSums = + SeqBasis.tabulate 1 (0, nb) (fn b => + Stream.iterate f z (getBlockSize b n, getBlock b) + ) + val p = SeqBasis.scan gran f z (0, nb) (A.nth blockSums) + val t = A.nth p nb + val r = Bid (n, fn b => Stream.iteratePrefixes f (A.nth p b) (getBlock b)) + in + (r, t) + end + + + fun scanIncl f z s = + let + val (n, getBlock) = bidify s + val nb = numBlocks n + val blockSums = + SeqBasis.tabulate 1 (0, nb) (fn b => + Stream.iterate f z (getBlockSize b n, getBlock b) + ) + val p = SeqBasis.scan gran f z (0, nb) (A.nth blockSums) + in + Bid (n, fn b => + Stream.iteratePrefixesIncl f (A.nth p b) (getBlock b)) + end + + + fun reduce f z s = + case s of + Full xx => SeqBasis.reduce gran f z (0, length s) (AS.nth xx) + | Rad xx => SeqBasis.reduce gran f z (0, length s) (radnth xx) + | Bid (n, getBlock) => + let + val nb = numBlocks n + in + SeqBasis.reduce gran f z (0, nb) (fn b => + Stream.iterate f z (getBlockSize b n, getBlock b) + ) + end + + + fun iterate f z s = + case s of + Full xx => SeqBasis.foldl f z (0, length s) (AS.nth xx) + | Rad xx => SeqBasis.foldl f z (0, length s) (radnth xx) + | Bid (n, getBlock) => + Util.loop (0, numBlocks n) z (fn (z, b) => + Stream.iterate f z (getBlockSize b n, getBlock b)) + + + fun rev s = + let + val n = length s + val rads = radify s + in + tabulate (fn i => radnth rads (n-i-1)) n + end + + + fun append (s, t) = + let + val n = length s + val m = length t + + val rads = radify s + val radt = radify t + + fun elem i = if i < n then radnth rads i else radnth radt (i-n) + in + tabulate elem (n+m) + end + + + fun subseq s (i, len) = + if i < 0 orelse len < 0 orelse i+len > length s then + raise Subscript + else + let + val n = length s + val (start, stop, nth) = radify s + in + Rad (start+i, start+i+len, nth) + end + + + fun take s n = subseq s (0, n) + fun drop s n = subseq s (n, length s - n) + + + fun toList s = + List.rev (iterate (fn (elems, x) => x :: elems) [] s) + + fun toString f s = + "[" ^ String.concatWith "," (toList (map f s)) ^ "]" + + (* ===================================================================== *) + + datatype 'a listview = NIL | CONS of 'a * 'a seq + datatype 'a treeview = EMPTY | ONE of 'a | PAIR of 'a seq * 'a seq + + type 'a ord = 'a * 'a -> order + type 'a t = 'a seq + + fun filterIdx x = raise NYI + fun iterateIdx x = raise NYI + + fun argmax x = raise NYI + fun collate x = raise NYI + fun collect x = raise NYI + fun equal x = raise NYI + fun iteratePrefixes x = raise NYI + fun iteratePrefixesIncl x = raise NYI + fun merge x = raise NYI + fun sort x = raise NYI + fun splitHead x = raise NYI + fun splitMid x = raise NYI + fun update x = raise NYI + fun zipWith3 x = raise NYI + + fun filterSome x = raise NYI + fun foreach x = raise NYI + fun foreachG x = raise NYI + +end + + + +structure DelayedSeq = MkDelayedSeq (DelayedStream) +(* structure DelayedSeq = MkDelayedSeq (RecursiveStream) *) diff --git a/tests/mpllib/DelayedStream.sml b/tests/mpllib/DelayedStream.sml new file mode 100644 index 000000000..4ffc74569 --- /dev/null +++ b/tests/mpllib/DelayedStream.sml @@ -0,0 +1,225 @@ +structure DelayedStream :> STREAM = +struct + + (** A stream is a generator for a stateful trickle function: + * trickle = stream () + * x0 = trickle 0 + * x1 = trickle 1 + * x2 = trickle 2 + * ... + * + * The integer argument is just an optimization (it could be packaged + * up into the state of the trickle function, but doing it this + * way is more efficient). Requires passing `i` on the ith call + * to trickle. + *) + type 'a t = unit -> int -> 'a + type 'a stream = 'a t + + + fun nth stream i = + let + val trickle = stream () + + fun loop j = + let + val x = trickle j + in + if j = i then x else loop (j+1) + end + in + loop 0 + end + + + fun tabulate f = + fn () => f + + + fun map g stream = + fn () => + let + val trickle = stream () + in + g o trickle + end + + + fun mapIdx g stream = + fn () => + let + val trickle = stream () + in + fn idx => g (idx, trickle idx) + end + + + fun applyIdx (length, stream) g = + let + val trickle = stream () + fun loop i = + if i >= length then () else (g (i, trickle i); loop (i+1)) + in + loop 0 + end + + + fun resize arr = + let + val newCapacity = 2 * Array.length arr + val dst = ForkJoin.alloc newCapacity + in + Array.copy {src = arr, dst = dst, di = 0}; + dst + end + + + (** simple but less efficient: accumulate in list *) + (*fun pack f (length, stream) = + let + val trickle = stream () + + fun loop (data, count) i = + if i < length then + case f (trickle i) of + SOME y => + loop (y :: data, count+1) (i+1) + | NONE => + loop (data, count) (i+1) + else + (data, count) + + val (data, count) = loop ([], 0) 0 + in + ArraySlice.full (Array.fromList (List.rev data)) + (* ArraySlice.slice (data, 0, SOME count) *) + end*) + + + (** more efficient: accumulate in dynamic resizing array *) + fun pack f (length, stream) = + let + val trickle = stream () + + fun loop (data, next) i = + if i < length andalso next < Array.length data then + case f (trickle i) of + SOME y => + ( Array.update (data, next, y) + ; loop (data, next+1) (i+1) + ) + | NONE => + loop (data, next) (i+1) + + else if next >= Array.length data then + loop (resize data, next) i + + else + (data, next) + + val (data, count) = loop (ForkJoin.alloc 10, 0) 0 + in + ArraySlice.slice (data, 0, SOME count) + end + + + fun iterate g b (length, stream) = + let + val trickle = stream () + fun loop b i = + if i >= length then b else loop (g (b, trickle i)) (i+1) + in + loop b 0 + end + + + fun iteratePrefixes g b stream = + fn () => + let + val trickle = stream () + val stuff = ref b + in + fn idx => + let + val acc = !stuff + val elem = trickle idx + val acc' = g (acc, elem) + in + stuff := acc'; + acc + end + end + + + fun iteratePrefixesIncl g b stream = + fn () => + let + val trickle = stream () + val stuff = ref b + in + fn idx => + let + val acc = !stuff + val elem = trickle idx + val acc' = g (acc, elem) + in + stuff := acc'; + acc' + end + end + + + fun zipWith g (s1, s2) = + fn () => + let + val trickle1 = s1 () + val trickle2 = s2 () + in + fn idx => g (trickle1 idx, trickle2 idx) + end + + + fun makeBlockStreams + { blockSize: int + , numChildren: int + , offset: int -> int + , getElem: int -> int -> 'a + } = + let + fun getBlock blockIdx = + let + fun advanceUntilNonEmpty i = + if i >= numChildren orelse offset i <> offset (i+1) then + i + else + advanceUntilNonEmpty (i+1) + in + fn () => + let + val lo = blockIdx * blockSize + val firstOuterIdx = + OffsetSearch.indexSearch (0, numChildren, offset) lo + val outerIdx = ref firstOuterIdx + in + fn idx => + (let + val i = !outerIdx + val j = lo + idx - offset i + (* val j = !innerIdx *) + val elem = getElem i j + in + if offset i + j + 1 < offset (i+1) then + () + else + outerIdx := advanceUntilNonEmpty (i+1); + elem + end) + end + end + + in + getBlock + end + + +end diff --git a/tests/mpllib/DoubleBinarySearch.sml b/tests/mpllib/DoubleBinarySearch.sml new file mode 100644 index 000000000..e96160ad1 --- /dev/null +++ b/tests/mpllib/DoubleBinarySearch.sml @@ -0,0 +1,123 @@ +(* The function `split_count` splits sequences s and t into (s1, s2) and + * (t1, t2) such that the largest items of s1 and t1 are smaller than the + * smallest items of s2 and t2. The desired output size |s1|+|t1| is given + * as a parameter. + * + * Specifically, `split_count cmp (s, t) k` returns `(m, n)` where: + * (s1, s2) = (s[..m], s[m..]) + * (t1, t2) = (t[..n], t[n..]) + * m+n = k + * max(s1) <= min(t2) + * max(t1) <= min(s2) + * + * Note that there are many possible solutions, so we also mandate that `m` + * should be minimized. + * + * Work: O(log(|s|+|t|)) + * Span: O(log(|s|+|t|)) + *) +structure DoubleBinarySearch: +sig + type 'a seq = {lo: int, hi: int, get: int -> 'a} + + val split_count: ('a * 'a -> order) -> 'a seq * 'a seq -> int -> (int * int) + + val split_count_slice: ('a * 'a -> order) + -> 'a ArraySlice.slice * 'a ArraySlice.slice + -> int + -> (int * int) +end = +struct + + type 'a seq = {lo: int, hi: int, get: int -> 'a} + + fun leq cmp (x, y) = + case cmp (x, y) of + GREATER => false + | _ => true + + fun geq cmp (x, y) = + case cmp (x, y) of + LESS => false + | _ => true + + fun split_count cmp (s: 'a seq, t: 'a seq) k = + let + fun normalize_then_loop (slo, shi) (tlo, thi) count = + let + val slo_orig = slo + val tlo_orig = tlo + + (* maybe count is small *) + val shi = Int.min (shi, slo + count) + val thi = Int.min (thi, tlo + count) + + (* maybe count is large *) + val slack = (shi - slo) + (thi - tlo) - count + val slack = Int.min (slack, shi - slo) + val slack = Int.min (slack, thi - tlo) + + val slo = Int.max (slo, shi - slack) + val tlo = Int.max (tlo, thi - slack) + + val count = count - (slo - slo_orig) - (tlo - tlo_orig) + in + loop (slo, shi) (tlo, thi) count + end + + + and loop (slo, shi) (tlo, thi) count = + if shi - slo <= 0 then + (slo, tlo + count) + + else if thi - tlo <= 0 then + (slo + count, tlo) + + else if count = 1 then + if geq cmp (#get s slo, #get t tlo) then (slo, tlo + 1) + else (slo + 1, tlo) + + else + let + val m = count div 2 + val n = count - m + + (* |------|x|-------| + * ^ ^ ^ + * slo slo+m shi + * + * |------|y|-------| + * ^ ^ ^ + * tlo tlo+n thi + *) + + val leq_y_x = + n = 0 orelse slo + m >= shi + orelse leq cmp (#get t (tlo + n - 1), #get s (slo + m)) + in + if leq_y_x then + normalize_then_loop (slo, shi) (tlo + n, thi) (count - n) + else + normalize_then_loop (slo, shi) (tlo, tlo + n) count + end + + + val {lo = slo, hi = shi, ...} = s + val {lo = tlo, hi = thi, ...} = t + + val (m, n) = normalize_then_loop (slo, shi) (tlo, thi) k + in + (m - slo, n - tlo) + end + + + fun fromslice s = + let val (sarr, slo, slen) = ArraySlice.base s + in {lo = slo, hi = slo + slen, get = fn i => Array.sub (sarr, i)} + end + + + fun split_count_slice cmp (s, t) k = + split_count cmp (fromslice s, fromslice t) k + +end diff --git a/tests/mpllib/ExtraBinIO.sml b/tests/mpllib/ExtraBinIO.sml new file mode 100644 index 000000000..1e0a7a131 --- /dev/null +++ b/tests/mpllib/ExtraBinIO.sml @@ -0,0 +1,73 @@ +structure ExtraBinIO = +struct + + fun w8 file (w: Word8.word) = BinIO.output1 (file, w) + + fun w64b file (w: Word64.word) = + let + val w8 = w8 file + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge (w >> 0w56)); + w8 (Word8.fromLarge (w >> 0w48)); + w8 (Word8.fromLarge (w >> 0w40)); + w8 (Word8.fromLarge (w >> 0w32)); + w8 (Word8.fromLarge (w >> 0w24)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + + fun w32b file (w: Word32.word) = + let + val w8 = w8 file + val w = Word32.toLarge w + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge (w >> 0w24)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + + fun w32l file (w: Word32.word) = + let + val w8 = w8 file + val w = Word32.toLarge w + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge w); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w24)) + end + + fun w16b file (w: Word16.word) = + let + val w8 = w8 file + val w = Word16.toLarge w + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + + fun w16l file (w: Word16.word) = + let + val w8 = w8 file + val w = Word16.toLarge w + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge w); + w8 (Word8.fromLarge (w >> 0w8)) + end + + fun wrgb file ({red, green, blue}: Color.pixel) = + ( w8 file red; w8 file green; w8 file blue ) + +end diff --git a/tests/mpllib/FindFirst.sml b/tests/mpllib/FindFirst.sml new file mode 100644 index 000000000..8302b8824 --- /dev/null +++ b/tests/mpllib/FindFirst.sml @@ -0,0 +1,39 @@ +structure FindFirst : +sig + val findFirstSerial : (int * int) -> (int -> bool) -> int option + + (* findFirst granularity (start, end) predicate *) + val findFirst : int -> (int * int) -> (int -> bool) -> int option +end = +struct + + fun findFirstSerial (i, j) p = + if i >= j then NONE + else if p i then SOME i + else findFirstSerial (i+1, j) p + + fun optMin (a, b) = + case (a, b) of + (SOME x, SOME y) => (SOME (Int.min (x, y))) + | (NONE, _) => b + | (_, NONE) => a + + fun findFirst grain (lo, hi) p = + let + fun try (i, j) = + if j - i <= grain then + findFirstSerial (i, j) p + else + SeqBasis.reduce grain optMin NONE (i, j) + (fn k => if p k then SOME k else NONE) + + fun loop (i, j) = + if i >= j then NONE else + case try (i, j) of + NONE => loop (j, Int.min (j + 2*(j-i), hi)) + | SOME x => SOME x + in + loop (lo, Int.min (lo+grain, hi)) + end + +end diff --git a/tests/mpllib/FlattenMerge.sml b/tests/mpllib/FlattenMerge.sml new file mode 100644 index 000000000..abe54390a --- /dev/null +++ b/tests/mpllib/FlattenMerge.sml @@ -0,0 +1,38 @@ +structure FlattenMerge: +sig + val merge: ('a * 'a -> order) -> 'a Seq.t * 'a Seq.t -> 'a Seq.t +end = +struct + + val serialGrain = CommandLineArgs.parseInt "MPLLib_Merge_serialGrain" 4000 + + fun merge_loop cmp (s1, s2) = + if Seq.length s1 = 0 then + TFlatten.leaf s2 + else if Seq.length s1 + Seq.length s2 <= serialGrain then + TFlatten.leaf (Merge.mergeSerial cmp (s1, s2)) + else + let + val n1 = Seq.length s1 + val n2 = Seq.length s2 + val mid1 = n1 div 2 + val pivot = Seq.nth s1 mid1 + val mid2 = BinarySearch.search cmp s2 pivot + + val l1 = Seq.take s1 mid1 + val r1 = Seq.drop s1 (mid1 + 1) + val l2 = Seq.take s2 mid2 + val r2 = Seq.drop s2 mid2 + + val (outl, outr) = + ForkJoin.par (fn _ => merge_loop cmp (l1, l2), fn _ => + merge_loop cmp (r1, r2)) + in + TFlatten.node + (TFlatten.node (outl, TFlatten.leaf (Seq.singleton pivot)), outr) + end + + fun merge cmp (s1, s2) = + TFlatten.flatten (merge_loop cmp (s1, s2)) + +end diff --git a/tests/mpllib/FuncSequence.sml b/tests/mpllib/FuncSequence.sml new file mode 100644 index 000000000..8e1434136 --- /dev/null +++ b/tests/mpllib/FuncSequence.sml @@ -0,0 +1,46 @@ +structure FuncSequence: +sig + type 'a t + type 'a seq = 'a t + + val empty: unit -> 'a seq + val length: 'a seq -> int + val take: 'a seq -> int -> 'a seq + val drop: 'a seq -> int -> 'a seq + val nth: 'a seq -> int -> 'a + val iterate: ('b * 'a -> 'b) -> 'b -> 'a seq -> 'b + val tabulate: (int -> 'a) -> int -> 'a seq + val reduce: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a +end = +struct + (* (i, j, f) defines the sequence [ f(k) : i <= k < j ] *) + type 'a t = int * int * (int -> 'a) + type 'a seq = 'a t + + fun empty () = (0, 0, fn _ => raise Subscript) + fun length (i, j, _) = j - i + fun nth (i, j, f) k = f (i+k) + fun take (i, j, f) k = (i, i+k, f) + fun drop (i, j, f) k = (i+k, j, f) + + fun tabulate f n = (0, n, f) + + fun iterate f b s = + if length s = 0 then b + else iterate f (f (b, nth s 0)) (drop s 1) + + fun reduce f b s = + case length s of + 0 => b + | 1 => nth s 0 + | n => let + val half = n div 2 + val (l, r) = + ForkJoin.par (fn _ => reduce f b (take s half), + fn _ => reduce f b (drop s half)) + in + f (l, r) + end +end + + diff --git a/tests/mpllib/GIF.sml b/tests/mpllib/GIF.sml new file mode 100644 index 000000000..b97836dcd --- /dev/null +++ b/tests/mpllib/GIF.sml @@ -0,0 +1,722 @@ +structure GIF: +sig + type pixel = Color.pixel + type image = {height: int, width: int, data: pixel Seq.t} + + (* A GIF color palette is a table of up to 256 colors, and + * function for remapping the colors of an image. *) + structure Palette: + sig + type t = {colors: pixel Seq.t, remap: image -> int Seq.t} + + (* Selects a set of "well-spaced" colors sampled from the image. + * The first argument is a list of required colors, that must be + * included in the palette. Second number is desired palette size. + *) + val summarize: pixel list -> int -> image -> t + + (* Same as `summarize`, except now just an arbitrary sampling function + * is given instead of an image. + *) + val summarizeBySampling: pixel list -> int -> (int -> pixel) -> t + + (* Selects a quantized color palette. *) + val quantized: (int * int * int) -> t + + val remapColor: t -> pixel -> int + end + + structure LZW: + sig + (* First step of compression. Remap an image with the given color + * palette, and then generate the LZW-compressed code stream. + * This inserts clear- and EOI codes. + * + * arguments are + * 1. number of colors, and + * 2. color indices (from palette remap) + *) + val codeStream: int -> int Seq.t -> int Seq.t + + (* Second step of compression: pack the code stream into bits with + * flexible bit-lengths. This step also inserts sub-block sizes. + * The first argument is the number of colors. *) + val packCodeStream: int -> int Seq.t -> Word8.word Seq.t + end + + (* Write many images as an animation. All images must be the same dimension. *) + val writeMany: string (* output path *) + -> int (* microsecond delay between images *) + -> Palette.t + -> {width: int, height: int, numImages: int, getImage: int -> int Seq.t} + -> unit + + val write: string -> image -> unit +end = +struct + + structure AS = ArraySlice + + type pixel = Color.pixel + type image = {height: int, width: int, data: pixel Seq.t} + + fun err msg = + raise Fail ("GIF: " ^ msg) + + fun stderr msg = + (TextIO.output (TextIO.stdErr, msg); TextIO.output (TextIO.stdErr, "\n")) + + fun ceilLog2 n = + if n <= 0 then + err "ceilLog2: expected input at least 1" + else + (* Util.log2(x) computes 1 + floor(log_2(x)) *) + Util.log2 (n-1) + + structure Palette = + struct + + type t = {colors: pixel Seq.t, remap: image -> int Seq.t} + + fun remapColor ({remap, ...}: t) px = + Seq.nth (remap {width=1, height=1, data=Seq.fromList [px]}) 0 + + fun makeQuantized (rqq, gqq, bqq) = + let + fun bucketSize numBuckets = + Real.floor (1.0 + 255.0 / Real.fromInt numBuckets) + fun bucketShift numBuckets = + Word8.fromInt ((255 - (numBuckets-1)*(bucketSize numBuckets)) div 2) + + fun qi nb = fn c => Word8.toInt c div (bucketSize nb) + fun qc nb = fn i => bucketShift nb + Word8.fromInt (i * bucketSize nb) + + fun makeQ nb = + { numBuckets = nb + , channelToIdx = qi nb + , channelFromIdx = qc nb + } + + val R = makeQ rqq + val G = makeQ gqq + val B = makeQ bqq + val numQuantized = (* this should be at most 256 *) + List.foldl op* 1 (List.map #numBuckets [R, G, B]) + + fun channelIndices {red, green, blue} = + (#channelToIdx R red, #channelToIdx G green, #channelToIdx B blue) + + fun packChannelIndices (r, g, b) = + b + + g * (#numBuckets B) + + r * (#numBuckets B) * (#numBuckets G) + + fun colorOfQuantizeIdx i = + let + val b = i mod (#numBuckets B) + val g = (i div (#numBuckets B)) mod (#numBuckets G) + val r = (i div (#numBuckets B) div (#numBuckets G)) mod (#numBuckets R) + in + { red = #channelFromIdx R r + , green = #channelFromIdx G g + , blue = #channelFromIdx B b + } + end + in + (numQuantized, channelIndices, packChannelIndices, colorOfQuantizeIdx) + end + + fun quantized qpackage = + let + val (numQuantized, channelIndices, pack, colorOfQuantizeIdx) = + makeQuantized qpackage + in + { colors = Seq.tabulate colorOfQuantizeIdx numQuantized + , remap = fn ({data, ...}: image) => + AS.full (SeqBasis.tabulate 1000 (0, Seq.length data) (fn i => + pack (channelIndices (Seq.nth data i)))) + } + end + + fun summarizeBySampling requiredColors paletteSize (sample: int -> Color.pixel) = + if paletteSize <= 0 then + err "summarize: palette size must be at least 1" + else if paletteSize > 256 then + err "summarize: max palette size is 256" + else if List.length requiredColors > paletteSize then + err "summarize: Too many required colors" + else + let + val dist = Color.sqDistance + + val dimBits = 3 + val dim = Util.pow2 dimBits + val numBuckets = dim*dim*dim + fun chanIdx chan = + Word8.toInt (Word8.>> (chan, Word.fromInt (8 - dimBits))) + fun chanIdxs {red, green, blue} = + (chanIdx red, chanIdx green, chanIdx blue) + fun pack (r, g, b) = (dim*dim)*r + dim*g + b + + val table = Array.array (numBuckets, []) + val sz = ref 0 + fun tableSize () = !sz + + fun insert color = + let + val i = pack (chanIdxs color) + val inBucket = Array.sub (table, i) + in + Array.update (table, i, color :: inBucket); + sz := !sz + 1 + end + + fun bounds x = (Int.max (0, x-1), Int.min (dim, x+2)) + + fun minDist color = + let + val (r, g, b) = chanIdxs color + in + Util.loop (bounds r) (valOf Int.maxInt) (fn (md, r) => + Util.loop (bounds g) md (fn (md, g) => + Util.loop (bounds b) md (fn (md, b) => + List.foldl (fn (c, md) => Int.min (md, dist (c, color))) + md + (Array.sub (table, pack (r, g, b)) )))) + end + + val candidatesSize = 20 + + fun chooseColorsLoop i = + if tableSize () = paletteSize then () else + let + fun minDist' px = (px, minDist px) + fun chooseMax ((c1, dc1), (c2, dc2)) = + if dc1 > dc2 then (c1, dc1) else (c2, dc2) + val (c, _) = + Util.loop (0, candidatesSize) (Color.black, ~1) + (fn (best, j) => chooseMax (best, minDist' (sample (i+j)))) + in + insert c; + chooseColorsLoop (i + candidatesSize) + end + + (* First, demand that there are a few simple colors in there! *) + val _ = List.app insert requiredColors + + (* Now, loop until full *) + val _ = chooseColorsLoop 0 + + (* Compact the table *) + val buckets = AS.full table + val bucketSizes = Seq.map List.length buckets + val bucketOffsets = + AS.full (SeqBasis.scan 100 op+ 0 (0, numBuckets) (Seq.nth bucketSizes)) + val palette = ForkJoin.alloc paletteSize + val _ = + Util.for (0, numBuckets) (fn i => + ignore (Util.copyListIntoArray + (Seq.nth buckets i) + palette + (Seq.nth bucketOffsets i))) + val palette = AS.full palette + + (* remap by lookup into compacted table *) + fun remapOne color = + let + val (r, g, b) = chanIdxs color + + fun chooseMin ((c1, dc1), (c2, dc2)) = + if dc1 <= dc2 then (c1, dc1) else (c2, dc2) + + val (i, _) = + Util.loop (bounds r) (~1, valOf Int.maxInt) (fn (md, r) => + Util.loop (bounds g) md (fn (md, g) => + Util.loop (bounds b) md (fn (md, b) => + let + val bucketIdx = pack (r, g, b) + val bStart = Seq.nth bucketOffsets bucketIdx + val bEnd = Seq.nth bucketOffsets (bucketIdx+1) + in + Util.loop (bStart, bEnd) md (fn (md, i) => + chooseMin (md, (i, dist (color, Seq.nth palette i)))) + end))) + in + Int.max (0, i) + end + + fun remap {width, height, data} = + AS.full (SeqBasis.tabulate 100 (0, Seq.length data) + (remapOne o Seq.nth data)) + in + {colors = palette, remap = remap} + end + + fun summarize cs sz ({data, width, height}: image) = + let + val n = Seq.length data + fun sample i = Seq.nth data (Util.hash i mod n) + in + summarizeBySampling cs sz sample + end + + end + + (* =================================================================== *) + + structure CodeTable: + sig + type t + type idx = int + type code = int + + val new: int -> t (* `new numColors` *) + val clear: t -> unit + val insert: (code * idx) -> t -> bool (* returns false when full *) + val maybeLookup: (code * idx) -> t -> code option + end = + struct + type idx = int + type code = int + + exception Invalid + + type t = + { nextCode: int ref + , numColors: int + , table: (idx * code) list array + } + + fun new numColors = + { nextCode = ref (Util.boundPow2 numColors + 2) + , numColors = numColors + , table = Array.array (4096, []) + } + + fun clear {nextCode, numColors, table} = + ( Util.for (0, Array.length table) (fn i => Array.update (table, i, [])) + ; nextCode := Util.boundPow2 numColors + 2 + ) + + fun insert (code, idx) ({nextCode, numColors, table}: t) = + if !nextCode = 4095 then + false (* GIF limits the maximum code number to 4095 *) + else + ( Array.update (table, code, (idx, !nextCode) :: Array.sub (table, code)) + ; nextCode := !nextCode + 1 + ; true + ) + + fun maybeLookup (code, idx) ({table, numColors, ...}: t) = + case List.find (fn (i, c) => i = idx) (Array.sub (table, code)) of + SOME (_, c) => SOME c + | NONE => NONE + + end + + (* =================================================================== *) + + structure DynArray = + struct + type 'a t = 'a array * int + + fun new () = + (ForkJoin.alloc 100, 0) + + fun push x (data, nextIdx) = + if nextIdx < Array.length data then + (Array.update (data, nextIdx, x); (data, nextIdx+1)) + else + let + val data' = ForkJoin.alloc (2 * Array.length data) + in + Util.for (0, Array.length data) (fn i => + Array.update (data', i, Array.sub (data, i))); + Array.update (data', nextIdx, x); + (data', nextIdx+1) + end + + fun toSeq (data, nextIdx) = + AS.slice (data, 0, SOME nextIdx) + end + +(* + structure DynArrayList = + struct + type 'a t = int * 'a array * ('a array list) + + val chunkSize = 256 + + fun new () = + (0, ForkJoin.alloc chunkSize, []) + + fun push x (nextIdx, chunk, tail) = + ( Array.update (chunk, nextIdx, x) + ; if nextIdx+1 < chunkSize then + (nextIdx+1, chunk, tail) + else + (0, ForkJoin.alloc chunkSize, chunk :: tail) + ) + + fun toSeq (nextIdx, chunk, tail) = + let + val totalSize = nextIdx + (chunkSize * List.length tail) + val result = ForkJoin.alloc totalSize + + fun writeChunks cs i = + case cs of + [] => () + | c :: cs' => + ( Array.copy {src = c, dst = result, di = i - Array.length c} + ; writeChunks cs' (i - Array.length c) + ) + in + Util.for (0, nextIdx) (fn i => + Array.update (result, totalSize - nextIdx + i, Array.sub (chunk, i))); + writeChunks tail (totalSize - nextIdx); + AS.full result + end + end +*) + +(* + structure DynList = + struct + type 'a t = 'a list + fun new () = [] + fun push x list = x :: list + fun toSeq xs = Seq.rev (Seq.fromList xs) + end +*) + + (* =================================================================== *) + + + structure LZW = + struct + + structure T = CodeTable + structure DS = DynArray + + fun codeStream numColors colorIndices = + let + fun colorIdx i = Seq.nth colorIndices i + + val clear = Util.boundPow2 numColors + val eoi = clear + 1 + + val table = T.new numColors + + fun finish stream = + DS.toSeq (DS.push eoi stream) + + (* The buffer is implicit, represented instead by currentCode + * i is the next index into `colorIndices` *) + fun loop stream currentCode i = + if i >= Seq.length colorIndices then + finish (DS.push currentCode stream) + else + case T.maybeLookup (currentCode, colorIdx i) table of + SOME code => loop stream code (i+1) + | NONE => + if T.insert (currentCode, colorIdx i) table then + loop (DS.push currentCode stream) (colorIdx i) (i+1) + else + ( T.clear table + ; loop (DS.push clear (DS.push currentCode stream)) + (colorIdx i) (i+1) + ) + in + if Seq.length colorIndices = 0 then + err "empty color index sequence" + else + loop (DS.push clear (DS.new ())) (colorIdx 0) 1 + end + + (* val codeStream = fn image => fn palette => + let + val (result, tm) = Util.getTime (fn _ => codeStream image palette) + in + print ("computed codeStream in " ^ Time.fmt 4 tm ^ "s\n"); + result + end *) + + fun packCodeStream numColors codes = + let + val n = Seq.length codes + fun code i = Seq.nth codes i + val clear = Util.boundPow2 numColors + val eoi = clear+1 + val minCodeSize = ceilLog2 numColors + val firstCodeWidth = minCodeSize+1 + + (* Begin by calculating the bit width of each code. Since we know bit + * widths are reset at each clear code, we can parallelize by splitting + * the codestream into segments delimited by clear codes and processing + * the segments in parallel. + * + * Within a segment, the width is increased every time we generated + * a new code with power-of-two width. Every symbol in the code stream + * corresponds to a newly generated code. + *) + + val clears = + AS.full (SeqBasis.filter 2000 (0, n) (fn i => i) (fn i => code i = clear)) + val numClears = Seq.length clears + + val widths = ForkJoin.alloc n + val _ = Array.update (widths, 0, firstCodeWidth) + val _ = ForkJoin.parfor 1 (0, numClears) (fn c => + let + val i = 1 + Seq.nth clears c + val j = if c = numClears-1 then n else 1 + Seq.nth clears (c+1) + + (* max code in table, up to (but not including) index k *) + fun currentMaxCode k = + k - i (* num outputs since the table was cleared *) + + eoi (* the max code immediately after clearing the table *) + in + Util.loop (i, j) (firstCodeWidth, Util.pow2 firstCodeWidth) + (fn ((currWidth, cwPow2), k) => + ( Array.update (widths, k, currWidth) + ; if currentMaxCode (k+1) <> cwPow2 then + (currWidth, cwPow2) + else + (currWidth+1, cwPow2*2) + )); + () + end) + val widths = AS.full widths + + val totalBitWidth = Seq.reduce op+ 0 widths + val packedSize = Util.ceilDiv totalBitWidth 8 + + val packed = ForkJoin.alloc packedSize + + fun flushBuffer (oi, buffer, used) = + if used < 8 then + (oi, buffer, used) + else + ( Array.update (packed, oi, Word8.fromLarge buffer) + ; flushBuffer (oi+1, LargeWord.>> (buffer, 0w8), used-8) + ) + + (* Input index range [ci,cj) + * Output index range [oi, oj) + * `buffer` is a partially filled byte that has not yet been written + * to the packed. `used` (0 to 7) is how much of that byte is + * used. *) + fun pack (oi, oj) (ci, cj) (buffer: LargeWord.word) (used: int) = + if ci >= cj then + (if oi >= oj then + () + else if oi = oj-1 then + Array.update (packed, oi, Word8.fromLarge buffer) + else + err "cannot fill rest of packed region") + else + let + val thisCode = code ci + val thisWidth = Seq.nth widths ci + val buffer' = + LargeWord.orb (buffer, + LargeWord.<< (LargeWord.fromInt thisCode, Word.fromInt used)) + val used' = used + thisWidth + val (oi', buffer'', used'') = flushBuffer (oi, buffer', used') + in + pack (oi', oj) (ci+1, cj) buffer'' used'' + end + + val _ = pack (0, packedSize) (0, n) 0w0 0 + val packed = AS.full packed + val numBlocks = Util.ceilDiv packedSize 255 + val output = ForkJoin.alloc (packedSize + numBlocks + 1) + in + ForkJoin.parfor 10 (0, numBlocks) (fn i => + let + val size = if i < numBlocks-1 then 255 else packedSize - 255*i + in + Array.update (output, 256*i, Word8.fromInt size); + Util.for (0, size) (fn j => + Array.update (output, 256*i + 1 + j, Seq.nth packed (255*i + j))) + end); + + Array.update (output, packedSize + numBlocks, 0w0); + + AS.full output + end + end + + (* ====================================================================== *) + + fun checkToWord16 thing x = + if x >= 0 andalso x <= 65535 then + Word16.fromInt x + else + err (thing ^ " must be non-negative and less than 2^16"); + + fun packScreenDescriptorByte + { colorTableFlag: bool + , colorResolution: int + , sortFlag: bool + , colorTableSize: int + } = + let + open Word8 + infix 2 << orb andb + in + ((if colorTableFlag then 0w1 else 0w0) << 0w7) + orb + ((fromInt colorResolution andb 0wx7) << 0w4) + orb + ((if sortFlag then 0w1 else 0w0) << 0w3) + orb + (fromInt colorTableSize andb 0wx7) + end + + fun writeMany path delay palette {width, height, numImages, getImage} = + if numImages <= 0 then + err "Must be at least one image" + else + let + val width16 = checkToWord16 "width" width + val height16 = checkToWord16 "height" height + + val numberOfColors = Seq.length (#colors palette) + + val _ = + if numberOfColors <= 256 then () + else err "Must have at most 256 colors in the palette" + + val (imageData, tm) = Util.getTime (fn _ => + AS.full (SeqBasis.tabulate 1 (0, numImages) (fn i => + let + val img = getImage i + in + if Seq.length img <> height * width then + err "Not all images are the right dimensions" + else + LZW.packCodeStream numberOfColors + (LZW.codeStream numberOfColors img) + end))) + + (* val _ = print ("compressed image data in " ^ Time.fmt 4 tm ^ "s\n") *) + + val file = BinIO.openOut path + val w8 = ExtraBinIO.w8 file + val w32b = ExtraBinIO.w32b file + val w32l = ExtraBinIO.w32l file + val w16l = ExtraBinIO.w16l file + val wrgb = ExtraBinIO.wrgb file + in + (* ========================== + * "GIF89a" header: 6 bytes + *) + + List.app (w8 o Word8.fromInt) [0x47, 0x49, 0x46, 0x38, 0x39, 0x61]; + + (* =================================== + * logical screen descriptor: 7 bytes + *) + + w16l width16; + w16l height16; + + w8 (packScreenDescriptorByte + { colorTableFlag = true + , colorResolution = 1 + , sortFlag = false + , colorTableSize = (ceilLog2 numberOfColors) - 1 + }); + + w8 0w0; (* background color index. just use 0 for now. *) + + w8 0w0; (* pixel aspect ratio ?? *) + + (* =================================== + * global color table + *) + + Util.for (0, numberOfColors) (fn i => + wrgb (Seq.nth (#colors palette) i)); + + Util.for (numberOfColors, Util.boundPow2 numberOfColors) (fn i => + wrgb Color.black); + + (* ================================== + * application extension, for looping + * OPTIONAL. skip it if there is only one image. + *) + + if numImages = 1 then () else + List.app (w8 o Word8.fromInt) + [ 0x21, 0xFF, 0x0B, 0x4E, 0x45, 0x54, 0x53, 0x43, 0x41, 0x50, 0x45, 0x32 + , 0x2E, 0x30, 0x03, 0x01, 0x00, 0x00, 0x00 + ]; + + (* ================================== + * IMAGE DATA + *) + + Util.for (0, numImages) (fn i => + let + val bytes = Seq.nth imageData i + in + (* ========================== + * graphics control extension. + * OPTIONAL. only needed if + * doing animation. + *) + + if numImages = 1 then () else + ( List.app (w8 o Word8.fromInt) [ 0x21, 0xF9, 0x04, 0x04 ] + ; w16l (Word16.fromInt delay) + ; w8 0w0 + ; w8 0w0 + ); + + (* ========================== + * image descriptor + *) + + w8 0wx2C; (* image separator *) + + w16l 0w0; (* image left *) + w16l 0w0; (* image top *) + + w16l width16; (* image width *) + w16l height16; (* image height *) + + w8 0w0; (* packed local color table descriptor (NONE FOR NOW) *) + + (* =========================== + * compressed image data + *) + + w8 (Word8.fromInt (ceilLog2 numberOfColors)); + Util.for (0, Seq.length bytes) (fn i => + w8 (Seq.nth bytes i)) + end); + + (* ================================ + * trailer + *) + + w8 0wx3B; + + BinIO.closeOut file + end + + fun write path img = + let + val palette = Palette.summarize [] 128 img + val img' = #remap palette img + in + writeMany path 0 palette + { width = #width img + , height = #height img + , numImages = 1 + , getImage = (fn _ => img') + } + end +end diff --git a/tests/mpllib/Geometry2D.sml b/tests/mpllib/Geometry2D.sml new file mode 100644 index 000000000..782d00a56 --- /dev/null +++ b/tests/mpllib/Geometry2D.sml @@ -0,0 +1,88 @@ +structure Geometry2D = +struct + + type point = real * real + + fun rtos x = Real.toString x + + fun toString (x, y) = + String.concat ["(", rtos x, ",", rtos y, ")"] + + fun samePoint (x1, y1) (x2, y2) = + Real.== (x1, x2) andalso Real.== (y1, y2) + + fun sq (x : real) = x*x + + fun distance ((x1,y1) : point) ((x2,y2) : point) = + Math.sqrt (sq (x2-x1) + sq (y2-y1)) + + fun quadrant ((cx, cy) : point) (x, y) = + if y < cy + then (if x < cx then 2 else 3) + else (if x < cx then 1 else 0) + (* *) + + structure Vector = + struct + type t = real * real + + val toString = toString + + fun add ((x1, y1), (x2, y2)) : t = (x1+x2, y1+y2) + fun sub ((x1, y1), (x2, y2)) : t = (x1-x2, y1-y2) + + fun dot ((x1, y1), (x2, y2)) : real = x1*x2 + y1*y2 + fun cross ((x1, y1), (x2, y2)) : real = x1*y2 - y1*x2 + + fun scaleBy a (x, y) : t = (a*x, a*y) + + fun length (x, y) = Math.sqrt (x*x + y*y) + + fun angle (u, v) = Math.atan2 (cross (u, v), dot (u, v)) + end + + structure Point = + struct + type t = point + + val toString = toString + + val add = Vector.add + val sub = Vector.sub + + fun minCoords ((a,b),(c,d)) = + (Real.min (a,c), Real.min (b,d)) + + fun maxCoords ((a,b),(c,d)) = + (Real.max (a,c), Real.max (b,d)) + + fun triArea (a, b, c) = + Vector.cross (sub (b, a), sub (c, a)) + + fun counterClockwise (a, b, c) = + triArea (a, b, c) > 0.0 + + (* Returns angle `r` inside the triangle formed by three points: + * b + * / \ + * / r \ + * a c + *) + fun triAngle (a, b, c) = + Vector.angle (sub (a, b), sub (c, b)) + + + fun inCircle (a, b, c) d = + let + fun onParabola ((x, y): point) : Geometry3D.Vector.t = + (x, y, x*x + y*y) + val ad = onParabola (Vector.sub (a, d)) + val bd = onParabola (Vector.sub (b, d)) + val cd = onParabola (Vector.sub (c, d)) + in + Geometry3D.Vector.dot (Geometry3D.Vector.cross (ad, bd), cd) > 0.0 + end + + end + +end diff --git a/tests/mpllib/Geometry3D.sml b/tests/mpllib/Geometry3D.sml new file mode 100644 index 000000000..08e47cfbe --- /dev/null +++ b/tests/mpllib/Geometry3D.sml @@ -0,0 +1,23 @@ +structure Geometry3D = +struct + + type point = real * real * real + + structure Vector = + struct + type t = real * real * real + + fun dot ((a1, a2, a3), (b1, b2, b3)) : real = + a1*b1 + a2*b2 + a3*b3 + + fun cross ((a1, a2, a3), (b1, b2, b3)) : t = + (a2*b3 - a3*b2, a3*b1 - a1*b3, a1*b2 - a2*b1) + + fun add ((a1, a2, a3), (b1, b2, b3)) : t = + (a1+b1, a2+b2, a3+b3) + + fun sub ((a1, a2, a3), (b1, b2, b3)) : t = + (a1-b1, a2-b2, a3-b3) + end + +end diff --git a/tests/mpllib/Hashset.sml b/tests/mpllib/Hashset.sml new file mode 100644 index 000000000..57f2d93cd --- /dev/null +++ b/tests/mpllib/Hashset.sml @@ -0,0 +1,111 @@ +structure Hashset: +sig + type 'a t + type 'a hashset = 'a t + + exception Full + + val make: + { hash: 'a -> int + , eq: 'a * 'a -> bool + , capacity: int + , maxload: real} -> 'a t + + val size: 'a t -> int + val capacity: 'a t -> int + val resize: 'a t -> 'a t + val insert: 'a t -> 'a -> bool + val to_list: 'a t -> 'a list +end = +struct + + +datatype 'a t = + S of + { data: 'a option array + , hash: 'a -> int + , eq: 'a * 'a -> bool + , maxload: real + } + + exception Full + + type 'a hashset = 'a t + + fun make {hash, eq, capacity, maxload} = + let + val data = SeqBasis.tabulate 5000 (0, capacity) (fn _ => NONE) + in + S {data=data, hash=hash, eq=eq, maxload = maxload} + end + + fun bcas (arr, i) (old, new) = + MLton.eq (old, Concurrency.casArray (arr, i) (old, new)) + + fun size (S {data, ...}) = + SeqBasis.reduce 10000 op+ 0 (0, Array.length data) (fn i => + if Option.isSome (Array.sub (data, i)) then 1 else 0) + + fun capacity (S {data, ...}) = Array.length data + + fun insert' (input as S {data, hash, eq, maxload}) x force = + let + val n = Array.length data + val start = (hash x) mod (Array.length data) + + val tolerance = + 2 * Real.ceil (1.0 / (1.0 - maxload)) + + fun loop i probes = + if not force andalso probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let + val current = Array.sub (data, i) + in + case current of + SOME y => if eq (x, y) then false else loop (i+1) (probes+1) + | NONE => + if bcas (data, i) (current, SOME x) then + (* (Concurrency.faa sz 1; true) *) + true + else + loop i probes + end + + val start = (hash x) mod (Array.length data) + in + loop start 0 + end + + + fun insert s x = insert' s x false + + + fun resize (input as S {data, hash, eq, maxload}) = + let + val newcap = 2 * capacity input + val new = make {hash = hash, eq = eq, capacity = newcap, maxload = maxload} + in + ForkJoin.parfor 1000 (0, Array.length data) (fn i => + case Array.sub (data, i) of + NONE => () + | SOME x => (insert' new x true; ())); + + new + end + + + fun to_list (S {data, hash, eq, ...}) = + let + fun pushSome (elem, xs) = + case elem of + SOME x => x :: xs + | NONE => xs + in + Array.foldr pushSome [] data + end + +end diff --git a/tests/mpllib/Hashtable.sml b/tests/mpllib/Hashtable.sml new file mode 100644 index 000000000..19582a0b5 --- /dev/null +++ b/tests/mpllib/Hashtable.sml @@ -0,0 +1,127 @@ +structure Hashtable: +sig + type ('a, 'b) t + type ('a, 'b) hashtable = ('a, 'b) t + + val make: {hash: 'a -> int, eq: 'a * 'a -> bool, capacity: int} -> ('a, 'b) t + val insert: ('a, 'b) t -> ('a * 'b) -> unit + val insert_if_absent: ('a, 'b) t -> ('a * 'b) -> unit + + val lookup: ('a, 'b) t -> 'a -> 'b option + val to_list: ('a, 'b) t -> ('a * 'b) list + val keys_to_arr: ('a, 'b) t -> 'a array +end = +struct + + datatype ('a, 'b) t = + S of + { data: ('a * 'b) option array + , hash: 'a -> int + , eq: 'a * 'a -> bool + } + + type ('a, 'b) hashtable = ('a, 'b) t + + fun make {hash, eq, capacity} = + let + val data = SeqBasis.tabulate 5000 (0, capacity) (fn _ => NONE) + in + S {data=data, hash=hash, eq=eq} + end + + fun bcas (arr, i) (old, new) = + MLton.eq (old, Concurrency.casArray (arr, i) (old, new)) + + fun insert (S {data, hash, eq}) (k, v) = + let + val n = Array.length data + + fun loop i = + if i >= n then loop 0 else + let + val current = Array.sub (data, i) + val rightPlace = + case current of + NONE => true + | SOME (k', _) => eq (k, k') + in + if not rightPlace then + loop (i+1) + else if bcas (data, i) (current, SOME (k, v)) then + () + else + loop i + end + + val start = (hash k) mod (Array.length data) + in + loop start + end + + (* This function differs from the above in the case + where the key k is already in the hashtable. + If so, the function does not update the key's value + and is thus more efficient (saves cas). + *) + fun insert_if_absent (S {data, hash, eq}) (k, v) = + let + val n = Array.length data + + fun loop i = + if i >= n then loop 0 else + let + val current = Array.sub (data, i) + in + case current of + NONE => + if (bcas (data, i) (current, SOME (k, v))) then () + else loop i + | SOME (k', _) => + if eq (k, k') then () + else loop (i + 1) + end + + val start = (hash k) mod (Array.length data) + in + loop start + end + + fun lookup (S {data, hash, eq}) k = + let + val n = Array.length data + + fun loop i = + if i >= n then loop 0 else + case Array.sub (data, i) of + SOME (k', v) => if eq (k, k') then SOME v else loop (i+1) + | NONE => NONE + + val start = (hash k) mod (Array.length data) + in + loop start + end + + fun keys_to_arr (S{data, hash, eq}) = + let + val n = Array.length data + val gran = 10000 + val keys = SeqBasis.tabFilter gran (0, Array.length data) + (fn i => + case Array.sub (data, i) of + NONE => NONE + | SOME (k, _) => SOME k) + in + keys + end + + fun to_list (S {data, hash, eq}) = + let + fun pushSome (elem, xs) = + case elem of + SOME x => x :: xs + | NONE => xs + in + Array.foldr pushSome [] data + end + +end diff --git a/tests/mpllib/MatCOO.sml b/tests/mpllib/MatCOO.sml new file mode 100644 index 000000000..acfdc3d1f --- /dev/null +++ b/tests/mpllib/MatCOO.sml @@ -0,0 +1,681 @@ +(* MatCOO(I, R): + * - indices of type I.int + * - values of type R.real + * + * For example, the following defines matrices where the row and column indices + * will be arrays of C type int32_t*, and values of C type float* + * + * structure M = MatCOO(structure I = Int32 + * structure R = Real32) + * + * We can also use int64_t and doubles (or any other desired combination): + * + * structure M = MatCOO(structure I = Int64 + * structure R = Real64) + *) + +functor MatCOO + (structure I: INTEGER + structure R: + sig + include REAL + structure W: WORD + val castFromWord: W.word -> real + val castToWord: real -> W.word + end) = +struct + structure I = I + structure R = R + + (* SoA format for storing every nonzero value = mat[row,column], i.e.: + * (row, column, value) = (row_indices[i], col_indices[i], values[i]) + * + * so, number of nonzeros (nnz) + * = length(row_indices) + * = length(col_indices) + * = length(values) + * + * we assume row_indices is sorted + *) + datatype mat = + Mat of + { width: int + , height: int + , row_indices: I.int Seq.t + , col_indices: I.int Seq.t + , values: R.real Seq.t + } + + type t = mat + + fun width (Mat m) = #width m + fun height (Mat m) = #height m + + fun nnz (Mat m) = + Seq.length (#row_indices m) + + fun row_lo (mat as Mat m) = + if nnz mat = 0 then I.fromInt 0 else Seq.first (#row_indices m) + + fun row_hi (mat as Mat m) = + if nnz mat = 0 then I.fromInt 0 + else I.+ (I.fromInt 1, Seq.last (#row_indices m)) + + fun row_spread mat = + I.toInt (row_hi mat) - I.toInt (row_lo mat) + + fun split_seq s k = + (Seq.take s k, Seq.drop s k) + + fun undo_split_seq s1 s2 = + let + val (a1, i1, n1) = ArraySlice.base s1 + val (a2, i2, n2) = ArraySlice.base s2 + in + if MLton.eq (a1, a2) andalso i1 + n1 = i2 then + Seq.subseq (ArraySlice.full a1) (i1, n1 + n2) + else + raise Fail + ("undo_split_seq: arguments are not adjacent: " ^ Int.toString i1 + ^ " " ^ Int.toString n1 ^ " " ^ Int.toString i2 ^ " " + ^ Int.toString n2) + end + + + fun split_nnz k (mat as Mat m) = + let + val (r1, r2) = split_seq (#row_indices m) k + val (c1, c2) = split_seq (#col_indices m) k + val (v1, v2) = split_seq (#values m) k + + val m1 = Mat + { width = #width m + , height = #height m + , row_indices = r1 + , col_indices = c1 + , values = v1 + } + + val m2 = Mat + { width = #width m + , height = #height m + , row_indices = r2 + , col_indices = c2 + , values = v2 + } + in + (m1, m2) + end + + + (* split m -> (m1, m2) + * where m = m1 + m2 + * and nnz(m1) ~= frac * nnz(m) + * and nnz(m2) ~= (1-frac) * nnz(m) + *) + fun split frac mat = + let + val half = Real.floor (frac * Real.fromInt (nnz mat)) + val half = + if nnz mat <= 1 then half + else if half = 0 then half + 1 + else if half = nnz mat then half - 1 + else half + in + split_nnz half mat + end + + + (* might fail if the matrices were not created by a split *) + fun undo_split mat1 mat2 = + if nnz mat1 = 0 then + mat2 + else if nnz mat2 = 0 then + mat1 + else + let + val Mat m1 = mat1 + val Mat m2 = mat2 + in + if #width m1 <> #width m2 orelse #height m1 <> #height m2 then + raise Fail "undo_split: dimension mismatch" + else + Mat + { width = #width m1 + , height = #height m1 + , row_indices = undo_split_seq (#row_indices m1) (#row_indices m2) + , col_indices = undo_split_seq (#col_indices m1) (#col_indices m2) + , values = undo_split_seq (#values m1) (#values m2) + } + end + + + fun dump_info msg (Mat m) = + let + val info = String.concatWith " " + [ msg + , Seq.toString I.toString (#row_indices m) + , Seq.toString I.toString (#col_indices m) + , Seq.toString (R.fmt (StringCvt.FIX (SOME 1))) (#values m) + ] + in + print (info ^ "\n") + end + + + fun upd (a, i, x) = + ( (*print ("upd " ^ Int.toString i ^ "\n");*)Array.update (a, i, x)) + + + (* ======================================================================= + * write_mxv: serial and parallel versions + * + * The first and last row of the input `mat` might overlap with other + * parallel tasks, so these need to be returned separately and combined. + * + * All middle rows are "owned" by the call to write_mxv + *) + + datatype write_mxv_result = + SingleRowValue of R.real + | FirstLastRowValue of R.real * R.real + + + (* requires `row_lo mat < row_hi mat`, i.e., at least one row *) + fun write_mxv_serial mat vec output : write_mxv_result = + let + val Mat {row_indices, col_indices, values, ...} = mat + val n = nnz mat + in + if row_lo mat = I.- (row_hi mat, I.fromInt 1) then + (* no writes to output; only a single result row value *) + let + (* val _ = dump_info "write_mxv_serial (single row)" mat *) + val result = Util.loop (0, n) (R.fromInt 0) (fn (acc, i) => + R.+ (acc, R.* + (Seq.nth vec (I.toInt (Seq.nth col_indices i)), Seq.nth values i))) + in + SingleRowValue result + end + else + let + (* val _ = dump_info "write_mxv_serial (multi rows)" mat *) + fun single_row_loop r (i, acc) = + if i >= n orelse Seq.nth row_indices i <> r then + (i, acc) + else + let + val acc' = R.+ (acc, R.* + ( Seq.nth vec (I.toInt (Seq.nth col_indices i)) + , Seq.nth values i + )) + in + single_row_loop r (i + 1, acc') + end + + val last_row = I.- (row_hi mat, I.fromInt 1) + + fun advance_to_next_nonempty_row i' r = + let + val next_r = Seq.nth row_indices i' + in + Util.for (1 + I.toInt r, I.toInt next_r) (fn rr => + upd (output, rr, R.fromInt 0)); + next_r + end + + fun middle_loop r i = + if r = last_row then + i + else + let + val (i', row_value) = single_row_loop r (i, R.fromInt 0) + in + upd (output, I.toInt r, row_value); + middle_loop (advance_to_next_nonempty_row i' r) i' + end + + val (i, first_row_value) = + single_row_loop (row_lo mat) (0, R.fromInt 0) + val i = middle_loop (advance_to_next_nonempty_row i (row_lo mat)) i + val (_, last_row_value) = single_row_loop last_row (i, R.fromInt 0) + in + FirstLastRowValue (first_row_value, last_row_value) + end + end + + val nnzGrain = 5000 + + + fun write_mxv_combine_results (m1, m2) (result1, result2) output = + if I.- (row_hi m1, I.fromInt 1) = row_lo m2 then + case (result1, result2) of + (SingleRowValue r1, SingleRowValue r2) => SingleRowValue (R.+ (r1, r2)) + | (SingleRowValue r1, FirstLastRowValue (f2, l2)) => + FirstLastRowValue (R.+ (r1, f2), l2) + | (FirstLastRowValue (f1, l1), SingleRowValue r2) => + FirstLastRowValue (f1, R.+ (l1, r2)) + | (FirstLastRowValue (f1, l1), FirstLastRowValue (f2, l2)) => + (* overlap *) + ( (*print "fill in middle overlap\n" + ;*) upd (output, I.toInt (row_lo m2), R.+ (l1, f2)) + ; FirstLastRowValue (f1, l2) + ) + else + let + fun finish_l1 v = + upd (output, I.toInt (row_hi m1) - 1, v) + fun finish_f2 v = + upd (output, I.toInt (row_lo m2), v) + fun fill_middle () = + ForkJoin.parfor 5000 (I.toInt (row_hi m1), I.toInt (row_lo m2)) + (fn r => upd (output, r, R.fromInt 0)) + in + (* print "fill in middle, no overlap\n"; *) + case (result1, result2) of + (SingleRowValue r1, SingleRowValue r2) => + (fill_middle (); FirstLastRowValue (r1, r2)) + + | (SingleRowValue r1, FirstLastRowValue (f2, l2)) => + (fill_middle (); finish_f2 f2; FirstLastRowValue (r1, l2)) + + | (FirstLastRowValue (f1, l1), SingleRowValue r2) => + (finish_l1 l1; fill_middle (); FirstLastRowValue (f1, r2)) + + | (FirstLastRowValue (f1, l1), FirstLastRowValue (f2, l2)) => + ( finish_l1 l1 + ; fill_middle () + ; finish_f2 f2 + ; FirstLastRowValue (f1, l2) + ) + end + + + (* requires `row_lo mat < row_hi mat`, i.e., at least one row *) + fun write_mxv mat vec output : write_mxv_result = + if nnz mat <= nnzGrain then + write_mxv_serial mat vec output + else + let + val (m1, m2) = split 0.5 mat + val (result1, result2) = + ForkJoin.par (fn _ => write_mxv m1 vec output, fn _ => + write_mxv m2 vec output) + + in + write_mxv_combine_results (m1, m2) (result1, result2) output + end + + + fun mxv (mat: mat) (vec: R.real Seq.t) = + if nnz mat = 0 then + Seq.tabulate (fn _ => R.fromInt 0) (Seq.length vec) + else + let + val output: R.real array = ForkJoin.alloc (Seq.length vec) + val result = write_mxv mat vec output + in + (* print "top level: fill in front\n"; *) + ForkJoin.parfor 5000 (0, I.toInt (row_lo mat)) (fn r => + upd (output, r, R.fromInt 0)); + (* print "top-level: fill in middle\n"; *) + case result of + SingleRowValue r => upd (output, I.toInt (row_lo mat), r) + | FirstLastRowValue (f, l) => + ( upd (output, I.toInt (row_lo mat), f) + ; upd (output, I.toInt (row_hi mat) - 1, l) + ); + (* print "top-level: fill in back\n"; *) + ForkJoin.parfor 5000 (I.toInt (row_hi mat), Seq.length vec) (fn r => + upd (output, r, R.fromInt 0)); + ArraySlice.full output + end + + (* ======================================================================= *) + (* ======================================================================= *) + (* ======================================================================= + * to/from file + *) + + structure DS = DelayedSeq + + fun fromMatrixMarketFile path chars = + let + val lines = ParseFile.tokens (fn c => c = #"\n") chars + val numLines = DS.length lines + fun line i : char DS.t = DS.nth lines i + fun lineStr i : string = + let val ln = line i + in CharVector.tabulate (DS.length ln, DS.nth ln) + end + + fun lineIsComment i = + let val ln = line i + in DS.length ln > 0 andalso DS.nth ln 0 = #"%" + end + + val _ = + if + numLines > 0 + andalso + ParseFile.eqStr "%%MatrixMarket matrix coordinate real general" + (line 0) + then () + else raise Fail ("MatCOO.fromFile: not sure how to parse file: " ^ path) + + val firstNonCommentLineNum = + case FindFirst.findFirst 1000 (1, numLines) (not o lineIsComment) of + SOME i => i + | NONE => raise Fail ("MatCOO.fromFile: missing contents?") + + fun fail () = + raise Fail + ("MatCOO.fromFile: error parsing line " + ^ Int.toString (1 + firstNonCommentLineNum) + ^ ": expected ") + val (numRows, numCols, numValues) = + let + val lineNum = firstNonCommentLineNum + val lnChars = DS.toArraySeq (line lineNum) + val toks = ParseFile.tokens Char.isSpace lnChars + val nr = valOf (ParseFile.parseInt (DS.nth toks 0)) + val nc = valOf (ParseFile.parseInt (DS.nth toks 1)) + val nv = valOf (ParseFile.parseInt (DS.nth toks 2)) + in + (nr, nc, nv) + end + handle _ => fail () + + val _ = print ("num rows " ^ Int.toString numRows ^ "\n") + val _ = print ("num cols " ^ Int.toString numCols ^ "\n") + val _ = print ("num nonzero " ^ Int.toString numValues ^ "\n") + val _ = print ("parsing elements (may take a while...)\n") + + val row_indices = ForkJoin.alloc numValues + val col_indices = ForkJoin.alloc numValues + val values = ForkJoin.alloc numValues + + fun fail lineNum = + raise Fail + ("MatCOO.fromFile: error parsing line " ^ Int.toString (1 + lineNum) + ^ ": expected ") + + (* TODO: this is very slow *) + fun parseValue i = + let + val lineNum = firstNonCommentLineNum + 1 + i + val lnChars = DS.toArraySeq (line lineNum) + val toks = ParseFile.tokens Char.isSpace lnChars + val r = I.fromInt (valOf (ParseFile.parseInt (DS.nth toks 0))) + val c = I.fromInt (valOf (ParseFile.parseInt (DS.nth toks 1))) + val v = R.fromLarge IEEEReal.TO_NEAREST (valOf (ParseFile.parseReal + (DS.nth toks 2))) + + (* val ln = line lineNum + val chars = CharVector.tabulate (DS.length ln, DS.nth ln) + val toks = String.tokens Char.isSpace chars + val r = I.fromInt (valOf (Int.fromString (List.nth (toks, 0)))) + val c = I.fromInt (valOf (Int.fromString (List.nth (toks, 1)))) + val v = R.fromLarge IEEEReal.TO_NEAREST (valOf (Real.fromString + (List.nth (toks, 2)))) *) + in + (* if i mod 500000 = 0 then + print ("finished row " ^ Int.toString i ^ "\n") + else + (); *) + + (* coordinates are stored in .mtx files using 1-indexing, but we + * want 0-indexing + *) + Array.update (row_indices, i, I.- (r, I.fromInt 1)); + Array.update (col_indices, i, I.- (c, I.fromInt 1)); + Array.update (values, i, v) + end + handle _ => fail (firstNonCommentLineNum + 1 + i) + + val _ = ForkJoin.parfor 1000 (0, numValues) parseValue + val _ = print ("finished parsing elements\n") + val _ = print ("formatting...\n") + + val getSorted = + let + val idx = Seq.tabulate (fn i => i) numValues + in + StableSort.sortInPlace + (fn (i, j) => + I.compare + (Array.sub (row_indices, i), Array.sub (row_indices, j))) idx; + fn i => Seq.nth idx i + end + + val row_indices = + Seq.tabulate (fn i => Array.sub (row_indices, getSorted i)) numValues + val col_indices = + Seq.tabulate (fn i => Array.sub (col_indices, getSorted i)) numValues + val values = + Seq.tabulate (fn i => Array.sub (values, getSorted i)) numValues + + val _ = print ("done parsing " ^ path ^ "\n") + in + Mat + { width = numCols + , height = numRows + , row_indices = row_indices + , col_indices = col_indices + , values = values + } + end + + + (* MatrixCoordinateRealBin\n + * [8 bits unsigned: real value precision, either 32 or 64] + * [64 bits unsigned: number of rows] + * [64 bits unsigned: number of columns] + * [64 bits unsigned: number of elements] + * [element] + * [element] + * ... + * + * each element is as follows, where X is the real value precision (32 or 64) + * [64 bits unsigned: row index][64 bits unsigned: col index][X bits: value] + *) + fun writeBinFile mat path = + let + val file = TextIO.openOut path + val _ = TextIO.output (file, "MatrixCoordinateRealBin\n") + val _ = TextIO.closeOut file + + val file = BinIO.openAppend path + + fun w8 (w: Word8.word) = BinIO.output1 (file, w) + + fun w32 (w: Word64.word) = + let + open Word64 + infix 2 >> andb + in + w8 (Word8.fromLarge (w >> 0w24)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + + fun w64 (w: Word64.word) = + let + open Word64 + infix 2 >> andb + in + (* this will only work if Word64 = LargeWord, which is good. *) + w8 (Word8.fromLarge (w >> 0w56)); + w8 (Word8.fromLarge (w >> 0w48)); + w8 (Word8.fromLarge (w >> 0w40)); + w8 (Word8.fromLarge (w >> 0w32)); + w8 (Word8.fromLarge (w >> 0w24)); + w8 (Word8.fromLarge (w >> 0w16)); + w8 (Word8.fromLarge (w >> 0w8)); + w8 (Word8.fromLarge w) + end + + fun wr64 (r: R.real) = + w64 (R.W.toLarge (R.castToWord r)) + fun wr32 (r: R.real) = + w32 (R.W.toLarge (R.castToWord r)) + + val (wr, rsize) = + case R.W.wordSize of + 32 => (wr32, 0w32) + | 64 => (wr64, 0w64) + | _ => + raise Fail + "MatCOO.writeBinFile: only 32-bit and 64-bit reals supported" + in + w8 rsize; + w64 (Word64.fromInt (height mat)); + w64 (Word64.fromInt (width mat)); + w64 (Word64.fromInt (nnz mat)); + Util.for (0, nnz mat) (fn i => + let + val Mat {row_indices, col_indices, values, ...} = mat + val r = Seq.nth row_indices i + val c = Seq.nth col_indices i + val v = Seq.nth values i + in + w64 (Word64.fromInt (I.toInt r)); + w64 (Word64.fromInt (I.toInt c)); + wr v + end); + BinIO.closeOut file + end + + + fun fromBinFile path bytes = + let + val header = "MatrixCoordinateRealBin\n" + val header' = + if Seq.length bytes < String.size header then + raise Fail ("MatCOO.fromBinFile: missing or incomplete header") + else + CharVector.tabulate (String.size header, fn i => + Char.chr (Word8.toInt (Seq.nth bytes i))) + val _ = + if header = header' then + () + else + raise Fail + ("MatCOO.fromBinFile: expected MatrixCoordinateRealBin header") + + val bytes = Seq.drop bytes (String.size header) + + fun r64 off = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes off) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 3))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 4))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 5))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 6))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 7))) + in + w + end + + fun r32 off = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes off) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (off + 3))) + in + w + end + + fun r8 off = Seq.nth bytes off + fun rr32 off = + R.castFromWord (R.W.fromLarge (r32 off)) + fun rr64 off = + R.castFromWord (R.W.fromLarge (r64 off)) + + (* ==================================================================== + * parse binary contents + *) + + val rsize = Word8.toInt (r8 0) + + fun rsizeFail () = + raise Fail + ("MatCOO.fromBinFile: found " ^ Int.toString rsize + ^ "-bit reals, but expected " ^ Int.toString R.W.wordSize ^ "-bit") + + val (rr, rbytes) = + if rsize = R.W.wordSize then + case rsize of + 32 => (rr32, 4) + | 64 => (rr64, 8) + | _ => rsizeFail () + else + rsizeFail () + + val elemSize = 8 + 8 + rbytes + val elemStartOff = 1 + 8 + 8 + 8 + + val height = Word64.toInt (r64 (1 + 0)) + val width = Word64.toInt (r64 (1 + 8)) + val numValues = Word64.toInt (r64 (1 + 8 + 8)) + + val row_indices = ForkJoin.alloc numValues + val col_indices = ForkJoin.alloc numValues + val values = ForkJoin.alloc numValues + in + ForkJoin.parfor 5000 (0, numValues) (fn i => + let + val off = elemStartOff + i * elemSize + val r = I.fromInt (Word64.toInt (r64 off)) + val c = I.fromInt (Word64.toInt (r64 (off + 8))) + val v = rr (off + 8 + 8) + in + Array.update (row_indices, i, r); + Array.update (col_indices, i, c); + Array.update (values, i, v) + end); + + Mat + { width = width + , height = height + , row_indices = ArraySlice.full row_indices + , col_indices = ArraySlice.full col_indices + , values = ArraySlice.full values + } + end + + + fun fromFile path = + let + val file = TextIO.openIn path + val _ = print ("loading " ^ path ^ "\n") + + val h1 = "%%MatrixMarket" + val h2 = "MatrixCoordinateRealBin\n" + + val actualHeader = TextIO.inputN + (file, Int.max (String.size h1, String.size h2)) + in + TextIO.closeIn file; + + if String.isPrefix h1 actualHeader then + fromMatrixMarketFile path (ReadFile.contentsSeq path) + else if String.isPrefix h2 actualHeader then + fromBinFile path (ReadFile.contentsBinSeq path) + else + raise Fail ("unknown header " ^ actualHeader) + end + +end diff --git a/tests/mpllib/Merge.sml b/tests/mpllib/Merge.sml new file mode 100644 index 000000000..a2f23fcdc --- /dev/null +++ b/tests/mpllib/Merge.sml @@ -0,0 +1,98 @@ +structure Merge: +sig + type 'a seq = 'a ArraySlice.slice + + val writeMergeSerial: ('a * 'a -> order) (* compare *) + -> 'a seq * 'a seq (* (sorted) sequences to merge *) + -> 'a seq (* output *) + -> unit + + val writeMerge: ('a * 'a -> order) (* compare *) + -> 'a seq * 'a seq (* (sorted) sequences to merge *) + -> 'a seq (* output *) + -> unit + + val mergeSerial: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq + val merge: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq +end = +struct + + structure AS = ArraySlice + type 'a seq = 'a AS.slice + + val for = Util.for + val parfor = ForkJoin.parfor + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + val serialGrain = CommandLineArgs.parseInt "MPLLib_Merge_serialGrain" 4000 + + fun sliceIdxs s i j = + AS.subslice (s, i, SOME (j - i)) + + fun writeMergeSerial cmp (s1, s2) t = + let + fun write i x = AS.update (t, i, x) + + val n1 = AS.length s1 + val n2 = AS.length s2 + + (* i1 index into s1 + * i2 index into s2 + * j index into output *) + fun loop i1 i2 j = + if i1 = n1 then + Util.foreach (sliceIdxs s2 i2 n2) (fn (i, x) => write (i + j) x) + else if i2 = n2 then + Util.foreach (sliceIdxs s1 i1 n1) (fn (i, x) => write (i + j) x) + else + let + val x1 = AS.sub (s1, i1) + val x2 = AS.sub (s2, i2) + in + case cmp (x1, x2) of + LESS => (write j x1; loop (i1 + 1) i2 (j + 1)) + | _ => (write j x2; loop i1 (i2 + 1) (j + 1)) + end + in + loop 0 0 0 + end + + fun mergeSerial cmp (s1, s2) = + let val out = AS.full (allocate (AS.length s1 + AS.length s2)) + in writeMergeSerial cmp (s1, s2) out; out + end + + fun writeMerge cmp (s1, s2) t = + if AS.length t <= serialGrain then + writeMergeSerial cmp (s1, s2) t + else if AS.length s1 = 0 then + Util.foreach s2 (fn (i, x) => AS.update (t, i, x)) + else + let + val n1 = AS.length s1 + val n2 = AS.length s2 + val mid1 = n1 div 2 + val pivot = AS.sub (s1, mid1) + val mid2 = BinarySearch.search cmp s2 pivot + + val l1 = sliceIdxs s1 0 mid1 + val r1 = sliceIdxs s1 (mid1 + 1) n1 + val l2 = sliceIdxs s2 0 mid2 + val r2 = sliceIdxs s2 mid2 n2 + + val _ = AS.update (t, mid1 + mid2, pivot) + val tl = sliceIdxs t 0 (mid1 + mid2) + val tr = sliceIdxs t (mid1 + mid2 + 1) (AS.length t) + in + par (fn _ => writeMerge cmp (l1, l2) tl, fn _ => + writeMerge cmp (r1, r2) tr); + () + end + + fun merge cmp (s1, s2) = + let val out = AS.full (allocate (AS.length s1 + AS.length s2)) + in writeMerge cmp (s1, s2) out; out + end + +end diff --git a/tests/mpllib/Mergesort.sml b/tests/mpllib/Mergesort.sml new file mode 100644 index 000000000..fc2a1438c --- /dev/null +++ b/tests/mpllib/Mergesort.sml @@ -0,0 +1,69 @@ +structure Mergesort: +sig + type 'a seq = 'a ArraySlice.slice + val sortInPlace: ('a * 'a -> order) -> 'a seq -> unit + val sort: ('a * 'a -> order) -> 'a seq -> 'a seq +end = +struct + + type 'a seq = 'a ArraySlice.slice + + structure AS = ArraySlice + + fun take s n = AS.subslice (s, 0, SOME n) + fun drop s n = AS.subslice (s, n, NONE) + + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + (* in-place sort s, using t as a temporary array if needed *) + fun sortInPlace' cmp s t = + if AS.length s <= 1024 then + Quicksort.sortInPlace cmp s + else let + val half = AS.length s div 2 + val (sl, sr) = (take s half, drop s half) + val (tl, tr) = (take t half, drop t half) + in + (* recursively sort, writing result into t *) + par (fn _ => writeSort cmp sl tl, fn _ => writeSort cmp sr tr); + (* merge back from t into s *) + Merge.writeMerge cmp (tl, tr) s; + () + end + + (* destructively sort s, writing the result in t *) + and writeSort cmp s t = + if AS.length s <= 1024 then + ( Util.foreach s (fn (i, x) => AS.update (t, i, x)) + ; Quicksort.sortInPlace cmp t + ) + else let + val half = AS.length s div 2 + val (sl, sr) = (take s half, drop s half) + val (tl, tr) = (take t half, drop t half) + in + (* recursively in-place sort sl and sr *) + par (fn _ => sortInPlace' cmp sl tl, fn _ => sortInPlace' cmp sr tr); + (* merge into t *) + Merge.writeMerge cmp (sl, sr) t; + () + end + + fun sortInPlace cmp s = + let + val t = AS.full (allocate (AS.length s)) + in + sortInPlace' cmp s t + end + + fun sort cmp s = + let + val result = AS.full (allocate (AS.length s)) + in + Util.foreach s (fn (i, x) => AS.update (result, i, x)); + sortInPlace cmp result; + result + end + +end diff --git a/tests/mpllib/MeshToImage.sml b/tests/mpllib/MeshToImage.sml new file mode 100644 index 000000000..a4c373d76 --- /dev/null +++ b/tests/mpllib/MeshToImage.sml @@ -0,0 +1,278 @@ +structure MeshToImage: +sig + val toImage: + { mesh: Topology2D.mesh + , resolution: int + , cavities: (Geometry2D.point * Topology2D.cavity) Seq.t option + , background: Color.color + } + -> PPM.image +end = +struct + + structure T = Topology2D + structure G = Geometry2D + + fun inRange (a, b) x = + Real.min (a, b) <= x andalso x <= Real.max (a, b) + + fun xIntercept (x0,y0) (x1,y1) y = + if not (inRange (y0, y1) y) orelse Real.== (y0, y1) then + NONE + else + let + val x = x0 + (y - y0) * ((x1-x0)/(y1-y0)) + in + if inRange (x0, x1) x then SOME x else NONE + end + + fun rocmp (xo, yo) = + case (xo, yo) of + (SOME x, SOME y) => Real.compare (x, y) + | (NONE, SOME _) => GREATER + | (SOME _, NONE) => LESS + | _ => EQUAL + + fun sort3 cmp (a, b, c) = + let + fun lt (x, y) = case cmp (x, y) of LESS => true | _ => false + val (a, b) = if lt (a, b) then (a, b) else (b, a) + val (a, c) = if lt (a, c) then (a, c) else (c, a) + val (b, c) = if lt (b, c) then (b, c) else (c, b) + in + (a, b, c) + end + + fun toImage {mesh, resolution, cavities, background} = + let + val points = T.getPoints mesh + + val width = resolution + val height = resolution + + val niceGray = Color.hsva {h=0.0, s=0.0, v=0.88, a=1.0} + (* val white = Color.hsva {h=0.0, s=0.0, v=1.0, a=1.0} *) + val black = Color.hsva {h=0.0, s=0.0, v=0.0, a=1.0} + val red = Color.hsva {h=0.0, s=1.0, v=1.0, a=1.0} + + val niceRed = Color.hsva {h = 0.0, s = 0.55, v = 0.95, a = 0.55} + val niceBlue = Color.hsva {h = 240.0, s = 0.55, v = 0.95, a = 0.55} + + fun alphaGray a = + {red = 0.5, blue = 0.5, green = 0.5, alpha = a} + (* Color.hsva {h = 0.0, s = 0.0, v = 0.7, a = a} *) + + fun alphaRed a = + {red = 1.0, blue = 0.0, green = 0.0, alpha = a} + + val image = + { width = width + , height = height + , data = Seq.tabulate (fn _ => background) (width*height) + } + + fun set (i, j) x = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then ArraySlice.update (#data image, i*width + j, x) + else () + + fun setxy (x, y) z = + set (resolution - y - 1, x) z + + fun modify (i, j) f = + if 0 <= i andalso i < height andalso + 0 <= j andalso j < width + then + let + val k = i*width + j + val a = #data image + in + ArraySlice.update (a, k, f (ArraySlice.sub (a, k))) + end + else () + + fun modifyxy (x, y) f = + modify (resolution - y - 1, x) f + + fun overlay (x, y) color = + modifyxy (x, y) (fn bg => Color.overlayColor {fg = color, bg = bg}) + + val r = Real.fromInt resolution + fun px x = Real.floor (x * r + 0.5) + + fun vpos v = T.vdata mesh v + + fun ipart x = Real.floor x + fun fpart x = x - Real.realFloor x + fun rfpart x = 1.0 - fpart x + + (** input points should be in range [0,1] *) + fun aaLine colorFn (x0, y0) (x1, y1) = + if x1 < x0 then aaLine colorFn (x1, y1) (x0, y0) else + let + (** scale to resolution *) + val (x0, y0, x1, y1) = (r*x0 + 0.5, r*y0 + 0.5, r*x1 + 0.5, r*y1 + 0.5) + + fun plot (x, y, c) = + overlay (x, y) (colorFn c) + + val dx = x1-x0 + val dy = y1-y0 + val yxSlope = dy / dx + val xySlope = dx / dy + (* val xhop = Real.fromInt (Real.sign dx) *) + (* val yhop = Real.fromInt (Real.sign dy) *) + + (* fun y x = x0 + (x-x0) * slope *) + + (** (x,y) = current point on the line *) + fun normalLoop (x, y) = + if x > x1 then () else + ( plot (ipart x, ipart y , rfpart y) + ; plot (ipart x, ipart y + 1, fpart y) + ; normalLoop (x + 1.0, y + yxSlope) + ) + + fun steepUpLoop (x, y) = + if y > y1 then () else + ( plot (ipart x , ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepUpLoop (x + xySlope, y + 1.0) + ) + + fun steepDownLoop (x, y) = + if y < y1 then () else + ( plot (ipart x , ipart y, rfpart x) + ; plot (ipart x + 1, ipart y, fpart x) + ; steepDownLoop (x - xySlope, y - 1.0) + ) + in + if Real.abs dx > Real.abs dy then + normalLoop (x0, y0) + else if y1 > y0 then + steepUpLoop (x0, y0) + else + steepDownLoop (x0, y0) + end + + fun adjust (x, y) = (r*x + 0.5, r*y + 0.5) + + fun fillTriangle color (p0, p1, p2) = + let + val (p0, p1, p2) = (adjust p0, adjust p1, adjust p2) + + (** min and max corners of bounding box *) + val (xlo, ylo) = List.foldl G.Point.minCoords p0 [p1, p2] + val (xhi, yhi) = List.foldl G.Point.maxCoords p0 [p1, p2] + + fun horizontalIntersect y = + let + val xa = xIntercept p0 p1 y + val xb = xIntercept p1 p2 y + val xc = xIntercept p0 p2 y + in + case sort3 rocmp (xa, xb, xc) of + (SOME xa, SOME xb, NONE) => SOME (xa, xb) + | _ => NONE + (* | (SOME xa, NONE, NONE) => (xa, xa) + | _ => raise Fail "MeshToImage.horizontalIntersect bug" *) + end + + fun loop y = + if y >= yhi then () else + let + val yy = ipart y + in + (case horizontalIntersect y of + SOME (xleft, xright) => + Util.for (ipart xleft, ipart xright + 1) + (fn xx => overlay (xx, yy) color) + | NONE => ()); + + loop (y+1.0) + end + in + loop (Real.realCeil ylo) + end + + in + (* draw all triangle edges as straight red lines *) + ForkJoin.parfor 1000 (0, T.numTriangles mesh) (fn i => + let + (** cut off anything that is outside the image (not important other than + * a little faster this way). + *) + (* fun constrain (x, y) = + (Real.min (1.0, Real.max (0.0, x)), Real.min (1.0, Real.max (0.0, y))) *) + (* fun vpos v = constrain (T.vdata mesh v) *) + + fun doLineIf b (u, v) = + if b then aaLine alphaGray (vpos u) (vpos v) else () + + val T.Tri {vertices=(u,v,w), neighbors=(a,b,c)} = T.tdata mesh i + in + (* skip "invalid" triangles *) + if u < 0 orelse v < 0 orelse w < 0 then () + else + (** This ensures that each line segment is only drawn once. The person + * responsible for drawing it is the triangle with larger id. + *) + ( doLineIf (i > a) (w, u) + ; doLineIf (i > b) (u, v) + ; doLineIf (i > c) (v, w) + ) + end); + + (* maybe fill in cavities *) + case cavities of NONE => () | SOME cavs => + ForkJoin.parfor 100 (0, Seq.length cavs) (fn i => + let + val (pt, (center, simps)) = Seq.nth cavs i + val triangles = center :: List.map (fn (t, _) => t) simps + + val perimeter = + List.map (T.vdata mesh) + ((let val (u,v,w) = T.verticesOfTriangle mesh center + in [u,v,w] + end) + @ + (List.map (fn s => T.firstVertex mesh (T.rotateClockwise s)) simps)) + + fun fillTri t = + let + val (v0,v1,v2) = T.verticesOfTriangle mesh t + val (p0,p1,p2) = (T.vdata mesh v0, T.vdata mesh v1, T.vdata mesh v2) + in + fillTriangle niceBlue (p0, p1, p2) + end + in + List.app fillTri triangles; + List.app (aaLine alphaRed pt) perimeter + end); + + (* mark input points as a pixel *) + ForkJoin.parfor 10000 (0, Seq.length points) (fn i => + let + val (x, y) = Seq.nth points i + val (x, y) = (px x, px y) + fun b spot = setxy spot black + in + (* skip "invalid" vertices *) + if T.triangleOfVertex mesh i < 0 then () + else + ( b (x-1, y) + ; b (x, y-1) + ; b (x, y) + ; b (x, y+1) + ; b (x+1, y) + ) + end); + + { width = #width image + , height = #height image + , data = Seq.map Color.colorToPixel (#data image) + } + end + +end diff --git a/tests/mpllib/MkComplex.sml b/tests/mpllib/MkComplex.sml new file mode 100644 index 000000000..2ff3b9df2 --- /dev/null +++ b/tests/mpllib/MkComplex.sml @@ -0,0 +1,111 @@ +signature COMPLEX = +sig + structure R: REAL + type r = R.real + + type t + + val toString: t -> string + + val make: (r * r) -> t + val view: t -> (r * r) + + val defaultReal: Real.real -> t + val defaultImag: Real.real -> t + + val real: r -> t + val imag: r -> t + val rotateBy: r -> t (* rotateBy x = e^(ix) *) + + val zeroThreshold: r + val realIsZero: r -> bool + + val isZero: t -> bool + val isNonZero: t -> bool + + val zero: t + val i: t + + val magnitude: t -> r + + val ~ : t -> t + val - : t * t -> t + val + : t * t -> t + val * : t * t -> t + + val scale: r * t -> t +end + + +functor MkComplex(R: REAL): COMPLEX = +struct + structure R = R + open R + type r = real + + val fromLarge = fromLarge IEEEReal.TO_NEAREST + + datatype t = C of {re: real, im: real} + + val rtos = fmt (StringCvt.FIX (SOME 8)) + + fun toString (C {re, im}) = + let + val (front, re) = if Int.< (R.sign re, 0) then ("-", R.~ re) else ("", re) + val (middle, im) = + if Int.< (R.sign im, 0) then ("-", R.~ im) else ("+", im) + in + front ^ rtos re ^ middle ^ rtos im ^ "i" + end + + fun make (re, im) = C {re = re, im = im} + + fun view (C {re, im}) = (re, im) + + val zeroThreshold = fromLarge 0.00000001 + fun realIsZero x = R.abs x < zeroThreshold + + fun magnitude (C {re, im}) = + R.Math.sqrt (R.+ (R.* (re, re), R.* (im, im))) + + fun isZero (C {re, im}) = realIsZero re andalso realIsZero im + + fun isNonZero c = + not (isZero c) + + fun rotateBy r = + C {re = Math.cos r, im = Math.sin r} + + fun real r = + C {re = r, im = fromLarge 0.0} + fun imag i = + C {re = fromLarge 0.0, im = i} + + fun defaultReal r = + real (fromLarge r) + fun defaultImag r = + imag (fromLarge r) + + val zero = C {re = fromLarge 0.0, im = fromLarge 0.0} + val i = C {re = fromLarge 0.0, im = fromLarge 1.0} + + fun neg (C {re, im}) = + C {re = ~re, im = ~im} + + fun add (C x, C y) = + C {re = #re x + #re y, im = #im x + #im y} + + fun sub (C x, C y) = + C {re = #re x - #re y, im = #im x - #im y} + + fun mul (C {re = a, im = b}, C {re = c, im = d}) = + C {re = a * c - b * d, im = a * d + b * c} + + fun scale (r, C {re, im}) = + C {re = r * re, im = r * im} + + val ~ = neg + val op- = sub + val op+ = add + val op* = mul +end diff --git a/tests/mpllib/MkGrep.sml b/tests/mpllib/MkGrep.sml new file mode 100644 index 000000000..d1438ac8e --- /dev/null +++ b/tests/mpllib/MkGrep.sml @@ -0,0 +1,105 @@ +functor MkGrep (Seq: SEQUENCE) : +sig + val grep: char ArraySequence.t (* pattern *) + -> char ArraySequence.t (* source text *) + -> (int * int) ArraySequence.t (* output line ranges *) +end = +struct + + structure ASeq = ArraySequence + + type 'a seq = 'a ASeq.t + +(* + fun lines (s: char seq) : (char seq) Seq.seq = + let + val n = ASeq.length s + val indices = Seq.tabulate (fn i => i) n + fun isNewline i = (ASeq.nth s i = #"\n") + val locs = Seq.filter isNewline indices + val m = Seq.length locs + + fun line i = + let + val lo = (if i = 0 then 0 else 1 + Seq.nth locs (i-1)) + val hi = (if i = m then n else Seq.nth locs i) + in + ASeq.subseq s (lo, hi-lo) + end + in + Seq.tabulate line (m+1) + end +*) + + (* check if line[i..] matches the pattern *) + fun checkMatch pattern line i = + (i + ASeq.length pattern <= ASeq.length line) andalso + let + val m = ASeq.length pattern + (* pattern[j..] matches line[i+j..] *) + fun matchesFrom j = + (j >= m) orelse + ((ASeq.nth line (i+j) = ASeq.nth pattern j) andalso matchesFrom (j+1)) + in + matchesFrom 0 + end + +(* + fun grep pat source = + let + val granularity = CommandLineArgs.parseOrDefaultInt "granularity" 1000 + (* val ff = FindFirst.findFirst granularity *) + val ff = FindFirst.findFirstSerial + fun containsPat line = + case ff (0, ASeq.length line) (checkMatch pat line) of + NONE => false + | SOME _ => true + + val linesWithPat = Seq.filter containsPat (lines source) + val newln = Seq.singleton #"\n" + + fun choose i = + if Util.even i + then Seq.fromArraySeq (Seq.nth linesWithPat (i div 2)) + else newln + in + Seq.toArraySeq (Seq.flatten (Seq.tabulate choose (2 * Seq.length linesWithPat))) + end +*) + + fun isNewline c = (c = #"\n") + val ff = FindFirst.findFirst 1000 + + fun grep pat s = + let + fun makeLine (start, stop) = ASeq.subseq s (start, stop-start) + fun containsPat (start, stop) = + case ff (0, stop-start) (checkMatch pat (makeLine (start, stop))) of + NONE => NONE + | SOME _ => SOME (start, stop) + + val s = Seq.fromArraySeq s + val n = Seq.length s + + val idx = Seq.filter (isNewline o Seq.nth s) (Seq.tabulate (fn i => i) n) + (* val idx = + Seq.mapOption + (fn i => if isNewline (Seq.nth s i) then SOME i else NONE) + (Seq.tabulate (fn i => i) n) *) + + val m = Seq.length idx + + fun line i = + let + val start = if i = 0 then 0 else Seq.nth idx (i-1) + val stop = if i = m then n else Seq.nth idx i + in + (start, stop) + end + + in + (* Seq.toArraySeq (Seq.filter containsPat (Seq.tabulate line (m+1))) *) + Seq.toArraySeq (Seq.mapOption containsPat (Seq.tabulate line (m+1))) + end + +end diff --git a/tests/mpllib/NearestNeighbors.sml b/tests/mpllib/NearestNeighbors.sml new file mode 100644 index 000000000..46889ed7f --- /dev/null +++ b/tests/mpllib/NearestNeighbors.sml @@ -0,0 +1,298 @@ +structure NearestNeighbors: +sig + type point = Geometry2D.point + type 'a seq = 'a ArraySlice.slice + + type tree + type t = tree * point seq + + (* makeTree leafSize points *) + val makeTree: int -> point seq -> t + + val nearestNeighbor: t -> point -> int (* id of nearest neighbor *) + val nearestNeighborOfId: t -> int -> int + + (* allNearestNeighbors grain quadtree *) + val allNearestNeighbors: int -> t -> int seq +end = +struct + + structure A = Array + structure AS = ArraySlice + + type 'a seq = 'a ArraySlice.slice + structure G = Geometry2D + type point = G.point + + fun par4 (a, b, c, d) = + let + val ((ar, br), (cr, dr)) = + ForkJoin.par (fn _ => ForkJoin.par (a, b), + fn _ => ForkJoin.par (c, d)) + in + (ar, br, cr, dr) + end + + datatype tree = + Leaf of { anchor : point + , width : real + , vertices : int seq (* indices of original point seq *) + } + | Node of { anchor : point + , width : real + , count : int + , children : tree seq + } + + type t = tree * point seq + + fun count t = + case t of + Leaf {vertices, ...} => AS.length vertices + | Node {count, ...} => count + + fun anchor t = + case t of + Leaf {anchor, ...} => anchor + | Node {anchor, ...} => anchor + + fun width t = + case t of + Leaf {width, ...} => width + | Node {width, ...} => width + + fun boxOf t = + case t of + Leaf {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + | Node {anchor=(x,y), width, ...} => (x, y, x+width, y+width) + + fun indexApp grain f t = + let + fun downSweep offset t = + case t of + Leaf {vertices, ...} => + AS.appi (fn (i, v) => f (offset + i, v)) vertices + | Node {children, ...} => + let + fun q i = AS.sub (children, i) + fun qCount i = count (q i) + val offset0 = offset + val offset1 = offset0 + qCount 0 + val offset2 = offset1 + qCount 1 + val offset3 = offset2 + qCount 2 + in + if count t <= grain then + ( downSweep offset0 (q 0) + ; downSweep offset1 (q 1) + ; downSweep offset2 (q 2) + ; downSweep offset3 (q 3) + ) + else + ( par4 + ( fn _ => downSweep offset0 (q 0) + , fn _ => downSweep offset1 (q 1) + , fn _ => downSweep offset2 (q 2) + , fn _ => downSweep offset3 (q 3) + ) + ; () + ) + end + in + downSweep 0 t + end + + fun indexMap grain f t = + let + val result = ForkJoin.alloc (count t) + val _ = indexApp grain (fn (i, v) => A.update (result, i, f (i, v))) t + in + AS.full result + end + + fun flatten grain t = indexMap grain (fn (_, v) => v) t + + (* val lowerTime = ref Time.zeroTime + val upperTime = ref Time.zeroTime + fun addTm r t = + if Primitives.numberOfProcessors = 1 then + r := Time.+ (!r, t) + else () + fun clearAndReport r name = + (print (name ^ " " ^ Time.fmt 4 (!r) ^ "\n"); r := Time.zeroTime) *) + + (* Make a tree where all points are in the specified bounding box. *) + fun makeTreeBounded leafSize (verts : point seq) (idx : int seq) ((xLeft, yBot) : G.point) width = + if AS.length idx <= leafSize then + Leaf { anchor = (xLeft, yBot) + , width = width + , vertices = idx + } + else let + val qw = width/2.0 (* quadrant width *) + val center = (xLeft + qw, yBot + qw) + + val ((sorted, offsets), tm) = Util.getTime (fn () => + CountingSort.sort idx (fn i => + G.quadrant center (Seq.nth verts (Seq.nth idx i))) 4) + + (* val _ = + if AS.length idx >= 4 * leafSize then + addTm upperTime tm + else + addTm lowerTime tm *) + + fun quadrant i = + let + val start = AS.sub (offsets, i) + val len = AS.sub (offsets, i+1) - start + val childIdx = AS.subslice (sorted, start, SOME len) + val qAnchor = + case i of + 0 => (xLeft + qw, yBot + qw) + | 1 => (xLeft, yBot + qw) + | 2 => (xLeft, yBot) + | _ => (xLeft + qw, yBot) + in + makeTreeBounded leafSize verts childIdx qAnchor qw + end + + (* val children = Seq.tabulate (Perf.grain 1) quadrant 4 *) + val (a, b, c, d) = + if AS.length idx <= 100 then + (quadrant 0, quadrant 1, quadrant 2, quadrant 3) + else + par4 + ( fn _ => quadrant 0 + , fn _ => quadrant 1 + , fn _ => quadrant 2 + , fn _ => quadrant 3 ) + val children = AS.full (Array.fromList [a,b,c,d]) + in + Node { anchor = (xLeft, yBot) + , width = width + , count = AS.length idx + , children = children + } + end + + fun loop (lo, hi) b f = + if (lo >= hi) then b else loop (lo+1, hi) (f (b, lo)) f + + fun reduce grain f b (get, lo, hi) = + if hi - lo <= grain then + loop (lo, hi) b (fn (b, i) => f (b, get i)) + else let + val mid = lo + (hi-lo) div 2 + val (l,r) = ForkJoin.par + ( fn _ => reduce grain f b (get, lo, mid) + , fn _ => reduce grain f b (get, mid, hi) + ) + in + f (l, r) + end + + fun makeTree leafSize (verts : point seq) = + if AS.length verts = 0 then raise Fail "makeTree with 0 points" else + let + (* calculate the bounding box *) + fun maxPt ((x1,y1),(x2,y2)) = (Real.max (x1, x2), Real.max (y1, y2)) + fun minPt ((x1,y1),(x2,y2)) = (Real.min (x1, x2), Real.min (y1, y2)) + fun getPt i = Seq.nth verts i + val (xLeft,yBot) = reduce 10000 minPt (Real.posInf, Real.posInf) (getPt, 0, AS.length verts) + val (xRight,yTop) = reduce 10000 maxPt (Real.negInf, Real.negInf) (getPt, 0, AS.length verts) + val width = Real.max (xRight-xLeft, yTop-yBot) + + val idx = Seq.tabulate (fn i => i) (Seq.length verts) + val result = makeTreeBounded leafSize verts idx (xLeft, yBot) width + in + (* clearAndReport upperTime "upper sort time"; *) + (* clearAndReport lowerTime "lower sort time"; *) + (result, verts) + end + + (* ======================================================================== *) + + fun constrain (x : real) (lo, hi) = + if x < lo then lo + else if x > hi then hi + else x + + fun distanceToBox (x,y) (xLeft, yBot, xRight, yTop) = + G.distance (x,y) (constrain x (xLeft, xRight), constrain y (yBot, yTop)) + + val dummyBest = (~1, Real.posInf) + + + (** The function isSamePoint given as argument indicates whether or not some + * other index is the same as the input point p. This is important for + * querying nearest neighbors of points already in the set, for example + * nearestNeighborOfId below. For query points outside of the set, + * isSamePoint can always return false. + *) + fun nearestNeighbor_ (t : tree, pts) (p: G.point, isSamePoint: int -> bool) = + let + fun pt i = Seq.nth pts i + + fun refineNearest (qi, (bestPt, bestDist)) = + if isSamePoint qi then (bestPt, bestDist) else + let + val qDist = G.distance p (pt qi) + in + if qDist < bestDist + then (qi, qDist) + else (bestPt, bestDist) + end + + fun search (best as (_, bestDist : real)) t = + if distanceToBox p (boxOf t) > bestDist then best else + case t of + Leaf {vertices, ...} => + AS.foldl refineNearest best vertices + | Node {anchor=(x,y), width, children, ...} => + let + val qw = width/2.0 + val center = (x+qw, y+qw) + + (* search the quadrant that p is in first *) + val heuristicOrder = + case G.quadrant center p of + 0 => [0,1,2,3] + | 1 => [1,0,2,3] + | 2 => [2,1,3,0] + | _ => [3,0,2,1] + + fun child i = AS.sub (children, i) + fun refine (i, best) = search best (child i) + in + List.foldl refine best heuristicOrder + end + + val (best, _) = search dummyBest t + in + best + end + + + fun nearestNeighborOfId (tree, pts) pi = + nearestNeighbor_ (tree, pts) (Seq.nth pts pi, fn qi => pi = qi) + + fun nearestNeighbor (tree, pts) p = + nearestNeighbor_ (tree, pts) (p, fn _ => false) + + + fun allNearestNeighbors grain (t, pts) = + let + val n = Seq.length pts + val idxs = flatten 10000 t + val nn = ForkJoin.alloc n + in + ForkJoin.parfor grain (0, n) (fn i => + let + val j = Seq.nth idxs i + in + A.update (nn, j, nearestNeighborOfId (t, pts) j) + end); + AS.full nn + end + +end diff --git a/tests/mpllib/NewWaveIO.sml b/tests/mpllib/NewWaveIO.sml new file mode 100644 index 000000000..d066c246d --- /dev/null +++ b/tests/mpllib/NewWaveIO.sml @@ -0,0 +1,262 @@ +structure NewWaveIO: +sig + (* A sound is a sequence of samples at the given + * sample rate, sr (measured in Hz). + * Each sample is in range [-1.0, +1.0]. *) + type sound = {sr: int, data: real Seq.t} + + val readSound: string -> sound + val writeSound: sound -> string -> unit + + (* Essentially mu-law compression. Normalizes to [-1,+1] and compresses + * the dynamic range slightly. The boost parameter should be >= 1. *) + val compress: real -> sound -> sound +end = +struct + + type sound = {sr: int, data: real Seq.t} + + structure AS = ArraySlice + + fun err msg = + raise Fail ("NewWaveIO: " ^ msg) + + fun compress boost (snd as {sr, data}: sound) = + if boost < 1.0 then + err "Compression boost parameter must be at least 1" + else + let + (* maximum amplitude *) + val maxA = + SeqBasis.reduce 10000 Real.max 1.0 (0, Seq.length data) + (fn i => Real.abs (Seq.nth data i)) + + (* a little buffer of intensity to avoid distortion *) + val maxA' = 1.05 * maxA + + val scale = Math.ln (1.0 + boost) + + fun transfer x = + let + (* normalized *) + val x' = Real.abs (x / maxA') + in + (* compressed *) + Real.copySign (Math.ln (1.0 + boost * x') / scale, x) + end + in + { sr = sr + , data = Seq.map transfer data + } + end + + fun readSound path = + let + val bytes = ReadFile.contentsBinSeq path + + fun findChunk chunkId offset = + if offset > (Seq.length bytes - 8) then + err "unexpected end of file" + else if Parse.r32b bytes offset = chunkId then + offset (* found it! *) + else + let + val chunkSize = Word32.toInt (Parse.r32l bytes (offset+4)) + val chunkName = + CharVector.tabulate (4, fn i => + Char.chr (Word8.toInt (Seq.nth bytes (offset+i)))) + in + if chunkSize < 0 then + err ("error parsing chunk size of '" ^ chunkName ^ "' chunk") + else + findChunk chunkId (offset + 8 + chunkSize) + end + + (* ======================================================= + * RIFF header, 12 bytes + *) + + val _ = + if Seq.length bytes >= 12 then () + else err "not enough bytes for RIFF header" + + val offset = 0 + val riff = 0wx52494646 (* ascii "RIFF", big endian *) + val _ = + if Parse.r32b bytes offset = riff then () + else err "expected 'RIFF' chunk ID" + + (* the chunkSize should be the size of the "rest of the file" *) + val offset = 4 + val chunkSize = Word32.toInt (Parse.r32l bytes offset) + + val totalFileSize = 8 + chunkSize + val _ = + if Seq.length bytes >= totalFileSize then () + else err ("expected " ^ Int.toString totalFileSize ^ + " bytes but the file is only " ^ + Int.toString (Seq.length bytes)) + + val offset = 8 + val wave = 0wx57415645 (* ascii "WAVE" big endian *) + val _ = + if Parse.r32b bytes offset = wave then () + else err "expected 'WAVE' format" + + val offset = 12 + + (* ======================================================= + * fmt subchunk, should be at least 8+16 bytes total for PCM + *) + + val fmtId = 0wx666d7420 (* ascii "fmt " big endian *) + val fmtChunkStart = findChunk fmtId offset + val offset = fmtChunkStart + + val _ = + if Parse.r32b bytes offset = fmtId then () + else err "expected 'fmt ' chunk ID" + + val offset = offset+4 + val fmtChunkSize = Word32.toInt (Parse.r32l bytes offset) + val _ = + if fmtChunkSize >= 16 then () + else err "expected 'fmt ' chunk to be at least 16 bytes" + + val offset = offset+4 + val audioFormat = Word16.toInt (Parse.r16l bytes offset) + val _ = + if audioFormat = 1 then () + else err ("expected PCM audio format, but found 0x" + ^ Int.fmt StringCvt.HEX audioFormat) + + val offset = offset+2 + val numChannels = Word16.toInt (Parse.r16l bytes offset) + + val offset = offset+2 + val sampleRate = Word32.toInt (Parse.r32l bytes offset) + + val offset = offset+4 + val byteRate = Word32.toInt (Parse.r32l bytes offset) + + val offset = offset+4 + val blockAlign = Word16.toInt (Parse.r16l bytes offset) + + val offset = offset+2 + val bitsPerSample = Word16.toInt (Parse.r16l bytes offset) + val bytesPerSample = bitsPerSample div 8 + + val offset = fmtChunkStart+8+fmtChunkSize + + (* ======================================================= + * data subchunk, should be the rest of the file + *) + + val dataId = 0wx64617461 (* ascii "data" big endian *) + val dataChunkStart = findChunk dataId offset + val offset = dataChunkStart + + val _ = + if Parse.r32b bytes offset = dataId then () + else err "expected 'data' chunk ID" + + val offset = offset + 4 + val dataSize = Word32.toInt (Parse.r32l bytes offset) + val _ = + if dataChunkStart + 8 + dataSize <= totalFileSize then () + else err ("badly formatted data chunk: unexpected end-of-file") + + val dataStart = dataChunkStart + 8 + + val numSamples = (dataSize div numChannels) div bytesPerSample + + fun readSample8 pos = + Real.fromInt (Word8.toInt (Seq.nth bytes pos) - 128) / 256.0 + fun readSample16 pos = + Real.fromInt (Word16.toIntX (Parse.r16l bytes pos)) / 32768.0 + + val readSample = + case bytesPerSample of + 1 => readSample8 + | 2 => readSample16 + | _ => err "only 8-bit and 16-bit samples supported at the moment" + + (* jth sample of ith channel *) + fun readChannel i j = + readSample (dataStart + j * (numChannels * bytesPerSample) + + i * bytesPerSample) + + val rawData = + AS.full (SeqBasis.tabulate 1000 (0, numSamples) (fn j => + Util.loop (0, numChannels) 0.0 (fn (s, i) => s + readChannel i j))) + + val rawResult = {sr = sampleRate, data = rawData} + in + if numChannels = 1 then + rawResult + else + ( TextIO.output (TextIO.stdErr, + "[WARN] mixing " ^ Int.toString numChannels + ^ " channels down to mono\n") + ; compress 1.0 rawResult + ) + end + + (* ====================================================================== *) + + fun writeSound ({sr, data}: sound) path = + let + val srw = Word32.fromInt sr + + val file = BinIO.openOut path + + val w32b = ExtraBinIO.w32b file + val w32l = ExtraBinIO.w32l file + val w16l = ExtraBinIO.w16l file + + val totalBytes = + 44 + (Seq.length data * 2) + + val riffId = 0wx52494646 (* ascii "RIFF", big endian *) + val fmtId = 0wx666d7420 (* ascii "fmt " big endian *) + val wave = 0wx57415645 (* ascii "WAVE" big endian *) + val dataId = 0wx64617461 (* ascii "data" big endian *) + in + (* ============================ + * RIFF header, 12 bytes *) + w32b riffId; + w32l (Word32.fromInt (totalBytes - 8)); + w32b wave; + + (* ============================ + * fmt subchunk, 24 bytes *) + w32b fmtId; + w32l 0w16; (* 16 remaining bytes in subchunk *) + w16l 0w1; (* audio format PCM = 1 *) + w16l 0w1; (* 1 channel (mono) *) + w32l srw; (* sample rate *) + w32l (srw * 0w2); (* "byte rate" = sampleRate * numChannels * bytesPerSample *) + w16l 0w2; (* "block align" = numChannels * bytesPerSample *) + w16l 0w16; (* bits per sample *) + + (* ============================ + * data subchunk: rest of file *) + w32b dataId; + w32l (Word32.fromInt (2 * Seq.length data)); (* number of data bytes *) + + Util.for (0, Seq.length data) (fn i => + let + val s = Seq.nth data i + val s = + if s < ~1.0 then ~1.0 + else if s > 1.0 then 1.0 + else s + val s = Real.round (s * 32767.0) + val s = if s < 0 then s + 65536 else s + in + w16l (Word16.fromInt s) + end); + + BinIO.closeOut file + end +end diff --git a/tests/mpllib/OffsetSearch.sml b/tests/mpllib/OffsetSearch.sml new file mode 100644 index 000000000..32b56c9a2 --- /dev/null +++ b/tests/mpllib/OffsetSearch.sml @@ -0,0 +1,26 @@ +structure OffsetSearch: +sig + (** `indexSearch (start, stop, offsetFn) k` returns which inner sequence + * contains index `k`. The tuple arg defines a sequence of offsets. + *) + val indexSearch: int * int * (int -> int) -> int -> int +end = +struct + + fun indexSearch (start, stop, offset: int -> int) k = + case stop-start of + 0 => + raise Fail "OffsetSearch.indexSearch: should not have hit 0" + | 1 => + start + | n => + let + val mid = start + (n div 2) + in + if k < offset mid then + indexSearch (start, mid, offset) k + else + indexSearch (mid, stop, offset) k + end + +end diff --git a/tests/mpllib/OldDelayedSeq.sml b/tests/mpllib/OldDelayedSeq.sml new file mode 100644 index 000000000..e78ec6304 --- /dev/null +++ b/tests/mpllib/OldDelayedSeq.sml @@ -0,0 +1,922 @@ +structure OldDelayedSeq = +struct + + val for = Util.for + + val par = ForkJoin.par + val parfor = ForkJoin.parfor + val alloc = ForkJoin.alloc + + val gran = 5000 + + + val blockSize = 10000 + fun numBlocks n = Util.ceilDiv n blockSize + + + structure A = + struct + open Array + type 'a t = 'a array + fun nth a i = sub (a, i) + end + + + structure AS = + struct + open ArraySlice + type 'a t = 'a slice + fun nth a i = sub (a, i) + end + + + (* Using given offsets, find which inner sequence contains index [k] *) + fun indexSearch (start, stop, offset: int -> int) k = + case stop-start of + 0 => + raise Fail "OldDelayedSeq.indexSearch: should not have hit 0" + | 1 => + start + | n => + let + val mid = start + (n div 2) + in + if k < offset mid then + indexSearch (start, mid, offset) k + else + indexSearch (mid, stop, offset) k + end + + + (* ======================================================================= *) + + + structure Stream:> + sig + type 'a t + type 'a stream = 'a t + + val tabulate: (int -> 'a) -> 'a stream + val map: ('a -> 'b) -> 'a stream -> 'b stream + val mapIdx: (int * 'a -> 'b) -> 'a stream -> 'b stream + val zipWith: ('a * 'b -> 'c) -> 'a stream * 'b stream -> 'c stream + val iteratePrefixes: ('b * 'a -> 'b) -> 'b -> 'a stream -> 'b stream + val iteratePrefixesIncl: ('b * 'a -> 'b) -> 'b -> 'a stream -> 'b stream + + val applyIdx: int * 'a stream -> (int * 'a -> unit) -> unit + val iterate: ('b * 'a -> 'b) -> 'b -> int * 'a stream -> 'b + + val makeBlockStreams: + { numChildren: int + , offset: int -> int + , getElem: int -> int -> 'a + } + -> (int -> 'a stream) + + end = + struct + + (** A stream is a generator for a stateful trickle function: + * trickle = stream () + * x0 = trickle 0 + * x1 = trickle 1 + * x2 = trickle 2 + * ... + * + * The integer argument is just an optimization (it could be packaged + * up into the state of the trickle function, but doing it this + * way is more efficient). Requires passing `i` on the ith call + * to trickle. + *) + type 'a t = unit -> int -> 'a + type 'a stream = 'a t + + + fun tabulate f = + fn () => f + + + fun map g stream = + fn () => + let + val trickle = stream () + in + g o trickle + end + + + fun mapIdx g stream = + fn () => + let + val trickle = stream () + in + fn idx => g (idx, trickle idx) + end + + + fun applyIdx (length, stream) g = + let + val trickle = stream () + fun loop i = + if i >= length then () else (g (i, trickle i); loop (i+1)) + in + loop 0 + end + + + fun iterate g b (length, stream) = + let + val trickle = stream () + fun loop b i = + if i >= length then b else loop (g (b, trickle i)) (i+1) + in + loop b 0 + end + + + fun iteratePrefixes g b stream = + fn () => + let + val trickle = stream () + val stuff = ref b + in + fn idx => + let + val acc = !stuff + val elem = trickle idx + val acc' = g (acc, elem) + in + stuff := acc'; + acc + end + end + + + fun iteratePrefixesIncl g b stream = + fn () => + let + val trickle = stream () + val stuff = ref b + in + fn idx => + let + val acc = !stuff + val elem = trickle idx + val acc' = g (acc, elem) + in + stuff := acc'; + acc' + end + end + + + fun zipWith g (s1, s2) = + fn () => + let + val trickle1 = s1 () + val trickle2 = s2 () + in + fn idx => g (trickle1 idx, trickle2 idx) + end + + + fun makeBlockStreams + { numChildren: int + , offset: int -> int + , getElem: int -> int -> 'a + } = + let + fun getBlock blockIdx = + let + val lo = blockIdx * blockSize + val firstOuterIdx = indexSearch (0, numChildren, offset) lo + (* val firstInnerIdx = lo - offset firstOuterIdx *) + + fun advanceUntilNonEmpty i = + if i >= numChildren orelse offset i <> offset (i+1) then + i + else + advanceUntilNonEmpty (i+1) + in + fn () => + let + val outerIdx = ref firstOuterIdx + (* val innerIdx = ref firstInnerIdx *) + in + fn idx => + let + val i = !outerIdx + val j = lo + idx - offset i + (* val j = !innerIdx *) + val elem = getElem i j + in + if offset i + j + 1 < offset (i+1) then + (* innerIdx := j+1 *) () + else + ( outerIdx := advanceUntilNonEmpty (i+1) + (* ; innerIdx := 0 *) + ); + + elem + end + end + end + + in + getBlock + end + + + end + + + (* ======================================================================= *) + + + datatype 'a flat = + Full of 'a AS.t + + (** [Delay (start, stop, lookup)] + * start index is inclusive, stop is exclusive + * length is [stop-start] + * the ith element lives at [lookup (start+i)] + *) + | Delay of int * int * (int -> 'a) + + + datatype 'a seq = + Flat of 'a flat + + (** [Nest (length, getBlock)] + * The block sizes are implicit: [gran] + * The number of block is implicit: ceil(length / gran) + *) + | Nest of int * (int -> 'a Stream.t) + + + fun makeBlocks s = + case s of + Flat (Full slice) => + let + fun blockStream b = + Stream.tabulate (fn i => AS.nth slice (b*blockSize + i)) + in + blockStream + end + | Flat (Delay (start, _, f)) => + let + fun blockStream b = + Stream.tabulate (fn i => f (start + b*blockSize + i)) + in + blockStream + end + | Nest (_, blockStream) => + blockStream + + + fun subseq s (i, k) = + case s of + Flat (Delay (start, stop, f)) => Flat (Delay (start+i, start+i+k, f)) + | Flat (Full slice) => Flat (Full (AS.subslice (slice, i, SOME k))) + | _ => raise Fail "delay subseq (Nest) not implemented yet" + + + fun flatNth (s: 'a flat) k = + case s of + Full slice => + AS.nth slice k + | Delay (i, j, f) => + f (i+k) + + + fun nth (s: 'a seq) k = + case s of + Flat xs => + flatNth xs k + | Nest _ => + raise Fail "delay nth (Nest) not implemented yet" + + + fun flatLength s = + case s of + Full slice => + AS.length slice + | Delay (i, j, f) => + j-i + + + fun length s = + case s of + Flat xs => + flatLength xs + | Nest (n, _) => + n + + + fun flatSubseqIdxs s (i, j) = + case s of + Full slice => + Full (AS.subslice (slice, i, SOME (j-i))) + | Delay (start, stop, f) => + Delay (start+i, start+j, f) + + + fun flatIterateIdx (g: 'b * (int * 'a) -> 'b) (b: 'b) s = + case s of + Full slice => + SeqBasis.foldl g b (0, AS.length slice) (fn i => (i, AS.nth slice i)) + | Delay (i, j, f) => + SeqBasis.foldl g b (i, j) (fn k => (k-i, f k)) + + + fun iterateIdx g b s = + case s of + Flat xs => + flatIterateIdx g b xs + | Nest _ => + raise Fail "delay iterateIdx (Nest) NYI" + + + fun flatIterate g b s = + flatIterateIdx (fn (b, (_, x)) => g (b, x)) b s + + + fun iterate g b s = + iterateIdx (fn (b, (_, x)) => g (b, x)) b s + + + fun applyIdx (s: 'a seq) (g: int * 'a -> unit) = + case s of + Flat (Full slice) => + parfor gran (0, AS.length slice) (fn i => g (i, AS.nth slice i)) + | Flat (Delay (i, j, f)) => + parfor gran (0, j-i) (fn k => g (k, f (i+k))) + | Nest (n, getBlock) => + parfor 1 (0, numBlocks n) (fn i => + let + val lo = i*blockSize + val hi = Int.min (lo+blockSize, n) + in + Stream.applyIdx (hi-lo, getBlock i) (fn (j, x) => g (lo+j, x)) + end) + + + fun apply (s: 'a seq) (g: 'a -> unit) = + applyIdx s (fn (_, x) => g x) + + + fun unravelAndCopy (s: 'a seq): 'a array = + let + val n = length s + val result = alloc n + in + applyIdx s (fn (i, x) => A.update (result, i, x)); + result + end + + + fun force s = + case s of + Flat (Full _) => s + | _ => Flat (Full (AS.full (unravelAndCopy s))) + + + fun forceFlat s = + case s of + Flat xx => xx + | _ => Full (AS.full (unravelAndCopy s)) + + + fun tabulate f n = + Flat (Delay (0, n, f)) + + + fun fromList xs = + Flat (Full (AS.full (Array.fromList xs))) + + + fun % xs = + fromList xs + + + fun singleton x = + Flat (Delay (0, 1, fn _ => x)) + + + fun $ x = + singleton x + + + fun empty () = + fromList [] + + + fun fromArraySeq a = + Flat (Full a) + + + fun range (i, j) = + Flat (Delay (i, j, fn k => k)) + + + fun toArraySeq s = + case s of + Flat (Full x) => x + | _ => AS.full (unravelAndCopy s) + + + fun flatMap g s = + case s of + Full slice => + Delay (0, AS.length slice, g o AS.nth slice) + | Delay (i, j, f) => + Delay (i, j, g o f) + + + fun map g s = + case s of + Flat xs => + Flat (flatMap g xs) + | Nest (n, getBlock) => + Nest (n, Stream.map g o getBlock) + + + fun flatMapIdx g s = + case s of + Full slice => + Delay (0, AS.length slice, fn i => g (i, AS.nth slice i)) + | Delay (i, j, f) => + Delay (i, j, fn k => g (k-i, f i)) + + + fun mapIdx g s = + case s of + Flat xs => + Flat (flatMapIdx g xs) + | Nest (n, getBlock) => + Nest (n, fn i => + Stream.mapIdx (fn (j, x) => g (i*blockSize + j, x)) (getBlock i)) + + + fun enum s = + mapIdx (fn (i,x) => (i,x)) s + + + fun flatten (ss: 'a seq seq) = + let + val numChildren = length ss + val children: 'a flat array = unravelAndCopy (map forceFlat ss) + val offsets = + SeqBasis.scan gran op+ 0 (0, numChildren) (flatLength o A.nth children) + val totalLen = A.nth offsets numChildren + fun offset i = A.nth offsets i + val getBlock = + Stream.makeBlockStreams + { numChildren = numChildren + , offset = offset + , getElem = (fn i => fn j => flatNth (A.nth children i) j) + } + in + Nest (totalLen, getBlock) + end + + + fun flatRev s = + case s of + Full slice => + let + val n = AS.length slice + in + Delay (0, n, fn i => AS.nth slice (n-i-1)) + end + | Delay (i, j, f) => + Delay (i, j, fn k => f (j-k-1)) + + + fun rev s = + case s of + Flat xs => + Flat (flatRev xs) + | Nest _ => + raise Fail "delay rev (Nest) NYI" + + + fun reduceG newGran g b s = + case s of + Flat (Full slice) => + SeqBasis.reduce newGran g b (0, AS.length slice) (AS.nth slice) + | Flat (Delay (i, j, f)) => + SeqBasis.reduce newGran g b (i, j) f + | Nest (n, getBlock) => + let + val nb = numBlocks n + fun len i = + if i < nb-1 then blockSize else n - i*blockSize + in + SeqBasis.reduce 1 g b (0, nb) (fn i => + Stream.iterate g b (len i, getBlock i)) + end + + + fun reduce g b s = + reduceG gran g b s + + + (** ======================================================================= + * mapOption implementations + * + * first one delays output (like flattening) + * second one has eager output + *) + + fun mapOption1 (f: 'a -> 'b option) (s: 'a seq) : 'b seq = + let + val n = length s + val nb = numBlocks n + val getBlock: int -> 'a Stream.t = + makeBlocks s + + val results: 'b array = alloc n + + fun packBlock b = + let + val start = b*blockSize + val stop = Int.min (start+blockSize, n) + val size = stop-start + + fun doNext (off, x) = + case f x of + NONE => off + | SOME x' => (A.update (results, off, x'); off+1) + + val lastOffset = + Stream.iterate doNext start (size, getBlock b) + in + lastOffset - start + end + + val counts = SeqBasis.tabulate 1 (0, nb) packBlock + val offsets = SeqBasis.scan 10000 op+ 0 (0, A.length counts) (A.nth counts) + + val totalLen = A.nth offsets nb + fun offset i = A.nth offsets i + val getBlock = + Stream.makeBlockStreams + { numChildren = nb + , offset = offset + , getElem = (fn i => fn j => A.nth results (i*blockSize + j)) + } + in + Nest (totalLen, getBlock) + end + + + fun mapOption2 (f: 'a -> 'b option) (s: 'a seq) = + let + val n = length s + val nb = numBlocks n + val getBlock: int -> 'a Stream.t = + makeBlocks s + + val results: 'b array = alloc n + + fun packBlock b = + let + val start = b*blockSize + val stop = Int.min (start+blockSize, n) + val size = stop-start + + fun doNext (off, x) = + case f x of + NONE => off + | SOME x' => (A.update (results, off, x'); off+1) + + val lastOffset = + Stream.iterate doNext start (size, getBlock b) + in + lastOffset - start + end + + val counts = SeqBasis.tabulate 1 (0, nb) packBlock + val outOff = SeqBasis.scan 10000 op+ 0 (0, A.length counts) (A.nth counts) + val outSize = A.sub (outOff, nb) + + val result = alloc outSize + in + parfor (n div (Int.max (outSize, 1))) (0, nb) (fn i => + let + val soff = i * blockSize + val doff = A.sub (outOff, i) + val size = A.sub (outOff, i+1) - doff + in + Util.for (0, size) (fn j => + A.update (result, doff+j, A.sub (results, soff+j))) + end); + + Flat (Full (ArraySlice.full result)) + end + + + fun mapOption f s = + (* mapOption1 f s *) + mapOption2 f s + + + fun filterIdx p s = + case s of + Flat (Full slice) => + Flat (Full (AS.full (SeqBasis.filter gran + (0, AS.length slice) (* range *) + (AS.nth slice) (* index lookup *) + (fn k => p (k, AS.nth slice k)) (* index predicate *) + ))) + + | Flat (Delay (i, j, f)) => + Flat (Full (AS.full (SeqBasis.filter gran + (i, j) + f + (fn k => p (k, f k)) + ))) + + | _ => + filterIdx p (force s) + + + fun filter p s = + filterIdx (fn (_, x) => p x) s + + + fun inject (s, u) = + let + val base = unravelAndCopy s + in + apply u (fn (i, x) => A.update (base, i, x)); + Flat (Full (AS.full base)) + end + + + fun injectG _ (s, u) = + let + val base = unravelAndCopy s + in + apply u (fn (i, x) => A.update (base, i, x)); + Flat (Full (AS.full base)) + end + + + fun toList s = + iterate (fn (xs, x) => x :: xs) [] (rev s) + + + fun toString f s = + "<" ^ + String.concatWith "," + (iterate (fn (strs, next) => next :: strs) [] (map f (rev s))) + ^ ">" + + + fun append (s, t) = + flatten (tabulate (fn 0 => s | _ => t) 2) + + + (* Do the scan on a flat delayed sequence [f(i), f(i+1), ..., f(j-1)] *) + fun scanDelay g b (i, j, f) = + let + val n = j-i + val nb = numBlocks n + + val blockSums = + SeqBasis.tabulate 1 (0, nb) (fn blockIdx => + let + val blockStart = i + blockIdx*blockSize + val blockEnd = Int.min (j, blockStart + blockSize) + in + SeqBasis.foldl g b (blockStart, blockEnd) f + end) + + val partials = + SeqBasis.scan gran g b (0, nb) (A.nth blockSums) + + val total = A.nth partials nb + + fun getChild blockIdx = + let + val firstElem = A.nth partials blockIdx + val blockStart = i + blockIdx*blockSize + in + Stream.iteratePrefixes g firstElem + (Stream.tabulate (fn k => f (blockStart+k))) + end + in + ( Nest (n, getChild) + , total + ) + end + + + fun scanInclDelay g b (i, j, f) = + let + val n = j-i + val nb = numBlocks n + + val blockSums = + SeqBasis.tabulate 1 (0, nb) (fn blockIdx => + let + val blockStart = i + blockIdx*blockSize + val blockEnd = Int.min (j, blockStart + blockSize) + in + SeqBasis.foldl g b (blockStart, blockEnd) f + end) + + val partials = + SeqBasis.scan gran g b (0, nb) (A.nth blockSums) + + fun getChild blockIdx = + let + val firstElem = A.nth partials blockIdx + val blockStart = i + blockIdx*blockSize + in + Stream.iteratePrefixesIncl g firstElem + (Stream.tabulate (fn k => f (blockStart+k))) + end + in + Nest (n, getChild) + end + + + fun scanScan g b (n, getChild: int -> 'a Stream.t) = + let + val numChildren = Util.ceilDiv n blockSize + fun len i = + if i < numChildren-1 then blockSize else n - i*blockSize + + val childSums = + SeqBasis.tabulate 1 (0, numChildren) (fn childIdx => + Stream.iterate g b (len childIdx, getChild childIdx) + ) + + val partials = + SeqBasis.scan gran g b (0, numChildren) (A.nth childSums) + val total = A.nth partials numChildren + + fun getChild' childIdx = + let + val childStream = getChild childIdx + val initial = A.nth partials childIdx + in + Stream.iteratePrefixes g initial childStream + end + in + ( Nest (n, getChild') + , total + ) + end + + + fun scanScanIncl g b (n, getChild) = + let + val numChildren = Util.ceilDiv n blockSize + fun len i = + if i < numChildren-1 then blockSize else n - i*blockSize + + val childSums = + SeqBasis.tabulate 1 (0, numChildren) (fn childIdx => + Stream.iterate g b (len childIdx, getChild childIdx) + ) + + val partials = + SeqBasis.scan gran g b (0, numChildren) (A.nth childSums) + + fun getChild' childIdx = + let + val childStream = getChild childIdx + val initial = A.nth partials childIdx + in + Stream.iteratePrefixesIncl g initial childStream + end + in + Nest (n, getChild') + end + + + fun scan g b s = + case s of + Flat (Full slice) => + scanDelay g b (0, AS.length slice, AS.nth slice) + | Flat (Delay (i, j, f)) => + scanDelay g b (i, j, f) + | Nest (n, getChild) => + scanScan g b (n, getChild) + + + fun scanIncl g b s = + case s of + Flat (Full slice) => + scanInclDelay g b (0, AS.length slice, AS.nth slice) + | Flat (Delay (i, j, f)) => + scanInclDelay g b (i, j, f) + | Nest (n, getChild) => + scanScanIncl g b (n, getChild) + + + fun zipWithBothFlat g (i1, j1, f1) (i2, j2, f2) = + let + val n1 = j1-i1 + val n2 = j2-i2 + val n = Int.min (n1, n2) + in + Flat (Delay (0, n, fn i => g (f1 (i1+i), f2 (i2+i)))) + end + + + fun zipWithOneNest g (n1, getChild1) (i, j, f) = + let + val n2 = j-i + val _ = + if n1 = n2 then () else + raise Fail "OldDelayedSeq.zipWith lengths don't match" + + fun getChild childIdx = + let + val child1 = getChild1 childIdx + val lo = i + childIdx * blockSize + in + Stream.mapIdx (fn (k, x) => g (x, f (lo+k))) child1 + end + in + Nest (n1, getChild) + end + + + fun zipWithBothNest g (n1, getChild1) (n2, getChild2) = + let + val _ = + if n1 = n2 then () else + raise Fail "OldDelayedSeq.zipWith lengths don't match" + + fun getChild childIdx = + Stream.zipWith g (getChild1 childIdx, getChild2 childIdx) + in + Nest (n1, getChild) + end + + fun flip g (a, b) = g (b, a) + + fun zipWith g (s1, s2) = + case (s1, s2) of + (Flat (Delay xx), Flat (Delay yy)) => + zipWithBothFlat g xx yy + | (Flat (Full slice), Flat (Delay yy)) => + zipWithBothFlat g (0, AS.length slice, AS.nth slice) yy + | (Flat (Delay xx), Flat (Full slice)) => + zipWithBothFlat g xx (0, AS.length slice, AS.nth slice) + | (Flat (Full slice1), Flat (Full slice2)) => + zipWithBothFlat g + (0, AS.length slice1, AS.nth slice1) + (0, AS.length slice2, AS.nth slice2) + | (Nest xx, Flat (Full slice)) => + zipWithOneNest g xx (0, AS.length slice, AS.nth slice) + | (Nest xx, Flat (Delay yy)) => + zipWithOneNest g xx yy + | (Flat (Full slice), Nest xx) => + zipWithOneNest (flip g) xx (0, AS.length slice, AS.nth slice) + | (Flat (Delay yy), Nest xx) => + zipWithOneNest (flip g) xx yy + | (Nest xx, Nest yy) => + zipWithBothNest g xx yy + + + fun zip (a, b) = zipWith (fn x => x) (a, b) + + (* ===================================================================== *) + + exception NYI + exception Range + exception Size + + datatype 'a listview = NIL | CONS of 'a * 'a seq + datatype 'a treeview = EMPTY | ONE of 'a | PAIR of 'a seq * 'a seq + + type 'a ord = 'a * 'a -> order + type 'a t = 'a seq + + fun argmax x = raise NYI + fun collate x = raise NYI + fun collect x = raise NYI + fun drop x = raise NYI + fun equal x = raise NYI + fun iteratePrefixes x = raise NYI + fun iteratePrefixesIncl x = raise NYI + fun merge x = raise NYI + fun sort x = raise NYI + fun splitHead x = raise NYI + fun splitMid x = raise NYI + fun take x = raise NYI + fun update x = raise NYI + fun zipWith3 x = raise NYI + + fun filterSome x = raise NYI + fun foreach x = raise NYI + fun foreachG x = raise NYI + +end diff --git a/tests/mpllib/PPM.sml b/tests/mpllib/PPM.sml new file mode 100644 index 000000000..ba2481858 --- /dev/null +++ b/tests/mpllib/PPM.sml @@ -0,0 +1,280 @@ +(* Basic support for the netpbm .ppm file format. *) +structure PPM: +sig + type channel = Color.channel + type pixel = Color.pixel + + (* flat sequence; pixel (i, j) is at data[i*width + j] *) + type image = {height: int, width: int, data: pixel Seq.t} + type box = {topleft: int * int, botright: int * int} + + val elem: image -> (int * int) -> pixel + + val subimage: box -> image -> image + + (* `replace box image subimage` copies subimage into the image at the + * specified box *) + val replace: box -> image -> image -> image + + (* read the given .ppm file *) + val read: string -> image + + (* output this image to the given .ppm file *) + val write: string -> image -> unit + +end = +struct + + type 'a seq = 'a Seq.t + + type channel = Color.channel + type pixel = Color.pixel + type image = {height: int, width: int, data: pixel Seq.t} + type box = {topleft: int * int, botright: int * int} + + fun elem ({height, width, data}: image) (i, j) = + if i < 0 orelse i >= height orelse j < 0 orelse j >= width then + raise Subscript + else + Seq.nth data (i*width + j) + + fun subimage {topleft=(i1,j1), botright=(i2,j2)} image = + let + val w = j2-j1 + val h = i2-i1 + + fun newElem k = + let + val i = k div w + val j = k mod w + in + elem image (i1 + i, j1 + j) + end + in + { width = w + , height = h + , data = Seq.tabulate newElem (w * h) + } + end + + fun replace {topleft=(i1,j1), botright=(i2,j2)} image subimage = + let + fun newElem k = + let + val i = k div (#width image) + val j = k mod (#width image) + in + if i1 <= i andalso i < i2 andalso + j1 <= j andalso j < j2 + then elem subimage (i-i1, j-j1) + else elem image (i, j) + end + in + { width = #width image + , height = #height image + , data = Seq.tabulate newElem (#width image * #height image) + } + end + + (* utilities... *) + + fun niceify str = + if String.size str <= 10 then str + else String.substring (str, 0, 7) ^ "..." + + (* ============================== P3 format ============================== *) + + fun parse3 contents = + let + (* val tokens = Seq.fromList (String.tokens Char.isSpace contents) *) + (* val numToks = Seq.length tokens *) + val (numToks, tokRange) = Tokenize.tokenRanges Char.isSpace contents + fun tok i = + let + val (lo, hi) = tokRange i + in + Seq.subseq contents (lo, hi-lo) + end + fun strTok i = + Parse.parseString (tok i) + + val filetype = strTok 0 + val _ = + if filetype = "P3" then () + else raise Fail "should not happen" + + fun intTok thingName i = + let + fun err () = + raise Fail ("error parsing .ppm file: cannot parse " + ^ thingName ^ " from '" + ^ niceify (strTok i) ^ "'") + in + case Parse.parseInt (tok i) of + NONE => err () + | SOME x => if x >= 0 then x else err () + end + + val width = intTok "width" 1 + val height = intTok "height" 2 + val resolution = intTok "max color value" 3 + + val numPixels = width * height + val numChannels = 3 * width * height + + val _ = + if numToks = numChannels + 4 then () + else raise Fail ("error parsing .ppm file: too few color channels") + + fun normalize (c : int) = + Real.ceil ((Real.fromInt c / Real.fromInt resolution) * 255.0) + + fun chan i = + let + val c = intTok "channel" (i + 4) + val _ = + if c <= resolution then () + else raise Fail ("error parsing .ppm file: channel value " + ^ Int.toString c ^ " greater than resolution " + ^ Int.toString resolution) + in + Word8.fromInt (normalize c) + end + + fun pixel i = + {red = chan (3*i), green = chan (3*i + 1), blue = chan (3*i + 2)} + in + { width = width + , height = height + , data = Seq.tabulate pixel (width * height) + } + end + + (* ============================== P6 format ============================== *) + + fun parse6 contents = + let + val filetype = Parse.parseString (Seq.subseq contents (0, 2)) + val _ = + if filetype = "P6" then () + else raise Fail "should not happen" + + fun findFirst p i = + if i >= Seq.length contents then + NONE + else if p (Seq.nth contents i) then + SOME i + else + findFirst p (i+1) + + fun findToken start = + case findFirst (not o Char.isSpace) start of + NONE => NONE + | SOME i => + case findFirst Char.isSpace i of + NONE => SOME (i, Seq.length contents) + | SOME j => SOME (i, j) + + (* start must be on a space *) + fun chompToken start = + case findToken start of + NONE => NONE + | SOME (i, j) => SOME (Seq.subseq contents (i, j-i), j) + + fun chompInt thingName i = + case chompToken i of + NONE => raise Fail ("error parsing .ppm file: missing " ^ thingName) + | SOME (s, j) => + case Parse.parseInt s of + NONE => raise Fail ("error parsing .ppm file: cannot parse " + ^ thingName ^ " from '" + ^ niceify (Parse.parseString s) ^ "'") + | SOME x => + if x >= 0 then (x, j) + else raise Fail ("error parsing .ppm file: cannot parse " + ^ thingName ^ " from '" + ^ niceify (Parse.parseString s) ^ "'") + + val cursor = 2 + val _ = + if Seq.length contents > 2 andalso + Char.isSpace (Seq.nth contents 2) + then () + else raise Fail "error parsing .ppm file: unexpected format" + + val (width, cursor) = chompInt "width" cursor + val (height, cursor) = chompInt "height" cursor + val (resolution, cursor) = chompInt "max color value" cursor + + val numChannels = 3 * width * height + + val _ = + if resolution = 255 then () + else raise Fail "error parsing .ppm file: P6 max color value must be 255" + + val cursor = + case findFirst (not o Char.isSpace) cursor of + SOME i => i + | NONE => raise Fail "error parsing .ppm file: missing contents" + + val _ = + if Seq.length contents - cursor >= numChannels then () + else raise Fail "error parsing .ppm file: too few color channels" + + fun chan i = + Word8.fromInt (Char.ord (Seq.nth contents (cursor + i))) + + fun pixel i = + {red = chan (3*i), green = chan (3*i + 1), blue = chan (3*i + 2)} + in + { width = width + , height = height + , data = Seq.tabulate pixel (width * height) + } + end + + (* ================================= read ================================= *) + + fun read filepath = + let + val contents = ReadFile.contentsSeq filepath + in + case Parse.parseString (Seq.subseq contents (0, 2)) of + "P3" => parse3 contents + | "P6" => parse6 contents + | _ => raise Fail "error parsing .ppm file: unknown or unsupported format" + end + + (* ================================ write ================================ *) + (* for now, only writes to format P6 *) + + fun write filepath image = + let + val file = TextIO.openOut filepath + + fun dump str = TextIO.output (file, str) + fun dumpChan c = TextIO.output1 (file, Char.chr (Word8.toInt c)) + fun dumpPx i j = + let + val {red, green, blue} = elem image (i, j) + in + (dumpChan red; + dumpChan green; + dumpChan blue) + end + + fun dumpLoop i j = + if i >= #height image then () + else if j >= #width image then + dumpLoop (i+1) 0 + else + (dumpPx i j; dumpLoop i (j+1)) + in + dump "P6 "; + dump (Int.toString (#width image) ^ " "); + dump (Int.toString (#height image) ^ " "); + dump "255 "; + dumpLoop 0 0 + end + +end diff --git a/tests/mpllib/ParFuncArray.sml b/tests/mpllib/ParFuncArray.sml new file mode 100644 index 000000000..d85954562 --- /dev/null +++ b/tests/mpllib/ParFuncArray.sml @@ -0,0 +1,223 @@ +structure ParFuncArray: +sig + type 'a t + type 'a farray = 'a t + + val alloc: int -> 'a farray + val length: 'a farray -> int + val tabulate: int * (int -> 'a) -> 'a farray + val sub: 'a farray * int -> 'a + val update: 'a farray * int * 'a -> 'a farray +end = +struct + + (** ======================================================================== + * Log data structure + *) + + type version = int + + structure Log: + sig + type 'a t + type 'a log = 'a t + type capacity = int + + val new: capacity -> 'a log + + val getVersion: 'a log -> version -> 'a option + + (** Pushing onto a log might make it grow, in which case the old version + * should no longer be used. This is a performance optimization: we can + * store logs in an array and only update a log when it grows. (The + * alternative would be to wrap a log in ref, but this is an unnecessary + * indirection. + *) + val push: 'a log -> version * 'a -> 'a log + end = + struct + datatype 'a t = L of {data: (version * 'a) array, size: int} + type 'a log = 'a t + type capacity = int + + + fun new cap = + L {data = ForkJoin.alloc cap, size = 0} + (* L {data = SeqBasis.tabulate 10000 (0, cap) (fn _ => NONE), size = 0} *) + + + fun push (L {data, size}) (v, x) = + let + val i = size + in + if i < Array.length data then + (* ( Array.update (data, i, SOME (v,x)) *) + ( Array.update (data, i, (v,x)) + ; L {data = data, size = size+1} + ) + else + let + val data' = ForkJoin.alloc (2 * Array.length data) + in + ForkJoin.parfor 10000 (0, Array.length data) (fn j => + Array.update (data', j, Array.sub (data, j)) + ); +(* + ForkJoin.parfor 10000 (Array.length data, Array.length data') (fn j => + Array.update (data', j, NONE) + ); +*) + (* Array.update (data', i, SOME(v,x)); *) + Array.update (data', i, (v,x)); + (* print ("grown! " ^ Util.intToString (Array.length data') ^ "\n"); *) + L {data = data', size = i+1} + end + end + + + (* fun push (logs, logIdx) (v, x) = + case push' (Array.sub (logs, logIdx)) of + NONE => () + | SOME newlog => Array.update (logs, logIdx, newlog) *) + + + fun getVersion (L {data, size}) v = + if size = 0 then + NONE + else + let + val n = size + val _ = + if n <= Array.length data then () + else print ("getVersion: data length: " ^ Int.toString (Array.length data) ^ ", current size: " ^ Int.toString n ^ "\n") + +(* + fun loop i = + if i >= n then n + else if #1 (valOf (Array.sub (data, i))) <= v then + loop (i+1) + else + i + *) + in + if #1 ((*valOf*) (Array.sub (data, n-1))) < v then NONE else + let + val slice = ArraySlice.slice (data, 0, SOME n) + val idx = + BinarySearch.searchPosition slice + (fn (*SOME*) (v', _) => Int.compare (v, v') + (* | NONE => raise Fail "ParFuncArray.getVersion found empty slot" *) + ) + in + SOME (#2 ((*valOf*) (Array.sub (data, idx)))) + end + end + handle e => (print ("error during getVersion: " ^ exnMessage e ^ "\n"); raise e) + + end + + (** ======================================================================== + * Main functions + *) + + datatype 'a array_data = + AD of {vr: version ref, data: 'a array, logs: 'a Log.t array} + + type 'a t = version * 'a array_data + type 'a farray = 'a t + + + fun alloc n = + let + val version = 0 + val data = ForkJoin.alloc n + val logs = SeqBasis.tabulate 5000 (0, n) (fn _ => Log.new 1) + in + (version, AD {vr = ref version, data = data, logs = logs}) + end + + + fun length (_, AD {data, ...}) = Array.length data + + + fun sub (farr, i) = + if i < 0 orelse i >= length farr then + raise Subscript + else + let + val (v, AD {vr, data, logs}) = farr + val guess = Array.sub (data, i) + in + if v = !vr then + guess + else + case Log.getVersion (Array.sub (logs, i)) v of + NONE => guess + | SOME x => x + end + + + fun bcas r (old, new) = + old = Concurrency.cas r (old, new) + + + fun updateLog (logs, i) (v, x) = + let + val oldLog = Array.sub (logs, i) + val newLog = Log.push oldLog (v, x) + val oldLog' = Concurrency.casArray (logs, i) (oldLog, newLog) + in + if MLton.eq (oldLog, oldLog') then () + else raise Fail "updateLog failed somehow!" + end + + + fun update (farr, i, x) = + if i < 0 orelse i > length farr then + raise Subscript + else + let + val (v, ad as AD {vr, data, logs}) = farr + val currv = !vr + in + if + currv = v andalso + currv < Array.length data andalso + bcas vr (v, v+1) + then + (* We successfully claimed access for updating the data *) + ( updateLog (logs, i) (v, Array.sub (data, i)) + ; Array.update (data, i, x) + ; (v+1, ad) + ) + else (* We have to rebuid *) + let + val n = Array.length data + (* val _ = print ("rebuilding " ^ Util.intToString n ^ "\n") *) + val data' = SeqBasis.tabulate 1000 (0, n) (fn i => sub (farr, i)) + val logs' = SeqBasis.tabulate 5000 (0, n) (fn _ => Log.new 1) + in + Array.update (data', i, x); + (0, AD {vr = ref 0, data = data', logs = logs'}) + end + end + + + fun tabulate (n, f) = + let + val version = 0 + val data = ForkJoin.alloc n + val logs = SeqBasis.tabulate 5000 (0, n) (fn _ => Log.new 1) + in + ForkJoin.parfor 1000 (0, n) (fn i => + let + val x = f i + in + Array.update (data, i, f i) + (* updateLog (logs, i) (version, x) *) + end); + + (version, AD {vr = ref version, data = data, logs = logs}) + end + +end diff --git a/tests/mpllib/Parse.sml b/tests/mpllib/Parse.sml new file mode 100644 index 000000000..65e47afec --- /dev/null +++ b/tests/mpllib/Parse.sml @@ -0,0 +1,179 @@ +structure Parse = +struct + + fun parseDigit char = + let + val code = Char.ord char + val code0 = Char.ord #"0" + val code9 = Char.ord #"9" + in + if code < code0 orelse code9 < code then + NONE + else + SOME (code - code0) + end + + fun parseInt s = + let + val n = Seq.length s + fun c i = Seq.nth s i + + fun build x i = + if i >= n then SOME x else + case c i of + #"," => build x (i+1) + | #"_" => build x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => build (x * 10 + dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1) (build 0 1) + else if (c 0 = #"+") then + build 0 1 + else + build 0 0 + end + + fun parseReal s = + let + val n = Seq.length s + fun c i = Seq.nth s i + + fun buildAfterE x i = + Option.map (fn e => x * Math.pow (10.0, Real.fromInt e)) + (parseInt (Seq.subseq s (i, n-i))) + + fun buildAfterPoint m x i = + if i >= n then SOME x else + case c i of + #"," => buildAfterPoint m x (i+1) + | #"_" => buildAfterPoint m x (i+1) + | #"." => NONE + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildAfterPoint (m * 0.1) (x + m * (Real.fromInt dig)) (i+1) + + fun buildBeforePoint x i = + if i >= n then SOME x else + case c i of + #"," => buildBeforePoint x (i+1) + | #"_" => buildBeforePoint x (i+1) + | #"." => buildAfterPoint 0.1 x (i+1) + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildBeforePoint (x * 10.0 + Real.fromInt dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1.0) (buildBeforePoint 0.0 1) + else + buildBeforePoint 0.0 0 + end + + fun parseString s = + CharVector.tabulate (Seq.length s, Seq.nth s) + + (* read a Word16, big endian, starting at index i *) + fun r16b bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes i) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+1))) + in + Word16.fromLarge w + end + + (* read a Word32, big endian, starting at index i *) + fun r32b bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes i) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+3))) + in + Word32.fromLarge w + end + + (* read a Word64, big endian, starting at index i *) + fun r64b bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes i) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+3))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+4))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+5))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+6))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+7))) + in + w + end + + (* read a Word16, little endian, starting at index i *) + fun r16l bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes (i+1)) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes i)) + in + Word16.fromLarge w + end + + (* read a Word32, little endian, starting at index i *) + fun r32l bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes (i+3)) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes i)) + in + Word32.fromLarge w + end + + (* read a Word64, little endian, starting at index i *) + fun r64l bytes i = + let + infix 2 << orb + val op<< = Word64.<< + val op orb = Word64.orb + + val w = Word8.toLarge (Seq.nth bytes (i+7)) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+6))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+5))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+4))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+3))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+2))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes (i+1))) + val w = (w << 0w8) orb (Word8.toLarge (Seq.nth bytes i)) + in + w + end + +end diff --git a/tests/mpllib/ParseFile.sml b/tests/mpllib/ParseFile.sml new file mode 100644 index 000000000..cee33c0b1 --- /dev/null +++ b/tests/mpllib/ParseFile.sml @@ -0,0 +1,190 @@ +(** SAM_NOTE: copy/pasted... some repetition here with Parse. *) +structure ParseFile = +struct + + structure RF = ReadFile + structure Seq = ArraySequence + structure DS = DelayedSeq + + fun tokens (f: char -> bool) (cs: char Seq.t) : (char DS.t) DS.t = + let + val n = Seq.length cs + val s = DS.tabulate (Seq.nth cs) n + val indices = DS.tabulate (fn i => i) (n+1) + fun check i = + if (i = n) then not (f(DS.nth s (n-1))) + else if (i = 0) then not (f(DS.nth s 0)) + else let val i1 = f (DS.nth s i) + val i2 = f (DS.nth s (i-1)) + in (i1 andalso not i2) orelse (i2 andalso not i1) end + val ids = DS.filter check indices + val res = DS.tabulate (fn i => + let val (start, e) = (DS.nth ids (2*i), DS.nth ids (2*i+1)) + in DS.tabulate (fn i => Seq.nth cs (start+i)) (e - start) + end) + ((DS.length ids) div 2) + in + res + end + + fun eqStr str (chars : char DS.t) = + let + val n = String.size str + fun checkFrom i = + i >= n orelse + (String.sub (str, i) = DS.nth chars i andalso checkFrom (i+1)) + in + DS.length chars = n + andalso + checkFrom 0 + end + + fun parseDigit char = + let + val code = Char.ord char + val code0 = Char.ord #"0" + val code9 = Char.ord #"9" + in + if code < code0 orelse code9 < code then + NONE + else + SOME (code - code0) + end + + (* This implementation doesn't work with mpl :( + * Need to fix the basis library... *) + (* + fun parseReal chars = + let + val str = CharVector.tabulate (DS.length chars, DS.nth chars) + in + Real.fromString str + end + *) + + fun parseInt (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun build x i = + if i >= n then SOME x else + case c i of + #"," => build x (i+1) + | #"_" => build x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => build (x * 10 + dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1) (build 0 1) + else if (c 0 = #"+") then + build 0 1 + else + build 0 0 + end + + fun parseReal (chars : char DS.t) = + let + val n = DS.length chars + fun c i = DS.nth chars i + + fun buildAfterE x i = + let + val chars' = DS.subseq chars (i, n-i) + in + Option.map (fn e => x * Math.pow (10.0, Real.fromInt e)) + (parseInt chars') + end + + fun buildAfterPoint m x i = + if i >= n then SOME x else + case c i of + #"," => buildAfterPoint m x (i+1) + | #"_" => buildAfterPoint m x (i+1) + | #"." => NONE + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildAfterPoint (m * 0.1) (x + m * (Real.fromInt dig)) (i+1) + + fun buildBeforePoint x i = + if i >= n then SOME x else + case c i of + #"," => buildBeforePoint x (i+1) + | #"_" => buildBeforePoint x (i+1) + | #"." => buildAfterPoint 0.1 x (i+1) + | #"e" => buildAfterE x (i+1) + | #"E" => buildAfterE x (i+1) + | cc => + case parseDigit cc of + NONE => NONE + | SOME dig => buildBeforePoint (x * 10.0 + Real.fromInt dig) (i+1) + in + if n = 0 then NONE + else if (c 0 = #"-" orelse c 0 = #"~") then + Option.map (fn x => x * ~1.0) (buildBeforePoint 0.0 1) + else + buildBeforePoint 0.0 0 + end + + fun readSequencePoint2d filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequencePoint2d" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun r i = Option.valOf (parseReal (tok (1 + i))) + + fun pt i = + (r (2*i), r (2*i+1)) + handle e => raise Fail ("error parsing point " ^ Int.toString i ^ " (" ^ exnMessage e ^ ")") + + val result = Seq.tabulate pt (n div 2) + in + result + end + + fun readSequenceInt filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceInt" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseInt (tok (1 + i))) + handle e => raise Fail ("error parsing integer " ^ Int.toString i) + in + Seq.tabulate p n + end + + fun readSequenceReal filename = + let + val toks = tokens Char.isSpace (RF.contentsSeq filename) + fun tok i = DS.nth toks i + val _ = + if eqStr "pbbs_sequenceDouble" (tok 0) then () + else raise Fail (filename ^ " wrong file type") + + val n = DS.length toks - 1 + + fun p i = + Option.valOf (parseReal (tok (1 + i))) + handle e => raise Fail ("error parsing double value " ^ Int.toString i) + in + Seq.tabulate p n + end + +end diff --git a/tests/mpllib/PureSeq.sml b/tests/mpllib/PureSeq.sml new file mode 100644 index 000000000..633b8938c --- /dev/null +++ b/tests/mpllib/PureSeq.sml @@ -0,0 +1,222 @@ +structure PureSeq :> +sig + type 'a seq = 'a VectorSlice.slice + type 'a t = 'a seq + + val nth: 'a seq -> int -> 'a + val length: 'a seq -> int + + val empty: unit -> 'a seq + val fromList: 'a list -> 'a seq + val fromSeq: 'a Seq.t -> 'a seq + + val tabulate: (int -> 'a) -> int -> 'a seq + val tabulateG: int -> (int -> 'a) -> int -> 'a seq + val map: ('a -> 'b) -> 'a seq -> 'b seq + + val filter: ('a -> bool) -> 'a seq -> 'a seq + val filterIdx: (int * 'a -> bool) -> 'a seq -> 'a seq + + val reduce: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a + val scan: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a seq * 'a + val scanIncl: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a seq + + val subseq: 'a seq -> int * int -> 'a seq + val take: 'a seq -> int -> 'a seq + val drop: 'a seq -> int -> 'a seq + + val merge: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq + val quicksort: ('a * 'a -> order) -> 'a seq -> 'a seq + + val summarize: int -> ('a -> string) -> 'a seq -> string + + val foreach: 'a seq -> (int * 'a -> unit) -> unit +end = +struct + + structure A = Array + structure AS = ArraySlice + structure V = Vector + structure VS = VectorSlice + + val gran = 5000 + + type 'a seq = 'a VS.slice + type 'a t = 'a seq + + val unsafeCast: 'a array -> 'a vector = VectorExtra.unsafeFromArray + + fun nth s i = VS.sub (s, i) + fun length s = VS.length s + fun empty () = VS.full (V.fromList []) + fun fromList xs = VS.full (V.fromList xs) + + fun subseq s (i, n) = VS.subslice (s, i, SOME n) + fun take s k = subseq s (0, k) + fun drop s k = subseq s (k, length s - k) + + fun tabulate f n = + VS.full (unsafeCast (SeqBasis.tabulate gran (0, n) f)) + + fun tabulateG gran f n = + VS.full (unsafeCast (SeqBasis.tabulate gran (0, n) f)) + + fun map f s = + tabulate (f o nth s) (length s) + + fun filter p s = + VS.full (unsafeCast (SeqBasis.filter gran (0, length s) (nth s) (p o nth s))) + + fun filterIdx p s = + VS.full (unsafeCast (SeqBasis.filter gran (0, length s) (nth s) + (fn i => p (i, nth s i)) + )) + + fun foreach s f = + ForkJoin.parfor gran (0, length s) (fn i => f (i, nth s i)) + + fun fromSeq xs = + tabulate (Seq.nth xs) (Seq.length xs) + + fun reduce f b s = + SeqBasis.reduce gran f b (0, length s) (nth s) + + fun scan f b s = + let + val n = length s + val v = VS.full (unsafeCast (SeqBasis.scan gran f b (0, n) (nth s))) + in + (take v n, nth v n) + end + + fun scanIncl f b s = + let + val n = length s + val v = VS.full (unsafeCast (SeqBasis.scan gran f b (0, n) (nth s))) + in + drop v 1 + end + + (** ======================================================================== + * Merge + * + * This is copied from the Merge.sml implementation and modified slightly + * to make it work with vectors. Really should just parameterize the + * Merge.sml by a Seq implementation... or do the func-sequence trick + * (pass the input sequences by length/nth functions) + *) + + fun sliceIdxs s i j = + VS.subslice (s, i, SOME (j-i)) + + fun arraySliceIdxs s i j = + AS.subslice (s, i, SOME (j-i)) + + fun search cmp s x = + let + fun loop lo hi = + case hi - lo of + 0 => lo + | n => + let + val mid = lo + n div 2 + val pivot = nth s mid + in + case cmp (x, pivot) of + LESS => loop lo mid + | EQUAL => mid + | GREATER => loop (mid+1) hi + end + in + loop 0 (length s) + end + + fun writeMergeSerial cmp (s1, s2) t = + let + fun write i x = AS.update (t, i, x) + + val n1 = length s1 + val n2 = length s2 + + (* i1 index into s1 + * i2 index into s2 + * j index into output *) + fun loop i1 i2 j = + if i1 = n1 then + foreach (sliceIdxs s2 i2 n2) (fn (i, x) => write (i+j) x) + else if i2 = n2 then + foreach (sliceIdxs s1 i1 n1) (fn (i, x) => write (i+j) x) + else + let + val x1 = nth s1 i1 + val x2 = nth s2 i2 + in + case cmp (x1, x2) of + LESS => (write j x1; loop (i1+1) i2 (j+1)) + | _ => (write j x2; loop i1 (i2+1) (j+1)) + end + in + loop 0 0 0 + end + + fun writeMerge cmp (s1: 'a seq, s2: 'a seq) (t: 'a AS.slice) = + if AS.length t <= gran then + writeMergeSerial cmp (s1, s2) t + else if length s1 = 0 then + foreach s2 (fn (i, x) => AS.update (t, i, x)) + else + let + val n1 = length s1 + val n2 = length s2 + val mid1 = n1 div 2 + val pivot = nth s1 mid1 + val mid2 = search cmp s2 pivot + + val l1 = sliceIdxs s1 0 mid1 + val r1 = sliceIdxs s1 (mid1+1) n1 + val l2 = sliceIdxs s2 0 mid2 + val r2 = sliceIdxs s2 mid2 n2 + + val _ = AS.update (t, mid1+mid2, pivot) + val tl = arraySliceIdxs t 0 (mid1+mid2) + val tr = arraySliceIdxs t (mid1+mid2+1) (AS.length t) + in + ForkJoin.par + (fn _ => writeMerge cmp (l1, l2) tl, + fn _ => writeMerge cmp (r1, r2) tr); + () + end + + fun merge cmp (s1, s2) = + let + val out = ForkJoin.alloc (length s1 + length s2) + in + writeMerge cmp (s1, s2) (AS.full out); + VS.full (unsafeCast out) + end + + fun quicksort cmp s = + let + val out = SeqBasis.tabulate gran (0, length s) (nth s) + in + Quicksort.sortInPlace cmp (AS.full out); + VS.full (unsafeCast out) + end + + fun summarize count toString xs = + let + val n = length xs + fun elem i = nth xs i + + val strs = + if count <= 0 then raise Fail "PureSeq.summarize needs count > 0" + else if count <= 2 orelse n <= count then + List.tabulate (n, toString o elem) + else + List.tabulate (count-1, toString o elem) @ + ["...", toString (elem (n-1))] + in + "[" ^ (String.concatWith ", " strs) ^ "]" + end + +end diff --git a/tests/mpllib/Quicksort.sml b/tests/mpllib/Quicksort.sml new file mode 100644 index 000000000..f00322190 --- /dev/null +++ b/tests/mpllib/Quicksort.sml @@ -0,0 +1,133 @@ +(* Author: The 210 Team + * + * Uses dual-pivot quicksort from: + * + * Dual-Pivot Quicksort Algorithm + * Vladimir Yaroslavskiy + * http://codeblab.com/wp-content/uploads/2009/09/DualPivotQuicksort.pdf + * 2009 + * + * Insertion sort is taken from the SML library ArraySort + *) + +structure Quicksort: +sig + type 'a seq = 'a ArraySlice.slice + val sortInPlaceG : int -> ('a * 'a -> order) -> 'a seq -> unit + val sortInPlace : ('a * 'a -> order) -> 'a seq -> unit + val sortG : int -> ('a * 'a -> order) -> 'a seq -> 'a seq + val sort : ('a * 'a -> order) -> 'a seq -> 'a seq +end = +struct + + type 'a seq = 'a ArraySlice.slice + + structure A = Array + structure AS = ArraySlice + + fun sortRange grainsize (array, start, n, compare) = + let + val sub = A.sub + val update = A.update + + fun item i = sub(array,i) + fun set(i,v) = update(array,i,v) + fun cmp(i,j) = compare(item i, item j) + + fun swap (i,j) = + let val tmp = item i + in set(i, item j); set(j, tmp) end + + (* same as swap(j,k); swap(i,j) *) + fun rotate(i,j,k) = + let val tmp = item k + in set(k, item j); set(j, item i); set(i, tmp) end + + fun insertSort (start, n) = + let val limit = start+n + fun outer i = + if i >= limit then () + else let fun inner j = + if j = start then outer(i+1) + else let val j' = j - 1 + in if cmp(j', j) = GREATER + then (swap(j,j'); inner j') + else outer(i+1) + end + in inner i end + in outer (start+1) end + + (* puts lesser pivot at start and larger at end *) + fun twoPivots(a, n) = + let fun sortToFront(size) = + let val m = n div (size + 1) + fun toFront(i) = + if (i < size) then (swap(a + i, a + m*(i+1)); toFront(i+1)) + else () + in (toFront(0); insertSort(a,size)) end + in if (n < 80) then + (if cmp(a, a+n-1) = GREATER then swap(a,a+n-1) else ()) + else (sortToFront(5); swap(a+1,a); swap(a+3,a+n-1)) + end + + (* splits based on two pivots (p1 and p2) into 3 parts: + less than p1, greater than p2, and the rest in the middle. + The pivots themselves end up at the two ends. + If the pivots are the same, returns a false flag to indicate middle + need not be sorted. *) + fun split3 (a, n) = + let + val (p1,p2) = (twoPivots(a,n); (a, a+n-1)) + fun right(r) = if cmp(r, p2) = GREATER then right(r-1) else r + fun loop(l,m,r) = + if (m > r) then (l,m) + else if cmp(m, p1) = LESS then (swap(m,l); loop(l+1, m+1, r)) + else (if cmp(m, p2) = GREATER then + (if cmp(r, p1) = LESS + then (rotate(l,m,r); loop(l+1, m+1, right(r-1))) + else (swap(m,r); loop(l, m+1, right(r-1)))) + else loop(l, m+1, r)) + val (l,m) = loop(a + 1, a + 1, right(a + n - 2)) + in (l, m, cmp(p1, p2) = LESS) end + + (* makes recursive calls in parallel if big enough *) + fun qsort (a, n) = + if (n < 16) then insertSort(a, n) + else let + val (l, m, doMid) = split3(a,n) + in if (n <= grainsize) then + (qsort (a, l-a); + (if doMid then qsort(l, m-l) else ()); + qsort (m, a+n-m)) + else let val par = ForkJoin.par + val left = (fn () => qsort (a, l-a)) + val mid = (fn () => qsort (l, m-l)) + val right = (fn () => qsort (m, a+n-m)) + val maybeMid = if doMid then (fn () => (par(mid,right);())) + else right + in par(left,maybeMid);() end + end + + in qsort (start,n) end + + (* sorts an array slice in place *) + fun sortInPlaceG grainsize compare aslice = + let val (a, i, n) = AS.base aslice + in sortRange grainsize (a, i, n, compare) + end + + fun sortG grainsize compare aslice = + let + val result = AS.full (ForkJoin.alloc (AS.length aslice)) + in + Util.foreach aslice (fn (i, x) => AS.update (result, i, x)); + sortInPlaceG grainsize compare result; + result + end + + val grainsize = 8192 + + fun sortInPlace c s = sortInPlaceG grainsize c s + fun sort c s = sortG grainsize c s + +end diff --git a/tests/mpllib/RadixSort.sml b/tests/mpllib/RadixSort.sml new file mode 100644 index 000000000..26a36a275 --- /dev/null +++ b/tests/mpllib/RadixSort.sml @@ -0,0 +1,145 @@ +(* Author: Lawrence Wang (lawrenc2@andrew.cmu.edu, github.com/larry98) + * + * The lsdSort and msdSort functions take the following arguments: + * - s : 'a ArraySequence.t + * the array (of strings) to sort + * - bucket : 'a ArraySequence.t -> int -> int -> int + * bucket s k i specifies which bucket the k'th digit of the i'th element + * of s should map to + * - numPasses : int + * the number of counting sort passes to make i.e. the number of digits + * in the strings of the array + * - numBuckets : int + * the number of buckets used in counting sort + * + * The quicksort function implements 3-way radix quicksort and takes the + * following arguments: + * - s : 'a ArraySequence.t + * the array (of strings) to sort + * - cmp : int -> 'a * 'a -> order + * cmp k (x, y) specifies the comparison between the kth digit of x with + * the kth digit of y + * - numPasses : int + * the maximum number of quicksort passes to make (commonly the maximum + * length of the strings being sorted) + *) +structure RadixSort :> +sig + val lsdSort : 'a Seq.t -> ('a Seq.t -> int -> int -> int) + -> int -> int -> 'a Seq.t + val msdSort : 'a Seq.t -> ('a Seq.t -> int -> int -> int) + -> int -> int -> 'a Seq.t + val quicksort : 'a Seq.t -> (int -> 'a * 'a -> order) -> int + -> 'a Seq.t +end = +struct + + structure AS = + struct + open ArraySlice + open Seq + + val GRAIN = 4096 + val ASupdate = ArraySlice.update + val alloc = ForkJoin.alloc + end + + fun lsdSort s bucket numPasses numBuckets = + let + fun loop s i = + if i < 0 then s + else loop (#1 (CountingSort.sort s (bucket s i) numBuckets)) (i - 1) + in + loop s (numPasses - 1) + end + + fun msdSort s bucket numPasses numBuckets = + let + val n = AS.length s + val result = ArraySlice.full (AS.alloc n) + fun msdSort' s pass lo hi = + if pass = numPasses then + ForkJoin.parfor AS.GRAIN (0, hi - lo) (fn i => + AS.ASupdate (result, lo + i, AS.nth s i) + ) + else + let + val (s', offsets) = CountingSort.sort s (bucket s pass) numBuckets + in + ForkJoin.parfor AS.GRAIN (0, numBuckets) (fn i => + let + val start = AS.nth offsets i + val len = if i = numBuckets - 1 then AS.length s' - start + else AS.nth offsets (i + 1) - start + val s'' = AS.subseq s' (start, len) + in + if len = 0 then () + else if len = 1 then + AS.ASupdate (result, lo + start, AS.nth s'' 0) + else + msdSort' s'' (pass + 1) (lo + start) (lo + start + len) + end + ) + end + val () = msdSort' s 0 0 n + in + result + end + + fun par3 (a, b, c) = + let + val ((ar, br), cr) = ForkJoin.par (fn _ => ForkJoin.par (a, b), c) + in + (ar, br, cr) + end + + fun quicksort s cmp numPasses = + let + val n = AS.length s + val result = ArraySlice.full (AS.alloc n) + (* TODO: Change to insertion sort if size of array is small *) + fun quicksort' s digit lo hi seed = + if hi = lo then () + else if hi - lo = 1 then AS.ASupdate (result, lo, AS.nth s 0) + else if digit = numPasses then + ForkJoin.parfor AS.GRAIN (0, hi - lo) (fn i => + AS.ASupdate (result, lo + i, AS.nth s i) + ) + else + let + val n' = hi - lo + val pivot = AS.nth s (seed mod n') + fun bucket i = + case cmp digit (AS.nth s i, pivot) of + LESS => 0 + | EQUAL => 1 + | GREATER => 2 + val (s', offsets) = CountingSort.sort s bucket 3 + val mid1 = AS.nth offsets 1 + val mid2 = AS.nth offsets 2 + val seed1 = Util.hash (seed + 1) + val seed2 = Util.hash (seed + 2) + val seed3 = Util.hash (seed + 3) + val s1 = AS.subseq s' (0, mid1) + val s2 = AS.subseq s' (mid1, mid2 - mid1) + val s3 = AS.subseq s' (mid2, n' - mid2) + val () = if hi - lo < 1024 then ( + quicksort' s1 digit lo (lo + mid1) seed1; + quicksort' s2 (digit + 1) (lo + mid1) (lo + mid2) seed2; + quicksort' s3 digit (lo + mid2) hi seed3 + ) else ( + let val ((), (), ()) = par3 ( + fn () => quicksort' s1 digit lo (lo + mid1) seed1, + fn () => quicksort' s2 (digit + 1) (lo + mid1) (lo + mid2) seed2, + fn () => quicksort' s3 digit (lo + mid2) hi seed3 + ) in () end + ) + in + () + end + val () = quicksort' s 0 0 n (Util.hash 0) + in + result + end + +end diff --git a/tests/mpllib/Rat.sml b/tests/mpllib/Rat.sml new file mode 100644 index 000000000..e8db17721 --- /dev/null +++ b/tests/mpllib/Rat.sml @@ -0,0 +1,145 @@ +structure Rat :> +sig + type t + type i = IntInf.int + + (* make(n, d) ~> n/d *) + val make: i * i -> t + val view: t -> i * i + + val normalize: t -> t + + val * : t * t -> t + val - : t * t -> t + val + : t * t -> t + val div: t * t -> t + + val max: t * t -> t + + val sign: t -> int + val compare: t * t -> order + + val approx: t -> Real64.real + + val toString: t -> string +end = +struct + + type i = IntInf.int + type t = i * i + + fun make (n, d) = (n, d) + fun view (n, d) = (n, d) + + fun gcd (a, b) = + if b = 0 then a else gcd (b, IntInf.mod (a, b)) + + fun normalize (n, d) = + if n = 0 then + (0, 1) + else + let + val same = IntInf.sameSign (n, d) + + val na = IntInf.abs n + val da = IntInf.abs d + + val g = gcd (na, da) + + val n' = IntInf.div (na, g) + val d' = IntInf.div (da, g) + in + if same then (n', d') else (IntInf.~ n', d') + end + + fun mul ((a, b): t, (c, d)) = (a * c, b * d) + + fun add ((a, b): t, (c, d)) = + (a * d + b * c, b * d) + + fun sub ((a, b): t, (c, d)) = + (a * d - b * c, b * d) + + fun divv ((a, b): t, (c, d)) = (a * d, b * c) + + + fun sign (n, d) = + if n = 0 then 0 else if IntInf.sameSign (n, d) then 1 else ~1 + + + fun compare (r1, r2) = + let + val diff = sub (r1, r2) + val s = sign diff + in + if s < 0 then LESS else if s = 0 then EQUAL else GREATER + end + + + fun max (r1, r2) = + case compare (r1, r2) of + LESS => r2 + | _ => r1 + + + fun itor x = + Real64.fromLargeInt (IntInf.toLarge x) + + + (* ========================================================================= + * approximate a rational with Real64 + * + * TODO: not sure what the best way to do this is. Kinda just threw something + * together. It's probably kinda messed up in a subtle way. + *) + + local + fun loopApprox acc (r, d) = + let + val r' = itor r + val d' = itor d + in + if Real64.isFinite r' andalso Real64.isFinite d' then + acc + r' / d' + else + let + val d2 = d div 2 + in + if r > d2 then + loopApprox (acc + 0.5) (normalize (r - d2, d)) + else + (* no idea how good this is... *) + loopApprox acc (normalize (r div 2, d div 2)) + end + end + in + fun approx (n, d) = + let + val (n, d) = normalize (n, d) + val s = Real64.fromInt (IntInf.sign n) + val n = IntInf.abs n + + (* abs(n/d) = m + r/d + * where: m is a natural number + * and: r/d is a proper fraction + *) + val (m, r) = IntInf.divMod (n, d) + val m = itor m + in + if not (Real64.isFinite m) then s * m + else s * (m + loopApprox 0.0 (normalize (r, d))) + end + end + + (* ======================================================================= *) + + fun toString (n, d) = + IntInf.toString n ^ "/" ^ IntInf.toString d + + (* ======================================================================= *) + + val op* = mul + val op+ = add + val op- = sub + val op div = divv +end diff --git a/tests/mpllib/RecursiveStream.sml b/tests/mpllib/RecursiveStream.sml new file mode 100644 index 000000000..8073f6d7c --- /dev/null +++ b/tests/mpllib/RecursiveStream.sml @@ -0,0 +1,181 @@ +structure RecursiveStream :> STREAM = +struct + + (** Intended use: when the stream is consumed, the indices are provided + * as input. + * + * For example, if `stream` represents elements x0,x1,... + * val S a = stream + * val (x0, S b) = a 0 + * val (x1, S c) = b 1 + * val (x2, S d) = c 2 + * ... + * + * Requiring the index in this way is just an optimization: we know that all + * streams require at least an index as state, so we can save storing one + * piece of state by instead providing this only as needed. + *) + datatype 'a stream = S of int -> 'a * ('a stream) + type 'a t = 'a stream + + fun nth stream i = + let + fun loop j (S g) = + let + val (first, rest) = g j + in + if j >= i then first else loop (j+1) rest + end + in + loop 0 stream + end + + fun tabulate f = S (fn i => (f i, tabulate f)) + + fun map f (S g) = + S (fn i => + let + val (first, rest) = g i + in + (f first, map f rest) + end) + + fun mapIdx f (S g) = + S (fn i => + let + val (first, rest) = g i + in + (f (i, first), mapIdx f rest) + end) + + fun zipWith f (S g, S h) = + S (fn i => + let + val (gfirst, grest) = g i + val (hfirst, hrest) = h i + in + (f (gfirst, hfirst), zipWith f (grest, hrest)) + end) + + fun iteratePrefixes f z (S g) = + S (fn i => + let + val (first, rest) = g i + val z' = f (z, first) + in + (z, iteratePrefixes f z' rest) + end) + + fun iteratePrefixesIncl f z (S g) = + S (fn i => + let + val (first, rest) = g i + val z' = f (z, first) + in + (z', iteratePrefixesIncl f z' rest) + end) + + fun applyIdx (n, stream) f = + let + fun loop i (S g) = + if i >= n then () else + let + val (first, rest) = g i + in + f (i, first); + loop (i+1) rest + end + in + loop 0 stream + end + + fun iterate f z (n, stream) = + let + fun loop acc i (S g) = + if i >= n then acc else + let + val (first, rest) = g i + val acc' = f (acc, first) + in + loop acc' (i+1) rest + end + in + loop z 0 stream + end + + fun resize arr = + let + val newCapacity = 2 * Array.length arr + val dst = ForkJoin.alloc newCapacity + in + Array.copy {src = arr, dst = dst, di = 0}; + dst + end + + fun pack f (length, stream) = + let + fun loop (data, next) i (S g) = + if i < length andalso next < Array.length data then + let + val (first, rest) = g i + in + case f first of + SOME y => + ( Array.update (data, next, y) + ; loop (data, next+1) (i+1) rest + ) + | NONE => + loop (data, next) (i+1) rest + end + + else if next >= Array.length data then + loop (resize data, next) i (S g) + + else + (data, next) + + val (data, count) = loop (ForkJoin.alloc 10, 0) 0 stream + in + ArraySlice.slice (data, 0, SOME count) + end + + + fun makeBlockStreams + { blockSize: int + , numChildren: int + , offset: int -> int + , getElem: int -> int -> 'a + } = + let + fun getBlock blockIdx = + let + fun advanceUntilNonEmpty i = + if i >= numChildren orelse offset i <> offset (i+1) then + i + else + advanceUntilNonEmpty (i+1) + + val lo = blockIdx * blockSize + + fun walk i = + S (fn idx => + let + val j = lo + idx - offset i + val elem = getElem i j + in + if offset i + j + 1 < offset (i+1) then + (elem, walk i) + else + (elem, walk (advanceUntilNonEmpty (i+1))) + end) + + val firstOuterIdx = + OffsetSearch.indexSearch (0, numChildren, offset) lo + in + walk firstOuterIdx + end + in + getBlock + end + +end diff --git a/tests/mpllib/SEQUENCE.sml b/tests/mpllib/SEQUENCE.sml new file mode 100644 index 000000000..4f7a2699e --- /dev/null +++ b/tests/mpllib/SEQUENCE.sml @@ -0,0 +1,73 @@ +signature SEQUENCE = +sig + type 'a t + type 'a seq = 'a t + (* type 'a ord = 'a * 'a -> order + datatype 'a listview = NIL | CONS of 'a * 'a seq + datatype 'a treeview = EMPTY | ONE of 'a | PAIR of 'a seq * 'a seq *) + + (* exception Range + exception Size *) + + val nth: 'a seq -> int -> 'a + val length: 'a seq -> int + val toList: 'a seq -> 'a list + val toString: ('a -> string) -> 'a seq -> string + val equal: ('a * 'a -> bool) -> 'a seq * 'a seq -> bool + + val empty: unit -> 'a seq + val singleton: 'a -> 'a seq + val tabulate: (int -> 'a) -> int -> 'a seq + val fromList: 'a list -> 'a seq + + val rev: 'a seq -> 'a seq + val append: 'a seq * 'a seq -> 'a seq + val flatten: 'a seq seq -> 'a seq + + val map: ('a -> 'b) -> 'a seq -> 'b seq + val mapOption: ('a -> 'b option) -> 'a seq -> 'b seq + val zip: 'a seq * 'b seq -> ('a * 'b) seq + val zipWith: ('a * 'b -> 'c) -> 'a seq * 'b seq -> 'c seq + val zipWith3: ('a * 'b * 'c -> 'd) -> 'a seq * 'b seq * 'c seq -> 'd seq + + val filter: ('a -> bool) -> 'a seq -> 'a seq + (* val filterSome: 'a option seq -> 'a seq *) + val filterIdx: (int * 'a -> bool) -> 'a seq -> 'a seq + + val enum: 'a seq -> (int * 'a) seq + val mapIdx: (int * 'a -> 'b) -> 'a seq -> 'b seq + (* val update: 'a seq * (int * 'a) -> 'a seq *) + val inject: 'a seq * (int * 'a) seq -> 'a seq + + val subseq: 'a seq -> int * int -> 'a seq + val take: 'a seq -> int -> 'a seq + val drop: 'a seq -> int -> 'a seq + (* val splitHead: 'a seq -> 'a listview *) + (* val splitMid: 'a seq -> 'a treeview *) + + val iterate: ('b * 'a -> 'b) -> 'b -> 'a seq -> 'b + (* val iteratePrefixes: ('b * 'a -> 'b) -> 'b -> 'a seq -> 'b seq * 'b *) + (* val iteratePrefixesIncl: ('b * 'a -> 'b) -> 'b -> 'a seq -> 'b seq *) + val reduce: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a + val scan: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a seq * 'a + val scanIncl: ('a * 'a -> 'a) -> 'a -> 'a seq -> 'a seq + + (* val sort: 'a ord -> 'a seq -> 'a seq + val merge: 'a ord -> 'a seq * 'a seq -> 'a seq + val collect: 'a ord -> ('a * 'b) seq -> ('a * 'b seq) seq + val collate: 'a ord -> 'a seq ord + val argmax: 'a ord -> 'a seq -> int *) + + val $ : 'a -> 'a seq + val % : 'a list -> 'a seq + + val fromArraySeq: 'a ArraySlice.slice -> 'a seq + val toArraySeq: 'a seq -> 'a ArraySlice.slice + + val force: 'a seq -> 'a seq + + val applyIdx: 'a seq -> (int * 'a -> unit) -> unit + + (* val foreach: 'a seq -> (int * 'a -> unit) -> unit + val foreachG: int -> 'a seq -> (int * 'a -> unit) -> unit *) +end diff --git a/tests/mpllib/STREAM.sml b/tests/mpllib/STREAM.sml new file mode 100644 index 000000000..b3aa85ea8 --- /dev/null +++ b/tests/mpllib/STREAM.sml @@ -0,0 +1,27 @@ +signature STREAM = +sig + type 'a t + type 'a stream = 'a t + + val nth: 'a stream -> int -> 'a + + val tabulate: (int -> 'a) -> 'a stream + val map: ('a -> 'b) -> 'a stream -> 'b stream + val mapIdx: (int * 'a -> 'b) -> 'a stream -> 'b stream + val zipWith: ('a * 'b -> 'c) -> 'a stream * 'b stream -> 'c stream + val iteratePrefixes: ('b * 'a -> 'b) -> 'b -> 'a stream -> 'b stream + val iteratePrefixesIncl: ('b * 'a -> 'b) -> 'b -> 'a stream -> 'b stream + + val applyIdx: int * 'a stream -> (int * 'a -> unit) -> unit + val iterate: ('b * 'a -> 'b) -> 'b -> int * 'a stream -> 'b + + val pack: ('a -> 'b option) -> (int * 'a stream) -> 'b ArraySlice.slice + + val makeBlockStreams: + { blockSize: int + , numChildren: int + , offset: int -> int + , getElem: int -> int -> 'a + } + -> (int -> 'a stream) +end diff --git a/tests/mpllib/SampleSort.sml b/tests/mpllib/SampleSort.sml new file mode 100644 index 000000000..54328b9e4 --- /dev/null +++ b/tests/mpllib/SampleSort.sml @@ -0,0 +1,195 @@ +(* Author: Guy Blelloch + * + * This file is basically the cache-oblivious sorting algorithm from: + * + * Low depth cache-oblivious algorithms. + * Guy E. Blelloch, Phillip B. Gibbons and Harsha Vardhan Simhadri. + * Proc. ACM symposium on Parallelism in algorithms and architectures (SPAA), 2010 + * + * The main difference is that it does not recurse (using quicksort instead) + * and the merging with samples is sequential. + *) + +structure SampleSort :> +sig + type 'a seq = 'a ArraySlice.slice + + (* transpose (matrix, numRows, numCols) *) + val transpose : 'a seq * int * int -> 'a seq + + (* transposeBlocks (blockMatrix, srcOffsets, dstOffsets, counts, numRows, numCols, n) *) + val transposeBlocks : 'a seq * int seq * int seq * int seq * int * int * int -> 'a seq + + val sort : ('a * 'a -> order) -> 'a seq -> 'a seq +end = +struct + type 'a seq = 'a Seq.t + + structure A = Array + structure AS = ArraySlice + + val sortInPlace = Quicksort.sortInPlace + + val sub = A.sub + val update = A.update + + val par = ForkJoin.par + val for = Util.for + val parallelFor = ForkJoin.parfor + + fun for_l (lo, len) f = for (lo, lo + len) f + + fun matrixDandC baseCase (threshold, num_rows, num_cols) = + let fun r(rs, rl, cs, cl) = + if (rl*cl < threshold) then baseCase(rs, rl, cs, cl) + else if (cl > rl) then + (par (fn () => r(rs, rl, cs, cl div 2), + fn () => r(rs, rl, cs + (cl div 2), cl - (cl div 2))); ()) + else + (par (fn () => r(rs, rl div 2, cs, cl), + fn () => r(rs + (rl div 2), rl - (rl div 2), cs, cl)); ()) + in r(0, num_rows, 0, num_cols) end + + (* transposes a matrix *) + fun transpose(S, num_rows, num_cols) = + let + val seq_threshold = 8000 + val (SS, offset, n) = AS.base S + val _ = if (AS.length S) <> (num_rows * num_cols) then raise Size else () + val R = ForkJoin.alloc (num_rows * num_cols) + fun baseCase(row_start, row_len, col_start, col_len) = + for_l (row_start, row_len) (fn i => + for_l (col_start, col_len) (fn j => + update(R, j * num_rows + i, sub(SS,(i*num_cols + j + offset))))) + in (matrixDandC baseCase (seq_threshold, num_rows, num_cols); + AS.full(R)) + end + + (* transposes a matrix of blocks given source and destination pairs *) + fun transposeBlocks(S, source_offsets, dest_offsets, counts, num_rows, num_cols, n) = + let + val seq_threshold = 500 + val (SS, offset, n) = AS.base S + val R = ForkJoin.alloc n + fun baseCase(row_start, row_len, col_start, col_len) = + for (row_start, row_start + row_len) (fn i => + for (col_start, col_start + col_len) (fn j => let + val pa = offset + AS.sub (source_offsets, i*num_cols + j) + val pb = Seq.nth dest_offsets (j*num_rows + i) + val l = Seq.nth counts (i*num_cols + j) + in for (0,l) (fn k => update(R,pb+k,sub(SS, pa + k))) end)) + in (matrixDandC baseCase (seq_threshold, num_rows, num_cols); + AS.full(R)) + end + + (* merges a sequence of elements A with the samples S, putting counts in C *) + fun mergeWithSamples cmp (A, S, C) = + let + val num_samples = AS.length S + val n = AS.length A + fun merge(i,j) = + if (j = num_samples) then AS.update(C,j,n-i) + else + let fun merge'(i) = if (i < n andalso cmp(AS.sub (A, i), AS.sub (S, j)) = LESS) + then merge'(i+1) + else i + val k = merge'(i) + val _ = AS.update(C, j, k-i) + in merge(k,j+1) end + in merge(0,0) end + + fun sort cmp A = + let + val n = AS.length A + + (* parameters used in algorithm *) + val bucket_quotient = 3 + val block_quotient = 2 + val sqrt = Real.floor(Math.sqrt(Real.fromInt n)) + val num_blocks = sqrt div block_quotient + val block_size = ((n-1) div num_blocks) + 1 + val num_buckets = (sqrt div bucket_quotient) + 1 + val over_sample = 1 + ((n div num_buckets) div 500) + val sample_size = num_buckets * over_sample + val sample_stride = n div sample_size + val m = num_blocks*num_buckets + + (* val _ = print ("num_blocks " ^ Int.toString num_blocks ^ "\n") + val _ = print ("num_buckets " ^ Int.toString num_buckets ^ "\n") + val _ = print ("sample_size " ^ Int.toString sample_size ^ "\n") + val _ = print ("over_sample " ^ Int.toString over_sample ^ "\n") + val _ = print ("m " ^ Int.toString m ^ "\n") *) + + (* val t0 = Time.now () *) + + (* sort a sample of keys *) + val sample = Seq.tabulate (fn i => AS.sub (A, i*sample_stride)) sample_size + val _ = sortInPlace cmp sample + + (* val t1 = Time.now () + val _ = print ("sorted sample " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "\n") *) + + (* take a subsample *) + val sub_sample = Seq.tabulate (fn i => AS.sub (sample, (i+1)*over_sample)) (num_buckets-1) + + (* val t2 = Time.now () + val _ = print ("subsample " ^ Time.fmt 4 (Time.- (t2, t1)) ^ "\n") *) + + val counts = AS.full (ForkJoin.alloc m) + val B = AS.full (ForkJoin.alloc n) + + (* sort each block and merge with the pivots, giving a count of the number + of keys between each pivot in each block *) + val _ = + parallelFor 1 (0,num_blocks) (fn i => + let + val start = i * block_size + val len = Int.min((i+1)* block_size,n) - start + (* copy into B to avoid changing A *) + val _ = for (start, start+len) (fn j => AS.update(B, j, AS.sub (A, j))) + val B' = Seq.subseq B (start, len) + val _ = sortInPlace cmp B' + val counts' = Seq.subseq counts (i*num_buckets, num_buckets) + val _ = mergeWithSamples cmp (B', sub_sample, counts') + in () end) + + (* val t3 = Time.now () + val _ = print ("sort blocks " ^ Time.fmt 4 (Time.- (t3, t2)) ^ "\n") *) + + (* scan across the counts to get offset of each source bucket within each block *) + val (source_offsets,_) = Seq.scan op+ 0 counts + + (* transpose and scan across the counts to get offset of each + destination within each bucket *) + val tcounts = transpose(counts,num_blocks,num_buckets) + val (dest_offsets,_) = Seq.scan op+ 0 tcounts + + (* move data to correct destination *) + val C = transposeBlocks(B, source_offsets, dest_offsets, + counts, num_blocks, num_buckets, n) + + (* val t4 = Time.now () + val _ = print ("transpose data " ^ Time.fmt 4 (Time.- (t4, t3)) ^ "\n") *) + + (* get the start location of each bucket *) + fun bucket_offset i = + if (i = num_buckets) then n + else AS.sub (dest_offsets, i * num_blocks) + + (* sort the buckets *) + val _ = + parallelFor 1 (0, num_buckets) (fn i => + let + val start = bucket_offset i + val len = bucket_offset (i+1) - start + (* val start = AS.sub (bucket_offsets, i) + val len = (AS.sub (bucket_offsets, i+1)) - start *) + val _ = sortInPlace cmp (Seq.subseq C (start,len)) + in () end) + + (* val t5 = Time.now () + val _ = print ("sort buckets " ^ Time.fmt 4 (Time.- (t5, t4)) ^ "\n") *) + + in C end +end + diff --git a/tests/mpllib/SeqBasis.sml b/tests/mpllib/SeqBasis.sml new file mode 100644 index 000000000..1739c1f8c --- /dev/null +++ b/tests/mpllib/SeqBasis.sml @@ -0,0 +1,214 @@ +structure SeqBasis: +sig + type grain = int + + val tabulate: grain -> (int * int) -> (int -> 'a) -> 'a array + + val foldl: ('b * 'a -> 'b) + -> 'b + -> (int * int) + -> (int -> 'a) + -> 'b + + val foldr: ('b * 'a -> 'b) + -> 'b + -> (int * int) + -> (int -> 'a) + -> 'b + + val reduce: grain + -> ('a * 'a -> 'a) + -> 'a + -> (int * int) + -> (int -> 'a) + -> 'a + + val scan: grain + -> ('a * 'a -> 'a) + -> 'a + -> (int * int) + -> (int -> 'a) + -> 'a array (* length N+1, for both inclusive and exclusive scan *) + + val filter: grain + -> (int * int) + -> (int -> 'a) + -> (int -> bool) + -> 'a array + + val tabFilter: grain + -> (int * int) + -> (int -> 'a option) + -> 'a array +end = +struct + + type grain = int + + structure A = Array + structure AS = ArraySlice + + (* + fun upd a i x = Unsafe.Array.update (a, i, x) + fun nth a i = Unsafe.Array.sub (a, i) + *) + + fun upd a i x = A.update (a, i, x) + fun nth a i = A.sub (a, i) + + val parfor = ForkJoin.parfor + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + fun tabulate grain (lo, hi) f = + let + val n = hi-lo + val result = allocate n + in + if lo = 0 then + parfor grain (0, n) (fn i => upd result i (f i)) + else + parfor grain (0, n) (fn i => upd result i (f (lo+i))); + + result + end + + fun foldl g b (lo, hi) f = + if lo >= hi then b else + let + val b' = g (b, f lo) + in + foldl g b' (lo+1, hi) f + end + + fun foldr g b (lo, hi) f = + if lo >= hi then b else + let + val hi' = hi-1 + val b' = g (b, f hi') + in + foldr g b' (lo, hi') f + end + + fun reduce grain g b (lo, hi) f = + if hi - lo <= grain then + foldl g b (lo, hi) f + else + let + val n = hi - lo + val k = grain + val m = 1 + (n-1) div k (* number of blocks *) + + fun red i j = + case j - i of + 0 => b + | 1 => foldl g b (lo + i*k, Int.min (lo + (i+1)*k, hi)) f + | n => let val mid = i + (j-i) div 2 + in g (par (fn _ => red i mid, fn _ => red mid j)) + end + in + red 0 m + end + + fun scan grain g b (lo, hi) (f : int -> 'a) = + if hi - lo <= grain then + let + val n = hi - lo + val result = allocate (n+1) + fun bump ((j,b),x) = (upd result j b; (j+1, g (b, x))) + val (_, total) = foldl bump (0, b) (lo, hi) f + in + upd result n total; + result + end + else + let + val n = hi - lo + val k = grain + val m = 1 + (n-1) div k (* number of blocks *) + val sums = tabulate 1 (0, m) (fn i => + let val start = lo + i*k + in foldl g b (start, Int.min (start+k, hi)) f + end) + val partials = scan grain g b (0, m) (nth sums) + val result = allocate (n+1) + in + parfor 1 (0, m) (fn i => + let + fun bump ((j,b),x) = (upd result j b; (j+1, g (b, x))) + val start = lo + i*k + in + foldl bump (i*k, nth partials i) (start, Int.min (start+k, hi)) f; + () + end); + upd result n (nth partials m); + result + end + + fun filter grain (lo, hi) f g = + let + val n = hi - lo + val k = grain + val m = 1 + (n-1) div k (* number of blocks *) + fun count (i, j) c = + if i >= j then c + else if g i then count (i+1, j) (c+1) + else count (i+1, j) c + val counts = tabulate 1 (0, m) (fn i => + let val start = lo + i*k + in count (start, Int.min (start+k, hi)) 0 + end) + val offsets = scan grain op+ 0 (0, m) (nth counts) + val result = allocate (nth offsets m) + fun filterSeq (i, j) c = + if i >= j then () + else if g i then (upd result c (f i); filterSeq (i+1, j) (c+1)) + else filterSeq (i+1, j) c + in + parfor 1 (0, m) (fn i => + let val start = lo + i*k + in filterSeq (start, Int.min (start+k, hi)) (nth offsets i) + end); + result + end + + fun tabFilter grain (lo, hi) (f : int -> 'a option) = + let + val n = hi - lo + val k = grain + val m = 1 + (n-1) div k (* number of blocks *) + val tmp = allocate n + + fun filterSeq (i,j,k) = + if (i >= j) then k + else case f i of + NONE => filterSeq(i+1, j, k) + | SOME v => (A.update(tmp, k, v); filterSeq(i+1, j, k+1)) + + val counts = tabulate 1 (0, m) (fn i => + let val last = filterSeq (lo + i*k, lo + Int.min((i+1)*k, n), i*k) + in last - i*k + end) + + val outOff = scan grain op+ 0 (0, m) (fn i => A.sub (counts, i)) + val outSize = A.sub (outOff, m) + + val result = allocate outSize + in + (* Choosing grain = n/outSize assumes that the blocks are all + * approximately the same amount full. We could do something more + * complex here, e.g. binary search to recursively split up the + * range into small pieces of all the same size. *) + parfor (n div (Int.max (outSize, 1))) (0, m) (fn i => + let + val soff = i * k + val doff = A.sub (outOff, i) + val size = A.sub (outOff, i+1) - doff + in + Util.for (0, size) (fn j => + A.update (result, doff+j, A.sub (tmp, soff+j))) + end); + result + end + +end diff --git a/tests/mpllib/SeqifiedMerge.sml b/tests/mpllib/SeqifiedMerge.sml new file mode 100644 index 000000000..eb3c3bf3d --- /dev/null +++ b/tests/mpllib/SeqifiedMerge.sml @@ -0,0 +1,59 @@ +structure SeqifiedMerge: +sig + val merge: ('a * 'a -> order) -> 'a Seq.t * 'a Seq.t -> 'a Seq.t +end = +struct + + val serialGrain = CommandLineArgs.parseInt "MPLLib_Merge_serialGrain" 4000 + + val unsafe_at_leaves = CommandLineArgs.parseFlag + "MPLLib_SeqifiedMerge_unsafe_at_leaves" + + fun merge_loop cmp (s1, s2) out = + if Seq.length s1 = 0 then + Seqifier.put (out, s2) + else if Seqifier.length out <= serialGrain then + if unsafe_at_leaves then + (* this is semantically safe (it does not violate any of the internal + * invariants of the Seqifier libary), but of course it appears to be + * syntactically unsafe from the perspective of the library interface. + * We are careful here to make sure we don't do anything really bad. + *) + ( Merge.writeMergeSerial cmp (s1, s2) + (Seqifier.unsafe_view_contents out) + ; Seqifier.unsafe_mark_put out + ) + else + (* this is completely safe, at a small performance cost, due to the + * need to create an intermediate sequence for the result of the + * `mergeSerial`. + *) + Seqifier.put (out, Merge.mergeSerial cmp (s1, s2)) + else + let + val n1 = Seq.length s1 + val n2 = Seq.length s2 + val mid1 = n1 div 2 + val pivot = Seq.nth s1 mid1 + val mid2 = BinarySearch.search cmp s2 pivot + + val (outl, out_tail) = Seqifier.split_at (out, mid1 + mid2) + val (outm, outr) = Seqifier.split_at (out_tail, 1) + val outm = Seqifier.put (outm, Seq.subseq s1 (mid1, 1)) + val l1 = Seq.take s1 mid1 + val r1 = Seq.drop s1 (mid1 + 1) + val l2 = Seq.take s2 mid2 + val r2 = Seq.drop s2 mid2 + val (outl, outr) = + ForkJoin.par (fn _ => merge_loop cmp (l1, l2) outl, fn _ => + merge_loop cmp (r1, r2) outr) + in + Seqifier.append (outl, Seqifier.append (outm, outr)) + end + + fun merge cmp (s1, s2) = + let val out = Seqifier.init_expect_length (Seq.length s1 + Seq.length s2) + in Seqifier.finalize (merge_loop cmp (s1, s2) out) + end + +end diff --git a/tests/mpllib/Seqifier.sml b/tests/mpllib/Seqifier.sml new file mode 100644 index 000000000..622f0d01f --- /dev/null +++ b/tests/mpllib/Seqifier.sml @@ -0,0 +1,206 @@ +(* Implements a parallel "sequence builder" data structure: + * type 'a seqifier + * type 'a t = 'a seqifier + * + * These can be used to write purely functional algorithms that directly write + * to an underlying (mutable) array, but the mutable array is hidden behind + * the interface and not exposed to the programmer. + * + * Example usage is: + * val sb = init_expect_length n (* O(1) *) + * val (sb1, sb2) = split_at (sb, i) (* O(1) *) + * val (sb1', sb2') = + * ForkJoin.par + * (fn () => put (sb1, X), (* O(|X|) work, O(log|X|) span *) + * fn () => put (sb2, Y)) (* O(|Y|) work, O(log|Y|) span *) + * ... + * val sb' = append (sb1', sb2') (* O(1) *) + * val result = finalize sb' (* O(1) *) + * + * The value semantics of these functions can be described in terms of a + * purely functional sequence with elements of type 'a option. Initially, + * every element is NONE. Calling `put (...)` returns a seqifier that is + * full of SOME(x) elements. Calling `finalize` checks that there are no + * NONEs, and returns a sequence of just the elements themselves. + * + * To achieve good cost bounds, seqifiers can only be "used" at most once. + * "Using" a seqifier means passing it as argument to one of the following + * functions: + * split_at + * put + * append + * finalize + * + * (Note that the function `length: 'a seqifier -> int` is read-only and does + * not constitute a "use"; this function is safe to call in any context.) + * + * Every time a seqifier is used, it is immediately invalidated. Any call to + * one of the functions above that receives an invalid seqifier as input + * will raise the exception UsedTwice. + * + * When appending two seqifiers, it is essential that they "came from" the + * same original seqifier and are physically adjacent to each other. Calling + * append will raise NonAdjacent otherwise. + * + * For example, this is okay: + * val (l, r) = split_at (x, i) + * val (l1, l2) = split_at (l, j) + * ... + * val foo = append (l1, append (l2, r)) + * + * But this would raise NonAdjacent: + * val (l, r) = split_at (x, i) + * val (l1, l2) = split_at (l, j) + * ... + * val foo = append (l1, r) + * + * When finalizing a seqifier, you may get the exception MaybeMissingPut. This + * occurs if one of the components of the seqifier was never covered by the + * result of a `put`. + * + * For example, this would raise MaybeMissingPut, because we never called + * `put` on the segment `r`. + * val x = init_expect_length n + * val (l, r) = split_at (x, i) + * val l' = put (x, ...) + * val result = finalize (append (l', r)) + * + * The MaybeMissingPut issue is checked conservatively by keeping track, for + * every seqifier, whether or not that seqifier has been fully put. This is set + * to `true` on the output of a `put`, and at each `append` we check if both + * sides are marked as fully put. Calling `split_at` just copies the boolean + * to both of the results. This approach is conservative because we don't + * individually track every index. (Specifically, at `append`, if at least + * one index has not yet been put, we mark the whole result as not put.) + *) +structure Seqifier: +sig + exception UsedTwice + exception NonAdjacent + exception MaybeMissingPut + + type 'a seqifier + type 'a t = 'a seqifier + + val length: 'a t -> int + + val init_expect_length: int -> 'a t + val split_at: 'a t * int -> 'a t * 'a t + val append: 'a t * 'a t -> 'a t + val put: 'a t * 'a Seq.t -> 'a t + val finalize: 'a t -> 'a Seq.t + + (* ======================================================================== + * UNSAFE FUNCTIONS + * These give direct access to the internals of the seqifier. + * Don't use unless you know what you are doing! + *) + + val unsafe_mark_put: 'a t -> 'a t + val unsafe_view_contents: 'a t -> 'a ArraySlice.slice + +end = +struct + + + datatype 'a t = + T of + { output: 'a array + , offset: int + , len: int + , fully_put: bool + , valid: Word8.word ref + } + + type 'a seqifier = 'a t + + + exception UsedTwice + exception NonAdjacent + exception MaybeMissingPut + + + fun unpack_and_mark_used (T {output, offset, len, fully_put, valid}) = + if !valid = 0w0 orelse Concurrency.cas valid (0w1, 0w0) = 0w0 then + raise UsedTwice + else + (output, offset, len, fully_put) + + + fun pack (output, offset, len, fully_put) = + T { output = output + , offset = offset + , len = len + , fully_put = fully_put + , valid = ref 0w1 + } + + + fun init_expect_length n = + pack (ForkJoin.alloc n, 0, n, false) + + + fun length (T {len, ...}) = len + + + fun split_at (t, i) = + let + val (output, offset, len, fp) = unpack_and_mark_used t + in + if i < 0 orelse i > len then + raise Subscript + else + (pack (output, offset, i, fp), pack (output, offset + i, len - i, fp)) + end + + + fun append (l, r) = + let + val (output1, offset1, len1, fp1) = unpack_and_mark_used l + val (output2, offset2, len2, fp2) = unpack_and_mark_used r + in + if not (MLton.eq (output1, output2) andalso offset1 + len1 = offset2) then + raise NonAdjacent + else + pack (output1, offset1, len1 + len2, fp1 andalso fp2) + end + + + fun put (t, s) = + let + val (output, offset, len, _) = unpack_and_mark_used t + in + if Seq.length s <> len then + raise Size + else + ( Seq.foreach s (fn (i, x) => Array.update (output, offset + i, x)) + ; pack (output, offset, len, true) + ) + end + + + fun finalize t = + let + val (output, offset, len, fp) = unpack_and_mark_used t + in + if not fp then raise MaybeMissingPut + else ArraySlice.slice (output, offset, SOME len) + end + + + (* ======================================================================= + * UNSAFE FUNCTIONS BELOW + *) + + fun unsafe_mark_put (T {output, offset, len, valid, ...}) = + T { output = output + , offset = offset + , len = len + , valid = valid + , fully_put = true + } + + fun unsafe_view_contents (T {output, offset, len, ...}) = + ArraySlice.slice (output, offset, SOME len) + +end diff --git a/tests/mpllib/Shuffle.sml b/tests/mpllib/Shuffle.sml new file mode 100644 index 000000000..f6a4bd227 --- /dev/null +++ b/tests/mpllib/Shuffle.sml @@ -0,0 +1,64 @@ +structure Shuffle :> +sig + type 'a seq = 'a ArraySlice.slice + val shuffle: 'a seq -> int -> 'a seq +end = +struct + open Seq + type 'a seq = 'a ArraySlice.slice + + (* inplace Knuth shuffle [l, r) *) + fun inplace_seq_shuffle s l r seed = + let + fun item i = AS.sub (s, i) + fun set (i, v) = AS.update (s, i, v) + (* get a random idx in [l, i] *) + fun rand_idx i = Int.mod (Util.hash (seed + i), i - l + 1) + l + fun swap (i,j) = + let + val tmp = item i + in + set(i, item j); set(j, tmp) + end + fun shuffle_helper li = + if r - li < 2 then () + else (swap (li, rand_idx li); shuffle_helper (li + 1)) + in + shuffle_helper l + end + + fun bucket_shuffle s seed = + let + fun log2_up n = Real.ceil (Math.log10 (Real.fromInt n) / (Math.log10 2.0)) + fun bit_and (n, mask) = Word.toInt (Word.andb (Word.fromInt n, mask)) + + val n = length s + val l = log2_up n + val bits = if n < Real.floor (Math.pow (2.0, 27.0)) then Int.div ((l - 7), 2) + else l - 17 + val num_buckets = Real.floor (Math.pow (2.0, Real.fromInt bits)) + val mask = Word.fromInt (num_buckets - 1) + fun rand_pos i = bit_and (Util.hash (seed + i), mask) + (* size of bucket_offsets = num_buckets + 1 *) + val (s', bucket_offsets) = CountingSort.sort s rand_pos num_buckets + fun bucket_shuffle i = inplace_seq_shuffle s' (nth bucket_offsets i) (nth bucket_offsets (i + 1)) seed + val _ = ForkJoin.parfor 1 (0, num_buckets) bucket_shuffle + in + s' + end + + fun shuffle s seed = + let + val n = length s + in + if n < 1000 then + let + val s' = (Seq.tabulate (Seq.nth s) n) + val _ = inplace_seq_shuffle s' 0 n seed + in + s' + end + else + bucket_shuffle s seed + end +end diff --git a/tests/mpllib/Signal.sml b/tests/mpllib/Signal.sml new file mode 100644 index 000000000..0a2fee511 --- /dev/null +++ b/tests/mpllib/Signal.sml @@ -0,0 +1,290 @@ +structure Signal: +sig + type sound = NewWaveIO.sound + val delay: real -> real -> sound -> sound + val allPass: real -> real -> sound -> sound + val reverb: sound -> sound +end = +struct + + type sound = NewWaveIO.sound + + structure A = Array + structure AS = ArraySlice + +(* + structure A = + struct + open A + val update = Unsafe.Array.update + val sub = Unsafe.Array.sub + end + + structure AS = + struct + open AS + fun update (s, i, x) = + let val (a, start, _) = base s + in A.update (a, start+i, x) + end + fun sub (s, i) = + let val (a, start, _) = base s + in A.sub (a, start+i) + end + end +*) + + fun delaySequential D a data = + let + val n = Seq.length data + val output = ForkJoin.alloc n + in + Util.for (0, n) (fn i => + if i < D then + A.update (output, i, Seq.nth data i) + else + A.update (output, i, Seq.nth data i + a * A.sub (output, i - D)) + ); + + AS.full output + end + + fun pow (a: real) n = + if n <= 1 then + a + else if n mod 2 = 0 then + pow (a*a) (n div 2) + else + a * pow (a*a) (n div 2) + + (* Granularity parameters *) + val blockWidth = CommandLineArgs.parseInt "comb-width" 600 + val blockHeight = CommandLineArgs.parseInt "comb-height" 50 + val combGran = CommandLineArgs.parseInt "comb-threshold" 10000 + (*val _ = print ("comb-width " ^ Int.toString blockWidth ^ "\n") + val _ = print ("comb-height " ^ Int.toString blockHeight ^ "\n") + val _ = print ("comb-threshold " ^ Int.toString combGran ^ "\n")*) + + (* Imagine laying out the data as a matrix, where sample s[i*D + j] is + * at row i, column j. + *) + fun delay' D alpha data = + if Seq.length data <= combGran then + delaySequential D alpha data + else + let + val n = Seq.length data + (* val _ = print ("delay' " ^ Int.toString D ^ " " ^ Int.toString n ^ "\n") *) + val output = ForkJoin.alloc n + + val numCols = D + val numRows = Util.ceilDiv n D + + fun getOutput i j = + A.sub (output, i*numCols + j) + + fun setOutput i j x = + let val idx = i*numCols + j + in if idx < n then A.update (output, idx, x) else () + end + + fun input i j = + let val idx = i*numCols + j + in if idx >= n then 0.0 else AS.sub (data, idx) + end + + val powAlpha = pow alpha blockHeight + + val numColumnStrips = Util.ceilDiv numCols blockWidth + val numRowStrips = Util.ceilDiv numRows blockHeight + + fun doColumnStrip c = + let + val jlo = blockWidth * c + val jhi = Int.min (numCols, jlo + blockWidth) + val width = jhi - jlo + val summaries = + AS.full (ForkJoin.alloc (width * numRowStrips)) + + fun doBlock b = + let + val ilo = blockHeight * b + val ihi = Int.min (numRows, ilo + blockHeight) + val ss = Seq.subseq summaries (width * b, width) + in + Util.for (0, width) (fn j => AS.update (ss, j, input ilo (jlo+j))); + + Util.for (ilo+1, ihi) (fn i => + Util.for (0, width) (fn j => + AS.update (ss, j, input i (jlo+j) + alpha * AS.sub (ss, j)) + ) + ) + end + + val _ = ForkJoin.parfor 1 (0, numRowStrips) doBlock + val summaries' = delay' width powAlpha summaries + + fun fillOutputBlock b = + let + val ilo = blockHeight * b + val ihi = Int.min (numRows, ilo + blockHeight) + in + if b = 0 then + Util.for (jlo, jhi) (fn j => setOutput 0 j (input 0 j)) + else + let + val ss = Seq.subseq summaries' (width * (b-1), width) + in + Util.for (0, width) (fn j => + setOutput ilo (jlo+j) (input ilo (jlo+j) + alpha * AS.sub (ss, j))) + end; + + Util.for (ilo+1, ihi) (fn i => + Util.for (jlo, jhi) (fn j => + setOutput i j (input i j + alpha * getOutput (i-1) j) + ) + ) + end + in + ForkJoin.parfor 1 (0, numRowStrips) fillOutputBlock + end + in + ForkJoin.parfor 1 (0, numColumnStrips) doColumnStrip; + + AS.full output + end + + fun delay ds alpha ({sr, data}: sound) = + let + val D = Real.round (ds * Real.fromInt sr) + in + {sr = sr, data = delay' D alpha data} + end + + fun allPass' D a data = + let + val combed = delay' D a data + + fun output j = + let + val k = j - D + in + (1.0 - a*a) * (if k < 0 then 0.0 else Seq.nth combed k) + - (a * Seq.nth data j) + end + in + Seq.tabulate output (Seq.length data) + end + + fun allPass ds a (snd as {sr, data}: sound) = + let + (* convert to samples *) + val D = Real.round (ds * Real.fromInt sr) + in + { sr = sr + , data = allPass' D a data + } + end + + val par = ForkJoin.par + + fun par4 (a, b, c, d) = + let + val ((ar, br), (cr, dr)) = + par (fn _ => par (a, b), fn _ => par (c, d)) + in + (ar, br, cr, dr) + end + + fun shiftBy n s i = + if i < n then + 0.0 + else if i < Seq.length s + n then + Seq.nth s (i-n) + else + 0.0 + + fun reverb ({sr, data=dry}: sound) = + let + val N = Seq.length dry + + (* Originally, I tuned the comb and allPass parameters + * based on numbers of samples at 44.1 kHz, which I chose + * to be relatively prime to one another. But now, to + * handle any sample rate, we need to convert these numbers + * of samples. Does it really matter if the sample delays are + * relatively prime? I'm not sure. For sample rates other + * than 44.1 kHz, they almost certainly won't be now. *) + + val srr = Real.fromInt sr + fun secondsToSamples sec = Real.round (sec * srr) + fun secondsAt441 samples = Real.fromInt samples / 44100.0 + fun adjust x = + if sr = 44100 then x else secondsToSamples (secondsAt441 x) + + val D1 = adjust 1931 + val D2 = adjust 2213 + val D3 = adjust 1747 + val D4 = adjust 1559 + + val DA1 = adjust 167 + val DA2 = adjust 191 + + val DE1 = adjust 1013 + val DE2 = adjust 1102 + val DE3 = adjust 1300 + + val DF = adjust 1500 + + (* ========================================== + * Fused reflections near 50ms + * (at 44.1kHz, 50ms is about 2200 samples) + * + * The basic design is taken from + * The Computer Music Tutorial (1996), page 481 + * Author: Curtis Roads + * + * The basic design is 4 comb filters (in parallel) + * which are then fed into two allpass filters, in series. + *) + + val (c1, c2, c3, c4) = + par4 (fn _ => delay' D1 0.7 dry, + fn _ => delay' D2 0.7 dry, + fn _ => delay' D3 0.7 dry, + fn _ => delay' D4 0.7 dry) + + fun combs i = + (Seq.nth c1 i + + Seq.nth c2 i + + Seq.nth c3 i + + Seq.nth c4 i) + + val fused = Seq.tabulate combs N + val fused = allPass' DA1 0.6 fused + val fused = allPass' DA2 0.6 fused + + (* ========================================== + * wet signal = dry + early + fused + * + * early reflections are single echos of + * the dry sound that occur after around + * 25ms delay + * + * the fused reflections start emerging after + * approximately 35ms + *) + + val wet = Seq.tabulate (fn i => + shiftBy 0 dry i + + 0.6 * (shiftBy DE1 dry i) + + 0.5 * (shiftBy DE2 dry i) + + 0.4 * (shiftBy DE3 dry i) + + 0.75 * (shiftBy DF fused i)) + (N + DF) + + in + NewWaveIO.compress 2.0 {sr=sr, data=wet} + end + +end diff --git a/tests/mpllib/StableMerge.sml b/tests/mpllib/StableMerge.sml new file mode 100644 index 000000000..d64933159 --- /dev/null +++ b/tests/mpllib/StableMerge.sml @@ -0,0 +1,100 @@ +structure StableMerge: +sig + type 'a seq = 'a ArraySlice.slice + + val writeMergeSerial: ('a * 'a -> order) (* compare *) + -> 'a seq * 'a seq (* (sorted) sequences to merge *) + -> 'a seq (* output *) + -> unit + + val writeMerge: ('a * 'a -> order) (* compare *) + -> 'a seq * 'a seq (* (sorted) sequences to merge *) + -> 'a seq (* output *) + -> unit + + val mergeSerial: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq + val merge: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq +end = +struct + + structure AS = ArraySlice + type 'a seq = 'a AS.slice + + val for = Util.for + val parfor = ForkJoin.parfor + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + val serialGrain = + CommandLineArgs.parseInt "MPLLib_StableMerge_serialGrain" 4000 + + fun sliceIdxs s i j = + AS.subslice (s, i, SOME (j - i)) + + fun writeMergeSerial cmp (s1, s2) t = + let + fun write i x = AS.update (t, i, x) + + val n1 = AS.length s1 + val n2 = AS.length s2 + + (* i1 index into s1 + * i2 index into s2 + * j index into output *) + fun loop i1 i2 j = + if i1 = n1 then + Util.foreach (sliceIdxs s2 i2 n2) (fn (i, x) => write (i + j) x) + else if i2 = n2 then + Util.foreach (sliceIdxs s1 i1 n1) (fn (i, x) => write (i + j) x) + else + let + val x1 = AS.sub (s1, i1) + val x2 = AS.sub (s2, i2) + in + (* NOTE: this is stable *) + case cmp (x1, x2) of + GREATER => (write j x2; loop i1 (i2 + 1) (j + 1)) + | _ => (write j x1; loop (i1 + 1) i2 (j + 1)) + end + in + loop 0 0 0 + end + + fun mergeSerial cmp (s1, s2) = + let val out = AS.full (allocate (AS.length s1 + AS.length s2)) + in writeMergeSerial cmp (s1, s2) out; out + end + + fun writeMerge cmp (s1, s2) t = + if AS.length t <= serialGrain then + writeMergeSerial cmp (s1, s2) t + else if AS.length s1 = 0 then + Util.foreach s2 (fn (i, x) => AS.update (t, i, x)) + else + let + val n1 = AS.length s1 + val n2 = AS.length s2 + val mid1 = n1 div 2 + val pivot = AS.sub (s1, mid1) + val mid2 = BinarySearch.countLess cmp s2 pivot + + val l1 = sliceIdxs s1 0 mid1 + val r1 = sliceIdxs s1 (mid1 + 1) n1 + val l2 = sliceIdxs s2 0 mid2 + val r2 = sliceIdxs s2 mid2 n2 + + val _ = AS.update (t, mid1 + mid2, pivot) + val tl = sliceIdxs t 0 (mid1 + mid2) + val tr = sliceIdxs t (mid1 + mid2 + 1) (AS.length t) + in + par (fn _ => writeMerge cmp (l1, l2) tl, fn _ => + writeMerge cmp (r1, r2) tr); + () + end + + fun merge cmp (s1, s2) = + let val out = AS.full (allocate (AS.length s1 + AS.length s2)) + in writeMerge cmp (s1, s2) out; out + end + +end diff --git a/tests/mpllib/StableMergeLowSpan.sml b/tests/mpllib/StableMergeLowSpan.sml new file mode 100644 index 000000000..a24c2c0a4 --- /dev/null +++ b/tests/mpllib/StableMergeLowSpan.sml @@ -0,0 +1,69 @@ +structure StableMergeLowSpan: +sig + type 'a seq = 'a ArraySlice.slice + + val writeMerge: ('a * 'a -> order) (* compare *) + -> 'a seq * 'a seq (* (sorted) sequences to merge *) + -> 'a seq (* output *) + -> unit + + val merge: ('a * 'a -> order) -> 'a seq * 'a seq -> 'a seq +end = +struct + + structure AS = ArraySlice + type 'a seq = 'a AS.slice + fun slice_idxs s (i, j) = + AS.subslice (s, i, SOME (j - i)) + + + (* DoubleBinarySearch guarantees that it takes the _minimum_ number of + * elements from the first argument. For stability, we want to take the + * _maximum_ number of elements from s1; this is equivalent to taking the + * minimum from s2. So, we can just swap the order of the arguments we + * give to the search. + *) + fun split_count_take_max_left cmp (s1, s2) k = + let val (i2, i1) = DoubleBinarySearch.split_count_slice cmp (s2, s1) k + in (i1, i2) + end + + + val blockSizeFactor = + CommandLineArgs.parseReal "MPLLib_StableMergeLowSpan_blockSizeFactor" 1000.0 + + + fun log2 x = + Real64.Math.log10 (Real64.fromInt x) / Real64.Math.log10 2.0 + + + fun writeMerge cmp (s1, s2) output = + let + val n = AS.length s1 + AS.length s2 + val logn = if n <= 2 then 1.0 else log2 n + val blockSize = Real64.ceil (blockSizeFactor * logn) + val numBlocks = Util.ceilDiv n blockSize + in + ForkJoin.parfor 1 (0, numBlocks) (fn b => + let + val start = blockSize * b + val stop = Int.min (n, start + blockSize) + + val (i1, i2) = split_count_take_max_left cmp (s1, s2) start + val (j1, j2) = split_count_take_max_left cmp (s1, s2) stop + + val piece1 = slice_idxs s1 (i1, j1) + val piece2 = slice_idxs s2 (i2, j2) + val piece_output = slice_idxs output (start, stop) + in + StableMerge.writeMerge cmp (piece1, piece2) piece_output + end) + end + + + fun merge cmp (s1, s2) = + let val out = AS.full (ForkJoin.alloc (AS.length s1 + AS.length s2)) + in writeMerge cmp (s1, s2) out; out + end + +end diff --git a/tests/mpllib/StableSort.sml b/tests/mpllib/StableSort.sml new file mode 100644 index 000000000..a9429e788 --- /dev/null +++ b/tests/mpllib/StableSort.sml @@ -0,0 +1,81 @@ +structure StableSort: +sig + type 'a seq = 'a ArraySlice.slice + val sortInPlace: ('a * 'a -> order) -> 'a seq -> unit + val sort: ('a * 'a -> order) -> 'a seq -> 'a seq +end = +struct + + type 'a seq = 'a ArraySlice.slice + + structure AS = ArraySlice + + fun take s n = AS.subslice (s, 0, SOME n) + fun drop s n = AS.subslice (s, n, NONE) + + val par = ForkJoin.par + val allocate = ForkJoin.alloc + + (* in-place sort s, using t as a temporary array if needed *) + fun sortInPlace' cmp s t = + if AS.length s <= 1 then + () + else let + val half = AS.length s div 2 + val (sl, sr) = (take s half, drop s half) + val (tl, tr) = (take t half, drop t half) + in + (* recursively sort, writing result into t *) + if AS.length s <= 1024 then + (writeSort cmp sl tl; writeSort cmp sr tr) + else + ( par (fn _ => writeSort cmp sl tl, fn _ => writeSort cmp sr tr) + ; () + ); + + (* merge back from t into s *) + StableMerge.writeMerge cmp (tl, tr) s; + + () + end + + (* destructively sort s, writing the result in t *) + and writeSort cmp s t = + if AS.length s <= 1 then + Util.foreach s (fn (i, x) => AS.update (t, i, x)) + else let + val half = AS.length s div 2 + val (sl, sr) = (take s half, drop s half) + val (tl, tr) = (take t half, drop t half) + in + (* recursively in-place sort sl and sr *) + if AS.length s <= 1024 then + (sortInPlace' cmp sl tl; sortInPlace' cmp sr tr) + else + ( par (fn _ => sortInPlace' cmp sl tl, fn _ => sortInPlace' cmp sr tr) + ; () + ); + + (* merge into t *) + StableMerge.writeMerge cmp (sl, sr) t; + + () + end + + fun sortInPlace cmp s = + let + val t = AS.full (allocate (AS.length s)) + in + sortInPlace' cmp s t + end + + fun sort cmp s = + let + val result = AS.full (allocate (AS.length s)) + in + Util.foreach s (fn (i, x) => AS.update (result, i, x)); + sortInPlace cmp result; + result + end + +end diff --git a/tests/mpllib/TFlatten.sml b/tests/mpllib/TFlatten.sml new file mode 100644 index 000000000..a5f552bc4 --- /dev/null +++ b/tests/mpllib/TFlatten.sml @@ -0,0 +1,51 @@ +structure TFlatten: +sig + type 'a tree + type 'a t = 'a tree + + datatype 'a view = Leaf of 'a Seq.t | Node of 'a t * 'a t + + val size: 'a t -> int + val leaf: 'a Seq.t -> 'a t + val node: 'a t * 'a t -> 'a t + val view: 'a t -> 'a view + val flatten: 'a t -> 'a Seq.t +end = +struct + + datatype 'a tree = Leaf_ of 'a Seq.t | Node_ of int * 'a tree * 'a tree + type 'a t = 'a tree + + datatype 'a view = Leaf of 'a Seq.t | Node of 'a t * 'a t + + fun size (Leaf_ s) = Seq.length s + | size (Node_ (n, _, _)) = n + + fun leaf s = Leaf_ s + + fun node (l, r) = + Node_ (size l + size r, l, r) + + fun view (Leaf_ s) = Leaf s + | view (Node_ (_, l, r)) = Node (l, r) + + + fun flatten t = + let + val output = ForkJoin.alloc (size t) + fun traverse (c, offset) = + case c of + Leaf_ s => + ForkJoin.parfor 100 (0, Seq.length s) (fn i => + Array.update (output, offset + i, Seq.nth s i)) + | Node_ (_, l, r) => + ( ForkJoin.par (fn () => traverse (l, offset), fn () => + traverse (r, offset + size l)) + ; () + ) + in + traverse (t, 0); + ArraySlice.full output + end + +end diff --git a/tests/mpllib/TabFilterTree.sml b/tests/mpllib/TabFilterTree.sml new file mode 100644 index 000000000..0f1d98c15 --- /dev/null +++ b/tests/mpllib/TabFilterTree.sml @@ -0,0 +1,76 @@ +structure TabFilterTree = +struct + + structure A = Array + structure AS = ArraySlice + + structure ChunkList = + struct + val chunkSize = 256 + type 'a t = 'a array list * 'a array * int + fun new () = ([], ForkJoin.alloc chunkSize, 0) + + fun push ((elems, chunk, pos): 'a t) (x: 'a) = + if pos >= chunkSize then + push (chunk :: elems, ForkJoin.alloc chunkSize, 0) x + else + ( A.update (chunk, pos, x); (elems, chunk, pos+1) ) + + fun finish ((elems, chunk, pos): 'a t) = + (List.rev elems, chunk, pos) + + fun foreach offset (elems, lastChunk, lastLen) f = + case elems of + [] => AS.appi (fn (i, x) => f (offset+i, x)) (AS.slice (lastChunk, 0, SOME lastLen)) + | (chunk' :: elems') => + ( A.appi (fn (i, x) => f (offset+i, x)) chunk' + ; foreach (offset+chunkSize) (elems', lastChunk, lastLen) f + ) + end + + datatype 'a tree = + Leaf of int * 'a ChunkList.t + | Node of int * 'a tree * 'a tree + + fun size (Leaf (n, _)) = n + | size (Node (n, _, _)) = n + + fun tabFilter grain (lo, hi) (f: int -> 'a option) = + let + fun filterSeq (count, elems) (i, j) = + if i >= j then + (count, ChunkList.finish elems) + else + case f i of + NONE => filterSeq (count, elems) (i, j) + | SOME x => filterSeq (count+1, ChunkList.push elems x) (i+1, j) + + fun t i j = + if j-i <= grain then + Leaf (filterSeq (0, ChunkList.new ()) (i, j)) + else + let + val mid = i + (j-i) div 2 + val (l, r) = ForkJoin.par (fn _ => t i mid, fn _ => t mid j) + in + Node (size l + size r, l, r) + end + in + t lo hi + end + + fun foreach (t: 'a tree) (f: (int * 'a) -> unit) = + let + fun doit offset t = + case t of + Leaf (_, elems) => ChunkList.foreach offset elems f + | Node (_, l, r) => + ( ForkJoin.par (fn _ => doit offset l, + fn _ => doit (offset + size l) r) + ; () + ) + in + doit 0 t + end + +end diff --git a/tests/mpllib/Tokenize.sml b/tests/mpllib/Tokenize.sml new file mode 100644 index 000000000..6f35af15d --- /dev/null +++ b/tests/mpllib/Tokenize.sml @@ -0,0 +1,53 @@ +structure Tokenize: +sig + val tokenRanges: (char -> bool) -> char Seq.t -> int * (int -> (int * int)) + + val tokensSeq: (char -> bool) -> char Seq.t -> (char Seq.t) Seq.t + + val tokens: (char -> bool) -> char Seq.t -> string Seq.t +end = +struct + + fun tokenRanges f s = + let + val n = Seq.length s + fun check i = + if (i = n) then not (f(Seq.nth s (n-1))) + else if (i = 0) then not (f(Seq.nth s 0)) + else let val i1 = f (Seq.nth s i) + val i2 = f (Seq.nth s (i-1)) + in (i1 andalso not i2) orelse (i2 andalso not i1) end + val ids = ArraySlice.full + (SeqBasis.filter 10000 (0, n+1) (fn i => i) check) + val count = (Seq.length ids) div 2 + in + (count, fn i => (Seq.nth ids (2*i), Seq.nth ids (2*i+1))) + end + + fun tokensSeq f s = + let + val (n, g) = tokenRanges f s + fun token i = + let + val (lo, hi) = g i + in + Seq.subseq s (lo, hi-lo) + end + in + Seq.tabulate token n + end + + fun tokens f s = + let + val (n, g) = tokenRanges f s + fun token i = + let + val (lo, hi) = g i + val chars = Seq.subseq s (lo, hi-lo) + in + CharVector.tabulate (Seq.length chars, Seq.nth chars) + end + in + ArraySlice.full (SeqBasis.tabulate 1024 (0, n) token) + end +end diff --git a/tests/mpllib/Topology2D.sml b/tests/mpllib/Topology2D.sml new file mode 100644 index 000000000..98a45ac47 --- /dev/null +++ b/tests/mpllib/Topology2D.sml @@ -0,0 +1,1088 @@ +structure Topology2D: +sig + type vertex = int + type vertex_data = Geometry2D.point + + type triangle = int + datatype triangle_data = + Tri of + { vertices: vertex * vertex * vertex + , neighbors: triangle * triangle * triangle + } + + type mesh + val parseFile: string -> mesh + val numVertices: mesh -> int + val numTriangles: mesh -> int + val toString: mesh -> string + + val initialMeshWithBoundaryCircle + : {numVertices: int, numBoundaryVertices: int} + -> {center: Geometry2D.point, radius: real} + -> mesh + + val vdata: mesh -> vertex -> vertex_data + val tdata: mesh -> triangle -> triangle_data + val verticesOfTriangle: mesh -> triangle -> vertex * vertex * vertex + val neighborsOfTriangle: mesh -> triangle -> triangle * triangle * triangle + val triangleOfVertex: mesh -> vertex -> triangle + val getPoints: mesh -> Geometry2D.point Seq.t + (* val neighbor: triangle_data -> int -> triangle option *) + (* val locate: triangle_data -> triangle -> int option *) + + type simplex + + val find: mesh -> vertex -> simplex -> simplex + val findPoint: mesh -> Geometry2D.point -> simplex -> simplex + + val across: mesh -> simplex -> simplex option + val rotateClockwise: simplex -> simplex + val outside: mesh -> simplex -> vertex -> bool + val pointOutside: mesh -> simplex -> Geometry2D.point -> bool + val inCircle: mesh -> simplex -> vertex -> bool + val pointInCircle: mesh -> simplex -> Geometry2D.point -> bool + val firstVertex: mesh -> simplex -> vertex + + val split: mesh -> triangle -> Geometry2D.point -> mesh + val flip: mesh -> simplex -> mesh + + (** A cavity is a center triangle and a set of nearby connected simplices. + * The order of the nearby simplices is important: these must emanate + * from the center triangle. + *) + type cavity = triangle * (simplex list) + + val findCavityAndPerimeter: mesh + -> simplex (** where to start search *) + -> Geometry2D.point (** center of the cavity *) + -> cavity * (vertex list) + + val loopPerimeter: mesh + -> triangle (* triangle containing center point *) + -> Geometry2D.point (* center of the cavity *) + -> 'a + -> ('a * vertex -> 'a) + -> 'a + + val findCavity: mesh + -> triangle (* triangle containing center point *) + -> Geometry2D.point (* center point of the cavity *) + -> cavity + + val ripAndTentCavity: mesh + -> triangle (* center triangle *) + -> (vertex * Geometry2D.point) (* center of cavity and vertex id to use *) + -> triangle * triangle (* two new triangles to use *) + -> unit + + (** For each (c, p), replace cavity c with a tent using p as the center + * point. The center triangle of of the cavity must contain p. + *) + val ripAndTent: (cavity * Geometry2D.point) Seq.t -> mesh -> mesh + val ripAndTentOne: cavity * Geometry2D.point -> mesh -> mesh + + + (** The following are for imperative algorithms on meshes. *) + + val new: {numVertices: int, numTriangles: int} -> mesh + + val doSplit: mesh + -> triangle (* triangle to split *) + -> vertex * Geometry2D.point (* point inside triangle, and vertex identifier to use *) + -> triangle * triangle (* two new triangle identifiers to create *) + -> unit + + val doFlip: mesh -> simplex -> unit + + val copyData: {src: mesh, dst: mesh} -> unit + val copy: mesh -> mesh + +end = +struct + + structure AS = ArraySlice + structure G = Geometry2D + + fun upd s i x = AS.update (s, i, x) + fun nth s i = AS.sub (s, i) + + (** vertex and triangle identifiers are indices into a mesh *) + type vertex = int + type triangle = int + + val INVALID_ID = ~1 + + type vertex_data = G.point + + (** Triangles with vertices (u,v,w) and neighbors (a,b,c) must be in + * counter-clockwise order. + * + * u + * | \ --> a + * b <-- | w + * | / --> c + * v + * + * This is equivalent to any rotation, e.g. [(v,w,u),(b,c,a)]. But CCW + * order must be preserved. + *) + datatype triangle_data = + Tri of + { vertices: vertex * vertex * vertex + , neighbors: triangle * triangle * triangle + } + + + val dummyPt = (0.0, 0.0) + val dummyTriple = (INVALID_ID, INVALID_ID, INVALID_ID) + val dummyTri = + Tri {vertices = dummyTriple, neighbors = dummyTriple} + + + datatype mesh = + Mesh of + { vdata: vertex_data Seq.t + , verticesOfTriangle: (vertex * vertex * vertex) Seq.t + , neighborsOfTriangle: (triangle * triangle * triangle) Seq.t + , triangleOfVertex: triangle Seq.t + } + + fun new {numVertices, numTriangles} = + Mesh { vdata = Seq.tabulate (fn _ => dummyPt) numVertices + , triangleOfVertex = Seq.tabulate (fn _ => INVALID_ID) numVertices + , verticesOfTriangle = Seq.tabulate (fn _ => dummyTriple) numTriangles + , neighborsOfTriangle = Seq.tabulate (fn _ => dummyTriple) numTriangles + } + + fun copyData {src = Mesh src, dst = Mesh dst} = + let + val len = Seq.length + val lengthsOkay = + len (#vdata src) <= len (#vdata dst) andalso + len (#triangleOfVertex src) <= len (#triangleOfVertex dst) andalso + len (#verticesOfTriangle src) <= len (#verticesOfTriangle dst) andalso + len (#neighborsOfTriangle src) <= len (#neighborsOfTriangle dst) + val _ = + if lengthsOkay then () + else raise Fail "Topology2D.copyData: dst smaller than src" + in + ForkJoin.parfor 10000 (0, len (#vdata src)) (fn i => + upd (#vdata dst) i (nth (#vdata src) i)); + + ForkJoin.parfor 10000 (0, len (#triangleOfVertex src)) (fn i => + upd (#triangleOfVertex dst) i (nth (#triangleOfVertex src) i)); + + ForkJoin.parfor 10000 (0, len (#verticesOfTriangle src)) (fn i => + upd (#verticesOfTriangle dst) i (nth (#verticesOfTriangle src) i)); + + ForkJoin.parfor 10000 (0, len (#neighborsOfTriangle src)) (fn i => + upd (#neighborsOfTriangle dst) i (nth (#neighborsOfTriangle src) i)) + end + + fun tdata (Mesh mesh) t = + Tri { vertices = nth (#verticesOfTriangle mesh) t + , neighbors = nth (#neighborsOfTriangle mesh) t + } + + fun verticesOfTriangle (Mesh mesh) t = + nth (#verticesOfTriangle mesh) t + + fun neighborsOfTriangle (Mesh mesh) t = + nth (#neighborsOfTriangle mesh) t + + fun vdata (Mesh mesh) t = nth (#vdata mesh) t + + fun triangleOfVertex (Mesh mesh) v = nth (#triangleOfVertex mesh) v + + fun getPoints (Mesh {vdata, ...}) = vdata + + fun numVertices (Mesh {vdata, ...}) = Seq.length vdata + fun numTriangles (Mesh {verticesOfTriangle, ...}) = + Seq.length verticesOfTriangle + + fun copy mesh = + let + val n = numVertices mesh + val m = numTriangles mesh + val vdata = AS.full (ForkJoin.alloc n) + val triangleOfVertex = AS.full (ForkJoin.alloc n) + val verticesOfTriangle = AS.full (ForkJoin.alloc m) + val neighborsOfTriangle = AS.full (ForkJoin.alloc m) + + val mesh' = + Mesh { vdata = vdata + , triangleOfVertex = triangleOfVertex + , verticesOfTriangle = verticesOfTriangle + , neighborsOfTriangle = neighborsOfTriangle + } + in + copyData {src = mesh, dst = mesh'}; + mesh' + end + + + fun vertex (vertices as (a,b,c)) i = + case i of + 0 => a + | 1 => b + | _ => c + + + fun neighbor (neighbors as (a,b,c)) i = + let + val t' = + case i of + 0 => a + | 1 => b + | _ => c + in + if t' < 0 then NONE else SOME t' + end + + + fun locate (neighbors as (a,b,c)) (t: triangle) = + if a = t then SOME 0 + else if b = t then SOME 1 + else if c = t then SOME 2 + else NONE + + + fun hasEdge (vertices as (a,b,c)) (u,v) = + (u = a orelse u = b orelse u = c) + andalso + (v = a orelse v = b orelse v = c) + + + fun sortTriangleCCW mesh (Tri {vertices=(v1,v2,v3), neighbors=(t1,t2,t3)}) = + let + fun p v = vdata mesh v + val (v2, v3) = + if G.Point.counterClockwise (p v1, p v2, p v3) then + (v2, v3) + else + (v3, v2) + + fun checkHasEdge (u,v) t = + t <> INVALID_ID andalso hasEdge (verticesOfTriangle mesh t) (u,v) + + val (t1,t2,t3) = + if checkHasEdge (v1,v3) t1 then + (t1,t2,t3) + else if checkHasEdge (v1,v3) t2 then + (t2,t1,t3) + else + (t3,t1,t2) + + val (t2,t3) = + if checkHasEdge (v2,v1) t2 then + (t2,t3) + else + (t3,t2) + + in + Tri {vertices=(v1,v2,v3), neighbors=(t1,t2,t3)} + end + + + (** A simplex is an oriented triangle, which essentially just selects an + * edge of the triangle (the integer indicates which edge with the value + * 0, 1, or 2). The orientation allows us to define operations such as + * "across" which returns the simplex on the other side of the + * distinguished edge. + *) + type simplex = triangle * int + + + fun triangleOfSimplex ((t, _): simplex) = t + + + fun orientedTriangleData mesh (t, i) = + let + val Tri {vertices=(a,b,c), neighbors=(d,e,f)} = + tdata mesh t + in + case i of + 0 => Tri {vertices=(a,b,c), neighbors=(d,e,f)} + | 1 => Tri {vertices=(b,c,a), neighbors=(e,f,d)} + | _ => Tri {vertices=(c,a,b), neighbors=(f,d,e)} + end + + + fun across mesh ((t, i): simplex) : simplex option = + case neighbor (neighborsOfTriangle mesh t) i of + SOME t' => + (case locate (neighborsOfTriangle mesh t') t of + SOME i' => SOME (t', i') + | NONE => NONE) + | NONE => NONE + + + fun fastNeighbor (neighbors as (a,b,c)) i = + case i of + 0 => a + | 1 => b + | _ => c + + fun fastLocate (neighbors as (a,b,c)) (t: triangle) = + if a = t then 0 + else if b = t then 1 + else 2 + + fun fastAcross mesh ((t, i): simplex) = + let + val t' = fastNeighbor (neighborsOfTriangle mesh t) i + in + (t', fastLocate (neighborsOfTriangle mesh t') t) + end + + fun mod3 i = + if i > 2 then i-3 else i + + fun rotateClockwise ((t, i): simplex) : simplex = + (t, mod3 (i+1)) + + fun pointOutside mesh ((t, i): simplex) pt = + let + val vs = verticesOfTriangle mesh t + val p1 = vdata mesh (vertex vs (mod3 (i+2))) + val p2 = pt + val p3 = vdata mesh (vertex vs i) + in + G.Point.counterClockwise (p1, p2, p3) + end + + fun outside mesh (simp: simplex) v = + pointOutside mesh simp (vdata mesh v) + + fun pointInCircle mesh ((t, _): simplex) pt = + let + val (a,b,c) = verticesOfTriangle mesh t + val p1 = vdata mesh a + val p2 = vdata mesh b + val p3 = vdata mesh c + in + G.Point.inCircle (p1, p2, p3) pt + end + + fun inCircle mesh simp v = + pointInCircle mesh simp (vdata mesh v) + + fun firstVertex mesh ((t, i): simplex) = + vertex (verticesOfTriangle mesh t) i + + + (** ======================================================================== + * traversal and cavities + *) + + fun findPoint mesh pt current = + if pointOutside mesh current pt then + findPoint mesh pt (fastAcross mesh current) + else + let val current = rotateClockwise current in + if pointOutside mesh current pt then + findPoint mesh pt (fastAcross mesh current) + else + let val current = rotateClockwise current in + if pointOutside mesh current pt then + findPoint mesh pt (fastAcross mesh current) + else + current + end end + + + (** find: mesh -> vertex -> simplex -> simplex *) + fun find mesh v current = + findPoint mesh (vdata mesh v) current + + + type cavity = triangle * (simplex list) + + + fun loopPerimeter mesh center pt (b: 'a) (f: 'a * vertex -> 'a) = + let + fun loop b t = + if not (pointInCircle mesh t pt) then + b + else + let + val t = rotateClockwise t + val b = loopAcross b t + val b = f (b, firstVertex mesh t) + val t = rotateClockwise t + val b = loopAcross b t + in + b + end + + and loopAcross b t = + case across mesh t of + SOME t' => loop b t' + | NONE => b + + (* val center = findPoint mesh pt findStart *) + + val t = (center, 0) + + val b = f (b, firstVertex mesh t) + val b = loopAcross b t + + val t = rotateClockwise t + val b = f (b, firstVertex mesh t) + val b = loopAcross b t + + val t = rotateClockwise t + val b = f (b, firstVertex mesh t) + val b = loopAcross b t + in + b + end + + + fun findCavityAndPerimeter mesh findStart (pt: Geometry2D.point) = + let + fun loop (simps, verts) t = + if not (pointInCircle mesh t pt) then + (simps, verts) + else + let + val simps = t :: simps + val t = rotateClockwise t + val (simps, verts) = loopAcross (simps, verts) t + val verts = firstVertex mesh t :: verts + val t = rotateClockwise t + val (simps, verts) = loopAcross (simps, verts) t + in + (simps, verts) + end + + and loopAcross (simps, verts) t = + case across mesh t of + SOME t' => loop (simps, verts) t' + | NONE => (simps, verts) + + val center = findPoint mesh pt findStart + + val t = center + val (simps, verts) = ([], []) + + val verts = firstVertex mesh t :: verts + val (simps, verts) = loopAcross (simps, verts) t + + val t = rotateClockwise t + val verts = firstVertex mesh t :: verts + val (simps, verts) = loopAcross (simps, verts) t + + val t = rotateClockwise t + val verts = firstVertex mesh t :: verts + val (simps, verts) = loopAcross (simps, verts) t + + val cavity = (triangleOfSimplex center, List.rev simps) + in + (cavity, verts) + end + + + fun findCavity mesh center (pt: Geometry2D.point) = + let + fun loop simps t = + if not (pointInCircle mesh t pt) then + simps + else + let + val simps = t :: simps + val t = rotateClockwise t + val simps = loopAcross simps t + val t = rotateClockwise t + val simps = loopAcross simps t + in + simps + end + + and loopAcross simps t = + case across mesh t of + SOME t' => loop simps t' + | NONE => simps + + (* val center = findPoint mesh pt findStart *) + + val t = (center, 0) + val simps = [] + + val simps = loopAcross simps t + val t = rotateClockwise t + val simps = loopAcross simps t + val t = rotateClockwise t + val simps = loopAcross simps t + + val cavity = (center, List.rev simps) + in + cavity + end + + + (** ======================================================================== + * split triangle + *) + + exception FailedReplaceNeighbor + + fun replaceNeighbor (Mesh {neighborsOfTriangle, ...}) t (old, new) = + if t < 0 then () else + let + val (a,b,c) = nth neighborsOfTriangle t + val newNeighbors = + if old = a then (new, b, c) + else if old = b then (a, new, c) + else if old = c then (a, b, new) + else raise FailedReplaceNeighbor + in + upd neighborsOfTriangle t newNeighbors + end + + + fun updateTriangle + (Mesh {verticesOfTriangle, neighborsOfTriangle, ...}) t (Tri {vertices, neighbors}) + = + ( upd verticesOfTriangle t vertices + ; upd neighborsOfTriangle t neighbors + ) + + + (** split triangle t by putting a new vertex v at point p inside. This creates + * two new triangles, which will have ids t1 and t2. This function modifies + * the mesh by editing triangle data (t, ta0, ta1) and vertex data (v). + * + * BEFORE: AFTER: + * v1 v1 + * |\ |\\ + * | \ t1 | \ \ t1 + * | \ | \ \ + * | \ | \ t \ + * t2 | t v3 t2 |ta0 v --- v3 + * | / | / ta1 / + * | / | / / + * | / t3 | / / t3 + * |/ |// + * v2 v2 + *) + fun doSplit (mesh as Mesh {vdata, triangleOfVertex, ...}) t (v, p) (ta0, ta1) = + let + val Tri {vertices=(v1,v2,v3), neighbors=(t1,t2,t3)} = tdata mesh t + val newdata_t = + Tri {vertices=(v1,v,v3), neighbors=(t1,ta0,ta1)} + val newdata_ta0 = + Tri {vertices=(v2,v,v1), neighbors=(t2,ta1,t)} + val newdata_ta1 = + Tri {vertices=(v3,v,v2), neighbors=(t3,t,ta0)} + in + upd vdata v p; + upd triangleOfVertex v t; + if nth triangleOfVertex v2 <> t then () + else upd triangleOfVertex v2 ta0; + updateTriangle mesh t newdata_t; + updateTriangle mesh ta0 newdata_ta0; + updateTriangle mesh ta1 newdata_ta1; + replaceNeighbor mesh t2 (t,ta0); + replaceNeighbor mesh t3 (t,ta1) + end + + + fun split (Mesh {verticesOfTriangle, neighborsOfTriangle, vdata, triangleOfVertex}) (t: triangle) p = + let + val n = Seq.length vdata + val m = Seq.length neighborsOfTriangle + + (** allocate new with dummy values *) + val vdata' = Seq.append (vdata, Seq.singleton (nth vdata 0)) + val verticesOfTriangle' = Seq.append (verticesOfTriangle, Seq.fromList [dummyTriple, dummyTriple]) + val neighborsOfTriangle' = Seq.append (neighborsOfTriangle, Seq.fromList [dummyTriple, dummyTriple]) + val triangleOfVertex' = Seq.append (triangleOfVertex, Seq.singleton t) + val mesh' = + Mesh { vdata=vdata' + , verticesOfTriangle=verticesOfTriangle' + , neighborsOfTriangle=neighborsOfTriangle' + , triangleOfVertex=triangleOfVertex' + } + in + (* print ("splitting " ^ Int.toString t ^ " into: " ^ String.concatWith " " (List.map Int.toString [t,m,m+1]) ^ "\n"); *) + doSplit mesh' t (n, p) (m, m+1); + mesh' + end + + + (** ======================================================================== + * flip simplex + *) + + + (** Flip the shared edge identified by the simplex. + * + * BEFORE: AFTER: + * v3 v3 + * /|\ / \ + * t4 / | \ t3 t4 / \ t3 + * / | \ / t \ + * v4 t1 | t v2 v4 --------- v2 + * \ | / \ t1 / + * t5 \ | / t2 t5 \ / t2 + * \|/ \ / + * v1 v1 + *) + fun doFlip (mesh as Mesh {triangleOfVertex, ...}) (simp: simplex) = + let + val Tri {vertices=(v1,v2,v3), neighbors=(t1,t2,t3)} = + orientedTriangleData mesh simp + val Tri {vertices=(v3_,v4,v1_), neighbors=(t,t4,t5)} = + orientedTriangleData mesh (fastAcross mesh simp) + + (* val _ = + print ("flipping " ^ Int.toString t ^ " and " ^ Int.toString t1 ^ "\n") *) + + (** sanity check *) + (* val _ = + if v3 = v3_ andalso v1 = v1_ andalso t = triangleOfSimplex simp then () + else raise Fail "effed up flip" *) + + val newdata_t = + Tri {vertices=(v2,v3,v4), neighbors=(t1,t3,t4)} + + val newdata_t1 = + Tri {vertices=(v1,v2,v4), neighbors=(t5,t2,t)} + + fun replaceTriangleOfVertex v (old, new) = + if nth triangleOfVertex v <> old then () + else upd triangleOfVertex v new + in + updateTriangle mesh t newdata_t; + updateTriangle mesh t1 newdata_t1; + replaceTriangleOfVertex v1 (t, t1); + replaceTriangleOfVertex v3 (t1, t); + replaceNeighbor mesh t2 (t, t1); + replaceNeighbor mesh t4 (t1, t) + end + + fun flip (Mesh {verticesOfTriangle,neighborsOfTriangle,vdata,triangleOfVertex}) simp = + let + (** make a copy *) + val mesh' = + Mesh {verticesOfTriangle = Seq.map (fn x => x) verticesOfTriangle, + neighborsOfTriangle = Seq.map (fn x => x) neighborsOfTriangle, + vdata = Seq.map (fn x => x) vdata, + triangleOfVertex = Seq.map (fn x => x) triangleOfVertex} + in + doFlip mesh' simp; + mesh' + end + + (** ======================================================================== + * loop to find cavity, and do rip-and-tent + *) + +(* + fun ripAndTentCavity mesh center (v, pt) (ta0, ta1) = + let + val (_, simps) = findCavity mesh center pt + in + doSplit mesh center (v, pt) (ta0, ta1); + List.app (doFlip mesh) simps + end +*) + + + fun ripAndTentCavity mesh center (v, pt) (ta0, ta1) = + let + fun loop t = + if not (pointInCircle mesh t pt) then + () + else + let + val t1 = across mesh (rotateClockwise t) + val t2 = across mesh (rotateClockwise (rotateClockwise t)) + in + doFlip mesh t; + maybeLoop t1; + maybeLoop t2 + end + + and maybeLoop t = + case t of + SOME t => loop t + | NONE => () + + val t = (center, 0) + val t1 = across mesh t + val t2 = across mesh (rotateClockwise t) + val t3 = across mesh (rotateClockwise (rotateClockwise t)) + in + doSplit mesh center (v, pt) (ta0, ta1); + maybeLoop t1; + maybeLoop t2; + maybeLoop t3 + end + + +(* + fun ripAndTentCavity mesh center (v, pt) (ta0, ta1) = + let + fun maybePush x xs = + case x of + SOME x => (if pointInCircle mesh x pt then x :: xs else xs) + | NONE => xs + + fun loop ts = + case ts of + [] => () + | t :: ts => + let + val t1 = across mesh (rotateClockwise t) + val t2 = across mesh (rotateClockwise (rotateClockwise t)) + in + doFlip mesh t; + loop (maybePush t1 (maybePush t2 ts)) + end + + val t = (center, 0) + val t1 = across mesh t + val t2 = across mesh (rotateClockwise t) + val t3 = across mesh (rotateClockwise (rotateClockwise t)) + in + doSplit mesh center (v, pt) (ta0, ta1); + loop (maybePush t1 (maybePush t2 (maybePush t3 []))) + end +*) + + (** ======================================================================== + * purely functional rip-and-tent on cavities (returns new mesh) + *) + + + fun ripAndTentOne ((t, simps): cavity, pt: G.point) mesh = + (* List.foldl (fn (s: simplex, m: mesh) => flip m s) (split mesh t pt) simps *) + let + val mesh' = split mesh t pt + in + List.app (doFlip mesh') simps; + mesh' + end + + + fun ripAndTent cavities (Mesh {verticesOfTriangle, neighborsOfTriangle, vdata, triangleOfVertex}) = + let + val numVerts = Seq.length vdata + val numTriangles = Seq.length verticesOfTriangle + val numNewVerts = Seq.length cavities + val numNewTriangles = 2 * numNewVerts + + val vdata' = Seq.tabulate (fn i => + if i < numVerts then nth vdata i else dummyPt) + (numVerts + numNewVerts) + val verticesOfTriangle' = Seq.tabulate (fn i => + if i < numTriangles then nth verticesOfTriangle i else dummyTriple) + (numTriangles + numNewTriangles) + val neighborsOfTriangle' = Seq.tabulate (fn i => + if i < numTriangles then nth neighborsOfTriangle i else dummyTriple) + (numTriangles + numNewTriangles) + val triangleOfVertex' = Seq.tabulate (fn i => + if i < numVerts then nth triangleOfVertex i else INVALID_ID) + (numVerts + numNewVerts) + val mesh' = + Mesh { vdata=vdata' + , verticesOfTriangle=verticesOfTriangle' + , neighborsOfTriangle=neighborsOfTriangle' + , triangleOfVertex=triangleOfVertex' + } + in + ForkJoin.parfor 100 (0, Seq.length cavities) (fn i => + let + val (cavity as (center, simps), pt) = nth cavities i + val ta0 = numTriangles + 2*i + val ta1 = ta0 + 1 + in + doSplit mesh' center (numVerts+i, pt) (ta0, ta1); + List.app (doFlip mesh') simps + end); + + mesh' + end + + + (** ======================================================================== + * generating an initial boundary + *) + + fun initialMeshWithBoundaryCircle + {numVertices, numBoundaryVertices} + {center, radius} + = + let + val pi = Math.pi + val n = Real.fromInt numBoundaryVertices + + val numNonBoundaryVertices = numVertices - numBoundaryVertices + val numNonBoundaryTriangles = 2 * numNonBoundaryVertices + val numBoundaryTriangles = numBoundaryVertices-2 + val numTriangles = numNonBoundaryTriangles + numBoundaryTriangles + + fun boundaryPoint i = + let + val ri = Real.fromInt i + val x = radius * Math.cos (2.0 * pi * (ri / n)) + val y = radius * Math.sin (2.0 * pi * (ri / n)) + val offset: G.point = (x,y) + in + G.Vector.add (center, offset) + end + + fun vertexPoint i = + if i < numNonBoundaryVertices then (0.0, 0.0) + else boundaryPoint (i - numNonBoundaryVertices) + + fun verticesOfTriangle i = + if i < numNonBoundaryTriangles then + (INVALID_ID, INVALID_ID, INVALID_ID) + else + let + val j = i - numNonBoundaryTriangles + numNonBoundaryVertices + in + (j+1, j+2, numNonBoundaryVertices) + end + + fun neighborsOfTriangle i = + if i < numNonBoundaryTriangles then + (INVALID_ID, INVALID_ID, INVALID_ID) + else + ( if i = numNonBoundaryTriangles then INVALID_ID else i-1 + , INVALID_ID + , if i = numTriangles-1 then INVALID_ID else i+1 + ) + + fun triangleOfVertex i = + if i < numNonBoundaryVertices then + INVALID_ID + else + case (i-numNonBoundaryVertices) of + 0 => numNonBoundaryTriangles + | 1 => numNonBoundaryTriangles + | j => numNonBoundaryTriangles + j - 2 + in + Mesh + { vdata = Seq.tabulate vertexPoint numVertices + , triangleOfVertex = Seq.tabulate triangleOfVertex numVertices + , verticesOfTriangle = Seq.tabulate verticesOfTriangle numTriangles + , neighborsOfTriangle = Seq.tabulate neighborsOfTriangle numTriangles + } + end + + (** ======================================================================== + * parsing from file and string representation + *) + + fun sortMeshCCW (mesh as Mesh {vdata, triangleOfVertex, ...}) = + let + val numTriangles = numTriangles mesh + val verticesOfTriangle = AS.full (ForkJoin.alloc numTriangles) + val neighborsOfTriangle = AS.full (ForkJoin.alloc numTriangles) + val mesh' = + Mesh { vdata = vdata + , verticesOfTriangle = verticesOfTriangle + , neighborsOfTriangle = neighborsOfTriangle + , triangleOfVertex = triangleOfVertex + } + in + ForkJoin.parfor 1000 (0, numTriangles) (fn i => + updateTriangle mesh' i (sortTriangleCCW mesh (tdata mesh i))); + + mesh + end + + fun writeMax a i x = + let + fun loop old = + if x <= old then () else + let + val old' = Concurrency.casArray (a, i) (old, x) + in + if old' = old then () + else loop old' + end + in + loop (Array.sub (a, i)) + end + + + fun parseFile filename = + let + val chars = ReadFile.contentsSeq filename + + fun isNewline i = (Seq.nth chars i = #"\n") + + val nlPos = + AS.full (SeqBasis.filter 10000 (0, Seq.length chars) (fn i => i) isNewline) + val numLines = Seq.length nlPos + 1 + fun lineStart i = + if i = 0 then 0 else 1 + Seq.nth nlPos (i-1) + fun lineEnd i = + if i = Seq.length nlPos then Seq.length chars else Seq.nth nlPos i + fun line i = Seq.subseq chars (lineStart i, lineEnd i - lineStart i) + + val _ = + if numLines >= 3 then () + else raise Fail ("Topology2D: read mesh: missing or incomplete header") + + val _ = + if Parse.parseString (line 0) = "Mesh" then () + else raise Fail ("expected Mesh header") + + fun tryParse parser test thing lineNum = + let + fun whoops msg = + raise Fail ("Topology2D: line " + ^ Int.toString (lineNum+1) + ^ ": error while parsing " ^ thing + ^ (case msg of NONE => "" | SOME msg => ": " ^ msg)) + in + case (parser (line lineNum) handle exn => whoops (SOME (exnMessage exn))) of + SOME x => if test x then x else whoops (SOME "test failed") + | NONE => whoops NONE + end + + fun tryParseInt thing lineNum = + tryParse Parse.parseInt (fn x => x >= 0) thing lineNum + fun tryParseReal thing lineNum = + tryParse Parse.parseReal (fn x => true) thing lineNum + + + val numVertices = tryParseInt "num vertices" 1 + val numTriangles = tryParseInt "num triangles" 2 + + fun validVid x = (0 <= x andalso x < numVertices) + fun validTid x = (x = INVALID_ID orelse (0 <= x andalso x< numTriangles)) + fun validTriangle (Tri {vertices=(a,b,c), neighbors=(d,e,f)}) = + validVid a andalso validVid b andalso validVid c + andalso + validTid d andalso validTid e andalso validTid f + + + fun ff range test = FindFirst.findFirstSerial range test + fun ss x (i, j) = Seq.subseq x (i, j-i) + + fun vertexParser line = + let + fun isSpace i = Char.isSpace (Seq.nth line i) + val spaceIdx = valOf (ff (0, Seq.length line) isSpace) + in + SOME + ( valOf (Parse.parseReal (Seq.take line spaceIdx)) + handle Option => raise Fail "bad first value" + , valOf (Parse.parseReal (Seq.drop line (spaceIdx+1))) + handle Option => raise Fail "bad second value" + ) + end + + fun neighborsParser restOfLine = + let + (* val _ = print ("parsing neighbors: " ^ Parse.parseString restOfLine ^ "\n") *) + fun isSpace i = Char.isSpace (Seq.nth restOfLine i) + val n = Seq.length restOfLine + val spPos = AS.full (SeqBasis.filter 10000 (0, n) (fn i => i) isSpace) + val numNbrs = Seq.length spPos + 1 + (* val _ = print ("num neighbors: " ^ Int.toString numNbrs ^ "\n") *) + fun nbrStart i = + if i = 0 then 0 else 1 + Seq.nth spPos (i-1) + fun nbrEnd i = + if i = Seq.length spPos then Seq.length restOfLine else Seq.nth spPos i + fun nbr i = + if i >= numNbrs then INVALID_ID else + ((*print ("nbr " ^ Int.toString i ^ " start " ^ Int.toString (nbrStart i) ^ " end " ^ Int.toString (nbrEnd i) ^ "\n"); + print ("nbrstring: \"" ^ Parse.parseString (ss restOfLine (nbrStart i, nbrEnd i)) ^ "\"\n");*) + valOf (Parse.parseInt (ss restOfLine (nbrStart i, nbrEnd i))) + handle Option => raise Fail ("bad neighbor")) + in + (nbr 0, nbr 1, nbr 2) + end + + fun triangleParser line = + let + fun isSpace i = Char.isSpace (Seq.nth line i) + val n = Seq.length line + val sp1 = valOf (ff (0, n) isSpace) + val sp2 = valOf (ff (sp1+1, n) isSpace) + val sp3 = Option.getOpt (ff (sp2+1, n) isSpace, n) + val verts = + ( valOf (Parse.parseInt (ss line (0, sp1))) + handle Option => raise Fail "bad first vertex" + , valOf (Parse.parseInt (ss line (sp1+1, sp2))) + handle Option => raise Fail "bad second vertex" + , valOf (Parse.parseInt (ss line (sp2+1, sp3))) + handle Option => raise Fail "bad third vertex" + ) + val nbrs = + if sp3 = n then (INVALID_ID,INVALID_ID,INVALID_ID) + else neighborsParser (ss line (sp3+1, n)) + in + SOME (Tri {vertices=verts, neighbors=nbrs}) + end + + fun tryParseVertex lineNum = + tryParse vertexParser (fn _ => true) "vertex" lineNum + + fun tryParseTriangle lineNum = + tryParse triangleParser validTriangle "triangle" lineNum + + val _ = + if numLines >= numVertices + numTriangles + 3 then () + else raise Fail ("Topology2D: not enough vertices and/or triangles to parse") + + val vertices = AS.full (SeqBasis.tabulate 1000 (0, numVertices) + (fn i => tryParseVertex (3+i))) + + val verticesOfTriangle = AS.full (ForkJoin.alloc numTriangles) + val neighborsOfTriangle = AS.full (ForkJoin.alloc numTriangles) + val triangleOfVertex = ForkJoin.alloc numVertices + val _ = ForkJoin.parfor 10000 (0, numVertices) + (fn i => Array.update (triangleOfVertex, i, ~1)) + + val _ = ForkJoin.parfor 1000 (0, numTriangles) (fn i => + let + val Tri {vertices=(a,b,c), neighbors} = + tryParseTriangle (3+numVertices+i) + in + upd verticesOfTriangle i (a,b,c); + upd neighborsOfTriangle i neighbors; + writeMax triangleOfVertex a i; + writeMax triangleOfVertex b i; + writeMax triangleOfVertex c i + end) + in + sortMeshCCW (Mesh + { vdata = vertices + , verticesOfTriangle = verticesOfTriangle + , neighborsOfTriangle = neighborsOfTriangle + , triangleOfVertex = AS.full triangleOfVertex + }) + end + + + fun toString (mesh as Mesh {vdata=vertices, ...}) = + let + val nv = numVertices mesh + val nt = numTriangles mesh + + fun ptos (x,y) = Real.toString x ^ " " ^ Real.toString y + + fun ttos (Tri {vertices=(a,b,c), neighbors=(d,e,f)}) = + String.concatWith " " (List.map Int.toString [a,b,c,d,e,f]) + in + String.concatWith "\n" + ([ "Mesh" + , Int.toString nv + , Int.toString nt + ] + @ + List.tabulate (nv, ptos o Seq.nth vertices) + @ + List.tabulate (nt, ttos o tdata mesh)) + end + +end diff --git a/tests/mpllib/TreeMatrix.sml b/tests/mpllib/TreeMatrix.sml new file mode 100644 index 000000000..e6a94f7dd --- /dev/null +++ b/tests/mpllib/TreeMatrix.sml @@ -0,0 +1,148 @@ +structure TreeMatrix: +sig + (* square matrices of sidelength 2^n matrices only! *) + datatype matrix = + Node of int * matrix * matrix * matrix * matrix + | Leaf of int * real Array.array + + val tabulate: int -> (int * int -> real) -> matrix + val flatten: matrix -> real array + val sidelength: matrix -> int + val multiply: matrix * matrix -> matrix +end = +struct + + val par = ForkJoin.par + + fun par4 (a, b, c, d) = + let + val ((ar, br), (cr, dr)) = par (fn _ => par (a, b), fn _ => par (c, d)) + in + (ar, br, cr, dr) + end + + datatype matrix = + Node of int * matrix * matrix * matrix * matrix + | Leaf of int * real Array.array + + exception MatrixFormat + + fun sidelength mat = + case mat of + Leaf (n, s) => n + | Node (n, _, _, _, _) => n + + fun tabulate sidelen f = + let + fun tab n (row, col) = + if n <= 64 then + Leaf (n, Array.tabulate (n * n, fn i => f (row + i div n, col + i mod n))) + else + let + val half = n div 2 + val (m11, m12, m21, m22) = + par4 ( fn _ => tab half (row, col) + , fn _ => tab half (row, col + half) + , fn _ => tab half (row + half, col) + , fn _ => tab half (row + half, col + half) + ) + in + Node (n, m11, m12, m21, m22) + end + in + tab sidelen (0, 0) + end + + val upd = Array.update + + fun writeFlatten (result, start, rowskip) m = + case m of + Leaf (n, s) => + let fun idx i = start + (i div n)*rowskip + (i mod n) + in Array.appi (fn (i, x) => upd (result, idx i, x)) s + end + | Node (n, m11, m12, m21, m22) => + ( par4 ( fn _ => writeFlatten (result, start, rowskip) m11 + , fn _ => writeFlatten (result, start + n div 2, rowskip) m12 + , fn _ => writeFlatten (result, start + (n div 2) * rowskip, rowskip) m21 + , fn _ => writeFlatten (result, start + (n div 2) * (rowskip + 1), rowskip) m22 + ) + ; () + ) + + fun flatten m = + let + val n = sidelength m + val result = ForkJoin.alloc (n * n) + in + writeFlatten (result, 0, n) m; + result + end + + fun flatmultiply n (s, t, output) = + let + val sub = Array.sub + val a = s + val b = t + val aStart = 0 + val bStart = 0 + (* assume our lengths are good *) + (* loop with accumulator to compute dot product. r is an index into + * vector a (the row index) and c is an index into b (the col index) *) + fun loop rowStop acc r c = + if r = rowStop then acc + else let val acc' = acc + (sub (a, r) * sub (b, c)) + val r' = r + 1 + val c' = c + n + in loop rowStop acc' r' c' + end + fun cell c = + let + val (i, j) = (c div n, c mod n) + val rowStart = aStart + i*n + val rowStop = rowStart + n + val colStart = bStart + j + in + loop rowStop 0.0 rowStart colStart + end + fun update i = + let + val newv = cell i + val old = sub (output, i) + in + Array.update (output, i, newv + old) + end + fun loopi i hi = + if i >= hi then () else (update i; loopi (i + 1) hi) + in + loopi 0 (n * n) + end + + fun multiply' (a, b, c) = + case (a, b, c) of + (Leaf (n, s), Leaf (_, t), Leaf (_, c)) => flatmultiply n (s, t, c) + | (Node (n, a11, a12, a21, a22), + Node (_, b11, b12, b21, b22), + Node (_, c11, c12, c21, c22)) => + let + fun block (m1, m2, m3, m4, c) = + (multiply' (m1, m2, c); multiply' (m3, m4, c)) + in + par4 ( fn _ => block (a11, b11, a12, b21, c11) + , fn _ => block (a11, b12, a12, b22, c12) + , fn _ => block (a21, b11, a22, b21, c21) + , fn _ => block (a21, b12, a22, b22, c22) + ); + () + end + | _ => raise MatrixFormat + + fun multiply (a, b) = + let + val c = tabulate (sidelength a) (fn _ => 0.0) + in + multiply' (a, b, c); + c + end + +end diff --git a/tests/mpllib/Util.sml b/tests/mpllib/Util.sml new file mode 100644 index 000000000..97b9e02f8 --- /dev/null +++ b/tests/mpllib/Util.sml @@ -0,0 +1,354 @@ +structure Util: +sig + val getTime: (unit -> 'a) -> ('a * Time.time) + val reportTime: (unit -> 'a) -> 'a + + val closeEnough: real * real -> bool + + val die: string -> 'a + + val repeat: int * (unit -> 'a) -> ('a) + + val hash64: Word64.word -> Word64.word + val hash64_2: Word64.word -> Word64.word + val hash32: Word32.word -> Word32.word + val hash32_2: Word32.word -> Word32.word + val hash32_3: Word32.word -> Word32.word + val hash: int -> int + + val ceilDiv: int -> int -> int + + val pow2: int -> int + + (* this actually computes 1 + floor(log_2(n)), i.e. the number of + * bits required to represent n in binary *) + val log2: int -> int + + (* boundPow2 n == smallest power of 2 that is less-or-equal-to n *) + val boundPow2: int -> int + + val foreach: 'a ArraySlice.slice -> (int * 'a -> unit) -> unit + + (* if the array is short, then convert it to a string. otherwise only + * show the first few elements and the last element *) + val summarizeArray: int -> ('a -> string) -> 'a array -> string + val summarizeArraySlice: int -> ('a -> string) -> 'a ArraySlice.slice -> string + + (* `for (lo, hi) f` do f(i) sequentially for each lo <= i < hi + * forBackwards goes from hi-1 down to lo *) + val for: (int * int) -> (int -> unit) -> unit + val forBackwards: (int * int) -> (int -> unit) -> unit + + (* `loop (lo, hi) b f` + * for lo <= i < hi, iteratively do b = f (b, i) *) + val loop: (int * int) -> 'a -> ('a * int -> 'a) -> 'a + + val all: (int * int) -> (int -> bool) -> bool + val exists: (int * int) -> (int -> bool) -> bool + + val copyListIntoArray: 'a list -> 'a array -> int -> int + + val revMap: ('a -> 'b) -> 'a list -> 'b list + + val intToString: int -> string + + val equalLists: ('a * 'a -> bool) -> 'a list * 'a list -> bool + +end = +struct + + fun ceilDiv n k = 1 + (n-1) div k + + fun digitToChar d = Char.chr (d + 48) + + fun intToString x = + let + (** For binary precision p, number of decimal digits needed is upper + * bounded by: + * 1 + log_{10}(2^p) = 1 + p * log_{10}(2) + * ~= 1 + p * 0.30103 + * < 1 + p * 0.33333 + * = 1 + p / 3 + * Just for a little extra sanity, we'll do ceiling-div. + *) + val maxNumChars = 1 + ceilDiv (valOf Int.precision) 3 + val buf = ForkJoin.alloc maxNumChars + + val orig = x + + fun loop q i = + let + val i = i-1 + val d = ~(Int.rem (q, 10)) + val _ = Array.update (buf, i, digitToChar d) + val q = Int.quot (q, 10) + in + if q <> 0 then + loop q i + else if orig < 0 then + (Array.update (buf, i-1, #"~"); i-1) + else + i + end + + val start = loop (if orig < 0 then orig else ~orig) maxNumChars + in + CharVector.tabulate (maxNumChars-start, fn i => Array.sub (buf, start+i)) + end + + fun die msg = + ( TextIO.output (TextIO.stdErr, msg ^ "\n") + ; TextIO.flushOut TextIO.stdErr + ; OS.Process.exit OS.Process.failure + ) + + fun getTime f = + let + val t0 = Time.now () + val result = f () + val t1 = Time.now () + in + (result, Time.- (t1, t0)) + end + + fun reportTime f = + let + val (result, tm) = getTime f + in + print ("time " ^ Time.fmt 4 tm ^ "s\n"); + result + end + + fun closeEnough (x, y) = + Real.abs (x - y) <= 0.000001 + + (* NOTE: this actually computes 1 + floor(log_2(n)), i.e. the number of + * bits required to represent n in binary *) + fun log2 n = if (n < 1) then 0 else 1 + log2(n div 2) + + fun pow2 i = if (i<1) then 1 else 2*pow2(i-1) + + fun searchPow2 n m = if m >= n then m else searchPow2 n (2*m) + fun boundPow2 n = searchPow2 n 1 + + fun loop (lo, hi) b f = + if lo >= hi then b else loop (lo+1, hi) (f (b, lo)) f + + fun forBackwards (i, j) f = + if i >= j then () else (f (j-1); forBackwards (i, j-1) f) + + fun for (lo, hi) f = + if lo >= hi then () else (f lo; for (lo+1, hi) f) + + fun foreach s f = + ForkJoin.parfor 4096 (0, ArraySlice.length s) + (fn i => f (i, ArraySlice.sub (s, i))) + + fun all (lo, hi) f = + let + fun allFrom i = + (i >= hi) orelse (f i andalso allFrom (i+1)) + in + allFrom lo + end + + fun exists (lo, hi) f = + let + fun existsFrom i = + i < hi andalso (f i orelse existsFrom (i+1)) + in + existsFrom lo + end + + fun copyListIntoArray xs arr i = + case xs of + [] => i + | x :: xs => + ( Array.update (arr, i, x) + ; copyListIntoArray xs arr (i+1) + ) + + fun repeat (n, f) = + let + fun rep_help 1 = f() + | rep_help n = ((rep_help (n-1)); f()) + + val ns = if (n>0) then n else 1 + in + rep_help ns + end + + fun summarizeArraySlice count toString xs = + let + val n = ArraySlice.length xs + fun elem i = ArraySlice.sub (xs, i) + + val strs = + if count <= 0 then raise Fail "summarizeArray needs count > 0" + else if count <= 2 orelse n <= count then + List.tabulate (n, toString o elem) + else + List.tabulate (count-1, toString o elem) @ + ["...", toString (elem (n-1))] + in + "[" ^ (String.concatWith ", " strs) ^ "]" + end + + fun summarizeArray count toString xs = + summarizeArraySlice count toString (ArraySlice.full xs) + + fun revMap f xs = + let + fun loop acc xs = + case xs of + [] => acc + | x :: rest => loop (f x :: acc) rest + in + loop [] xs + end + + (* // from numerical recipes + * uint64_t hash64(uint64_t u) + * { + * uint64_t v = u * 3935559000370003845ul + 2691343689449507681ul; + * v ^= v >> 21; + * v ^= v << 37; + * v ^= v >> 4; + * v *= 4768777513237032717ul; + * v ^= v << 20; + * v ^= v >> 41; + * v ^= v << 5; + * return v; + * } + *) + + fun hash64 u = + let + open Word64 + infix 2 >> << xorb andb + val v = u * 0w3935559000370003845 + 0w2691343689449507681 + val v = v xorb (v >> 0w21) + val v = v xorb (v << 0w37) + val v = v xorb (v >> 0w4) + val v = v * 0w4768777513237032717 + val v = v xorb (v << 0w20) + val v = v xorb (v >> 0w41) + val v = v xorb (v << 0w5) + in + v + end + + (* uint32_t hash32(uint32_t a) { + * a = (a+0x7ed55d16) + (a<<12); + * a = (a^0xc761c23c) ^ (a>>19); + * a = (a+0x165667b1) + (a<<5); + * a = (a+0xd3a2646c) ^ (a<<9); + * a = (a+0xfd7046c5) + (a<<3); + * a = (a^0xb55a4f09) ^ (a>>16); + * return a; + * } + *) + + fun hash32 a = + let + open Word32 + infix 2 >> << xorb + val a = (a + 0wx7ed55d16) + (a << 0w12) + val a = (a xorb 0wxc761c23c) xorb (a >> 0w19) + val a = (a + 0wx165667b1) + (a << 0w5) + val a = (a + 0wxd3a2646c) xorb (a << 0w9) + val a = (a + 0wxfd7046c5) + (a << 0w3) + val a = (a xorb 0wxb55a4f09) xorb (a >> 0w16) + in + a + end + + (* uint32_t hash32_2(uint32_t a) { + * uint32_t z = (a + 0x6D2B79F5UL); + * z = (z ^ (z >> 15)) * (z | 1UL); + * z ^= z + (z ^ (z >> 7)) * (z | 61UL); + * return z ^ (z >> 14); + * } + *) + + fun hash32_2 a = + let + open Word32 + infix 2 >> << xorb orb + val z = (a + 0wx6D2B79F5) + val z = (z xorb (z >> 0w15)) * (z orb 0w1) + val z = z xorb (z + (z xorb (z >> 0w7)) * (z orb 0w61)) + in + z xorb (z >> 0w14) + end + + (* inline uint32_t hash32_3(uint32_t a) { + * uint32_t z = a + 0x9e3779b9; + * z ^= z >> 15; // 16 for murmur3 + * z *= 0x85ebca6b; + * z ^= z >> 13; + * z *= 0xc2b2ae3d; // 0xc2b2ae35 for murmur3 + * return z ^= z >> 16; + * } + *) + + fun hash32_3 a = + let + open Word32 + infix 2 >> << xorb orb + val z = a + 0wx9e3779b9 + val z = z xorb (z >> 0w15) (* 16 for murmur3 *) + val z = z * 0wx85ebca6b + val z = z xorb (z >> 0w13) + val z = z * 0wxc2b2ae3d (* 0wxc2b2ae35 for murmur3 *) + val z = z xorb (z >> 0w16) + in + z + end + + (* // a slightly cheaper, but possibly not as good version + * // based on splitmix64 + * inline uint64_t hash64_2(uint64_t x) { + * x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); + * x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb); + * x = x ^ (x >> 31); + * return x; + * } + *) + + fun hash64_2 x = + let + open Word64 + infix 2 >> << xorb orb + val x = (x xorb (x >> 0w30)) * 0wxbf58476d1ce4e5b9 + val x = (x xorb (x >> 0w27)) * 0wx94d049bb133111eb + val x = x xorb (x >> 0w31) + in + x + end + + (* This chooses which hash function to use for generic integers, since + * integers are configurable at compile time. *) + val hash : int -> int = + case Int.precision of + NONE => (Word64.toInt o hash64 o Word64.fromInt) + | SOME 32 => (Word32.toIntX o hash32 o Word32.fromInt) + | SOME 64 => (Word64.toIntX o hash64 o Word64.fromInt) + | SOME p => (fn x => + let + val wp1 = Word.fromInt (p-1) + open Word64 + infix 2 >> << andb + val v = hash64 (fromInt x) + val v = v andb ((0w1 << wp1) - 0w1) + in + toInt v + end) + + + fun equalLists eq ([], []) = true + | equalLists eq (x :: xs, y :: ys) = + eq (x, y) andalso equalLists eq (xs, ys) + | equalLists _ _ = false + +end diff --git a/tests/mpllib/compat/PosixReadFile.sml b/tests/mpllib/compat/PosixReadFile.sml new file mode 100644 index 000000000..7f6e73199 --- /dev/null +++ b/tests/mpllib/compat/PosixReadFile.sml @@ -0,0 +1,60 @@ +structure PosixReadFile = +struct + + fun contentsSeq' (readByte: Word8.word -> 'a) path = + let + val (file, length) = + let + open Posix.FileSys + val file = openf (path, O_RDONLY, O.fromWord 0w0) + in + (file, Position.toInt (ST.size (fstat file))) + end + + open Posix.IO + + val bufferSize = 100000 + val buffer = Word8ArrayExtra.alloc bufferSize + val result = ArrayExtra.alloc length + (* val result = Word8ArrayExtra.alloc length *) + + (* fun copyToResult i n = + Word8ArraySlice.copy + { src = Word8ArraySlice.slice (buffer, 0, SOME n) + , dst = result + , di = i + } *) + + fun copyToResult i n = + Word8ArraySlice.appi (fn (j, b) => + Unsafe.Array.update (result, i+j, readByte b)) + (Word8ArraySlice.slice (buffer, 0, SOME n)) + + fun dumpFrom i = + if i >= length then () else + let + val bytesRead = readArr (file, Word8ArraySlice.full buffer) + in + copyToResult i bytesRead; + dumpFrom (i + bytesRead) + end + in + dumpFrom 0; + close file; + ArraySlice.full result + end + + fun contentsSeq path = + contentsSeq' (Char.chr o Word8.toInt) path + + fun contentsBinSeq path = + contentsSeq' (fn w => w) path + + fun contents filename = + let + val chars = contentsSeq filename + in + CharVector.tabulate (ArraySlice.length chars, + fn i => ArraySlice.sub (chars, i)) + end +end diff --git a/tests/mpllib/compat/PosixWriteFile.sml b/tests/mpllib/compat/PosixWriteFile.sml new file mode 100644 index 000000000..bceb2f1ad --- /dev/null +++ b/tests/mpllib/compat/PosixWriteFile.sml @@ -0,0 +1,13 @@ +structure PosixWriteFile = +struct + fun dump (filename: string, contents: string) = + let + open Posix.FileSys + val f = creat (filename, S.flags [S.iwusr, S.irusr, S.irgrp, S.iroth]) + val contentslice = (Word8VectorSlice.full (Byte.stringToBytes contents)) + in + (Posix.IO.writeVec (f, contentslice); + Posix.IO.close f; + ()) + end +end diff --git a/tests/mpllib/compat/mlton.mlb b/tests/mpllib/compat/mlton.mlb new file mode 100644 index 000000000..e0e996a1c --- /dev/null +++ b/tests/mpllib/compat/mlton.mlb @@ -0,0 +1,25 @@ +local + $(SML_LIB)/basis/basis.mlb + $(SML_LIB)/basis/mlton.mlb + $(SML_LIB)/basis/unsafe.mlb + local + $(SML_LIB)/basis/build/sources.mlb + in + structure ArrayExtra = Array + structure VectorExtra = Vector + structure Word8ArrayExtra = Word8Array + end + + PosixReadFile.sml + PosixWriteFile.sml + mlton.sml +in + structure ForkJoin + structure Concurrency + structure ReadFile + structure WriteFile + structure GCStats + structure MLton + structure VectorExtra + structure RuntimeStats +end diff --git a/tests/mpllib/compat/mlton.sml b/tests/mpllib/compat/mlton.sml new file mode 100644 index 000000000..b5a3b243c --- /dev/null +++ b/tests/mpllib/compat/mlton.sml @@ -0,0 +1,68 @@ +structure ForkJoin: +sig + val par: (unit -> 'a) * (unit -> 'b) -> 'a * 'b + val parfor: int -> int * int -> (int -> unit) -> unit + val alloc: int -> 'a array +end = +struct + fun par (f, g) = (f (), g ()) + fun parfor (g: int) (lo, hi) (f: int -> unit) = + if lo >= hi then () else (f lo; parfor g (lo + 1, hi) f) + fun alloc n = ArrayExtra.alloc n +end + +structure VectorExtra: +sig + val unsafeFromArray: 'a array -> 'a vector +end = +struct open VectorExtra end + +structure Concurrency = +struct + val numberOfProcessors = 1 + + fun cas r (x, y) = + let val current = !r + in if MLton.eq (x, current) then r := y else (); current + end + + fun casArray (a, i) (x, y) = + let val current = Array.sub (a, i) + in if MLton.eq (x, current) then Array.update (a, i, y) else (); current + end +end + +structure ReadFile = PosixReadFile + +structure WriteFile = PosixWriteFile + +structure GCStats: +sig + val report: unit -> unit +end = +struct + + fun p name thing = + print (name ^ ": " ^ thing () ^ "\n") + + fun report () = + let in print ("======== GC Stats ========\n"); print "none yet...\n" + end + +end + +structure RuntimeStats: +sig + type t + val get: unit -> t + val benchReport: {before: t, after: t} -> unit +end = +struct + type t = unit + fun get () = () + fun benchReport _ = + ( print ("======== Runtime Stats ========\n") + ; print ("none yet...\n") + ; print ("====== End Runtime Stats ======\n") + ) +end diff --git a/tests/mpllib/compat/mpl-old.mlb b/tests/mpllib/compat/mpl-old.mlb new file mode 100644 index 000000000..acf9bcf69 --- /dev/null +++ b/tests/mpllib/compat/mpl-old.mlb @@ -0,0 +1,25 @@ +local + $(SML_LIB)/basis/basis.mlb + $(SML_LIB)/basis/mlton.mlb + $(SML_LIB)/basis/unsafe.mlb + $(SML_LIB)/basis/fork-join.mlb + + local + $(SML_LIB)/basis/build/sources.mlb + in + structure ArrayExtra = Array + structure VectorExtra = Vector + structure Word8ArrayExtra = Word8Array + end + + PosixReadFile.sml + mpl-old.sml +in + structure ForkJoin + structure Concurrency + structure ReadFile + structure GCStats + structure MLton + structure VectorExtra + structure RuntimeStats +end diff --git a/tests/mpllib/compat/mpl-old.sml b/tests/mpllib/compat/mpl-old.sml new file mode 100644 index 000000000..1b5124c9c --- /dev/null +++ b/tests/mpllib/compat/mpl-old.sml @@ -0,0 +1,49 @@ +(* already provided by the compiler *) +structure ForkJoin = ForkJoin + +structure Concurrency = +struct + val numberOfProcessors = MLton.Parallel.numberOfProcessors + val cas = MLton.Parallel.compareAndSwap + val casArray = MLton.Parallel.arrayCompareAndSwap +end + +structure VectorExtra: +sig + val unsafeFromArray: 'a array -> 'a vector +end = +struct open VectorExtra end + +structure ReadFile = PosixReadFile + +structure GCStats: +sig + val report: unit -> unit +end = +struct + + fun p name thing = + print (name ^ ": " ^ thing () ^ "\n") + + fun report () = + let in print ("======== GC Stats ========\n"); print "none yet...\n" + end + +end + + +structure RuntimeStats: +sig + type t + val get: unit -> t + val benchReport: {before: t, after: t} -> unit +end = +struct + type t = unit + fun get () = () + fun benchReport _ = + ( print ("======== Runtime Stats ========\n") + ; print ("none yet...\n") + ; print ("====== End Runtime Stats ======\n") + ) +end diff --git a/tests/mpllib/compat/mpl.mlb b/tests/mpllib/compat/mpl.mlb new file mode 100644 index 000000000..9e56c0ad6 --- /dev/null +++ b/tests/mpllib/compat/mpl.mlb @@ -0,0 +1,23 @@ +local + $(SML_LIB)/basis/basis.mlb + $(SML_LIB)/basis/mlton.mlb + $(SML_LIB)/basis/mpl.mlb + $(SML_LIB)/basis/fork-join.mlb + + local + $(SML_LIB)/basis/build/sources.mlb + in + structure VectorExtra = Vector + end + PosixWriteFile.sml + mpl.sml +in + structure ForkJoin + structure Concurrency + structure ReadFile + structure WriteFile + structure GCStats + structure RuntimeStats + structure MLton + structure VectorExtra +end diff --git a/tests/mpllib/compat/mpl.sml b/tests/mpllib/compat/mpl.sml new file mode 100644 index 000000000..6d3183c59 --- /dev/null +++ b/tests/mpllib/compat/mpl.sml @@ -0,0 +1,236 @@ +(* already provided by the compiler *) +structure ForkJoin = ForkJoin + +structure Concurrency = +struct + val numberOfProcessors = MLton.Parallel.numberOfProcessors + val cas = MLton.Parallel.compareAndSwap + val casArray = MLton.Parallel.arrayCompareAndSwap +end + +structure VectorExtra: +sig + val unsafeFromArray: 'a array -> 'a vector +end = +struct open VectorExtra end + +structure ReadFile = +struct + + fun contentsSeq' reader filename = + let + val file = MPL.File.openFile filename + val n = MPL.File.size file + val arr = ForkJoin.alloc n + val k = 10000 + val m = 1 + (n - 1) div k + in + ForkJoin.parfor 1 (0, m) (fn i => + let + val lo = i * k + val hi = Int.min ((i + 1) * k, n) + in + reader file lo (ArraySlice.slice (arr, lo, SOME (hi - lo))) + end); + MPL.File.closeFile file; + ArraySlice.full arr + end + + fun contentsSeq filename = contentsSeq' MPL.File.readChars filename + + fun contentsBinSeq filename = contentsSeq' MPL.File.readWord8s filename + + fun contents filename = + let + val chars = contentsSeq filename + in + CharVector.tabulate (ArraySlice.length chars, fn i => + ArraySlice.sub (chars, i)) + end + +end + +structure WriteFile = PosixWriteFile + +structure GCStats: +sig + val report: unit -> unit +end = +struct + + fun p name thing = + print (name ^ ": " ^ thing () ^ "\n") + + fun report () = + let in + print ("======== GC Stats ========\n"); + p "local reclaimed" (LargeInt.toString o MPL.GC.localBytesReclaimed); + p "num local" (LargeInt.toString o MPL.GC.numLocalGCs); + p "local gc time" + (LargeInt.toString o Time.toMilliseconds o MPL.GC.localGCTime); + p "promo time" + (LargeInt.toString o Time.toMilliseconds o MPL.GC.promoTime); + p "internal reclaimed" (LargeInt.toString o MPL.GC.internalBytesReclaimed) + end + +end + + +structure RuntimeStats: +sig + type t + val get: unit -> t + val benchReport: {before: t, after: t} -> unit +end = +struct + + type stats = + { lgcCount: int + , lgcBytesReclaimed: int + , lgcBytesInScope: int + , lgcTracingTime: Time.time + , lgcPromoTime: Time.time + , cgcCount: int + , cgcBytesReclaimed: int + , cgcBytesInScope: int + , cgcTime: Time.time + , schedWorkTime: Time.time + , schedIdleTime: Time.time + , susMarks: int + , deChecks: int + , entanglements: int + , bytesPinnedEntangled: int + , bytesPinnedEntangledWatermark: int + , numSpawns: int + , numEagerSpawns: int + , numHeartbeats: int + , numSkippedHeartbeats: int + , numSteals: int + , maxHeartbeatStackWalk: int + , maxHeartbeatStackSize: int + } + + datatype t = Stats of stats + + fun get () = + Stats + { lgcCount = LargeInt.toInt (MPL.GC.numLocalGCs ()) + , lgcBytesReclaimed = LargeInt.toInt (MPL.GC.localBytesReclaimed ()) + , lgcBytesInScope = LargeInt.toInt (MPL.GC.bytesInScopeForLocal ()) + , lgcTracingTime = MPL.GC.localGCTime () + , lgcPromoTime = MPL.GC.promoTime () + , cgcCount = LargeInt.toInt (MPL.GC.numCCs ()) + , cgcBytesReclaimed = LargeInt.toInt (MPL.GC.ccBytesReclaimed ()) + , cgcBytesInScope = LargeInt.toInt (MPL.GC.bytesInScopeForCC ()) + , cgcTime = MPL.GC.ccTime () + , schedWorkTime = ForkJoin.workTimeSoFar () + , schedIdleTime = ForkJoin.idleTimeSoFar () + , susMarks = LargeInt.toInt (MPL.GC.numberSuspectsMarked ()) + , deChecks = LargeInt.toInt (MPL.GC.numberDisentanglementChecks ()) + , entanglements = LargeInt.toInt (MPL.GC.numberEntanglements ()) + , bytesPinnedEntangled = LargeInt.toInt (MPL.GC.bytesPinnedEntangled ()) + , bytesPinnedEntangledWatermark = LargeInt.toInt + (MPL.GC.bytesPinnedEntangledWatermark ()) + , numSpawns = ForkJoin.numSpawnsSoFar () + , numEagerSpawns = ForkJoin.numEagerSpawnsSoFar () + , numHeartbeats = ForkJoin.numHeartbeatsSoFar () + , numSkippedHeartbeats = ForkJoin.numSkippedHeartbeatsSoFar () + , numSteals = ForkJoin.numStealsSoFar () + , maxHeartbeatStackSize = IntInf.toInt + (MPL.GC.maxStackSizeForHeartbeat ()) + , maxHeartbeatStackWalk = IntInf.toInt + (MPL.GC.maxStackFramesWalkedForHeartbeat ()) + } + + fun pct a b = + Real.round (100.0 * (Real.fromInt a / Real.fromInt b)) + handle _ => 0 + + val itos = Int.toString + val rtos = Real.fmt (StringCvt.FIX (SOME 2)) + + fun benchReport {before = Stats b, after = Stats a} = + let + val numSpawns = #numSpawns a - #numSpawns b + val numEagerSpawns = #numEagerSpawns a - #numEagerSpawns b + val numHeartbeatSpawns = numSpawns - numEagerSpawns + val numHeartbeats = #numHeartbeats a - #numHeartbeats b + val numSkippedHeartbeats = + #numSkippedHeartbeats a - #numSkippedHeartbeats b + val numSteals = #numSteals a - #numSteals b + + val eagerp = pct numEagerSpawns numSpawns + val hbp = pct numHeartbeatSpawns numSpawns + val skipp = pct numSkippedHeartbeats numHeartbeats + + val spawnsPerHb = Real.fromInt numSpawns / Real.fromInt numHeartbeats + handle _ => 0.0 + + val eagerSpawnsPerHb = + Real.fromInt numEagerSpawns / Real.fromInt numHeartbeats + handle _ => 0.0 + + val hbSpawnsPerHb = + Real.fromInt numHeartbeatSpawns / Real.fromInt numHeartbeats + handle _ => 0.0 + + fun p name (selector: stats -> 'a) (differ: 'a * 'a -> 'a) + (stringer: 'a -> string) : unit = + print (name ^ " " ^ stringer (differ (selector a, selector b)) ^ "\n") + in + print ("======== Runtime Stats ========\n"); + print ("num spawns " ^ itos numSpawns ^ "\n"); + print + (" eager " ^ itos numEagerSpawns ^ " (" ^ itos eagerp + ^ "%)\n"); + print + (" at heartbeat " ^ itos numHeartbeatSpawns ^ " (" ^ itos hbp + ^ "%)\n"); + print "\n"; + print ("num heartbeats " ^ itos numHeartbeats ^ "\n"); + print + (" skipped " ^ itos numSkippedHeartbeats ^ " (" ^ itos skipp + ^ "%)\n"); + print "\n"; + print ("spawns / hb " ^ rtos spawnsPerHb ^ "\n"); + print (" eager " ^ rtos eagerSpawnsPerHb ^ "\n"); + print (" at heartbeat " ^ rtos hbSpawnsPerHb ^ "\n"); + print "\n"; + print ("num steals " ^ itos numSteals ^ "\n"); + print "\n"; + print ("max hb stack walk " ^ itos (#maxHeartbeatStackWalk a) ^ "\n"); + print ("max hb stack size " ^ itos (#maxHeartbeatStackSize a) ^ "\n"); + print "\n"; + + p "sus marks" #susMarks op- Int.toString; + p "de checks" #deChecks op- Int.toString; + p "entanglements" #entanglements op- Int.toString; + p "bytes pinned entangled" #bytesPinnedEntangled op- Int.toString; + p "bytes pinned entangled watermark" #bytesPinnedEntangledWatermark #1 + Int.toString; + print "\n"; + p "lgc count" #lgcCount op- Int.toString; + p "lgc bytes reclaimed" #lgcBytesReclaimed op- Int.toString; + p "lgc bytes in scope " #lgcBytesInScope op- Int.toString; + p "lgc trace time(ms) " #lgcTracingTime Time.- + (LargeInt.toString o Time.toMilliseconds); + p "lgc promo time(ms) " #lgcPromoTime Time.- + (LargeInt.toString o Time.toMilliseconds); + p "lgc total time(ms) " + (fn x => Time.+ (#lgcTracingTime x, #lgcPromoTime x)) Time.- + (LargeInt.toString o Time.toMilliseconds); + print "\n"; + p "cgc count" #cgcCount op- Int.toString; + p "cgc bytes reclaimed" #cgcBytesReclaimed op- Int.toString; + p "cgc bytes in scope " #cgcBytesInScope op- Int.toString; + p "cgc time(ms)" #cgcTime Time.- (LargeInt.toString o Time.toMilliseconds); + print "\n"; + p "work time(ms)" #schedWorkTime Time.- + (LargeInt.toString o Time.toMilliseconds); + p "idle time(ms)" #schedIdleTime Time.- + (LargeInt.toString o Time.toMilliseconds); + print ("====== End Runtime Stats ======\n"); + () + end + +end diff --git a/tests/mpllib/sources.mlton.mlb b/tests/mpllib/sources.mlton.mlb new file mode 100644 index 000000000..67ebb64e1 --- /dev/null +++ b/tests/mpllib/sources.mlton.mlb @@ -0,0 +1,85 @@ +$(SML_LIB)/basis/basis.mlb +compat/mlton.mlb + +CommandLineArgs.sml +Util.sml +SeqBasis.sml + +SEQUENCE.sml +local + ArraySequence.sml +in + structure Seq = ArraySequence + structure ArraySequence +end + +Seqifier.sml +TFlatten.sml + +OffsetSearch.sml +STREAM.sml +DelayedStream.sml +RecursiveStream.sml +DelayedSeq.sml +OldDelayedSeq.sml + +FuncSequence.sml +BinarySearch.sml +Merge.sml +SeqifiedMerge.sml +FlattenMerge.sml +StableMerge.sml +DoubleBinarySearch.sml +StableMergeLowSpan.sml +StableSort.sml +Quicksort.sml +Mergesort.sml +TreeMatrix.sml +AugMap.sml + +PureSeq.sml + +SampleSort.sml +CountingSort.sml +RadixSort.sml +Shuffle.sml + +(* ReadFile.sml *) +Tokenize.sml +FindFirst.sml +Parse.sml +ParseFile.sml +(* TabFilterTree.sml *) +AdjacencyGraph.sml +AdjacencyInt.sml + +MatCOO.sml + +CheckSort.sml + +MkComplex.sml +Rat.sml +Geometry3D.sml +Geometry2D.sml +Topology2D.sml + +NearestNeighbors.sml + +Color.sml +PPM.sml +ExtraBinIO.sml +GIF.sml +NewWaveIO.sml +Signal.sml + +MeshToImage.sml + +MkGrep.sml + +Hashset.sml +Hashtable.sml + +ParFuncArray.sml +ChunkedTreap.sml + +Benchmark.sml diff --git a/tests/mpllib/sources.mpl.mlb b/tests/mpllib/sources.mpl.mlb new file mode 100644 index 000000000..a1da67fc5 --- /dev/null +++ b/tests/mpllib/sources.mpl.mlb @@ -0,0 +1,85 @@ +$(SML_LIB)/basis/basis.mlb +compat/mpl.mlb + +CommandLineArgs.sml +Util.sml +SeqBasis.sml + +SEQUENCE.sml +local + ArraySequence.sml +in + structure Seq = ArraySequence + structure ArraySequence +end + +Seqifier.sml +TFlatten.sml + +OffsetSearch.sml +STREAM.sml +DelayedStream.sml +RecursiveStream.sml +DelayedSeq.sml +OldDelayedSeq.sml + +FuncSequence.sml +BinarySearch.sml +Merge.sml +SeqifiedMerge.sml +FlattenMerge.sml +StableMerge.sml +DoubleBinarySearch.sml +StableMergeLowSpan.sml +StableSort.sml +Quicksort.sml +Mergesort.sml +TreeMatrix.sml +AugMap.sml + +PureSeq.sml + +SampleSort.sml +CountingSort.sml +RadixSort.sml +Shuffle.sml + +(* ReadFile.sml *) +Tokenize.sml +FindFirst.sml +Parse.sml +ParseFile.sml +(* TabFilterTree.sml *) +AdjacencyGraph.sml +AdjacencyInt.sml + +MatCOO.sml + +CheckSort.sml + +MkComplex.sml +Rat.sml +Geometry3D.sml +Geometry2D.sml +Topology2D.sml + +NearestNeighbors.sml + +Color.sml +PPM.sml +ExtraBinIO.sml +GIF.sml +NewWaveIO.sml +Signal.sml + +MeshToImage.sml + +MkGrep.sml + +Hashset.sml +Hashtable.sml + +ParFuncArray.sml +ChunkedTreap.sml + +Benchmark.sml