Skip to content

Commit c31aec0

Browse files
update lance
1 parent e8f68cb commit c31aec0

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

pyspark_datasources/lance.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pyarrow as pa
33

44
from dataclasses import dataclass
5-
from typing import Iterator
5+
from typing import Iterator, List
66
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage
77
from pyspark.sql.pandas.types import to_arrow_schema
88

@@ -22,7 +22,7 @@ class LanceSink(DataSource):
2222
2323
Create a Spark dataframe with 2 partitions:
2424
25-
>>> df = spark.range(0, 3, 1, 2)
25+
>>> df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], schema="id int, value string")
2626
2727
Save the dataframe in lance format:
2828
@@ -58,7 +58,7 @@ def writer(self, schema, overwrite: bool):
5858

5959
@dataclass
6060
class LanceCommitMessage(WriterCommitMessage):
61-
fragment: lance.FragmentMetadata
61+
fragments: List[lance.FragmentMetadata]
6262

6363

6464
class LanceWriter(DataSourceArrowWriter):
@@ -78,18 +78,12 @@ def _get_read_version(self):
7878
return None
7979

8080
def write(self, iterator: Iterator[pa.RecordBatch]):
81-
from pyspark import TaskContext
82-
83-
context = TaskContext.get()
84-
assert context is not None, "Unable to get TaskContext"
85-
task_id = context.taskAttemptId()
86-
8781
reader = pa.RecordBatchReader.from_batches(self.arrow_schema, iterator)
88-
fragment = lance.LanceFragment.create(self.uri, reader, fragment_id=task_id, schema=self.arrow_schema)
89-
return LanceCommitMessage(fragment=fragment)
82+
fragments = lance.fragment.write_fragments(reader, self.uri, schema=self.arrow_schema)
83+
return LanceCommitMessage(fragments=fragments)
9084

9185
def commit(self, messages):
92-
fragments = [msg.fragment for msg in messages]
86+
fragments = [fragment for msg in messages for fragment in msg.fragments ]
9387
if self.read_version:
9488
# This means the dataset already exists.
9589
op = lance.LanceOperation.Append(fragments)

0 commit comments

Comments
 (0)