Skip to content

Commit 8dd7785

Browse files
authored
Merge pull request #243 from lucasimi/parametric-types
Parametric types
2 parents 5ba1c50 + 6fb1f8e commit 8dd7785

File tree

22 files changed

+345
-263
lines changed

22 files changed

+345
-263
lines changed

.github/workflows/test-unit.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ jobs:
4949
run: |
5050
python -m pip install -e .[dev]
5151
52+
- name: Run typechecks
53+
run: |
54+
mypy src tests --ignore-missing-imports
55+
5256
- name: Run tests and code coverage
5357
run: |
5458
coverage run --source=src -m pytest tests/test_unit_*.py

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ Homepage = "https://github.com/lucasimi/tda-mapper-python"
7272
Documentation = "https://tda-mapper.readthedocs.io"
7373
Issues = "https://github.com/lucasimi/tda-mapper-python/issues"
7474

75+
[tool.setuptools.packages.find]
76+
where = ["src"]
77+
78+
[tool.setuptools.package-data]
79+
"tdamapper" = ["py.typed"]
80+
7581
[tool.coverage.run]
7682
omit = [
7783
"**/_*.py",

src/tdamapper/_common.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,14 @@
88
import io
99
import pstats
1010
import warnings
11-
from typing import Any, Callable, Iterator, Protocol, TypeVar
11+
from typing import Any, Callable
1212

1313
import numpy as np
1414
from numpy.typing import NDArray
1515

16-
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
17-
18-
T = TypeVar("T")
19-
20-
21-
class Array(Protocol[T]):
22-
23-
def __getitem__(self, index: int) -> T:
24-
"""
25-
Get an item from the array.
26-
"""
27-
28-
def __len__(self) -> int:
29-
"""
30-
Get the length of the array.
31-
"""
32-
33-
def __setitem__(self, index: int, value: T) -> None:
34-
"""
35-
Set an item in the array.
36-
"""
16+
from tdamapper.protocols import Array, ArrayRead
3717

38-
def __iter__(self) -> Iterator[T]:
39-
"""
40-
Iterate over the array.
41-
"""
18+
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
4219

4320

4421
def deprecated(msg: str) -> Callable[..., Any]:
@@ -58,48 +35,52 @@ def warn_user(msg: str) -> None:
5835

5936
class EstimatorMixin:
6037

61-
def _is_sparse(self, X: Array[Any]) -> bool:
38+
def _is_sparse(self, X: ArrayRead[Any]) -> bool:
6239
# simple alternative use scipy.sparse.issparse
6340
return hasattr(X, "toarray")
6441

6542
def _validate_X_y(
66-
self, X: Array[Any], y: Array[Any]
43+
self, X: ArrayRead[Any], y: ArrayRead[Any]
6744
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
6845
if self._is_sparse(X):
6946
raise ValueError("Sparse data not supported.")
7047

71-
X = np.asarray(X)
72-
y = np.asarray(y)
48+
X_ = np.asarray(X)
49+
y_ = np.asarray(y)
7350

74-
if X.size == 0:
75-
msg = f"0 feature(s) (shape={X.shape}) while a minimum of 1 is " "required."
51+
if X_.size == 0:
52+
msg = (
53+
f"0 feature(s) (shape={X_.shape}) while a minimum of 1 is " "required."
54+
)
7655
raise ValueError(msg)
7756

78-
if y.size == 0:
79-
msg = f"0 feature(s) (shape={y.shape}) while a minimum of 1 is " "required."
57+
if y_.size == 0:
58+
msg = (
59+
f"0 feature(s) (shape={y_.shape}) while a minimum of 1 is " "required."
60+
)
8061
raise ValueError(msg)
8162

82-
if X.ndim == 1:
63+
if X_.ndim == 1:
8364
raise ValueError("1d-arrays not supported.")
8465

85-
if np.iscomplexobj(X) or np.iscomplexobj(y):
66+
if np.iscomplexobj(X_) or np.iscomplexobj(y_):
8667
raise ValueError("Complex data not supported.")
8768

88-
if X.dtype == np.object_:
89-
X = np.array(X, dtype=float)
69+
if X_.dtype == np.object_:
70+
X_ = np.array(X_, dtype=float)
9071

91-
if y.dtype == np.object_:
92-
y = np.array(y, dtype=float)
72+
if y_.dtype == np.object_:
73+
y_ = np.array(y_, dtype=float)
9374

9475
if (
95-
np.isnan(X).any()
96-
or np.isinf(X).any()
97-
or np.isnan(y).any()
98-
or np.isinf(y).any()
76+
np.isnan(X_).any()
77+
or np.isinf(X_).any()
78+
or np.isnan(y_).any()
79+
or np.isinf(y_).any()
9980
):
10081
raise ValueError("NaNs or infinite values not supported.")
10182

102-
return X, y
83+
return X_, y_
10384

10485
def _set_n_features_in(self, X: Array[Any]) -> None:
10586
if hasattr(X, "shape"):

src/tdamapper/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from sklearn.preprocessing import StandardScaler
1616
from umap import UMAP
1717

18-
from tdamapper.core import Cover, TrivialClustering
18+
from tdamapper.core import TrivialClustering
1919
from tdamapper.cover import BallCover, CubicalCover, KNNCover
2020
from tdamapper.learn import MapperAlgorithm
2121
from tdamapper.plot import MapperPlot
22+
from tdamapper.protocols import Clustering, Cover
2223

2324
logging.basicConfig(level=logging.INFO)
2425
logger = logging.getLogger(__name__)
@@ -164,7 +165,7 @@ def run_mapper(
164165
elif lens_type == LENS_UMAP:
165166
lens = lens_umap(n_components=lens_umap_n_components)
166167

167-
cover: Cover
168+
cover: Cover[NDArray[np.float_]]
168169
if cover_type == COVER_CUBICAL:
169170
cover = CubicalCover(
170171
n_intervals=cover_cubical_n_intervals,
@@ -178,6 +179,7 @@ def run_mapper(
178179
logger.error(f"Unknown cover type: {cover_type}")
179180
return None
180181

182+
clustering: Clustering[NDArray[np.float_]]
181183
if clustering_type == CLUSTERING_TRIVIAL:
182184
clustering = TrivialClustering()
183185
elif clustering_type == CLUSTERING_KMEANS:

src/tdamapper/core.py

Lines changed: 37 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131
from __future__ import annotations
3232

3333
import logging
34-
from typing import Any, Callable, Iterator, Optional, Protocol
34+
from typing import Any, Callable, Generic, Iterator, Optional, TypeVar
3535

3636
import networkx as nx
3737
from joblib import Parallel, delayed
3838

39-
from tdamapper._common import Array, ParamsMixin, clone
39+
from tdamapper._common import ParamsMixin, clone
40+
from tdamapper.protocols import ArrayRead, Clustering, Cover, SpatialSearch
4041
from tdamapper.utils.unionfind import UnionFind
4142

4243
ATTR_IDS = "ids"
@@ -53,9 +54,17 @@
5354
handlers=[logging.StreamHandler()],
5455
)
5556

57+
S = TypeVar("S")
58+
59+
T = TypeVar("T")
60+
5661

5762
def mapper_labels(
58-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
63+
X: ArrayRead[S],
64+
y: ArrayRead[T],
65+
cover: Cover[T],
66+
clustering: Clustering[S],
67+
n_jobs: int = 1,
5968
) -> list[list[int]]:
6069
"""
6170
Identify the nodes of the Mapper graph.
@@ -85,7 +94,7 @@ def mapper_labels(
8594
"""
8695

8796
def _run_clustering(
88-
local_ids: list[int], X_local: Array[Any], clust: Clustering
97+
local_ids: list[int], X_local: ArrayRead[S], clust: Clustering[S]
8998
) -> tuple[list[int], list[int]]:
9099
local_lbls = clust.fit(X_local).labels_
91100
return local_ids, local_lbls
@@ -110,7 +119,11 @@ def _run_clustering(
110119

111120

112121
def mapper_connected_components(
113-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
122+
X: ArrayRead[S],
123+
y: ArrayRead[T],
124+
cover: Cover[T],
125+
clustering: Clustering[S],
126+
n_jobs: int = 1,
114127
) -> list[int]:
115128
"""
116129
Identify the connected components of the Mapper graph.
@@ -155,7 +168,11 @@ def mapper_connected_components(
155168

156169

157170
def mapper_graph(
158-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
171+
X: ArrayRead[S],
172+
y: ArrayRead[T],
173+
cover: Cover[T],
174+
clustering: Clustering[S],
175+
n_jobs: int = 1,
159176
) -> nx.Graph:
160177
"""
161178
Create the Mapper graph.
@@ -201,7 +218,7 @@ def mapper_graph(
201218

202219

203220
def aggregate_graph(
204-
X: Array[Any], graph: nx.Graph, agg: Callable[..., Any]
221+
X: ArrayRead[S], graph: nx.Graph, agg: Callable[..., Any]
205222
) -> dict[int, Any]:
206223
"""
207224
Apply an aggregation function to the nodes of a graph.
@@ -229,81 +246,7 @@ def aggregate_graph(
229246
return agg_values
230247

231248

232-
class Cover(Protocol):
233-
"""
234-
Abstract interface for cover algorithms.
235-
236-
This is a naive implementation. Subclasses should override the methods of
237-
this class to implement more meaningful cover algorithms.
238-
"""
239-
240-
def apply(self, X: Array[Any]) -> Iterator[list[int]]:
241-
"""
242-
Covers the dataset with a single open set.
243-
244-
This is a naive implementation that returns a generator producing a
245-
single list containing all the ids if the original dataset. This
246-
method should be overridden by subclasses to implement more meaningful
247-
cover algorithms.
248-
249-
:param X: A dataset of n points.
250-
:return: A generator of lists of ids.
251-
"""
252-
253-
254-
class Clustering(Protocol):
255-
"""
256-
Abstract interface for clustering algorithms.
257-
258-
A clustering algorithm is a method for grouping data points into clusters.
259-
Each cluster is represented by a unique integer label, and the labels are
260-
assigned to the points in the dataset. The labels are typically non-negative
261-
integers, starting from zero. The labels are assigned such that the points
262-
in the same cluster have the same label, and the points in different clusters
263-
have different labels. The labels are not necessarily contiguous, and there
264-
may be gaps in the sequence of labels.
265-
"""
266-
267-
labels_: list[int]
268-
269-
def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> Clustering:
270-
"""
271-
Fit the clustering algorithm to the data.
272-
273-
:param X: A dataset of n points.
274-
:param y: A dataset of targets. Typically ignored and present for
275-
compatibility with scikit-learn's clustering interface.
276-
:return: The fitted clustering object.
277-
"""
278-
279-
280-
class SpatialSearch(Protocol):
281-
"""
282-
Abstract interface for search algorithms.
283-
284-
A spatial search algorithm is a method for finding neighbors of a
285-
query point in a dataset.
286-
"""
287-
288-
def fit(self, X: Array[Any]) -> SpatialSearch:
289-
"""
290-
Train internal parameters.
291-
292-
:param X: A dataset of n points.
293-
:return: The object itself.
294-
"""
295-
296-
def search(self, x: Any) -> list[int]:
297-
"""
298-
Return a list of neighbors for the query point.
299-
300-
:param x: A query point for which we want to find neighbors.
301-
:return: A list containing all the indices of the points in the
302-
dataset.
303-
"""
304-
305-
306-
def proximity_net(search: SpatialSearch, X: Array[Any]) -> Iterator[list[int]]:
249+
def proximity_net(search: SpatialSearch[S], X: ArrayRead[S]) -> Iterator[list[int]]:
307250
"""
308251
Covers the dataset using proximity-net.
309252
@@ -331,7 +274,7 @@ def proximity_net(search: SpatialSearch, X: Array[Any]) -> Iterator[list[int]]:
331274
yield neigh_ids
332275

333276

334-
class TrivialCover(ParamsMixin):
277+
class TrivialCover(ParamsMixin, Generic[T]):
335278
"""
336279
Cover algorithm that covers data with a single subset containing the whole
337280
dataset.
@@ -340,7 +283,7 @@ class TrivialCover(ParamsMixin):
340283
dataset.
341284
"""
342285

343-
def apply(self, X: Array[Any]) -> Iterator[list[int]]:
286+
def apply(self, X: ArrayRead[T]) -> Iterator[list[int]]:
344287
"""
345288
Covers the dataset with a single open set.
346289
@@ -350,7 +293,7 @@ def apply(self, X: Array[Any]) -> Iterator[list[int]]:
350293
yield list(range(0, len(X)))
351294

352295

353-
class FailSafeClustering(ParamsMixin):
296+
class FailSafeClustering(ParamsMixin, Generic[T]):
354297
"""
355298
A delegating clustering algorithm that prevents failure.
356299
@@ -364,17 +307,19 @@ class FailSafeClustering(ParamsMixin):
364307
enable logging, or False to suppress it. Defaults to True.
365308
"""
366309

367-
_clustering: Optional[Clustering]
310+
_clustering: Optional[Clustering[T]]
368311
_verbose: bool
369312
labels_: list[int]
370313

371314
def __init__(
372-
self, clustering: Optional[Clustering] = None, verbose: bool = True
315+
self, clustering: Optional[Clustering[T]] = None, verbose: bool = True
373316
) -> None:
374317
self.clustering = clustering
375318
self.verbose = verbose
376319

377-
def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> FailSafeClustering:
320+
def fit(
321+
self, X: ArrayRead[T], y: Optional[ArrayRead[T]] = None
322+
) -> FailSafeClustering[T]:
378323
self._clustering = (
379324
TrivialClustering() if self.clustering is None else self.clustering
380325
)
@@ -389,7 +334,7 @@ def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> FailSafeClusteri
389334
return self
390335

391336

392-
class TrivialClustering(ParamsMixin):
337+
class TrivialClustering(ParamsMixin, Generic[T]):
393338
"""
394339
A clustering algorithm that returns a single cluster.
395340
@@ -404,7 +349,9 @@ class TrivialClustering(ParamsMixin):
404349
def __init__(self) -> None:
405350
pass
406351

407-
def fit(self, X: Array[Any], _y: Optional[Array[Any]] = None) -> TrivialClustering:
352+
def fit(
353+
self, X: ArrayRead[T], _y: Optional[ArrayRead[T]] = None
354+
) -> TrivialClustering[T]:
408355
"""
409356
Fit the clustering algorithm to the data.
410357

0 commit comments

Comments
 (0)