From 897659609b403e294a0febc9d92f543626f6f85f Mon Sep 17 00:00:00 2001 From: Mike Ledger Date: Mon, 1 May 2017 20:37:19 +1000 Subject: [PATCH] faster traversable sorting with the help of vector-algorithms --- benchmarks/Main.hs | 23 +++++++++----- sort-traversable.cabal | 4 +++ src/Data/Traversable/Sort/Vector.hs | 49 +++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 src/Data/Traversable/Sort/Vector.hs diff --git a/benchmarks/Main.hs b/benchmarks/Main.hs index 2186c58..82a14d7 100644 --- a/benchmarks/Main.hs +++ b/benchmarks/Main.hs @@ -1,5 +1,6 @@ module Main where import Data.Traversable.Sort.PairingHeap (sortTraversable) +import qualified Data.Traversable.Sort.Vector as V import Criterion.Main import Data.List (sort) import qualified Data.Sequence as Seq @@ -24,46 +25,54 @@ main = do defaultMain [ bgroup "1000" [ bgroup "list" - [ bench "Data.List" $ nf sort thousand - , bench "HSTrav" $ nf sortTraversable thousand + [ bench "Data.List" $ nf sort thousand + , bench "HSTrav" $ nf sortTraversable thousand + , bench "HSTrav vector" $ nf V.sortTraversable thousand ] , bgroup "sequence" [ bench "sort" $ nf Seq.sort thousand' , bench "unstableSort" $ nf Seq.unstableSort thousand' , bench "HSTrav" $ nf sortTraversable thousand' + , bench "HSTrav vector" $ nf V.sortTraversable thousand' ] ] , bgroup "10000" [ bgroup "list" [ bench "Data.List" $ nf sort tenthousand , bench "HSTrav" $ nf sortTraversable tenthousand + , bench "HSTrav vector" $ nf V.sortTraversable tenthousand ] , bgroup "sequence" [ bench "sort" $ nf Seq.sort tenthousand' , bench "unstableSort" $ nf Seq.unstableSort tenthousand' , bench "HSTrav" $ nf sortTraversable tenthousand' + , bench "HSTrav vector" $ nf V.sortTraversable tenthousand' ] ] , bgroup "100000" [ bgroup "list" [ bench "Data.List" $ nf sort hundredthousand , bench "HSTrav" $ nf sortTraversable hundredthousand + , bench "HSTrav vector" $ nf V.sortTraversable hundredthousand ] , bgroup "sequence" [ bench "sort" $ nf Seq.sort hundredthousand' , bench "unstableSort" $ nf Seq.unstableSort hundredthousand' , bench "HSTrav" $ nf sortTraversable hundredthousand' + , bench "HSTrav vector" $ nf V.sortTraversable hundredthousand' ] ] , bgroup "1000000" [ bgroup "list" - [ bench "Data.List" $ nf sort million - , bench "HSTrav" $ nf sortTraversable million + [ bench "Data.List" $ nf sort million + , bench "HSTrav" $ nf sortTraversable million + , bench "HSTrav vector" $ nf V.sortTraversable million ] , bgroup "sequence" - [ bench "sort" $ nf Seq.sort million' - , bench "unstableSort" $ nf Seq.unstableSort million' - , bench "HSTrav" $ nf sortTraversable million' + [ bench "sort" $ nf Seq.sort million' + , bench "unstableSort" $ nf Seq.unstableSort million' + , bench "HSTrav" $ nf sortTraversable million' + , bench "HSTrav vector" $ nf V.sortTraversable million' ] ] ] diff --git a/sort-traversable.cabal b/sort-traversable.cabal index df4bdda..9213dc9 100644 --- a/sort-traversable.cabal +++ b/sort-traversable.cabal @@ -47,6 +47,7 @@ library Data.Traversable.Sort.PairingHeap Data.Traversable.Sort.PairingHeap.BasicNat Data.Traversable.Sort.PairingHeap.IndexedPairingHeap + Data.Traversable.Sort.Vector -- Modules included in this library but not exported. -- other-modules: @@ -65,6 +66,9 @@ library -- Other library packages from which modules are imported. build-depends: base >=4.8 && <4.10 + , vector + , vector-algorithms + , mtl -- Directories containing source files. hs-source-dirs: src, benchmarks diff --git a/src/Data/Traversable/Sort/Vector.hs b/src/Data/Traversable/Sort/Vector.hs new file mode 100644 index 0000000..ba55c19 --- /dev/null +++ b/src/Data/Traversable/Sort/Vector.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE ScopedTypeVariables #-} +module Data.Traversable.Sort.Vector (sortTraversableBy, sortTraversable) where +import Control.Monad.ST.Strict +import Control.Monad.State.Strict +import Data.Foldable +import qualified Data.Vector.Algorithms.Intro as Intro +import qualified Data.Vector.Mutable as VM + +{-# INLINE sortTraversableBy #-} +sortTraversableBy :: (Ord a, Traversable f) + => (forall s. VM.STVector s a -> ST s ()) + -> f a + -> f a +sortTraversableBy sort val = runST (do + vec <- indexed val + sort vec + evalStateT (traverse + (\_ -> StateT + (\i -> do + r <- VM.unsafeRead vec i + return (r, i + 1))) + val) + (0 :: Int)) + +{-# INLINE sortTraversable #-} +-- | Sort a traversable container using introsort from vector-algorithms. +sortTraversable :: (Ord a, Traversable f) => f a -> f a +sortTraversable = sortTraversableBy Intro.sort + +data P s a = P + {-# UNPACK #-} !Int + !(VM.STVector s a -> ST s ()) + +{-# INLINE indexed #-} +indexed :: forall f a s. (Ord a, Foldable f) => f a -> ST s (VM.STVector s a) +indexed x = do + case foldl' + (\(P i f) el -> P + (i + 1) + (\v -> f v >> VM.unsafeWrite v i el)) + (P 0 (\_ -> return ()) :: P s a) + x of + P len initFn -> do + vec <- VM.unsafeNew len + initFn vec + return vec + +