Skip to content

Commit 5578d4a

Browse files
authored
Merge pull request #221 from lucasimi/use-pytest
Refactored tests to use pytest instead of unittest
2 parents 27d51f3 + e13162e commit 5578d4a

30 files changed

+1207
-1190
lines changed

.github/workflows/test-bench.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
5454
- name: Run benchmarks
5555
run: |
56-
python -m unittest discover -s tests -p 'test_bench_*.py'
56+
python -m pytest tests/test_bench_*.py -s
5757
5858
test-bench-job:
5959
needs: test-bench-matrix-job

.github/workflows/test-unit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
5454
- name: Run tests and code coverage
5555
run: |
56-
coverage run --source=src -m unittest discover -s tests -p 'test_unit_*.py'
56+
coverage run --source=src -m pytest tests/test_unit_*.py
5757
coverage report -m
5858
5959
- name: Upload coverage reports to Codecov

CONTRIBUTING.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ Follow these steps to contribute:
6767
the naming convention `test_bench_*.py`.
6868

6969
4. **Run Tests**.
70-
Ensure your changes pass all tests before committing. We use `unittest` as
70+
Ensure your changes pass all tests before committing. We use `pytest` as
7171
test framework:
7272

7373
```bash
74-
python -m unittest discover -s tests -p 'test_*.py'
74+
python -m pytest tests/test_*.py
7575
```
7676

7777
Before each commit make sure to check code coverage:
7878

7979
```bash
80-
coverage run --source=src -m unittest discover -s tests -p 'test_*.py'
80+
coverage run --source=src -m pytest tests/test_*.py
8181
```
8282

8383
5. **Commit and Push Your Changes**.

app/streamlit_app.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from sklearn.decomposition import PCA
2323
from umap import UMAP
2424

