Skip to content

Commit 0b49783

Browse files
authored
Add Google Sheets data source (#6)
* add Google Sheets data source * support specifying schema * add google sheets schema test * add more comments * add docs * add failure tests * add tests for custom schema errors * update API to support skipping header and pass spreadsheet id as path
1 parent c31aec0 commit 0b49783

File tree

6 files changed

+312
-8
lines changed

6 files changed

+312
-8
lines changed

docs/datasources/googlesheets.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# GoogleSheetsDataSource
2+
3+
::: pyspark_datasources.googlesheets.GoogleSheetsDataSource

docs/index.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ spark.read.format("github").load("apache/spark").show()
2828

2929
## Data Sources
3030

31-
| Data Source | Short Name | Description | Dependencies |
32-
|-----------------------------------------------------|---------------|---------------------------------------------|-----------------|
33-
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
34-
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
35-
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
36-
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
37-
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk`|
31+
| Data Source | Short Name | Description | Dependencies |
32+
| ------------------------------------------------------- | -------------- | --------------------------------------------- | ---------------- |
33+
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
34+
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
35+
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
36+
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
37+
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` |
38+
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ nav:
2323
- datasources/huggingface.md
2424
- datasources/stock.md
2525
- datasources/simplejson.md
26+
- datasources/googlesheets.md
2627

2728
markdown_extensions:
2829
- pymdownx.highlight:

pyspark_datasources/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .fake import FakeDataSource
22
from .github import GithubDataSource
3+
from .googlesheets import GoogleSheetsDataSource
34
from .huggingface import HuggingFaceDatasets
4-
from .stock import StockDataSource
55
from .simplejson import SimpleJsonDataSource
6+
from .stock import StockDataSource
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from dataclasses import dataclass
2+
from typing import Dict, Optional
3+
4+
from pyspark.sql.datasource import DataSource, DataSourceReader
5+
from pyspark.sql.types import StringType, StructField, StructType
6+
7+
8+
@dataclass
9+
class Sheet:
10+
"""
11+
A dataclass to identify a Google Sheets document.
12+
13+
Attributes
14+
----------
15+
spreadsheet_id : str
16+
The ID of the Google Sheets document.
17+
sheet_id : str, optional
18+
The ID of the worksheet within the document.
19+
"""
20+
21+
spreadsheet_id: str
22+
sheet_id: Optional[str] = None # if None, the first sheet is used
23+
24+
@classmethod
25+
def from_url(cls, url: str) -> "Sheet":
26+
"""
27+
Converts a Google Sheets URL to a Sheet object.
28+
"""
29+
from urllib.parse import parse_qs, urlparse
30+
31+
parsed = urlparse(url)
32+
if parsed.netloc != "docs.google.com" or not parsed.path.startswith(
33+
"/spreadsheets/d/"
34+
):
35+
raise ValueError("URL is not a Google Sheets URL")
36+
qs = parse_qs(parsed.query)
37+
spreadsheet_id = parsed.path.split("/")[3]
38+
if "gid" in qs:
39+
sheet_id = qs["gid"][0]
40+
else:
41+
sheet_id = None
42+
return cls(spreadsheet_id, sheet_id)
43+
44+
def get_query_url(self, query: Optional[str] = None):
45+
"""
46+
Gets the query url that returns the results of the query as a CSV file.
47+
48+
If no query is provided, returns the entire sheet.
49+
If sheet ID is None, uses the first sheet.
50+
51+
See https://developers.google.com/chart/interactive/docs/querylanguage
52+
"""
53+
from urllib.parse import urlencode
54+
55+
path = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/gviz/tq"
56+
url_query = {"tqx": "out:csv"}
57+
if self.sheet_id:
58+
url_query["gid"] = self.sheet_id
59+
if query:
60+
url_query["tq"] = query
61+
return f"{path}?{urlencode(url_query)}"
62+
63+
64+
@dataclass
65+
class Parameters:
66+
sheet: Sheet
67+
has_header: bool
68+
69+
70+
class GoogleSheetsDataSource(DataSource):
71+
"""
72+
A DataSource for reading table from public Google Sheets.
73+
74+
Name: `googlesheets`
75+
76+
Schema: By default, all columns are treated as strings and the header row defines the column names.
77+
78+
Options
79+
--------
80+
- `url`: The URL of the Google Sheets document.
81+
- `path`: The ID of the Google Sheets document.
82+
- `sheet_id`: The ID of the worksheet within the document.
83+
- `has_header`: Whether the sheet has a header row. Default is `true`.
84+
85+
Either `url` or `path` must be specified, but not both.
86+
87+
Examples
88+
--------
89+
Register the data source.
90+
91+
>>> from pyspark_datasources import GoogleSheetsDataSource
92+
>>> spark.dataSource.register(GoogleSheetsDataSource)
93+
94+
Load data from a public Google Sheets document using `path` and optional `sheet_id`.
95+
96+
>>> spreadsheet_id = "10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0"
97+
>>> spark.read.format("googlesheets").options(sheet_id="0").load(spreadsheet_id).show()
98+
+-------+---------+---------+-------+
99+
|country| latitude|longitude| name|
100+
+-------+---------+---------+-------+
101+
| AD|42.546245| 1.601554|Andorra|
102+
| ...| ...| ...| ...|
103+
+-------+---------+---------+-------+
104+
105+
Load data from a public Google Sheets document using `url`.
106+
107+
>>> url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=0#gid=0"
108+
>>> spark.read.format("googlesheets").options(url=url).load().show()
109+
+-------+---------+--------+-------+
110+
|country| latitude|ongitude| name|
111+
+-------+---------+--------+-------+
112+
| AD|42.546245|1.601554|Andorra|
113+
| ...| ...| ...| ...|
114+
+-------+---------+--------+-------+
115+
116+
Specify custom schema.
117+
118+
>>> schema = "id string, lat double, long double, name string"
119+
>>> spark.read.format("googlesheets").schema(schema).options(url=url).load().show()
120+
+---+---------+--------+-------+
121+
| id| lat| long| name|
122+
+---+---------+--------+-------+
123+
| AD|42.546245|1.601554|Andorra|
124+
|...| ...| ...| ...|
125+
+---+---------+--------+-------+
126+
127+
Treat first row as data instead of header.
128+
129+
>>> schema = "c1 string, c2 string, c3 string, c4 string"
130+
>>> spark.read.format("googlesheets").schema(schema).options(url=url, has_header="false").load().show()
131+
+-------+---------+---------+-------+
132+
| c1| c2| c3| c4|
133+
+-------+---------+---------+-------+
134+
|country| latitude|longitude| name|
135+
| AD|42.546245| 1.601554|Andorra|
136+
| ...| ...| ...| ...|
137+
+-------+---------+---------+-------+
138+
"""
139+
140+
@classmethod
141+
def name(self):
142+
return "googlesheets"
143+
144+
def __init__(self, options: Dict[str, str]):
145+
if "url" in options:
146+
sheet = Sheet.from_url(options.pop("url"))
147+
elif "path" in options:
148+
sheet = Sheet(options.pop("path"), options.pop("sheet_id", None))
149+
else:
150+
raise ValueError(
151+
"You must specify either `url` or `path` (spreadsheet ID)."
152+
)
153+
has_header = options.pop("has_header", "true").lower() == "true"
154+
self.parameters = Parameters(sheet, has_header)
155+
156+
def schema(self) -> StructType:
157+
if not self.parameters.has_header:
158+
raise ValueError("Custom schema is required when `has_header` is false")
159+
160+
import pandas as pd
161+
162+
# Read schema from the first row of the sheet
163+
df = pd.read_csv(self.parameters.sheet.get_query_url("select * limit 1"))
164+
return StructType([StructField(col, StringType()) for col in df.columns])
165+
166+
def reader(self, schema: StructType) -> DataSourceReader:
167+
return GoogleSheetsReader(self.parameters, schema)
168+
169+
170+
class GoogleSheetsReader(DataSourceReader):
171+
172+
def __init__(self, parameters: Parameters, schema: StructType):
173+
self.parameters = parameters
174+
self.schema = schema
175+
176+
def read(self, partition):
177+
from urllib.request import urlopen
178+
179+
from pyarrow import csv
180+
from pyspark.sql.pandas.types import to_arrow_schema
181+
182+
# Specify column types based on the schema
183+
convert_options = csv.ConvertOptions(
184+
column_types=to_arrow_schema(self.schema),
185+
)
186+
read_options = csv.ReadOptions(
187+
column_names=self.schema.fieldNames(), # Rename columns
188+
skip_rows=(
189+
1 if self.parameters.has_header else 0 # Skip header row if present
190+
),
191+
)
192+
with urlopen(self.parameters.sheet.get_query_url()) as file:
193+
yield from csv.read_csv(
194+
file, read_options=read_options, convert_options=convert_options
195+
).to_batches()

tests/test_google_sheets.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
from pyspark.errors.exceptions.captured import AnalysisException, PythonException
3+
from pyspark.sql import SparkSession
4+
5+
from pyspark_datasources import GoogleSheetsDataSource
6+
7+
8+
@pytest.fixture(scope="module")
9+
def spark():
10+
spark = SparkSession.builder.getOrCreate()
11+
spark.dataSource.register(GoogleSheetsDataSource)
12+
yield spark
13+
14+
15+
def test_url(spark):
16+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
17+
df = spark.read.format("googlesheets").options(url=url).load()
18+
df.show()
19+
assert df.count() == 2
20+
assert len(df.columns) == 2
21+
assert df.schema.simpleString() == "struct<num:string,name:string>"
22+
23+
24+
def test_spreadsheet_id(spark):
25+
df = spark.read.format("googlesheets").load(
26+
"10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0"
27+
)
28+
df.show()
29+
assert df.count() == 2
30+
assert len(df.columns) == 2
31+
32+
33+
def test_missing_options(spark):
34+
with pytest.raises(AnalysisException) as excinfo:
35+
spark.read.format("googlesheets").load()
36+
assert "ValueError" in str(excinfo.value)
37+
38+
39+
def test_mutual_exclusive_options(spark):
40+
with pytest.raises(AnalysisException) as excinfo:
41+
spark.read.format("googlesheets").options(
42+
url="a",
43+
spreadsheet_id="b",
44+
).load()
45+
assert "ValueError" in str(excinfo.value)
46+
47+
48+
def test_custom_schema(spark):
49+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
50+
df = (
51+
spark.read.format("googlesheets")
52+
.options(url=url)
53+
.schema("a double, b string")
54+
.load()
55+
)
56+
df.show()
57+
assert df.count() == 2
58+
assert len(df.columns) == 2
59+
assert df.schema.simpleString() == "struct<a:double,b:string>"
60+
61+
62+
def test_custom_schema_mismatch_count(spark):
63+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
64+
df = spark.read.format("googlesheets").options(url=url).schema("a double").load()
65+
with pytest.raises(PythonException) as excinfo:
66+
df.show()
67+
assert "CSV parse error" in str(excinfo.value)
68+
69+
70+
def test_unnamed_column(spark):
71+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727"
72+
df = spark.read.format("googlesheets").options(url=url).load()
73+
df.show()
74+
assert df.count() == 1
75+
assert df.columns == ["Unnamed: 0", "1", "Unnamed: 2"]
76+
77+
78+
def test_duplicate_column(spark):
79+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1875209731#gid=1875209731"
80+
df = spark.read.format("googlesheets").options(url=url).load()
81+
df.show()
82+
assert df.count() == 1
83+
assert df.columns == ["a", "a.1"]
84+
85+
86+
def test_no_header_row(spark):
87+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727"
88+
df = (
89+
spark.read.format("googlesheets")
90+
.schema("a int, b int, c int")
91+
.options(url=url, has_header="false")
92+
.load()
93+
)
94+
df.show()
95+
assert df.count() == 2
96+
assert len(df.columns) == 3
97+
98+
99+
def test_empty(spark):
100+
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=2123944555#gid=2123944555"
101+
with pytest.raises(AnalysisException) as excinfo:
102+
spark.read.format("googlesheets").options(url=url).load()
103+
assert "EmptyDataError" in str(excinfo.value)

0 commit comments

Comments
 (0)