diff --git a/pyproject.toml b/pyproject.toml index b269f67..a1c6dbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,9 @@ license = {text = "Apache License 2.0"} readme = "README.md" requires-python = ">=3.9" dependencies = [ - "datasets>=3.2", - "huggingface-hub>=0.27.1", + "datasets>=4.0", + "huggingface-hub>=0.34.4", + "pyarrow>=21.0.0", ] [dependency-groups] diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index edef431..dea3157 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) + class HuggingFaceSink(DataSource): """ A DataSource for writing Spark DataFrames to HuggingFace Datasets. @@ -125,8 +126,9 @@ def __init__( token: str, endpoint: Optional[str] = None, row_group_size: Optional[int] = None, - max_bytes_per_file=500_000_000, - max_operations_per_commit=100, + max_bytes_per_file: int = 500_000_000, + max_operations_per_commit: int = 100, + use_content_defined_chunking: bool = True, **kwargs, ): import uuid @@ -144,6 +146,7 @@ def __init__( self.row_group_size = row_group_size self.max_bytes_per_file = max_bytes_per_file self.max_operations_per_commit = max_operations_per_commit + self.use_content_defined_chunking = use_content_defined_chunking self.kwargs = kwargs # Use a unique filename prefix to avoid conflicts with existing files @@ -210,10 +213,9 @@ def flush(writer: pq.ParquetWriter): f"{self.prefix}-{self.uuid}-part-{partition_id}-{num_files}.parquet" ) num_files += 1 - parquet.seek(0) addition = CommitOperationAdd( - path_in_repo=name, path_or_fileobj=parquet + path_in_repo=name, path_or_fileobj=parquet.getvalue() ) api.preupload_lfs_files( repo_id=self.repo_id, @@ -232,7 +234,14 @@ def flush(writer: pq.ParquetWriter): Limiting the size is necessary because we are writing them in memory. """ while True: - with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer: + with pq.ParquetWriter( + parquet, + schema=schema, + **{ + "use_content_defined_chunking": self.use_content_defined_chunking, + **self.kwargs + } + ) as writer: num_batches = 0 for batch in iterator: # Start iterating from where we left off writer.write_batch(batch, row_group_size=self.row_group_size)