Skip to content

Commit fb8c482

Browse files
committed
[SPARK-54305][PySpark][Streaming] Add admission control support for Python streaming data sources
1 parent e09c999 commit fb8c482

File tree

9 files changed

+255
-101
lines changed

9 files changed

+255
-101
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,8 @@ sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/
129129
tpcds-sf-1/
130130
tpcds-sf-1-text/
131131
tpcds-kit/
132+
133+
# Cursor AI configuration files (local development only)
134+
.cursorrules*
135+
PR_SPARK_GUIDELINES.MD
136+
SPARK_WORKFLOW.md
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. See also ``pyspark.sql.sources.DataSource.streamReader``.
2+
3+
The parameter `read_limit` in `latestOffset` provides the read limit for the current batch.
4+
The implementation can use this information to cap the number of rows returned in the batch.
5+
For example, if the `read_limit` is `{"maxRows": 1000}`, the data source should not return
6+
more than 1000 rows. The available read limit types are:
7+
8+
* `maxRows`: the maximum number of rows to return in a batch.
9+
* `minRows`: the minimum number of rows to return in a batch.
10+
* `maxBytes`: the maximum size in bytes to return in a batch.
11+
* `minBytes`: the minimum size in bytes to return in a batch.
12+
* `allAvailable`: return all available data in a batch.

