|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Dict, Optional |
| 3 | + |
| 4 | +from pyspark.sql.datasource import DataSource, DataSourceReader |
| 5 | +from pyspark.sql.types import StringType, StructField, StructType |
| 6 | + |
| 7 | + |
| 8 | +@dataclass |
| 9 | +class Sheet: |
| 10 | + """ |
| 11 | + A dataclass to identify a Google Sheets document. |
| 12 | +
|
| 13 | + Attributes |
| 14 | + ---------- |
| 15 | + spreadsheet_id : str |
| 16 | + The ID of the Google Sheets document. |
| 17 | + sheet_id : str, optional |
| 18 | + The ID of the worksheet within the document. |
| 19 | + """ |
| 20 | + |
| 21 | + spreadsheet_id: str |
| 22 | + sheet_id: Optional[str] = None # if None, the first sheet is used |
| 23 | + |
| 24 | + @classmethod |
| 25 | + def from_url(cls, url: str) -> "Sheet": |
| 26 | + """ |
| 27 | + Converts a Google Sheets URL to a Sheet object. |
| 28 | + """ |
| 29 | + from urllib.parse import parse_qs, urlparse |
| 30 | + |
| 31 | + parsed = urlparse(url) |
| 32 | + if parsed.netloc != "docs.google.com" or not parsed.path.startswith( |
| 33 | + "/spreadsheets/d/" |
| 34 | + ): |
| 35 | + raise ValueError("URL is not a Google Sheets URL") |
| 36 | + qs = parse_qs(parsed.query) |
| 37 | + spreadsheet_id = parsed.path.split("/")[3] |
| 38 | + if "gid" in qs: |
| 39 | + sheet_id = qs["gid"][0] |
| 40 | + else: |
| 41 | + sheet_id = None |
| 42 | + return cls(spreadsheet_id, sheet_id) |
| 43 | + |
| 44 | + def get_query_url(self, query: Optional[str] = None): |
| 45 | + """ |
| 46 | + Gets the query url that returns the results of the query as a CSV file. |
| 47 | +
|
| 48 | + If no query is provided, returns the entire sheet. |
| 49 | + If sheet ID is None, uses the first sheet. |
| 50 | +
|
| 51 | + See https://developers.google.com/chart/interactive/docs/querylanguage |
| 52 | + """ |
| 53 | + from urllib.parse import urlencode |
| 54 | + |
| 55 | + path = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/gviz/tq" |
| 56 | + url_query = {"tqx": "out:csv"} |
| 57 | + if self.sheet_id: |
| 58 | + url_query["gid"] = self.sheet_id |
| 59 | + if query: |
| 60 | + url_query["tq"] = query |
| 61 | + return f"{path}?{urlencode(url_query)}" |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class Parameters: |
| 66 | + sheet: Sheet |
| 67 | + has_header: bool |
| 68 | + |
| 69 | + |
| 70 | +class GoogleSheetsDataSource(DataSource): |
| 71 | + """ |
| 72 | + A DataSource for reading table from public Google Sheets. |
| 73 | +
|
| 74 | + Name: `googlesheets` |
| 75 | +
|
| 76 | + Schema: By default, all columns are treated as strings and the header row defines the column names. |
| 77 | +
|
| 78 | + Options |
| 79 | + -------- |
| 80 | + - `url`: The URL of the Google Sheets document. |
| 81 | + - `path`: The ID of the Google Sheets document. |
| 82 | + - `sheet_id`: The ID of the worksheet within the document. |
| 83 | + - `has_header`: Whether the sheet has a header row. Default is `true`. |
| 84 | +
|
| 85 | + Either `url` or `path` must be specified, but not both. |
| 86 | +
|
| 87 | + Examples |
| 88 | + -------- |
| 89 | + Register the data source. |
| 90 | +
|
| 91 | + >>> from pyspark_datasources import GoogleSheetsDataSource |
| 92 | + >>> spark.dataSource.register(GoogleSheetsDataSource) |
| 93 | +
|
| 94 | + Load data from a public Google Sheets document using `path` and optional `sheet_id`. |
| 95 | +
|
| 96 | + >>> spreadsheet_id = "10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0" |
| 97 | + >>> spark.read.format("googlesheets").options(sheet_id="0").load(spreadsheet_id).show() |
| 98 | + +-------+---------+---------+-------+ |
| 99 | + |country| latitude|longitude| name| |
| 100 | + +-------+---------+---------+-------+ |
| 101 | + | AD|42.546245| 1.601554|Andorra| |
| 102 | + | ...| ...| ...| ...| |
| 103 | + +-------+---------+---------+-------+ |
| 104 | +
|
| 105 | + Load data from a public Google Sheets document using `url`. |
| 106 | +
|
| 107 | + >>> url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=0#gid=0" |
| 108 | + >>> spark.read.format("googlesheets").options(url=url).load().show() |
| 109 | + +-------+---------+--------+-------+ |
| 110 | + |country| latitude|ongitude| name| |
| 111 | + +-------+---------+--------+-------+ |
| 112 | + | AD|42.546245|1.601554|Andorra| |
| 113 | + | ...| ...| ...| ...| |
| 114 | + +-------+---------+--------+-------+ |
| 115 | +
|
| 116 | + Specify custom schema. |
| 117 | +
|
| 118 | + >>> schema = "id string, lat double, long double, name string" |
| 119 | + >>> spark.read.format("googlesheets").schema(schema).options(url=url).load().show() |
| 120 | + +---+---------+--------+-------+ |
| 121 | + | id| lat| long| name| |
| 122 | + +---+---------+--------+-------+ |
| 123 | + | AD|42.546245|1.601554|Andorra| |
| 124 | + |...| ...| ...| ...| |
| 125 | + +---+---------+--------+-------+ |
| 126 | +
|
| 127 | + Treat first row as data instead of header. |
| 128 | +
|
| 129 | + >>> schema = "c1 string, c2 string, c3 string, c4 string" |
| 130 | + >>> spark.read.format("googlesheets").schema(schema).options(url=url, has_header="false").load().show() |
| 131 | + +-------+---------+---------+-------+ |
| 132 | + | c1| c2| c3| c4| |
| 133 | + +-------+---------+---------+-------+ |
| 134 | + |country| latitude|longitude| name| |
| 135 | + | AD|42.546245| 1.601554|Andorra| |
| 136 | + | ...| ...| ...| ...| |
| 137 | + +-------+---------+---------+-------+ |
| 138 | + """ |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def name(self): |
| 142 | + return "googlesheets" |
| 143 | + |
| 144 | + def __init__(self, options: Dict[str, str]): |
| 145 | + if "url" in options: |
| 146 | + sheet = Sheet.from_url(options.pop("url")) |
| 147 | + elif "path" in options: |
| 148 | + sheet = Sheet(options.pop("path"), options.pop("sheet_id", None)) |
| 149 | + else: |
| 150 | + raise ValueError( |
| 151 | + "You must specify either `url` or `path` (spreadsheet ID)." |
| 152 | + ) |
| 153 | + has_header = options.pop("has_header", "true").lower() == "true" |
| 154 | + self.parameters = Parameters(sheet, has_header) |
| 155 | + |
| 156 | + def schema(self) -> StructType: |
| 157 | + if not self.parameters.has_header: |
| 158 | + raise ValueError("Custom schema is required when `has_header` is false") |
| 159 | + |
| 160 | + import pandas as pd |
| 161 | + |
| 162 | + # Read schema from the first row of the sheet |
| 163 | + df = pd.read_csv(self.parameters.sheet.get_query_url("select * limit 1")) |
| 164 | + return StructType([StructField(col, StringType()) for col in df.columns]) |
| 165 | + |
| 166 | + def reader(self, schema: StructType) -> DataSourceReader: |
| 167 | + return GoogleSheetsReader(self.parameters, schema) |
| 168 | + |
| 169 | + |
| 170 | +class GoogleSheetsReader(DataSourceReader): |
| 171 | + |
| 172 | + def __init__(self, parameters: Parameters, schema: StructType): |
| 173 | + self.parameters = parameters |
| 174 | + self.schema = schema |
| 175 | + |
| 176 | + def read(self, partition): |
| 177 | + from urllib.request import urlopen |
| 178 | + |
| 179 | + from pyarrow import csv |
| 180 | + from pyspark.sql.pandas.types import to_arrow_schema |
| 181 | + |
| 182 | + # Specify column types based on the schema |
| 183 | + convert_options = csv.ConvertOptions( |
| 184 | + column_types=to_arrow_schema(self.schema), |
| 185 | + ) |
| 186 | + read_options = csv.ReadOptions( |
| 187 | + column_names=self.schema.fieldNames(), # Rename columns |
| 188 | + skip_rows=( |
| 189 | + 1 if self.parameters.has_header else 0 # Skip header row if present |
| 190 | + ), |
| 191 | + ) |
| 192 | + with urlopen(self.parameters.sheet.get_query_url()) as file: |
| 193 | + yield from csv.read_csv( |
| 194 | + file, read_options=read_options, convert_options=convert_options |
| 195 | + ).to_batches() |
0 commit comments