25-
from tdamapper._plot_plotly import _marker_size
2625
from tdamapper.core import aggregate_graph
2726
from tdamapper.cover import BallCover, CubicalCover
2827
from tdamapper.learn import MapperAlgorithm
@@ -137,14 +136,14 @@ def _check_limits_mapper_graph(mapper_graph):
137136
if LIMITS_ENABLED:
138137
num_nodes = mapper_graph.number_of_nodes()
139138
if num_nodes > LIMITS_NUM_NODES:
140-
logging.warn("Too many nodes.")
139+
logging.warning("Too many nodes.")
141140
raise ValueError(
142141
"Too many nodes: select different parameters or run the app "
143142
"locally on your machine."
144143
)
145144
num_edges = mapper_graph.number_of_edges()
146145
if num_edges > LIMITS_NUM_EDGES:
147-
logging.warn("Too many edges.")
146+
logging.warning("Too many edges.")
148147
raise ValueError(
149148
"Too many edges: select different parameters or run the app "
150149
"locally on your machine."
@@ -155,14 +154,14 @@ def _check_limits_dataset(df_X, df_y):
155154
if LIMITS_ENABLED:
156155
num_samples = len(df_X)
157156
if num_samples > LIMITS_NUM_SAMPLES:
158-
logging.warn("Dataset too big.")
157+
logging.warning("Dataset too big.")
159158
raise ValueError(
160159
"Dataset too big: select a different dataset or run the app "
161160
"locally on your machine."
162161
)
163162
num_features = len(df_X.columns) + len(df_y.columns)
164163
if num_features > LIMITS_NUM_FEATURES:
165-
logging.warn("Too many features.")
164+
logging.warning("Too many features.")
166165
raise ValueError(
167166
"Too many features: select a different dataset or run the app "
168167
"locally on your machine."
@@ -529,8 +528,8 @@ def mapper_input_section(X):
529528
mapper_algo = MapperAlgorithm(
530529
cover=cover,
531530
clustering=clustering,
532-
verbose=True,
533-
n_jobs=1,
531+
verbose=False,
532+
n_jobs=-2,
534533
)
535534
mapper_graph = compute_mapper(mapper_algo, X, lens)
536535
return mapper_graph
@@ -628,11 +627,12 @@ def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
628627
colors,
629628
node_size=node_size,
630629
agg=_agg,
631-
title=[f"{agg_name} of {c}" for c in colors.columns],
630+
title=[f"{c}" for c in colors.columns],
632631
cmap=cmap,
633632
width=600,
634633
height=600,
635634
)
635+
logger.info("Done")
636636
return mapper_fig
637637

638638

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dev = [
4646
"coverage[toml]",
4747
"pandas",
4848
"scikit-learn<1.6.0",
49+
"pytest",
4950
"black[jupyter]",
5051
"isort",
5152
"flake8",

src/tdamapper/_plot_plotly.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def plot_plotly(
5757
mapper_plot,
5858
width: int,
5959
height: int,
60-
node_size: int = DEFAULT_NODE_SIZE,
60+
node_size: Optional[Union[int, List[int]]] = DEFAULT_NODE_SIZE,
6161
colors=None,
6262
title: Optional[Union[str, List[str]]] = None,
6363
agg=np.nanmean,
@@ -73,8 +73,9 @@ def plot_plotly(
7373
titles = [title for _ in range(colors_num)]
7474
elif isinstance(title, list) and len(title) == colors_num:
7575
titles = title
76-
fig = _figure(mapper_plot, width, height, node_size, colors, titles, agg, cmaps)
77-
_add_ui_to_layout(mapper_plot, fig, colors, titles, node_size, agg, cmaps)
76+
node_sizes = [node_size] if isinstance(node_size, int) else node_size
77+
fig = _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps)
78+
_add_ui_to_layout(mapper_plot, fig, colors, titles, node_sizes, agg, cmaps)
7879
return fig
7980

8081

@@ -220,7 +221,7 @@ def _update_layout(fig, width, height):
220221
)
221222

222223

223-
def _figure(mapper_plot, width, height, node_size, colors, titles, agg, cmaps):
224+
def _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps):
224225
node_pos = mapper_plot.positions
225226
node_pos_arr = _node_pos_array(
226227
mapper_plot.graph,
@@ -239,7 +240,7 @@ def _figure(mapper_plot, width, height, node_size, colors, titles, agg, cmaps):
239240

240241
_set_cmap(mapper_plot, fig, cmaps[0])
241242
_set_colors(mapper_plot, fig, colors[:, 0], agg)
242-
_set_node_size(mapper_plot, fig, node_size)
243+
_set_node_size(mapper_plot, fig, node_sizes[len(node_sizes) // 2])
243244
_set_title(mapper_plot, fig, titles[0])
244245

245246
return fig
@@ -387,7 +388,7 @@ def _layout(width, height):
387388
)
388389

389390

390-
def _add_ui_to_layout(mapper_plot, mapper_fig, colors, titles, node_size, agg, cmaps):
391+
def _add_ui_to_layout(mapper_plot, mapper_fig, colors, titles, node_sizes, agg, cmaps):
391392
cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps]
392393
menu_color = _ui_color(mapper_plot, colors, titles, agg)
393394
if menu_color["buttons"]:
@@ -396,7 +397,7 @@ def _add_ui_to_layout(mapper_plot, mapper_fig, colors, titles, node_size, agg, c
396397
menu_color["x"] = -0.25
397398
menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly)
398399
menu_cmap["x"] = menu_color["x"] + 0.25
399-
slider_size = _ui_node_size(mapper_plot, node_size)
400+
slider_size = _ui_node_size(mapper_plot, node_sizes)
400401
mapper_fig.update_layout(
401402
updatemenus=[menu_cmap, menu_color],
402403
sliders=[slider_size],
@@ -441,7 +442,7 @@ def _update_cmap(cmap):
441442
)
442443

443444

444-
def _ui_node_size(mapper_plot, node_size):
445+
def _ui_node_size(mapper_plot, node_sizes):
445446
steps = [
446447
dict(
447448
method="restyle",
@@ -451,7 +452,7 @@ def _ui_node_size(mapper_plot, node_size):
451452
[1],
452453
],
453454
)
454-
for size in [node_size * x / 10.0 for x in range(1, 20)]
455+
for size in node_sizes
455456
]
456457

457458
return dict(

src/tdamapper/utils/vptree_flat/ball_search.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ class BallSearch:
55

66
def __init__(self, vpt, point, eps, inclusive=True):
77
self._arr = vpt._get_arr()
8-
self.__distance = vpt._get_distance()
9-
self.__point = point
10-
self.__eps = eps
11-
self.__inclusive = inclusive
8+
self._distance = vpt._get_distance()
9+
self._point = point
10+
self._eps = eps
11+
self._inclusive = inclusive
1212

1313
def search(self):
1414
return self._search_iter()
1515

1616
def _inside(self, dist):
17-
if self.__inclusive:
18-
return dist <= self.__eps
19-
return dist < self.__eps
17+
if self._inclusive:
18+
return dist <= self._eps
19+
return dist < self._eps
2020

2121
def _search_iter(self):
2222
stack = [(0, self._arr.size())]
@@ -28,11 +28,11 @@ def _search_iter(self):
2828
is_terminal = self._arr.is_terminal(start)
2929
if is_terminal:
3030
for x in self._arr.get_points(start, end):
31-
dist = self.__distance(self.__point, x)
31+
dist = self._distance(self._point, x)
3232
if self._inside(dist):
3333
result.append(x)
3434
else:
35-
dist = self.__distance(self.__point, v_point)
35+
dist = self._distance(self._point, v_point)
3636
mid = _mid(start, end)
3737
if self._inside(dist):
3838
result.append(v_point)
@@ -42,7 +42,7 @@ def _search_iter(self):
4242
else:
4343
fst = (mid, end)
4444
snd = (start + 1, mid)
45-
if abs(dist - v_radius) <= self.__eps:
45+
if abs(dist - v_radius) <= self._eps:
4646
stack.append(snd)
4747
stack.append(fst)
4848
return result

src/tdamapper/utils/vptree_flat/builder.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@ def _mid(start, end):
1212
class Builder:
1313

1414
def __init__(self, vpt, X):
15-
self.__distance = vpt._get_distance()
15+
self._distance = vpt._get_distance()
1616

1717
dataset = [x for x in X]
1818
indices = np.array([i for i in range(len(dataset))])
1919
distances = np.array([0.0 for _ in X])
2020
is_terminal = np.array([False for _ in X])
2121
self._arr = VPArray(dataset, distances, indices, is_terminal)
2222

23-
self.__leaf_capacity = vpt.get_leaf_capacity()
24-
self.__leaf_radius = vpt.get_leaf_radius()
23+
self._leaf_capacity = vpt.get_leaf_capacity()
24+
self._leaf_radius = vpt.get_leaf_radius()
2525
pivoting = vpt.get_pivoting()
26-
self.__pivoting = self._pivoting_disabled
26+
self._pivoting = self._pivoting_disabled
2727
if pivoting == "random":
28-
self.__pivoting = self._pivoting_random
28+
self._pivoting = self._pivoting_random
2929
elif pivoting == "furthest":
30-
self.__pivoting = self._pivoting_furthest
30+
self._pivoting = self._pivoting_furthest
3131

3232
def _pivoting_disabled(self, start, end):
3333
pass
@@ -45,7 +45,7 @@ def _furthest(self, start, end, i):
4545
i_point = self._arr.get_point(i)
4646
for j in range(start, end):
4747
j_point = self._arr.get_point(j)
48-
j_dist = self.__distance(i_point, j_point)
48+
j_dist = self._distance(i_point, j_point)
4949
if j_dist > furthest_dist:
5050
furthest = j
5151
furthest_dist = j_dist
@@ -61,12 +61,12 @@ def _pivoting_furthest(self, start, end):
6161
self._arr.swap(start, furthest)
6262

6363
def _update(self, start, end):
64-
self.__pivoting(start, end)
64+
self._pivoting(start, end)
6565
v_point = self._arr.get_point(start)
6666
is_terminal = self._arr.is_terminal(start)
6767
for i in range(start + 1, end):
6868
point = self._arr.get_point(i)
69-
self._arr.set_distance(i, self.__distance(v_point, point))
69+
self._arr.set_distance(i, self._distance(v_point, point))
7070
self._arr.set_terminal(i, is_terminal)
7171

7272
def build(self):
@@ -81,8 +81,8 @@ def _build_iter(self):
8181
self._update(start, end)
8282
self._arr.partition(start + 1, end, mid)
8383
v_radius = self._arr.get_distance(mid)
84-
if (end - start > 2 * self.__leaf_capacity) and (
85-
v_radius > self.__leaf_radius
84+
if (end - start > 2 * self._leaf_capacity) and (
85+
v_radius > self._leaf_radius
8686
):
8787
self._arr.set_distance(start, v_radius)
8888
self._arr.set_terminal(start, False)
Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,43 @@
11
from tdamapper.utils.heap import MaxHeap
22
from tdamapper.utils.vptree_flat.common import _mid
33

4+
_PRE = 0
5+
_POST = 1
6+
47

58
class KnnSearch:
69

710
def __init__(self, vpt, point, neighbors):
811
self._arr = vpt._get_arr()
9-
self.__distance = vpt._get_distance()
10-
self.__point = point
11-
self.__neighbors = neighbors
12-
self.__radius = float("inf")
13-
self.__result = MaxHeap()
12+
self._distance = vpt._get_distance()
13+
self._point = point
14+
self._neighbors = neighbors
15+
self._radius = float("inf")
16+
self._result = MaxHeap()
1417

1518
def _get_items(self):
16-
while len(self.__result) > self.__neighbors:
17-
self.__result.pop()
18-
return [x for (_, x) in self.__result]
19+
while len(self._result) > self._neighbors:
20+
self._result.pop()
21+
return [x for (_, x) in self._result]
1922

2023
def search(self):
2124
self._search_iter()
2225
return self._get_items()
2326

2427
def _process(self, x):
25-
dist = self.__distance(self.__point, x)
26-
if dist >= self.__radius:
28+
dist = self._distance(self._point, x)
29+
if dist >= self._radius:
2730
return dist
28-
self.__result.add(dist, x)
29-
while len(self.__result) > self.__neighbors:
30-
self.__result.pop()
31-
if len(self.__result) == self.__neighbors:
32-
self.__radius, _ = self.__result.top()
31+
self._result.add(dist, x)
32+
while len(self._result) > self._neighbors:
33+
self._result.pop()
34+
if len(self._result) == self._neighbors:
35+
self._radius, _ = self._result.top()
3336
return dist
3437

3538
def _search_iter(self):
36-
PRE, POST = 0, 1
37-
self.__result = MaxHeap()
38-
stack = [(0, self._arr.size(), 0.0, PRE)]
39+
self._result = MaxHeap()
40+
stack = [(0, self._arr.size(), 0.0, _PRE)]
3941
while stack:
4042
start, end, thr, action = stack.pop()
4143

@@ -47,7 +49,7 @@ def _search_iter(self):
4749
for x in self._arr.get_points(start, end):
4850
self._process(x)
4951
else:
50-
if action == PRE:
52+
if action == _PRE:
5153
mid = _mid(start, end)
5254
dist = self._process(v_point)
5355
if dist <= v_radius:
@@ -56,9 +58,9 @@ def _search_iter(self):
5658
else:
5759
fst_start, fst_end = mid, end
5860
snd_start, snd_end = start + 1, mid
59-
stack.append((snd_start, snd_end, abs(v_radius - dist), POST))
60-
stack.append((fst_start, fst_end, 0.0, PRE))
61-
elif action == POST:
62-
if self.__radius > thr:
63-
stack.append((start, end, 0.0, PRE))
61+
stack.append((snd_start, snd_end, abs(v_radius - dist), _POST))
62+
stack.append((fst_start, fst_end, 0.0, _PRE))
63+
elif action == _POST:
64+
if self._radius > thr:
65+
stack.append((start, end, 0.0, _PRE))
6466
return self._get_items()

0 commit comments

Comments
 (0)