|
178 | 178 | " # Handle 2D labels\n", |
179 | 179 | " if labels and isinstance(labels[0], (list, tuple)):\n", |
180 | 180 | " row_labels = [labels[i][0] for i in range(n_rows)]\n", |
181 | | - " col_labels = labels[0]\n", |
| 181 | + " col_labels = labels[0][1:]\n", |
182 | 182 | " else:\n", |
183 | 183 | " row_labels = [f\"Row {i+1}\" for i in range(n_rows)]\n", |
184 | 184 | " col_labels = [f\"Col {j+1}\" for j in range(n_cols)]\n", |
|
305 | 305 | "outputs": [], |
306 | 306 | "source": [ |
307 | 307 | "#| export\n", |
308 | | - "def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3, \n", |
309 | | - " reverse_neg=True, abs_rank=False, chop_tail=0, ax=None, **kwargs):\n", |
| 308 | + "def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, \n", |
| 309 | + " reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):\n", |
310 | 310 | " \"\"\"\n", |
311 | 311 | " Create a vortexmap visualization of multiple contrasts.\n", |
312 | 312 | " \n", |
|
318 | 318 | " Size of each spiral (n x n grid per contrast)\n", |
319 | 319 | " sort_by : list, optional\n", |
320 | 320 | " Order to sort contrasts by\n", |
321 | | - " vmax, vmin : float, default 3, -3\n", |
| 321 | + " vmax, vmin : float, default None, None\n", |
322 | 322 | " Color scale limits\n", |
323 | 323 | " reverse_neg : bool, default True\n", |
324 | 324 | " Whether to reverse negative values\n", |
|
344 | 344 | " contrasts_2d = structure['contrasts_2d']\n", |
345 | 345 | "\n", |
346 | 346 | " spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n)))\n", |
| 347 | + " \n", |
347 | 348 | " mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)), \n", |
348 | 349 | " columns=col_labels, \n", |
349 | 350 | " index=row_labels)\n", |
|
368 | 369 | " f, a = plt.subplots(1, 1)\n", |
369 | 370 | " else:\n", |
370 | 371 | " a = ax\n", |
| 372 | + " if vmax is None:\n", |
| 373 | + " vmax = np.max(spirals.values)\n", |
| 374 | + " if vmin is None:\n", |
| 375 | + " vmin = np.min(spirals.values)\n", |
| 376 | + " if structure['was_1d']:\n", |
| 377 | + " cbar_orientation = 'horizontal'\n", |
| 378 | + " cbar_location = 'top'\n", |
| 379 | + " else:\n", |
| 380 | + " cbar_orientation = 'vertical'\n", |
| 381 | + " cbar_location = 'right'\n", |
371 | 382 | " \n", |
372 | | - " sns.heatmap(spirals, cmap='vlag', cbar_kws={\"shrink\": 0.2, 'pad': .17}, \n", |
373 | | - " ax=a, vmax=vmax, vmin=vmin)\n", |
| 383 | + " # Create heatmap\n", |
| 384 | + " sns.heatmap(spirals, cmap=cmap, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n", |
| 385 | + " ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)\n", |
374 | 386 | " \n", |
375 | 387 | " # Set labels\n", |
376 | 388 | " a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))\n", |
377 | 389 | " a.set_xticklabels(col_labels, rotation=45, ha='right')\n", |
378 | | - " a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))\n", |
379 | | - " a.set_yticklabels(row_labels, ha='right', rotation=0)\n", |
380 | | - " \n", |
| 390 | + "\n", |
| 391 | + " if structure['was_1d']:\n", |
| 392 | + " a.set_xlabel('Contrasts')\n", |
| 393 | + " a.set_ylabel(' ')\n", |
| 394 | + " a.set_yticks([])\n", |
| 395 | + " a.set_yticklabels([])\n", |
| 396 | + " else:\n", |
| 397 | + " a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))\n", |
| 398 | + " a.set_yticklabels(row_labels, ha='right', rotation=0)\n", |
| 399 | + "\n", |
381 | 400 | " if ax is None:\n", |
382 | 401 | " f.gca().set_aspect('equal')\n", |
383 | | - " f.set_size_inches(n_cols/3, n_rows/3)\n", |
| 402 | + " if fig_size is None:\n", |
| 403 | + " f.set_size_inches(n_cols/3, n_rows/3)\n", |
| 404 | + " else:\n", |
| 405 | + " f.set_size_inches(fig_size)\n", |
384 | 406 | " return f, a, mean_delta\n", |
385 | 407 | " else:\n", |
386 | 408 | " return a, mean_delta\n", |
|
0 commit comments