Skip to content

Commit 4d8969d

Browse files
feat: add shuffle option in read_files
1 parent 8537fe1 commit 4d8969d

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

pipd/functions/read_files.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import glob
33
import os
4+
import random
45
from typing import Iterable, Iterator, Optional, Sequence, TypeVar
56

67
from pipd import Function, Pipe
@@ -29,10 +30,14 @@ def watchdir(
2930

3031
class ReadFiles(Function):
3132
def __init__(
32-
self, cache_filepath: Optional[str] = None, watch: bool = False
33+
self,
34+
cache_filepath: Optional[str] = None,
35+
watch: bool = False,
36+
shuffle: bool = False,
3337
) -> None:
3438
self.cache_filepath = cache_filepath
3539
self.watch = watch
40+
self.shuffle = shuffle
3641

3742
def __call__(self, items: Iterable[str]) -> Iterator[str]:
3843
for filepath in items:
@@ -46,6 +51,9 @@ def __call__(self, items: Iterable[str]) -> Iterator[str]:
4651
else:
4752
files = glob.glob(filepath)
4853

54+
if self.shuffle:
55+
random.shuffle(files)
56+
4957
for file in files:
5058
yield file
5159
if self.watch:

pipd/tests/test_functions.py

+11
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,14 @@ def test_read_files():
158158
pipe = Pipe([f"{d}/*.txt"]).read_files(cache_filepath=f"{d}/cache.txt")
159159
assert list(pipe) == [f1.name, f2.name]
160160
os.remove(f"{d}/cache.txt")
161+
162+
# Test shuffle
163+
with tempfile.TemporaryDirectory() as d:
164+
with open(os.path.join(d, "f1.txt"), "w") as f1:
165+
f1.write("a")
166+
with open(os.path.join(d, "f2.txt"), "w") as f2:
167+
f2.write("b")
168+
pipe = Pipe([f"{d}/*.txt"]).read_files(shuffle=True)
169+
assert sorted(list(pipe)) == sorted([f1.name, f2.name])
170+
os.remove(f1.name)
171+
os.remove(f2.name)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="pipd",
55
packages=find_packages(exclude=[]),
6-
version="0.1.3",
6+
version="0.1.4",
77
description="Utility functions for python data pipelines.",
88
long_description_content_type="text/markdown",
99
author="ElevenLabs",

0 commit comments

Comments
 (0)