diff --git a/src/Data/List/Linear.hs b/src/Data/List/Linear.hs index e67f9993..ef920406 100644 --- a/src/Data/List/Linear.hs +++ b/src/Data/List/Linear.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE LinearTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE NoImplicitPrelude #-} @@ -354,18 +354,34 @@ zip3 :: (Consumable a, Consumable b, Consumable c) => [a] %1 -> [b] %1 -> [c] %1 zip3 = zipWith3 (,,) zipWith :: (Consumable a, Consumable b) => (a %1 -> b %1 -> c) -> [a] %1 -> [b] %1 -> [c] -zipWith f xs ys = - zipWith' f xs ys & \(ret, leftovers) -> - leftovers `lseq` ret +zipWith f = + zipWithk f (:) [] consume2 consume2 + where + consume2 :: forall x y z. (Consumable x, Consumable y) => x %1 -> y %1 -> [z] + consume2 x y = x `lseq` y `lseq` [] -- | Same as 'zipWith', but returns the leftovers instead of consuming them. +-- Because the leftovers are returned at toplevel, @zipWith'@ is pretty strict: +-- forcing the first cons cell of the returned list forces all the recursive +-- calls. zipWith' :: (a %1 -> b %1 -> c) -> [a] %1 -> [b] %1 -> ([c], Maybe (Either (NonEmpty a) (NonEmpty b))) -zipWith' _ [] [] = ([], Nothing) -zipWith' _ (a : as) [] = ([], Just (Left (a :| as))) -zipWith' _ [] (b : bs) = ([], Just (Right (b :| bs))) -zipWith' f (a : as) (b : bs) = - case zipWith' f as bs of - (cs, rest) -> (f a b : cs, rest) +zipWith' f = + zipWithk + f + (\c !(cs, rest) -> ((c : cs), rest)) + ([], Nothing) + (\a as -> ([], Just (Left (a :| as)))) + (\b bs -> ([], Just (Right (b :| bs)))) + +zipWithk :: forall r a b c. (a %1 -> b %1 -> c) -> (c %1 -> r %1 -> r) -> r -> (a %1 -> [a] %1 -> r) -> (b %1 -> [b] %1 -> r) -> [a] %1 -> [b] %1 -> r +zipWithk f cons nil lefta leftb = + go + where + go :: [a] %1 -> [b] %1 -> r + go [] [] = nil + go (a : as) [] = lefta a as + go [] (b : bs) = leftb b bs + go (a : as) (b : bs) = cons (f a b) (go as bs) zipWith3 :: forall a b c d. (Consumable a, Consumable b, Consumable c) => (a %1 -> b %1 -> c %1 -> d) -> [a] %1 -> [b] %1 -> [c] %1 -> [d] zipWith3 _ [] ys zs = (ys, zs) `lseq` [] diff --git a/test/Test/Data/List.hs b/test/Test/Data/List.hs index 7d1b679c..db29b401 100644 --- a/test/Test/Data/List.hs +++ b/test/Test/Data/List.hs @@ -4,6 +4,7 @@ module Test.Data.List (listTests) where import qualified Data.List.Linear as List +import qualified Data.Num.Linear as Num import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range @@ -17,7 +18,9 @@ listTests = testGroup "List tests" [ testPropertyNamed "take n ++ drop n = id" "take_drop" take_drop, - testPropertyNamed "length . take n = const n" "take_length" take_length + testPropertyNamed "length . take n = const n" "take_length" take_length, + testPropertyNamed "zipWith is lazy" "zipWith_lazy" zipWith_lazy, + testPropertyNamed "zipWith3 is lazy" "zipWith3_lazy" zipWith3_lazy ] take_drop :: Property @@ -41,3 +44,19 @@ take_length = property $ do False -> do annotate "Prelude.length xs < n" Prelude.length (List.take n xs) === Prelude.length xs + +zipWith_lazy :: Property +zipWith_lazy = property $ do + _ <- eval $ Prelude.head xs + Prelude.return () + where + xs :: [Word] + xs = List.zipWith (Num.+) (0 : error "bottom") [0 .. 42] + +zipWith3_lazy :: Property +zipWith3_lazy = property $ do + _ <- eval $ Prelude.head xs + Prelude.return () + where + xs :: [Word] + xs = List.zipWith3 (\x y z -> x Num.+ y Num.+ z) (0 : error "bottom") [0 .. 42] [0 .. 57]