Skip to content

Commit 3e7564d

Browse files
authored
Add Kaggle datasets (#9)
* Add Kaggle datasets * fix * cache in temp directory * add dependencies * fix dependencies * fix dependencies * fix * update docs * lock
1 parent 7495ae7 commit 3e7564d

File tree

8 files changed

+318
-14
lines changed

8 files changed

+318
-14
lines changed

docs/datasources/kaggle.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# KaggleDataSource
2+
3+
> Requires the [`kagglehub`](https://github.com/Kaggle/kagglehub) library.
4+
5+
::: pyspark_datasources.kaggle.KaggleDataSource

docs/index.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ spark.read.format("github").load("apache/spark").show()
2828

2929
## Data Sources
3030

31-
| Data Source | Short Name | Description | Dependencies |
32-
| ------------------------------------------------------- | -------------- | --------------------------------------------- | ---------------- |
33-
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
34-
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
35-
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
36-
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
37-
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` |
38-
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
31+
| Data Source | Short Name | Description | Dependencies |
32+
| ------------------------------------------------------- | -------------- | --------------------------------------------- | --------------------- |
33+
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
34+
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
35+
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
36+
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
37+
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` |
38+
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
39+
| [KaggleDataSource](./datasources/kaggle.md) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` |

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- datasources/stock.md
2525
- datasources/simplejson.md
2626
- datasources/googlesheets.md
27+
- datasources/kaggle.md
2728

2829
markdown_extensions:
2930
- pymdownx.highlight:

poetry.lock

Lines changed: 178 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,27 @@ packages = [
1111

1212
[tool.poetry.dependencies]
1313
python = ">=3.9,<=3.12"
14+
pyarrow = ">=11.0.0"
1415
requests = "^2.31.0"
1516
faker = {version = "^23.1.0", optional = true}
1617
mkdocstrings = {extras = ["python"], version = "^0.24.0"}
1718
datasets = {version = "^2.17.0", optional = true}
1819
databricks-sdk = {version = "^0.28.0", optional = true}
20+
kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true}
1921

2022
[tool.poetry.extras]
2123
faker = ["faker"]
2224
datasets = ["datasets"]
2325
databricks = ["databricks-sdk"]
26+
kaggle = ["kagglehub"]
2427
lance = ["pylance"]
25-
all = ["faker", "datasets", "databricks"]
28+
all = ["faker", "datasets", "databricks-sdk", "kagglehub"]
2629

2730
[tool.poetry.group.dev.dependencies]
2831
pytest = "^8.0.0"
2932
grpcio = "^1.60.1"
3033
grpcio-status = "^1.60.1"
3134
pandas = "^2.2.0"
32-
pyarrow = "^15.0.0"
3335
mkdocs-material = "^9.5.9"
3436

3537
[build-system]

pyspark_datasources/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .github import GithubDataSource
33
from .googlesheets import GoogleSheetsDataSource
44
from .huggingface import HuggingFaceDatasets
5+
from .kaggle import KaggleDataSource
56
from .simplejson import SimpleJsonDataSource
67
from .stock import StockDataSource

pyspark_datasources/kaggle.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import tempfile
2+
from functools import cached_property
3+
from typing import TYPE_CHECKING, Iterator
4+
5+
from pyspark.sql.datasource import DataSource, DataSourceReader
6+
from pyspark.sql.pandas.types import from_arrow_schema
7+
from pyspark.sql.types import StructType
8+
9+
if TYPE_CHECKING:
10+
import pyarrow as pa
11+
12+
13+
class KaggleDataSource(DataSource):
14+
"""
15+
A DataSource for reading Kaggle datasets in Spark.
16+
17+
This data source allows reading datasets from Kaggle directly into Spark DataFrames.
18+
19+
Name: `kaggle`
20+
21+
Options
22+
-------
23+
- `handle`: The dataset handle on Kaggle, in the form of `{owner_slug}/{dataset_slug}`
24+
or `{owner_slug}/{dataset_slug}/versions/{version_number}`
25+
- `path`: The path to a file within the dataset.
26+
- `username`: The Kaggle username for authentication.
27+
- `key`: The Kaggle API key for authentication.
28+
29+
Notes:
30+
-----
31+
- The `kagglehub` library is required to use this data source. Make sure it is installed.
32+
- To read private datasets or datasets that require user authentication, `username` and `key` must be provided.
33+
- Currently all data is read from a single partition.
34+
35+
Examples
36+
--------
37+
Register the data source.
38+
39+
>>> from pyspark_datasources import KaggleDataSource
40+
>>> spark.dataSource.register(KaggleDataSource)
41+
42+
Load a public dataset from Kaggle.
43+
44+
>>> spark.read.format("kaggle").options(handle="yasserh/titanic-dataset").load("Titanic-Dataset.csv").select("Name").show()
45+
+--------------------+
46+
| Name|
47+
+--------------------+
48+
|Braund, Mr. Owen ...|
49+
|Cumings, Mrs. Joh...|
50+
|... |
51+
+--------------------+
52+
53+
Load a private dataset with authentication.
54+
55+
>>> spark.read.format("kaggle").options(
56+
... username="myaccount",
57+
... key="<token>",
58+
... handle="myaccount/my-private-dataset",
59+
... ).load("file.csv").show()
60+
"""
61+
62+
@classmethod
63+
def name(cls) -> str:
64+
return "kaggle"
65+
66+
@cached_property
67+
def _data(self) -> "pa.Table":
68+
import ast
69+
import os
70+
71+
import pyarrow as pa
72+
73+
handle = self.options.pop("handle")
74+
path = self.options.pop("path")
75+
username = self.options.pop("username", None)
76+
key = self.options.pop("key", None)
77+
if username or key:
78+
if not (username and key):
79+
raise ValueError(
80+
"Both username and key must be provided to authenticate."
81+
)
82+
os.environ["KAGGLE_USERNAME"] = username
83+
os.environ["KAGGLE_KEY"] = key
84+
85+
kwargs = {k: ast.literal_eval(v) for k, v in self.options.items()}
86+
87+
# Cache in a temporary directory to avoid writing to ~ which may be read-only
88+
with tempfile.TemporaryDirectory() as tmpdir:
89+
os.environ["KAGGLEHUB_CACHE"] = tmpdir
90+
import kagglehub
91+
92+
df = kagglehub.dataset_load(
93+
kagglehub.KaggleDatasetAdapter.PANDAS,
94+
handle,
95+
path,
96+
**kwargs,
97+
)
98+
return pa.Table.from_pandas(df)
99+
100+
def schema(self) -> StructType:
101+
return from_arrow_schema(self._data.schema)
102+
103+
def reader(self, schema: StructType) -> "KaggleDataReader":
104+
return KaggleDataReader(self)
105+
106+
107+
class KaggleDataReader(DataSourceReader):
108+
def __init__(self, source: KaggleDataSource):
109+
self.source = source
110+
111+
def read(self, partition) -> Iterator["pa.RecordBatch"]:
112+
yield from self.source._data.to_batches()

tests/test_data_sources.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,11 @@ def test_fake_datasource(spark):
2323
df.show()
2424
assert df.count() == 3
2525
assert len(df.columns) == 4
26+
27+
28+
def test_kaggle_datasource(spark):
29+
spark.dataSource.register(KaggleDataSource)
30+
df = spark.read.format("kaggle").options(handle="yasserh/titanic-dataset").load("Titanic-Dataset.csv")
31+
df.show()
32+
assert df.count() == 891
33+
assert len(df.columns) == 12

0 commit comments

Comments
 (0)