4545import ot
4646import networkx
4747from networkx .generators .community import stochastic_block_model as sbm
48- # %%
49- # =============================================================================
48+
49+ #############################################################################
50+ #
5051# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
51- # =============================================================================
52+ # ---------------------------------------------
5253
5354np .random .seed (42 )
5455
@@ -109,10 +110,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
109110pl .tight_layout ()
110111pl .show ()
111112
112- # %%
113- # =============================================================================
113+ #############################################################################
114+ #
114115# Estimate the gromov-wasserstein dictionary from the dataset
115- # =============================================================================
116+ # ---------------------------------------------
116117
117118
118119np .random .seed (0 )
@@ -140,10 +141,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
140141pl .tight_layout ()
141142pl .show ()
142143
143- # %%
144- # =============================================================================
144+ #############################################################################
145+ #
145146# Visualization of the estimated dictionary atoms
146- # =============================================================================
147+ # ---------------------------------------------
147148
148149
149150# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
164165 pl .axis ("off" )
165166pl .tight_layout ()
166167pl .show ()
167- #%%
168- # =============================================================================
168+
169+ #############################################################################
170+ #
169171# Visualization of the embedding space
170- # =============================================================================
172+ # ---------------------------------------------
171173
172174unmixings = []
173175reconstruction_errors = []
@@ -211,11 +213,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
211213pl .legend (fontsize = 11 )
212214pl .tight_layout ()
213215pl .show ()
214- # %%
215- # =============================================================================
216- # Endow the dataset with node features
217- # =============================================================================
218216
217+ #############################################################################
218+ #
219+ # Endow the dataset with node features
220+ # ---------------------------------------------
219221# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
220222# 1 cluster --> 0 as nodes feature
221223# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
251253 pl .axis ("off" )
252254pl .tight_layout ()
253255pl .show ()
254- # %%
255- # =============================================================================
256+
257+ #############################################################################
258+ #
256259# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
257- # =============================================================================
260+ # ---------------------------------------------
258261np .random .seed (0 )
259262ps = [ot .unif (C .shape [0 ]) for C in dataset ]
260263D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
280283pl .tight_layout ()
281284pl .show ()
282285
283- # %%
284- # =============================================================================
286+ #############################################################################
287+ #
285288# Visualization of the estimated dictionary atoms
286- # =============================================================================
289+ # ---------------------------------------------
287290
288291pl .figure (7 , (12 , 8 ))
289292pl .clf ()
@@ -307,10 +310,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
307310pl .tight_layout ()
308311pl .show ()
309312
310- # %%
311- # =============================================================================
313+ #############################################################################
314+ #
312315# Visualization of the embedding space
313- # =============================================================================
316+ # ---------------------------------------------
314317
315318unmixings = []
316319reconstruction_errors = []
0 commit comments