Skip to content

Commit a4ae4f2

Browse files
authored
Merge pull request #222 from lucasimi/general-improvements
Improved code. Minor fixes
2 parents 5578d4a + 7534138 commit a4ae4f2

File tree

9 files changed

+183
-64
lines changed

9 files changed

+183
-64
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
.idea
2525
dist/
2626
build/
27+
coverage.xml

Makefile

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
PYTHON = python
2+
PIP = pip
3+
4+
.PHONY: all
5+
all: install
6+
7+
.PHONY: install
8+
install:
9+
$(PIP) install -e .[dev]
10+
11+
.PHONY: test
12+
test:
13+
coverage run --source=src -m pytest tests/test_unit_*.py
14+
coverage xml
15+
16+
.PHONY: bench
17+
bench:
18+
$(PYTHON) -m pytest tests/test_bench_*.py -s -o log_cli=true --log-level=INFO
19+
20+
.PHONY: clean
21+
clean:
22+
find . -type d -name "__pycache__" -exec rm -r {} +
23+
find . -type f -name "*.pyc" -delete

app/streamlit_app.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from umap import UMAP
2424

2525
from tdamapper.core import aggregate_graph
26-
from tdamapper.cover import BallCover, CubicalCover
27-
from tdamapper.learn import MapperAlgorithm
26+
from tdamapper.cover import BallCover, CubicalCover, KNNCover
27+
from tdamapper.learn import MapperAlgorithm, MapperClustering
2828
from tdamapper.plot import MapperPlot
2929

3030
LIMITS_ENABLED = bool(os.environ.get("LIMITS_ENABLED", False))
@@ -63,8 +63,12 @@
6363

6464
V_COVER_CUBICAL = "Cubical"
6565

66+
V_COVER_KNN = "KNN"
67+
6668
V_CLUSTERING_TRIVIAL = "Trivial"
6769

70+
V_CLUSTERING_COVER = "Cover"
71+
6872
V_CLUSTERING_AGGLOMERATIVE = "Agglomerative"
6973

7074
V_CLUSTERING_DBSCAN = "DBSCAN"
@@ -198,7 +202,10 @@ def _get_data_summary(df_X, df_y):
198202
}
199203
).T
200204
df_summary = pd.DataFrame(
201-
{V_DATA_SUMMARY_FEAT: df.columns, V_DATA_SUMMARY_HIST: df_hist.values.tolist()}
205+
{
206+
V_DATA_SUMMARY_FEAT: df.columns,
207+
V_DATA_SUMMARY_HIST: df_hist.values.tolist(),
208+
}
202209
)
203210
return df_summary
204211

@@ -316,9 +323,10 @@ def mapper_lens_input_section(X):
316323
if pca_n > n_feats:
317324
lens = X
318325
else:
319-
lens = PCA(n_components=pca_n, random_state=pca_random_state).fit_transform(
320-
X
321-
)
326+
lens = PCA(
327+
n_components=pca_n,
328+
random_state=pca_random_state,
329+
).fit_transform(X)
322330
elif lens_type == V_LENS_UMAP:
323331
umap_n = st.number_input(
324332
"UMAP Components",
@@ -343,7 +351,12 @@ def mapper_cover_input_section():
343351
st.header("🌐 Cover")
344352
cover_type = st.selectbox(
345353
"Type",
346-
options=[V_COVER_TRIVIAL, V_COVER_BALL, V_COVER_CUBICAL],
354+
options=[
355+
V_COVER_TRIVIAL,
356+
V_COVER_BALL,
357+
V_COVER_CUBICAL,
358+
V_COVER_KNN,
359+
],
347360
index=2,
348361
)
349362
cover = None
@@ -379,9 +392,79 @@ def mapper_cover_input_section():
379392
"Overlap", value=0.25, min_value=0.0, max_value=1.0
380393
)
381394
cover = CubicalCover(n_intervals=cubical_n, overlap_frac=cubical_p)
395+
elif cover_type == V_COVER_KNN:
396+
knn_k = st.number_input("Neighbors", value=10, min_value=1)
397+
cover = KNNCover(neighbors=knn_k)
382398
return cover
383399

384400

