Skip to content

Commit dcd0761

Browse files
committed
backport writer
1 parent 02e5e9c commit dcd0761

File tree

3 files changed

+124
-45
lines changed

3 files changed

+124
-45
lines changed
Lines changed: 113 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import TYPE_CHECKING, Iterator, Optional, Union
1+
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
22

33
import pyspark
4-
from packaging import version
54

65

7-
if version.parse(pyspark.__version__) >= version.parse("4.0.0.dev2"):
6+
if int(pyspark.__version__.split(".")[0]) >= 4 and ("dev0" not in pyspark.__version__ and "dev1" not in pyspark.__version__):
87
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader, DataSourceWriter, InputPartition, WriterCommitMessage
98
else:
109
class DataSource:
@@ -18,72 +17,81 @@ class DataSourceReader:
1817
...
1918

2019
class DataSourceWriter:
21-
...
22-
20+
def __init__(self, options):
21+
self.options = options
22+
2323
class InputPartition:
2424
...
2525

2626
class WriterCommitMessage:
2727
...
2828

2929

30-
import os
3130
import logging
31+
import os
32+
import pickle
3233

33-
from pyspark.sql.readwriter import DataFrameReader as _DataFrameReader
34+
from pyspark.sql.readwriter import DataFrameReader as _DataFrameReader, DataFrameWriter as _DataFrameWriter
3435

3536
if TYPE_CHECKING:
36-
import pyarrow as pa
37+
from pyarrow import RecordBatch
3738
from pyspark.sql.dataframe import DataFrame
3839
from pyspark.sql.readwriter import PathOrPaths
3940
from pyspark.sql.types import StructType
4041
from pyspark.sql._typing import OptionalPrimitiveType
4142

4243

43-
_orig_format = _DataFrameReader.format
44+
class _ArrowPickler:
45+
46+
def __init__(self, key: str):
47+
from pyspark.sql.types import StructType, StructField, BinaryType
48+
49+
self.key = key
50+
self.schema = StructType([StructField(self.key, BinaryType(), True)])
51+
52+
def dumps(self, obj):
53+
return {self.key: pickle.dumps(obj)}
54+
55+
def loads(self, obj):
56+
return pickle.loads(obj[self.key])
57+
58+
# Reader
59+
60+
def _read_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_reader) -> Iterator["RecordBatch"]:
61+
for batch in batches:
62+
for record in batch.to_pylist():
63+
partition = arrow_pickler.loads(record)
64+
yield from hf_reader.read(partition)
65+
66+
_orig_reader_format = _DataFrameReader.format
4467

4568
def _new_format(self: _DataFrameReader, source: str) -> _DataFrameReader:
4669
self._format = source
47-
return _orig_format(self, source)
70+
return _orig_reader_format(self, source)
4871

4972
_DataFrameReader.format = _new_format
5073

51-
_orig_option = _DataFrameReader.option
74+
_orig_reader_option = _DataFrameReader.option
5275

5376
def _new_option(self: _DataFrameReader, key, value) -> _DataFrameReader:
5477
if not hasattr(self, "_options"):
5578
self._options = {}
5679
self._options[key] = value
57-
return _orig_option(self, key, value)
80+
return _orig_reader_option(self, key, value)
5881

59-
_DataFrameReader.option = _orig_option
82+
_DataFrameReader.option = _new_option
6083

61-
_orig_options = _DataFrameReader.options
84+
_orig_reader_options = _DataFrameReader.options
6285

6386
def _new_options(self: _DataFrameReader, **options) -> _DataFrameReader:
6487
if not hasattr(self, "_options"):
6588
self._options = {}
6689
self._options.update(options)
67-
return _orig_options(self, **options)
90+
return _orig_reader_options(self, **options)
6891

69-
_DataFrameReader.options = _orig_options
92+
_DataFrameReader.options = _new_options
7093

