Skip to content

Commit 635f7d7

Browse files
committed
wrap methods
1 parent dcd0761 commit 635f7d7

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

pyspark_huggingface/compat/datasource.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class WriterCommitMessage:
3030
import logging
3131
import os
3232
import pickle
33+
from functools import wraps
3334

3435
from pyspark.sql.readwriter import DataFrameReader as _DataFrameReader, DataFrameWriter as _DataFrameWriter
3536

@@ -65,6 +66,7 @@ def _read_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_reader) -
6566

6667
_orig_reader_format = _DataFrameReader.format
6768

69+
@wraps(_orig_reader_format)
6870
def _new_format(self: _DataFrameReader, source: str) -> _DataFrameReader:
6971
self._format = source
7072
return _orig_reader_format(self, source)
@@ -73,6 +75,7 @@ def _new_format(self: _DataFrameReader, source: str) -> _DataFrameReader:
7375

7476
_orig_reader_option = _DataFrameReader.option
7577

78+
@wraps(_orig_reader_option)
7679
def _new_option(self: _DataFrameReader, key, value) -> _DataFrameReader:
7780
if not hasattr(self, "_options"):
7881
self._options = {}
@@ -83,6 +86,7 @@ def _new_option(self: _DataFrameReader, key, value) -> _DataFrameReader:
8386

8487
_orig_reader_options = _DataFrameReader.options
8588

89+
@wraps(_orig_reader_options)
8690
def _new_options(self: _DataFrameReader, **options) -> _DataFrameReader:
8791
if not hasattr(self, "_options"):
8892
self._options = {}
@@ -93,6 +97,7 @@ def _new_options(self: _DataFrameReader, **options) -> _DataFrameReader:
9397

9498
_orig_reader_load = _DataFrameReader.load
9599

100+
@wraps(_orig_reader_load)
96101
def _new_load(
97102
self: _DataFrameReader,
98103
path: Optional["PathOrPaths"] = None,
@@ -127,6 +132,7 @@ def _write_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_writer)
127132

128133
_orig_writer_format = _DataFrameWriter.format
129134

135+
@wraps(_orig_writer_format)
130136
def _new_format(self: _DataFrameWriter, source: str) -> _DataFrameWriter:
131137
self._format = source
132138
return _orig_writer_format(self, source)
@@ -135,6 +141,7 @@ def _new_format(self: _DataFrameWriter, source: str) -> _DataFrameWriter:
135141

136142
_orig_writer_option = _DataFrameWriter.option
137143

144+
@wraps(_orig_writer_option)
138145
def _new_option(self: _DataFrameWriter, key, value) -> _DataFrameWriter:
139146
if not hasattr(self, "_options"):
140147
self._options = {}
@@ -145,6 +152,7 @@ def _new_option(self: _DataFrameWriter, key, value) -> _DataFrameWriter:
145152

146153
_orig_writer_options = _DataFrameWriter.options
147154

155+
@wraps(_orig_writer_options)
148156
def _new_options(self: _DataFrameWriter, **options) -> _DataFrameWriter:
149157
if not hasattr(self, "_options"):
150158
self._options = {}
@@ -155,6 +163,7 @@ def _new_options(self: _DataFrameWriter, **options) -> _DataFrameWriter:
155163

156164
_orig_writer_save = _DataFrameWriter.save
157165

166+
@wraps(_orig_writer_save)
158167
def _new_save(
159168
self: _DataFrameWriter,
160169
path: Optional["PathOrPaths"] = None,

0 commit comments

Comments
 (0)