401+
def mapper_clustering_cover():
402+
cover_type = st.selectbox(
403+
"Type",
404+
options=[
405+
V_COVER_TRIVIAL,
406+
V_COVER_BALL,
407+
V_COVER_CUBICAL,
408+
V_COVER_KNN,
409+
],
410+
index=2,
411+
key="mapper_clustering_cover_type",
412+
)
413+
cover = None
414+
if cover_type == V_COVER_TRIVIAL:
415+
cover = None
416+
elif cover_type == V_COVER_BALL:
417+
ball_r = st.number_input(
418+
"Radius",
419+
value=100.0,
420+
min_value=0.0,
421+
key="mapper_clustering_radius",
422+
)
423+
metric = st.selectbox(
424+
"Metric",
425+
options=[
426+
"euclidean",
427+
"chebyshev",
428+
"manhattan",
429+
"cosine",
430+
],
431+
key="mapper_clustering_cover_metric",
432+
)
433+
cover = BallCover(radius=ball_r, metric=metric)
434+
elif cover_type == V_COVER_CUBICAL:
435+
cubical_n = st.number_input(
436+
"Intervals",
437+
value=10,
438+
min_value=0,
439+
key="mapper_clustering_cover_intervals",
440+
)
441+
cubical_overlap = st.checkbox(
442+
"Set overlap",
443+
value=False,
444+
help="Uses a dimension-dependant default overlap when unchecked",
445+
key="mapper_clustering_cover_set_overlap",
446+
)
447+
cubical_p = None
448+
if cubical_overlap:
449+
cubical_p = st.number_input(
450+
"Overlap",
451+
value=0.25,
452+
min_value=0.0,
453+
max_value=1.0,
454+
key="mapper_clustering_cover_overlap",
455+
)
456+
cover = CubicalCover(n_intervals=cubical_n, overlap_frac=cubical_p)
457+
elif cover_type == V_COVER_KNN:
458+
knn_k = st.number_input(
459+
"Neighbors",
460+
value=10,
461+
min_value=1,
462+
key="mapper_clustering_knn_k",
463+
)
464+
cover = KNNCover(neighbors=knn_k)
465+
return MapperClustering(cover=cover, n_jobs=-2)
466+
467+
385468
def mapper_clustering_kmeans():
386469
clust_num = st.number_input(
387470
"Clusters",
@@ -485,17 +568,20 @@ def mapper_clustering_input_section():
485568
"Type",
486569
options=[
487570
V_CLUSTERING_TRIVIAL,
571+
V_CLUSTERING_COVER,
488572
V_CLUSTERING_KMEANS,
489573
V_CLUSTERING_AGGLOMERATIVE,
490574
V_CLUSTERING_DBSCAN,
491575
V_CLUSTERING_HDBSCAN,
492576
V_CLUSTERING_AFFINITY_PROPAGATION,
493577
],
494-
index=1,
578+
index=0,
495579
)
496580
clustering = None
497581
if clustering_type == V_CLUSTERING_TRIVIAL:
498582
clustering = None
583+
elif clustering_type == V_CLUSTERING_COVER:
584+
clustering = mapper_clustering_cover()
499585
elif clustering_type == V_CLUSTERING_AGGLOMERATIVE:
500586
clustering = mapper_clustering_agglomerative()
501587
elif clustering_type == V_CLUSTERING_KMEANS:
@@ -625,7 +711,13 @@ def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
625711
logger.info("Generating Mapper figure")
626712
mapper_fig = mapper_plot.plot_plotly(
627713
colors,
628-
node_size=node_size,
714+
node_size=[
715+
0.0,
716+
node_size / 2.0,
717+
node_size,
718+
node_size * 1.5,
719+
node_size * 2.0,
720+
],
629721
agg=_agg,
630722
title=[f"{c}" for c in colors.columns],
631723
cmap=cmap,

src/tdamapper/_plot_plotly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def plot_plotly(
7373
titles = [title for _ in range(colors_num)]
7474
elif isinstance(title, list) and len(title) == colors_num:
7575
titles = title
76-
node_sizes = [node_size] if isinstance(node_size, int) else node_size
76+
node_sizes = [node_size] if isinstance(node_size, (int, float)) else node_size
7777
fig = _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps)
7878
_add_ui_to_layout(mapper_plot, fig, colors, titles, node_sizes, agg, cmaps)
7979
return fig

src/tdamapper/clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, cover=None, clustering=None, n_jobs=1):
4343
self.n_jobs = n_jobs
4444

