@@ -50,6 +50,7 @@ def read_csv(
50
50
max_result_size = None ,
51
51
header = "infer" ,
52
52
names = None ,
53
+ usecols = None ,
53
54
dtype = None ,
54
55
sep = "," ,
55
56
lineterminator = "\n " ,
@@ -71,6 +72,7 @@ def read_csv(
71
72
:param max_result_size: Max number of bytes on each request to S3
72
73
:param header: Same as pandas.read_csv()
73
74
:param names: Same as pandas.read_csv()
75
+ :param usecols: Same as pandas.read_csv()
74
76
:param dtype: Same as pandas.read_csv()
75
77
:param sep: Same as pandas.read_csv()
76
78
:param lineterminator: Same as pandas.read_csv()
@@ -96,6 +98,7 @@ def read_csv(
96
98
max_result_size = max_result_size ,
97
99
header = header ,
98
100
names = names ,
101
+ usecols = usecols ,
99
102
dtype = dtype ,
100
103
sep = sep ,
101
104
lineterminator = lineterminator ,
@@ -113,6 +116,7 @@ def read_csv(
113
116
key_path = key_path ,
114
117
header = header ,
115
118
names = names ,
119
+ usecols = usecols ,
116
120
dtype = dtype ,
117
121
sep = sep ,
118
122
lineterminator = lineterminator ,
@@ -133,6 +137,7 @@ def _read_csv_iterator(
133
137
max_result_size = 200_000_000 , # 200 MB
134
138
header = "infer" ,
135
139
names = None ,
140
+ usecols = None ,
136
141
dtype = None ,
137
142
sep = "," ,
138
143
lineterminator = "\n " ,
@@ -155,6 +160,7 @@ def _read_csv_iterator(
155
160
:param max_result_size: Max number of bytes on each request to S3
156
161
:param header: Same as pandas.read_csv()
157
162
:param names: Same as pandas.read_csv()
163
+ :param usecols: Same as pandas.read_csv()
158
164
:param dtype: Same as pandas.read_csv()
159
165
:param sep: Same as pandas.read_csv()
160
166
:param lineterminator: Same as pandas.read_csv()
@@ -182,6 +188,7 @@ def _read_csv_iterator(
182
188
key_path = key_path ,
183
189
header = header ,
184
190
names = names ,
191
+ usecols = usecols ,
185
192
dtype = dtype ,
186
193
sep = sep ,
187
194
lineterminator = lineterminator ,
@@ -235,6 +242,7 @@ def _read_csv_iterator(
235
242
StringIO (body [:last_char ].decode ("utf-8" )),
236
243
header = header ,
237
244
names = names ,
245
+ usecols = usecols ,
238
246
sep = sep ,
239
247
quotechar = quotechar ,
240
248
quoting = quoting ,
@@ -353,6 +361,7 @@ def _read_csv_once(
353
361
key_path ,
354
362
header = "infer" ,
355
363
names = None ,
364
+ usecols = None ,
356
365
dtype = None ,
357
366
sep = "," ,
358
367
lineterminator = "\n " ,
@@ -374,6 +383,7 @@ def _read_csv_once(
374
383
:param key_path: S3 key path (W/o bucket)
375
384
:param header: Same as pandas.read_csv()
376
385
:param names: Same as pandas.read_csv()
386
+ :param usecols: Same as pandas.read_csv()
377
387
:param dtype: Same as pandas.read_csv()
378
388
:param sep: Same as pandas.read_csv()
379
389
:param lineterminator: Same as pandas.read_csv()
@@ -395,6 +405,7 @@ def _read_csv_once(
395
405
buff ,
396
406
header = header ,
397
407
names = names ,
408
+ usecols = usecols ,
398
409
sep = sep ,
399
410
quotechar = quotechar ,
400
411
quoting = quoting ,
@@ -714,7 +725,8 @@ def _data_to_s3_dataset_writer(dataframe,
714
725
session_primitives ,
715
726
file_format ,
716
727
cast_columns = None ,
717
- extra_args = None ):
728
+ extra_args = None ,
729
+ isolated_dataframe = False ):
718
730
objects_paths = []
719
731
if not partition_cols :
720
732
object_path = Pandas ._data_to_s3_object_writer (
@@ -725,7 +737,8 @@ def _data_to_s3_dataset_writer(dataframe,
725
737
session_primitives = session_primitives ,
726
738
file_format = file_format ,
727
739
cast_columns = cast_columns ,
728
- extra_args = extra_args )
740
+ extra_args = extra_args ,
741
+ isolated_dataframe = isolated_dataframe )
729
742
objects_paths .append (object_path )
730
743
else :
731
744
for keys , subgroup in dataframe .groupby (partition_cols ):
@@ -744,7 +757,8 @@ def _data_to_s3_dataset_writer(dataframe,
744
757
session_primitives = session_primitives ,
745
758
file_format = file_format ,
746
759
cast_columns = cast_columns ,
747
- extra_args = extra_args )
760
+ extra_args = extra_args ,
761
+ isolated_dataframe = True )
748
762
objects_paths .append (object_path )
749
763
return objects_paths
750
764
@@ -769,7 +783,8 @@ def _data_to_s3_dataset_writer_remote(send_pipe,
769
783
session_primitives = session_primitives ,
770
784
file_format = file_format ,
771
785
cast_columns = cast_columns ,
772
- extra_args = extra_args ))
786
+ extra_args = extra_args ,
787
+ isolated_dataframe = True ))
773
788
send_pipe .close ()
774
789
775
790
@staticmethod
@@ -780,7 +795,8 @@ def _data_to_s3_object_writer(dataframe,
780
795
session_primitives ,
781
796
file_format ,
782
797
cast_columns = None ,
783
- extra_args = None ):
798
+ extra_args = None ,
799
+ isolated_dataframe = False ):
784
800
fs = s3 .get_fs (session_primitives = session_primitives )
785
801
fs = pyarrow .filesystem ._ensure_filesystem (fs )
786
802
s3 .mkdir_if_not_exists (fs , path )
@@ -803,12 +819,14 @@ def _data_to_s3_object_writer(dataframe,
803
819
raise UnsupportedFileFormat (file_format )
804
820
object_path = "/" .join ([path , outfile ])
805
821
if file_format == "parquet" :
806
- Pandas .write_parquet_dataframe (dataframe = dataframe ,
807
- path = object_path ,
808
- preserve_index = preserve_index ,
809
- compression = compression ,
810
- fs = fs ,
811
- cast_columns = cast_columns )
822
+ Pandas .write_parquet_dataframe (
823
+ dataframe = dataframe ,
824
+ path = object_path ,
825
+ preserve_index = preserve_index ,
826
+ compression = compression ,
827
+ fs = fs ,
828
+ cast_columns = cast_columns ,
829
+ isolated_dataframe = isolated_dataframe )
812
830
elif file_format == "csv" :
813
831
Pandas .write_csv_dataframe (dataframe = dataframe ,
814
832
path = object_path ,
@@ -848,15 +866,17 @@ def write_csv_dataframe(dataframe,
848
866
849
867
@staticmethod
850
868
def write_parquet_dataframe (dataframe , path , preserve_index , compression ,
851
- fs , cast_columns ):
869
+ fs , cast_columns , isolated_dataframe ):
852
870
if not cast_columns :
853
871
cast_columns = {}
854
872
855
873
# Casting on Pandas
874
+ casted_in_pandas = []
856
875
dtypes = copy .deepcopy (dataframe .dtypes .to_dict ())
857
876
for name , dtype in dtypes .items ():
858
877
if str (dtype ) == "Int64" :
859
878
dataframe [name ] = dataframe [name ].astype ("float64" )
879
+ casted_in_pandas .append (name )
860
880
cast_columns [name ] = "bigint"
861
881
logger .debug (f"Casting column { name } Int64 to float64" )
862
882
@@ -885,6 +905,11 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression,
885
905
coerce_timestamps = "ms" ,
886
906
flavor = "spark" )
887
907
908
+ # Casting back on Pandas if necessary
909
+ if isolated_dataframe is False :
910
+ for col in casted_in_pandas :
911
+ dataframe [col ] = dataframe [col ].astype ("Int64" )
912
+
888
913
def to_redshift (
889
914
self ,
890
915
dataframe ,
0 commit comments