1- from typing import TYPE_CHECKING , Iterator , Optional , Union
1+ from typing import TYPE_CHECKING , Iterator , List , Optional , Union
22
33import pyspark
4- from packaging import version
54
65
7- if version . parse (pyspark .__version__ ) >= version . parse ( "4.0.0.dev2" ):
6+ if int (pyspark .__version__ . split ( "." )[ 0 ]) >= 4 and ( "dev0" not in pyspark . __version__ and "dev1" not in pyspark . __version__ ):
87 from pyspark .sql .datasource import DataSource , DataSourceArrowWriter , DataSourceReader , DataSourceWriter , InputPartition , WriterCommitMessage
98else :
109 class DataSource :
@@ -18,72 +17,81 @@ class DataSourceReader:
1817 ...
1918
2019 class DataSourceWriter :
21- ...
22-
20+ def __init__ (self , options ):
21+ self .options = options
22+
2323 class InputPartition :
2424 ...
2525
2626 class WriterCommitMessage :
2727 ...
2828
2929
30- import os
3130 import logging
31+ import os
32+ import pickle
3233
33- from pyspark .sql .readwriter import DataFrameReader as _DataFrameReader
34+ from pyspark .sql .readwriter import DataFrameReader as _DataFrameReader , DataFrameWriter as _DataFrameWriter
3435
3536 if TYPE_CHECKING :
36- import pyarrow as pa
37+ from pyarrow import RecordBatch
3738 from pyspark .sql .dataframe import DataFrame
3839 from pyspark .sql .readwriter import PathOrPaths
3940 from pyspark .sql .types import StructType
4041 from pyspark .sql ._typing import OptionalPrimitiveType
4142
4243
43- _orig_format = _DataFrameReader .format
44+ class _ArrowPickler :
45+
46+ def __init__ (self , key : str ):
47+ from pyspark .sql .types import StructType , StructField , BinaryType
48+
49+ self .key = key
50+ self .schema = StructType ([StructField (self .key , BinaryType (), True )])
51+
52+ def dumps (self , obj ):
53+ return {self .key : pickle .dumps (obj )}
54+
55+ def loads (self , obj ):
56+ return pickle .loads (obj [self .key ])
57+
58+ # Reader
59+
60+ def _read_in_arrow (batches : Iterator ["RecordBatch" ], arrow_pickler , hf_reader ) -> Iterator ["RecordBatch" ]:
61+ for batch in batches :
62+ for record in batch .to_pylist ():
63+ partition = arrow_pickler .loads (record )
64+ yield from hf_reader .read (partition )
65+
66+ _orig_reader_format = _DataFrameReader .format
4467
4568 def _new_format (self : _DataFrameReader , source : str ) -> _DataFrameReader :
4669 self ._format = source
47- return _orig_format (self , source )
70+ return _orig_reader_format (self , source )
4871
4972 _DataFrameReader .format = _new_format
5073
51- _orig_option = _DataFrameReader .option
74+ _orig_reader_option = _DataFrameReader .option
5275
5376 def _new_option (self : _DataFrameReader , key , value ) -> _DataFrameReader :
5477 if not hasattr (self , "_options" ):
5578 self ._options = {}
5679 self ._options [key ] = value
57- return _orig_option (self , key , value )
80+ return _orig_reader_option (self , key , value )
5881
59- _DataFrameReader .option = _orig_option
82+ _DataFrameReader .option = _new_option
6083
61- _orig_options = _DataFrameReader .options
84+ _orig_reader_options = _DataFrameReader .options
6285
6386 def _new_options (self : _DataFrameReader , ** options ) -> _DataFrameReader :
6487 if not hasattr (self , "_options" ):
6588 self ._options = {}
6689 self ._options .update (options )
67- return _orig_options (self , ** options )
90+ return _orig_reader_options (self , ** options )
6891
69- _DataFrameReader .options = _orig_options
92+ _DataFrameReader .options = _new_options
7093
71- _orig_load = _DataFrameReader .load
72-
73- class _unpack_dict (dict ):
74- ...
75-
76- class _ArrowPipe :
77-
78- def __init__ (self , * fns ):
79- self .fns = fns
80-
81- def __call__ (self , iterator : Iterator ["pa.RecordBatch" ]):
82- for record_batch in iterator :
83- for data in record_batch .to_pylist ():
84- for fn in self .fns :
85- data = fn (** data ) if isinstance (data , _unpack_dict ) else fn (data )
86- yield from data
94+ _orig_reader_load = _DataFrameReader .load
8795
8896 def _new_load (
8997 self : _DataFrameReader ,
@@ -93,21 +101,87 @@ def _new_load(
93101 ** options : "OptionalPrimitiveType" ,
94102 ) -> "DataFrame" :
95103 if (format or getattr (self , "_format" , None )) == "huggingface" :
96- from dataclasses import asdict
104+ from functools import partial
97105 from pyspark_huggingface .huggingface import HuggingFaceDatasets
98106
99107 source = HuggingFaceDatasets (options = {** getattr (self , "_options" , {}), ** options , "path" : path }).get_source ()
100108 schema = schema or source .schema ()
101- reader = source .reader (schema )
102- partitions = reader .partitions ()
103- partition_cls = type ( partitions [ 0 ] )
104- rdd = self ._spark .sparkContext .parallelize ([asdict (partition ) for partition in partitions ], len (partitions ))
109+ hf_reader = source .reader (schema )
110+ partitions = hf_reader .partitions ()
111+ arrow_pickler = _ArrowPickler ( "partition" )
112+ rdd = self ._spark .sparkContext .parallelize ([arrow_pickler . dumps (partition ) for partition in partitions ], len (partitions ))
105113 df = self ._spark .createDataFrame (rdd )
106- return df .mapInArrow (_ArrowPipe ( _unpack_dict , partition_cls , reader . read ), schema )
114+ return df .mapInArrow (partial ( _read_in_arrow , arrow_pickler = arrow_pickler , hf_reader = hf_reader ), schema )
107115
108- return _orig_load (self , path = path , format = format , schema = schema , ** options )
116+ return _orig_reader_load (self , path = path , format = format , schema = schema , ** options )
109117
110118 _DataFrameReader .load = _new_load
111119
120+ # Writer
121+
122+ def _write_in_arrow (batches : Iterator ["RecordBatch" ], arrow_pickler , hf_writer ) -> Iterator ["RecordBatch" ]:
123+ from pyarrow import RecordBatch
124+
125+ commit_message = hf_writer .write (batches )
126+ yield RecordBatch .from_pylist ([arrow_pickler .dumps (commit_message )])
127+
128+ _orig_writer_format = _DataFrameWriter .format
129+
130+ def _new_format (self : _DataFrameWriter , source : str ) -> _DataFrameWriter :
131+ self ._format = source
132+ return _orig_writer_format (self , source )
133+
134+ _DataFrameWriter .format = _new_format
135+
136+ _orig_writer_option = _DataFrameWriter .option
137+
138+ def _new_option (self : _DataFrameWriter , key , value ) -> _DataFrameWriter :
139+ if not hasattr (self , "_options" ):
140+ self ._options = {}
141+ self ._options [key ] = value
142+ return _orig_writer_option (self , key , value )
143+
144+ _DataFrameWriter .option = _new_option
145+
146+ _orig_writer_options = _DataFrameWriter .options
147+
148+ def _new_options (self : _DataFrameWriter , ** options ) -> _DataFrameWriter :
149+ if not hasattr (self , "_options" ):
150+ self ._options = {}
151+ self ._options .update (options )
152+ return _orig_writer_options (self , ** options )
153+
154+ _DataFrameWriter .options = _new_options
155+
156+ _orig_writer_save = _DataFrameWriter .save
157+
158+ def _new_save (
159+ self : _DataFrameWriter ,
160+ path : Optional ["PathOrPaths" ] = None ,
161+ format : Optional [str ] = None ,
162+ mode : Optional [Union ["StructType" , str ]] = None ,
163+ partitionBy : Optional [Union [str , List [str ]]] = None ,
164+ ** options : "OptionalPrimitiveType" ,
165+ ) -> "DataFrame" :
166+ if (format or getattr (self , "_format" , None )) == "huggingface" :
167+ from functools import partial
168+ from pyspark_huggingface .huggingface import HuggingFaceDatasets
169+
170+ sink = HuggingFaceDatasets (options = {** getattr (self , "_options" , {}), ** options , "path" : path }).get_sink ()
171+ schema = self ._df .schema
172+ mode = options .pop ("mode" , None )
173+ hf_writer = sink .writer (schema , overwrite = (mode == "overwrite" ))
174+ arrow_pickler = _ArrowPickler ("commit_message" )
175+ commit_messages = self ._df .mapInArrow (partial (_write_in_arrow , arrow_pickler = arrow_pickler , hf_writer = hf_writer ), arrow_pickler .schema ).collect ()
176+ commit_messages = [arrow_pickler .loads (commit_message ) for commit_message in commit_messages ]
177+ hf_writer .commit (commit_messages )
178+ return
179+
180+ return _orig_writer_save (self , path = path , format = format , schema = schema , ** options )
181+
182+ _DataFrameWriter .save = _new_save
183+
184+ # Log only in driver
185+
112186 if not os .environ .get ("SPARK_ENV_LOADED" ):
113187 logging .getLogger (__name__ ).warning (f"huggingface datasource enabled for pyspark { pyspark .__version__ } (backport from pyspark 4)" )
0 commit comments