Skip to content

Commit e8f68cb

Browse files
add lance sink example
1 parent 3e466fb commit e8f68cb

File tree

3 files changed

+105
-3
lines changed

3 files changed

+105
-3
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ pip install pyspark-data-sources[all]
1515

1616
## Usage
1717

18-
Install the pyspark 4.0 preview version: https://pypi.org/project/pyspark/4.0.0.dev1/
18+
Install the pyspark 4.0 [preview version](https://pypi.org/project/pyspark/4.0.0.dev2/)
1919

2020
```
21-
pip install "pyspark[connect]==4.0.0.dev1"
21+
pip install "pyspark[connect]==4.0.0.dev2"
2222
```
2323

24-
Or use Databricks Runtime 15.2 or above.
24+
Or use Databricks Runtime 15.4 LTS or above.
2525

2626
Try the data sources!
2727

@@ -48,6 +48,7 @@ We welcome and appreciate any contributions to enhance and expand the custom dat
4848
## Development
4949

5050
```
51+
poetry install
5152
poetry shell
5253
```
5354

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ databricks-sdk = {version = "^0.28.0", optional = true}
2121
faker = ["faker"]
2222
datasets = ["datasets"]
2323
databricks = ["databricks-sdk"]
24+
lance = ["pylance"]
2425
all = ["faker", "datasets", "databricks"]
2526

2627
[tool.poetry.group.dev.dependencies]

pyspark_datasources/lance.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import lance
2+
import pyarrow as pa
3+
4+
from dataclasses import dataclass
5+
from typing import Iterator
6+
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage
7+
from pyspark.sql.pandas.types import to_arrow_schema
8+
9+
10+
class LanceSink(DataSource):
11+
"""
12+
Write a Spark DataFrame into Lance format: https://lancedb.github.io/lance/index.html
13+
14+
Note this requires Spark master branch nightly build to support `DataSourceArrowWriter`.
15+
16+
Examples
17+
--------
18+
Register the data source:
19+
20+
>>> from pyspark_datasources import LanceSink
21+
>>> spark.dataSource.register(LanceSink)
22+
23+
Create a Spark dataframe with 2 partitions:
24+
25+
>>> df = spark.range(0, 3, 1, 2)
26+
27+
Save the dataframe in lance format:
28+
29+
>>> df.write.format("lance").mode("append").save("/tmp/test_lance")
30+
/tmp/test_lance
31+
_transactions _versions data
32+
33+
Then you can use lance API to read the dataset:
34+
35+
>>> import lance
36+
>>> ds = lance.LanceDataset("/tmp/test_lance")
37+
>>> ds.to_table().to_pandas()
38+
id
39+
0 0
40+
1 1
41+
2 2
42+
43+
Notes
44+
-----
45+
- Currently this only works with Spark local mode. Cluster mode is not supported.
46+
"""
47+
@classmethod
48+
def name(cls) -> str:
49+
return "lance"
50+
51+
def writer(self, schema, overwrite: bool):
52+
if overwrite:
53+
raise Exception("Overwrite mode is not supported")
54+
if "path" not in self.options:
55+
raise Exception("Dataset URI must be specified when calling save()")
56+
return LanceWriter(schema, overwrite, self.options)
57+
58+
59+
@dataclass
60+
class LanceCommitMessage(WriterCommitMessage):
61+
fragment: lance.FragmentMetadata
62+
63+
64+
class LanceWriter(DataSourceArrowWriter):
65+
def __init__(self, schema, overwrite, options):
66+
self.options = options
67+
self.schema = schema # Spark Schema (pyspark.sql.types.StructType)
68+
self.arrow_schema = to_arrow_schema(schema) # Arrow schema (pa.StructType)
69+
self.uri = options["path"]
70+
assert not overwrite
71+
self.read_version = self._get_read_version()
72+
73+
def _get_read_version(self):
74+
try:
75+
ds = lance.LanceDataset(self.uri)
76+
return ds.version
77+
except Exception:
78+
return None
79+
80+
def write(self, iterator: Iterator[pa.RecordBatch]):
81+
from pyspark import TaskContext
82+
83+
context = TaskContext.get()
84+
assert context is not None, "Unable to get TaskContext"
85+
task_id = context.taskAttemptId()
86+
87+
reader = pa.RecordBatchReader.from_batches(self.arrow_schema, iterator)
88+
fragment = lance.LanceFragment.create(self.uri, reader, fragment_id=task_id, schema=self.arrow_schema)
89+
return LanceCommitMessage(fragment=fragment)
90+
91+
def commit(self, messages):
92+
fragments = [msg.fragment for msg in messages]
93+
if self.read_version:
94+
# This means the dataset already exists.
95+
op = lance.LanceOperation.Append(fragments)
96+
else:
97+
# Create a new dataset.
98+
schema = to_arrow_schema(self.schema)
99+
op = lance.LanceOperation.Overwrite(schema, fragments)
100+
lance.LanceDataset.commit(self.uri, op, read_version=self.read_version)

0 commit comments

Comments
 (0)