2
2
import pyarrow as pa
3
3
4
4
from dataclasses import dataclass
5
- from typing import Iterator
5
+ from typing import Iterator , List
6
6
from pyspark .sql .datasource import DataSource , DataSourceArrowWriter , WriterCommitMessage
7
7
from pyspark .sql .pandas .types import to_arrow_schema
8
8
@@ -22,7 +22,7 @@ class LanceSink(DataSource):
22
22
23
23
Create a Spark dataframe with 2 partitions:
24
24
25
- >>> df = spark.range(0, 3, 1, 2 )
25
+ >>> df = spark.createDataFrame([(1, "a"), (2, "b"), ( 3, "c")], schema="id int, value string" )
26
26
27
27
Save the dataframe in lance format:
28
28
@@ -58,7 +58,7 @@ def writer(self, schema, overwrite: bool):
58
58
59
59
@dataclass
60
60
class LanceCommitMessage (WriterCommitMessage ):
61
- fragment : lance .FragmentMetadata
61
+ fragments : List [ lance .FragmentMetadata ]
62
62
63
63
64
64
class LanceWriter (DataSourceArrowWriter ):
@@ -78,18 +78,12 @@ def _get_read_version(self):
78
78
return None
79
79
80
80
def write (self , iterator : Iterator [pa .RecordBatch ]):
81
- from pyspark import TaskContext
82
-
83
- context = TaskContext .get ()
84
- assert context is not None , "Unable to get TaskContext"
85
- task_id = context .taskAttemptId ()
86
-
87
81
reader = pa .RecordBatchReader .from_batches (self .arrow_schema , iterator )
88
- fragment = lance .LanceFragment . create ( self . uri , reader , fragment_id = task_id , schema = self .arrow_schema )
89
- return LanceCommitMessage (fragment = fragment )
82
+ fragments = lance .fragment . write_fragments ( reader , self . uri , schema = self .arrow_schema )
83
+ return LanceCommitMessage (fragments = fragments )
90
84
91
85
def commit (self , messages ):
92
- fragments = [msg . fragment for msg in messages ]
86
+ fragments = [fragment for msg in messages for fragment in msg . fragments ]
93
87
if self .read_version :
94
88
# This means the dataset already exists.
95
89
op = lance .LanceOperation .Append (fragments )
0 commit comments