71-
_orig_load = _DataFrameReader.load
72-
73-
class _unpack_dict(dict):
74-
...
75-
76-
class _ArrowPipe:
77-
78-
def __init__(self, *fns):
79-
self.fns = fns
80-
81-
def __call__(self, iterator: Iterator["pa.RecordBatch"]):
82-
for record_batch in iterator:
83-
for data in record_batch.to_pylist():
84-
for fn in self.fns:
85-
data = fn(**data) if isinstance(data, _unpack_dict) else fn(data)
86-
yield from data
94+
_orig_reader_load = _DataFrameReader.load
8795

8896
def _new_load(
8997
self: _DataFrameReader,
@@ -93,21 +101,87 @@ def _new_load(
93101
**options: "OptionalPrimitiveType",
94102
) -> "DataFrame":
95103
if (format or getattr(self, "_format", None)) == "huggingface":
96-
from dataclasses import asdict
104+
from functools import partial
97105
from pyspark_huggingface.huggingface import HuggingFaceDatasets
98106

99107
source = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_source()
100108
schema = schema or source.schema()
101-
reader = source.reader(schema)
102-
partitions = reader.partitions()
103-
partition_cls = type(partitions[0])
104-
rdd = self._spark.sparkContext.parallelize([asdict(partition) for partition in partitions], len(partitions))
109+
hf_reader = source.reader(schema)
110+
partitions = hf_reader.partitions()
111+
arrow_pickler = _ArrowPickler("partition")
112+
rdd = self._spark.sparkContext.parallelize([arrow_pickler.dumps(partition) for partition in partitions], len(partitions))
105113
df = self._spark.createDataFrame(rdd)
106-
return df.mapInArrow(_ArrowPipe(_unpack_dict, partition_cls, reader.read), schema)
114+
return df.mapInArrow(partial(_read_in_arrow, arrow_pickler=arrow_pickler, hf_reader=hf_reader), schema)
107115

108-
return _orig_load(self, path=path, format=format, schema=schema, **options)
116+
return _orig_reader_load(self, path=path, format=format, schema=schema, **options)
109117

110118
_DataFrameReader.load = _new_load
111119

120+
# Writer
121+
122+
def _write_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_writer) -> Iterator["RecordBatch"]:
123+
from pyarrow import RecordBatch
124+
125+
commit_message = hf_writer.write(batches)
126+
yield RecordBatch.from_pylist([arrow_pickler.dumps(commit_message)])
127+
128+
_orig_writer_format = _DataFrameWriter.format
129+
130+
def _new_format(self: _DataFrameWriter, source: str) -> _DataFrameWriter:
131+
self._format = source
132+
return _orig_writer_format(self, source)
133+
134+
_DataFrameWriter.format = _new_format
135+
136+
_orig_writer_option = _DataFrameWriter.option
137+
138+
def _new_option(self: _DataFrameWriter, key, value) -> _DataFrameWriter:
139+
if not hasattr(self, "_options"):
140+
self._options = {}
141+
self._options[key] = value
142+
return _orig_writer_option(self, key, value)
143+
144+
_DataFrameWriter.option = _new_option
145+
146+
_orig_writer_options = _DataFrameWriter.options
147+
148+
def _new_options(self: _DataFrameWriter, **options) -> _DataFrameWriter:
149+
if not hasattr(self, "_options"):
150+
self._options = {}
151+
self._options.update(options)
152+
return _orig_writer_options(self, **options)
153+
154+
_DataFrameWriter.options = _new_options
155+
156+
_orig_writer_save = _DataFrameWriter.save
157+
158+
def _new_save(
159+
self: _DataFrameWriter,
160+
path: Optional["PathOrPaths"] = None,
161+
format: Optional[str] = None,
162+
mode: Optional[Union["StructType", str]] = None,
163+
partitionBy: Optional[Union[str, List[str]]] = None,
164+
**options: "OptionalPrimitiveType",
165+
) -> "DataFrame":
166+
if (format or getattr(self, "_format", None)) == "huggingface":
167+
from functools import partial
168+
from pyspark_huggingface.huggingface import HuggingFaceDatasets
169+
170+
sink = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_sink()
171+
schema = self._df.schema
172+
mode = options.pop("mode", None)
173+
hf_writer = sink.writer(schema, overwrite=(mode == "overwrite"))
174+
arrow_pickler = _ArrowPickler("commit_message")
175+
commit_messages = self._df.mapInArrow(partial(_write_in_arrow, arrow_pickler=arrow_pickler, hf_writer=hf_writer), arrow_pickler.schema).collect()
176+
commit_messages = [arrow_pickler.loads(commit_message) for commit_message in commit_messages]
177+
hf_writer.commit(commit_messages)
178+
return
179+
180+
return _orig_writer_save(self, path=path, format=format, schema=schema, **options)
181+
182+
_DataFrameWriter.save = _new_save
183+
184+
# Log only in driver
185+
112186
if not os.environ.get("SPARK_ENV_LOADED"):
113187
logging.getLogger(__name__).warning(f"huggingface datasource enabled for pyspark {pyspark.__version__} (backport from pyspark 4)")