python/pyspark/sql/datasource.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ class Filter(ABC):
300300
301301
+---------------------+--------------------------------------------+
302302
| SQL filter | Representation |
303-
+---------------------+--------------------------------------------+
303+
+---------------------+---------------------------------------------+
304304
| `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` |
305305
| `a = 1` | `EqualTo(("a",), 1)` |
306306
| `a = 'hi'` | `EqualTo(("a",), "hi")` |
@@ -685,56 +685,23 @@ def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["Re
685685

686686
class DataSourceStreamReader(ABC):
687687
"""
688-
A base class for streaming data source readers. Data source stream readers are responsible
689-
for outputting data from a streaming data source.
688+
An interface for streaming data source.
690689
691690
.. versionadded: 4.0.0
692691
"""
693692

694-
def initialOffset(self) -> dict:
695-
"""
696-
Return the initial offset of the streaming data source.
697-
A new streaming query starts reading data from the initial offset.
698-
If Spark is restarting an existing query, it will restart from the check-pointed offset
699-
rather than the initial one.
700-
701-
Returns
702-
-------
703-
dict
704-
A dict or recursive dict whose key and value are primitive types, which includes
705-
Integer, String and Boolean.
706-
707-
Examples
708-
--------
709-
>>> def initialOffset(self):
710-
... return {"parititon-1": {"index": 3, "closed": True}, "partition-2": {"index": 5}}
711-
"""
712-
raise PySparkNotImplementedError(
713-
errorClass="NOT_IMPLEMENTED",
714-
messageParameters={"feature": "initialOffset"},
715-
)
693+
def initialOffset(self) -> str:
694+
pass
716695

717-
def latestOffset(self) -> dict:
696+
@abstractmethod
697+
def latestOffset(self) -> str:
718698
"""
719699
Returns the most recent offset available.
720-
721-
Returns
722-
-------
723-
dict
724-
A dict or recursive dict whose key and value are primitive types, which includes
725-
Integer, String and Boolean.
726-
727-
Examples
728-
--------
729-
>>> def latestOffset(self):
730-
... return {"parititon-1": {"index": 3, "closed": True}, "partition-2": {"index": 5}}
731700
"""
732-
raise PySparkNotImplementedError(
733-
errorClass="NOT_IMPLEMENTED",
734-
messageParameters={"feature": "latestOffset"},
735-
)
701+
pass
736702

737-
def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
703+
@abstractmethod
704+
def partitions(self, start: str, end: str) -> List[bytes]:
738705
"""
739706
Returns a list of InputPartition given the start and end offsets. Each InputPartition
740707
represents a data split that can be processed by one Spark task. This may be called with

python/pyspark/sql/datasource_internal.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
import copy
2121
from itertools import chain
22-
from typing import Iterator, List, Optional, Sequence, Tuple
22+
from typing import Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING
2323

2424
from pyspark.sql.datasource import (
2525
DataSource,
@@ -77,25 +77,22 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
7777
replayed by reading data between start and end offset through readBetweenOffsets(start, end).
7878
"""
7979

80-
def __init__(self, simple_reader: SimpleDataSourceStreamReader):
81-
self.simple_reader = simple_reader
82-
self.initial_offset: Optional[dict] = None
83-
self.current_offset: Optional[dict] = None
84-
self.cache: List[PrefetchedCacheEntry] = []
80+
def __init__(self, reader: "DataSourceStreamReader"):
81+
self.reader = reader
8582

8683
def initialOffset(self) -> dict:
87-
if self.initial_offset is None:
88-
self.initial_offset = self.simple_reader.initialOffset()
89-
return self.initial_offset
90-
91-
def latestOffset(self) -> dict:
92-
# when query start for the first time, use initial offset as the start offset.
93-
if self.current_offset is None:
94-
self.current_offset = self.initialOffset()
95-
(iter, end) = self.simple_reader.read(self.current_offset)
96-
self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
97-
self.current_offset = end
98-
return end
84+
return self.reader.initialOffset()
85+
86+
def latestOffset(self, start: Optional[dict], read_limit: Dict) -> dict:
87+
# For backward compatibility, `latestOffset` with two arguments is not an abstract method.
88+
# If the user-defined stream reader does not implement that, it will fall back to
89+
# the `latestOffset` with no argument.
90+
if hasattr(self.reader, "latestOffset") and not isinstance(
91+
self.reader, SimpleDataSourceStreamReader
92+
):
93+
return self.reader.latestOffset(start, read_limit)
94+
else:
95+
return self.reader.latestOffset()
9996

10097
def commit(self, end: dict) -> None:
10198
if self.current_offset is None:

python/pyspark/sql/streaming/python_streaming_source_runner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
write_int,
2828
write_with_length,
2929
SpecialLengths,
30+
read_bool,
31+
read_with_length,
3032
)
3133
from pyspark.sql.datasource import DataSource, DataSourceStreamReader
3234
from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, _streamReader
@@ -169,7 +171,13 @@ def main(infile: IO, outfile: IO) -> None:
169171
if func_id == INITIAL_OFFSET_FUNC_ID:
170172
initial_offset_func(reader, outfile)
171173
elif func_id == LATEST_OFFSET_FUNC_ID:
172-
latest_offset_func(reader, outfile)
174+
has_start = read_bool(infile)
175+
if has_start:
176+
start = read_with_length(infile)
177+
else:
178+
start = None
179+
read_limit = json.loads(read_with_length(infile))
180+
write_with_length(reader.latestOffset(start, read_limit))
173181
elif func_id == PARTITIONS_FUNC_ID:
174182
partitions_func(
175183
reader, data_source, schema, max_arrow_batch_size, infile, outfile
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import os
18+
import time
19+
import unittest
20+
21+
from pyspark.sql.datasource import DataSource, DataSourceStreamReader
22+
from pyspark.sql.functions import F
23+
from pyspark.sql.streaming import StreamTest
24+
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
25+
26+
27+
class RateLimitStreamReader(DataSourceStreamReader):
28+
def __init__(self, start, max_rows_per_batch):
29+
self._start = start
30+
self._max_rows_per_batch = max_rows_per_batch
31+
self._next_offset = start
32+
33+
def initialOffset(self):
34+
return str(self._start)
35+
36+
def latestOffset(self, start, read_limit):
37+
max_rows = read_limit.get("maxRows", self._max_rows_per_batch)
38+
self._next_offset += max_rows
39+
return str(self._next_offset)
40+
41+
def partitions(self, start, end):
42+
return [str(i).encode("utf-8") for i in range(int(start), int(end))]
43+
44+
45+
class RateLimitDataSource(DataSource):
46+
def __init__(self, options):
47+
self._max_rows_per_batch = int(options.get("maxRowsPerBatch", "100"))
48+
49+
def streamReader(self, schema):
50+
return RateLimitStreamReader(0, self._max_rows_per_batch)
51+
52+
53+
class BackwardCompatibilityStreamReader(DataSourceStreamReader):
54+
def __init__(self, start):
55+
self._start = start
56+
self._next_offset = start
57+
58+
def initialOffset(self):
59+
return str(self._start)
60+
61+
def latestOffset(self):
62+
self._next_offset += 1
63+
return str(self._next_offset)
64+
65+
def partitions(self, start, end):
66+
return [str(i).encode("utf-8") for i in range(int(start), int(end))]
67+
68+
69+
class BackwardCompatibilityDataSource(DataSource):
70+
def streamReader(self, schema):
71+
return BackwardCompatibilityStreamReader(0)
72+
73+
74+
class StreamingDataSourceAdmissionControlTests(StreamTest):
75+
def test_backward_compatibility(self):
76+
df = (
77+
self.spark.readStream.format(
78+
"org.apache.spark.sql.streaming.test.BackwardCompatibilityDataSource"
79+
)
80+
.option("includeTimestamp", "true")
81+
.load()
82+
)
83+
self.assertTrue(df.isStreaming)
84+
85+
q = df.writeStream.queryName("test").format("memory").start()
86+
try:
87+
time.sleep(5)
88+
self.assertTrue(self.spark.table("test").count() > 0)
89+
finally:
90+
q.stop()
91+
92+
def test_rate_limit(self):
93+
df = (
94+
self.spark.readStream.format("org.apache.spark.sql.streaming.test.RateLimitDataSource")
95+
.option("maxRowsPerBatch", "5")
96+
.load()
97+
)
98+
self.assertTrue(df.isStreaming)
99+
100+
q = df.writeStream.queryName("test_rate_limit").format("memory").start()
101+
try:
102+
time.sleep(5)
103+
# The exact count can vary, but it should be a multiple of 5.
104+
count = self.spark.table("test_rate_limit").count()
105+
self.assertTrue(count > 0)
106+
self.assertEqual(count % 5, 0)
107+
finally:
108+
q.stop()
109+
110+
111+
if __name__ == "__main__":
112+
from pyspark.sql.tests.streaming.test_streaming_datasource_admission_control import *
113+
114+
try:
115+
import xmlrunner
116+
117+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
118+
except ImportError:
119+
testRunner = None
120+
unittest.main(testRunner=testRunner, verbosity=2)

sql/core/src/main/scala/org/apache.spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala

Whitespace-only changes.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.python
1919
import org.apache.spark.SparkEnv
2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
22-
import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset}
22+
import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl}
2323
import org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
2424
import org.apache.spark.sql.execution.python.streaming.PythonStreamingSourceRunner
2525
import org.apache.spark.sql.types.StructType
@@ -32,11 +32,11 @@ class PythonMicroBatchStream(
3232
ds: PythonDataSourceV2,
3333
shortName: String,
3434
outputSchema: StructType,
35-
options: CaseInsensitiveStringMap
36-
)
37-
extends MicroBatchStream
38-
with Logging
39-
with AcceptsLatestSeenOffset {
35+
options: CaseInsensitiveStringMap)
36+
extends MicroBatchStream
37+
with Logging
38+
with AcceptsLatestSeenOffset
39+
with SupportsAdmissionControl {
4040
private def createDataSourceFunc =
4141
ds.source.createPythonFunction(
4242
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)).dataSource)
@@ -55,7 +55,11 @@ class PythonMicroBatchStream(
5555

5656
override def initialOffset(): Offset = PythonStreamingSourceOffset(runner.initialOffset())
5757

58-
override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset())
58+
override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset(None))
59+
60+
override def latestOffset(start: Offset, limit: ReadLimit): Offset = {
61+
PythonStreamingSourceOffset(runner.latestOffset(Some(start), Some(limit)))
62+
}
5963