4545
def fit(self, X, y=None):
46+
y = X if y is None else y
4647
X, y = self._validate_X_y(X, y)
4748
cover = TrivialCover() if self.cover is None else self.cover
4849
cover = clone(cover)
@@ -53,7 +54,6 @@ def fit(self, X, y=None):
5354
)
5455
clustering = clone(clustering)
5556
n_jobs = self.n_jobs
56-
y = X if y is None else y
5757
itm_lbls = mapper_connected_components(
5858
X,
5959
y,

src/tdamapper/utils/heap.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,91 +13,91 @@ def _parent(i):
1313
class _HeapNode:
1414

1515
def __init__(self, key, value):
16-
self.__key = key
17-
self.__value = value
16+
self._key = key
17+
self._value = value
1818

1919
def get(self):
20-
return self.__key, self.__value
20+
return self._key, self._value
2121

2222
def __lt__(self, other):
23-
return self.__key < other
23+
return self._key < other._key
2424

2525
def __le__(self, other):
26-
return self.__key <= other
26+
return self._key <= other._key
2727

2828
def __gt__(self, other):
29-
return self.__key > other
29+
return self._key > other._key
3030

3131
def __ge__(self, other):
32-
return self.__key >= other
32+
return self._key >= other._key
3333

3434

3535
class MaxHeap:
3636

3737
def __init__(self):
38-
self.__heap = []
39-
self.__iter = None
38+
self._heap = []
39+
self._iter = None
4040

4141
def __iter__(self):
42-
self.__iter = iter(self.__heap)
42+
self._iter = iter(self._heap)
4343
return self
4444

4545
def __next__(self):
46-
node = next(self.__iter)
46+
node = next(self._iter)
4747
return node.get()
4848

4949
def __len__(self):
50-
return len(self.__heap)
50+
return len(self._heap)
5151

5252
def top(self):
53-
if not self.__heap:
53+
if not self._heap:
5454
return (None, None)
55-
return self.__heap[0].get()
55+
return self._heap[0].get()
5656

5757
def pop(self):
58-
if not self.__heap:
58+
if not self._heap:
5959
return
60-
max_val = self.__heap[0]
61-
self.__heap[0] = self.__heap[-1]
62-
self.__heap.pop()
60+
max_val = self._heap[0]
61+
self._heap[0] = self._heap[-1]
62+
self._heap.pop()
6363
self._bubble_down()
6464
return max_val.get()
6565

6666
def add(self, key, val):
67-
self.__heap.append(_HeapNode(key, val))
67+
self._heap.append(_HeapNode(key, val))
6868
self._bubble_up()
6969

7070
def _get_local_max(self, i):
71-
heap_len = len(self.__heap)
71+
heap_len = len(self._heap)
7272
left = _left(i)
7373
right = _right(i)
7474
if left >= heap_len:
7575
return i
7676
if right >= heap_len:
77-
if self.__heap[i] < self.__heap[left]:
77+
if self._heap[i] < self._heap[left]:
7878
return left
7979
return i
8080
max_child = left
81-
if self.__heap[left] < self.__heap[right]:
81+
if self._heap[left] < self._heap[right]:
8282
max_child = right
83-
if self.__heap[i] < self.__heap[max_child]:
83+
if self._heap[i] < self._heap[max_child]:
8484
return max_child
8585
return i
8686

8787
def _fix_down(self, i):
8888
local_max = self._get_local_max(i)
8989
if i < local_max:
90-
self.__heap[i], self.__heap[local_max] = (
91-
self.__heap[local_max],
92-
self.__heap[i],
90+
self._heap[i], self._heap[local_max] = (
91+
self._heap[local_max],
92+
self._heap[i],
9393
)
9494
return local_max
9595
return i
9696

9797
def _fix_up(self, i):
9898
parent = _parent(i)
99-
if self.__heap[parent] < self.__heap[i]:
100-
self.__heap[i], self.__heap[parent] = self.__heap[parent], self.__heap[i]
99+
if self._heap[parent] < self._heap[i]:
100+
self._heap[i], self._heap[parent] = self._heap[parent], self._heap[i]
101101
return parent
102102
return i
103103

@@ -110,7 +110,7 @@ def _bubble_down(self):
110110
current = local_max
111111

112112
def _bubble_up(self):
113-
current = len(self.__heap) - 1
113+
current = len(self._heap) - 1
114114
done = False
115115
while not done:
116116
local_max = self._fix_up(current)

0 commit comments

Comments
 (0)