|
| 1 | +import lance |
| 2 | +import pyarrow as pa |
| 3 | + |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Iterator |
| 6 | +from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage |
| 7 | +from pyspark.sql.pandas.types import to_arrow_schema |
| 8 | + |
| 9 | + |
| 10 | +class LanceSink(DataSource): |
| 11 | + """ |
| 12 | + Write a Spark DataFrame into Lance format: https://lancedb.github.io/lance/index.html |
| 13 | +
|
| 14 | + Note this requires Spark master branch nightly build to support `DataSourceArrowWriter`. |
| 15 | +
|
| 16 | + Examples |
| 17 | + -------- |
| 18 | + Register the data source: |
| 19 | +
|
| 20 | + >>> from pyspark_datasources import LanceSink |
| 21 | + >>> spark.dataSource.register(LanceSink) |
| 22 | +
|
| 23 | + Create a Spark dataframe with 2 partitions: |
| 24 | +
|
| 25 | + >>> df = spark.range(0, 3, 1, 2) |
| 26 | +
|
| 27 | + Save the dataframe in lance format: |
| 28 | +
|
| 29 | + >>> df.write.format("lance").mode("append").save("/tmp/test_lance") |
| 30 | + /tmp/test_lance |
| 31 | + _transactions _versions data |
| 32 | +
|
| 33 | + Then you can use lance API to read the dataset: |
| 34 | +
|
| 35 | + >>> import lance |
| 36 | + >>> ds = lance.LanceDataset("/tmp/test_lance") |
| 37 | + >>> ds.to_table().to_pandas() |
| 38 | + id |
| 39 | + 0 0 |
| 40 | + 1 1 |
| 41 | + 2 2 |
| 42 | +
|
| 43 | + Notes |
| 44 | + ----- |
| 45 | + - Currently this only works with Spark local mode. Cluster mode is not supported. |
| 46 | + """ |
| 47 | + @classmethod |
| 48 | + def name(cls) -> str: |
| 49 | + return "lance" |
| 50 | + |
| 51 | + def writer(self, schema, overwrite: bool): |
| 52 | + if overwrite: |
| 53 | + raise Exception("Overwrite mode is not supported") |
| 54 | + if "path" not in self.options: |
| 55 | + raise Exception("Dataset URI must be specified when calling save()") |
| 56 | + return LanceWriter(schema, overwrite, self.options) |
| 57 | + |
| 58 | + |
| 59 | +@dataclass |
| 60 | +class LanceCommitMessage(WriterCommitMessage): |
| 61 | + fragment: lance.FragmentMetadata |
| 62 | + |
| 63 | + |
| 64 | +class LanceWriter(DataSourceArrowWriter): |
| 65 | + def __init__(self, schema, overwrite, options): |
| 66 | + self.options = options |
| 67 | + self.schema = schema # Spark Schema (pyspark.sql.types.StructType) |
| 68 | + self.arrow_schema = to_arrow_schema(schema) # Arrow schema (pa.StructType) |
| 69 | + self.uri = options["path"] |
| 70 | + assert not overwrite |
| 71 | + self.read_version = self._get_read_version() |
| 72 | + |
| 73 | + def _get_read_version(self): |
| 74 | + try: |
| 75 | + ds = lance.LanceDataset(self.uri) |
| 76 | + return ds.version |
| 77 | + except Exception: |
| 78 | + return None |
| 79 | + |
| 80 | + def write(self, iterator: Iterator[pa.RecordBatch]): |
| 81 | + from pyspark import TaskContext |
| 82 | + |
| 83 | + context = TaskContext.get() |
| 84 | + assert context is not None, "Unable to get TaskContext" |
| 85 | + task_id = context.taskAttemptId() |
| 86 | + |
| 87 | + reader = pa.RecordBatchReader.from_batches(self.arrow_schema, iterator) |
| 88 | + fragment = lance.LanceFragment.create(self.uri, reader, fragment_id=task_id, schema=self.arrow_schema) |
| 89 | + return LanceCommitMessage(fragment=fragment) |
| 90 | + |
| 91 | + def commit(self, messages): |
| 92 | + fragments = [msg.fragment for msg in messages] |
| 93 | + if self.read_version: |
| 94 | + # This means the dataset already exists. |
| 95 | + op = lance.LanceOperation.Append(fragments) |
| 96 | + else: |
| 97 | + # Create a new dataset. |
| 98 | + schema = to_arrow_schema(self.schema) |
| 99 | + op = lance.LanceOperation.Overwrite(schema, fragments) |
| 100 | + lance.LanceDataset.commit(self.uri, op, read_version=self.read_version) |
0 commit comments