22import pyarrow as pa
33
44from dataclasses import dataclass
5- from typing import Iterator
5+ from typing import Iterator , List
66from pyspark .sql .datasource import DataSource , DataSourceArrowWriter , WriterCommitMessage
77from pyspark .sql .pandas .types import to_arrow_schema
88
@@ -22,7 +22,7 @@ class LanceSink(DataSource):
2222
2323 Create a Spark dataframe with 2 partitions:
2424
25- >>> df = spark.range(0, 3, 1, 2 )
25+ >>> df = spark.createDataFrame([(1, "a"), (2, "b"), ( 3, "c")], schema="id int, value string" )
2626
2727 Save the dataframe in lance format:
2828
@@ -58,7 +58,7 @@ def writer(self, schema, overwrite: bool):
5858
5959@dataclass
6060class LanceCommitMessage (WriterCommitMessage ):
61- fragment : lance .FragmentMetadata
61+ fragments : List [ lance .FragmentMetadata ]
6262
6363
6464class LanceWriter (DataSourceArrowWriter ):
@@ -78,18 +78,12 @@ def _get_read_version(self):
7878 return None
7979
8080 def write (self , iterator : Iterator [pa .RecordBatch ]):
81- from pyspark import TaskContext
82-
83- context = TaskContext .get ()
84- assert context is not None , "Unable to get TaskContext"
85- task_id = context .taskAttemptId ()
86-
8781 reader = pa .RecordBatchReader .from_batches (self .arrow_schema , iterator )
88- fragment = lance .LanceFragment . create ( self . uri , reader , fragment_id = task_id , schema = self .arrow_schema )
89- return LanceCommitMessage (fragment = fragment )
82+ fragments = lance .fragment . write_fragments ( reader , self . uri , schema = self .arrow_schema )
83+ return LanceCommitMessage (fragments = fragments )
9084
9185 def commit (self , messages ):
92- fragments = [msg . fragment for msg in messages ]
86+ fragments = [fragment for msg in messages for fragment in msg . fragments ]
9387 if self .read_version :
9488 # This means the dataset already exists.
9589 op = lance .LanceOperation .Append (fragments )
0 commit comments