2323from  umap  import  UMAP 
2424
2525from  tdamapper .core  import  aggregate_graph 
26- from  tdamapper .cover  import  BallCover , CubicalCover 
27- from  tdamapper .learn  import  MapperAlgorithm 
26+ from  tdamapper .cover  import  BallCover , CubicalCover ,  KNNCover 
27+ from  tdamapper .learn  import  MapperAlgorithm ,  MapperClustering 
2828from  tdamapper .plot  import  MapperPlot 
2929
3030LIMITS_ENABLED  =  bool (os .environ .get ("LIMITS_ENABLED" , False ))
6363
6464V_COVER_CUBICAL  =  "Cubical" 
6565
66+ V_COVER_KNN  =  "KNN" 
67+ 
6668V_CLUSTERING_TRIVIAL  =  "Trivial" 
6769
70+ V_CLUSTERING_COVER  =  "Cover" 
71+ 
6872V_CLUSTERING_AGGLOMERATIVE  =  "Agglomerative" 
6973
7074V_CLUSTERING_DBSCAN  =  "DBSCAN" 
@@ -198,7 +202,10 @@ def _get_data_summary(df_X, df_y):
198202 }
199203 ).T 
200204 df_summary  =  pd .DataFrame (
201-  {V_DATA_SUMMARY_FEAT : df .columns , V_DATA_SUMMARY_HIST : df_hist .values .tolist ()}
205+  {
206+  V_DATA_SUMMARY_FEAT : df .columns ,
207+  V_DATA_SUMMARY_HIST : df_hist .values .tolist (),
208+  }
202209 )
203210 return  df_summary 
204211
@@ -316,9 +323,10 @@ def mapper_lens_input_section(X):
316323 if  pca_n  >  n_feats :
317324 lens  =  X 
318325 else :
319-  lens  =  PCA (n_components = pca_n , random_state = pca_random_state ).fit_transform (
320-  X 
321-  )
326+  lens  =  PCA (
327+  n_components = pca_n ,
328+  random_state = pca_random_state ,
329+  ).fit_transform (X )
322330 elif  lens_type  ==  V_LENS_UMAP :
323331 umap_n  =  st .number_input (
324332 "UMAP Components" ,
@@ -343,7 +351,12 @@ def mapper_cover_input_section():
343351 st .header ("🌐 Cover" )
344352 cover_type  =  st .selectbox (
345353 "Type" ,
346-  options = [V_COVER_TRIVIAL , V_COVER_BALL , V_COVER_CUBICAL ],
354+  options = [
355+  V_COVER_TRIVIAL ,
356+  V_COVER_BALL ,
357+  V_COVER_CUBICAL ,
358+  V_COVER_KNN ,
359+  ],
347360 index = 2 ,
348361 )
349362 cover  =  None 
@@ -379,9 +392,79 @@ def mapper_cover_input_section():
379392 "Overlap" , value = 0.25 , min_value = 0.0 , max_value = 1.0 
380393 )
381394 cover  =  CubicalCover (n_intervals = cubical_n , overlap_frac = cubical_p )
395+  elif  cover_type  ==  V_COVER_KNN :
396+  knn_k  =  st .number_input ("Neighbors" , value = 10 , min_value = 1 )
397+  cover  =  KNNCover (neighbors = knn_k )
382398 return  cover 
383399
384400
401+ def  mapper_clustering_cover ():
402+  cover_type  =  st .selectbox (
403+  "Type" ,
404+  options = [
405+  V_COVER_TRIVIAL ,
406+  V_COVER_BALL ,
407+  V_COVER_CUBICAL ,
408+  V_COVER_KNN ,
409+  ],
410+  index = 2 ,
411+  key = "mapper_clustering_cover_type" ,
412+  )
413+  cover  =  None 
414+  if  cover_type  ==  V_COVER_TRIVIAL :
415+  cover  =  None 
416+  elif  cover_type  ==  V_COVER_BALL :
417+  ball_r  =  st .number_input (
418+  "Radius" ,
419+  value = 100.0 ,
420+  min_value = 0.0 ,
421+  key = "mapper_clustering_radius" ,
422+  )
423+  metric  =  st .selectbox (
424+  "Metric" ,
425+  options = [
426+  "euclidean" ,
427+  "chebyshev" ,
428+  "manhattan" ,
429+  "cosine" ,
430+  ],
431+  key = "mapper_clustering_cover_metric" ,
432+  )
433+  cover  =  BallCover (radius = ball_r , metric = metric )
434+  elif  cover_type  ==  V_COVER_CUBICAL :
435+  cubical_n  =  st .number_input (
436+  "Intervals" ,
437+  value = 10 ,
438+  min_value = 0 ,
439+  key = "mapper_clustering_cover_intervals" ,
440+  )
441+  cubical_overlap  =  st .checkbox (
442+  "Set overlap" ,
443+  value = False ,
444+  help = "Uses a dimension-dependant default overlap when unchecked" ,
445+  key = "mapper_clustering_cover_set_overlap" ,
446+  )
447+  cubical_p  =  None 
448+  if  cubical_overlap :
449+  cubical_p  =  st .number_input (
450+  "Overlap" ,
451+  value = 0.25 ,
452+  min_value = 0.0 ,
453+  max_value = 1.0 ,
454+  key = "mapper_clustering_cover_overlap" ,
455+  )
456+  cover  =  CubicalCover (n_intervals = cubical_n , overlap_frac = cubical_p )
457+  elif  cover_type  ==  V_COVER_KNN :
458+  knn_k  =  st .number_input (
459+  "Neighbors" ,
460+  value = 10 ,
461+  min_value = 1 ,
462+  key = "mapper_clustering_knn_k" ,
463+  )
464+  cover  =  KNNCover (neighbors = knn_k )
465+  return  MapperClustering (cover = cover , n_jobs = - 2 )
466+ 
467+ 
385468def  mapper_clustering_kmeans ():
386469 clust_num  =  st .number_input (
387470 "Clusters" ,
@@ -485,17 +568,20 @@ def mapper_clustering_input_section():
485568 "Type" ,
486569 options = [
487570 V_CLUSTERING_TRIVIAL ,
571+  V_CLUSTERING_COVER ,
488572 V_CLUSTERING_KMEANS ,
489573 V_CLUSTERING_AGGLOMERATIVE ,
490574 V_CLUSTERING_DBSCAN ,
491575 V_CLUSTERING_HDBSCAN ,
492576 V_CLUSTERING_AFFINITY_PROPAGATION ,
493577 ],
494-  index = 1 ,
578+  index = 0 ,
495579 )
496580 clustering  =  None 
497581 if  clustering_type  ==  V_CLUSTERING_TRIVIAL :
498582 clustering  =  None 
583+  elif  clustering_type  ==  V_CLUSTERING_COVER :
584+  clustering  =  mapper_clustering_cover ()
499585 elif  clustering_type  ==  V_CLUSTERING_AGGLOMERATIVE :
500586 clustering  =  mapper_clustering_agglomerative ()
501587 elif  clustering_type  ==  V_CLUSTERING_KMEANS :
@@ -625,7 +711,13 @@ def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
625711 logger .info ("Generating Mapper figure" )
626712 mapper_fig  =  mapper_plot .plot_plotly (
627713 colors ,
628-  node_size = node_size ,
714+  node_size = [
715+  0.0 ,
716+  node_size  /  2.0 ,
717+  node_size ,
718+  node_size  *  1.5 ,
719+  node_size  *  2.0 ,
720+  ],
629721 agg = _agg ,
630722 title = [f"{ c }  "  for  c  in  colors .columns ],
631723 cmap = cmap ,
0 commit comments