3131from __future__ import annotations
3232
3333import logging
34- from typing import Any , Callable , Iterator , Optional , Protocol
34+ from typing import Any , Callable , Generic , Iterator , Optional , TypeVar
3535
3636import networkx as nx
3737from joblib import Parallel , delayed
3838
39- from tdamapper ._common import Array , ParamsMixin , clone
39+ from tdamapper ._common import ParamsMixin , clone
40+ from tdamapper .protocols import ArrayRead , Clustering , Cover , SpatialSearch
4041from tdamapper .utils .unionfind import UnionFind
4142
4243ATTR_IDS = "ids"
5354 handlers = [logging .StreamHandler ()],
5455)
5556
57+ S = TypeVar ("S" )
58+
59+ T = TypeVar ("T" )
60+
5661
5762def mapper_labels (
58- X : Array [Any ], y : Array [Any ], cover : Cover , clustering : Clustering , n_jobs : int = 1
63+ X : ArrayRead [S ],
64+ y : ArrayRead [T ],
65+ cover : Cover [T ],
66+ clustering : Clustering [S ],
67+ n_jobs : int = 1 ,
5968) -> list [list [int ]]:
6069 """
6170 Identify the nodes of the Mapper graph.
@@ -85,7 +94,7 @@ def mapper_labels(
8594 """
8695
8796 def _run_clustering (
88- local_ids : list [int ], X_local : Array [ Any ], clust : Clustering
97+ local_ids : list [int ], X_local : ArrayRead [ S ], clust : Clustering [ S ]
8998 ) -> tuple [list [int ], list [int ]]:
9099 local_lbls = clust .fit (X_local ).labels_
91100 return local_ids , local_lbls
@@ -110,7 +119,11 @@ def _run_clustering(
110119
111120
112121def mapper_connected_components (
113- X : Array [Any ], y : Array [Any ], cover : Cover , clustering : Clustering , n_jobs : int = 1
122+ X : ArrayRead [S ],
123+ y : ArrayRead [T ],
124+ cover : Cover [T ],
125+ clustering : Clustering [S ],
126+ n_jobs : int = 1 ,
114127) -> list [int ]:
115128 """
116129 Identify the connected components of the Mapper graph.
@@ -155,7 +168,11 @@ def mapper_connected_components(
155168
156169
157170def mapper_graph (
158- X : Array [Any ], y : Array [Any ], cover : Cover , clustering : Clustering , n_jobs : int = 1
171+ X : ArrayRead [S ],
172+ y : ArrayRead [T ],
173+ cover : Cover [T ],
174+ clustering : Clustering [S ],
175+ n_jobs : int = 1 ,
159176) -> nx .Graph :
160177 """
161178 Create the Mapper graph.
@@ -201,7 +218,7 @@ def mapper_graph(
201218
202219
203220def aggregate_graph (
204- X : Array [ Any ], graph : nx .Graph , agg : Callable [..., Any ]
221+ X : ArrayRead [ S ], graph : nx .Graph , agg : Callable [..., Any ]
205222) -> dict [int , Any ]:
206223 """
207224 Apply an aggregation function to the nodes of a graph.
@@ -229,81 +246,7 @@ def aggregate_graph(
229246 return agg_values
230247
231248
232- class Cover (Protocol ):
233- """
234- Abstract interface for cover algorithms.
235-
236- This is a naive implementation. Subclasses should override the methods of
237- this class to implement more meaningful cover algorithms.
238- """
239-
240- def apply (self , X : Array [Any ]) -> Iterator [list [int ]]:
241- """
242- Covers the dataset with a single open set.
243-
244- This is a naive implementation that returns a generator producing a
245- single list containing all the ids if the original dataset. This
246- method should be overridden by subclasses to implement more meaningful
247- cover algorithms.
248-
249- :param X: A dataset of n points.
250- :return: A generator of lists of ids.
251- """
252-
253-
254- class Clustering (Protocol ):
255- """
256- Abstract interface for clustering algorithms.
257-
258- A clustering algorithm is a method for grouping data points into clusters.
259- Each cluster is represented by a unique integer label, and the labels are
260- assigned to the points in the dataset. The labels are typically non-negative
261- integers, starting from zero. The labels are assigned such that the points
262- in the same cluster have the same label, and the points in different clusters
263- have different labels. The labels are not necessarily contiguous, and there
264- may be gaps in the sequence of labels.
265- """
266-
267- labels_ : list [int ]
268-
269- def fit (self , X : Array [Any ], y : Optional [Array [Any ]] = None ) -> Clustering :
270- """
271- Fit the clustering algorithm to the data.
272-
273- :param X: A dataset of n points.
274- :param y: A dataset of targets. Typically ignored and present for
275- compatibility with scikit-learn's clustering interface.
276- :return: The fitted clustering object.
277- """
278-
279-
280- class SpatialSearch (Protocol ):
281- """
282- Abstract interface for search algorithms.
283-
284- A spatial search algorithm is a method for finding neighbors of a
285- query point in a dataset.
286- """
287-
288- def fit (self , X : Array [Any ]) -> SpatialSearch :
289- """
290- Train internal parameters.
291-
292- :param X: A dataset of n points.
293- :return: The object itself.
294- """
295-
296- def search (self , x : Any ) -> list [int ]:
297- """
298- Return a list of neighbors for the query point.
299-
300- :param x: A query point for which we want to find neighbors.
301- :return: A list containing all the indices of the points in the
302- dataset.
303- """
304-
305-
306- def proximity_net (search : SpatialSearch , X : Array [Any ]) -> Iterator [list [int ]]:
249+ def proximity_net (search : SpatialSearch [S ], X : ArrayRead [S ]) -> Iterator [list [int ]]:
307250 """
308251 Covers the dataset using proximity-net.
309252
@@ -331,7 +274,7 @@ def proximity_net(search: SpatialSearch, X: Array[Any]) -> Iterator[list[int]]:
331274 yield neigh_ids
332275
333276
334- class TrivialCover (ParamsMixin ):
277+ class TrivialCover (ParamsMixin , Generic [ T ] ):
335278 """
336279 Cover algorithm that covers data with a single subset containing the whole
337280 dataset.
@@ -340,7 +283,7 @@ class TrivialCover(ParamsMixin):
340283 dataset.
341284 """
342285
343- def apply (self , X : Array [ Any ]) -> Iterator [list [int ]]:
286+ def apply (self , X : ArrayRead [ T ]) -> Iterator [list [int ]]:
344287 """
345288 Covers the dataset with a single open set.
346289
@@ -350,7 +293,7 @@ def apply(self, X: Array[Any]) -> Iterator[list[int]]:
350293 yield list (range (0 , len (X )))
351294
352295
353- class FailSafeClustering (ParamsMixin ):
296+ class FailSafeClustering (ParamsMixin , Generic [ T ] ):
354297 """
355298 A delegating clustering algorithm that prevents failure.
356299
@@ -364,17 +307,19 @@ class FailSafeClustering(ParamsMixin):
364307 enable logging, or False to suppress it. Defaults to True.
365308 """
366309
367- _clustering : Optional [Clustering ]
310+ _clustering : Optional [Clustering [ T ] ]
368311 _verbose : bool
369312 labels_ : list [int ]
370313
371314 def __init__ (
372- self , clustering : Optional [Clustering ] = None , verbose : bool = True
315+ self , clustering : Optional [Clustering [ T ] ] = None , verbose : bool = True
373316 ) -> None :
374317 self .clustering = clustering
375318 self .verbose = verbose
376319
377- def fit (self , X : Array [Any ], y : Optional [Array [Any ]] = None ) -> FailSafeClustering :
320+ def fit (
321+ self , X : ArrayRead [T ], y : Optional [ArrayRead [T ]] = None
322+ ) -> FailSafeClustering [T ]:
378323 self ._clustering = (
379324 TrivialClustering () if self .clustering is None else self .clustering
380325 )
@@ -389,7 +334,7 @@ def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> FailSafeClusteri
389334 return self
390335
391336
392- class TrivialClustering (ParamsMixin ):
337+ class TrivialClustering (ParamsMixin , Generic [ T ] ):
393338 """
394339 A clustering algorithm that returns a single cluster.
395340
@@ -404,7 +349,9 @@ class TrivialClustering(ParamsMixin):
404349 def __init__ (self ) -> None :
405350 pass
406351
407- def fit (self , X : Array [Any ], _y : Optional [Array [Any ]] = None ) -> TrivialClustering :
352+ def fit (
353+ self , X : ArrayRead [T ], _y : Optional [ArrayRead [T ]] = None
354+ ) -> TrivialClustering [T ]:
408355 """
409356 Fit the clustering algorithm to the data.
410357
0 commit comments