2323from umap import UMAP
2424
2525from tdamapper .core import aggregate_graph
26- from tdamapper .cover import BallCover , CubicalCover
27- from tdamapper .learn import MapperAlgorithm
26+ from tdamapper .cover import BallCover , CubicalCover , KNNCover
27+ from tdamapper .learn import MapperAlgorithm , MapperClustering
2828from tdamapper .plot import MapperPlot
2929
3030LIMITS_ENABLED = bool (os .environ .get ("LIMITS_ENABLED" , False ))
6363
6464V_COVER_CUBICAL = "Cubical"
6565
66+ V_COVER_KNN = "KNN"
67+
6668V_CLUSTERING_TRIVIAL = "Trivial"
6769
70+ V_CLUSTERING_COVER = "Cover"
71+
6872V_CLUSTERING_AGGLOMERATIVE = "Agglomerative"
6973
7074V_CLUSTERING_DBSCAN = "DBSCAN"
@@ -198,7 +202,10 @@ def _get_data_summary(df_X, df_y):
198202 }
199203 ).T
200204 df_summary = pd .DataFrame (
201- {V_DATA_SUMMARY_FEAT : df .columns , V_DATA_SUMMARY_HIST : df_hist .values .tolist ()}
205+ {
206+ V_DATA_SUMMARY_FEAT : df .columns ,
207+ V_DATA_SUMMARY_HIST : df_hist .values .tolist (),
208+ }
202209 )
203210 return df_summary
204211
@@ -316,9 +323,10 @@ def mapper_lens_input_section(X):
316323 if pca_n > n_feats :
317324 lens = X
318325 else :
319- lens = PCA (n_components = pca_n , random_state = pca_random_state ).fit_transform (
320- X
321- )
326+ lens = PCA (
327+ n_components = pca_n ,
328+ random_state = pca_random_state ,
329+ ).fit_transform (X )
322330 elif lens_type == V_LENS_UMAP :
323331 umap_n = st .number_input (
324332 "UMAP Components" ,
@@ -343,7 +351,12 @@ def mapper_cover_input_section():
343351 st .header ("🌐 Cover" )
344352 cover_type = st .selectbox (
345353 "Type" ,
346- options = [V_COVER_TRIVIAL , V_COVER_BALL , V_COVER_CUBICAL ],
354+ options = [
355+ V_COVER_TRIVIAL ,
356+ V_COVER_BALL ,
357+ V_COVER_CUBICAL ,
358+ V_COVER_KNN ,
359+ ],
347360 index = 2 ,
348361 )
349362 cover = None
@@ -379,9 +392,79 @@ def mapper_cover_input_section():
379392 "Overlap" , value = 0.25 , min_value = 0.0 , max_value = 1.0
380393 )
381394 cover = CubicalCover (n_intervals = cubical_n , overlap_frac = cubical_p )
395+ elif cover_type == V_COVER_KNN :
396+ knn_k = st .number_input ("Neighbors" , value = 10 , min_value = 1 )
397+ cover = KNNCover (neighbors = knn_k )
382398 return cover
383399
384400
401+ def mapper_clustering_cover ():
402+ cover_type = st .selectbox (
403+ "Type" ,
404+ options = [
405+ V_COVER_TRIVIAL ,
406+ V_COVER_BALL ,
407+ V_COVER_CUBICAL ,
408+ V_COVER_KNN ,
409+ ],
410+ index = 2 ,
411+ key = "mapper_clustering_cover_type" ,
412+ )
413+ cover = None
414+ if cover_type == V_COVER_TRIVIAL :
415+ cover = None
416+ elif cover_type == V_COVER_BALL :
417+ ball_r = st .number_input (
418+ "Radius" ,
419+ value = 100.0 ,
420+ min_value = 0.0 ,
421+ key = "mapper_clustering_radius" ,
422+ )
423+ metric = st .selectbox (
424+ "Metric" ,
425+ options = [
426+ "euclidean" ,
427+ "chebyshev" ,
428+ "manhattan" ,
429+ "cosine" ,
430+ ],
431+ key = "mapper_clustering_cover_metric" ,
432+ )
433+ cover = BallCover (radius = ball_r , metric = metric )
434+ elif cover_type == V_COVER_CUBICAL :
435+ cubical_n = st .number_input (
436+ "Intervals" ,
437+ value = 10 ,
438+ min_value = 0 ,
439+ key = "mapper_clustering_cover_intervals" ,
440+ )
441+ cubical_overlap = st .checkbox (
442+ "Set overlap" ,
443+ value = False ,
444+ help = "Uses a dimension-dependant default overlap when unchecked" ,
445+ key = "mapper_clustering_cover_set_overlap" ,
446+ )
447+ cubical_p = None
448+ if cubical_overlap :
449+ cubical_p = st .number_input (
450+ "Overlap" ,
451+ value = 0.25 ,
452+ min_value = 0.0 ,
453+ max_value = 1.0 ,
454+ key = "mapper_clustering_cover_overlap" ,
455+ )
456+ cover = CubicalCover (n_intervals = cubical_n , overlap_frac = cubical_p )
457+ elif cover_type == V_COVER_KNN :
458+ knn_k = st .number_input (
459+ "Neighbors" ,
460+ value = 10 ,
461+ min_value = 1 ,
462+ key = "mapper_clustering_knn_k" ,
463+ )
464+ cover = KNNCover (neighbors = knn_k )
465+ return MapperClustering (cover = cover , n_jobs = - 2 )
466+
467+
385468def mapper_clustering_kmeans ():
386469 clust_num = st .number_input (
387470 "Clusters" ,
@@ -485,17 +568,20 @@ def mapper_clustering_input_section():
485568 "Type" ,
486569 options = [
487570 V_CLUSTERING_TRIVIAL ,
571+ V_CLUSTERING_COVER ,
488572 V_CLUSTERING_KMEANS ,
489573 V_CLUSTERING_AGGLOMERATIVE ,
490574 V_CLUSTERING_DBSCAN ,
491575 V_CLUSTERING_HDBSCAN ,
492576 V_CLUSTERING_AFFINITY_PROPAGATION ,
493577 ],
494- index = 1 ,
578+ index = 0 ,
495579 )
496580 clustering = None
497581 if clustering_type == V_CLUSTERING_TRIVIAL :
498582 clustering = None
583+ elif clustering_type == V_CLUSTERING_COVER :
584+ clustering = mapper_clustering_cover ()
499585 elif clustering_type == V_CLUSTERING_AGGLOMERATIVE :
500586 clustering = mapper_clustering_agglomerative ()
501587 elif clustering_type == V_CLUSTERING_KMEANS :
@@ -625,7 +711,13 @@ def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
625711 logger .info ("Generating Mapper figure" )
626712 mapper_fig = mapper_plot .plot_plotly (
627713 colors ,
628- node_size = node_size ,
714+ node_size = [
715+ 0.0 ,
716+ node_size / 2.0 ,
717+ node_size ,
718+ node_size * 1.5 ,
719+ node_size * 2.0 ,
720+ ],
629721 agg = _agg ,
630722 title = [f"{ c } " for c in colors .columns ],
631723 cmap = cmap ,
0 commit comments