Skip to content

Commit c9bdbb2

Browse files
authored
Add fake streaming source (#8)
1 parent 3e7564d commit c9bdbb2

File tree

2 files changed

+100
-26
lines changed

2 files changed

+100
-26
lines changed

pyspark_datasources/fake.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,35 @@
1-
from pyspark.sql.datasource import DataSource, DataSourceReader
2-
from pyspark.sql.types import StructType, StringType
1+
from typing import List
2+
3+
from pyspark.sql.datasource import (
4+
DataSource,
5+
DataSourceReader,
6+
DataSourceStreamReader,
7+
InputPartition,
8+
)
9+
from pyspark.sql.types import StringType, StructType
10+
11+
12+
def _validate_faker_schema(schema):
13+
# Verify the library is installed correctly.
14+
try:
15+
from faker import Faker
16+
except ImportError:
17+
raise Exception("You need to install `faker` to use the fake datasource.")
18+
19+
fake = Faker()
20+
for field in schema.fields:
21+
try:
22+
getattr(fake, field.name)()
23+
except AttributeError:
24+
raise Exception(
25+
f"Unable to find a method called `{field.name}` in faker. "
26+
f"Please check Faker's documentation to see supported methods."
27+
)
28+
if field.dataType != StringType():
29+
raise Exception(
30+
f"Field `{field.name}` is not a StringType. "
31+
f"Only StringType is supported in the fake datasource."
32+
)
333

434

535
class FakeDataSource(DataSource):
@@ -19,6 +49,7 @@ class FakeDataSource(DataSource):
1949
- The fake data source relies on the `faker` library. Make sure it is installed and accessible.
2050
- Only string type fields are supported, and each field name must correspond to a method name in
2151
the `faker` library.
52+
- When using the stream reader, `numRows` is the number of rows per microbatch.
2253
2354
Examples
2455
--------
@@ -61,6 +92,21 @@ class FakeDataSource(DataSource):
6192
| Caitlin Reed|1983-06-22| 89813|Pennsylvania|
6293
| Douglas James|2007-01-18| 46226| Alabama|
6394
+--------------+----------+-------+------------+
95+
96+
Streaming fake data:
97+
98+
>>> stream = spark.readStream.format("fake").load().writeStream.format("console").start()
99+
Batch: 0
100+
+--------------+----------+-------+------------+
101+
| name| date|zipcode| state|
102+
+--------------+----------+-------+------------+
103+
| Tommy Diaz|1976-11-17| 27627|South Dakota|
104+
|Jonathan Perez|1986-02-23| 81307|Rhode Island|
105+
| Julia Farmer|1990-10-10| 40482| Virginia|
106+
+--------------+----------+-------+------------+
107+
Batch: 1
108+
...
109+
>>> stream.stop()
64110
"""
65111

66112
@classmethod
@@ -70,40 +116,24 @@ def name(cls):
70116
def schema(self):
71117
return "name string, date string, zipcode string, state string"
72118

73-
def reader(self, schema: StructType):
74-
# Verify the library is installed correctly.
75-
try:
76-
from faker import Faker
77-
except ImportError:
78-
raise Exception("You need to install `faker` to use the fake datasource.")
79-
80-
# Check the schema is valid before proceed to reading.
81-
fake = Faker()
82-
for field in schema.fields:
83-
try:
84-
getattr(fake, field.name)()
85-
except AttributeError:
86-
raise Exception(
87-
f"Unable to find a method called `{field.name}` in faker. "
88-
f"Please check Faker's documentation to see supported methods."
89-
)
90-
if field.dataType != StringType():
91-
raise Exception(
92-
f"Field `{field.name}` is not a StringType. "
93-
f"Only StringType is supported in the fake datasource."
94-
)
95-
119+
def reader(self, schema: StructType) -> "FakeDataSourceReader":
120+
_validate_faker_schema(schema)
96121
return FakeDataSourceReader(schema, self.options)
97122

123+
def streamReader(self, schema) -> "FakeDataSourceStreamReader":
124+
_validate_faker_schema(schema)
125+
return FakeDataSourceStreamReader(schema, self.options)
126+
98127

99128
class FakeDataSourceReader(DataSourceReader):
100129

101-
def __init__(self, schema, options):
130+
def __init__(self, schema, options) -> None:
102131
self.schema: StructType = schema
103132
self.options = options
104133

105134
def read(self, partition):
106135
from faker import Faker
136+
107137
fake = Faker()
108138
# Note: every value in this `self.options` dictionary is a string.
109139
num_rows = int(self.options.get("numRows", 3))
@@ -113,3 +143,32 @@ def read(self, partition):
113143
value = getattr(fake, field.name)()
114144
row.append(value)
115145
yield tuple(row)
146+
147+
148+
class FakeDataSourceStreamReader(DataSourceStreamReader):
149+
def __init__(self, schema, options) -> None:
150+
self.schema: StructType = schema
151+
self.rows_per_microbatch = int(options.get("numRows", 3))
152+
self.options = options
153+
self.offset = 0
154+
155+
def initialOffset(self) -> dict:
156+
return {"offset": 0}
157+
158+
def latestOffset(self) -> dict:
159+
self.offset += self.rows_per_microbatch
160+
return {"offset": self.offset}
161+
162+
def partitions(self, start, end) -> List[InputPartition]:
163+
return [InputPartition(end["offset"] - start["offset"])]
164+
165+
def read(self, partition):
166+
from faker import Faker
167+
168+
fake = Faker()
169+
for _ in range(partition.value):
170+
row = []
171+
for field in self.schema.fields:
172+
value = getattr(fake, field.name)()
173+
row.append(value)
174+
yield tuple(row)

tests/test_data_sources.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ def test_github_datasource(spark):
1717
assert len(prs) > 0
1818

1919

20+
def test_fake_datasource_stream(spark):
21+
spark.dataSource.register(FakeDataSource)
22+
(
23+
spark.readStream.format("fake")
24+
.load()
25+
.writeStream.format("memory")
26+
.queryName("result")
27+
.trigger(once=True)
28+
.start()
29+
.awaitTermination()
30+
)
31+
spark.sql("SELECT * FROM result").show()
32+
assert spark.sql("SELECT * FROM result").count() == 3
33+
34+
2035
def test_fake_datasource(spark):
2136
spark.dataSource.register(FakeDataSource)
2237
df = spark.read.format("fake").load()

0 commit comments

Comments
 (0)