|
| 1 | +import io |
| 2 | +import json |
| 3 | +import time |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Dict, List |
| 7 | + |
| 8 | +from pyspark.sql.types import StructType |
| 9 | +from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage |
| 10 | + |
| 11 | + |
| 12 | +class SimpleJsonDataSource(DataSource): |
| 13 | + """ |
| 14 | + A simple json writer for writing data to Databricks DBFS. |
| 15 | +
|
| 16 | + Examples |
| 17 | + -------- |
| 18 | +
|
| 19 | + >>> import pyspark.sql.functions as sf |
| 20 | + >>> df = spark.range(0, 10, 1, 2).withColumn("value", sf.expr("concat('value_', id)")) |
| 21 | +
|
| 22 | + Register the data source. |
| 23 | +
|
| 24 | + >>> from pyspark_datasources import SimpleJsonDataSource |
| 25 | + >>> spark.dataSource.register(SimpleJsonDataSource) |
| 26 | +
|
| 27 | + Append the DataFrame to a DBFS path as json files. |
| 28 | +
|
| 29 | + >>> ( |
| 30 | + ... df.write.format("simplejson") |
| 31 | + ... .mode("append") |
| 32 | + ... .option("databricks_url", "https://your-databricks-instance.cloud.databricks.com") |
| 33 | + ... .option("databricks_token", "your-token") |
| 34 | + ... .save("/path/to/output") |
| 35 | + ... ) |
| 36 | +
|
| 37 | + Overwrite the DataFrame to a DBFS path as json files. |
| 38 | +
|
| 39 | + >>> ( |
| 40 | + ... df.write.format("simplejson") |
| 41 | + ... .mode("overwrite") |
| 42 | + ... .option("databricks_url", "https://your-databricks-instance.cloud.databricks.com") |
| 43 | + ... .option("databricks_token", "your-token") |
| 44 | + ... .save("/path/to/output") |
| 45 | + ... ) |
| 46 | + """ |
| 47 | + @classmethod |
| 48 | + def name(self) -> str: |
| 49 | + return "simplejson" |
| 50 | + |
| 51 | + def writer(self, schema: StructType, overwrite: bool): |
| 52 | + return SimpleJsonWriter(schema, self.options, overwrite) |
| 53 | + |
| 54 | + |
| 55 | +@dataclass |
| 56 | +class CommitMessage(WriterCommitMessage): |
| 57 | + output_path: str |
| 58 | + |
| 59 | + |
| 60 | +class SimpleJsonWriter(DataSourceWriter): |
| 61 | + def __init__(self, schema: StructType, options: Dict, overwrite: bool): |
| 62 | + self.overwrite = overwrite |
| 63 | + self.databricks_url = options.get("databricks_url") |
| 64 | + self.databricks_token = options.get("databricks_token") |
| 65 | + if not self.databricks_url or not self.databricks_token: |
| 66 | + raise Exception("Databricks URL and token must be specified") |
| 67 | + self.path = options.get("path") |
| 68 | + if not self.path: |
| 69 | + raise Exception("You must specify an output path") |
| 70 | + |
| 71 | + def write(self, iterator): |
| 72 | + # Important: Always import non-serializable libraries inside the `write` method. |
| 73 | + from pyspark import TaskContext |
| 74 | + from databricks.sdk import WorkspaceClient |
| 75 | + |
| 76 | + # Consume all input rows and dump them as json. |
| 77 | + rows = [row.asDict() for row in iterator] |
| 78 | + json_data = json.dumps(rows) |
| 79 | + f = io.BytesIO(json_data.encode('utf-8')) |
| 80 | + |
| 81 | + context = TaskContext.get() |
| 82 | + id = context.taskAttemptId() |
| 83 | + file_path = f"{self.path}/{id}_{time.time_ns()}.json" |
| 84 | + |
| 85 | + # Upload to DFBS. |
| 86 | + w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) |
| 87 | + w.dbfs.upload(file_path, f) |
| 88 | + |
| 89 | + return CommitMessage(output_path=file_path) |
| 90 | + |
| 91 | + def commit(self, messages: List[CommitMessage]): |
| 92 | + from databricks.sdk import WorkspaceClient |
| 93 | + |
| 94 | + w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) |
| 95 | + paths = [message.output_path for message in messages] |
| 96 | + |
| 97 | + if self.overwrite: |
| 98 | + # Remove all files in the current directory except for the newly written files. |
| 99 | + for file in w.dbfs.list(self.path): |
| 100 | + if file.path not in paths: |
| 101 | + print(f"[Overwrite] Removing file {file.path}") |
| 102 | + w.dbfs.delete(file.path) |
| 103 | + |
| 104 | + # Write a success file |
| 105 | + file_path = f"{self.path}/_SUCCESS" |
| 106 | + f = io.BytesIO(b"success") |
| 107 | + w.dbfs.upload(file_path, f, overwrite=True) |
| 108 | + |
| 109 | + def abort(self, messages: List[CommitMessage]): |
| 110 | + from databricks.sdk import WorkspaceClient |
| 111 | + |
| 112 | + w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) |
| 113 | + # Clean up the newly written files |
| 114 | + for message in messages: |
| 115 | + if message is not None: |
| 116 | + print(f"[Abort] Removing up partially written files: {message.output_path}") |
| 117 | + w.dbfs.delete(message.output_path) |
0 commit comments