Skip to content

Commit ae337d0

Browse files
committed
Improved docstrings
1 parent 867a761 commit ae337d0

File tree

1 file changed

+199
-14
lines changed

1 file changed

+199
-14
lines changed

src/tdamapper/app.py

Lines changed: 199 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This module provides a web app for visualizing Mapper graphs.
3+
"""
4+
15
import logging
26
import os
37
from dataclasses import asdict, dataclass
@@ -31,6 +35,18 @@
3135

3236
LOGO_URL = f"{GIT_REPO_URL}/raw/main/docs/source/logos/tda-mapper-logo-horizontal.png"
3337

38+
ABOUT_TEXT = """
39+
### About
40+
41+
**tda-mapper** is a Python library built around the Mapper algorithm, a core
42+
technique in Topological Data Analysis (TDA) for extracting topological
43+
structure from complex data. Designed for computational efficiency and
44+
scalability, it leverages optimized spatial search methods to support
45+
high-dimensional datasets. You can find further details in the
46+
[documentation](https://tda-mapper.readthedocs.io/en/main/)
47+
and in the
48+
[paper](https://openreview.net/pdf?id=lTX4bYREAZ).
49+
"""
3450

3551
LOAD_EXAMPLE = "Example"
3652
LOAD_EXAMPLE_DIGITS = "Digits"
@@ -77,6 +93,32 @@
7793

7894
@dataclass
7995
class MapperConfig:
96+
"""
97+
Configuration for the Mapper algorithm.
98+
99+
:param lens_type: Type of lens to use for dimensionality reduction.
100+
:param cover_scale_data: Whether to scale the data before covering.
101+
:param cover_type: Type of cover to use for the Mapper algorithm.
102+
:param clustering_scale_data: Whether to scale the data before clustering.
103+
:param clustering_type: Type of clustering algorithm to use.
104+
:param lens_pca_n_components: Number of components for PCA lens.
105+
:param lens_umap_n_components: Number of components for UMAP lens.
106+
:param cover_cubical_n_intervals: Number of intervals for cubical cover.
107+
:param cover_cubical_overlap_frac: Overlap fraction for cubical cover.
108+
:param cover_ball_radius: Radius for ball cover.
109+
:param cover_knn_neighbors: Number of neighbors for KNN cover.
110+
:param clustering_kmeans_n_clusters: Number of clusters for KMeans
111+
clustering.
112+
:param clustering_dbscan_eps: Epsilon parameter for DBSCAN clustering.
113+
:param clustering_dbscan_min_samples: Minimum samples for DBSCAN
114+
clustering.
115+
:param clustering_agglomerative_n_clusters: Number of clusters for
116+
Agglomerative clustering.
117+
:param plot_dimensions: Number of dimensions for the plot (2D or 3D).
118+
:param plot_iterations: Number of iterations for the plot.
119+
:param plot_seed: Random seed for reproducibility.
120+
"""
121+
80122
lens_type: str = LENS_PCA
81123
cover_scale_data: bool = COVER_SCALE_DATA
82124
cover_type: str = COVER_CUBICAL
@@ -98,6 +140,14 @@ class MapperConfig:
98140

99141

100142
def fix_data(data: pd.DataFrame) -> pd.DataFrame:
143+
"""
144+
Fixes the input data by selecting numeric columns, dropping empty columns,
145+
and filling NaN values with the mean of each column.
146+
147+
:param data: Input DataFrame to be fixed.
148+
:return: Fixed DataFrame with numeric columns, no empty columns, and NaN
149+
values filled with column means.
150+
"""
101151
df = pd.DataFrame(data)
102152
df = df.select_dtypes(include="number")
103153
df.dropna(axis=1, how="all", inplace=True)
@@ -106,10 +156,25 @@ def fix_data(data: pd.DataFrame) -> pd.DataFrame:
106156

107157

108158
def lens_identity(X: NDArray[np.float_]) -> NDArray[np.float_]:
159+
"""
160+
Identity lens function that returns the input data as is.
161+
162+
:param X: Input data as a NumPy array.
163+
:return: The same input data as a NumPy array.
164+
"""
109165
return X
110166

111167

112168
def lens_pca(n_components: int) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]:
169+
"""
170+
Creates a lens function that reduces the dimensionality of the input data.
171+
This function applies PCA to the input data and returns the transformed
172+
data.
173+
174+
:param n_components: Number of components to keep after PCA.
175+
:return: A function that applies PCA to the input data and returns the
176+
transformed data.
177+
"""
113178

114179
def _pca(X: NDArray[np.float_]) -> NDArray[np.float_]:
115180
pca_model = PCA(n_components=n_components, random_state=RANDOM_SEED)
@@ -119,6 +184,15 @@ def _pca(X: NDArray[np.float_]) -> NDArray[np.float_]:
119184

120185

121186
def lens_umap(n_components: int) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]:
187+
"""
188+
Creates a lens function that reduces the dimensionality of the input data.
189+
This function applies UMAP to the input data and returns the transformed
190+
data.
191+
192+
:param n_components: Number of components to keep after UMAP.
193+
:return: A function that applies UMAP to the input data and returns the
194+
transformed data.
195+
"""
122196

123197
def _umap(X: NDArray[np.float_]) -> NDArray[np.float_]:
124198
um = UMAP(n_components=n_components, random_state=RANDOM_SEED)
@@ -130,6 +204,15 @@ def _umap(X: NDArray[np.float_]) -> NDArray[np.float_]:
130204
def run_mapper(
131205
df: pd.DataFrame, **kwargs: dict[str, Any]
132206
) -> Optional[tuple[nx.Graph, pd.DataFrame]]:
207+
"""
208+
Runs the Mapper algorithm on the provided DataFrame and returns the Mapper
209+
graph and the transformed DataFrame.
210+
211+
:param df: Input DataFrame containing the data to be processed.
212+
:param kwargs: Additional parameters for the Mapper configuration.
213+
:return: A tuple containing the Mapper graph and the transformed DataFrame,
214+
or None if the computation fails.
215+
"""
133216
logger.info("Mapper computation started...")
134217
if df is None or df.empty:
135218
error = "Mapper computation failed: no data found, please load data first."
@@ -220,6 +303,16 @@ def create_mapper_figure(
220303
mapper_graph: nx.Graph,
221304
**kwargs: dict[str, Any],
222305
) -> go.Figure:
306+
"""
307+
Renders the Mapper graph as a Plotly figure.
308+
309+
:param df_X: DataFrame containing the input data.
310+
:param df_y: DataFrame containing the lens-transformed data.
311+
:param df_target: DataFrame containing the target labels.
312+
:param mapper_graph: The Mapper graph to be visualized.
313+
:param kwargs: Additional parameters for the Mapper configuration.
314+
:return: A Plotly figure representing the Mapper graph.
315+
"""
223316
logger.info("Mapper rendering started...")
224317
df_colors = pd.concat([df_target, df_y, df_X], axis=1)
225318
params: dict[str, Any] = kwargs
@@ -257,6 +350,39 @@ def create_mapper_figure(
257350

258351

259352
class App:
353+
"""
354+
Main application class for the Mapper web application.
355+
356+
This class initializes the user interface, handles data loading, and runs
357+
the Mapper algorithm.
358+
359+
:param storage: Dictionary to store application state and data.
360+
:param draw_area: Optional draw area for rendering Mapper graphs.
361+
:param plot_container: Container for the plot area.
362+
:param left_drawer: Drawer for the left sidebar containing controls and
363+
settings.
364+
:param lens_type: Type of lens to use for dimensionality reduction.
365+
:param cover_type: Type of cover to use for the Mapper algorithm.
366+
:param clustering_type: Type of clustering algorithm to use.
367+
:param lens_pca_n_components: Number of components for PCA lens.
368+
:param lens_umap_n_components: Number of components for UMAP lens.
369+
:param cover_cubical_n_intervals: Number of intervals for cubical cover.
370+
:param cover_cubical_overlap_frac: Overlap fraction for cubical cover.
371+
:param cover_ball_radius: Radius for ball cover.
372+
:param cover_knn_neighbors: Number of neighbors for KNN cover.
373+
:param clustering_kmeans_n_clusters: Number of clusters for KMeans
374+
clustering.
375+
:param clustering_dbscan_eps: Epsilon parameter for DBSCAN clustering.
376+
:param clustering_dbscan_min_samples: Minimum samples for DBSCAN
377+
clustering.
378+
:param clustering_agglomerative_n_clusters: Number of clusters for
379+
Agglomerative clustering.
380+
:param plot_dimensions: Number of dimensions for the plot (2D or 3D).
381+
:param plot_iterations: Number of iterations for the plot.
382+
:param plot_seed: Random seed for reproducibility.
383+
:param load_type: Type of data loading (example or CSV).
384+
:param load_example: Example dataset to load if using example data.
385+
"""
260386

261387
lens_type: Any
262388
cover_type: Any
@@ -341,20 +467,7 @@ def __init__(self, storage: dict[str, Any]) -> None:
341467

342468
def _init_about(self) -> None:
343469
with ui.dialog() as dialog, ui.card():
344-
ui.markdown(
345-
"""
346-
### About
347-
348-
**tda-mapper** is a Python library built around the Mapper algorithm, a core
349-
technique in Topological Data Analysis (TDA) for extracting topological
350-
structure from complex data. Designed for computational efficiency and
351-
scalability, it leverages optimized spatial search methods to support
352-
high-dimensional datasets. You can find further details in the
353-
[documentation](https://tda-mapper.readthedocs.io/en/main/)
354-
and in the
355-
[paper](https://openreview.net/pdf?id=lTX4bYREAZ).
356-
"""
357-
)
470+
ui.markdown(ABOUT_TEXT)
358471
ui.link(
359472
text="If you like this project, please consider giving it a ⭐ on GitHub!",
360473
target=GIT_REPO_URL,
@@ -596,6 +709,11 @@ def _toggle_drawer() -> None:
596709
).props("fab color=themedark")
597710

598711
def get_mapper_config(self) -> MapperConfig:
712+
"""
713+
Retrieves the current configuration settings for the Mapper algorithm.
714+
715+
:return: A MapperConfig object containing the current settings.
716+
"""
599717
plot_dim = int(self.plot_dimensions.value)
600718
plot_dimensions: Literal[2, 3]
601719
if plot_dim == 2:
@@ -676,6 +794,13 @@ def get_mapper_config(self) -> MapperConfig:
676794
)
677795

678796
def upload_file(self, file: Any) -> None:
797+
"""
798+
Handles the file upload event, reads the CSV file,
799+
and stores the data in the application storage.
800+
801+
:param file: The uploaded file object.
802+
:return: None
803+
"""
679804
if file is not None:
680805
df = pd.read_csv(file.content)
681806
self.storage["df"] = fix_data(df)
@@ -689,6 +814,15 @@ def upload_file(self, file: Any) -> None:
689814
ui.notify(error, type="warning")
690815

691816
def load_data(self) -> None:
817+
"""
818+
Loads example datasets or CSV files based on the selected load type.
819+
820+
If the load type is set to "Example", it loads either the Digits or
821+
Iris dataset. If the load type is set to "CSV", it checks if a
822+
DataFrame is already stored in the application storage and uses it.
823+
824+
:return: None
825+
"""
692826
if self.load_type.value == LOAD_EXAMPLE:
693827
if self.load_example.value == LOAD_EXAMPLE_DIGITS:
694828
df, labels = load_digits(as_frame=True, return_X_y=True)
@@ -724,6 +858,14 @@ def load_data(self) -> None:
724858
ui.notify(error, type="warning")
725859

726860
def notification_running_start(self, message: str) -> Any:
861+
"""
862+
Starts a notification to indicate that a long-running operation is in
863+
progress.
864+
865+
:param message: The message to display in the notification.
866+
:return: A notification object that can be used to update the message
867+
and status.
868+
"""
727869
notification = ui.notification(timeout=None, type="ongoing")
728870
notification.message = message
729871
notification.spinner = True
@@ -732,13 +874,31 @@ def notification_running_start(self, message: str) -> Any:
732874
def notification_running_stop(
733875
self, notification: Any, message: str, type: Optional[str] = None
734876
) -> None:
877+
"""
878+
Stops the notification and updates it with the final message and type.
879+
880+
:param notification: The notification object to update.
881+
:param message: The final message to display in the notification.
882+
:param type: The type of notification.
883+
:return: None
884+
"""
735885
if type is not None:
736886
notification.type = type
737887
notification.message = message
738888
notification.timeout = 5.0
739889
notification.spinner = False
740890

741891
async def async_run_mapper(self) -> None:
892+
"""
893+
Runs the Mapper algorithm on the loaded data and updates the storage
894+
with the Mapper graph and transformed DataFrame.
895+
896+
This method retrieves the input DataFrame from storage, applies the
897+
Mapper algorithm, and stores the resulting Mapper graph and transformed
898+
DataFrame back into storage.
899+
900+
:return: None
901+
"""
742902
notification = self.notification_running_start("Running Mapper...")
743903
df_X = self.storage.get("df", pd.DataFrame())
744904
if df_X is None or df_X.empty:
@@ -764,6 +924,15 @@ async def async_run_mapper(self) -> None:
764924
await self.async_draw_mapper()
765925

766926
async def async_draw_mapper(self) -> None:
927+
"""
928+
Draws the Mapper graph using the stored graph and input data.
929+
930+
This method retrieves the Mapper graph and input DataFrame from
931+
storage, creates a Plotly figure representing the Mapper graph, and
932+
updates the draw area in the user interface with the new figure.
933+
934+
:return: None
935+
"""
767936
notification = self.notification_running_start("Drawing Mapper...")
768937

769938
mapper_config = self.get_mapper_config()
@@ -808,14 +977,30 @@ async def async_draw_mapper(self) -> None:
808977

809978

810979
def startup() -> None:
980+
"""
981+
Initializes the NiceGUI app and sets up the main page.
982+
983+
:return: None
984+
"""
985+
811986
@ui.page("/")
812987
def main_page() -> None:
988+
"""
989+
Main page of the application.
990+
991+
:return: None
992+
"""
813993
ui.query(".nicegui-content").classes("p-0")
814994
storage = app.storage.client
815995
App(storage=storage)
816996

817997

818998
def main() -> None:
999+
"""
1000+
Main entry point for the Mapper web application.
1001+
1002+
:return: None
1003+
"""
8191004
port = os.getenv("PORT", "8080")
8201005
host = os.getenv("HOST", "0.0.0.0")
8211006
production = os.getenv("PRODUCTION", "false").lower() == "true"

0 commit comments

Comments
 (0)