6064
override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = {
6165
val startOffsetJson = start.asInstanceOf[PythonStreamingSourceOffset].json
@@ -72,7 +76,10 @@ class PythonMicroBatchStream(
7276
nextBlockId = nextBlockId + 1
7377
val blockId = PythonStreamBlockId(streamId, nextBlockId)
7478
SparkEnv.get.blockManager.putIterator(
75-
blockId, rows.get, StorageLevel.MEMORY_AND_DISK_SER, true)
79+
blockId,
80+
rows.get,
81+
StorageLevel.MEMORY_AND_DISK_SER,
82+
true)
7683
val partition = PythonStreamingInputPartition(0, partitions.head, Some(blockId))
7784
cachedInputPartition.foreach(_._3.dropCache())
7885
cachedInputPartition = Some((startOffsetJson, endOffsetJson, partition))
@@ -94,8 +101,7 @@ class PythonMicroBatchStream(
94101
}
95102

96103
override def createReaderFactory(): PartitionReaderFactory = {
97-
new PythonStreamingPartitionReaderFactory(
98-
ds.source, readInfo.func, outputSchema, None, None)
104+
new PythonStreamingPartitionReaderFactory(ds.source, readInfo.func, outputSchema, None, None)
99105
}
100106

101107
override def commit(end: Offset): Unit = {

0 commit comments

Comments
 (0)