|
351 | 351 | " \n", |
352 | 352 | " def _extract_data(self) -> Tuple[List, List, List, List]:\n", |
353 | 353 | " \"\"\"\n", |
354 | | - " Extract bootstrap, effect sizes, and CI data.\n", |
| 354 | + " Extract bootstrap, effect sizes, CI low bounds and CI high bounds.\n", |
355 | 355 | " Handles mixed contrast types for vortexmap.\n", |
356 | 356 | " \"\"\"\n", |
357 | 357 | " if self._bootstrap_data is not None:\n", |
358 | | - " return self._bootstrap_data, self._effect_data, self._ci_data\n", |
| 358 | + " return self._bootstrap_data, self._effect_data, self._ci_lows, self._ci_highs\n", |
359 | 359 | " \n", |
360 | 360 | " # Process effect size attribute name\n", |
361 | 361 | " effect_attr = \"hedges_g\" if self.effect_size == 'delta_g' else self.effect_size\n", |
|
400 | 400 | " # Cache results\n", |
401 | 401 | " self._bootstrap_data = bootstraps\n", |
402 | 402 | " self._effect_data = differences\n", |
403 | | - " self._ci_data = (ci_lows, ci_highs)\n", |
| 403 | + " self._ci_lows = ci_lows\n", |
| 404 | + " self._ci_highs = ci_highs\n", |
404 | 405 | " \n", |
405 | 406 | " return bootstraps, differences, ci_lows, ci_highs\n", |
406 | 407 | " \n", |
|
449 | 450 | " _, _, ci_lows, ci_highs = self._extract_data()\n", |
450 | 451 | " return ci_lows, ci_highs\n", |
451 | 452 | " \n", |
452 | | - " def forest_plot(self, **kwargs):\n", |
| 453 | + " def forest_plot(self, **forest_plot_kwargs):\n", |
453 | 454 | " \"\"\"\n", |
454 | 455 | " Create forest plot using validated data.\n", |
455 | 456 | " \n", |
|
480 | 481 | " 'ci_type': self.ci_type,\n", |
481 | 482 | " 'labels': self.structure['col_labels'],\n", |
482 | 483 | " }\n", |
483 | | - " forest_kwargs.update(kwargs) # kwargs can override defaults\n", |
| 484 | + " forest_kwargs.update(forest_plot_kwargs) # kwargs can override defaults\n", |
484 | 485 | " \n", |
485 | 486 | " # Call existing forest_plot with validated dabest objects\n", |
486 | 487 | " return forest_plot(data=all_dabest_objs, **forest_kwargs)\n", |
487 | 488 | "\n", |
488 | | - " def vortexmap(self, **kwargs):\n", |
| 489 | + " def vortexmap(self, **heatmap_kwargs):\n", |
489 | 490 | " \"\"\"\n", |
490 | 491 | " Create vortexmap using validated data.\n", |
491 | 492 | " \n", |
492 | | - " This uses the enhanced vortexmap that can handle both homogeneous\n", |
| 493 | + " This uses the vortexmap that can handle both homogeneous\n", |
493 | 494 | " and mixed contrast types.\n", |
494 | 495 | " \"\"\"\n", |
495 | | - " # Import here to avoid circular imports \n", |
496 | 496 | " from .multi import vortexmap\n", |
497 | 497 | " \n", |
498 | | - " # Call enhanced vortexmap with self as the multi_contrast object\n", |
499 | | - " return vortexmap(multi_contrast=self, **kwargs) \n", |
| 498 | + " # Call vortexmap with self as the multi_contrast object\n", |
| 499 | + " return vortexmap(multi_contrast=self, **heatmap_kwargs) \n", |
500 | 500 | " def get_bootstrap_by_position(self, row: int, col: int):\n", |
501 | 501 | " \"\"\"\n", |
502 | 502 | " Get bootstrap data for a specific position in the grid.\n", |
|
727 | 727 | "source": [ |
728 | 728 | "#| export\n", |
729 | 729 | "def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, \n", |
730 | | - " reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):\n", |
| 730 | + " reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, heatmap_kwargs=None):\n", |
731 | 731 | " \"\"\"\n", |
732 | 732 | " Create a vortexmap visualization of multiple contrasts.\n", |
733 | 733 | " \n", |
|
749 | 749 | " Percentage of extreme values to exclude\n", |
750 | 750 | " ax : matplotlib.Axes, optional\n", |
751 | 751 | " Existing axes to plot on\n", |
752 | | - " \n", |
| 752 | + " fig_size : tuple, optional\n", |
| 753 | + " Figure size (width, height) in inches\n", |
| 754 | + " heatmap_kwargs : dict, optional\n", |
| 755 | + " Additional keyword arguments passed to sns.heatmap().\n", |
| 756 | + " Common options include:\n", |
| 757 | + " - 'cmap': colormap (overrides direct cmap parameter)\n", |
| 758 | + " - 'vmin', 'vmax': color scale limits (override direct parameters)\n", |
| 759 | + " - 'center': center value for colormap\n", |
| 760 | + " - 'annot': whether to annotate cells with values\n", |
| 761 | + " - 'fmt': format string for annotations\n", |
| 762 | + " - 'linewidths': width of lines between cells\n", |
| 763 | + " - 'linecolor': color of lines between cells\n", |
| 764 | + " - 'cbar': whether to show colorbar\n", |
| 765 | + " - 'cbar_kws': colorbar customization dict\n", |
| 766 | + " - 'square': whether to make cells square\n", |
| 767 | + " - 'xticklabels', 'yticklabels': tick label control\n", |
| 768 | + " - 'mask': boolean array to mask cells\n", |
| 769 | + " plot_kwargs : dict, optional\n", |
| 770 | + " Additional keyword arguments for plot styling and layout.\n", |
| 771 | + " Available options:\n", |
| 772 | + " - 'title': plot title\n", |
| 773 | + " - 'xlabel', 'ylabel': axis labels\n", |
| 774 | + " - 'xlabel_rotation', 'ylabel_rotation': label rotation angles\n", |
| 775 | + " - 'grid': whether to show grid \n", |
753 | 776 | " Returns\n", |
754 | 777 | " -------\n", |
755 | 778 | " tuple\n", |
756 | 779 | " (figure, axes, mean_delta_dataframe) if ax is None, \n", |
757 | 780 | " else (axes, mean_delta_dataframe)\n", |
758 | 781 | " \"\"\"\n", |
| 782 | + "\n", |
| 783 | + " if heatmap_kwargs is None:\n", |
| 784 | + " heatmap_kwargs = {}\n", |
759 | 785 | " structure = multi_contrast.structure\n", |
760 | 786 | "\n", |
761 | 787 | " n_rows = structure['n_rows']\n", |
762 | 788 | " n_cols = structure['n_cols']\n", |
763 | | - " col_labels = structure['col_labels']\n", |
| 789 | + " col_labels = structure['col_labels'] \n", |
764 | 790 | " row_labels = structure['row_labels']\n", |
765 | 791 | " was_1d = (structure['type'] == '1D')\n", |
766 | 792 | "\n", |
|
810 | 836 | " else:\n", |
811 | 837 | " cbar_orientation = 'vertical'\n", |
812 | 838 | " cbar_location = 'right'\n", |
813 | | - " \n", |
| 839 | + " heatmap_kwargs.setdefault('cmap', cmap)\n", |
| 840 | + " heatmap_kwargs.setdefault('vmax', vmax)\n", |
| 841 | + " heatmap_kwargs.setdefault('vmin', vmin)\n", |
| 842 | + " heatmap_kwargs.setdefault('center', 0)\n", |
814 | 843 | " # Create heatmap\n", |
815 | | - " sns.heatmap(spirals, cmap=cmap, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n", |
816 | | - " ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)\n", |
| 844 | + " sns.heatmap(spirals, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n", |
| 845 | + " ax=a, **heatmap_kwargs)\n", |
817 | 846 | " \n", |
818 | 847 | " # Set labels\n", |
819 | 848 | " a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))\n", |
|
0 commit comments