Skip to content

Commit 08a229a

Browse files
committed
Make List.zipWith as lazy as expected
1 parent c566b0b commit 08a229a

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

src/Data/List/Linear.hs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE LinearTypes #-}
23
{-# LANGUAGE ScopedTypeVariables #-}
34
{-# LANGUAGE NoImplicitPrelude #-}
@@ -353,18 +354,34 @@ zip3 :: (Consumable a, Consumable b, Consumable c) => [a] %1 -> [b] %1 -> [c] %1
353354
zip3 = zipWith3 (,,)
354355

355356
zipWith :: (Consumable a, Consumable b) => (a %1 -> b %1 -> c) -> [a] %1 -> [b] %1 -> [c]
356-
zipWith f xs ys =
357-
zipWith' f xs ys & \(ret, leftovers) ->
358-
leftovers `lseq` ret
357+
zipWith f =
358+
zipWithk f (:) [] consume2 consume2
359+
where
360+
consume2 :: forall x y z. (Consumable x, Consumable y) => x %1 -> y %1 -> [z]
361+
consume2 x y = x `lseq` y `lseq` []
359362

360363
-- | Same as 'zipWith', but returns the leftovers instead of consuming them.
364+
-- Because the leftovers are returned at toplevel, @zipWith'@ is pretty strict:
365+
-- forcing the first cons cell of the returned list forces all the recursive
366+
-- calls.
361367
zipWith' :: (a %1 -> b %1 -> c) -> [a] %1 -> [b] %1 -> ([c], Maybe (Either (NonEmpty a) (NonEmpty b)))
362-
zipWith' _ [] [] = ([], Nothing)
363-
zipWith' _ (a : as) [] = ([], Just (Left (a :| as)))
364-
zipWith' _ [] (b : bs) = ([], Just (Right (b :| bs)))
365-
zipWith' f (a : as) (b : bs) =
366-
case zipWith' f as bs of
367-
(cs, rest) -> (f a b : cs, rest)
368+
zipWith' f =
369+
zipWithk
370+
f
371+
(\c !(cs, rest) -> ((c : cs), rest))
372+
([], Nothing)
373+
(\a as -> ([], Just (Left (a :| as))))
374+
(\b bs -> ([], Just (Right (b :| bs))))
375+
376+
zipWithk :: forall r a b c. (a %1 -> b %1 -> c) -> (c %1 -> r %1 -> r) -> r -> (a %1 -> [a] %1 -> r) -> (b %1 -> [b] %1 -> r) -> [a] %1 -> [b] %1 -> r
377+
zipWithk f cons nil lefta leftb =
378+
go
379+
where
380+
go :: [a] %1 -> [b] %1 -> r
381+
go [] [] = nil
382+
go (a : as) [] = lefta a as
383+
go [] (b : bs) = leftb b bs
384+
go (a : as) (b : bs) = cons (f a b) (go as bs)
368385

369386
zipWith3 :: forall a b c d. (Consumable a, Consumable b, Consumable c) => (a %1 -> b %1 -> c %1 -> d) -> [a] %1 -> [b] %1 -> [c] %1 -> [d]
370387
zipWith3 _ [] ys zs = (ys, zs) `lseq` []

0 commit comments

Comments
 (0)