Skip to content

Commit a293e89

Browse files
committed
added vortexmap 2d functionality
Constructing MultiContrast class with a 2d array in the form of a nested list now works to produce a 2d vortexmap.
1 parent 9ca8aa1 commit a293e89

File tree

3 files changed

+344
-20
lines changed

3 files changed

+344
-20
lines changed

dabest/multi.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _parse_contrast_structure(contrasts, labels=None):
9595
# Handle 2D labels
9696
if labels and isinstance(labels[0], (list, tuple)):
9797
row_labels = [labels[i][0] for i in range(n_rows)]
98-
col_labels = labels[0]
98+
col_labels = labels[0][1:]
9999
else:
100100
row_labels = [f"Row {i+1}" for i in range(n_rows)]
101101
col_labels = [f"Col {j+1}" for j in range(n_cols)]
@@ -188,8 +188,8 @@ def _spiralize(fill, m, n):
188188
return array
189189

190190
# %% ../nbs/API/multi.ipynb 12
191-
def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3,
192-
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None, **kwargs):
191+
def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None,
192+
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):
193193
"""
194194
Create a vortexmap visualization of multiple contrasts.
195195
@@ -201,7 +201,7 @@ def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3,
201201
Size of each spiral (n x n grid per contrast)
202202
sort_by : list, optional
203203
Order to sort contrasts by
204-
vmax, vmin : float, default 3, -3
204+
vmax, vmin : float, default None, None
205205
Color scale limits
206206
reverse_neg : bool, default True
207207
Whether to reverse negative values
@@ -227,6 +227,7 @@ def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3,
227227
contrasts_2d = structure['contrasts_2d']
228228

229229
spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n)))
230+
230231
mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)),
231232
columns=col_labels,
232233
index=row_labels)
@@ -251,19 +252,40 @@ def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3,
251252
f, a = plt.subplots(1, 1)
252253
else:
253254
a = ax
255+
if vmax is None:
256+
vmax = np.max(spirals.values)
257+
if vmin is None:
258+
vmin = np.min(spirals.values)
259+
if structure['was_1d']:
260+
cbar_orientation = 'horizontal'
261+
cbar_location = 'top'
262+
else:
263+
cbar_orientation = 'vertical'
264+
cbar_location = 'right'
254265

255-
sns.heatmap(spirals, cmap='vlag', cbar_kws={"shrink": 0.2, 'pad': .17},
256-
ax=a, vmax=vmax, vmin=vmin)
266+
# Create heatmap
267+
sns.heatmap(spirals, cmap=cmap, cbar_kws={"shrink": 1, "pad": .17, "orientation": cbar_orientation, "location": cbar_location},
268+
ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)
257269

258270
# Set labels
259271
a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))
260272
a.set_xticklabels(col_labels, rotation=45, ha='right')
261-
a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))
262-
a.set_yticklabels(row_labels, ha='right', rotation=0)
263-
273+
274+
if structure['was_1d']:
275+
a.set_xlabel('Contrasts')
276+
a.set_ylabel(' ')
277+
a.set_yticks([])
278+
a.set_yticklabels([])
279+
else:
280+
a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))
281+
a.set_yticklabels(row_labels, ha='right', rotation=0)
282+
264283
if ax is None:
265284
f.gca().set_aspect('equal')
266-
f.set_size_inches(n_cols/3, n_rows/3)
285+
if fig_size is None:
286+
f.set_size_inches(n_cols/3, n_rows/3)
287+
else:
288+
f.set_size_inches(fig_size)
267289
return f, a, mean_delta
268290
else:
269291
return a, mean_delta

nbs/API/multi.ipynb

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
" # Handle 2D labels\n",
179179
" if labels and isinstance(labels[0], (list, tuple)):\n",
180180
" row_labels = [labels[i][0] for i in range(n_rows)]\n",
181-
" col_labels = labels[0]\n",
181+
" col_labels = labels[0][1:]\n",
182182
" else:\n",
183183
" row_labels = [f\"Row {i+1}\" for i in range(n_rows)]\n",
184184
" col_labels = [f\"Col {j+1}\" for j in range(n_cols)]\n",
@@ -305,8 +305,8 @@
305305
"outputs": [],
306306
"source": [
307307
"#| export\n",
308-
"def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3, \n",
309-
" reverse_neg=True, abs_rank=False, chop_tail=0, ax=None, **kwargs):\n",
308+
"def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, \n",
309+
" reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):\n",
310310
" \"\"\"\n",
311311
" Create a vortexmap visualization of multiple contrasts.\n",
312312
" \n",
@@ -318,7 +318,7 @@
318318
" Size of each spiral (n x n grid per contrast)\n",
319319
" sort_by : list, optional\n",
320320
" Order to sort contrasts by\n",
321-
" vmax, vmin : float, default 3, -3\n",
321+
" vmax, vmin : float, default None, None\n",
322322
" Color scale limits\n",
323323
" reverse_neg : bool, default True\n",
324324
" Whether to reverse negative values\n",
@@ -344,6 +344,7 @@
344344
" contrasts_2d = structure['contrasts_2d']\n",
345345
"\n",
346346
" spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n)))\n",
347+
" \n",
347348
" mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)), \n",
348349
" columns=col_labels, \n",
349350
" index=row_labels)\n",
@@ -368,19 +369,40 @@
368369
" f, a = plt.subplots(1, 1)\n",
369370
" else:\n",
370371
" a = ax\n",
372+
" if vmax is None:\n",
373+
" vmax = np.max(spirals.values)\n",
374+
" if vmin is None:\n",
375+
" vmin = np.min(spirals.values)\n",
376+
" if structure['was_1d']:\n",
377+
" cbar_orientation = 'horizontal'\n",
378+
" cbar_location = 'top'\n",
379+
" else:\n",
380+
" cbar_orientation = 'vertical'\n",
381+
" cbar_location = 'right'\n",
371382
" \n",
372-
" sns.heatmap(spirals, cmap='vlag', cbar_kws={\"shrink\": 0.2, 'pad': .17}, \n",
373-
" ax=a, vmax=vmax, vmin=vmin)\n",
383+
" # Create heatmap\n",
384+
" sns.heatmap(spirals, cmap=cmap, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n",
385+
" ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)\n",
374386
" \n",
375387
" # Set labels\n",
376388
" a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))\n",
377389
" a.set_xticklabels(col_labels, rotation=45, ha='right')\n",
378-
" a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))\n",
379-
" a.set_yticklabels(row_labels, ha='right', rotation=0)\n",
380-
" \n",
390+
"\n",
391+
" if structure['was_1d']:\n",
392+
" a.set_xlabel('Contrasts')\n",
393+
" a.set_ylabel(' ')\n",
394+
" a.set_yticks([])\n",
395+
" a.set_yticklabels([])\n",
396+
" else:\n",
397+
" a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))\n",
398+
" a.set_yticklabels(row_labels, ha='right', rotation=0)\n",
399+
"\n",
381400
" if ax is None:\n",
382401
" f.gca().set_aspect('equal')\n",
383-
" f.set_size_inches(n_cols/3, n_rows/3)\n",
402+
" if fig_size is None:\n",
403+
" f.set_size_inches(n_cols/3, n_rows/3)\n",
404+
" else:\n",
405+
" f.set_size_inches(fig_size)\n",
384406
" return f, a, mean_delta\n",
385407
" else:\n",
386408
" return a, mean_delta\n",

0 commit comments

Comments
 (0)