pyspark_huggingface/huggingface_sink.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
55

6-
from pyspark.sql.datasource import (
6+
from pyspark.sql.types import StructType
7+
from pyspark_huggingface.compat.datasource import (
78
DataSource,
89
DataSourceArrowWriter,
910
WriterCommitMessage,
1011
)
11-
from pyspark.sql.types import StructType
1212

1313
if TYPE_CHECKING:
1414
from huggingface_hub import (
@@ -66,12 +66,14 @@ def __init__(self, options):
6666
if "path" not in options or not options["path"]:
6767
raise Exception("You must specify a dataset name.")
6868

69+
from huggingface_hub import get_token
70+
6971
kwargs = dict(self.options)
70-
self.token = kwargs.pop("token")
7172
self.repo_id = kwargs.pop("path")
7273
self.path_in_repo = kwargs.pop("path_in_repo", None)
7374
self.split = kwargs.pop("split", None)
7475
self.revision = kwargs.pop("revision", None)
76+
self.token = kwargs.pop("token", None) or get_token()
7577
self.endpoint = kwargs.pop("endpoint", None)
7678
for arg in kwargs:
7779
if kwargs[arg].lower() == "true":
@@ -89,7 +91,7 @@ def __init__(self, options):
8991
def name(cls):
9092
return "huggingfacesink"
9193

92-
def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter:
94+
def writer(self, schema: StructType, overwrite: bool) -> "HuggingFaceDatasetsWriter":
9395
return HuggingFaceDatasetsWriter(
9496
repo_id=self.repo_id,
9597
path_in_repo=self.path_in_repo,

pyspark_huggingface/huggingface_source.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,16 @@ def __init__(self, options):
8282
if "path" not in options or not options["path"]:
8383
raise Exception("You must specify a dataset name.")
8484

85+
from huggingface_hub import get_token
86+
8587
kwargs = dict(self.options)
8688
self.dataset_name = kwargs.pop("path")
8789
self.config_name = kwargs.pop("config", None)
8890
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
8991
self.revision = kwargs.pop("revision", None)
9092
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
91-
self.token = kwargs.pop("token", None)
93+
self.token = kwargs.pop("token", None) or get_token()
94+
self.endpoint = kwargs.pop("endpoint", None)
9295
for arg in kwargs:
9396
if kwargs[arg].lower() == "true":
9497
kwargs[arg] = True
@@ -116,7 +119,7 @@ def __init__(self, options):
116119
def _get_api(self):
117120
from huggingface_hub import HfApi
118121

119-
return HfApi(token=self.token, library_name="pyspark_huggingface")
122+
return HfApi(token=self.token, endpoint=self.endpoint, library_name="pyspark_huggingface")
120123

121124
@classmethod
122125
def name(cls):

0 commit comments

Comments
 (0)