Skip to content

Commit 666bced

Browse files
committed
backport to 3.2
1 parent 635f7d7 commit 666bced

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pyspark_huggingface/compat/datasource.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,17 @@ def _new_load(
107107
) -> "DataFrame":
108108
if (format or getattr(self, "_format", None)) == "huggingface":
109109
from functools import partial
110+
from pyspark.sql import SparkSession
110111
from pyspark_huggingface.huggingface import HuggingFaceDatasets
111112

112113
source = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_source()
113114
schema = schema or source.schema()
114115
hf_reader = source.reader(schema)
115116
partitions = hf_reader.partitions()
116117
arrow_pickler = _ArrowPickler("partition")
117-
rdd = self._spark.sparkContext.parallelize([arrow_pickler.dumps(partition) for partition in partitions], len(partitions))
118-
df = self._spark.createDataFrame(rdd)
118+
spark = self._spark if isinstance(self._spark, SparkSession) else self.spark._sc # _spark is SQLContext for older versions
119+
rdd = spark.sparkContext.parallelize([arrow_pickler.dumps(partition) for partition in partitions], len(partitions))
120+
df = spark.createDataFrame(rdd)
119121
return df.mapInArrow(partial(_read_in_arrow, arrow_pickler=arrow_pickler, hf_reader=hf_reader), schema)
120122

121123
return _orig_reader_load(self, path=path, format=format, schema=schema, **options)

0 commit comments

Comments
 (0)