Skip to content

Commit babf285

Browse files
committed
fixed extract_data
1 parent 5a1eeed commit babf285

File tree

5 files changed

+175
-75
lines changed

5 files changed

+175
-75
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# DABEST-Python
22

3-
43
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
54

65
[![minimal Python

dabest/multi.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,11 @@ def _validate_individual_dabest_obj(self, dabest_obj, position: int):
285285

286286
def _extract_data(self) -> Tuple[List, List, List, List]:
287287
"""
288-
Extract bootstrap, effect sizes, and CI data.
288+
Extract bootstrap, effect sizes, CI low bounds and CI high bounds.
289289
Handles mixed contrast types for vortexmap.
290290
"""
291291
if self._bootstrap_data is not None:
292-
return self._bootstrap_data, self._effect_data, self._ci_data
292+
return self._bootstrap_data, self._effect_data, self._ci_lows, self._ci_highs
293293

294294
# Process effect size attribute name
295295
effect_attr = "hedges_g" if self.effect_size == 'delta_g' else self.effect_size
@@ -334,7 +334,8 @@ def _extract_data(self) -> Tuple[List, List, List, List]:
334334
# Cache results
335335
self._bootstrap_data = bootstraps
336336
self._effect_data = differences
337-
self._ci_data = (ci_lows, ci_highs)
337+
self._ci_lows = ci_lows
338+
self._ci_highs = ci_highs
338339

339340
return bootstraps, differences, ci_lows, ci_highs
340341

@@ -383,7 +384,7 @@ def confidence_intervals(self) -> Tuple[List, List]:
383384
_, _, ci_lows, ci_highs = self._extract_data()
384385
return ci_lows, ci_highs
385386

386-
def forest_plot(self, **kwargs):
387+
def forest_plot(self, **forest_plot_kwargs):
387388
"""
388389
Create forest plot using validated data.
389390
@@ -414,23 +415,22 @@ def forest_plot(self, **kwargs):
414415
'ci_type': self.ci_type,
415416
'labels': self.structure['col_labels'],
416417
}
417-
forest_kwargs.update(kwargs) # kwargs can override defaults
418+
forest_kwargs.update(forest_plot_kwargs) # kwargs can override defaults
418419

419420
# Call existing forest_plot with validated dabest objects
420421
return forest_plot(data=all_dabest_objs, **forest_kwargs)
421422

422-
def vortexmap(self, **kwargs):
423+
def vortexmap(self, **heatmap_kwargs):
423424
"""
424425
Create vortexmap using validated data.
425426
426-
This uses the enhanced vortexmap that can handle both homogeneous
427+
This uses the vortexmap that can handle both homogeneous
427428
and mixed contrast types.
428429
"""
429-
# Import here to avoid circular imports
430430
from .multi import vortexmap
431431

432-
# Call enhanced vortexmap with self as the multi_contrast object
433-
return vortexmap(multi_contrast=self, **kwargs)
432+
# Call vortexmap with self as the multi_contrast object
433+
return vortexmap(multi_contrast=self, **heatmap_kwargs)
434434
def get_bootstrap_by_position(self, row: int, col: int):
435435
"""
436436
Get bootstrap data for a specific position in the grid.
@@ -611,7 +611,7 @@ def _spiralize(fill, m, n):
611611

612612
# %% ../nbs/API/multi.ipynb 12
613613
def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None,
614-
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):
614+
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, heatmap_kwargs=None):
615615
"""
616616
Create a vortexmap visualization of multiple contrasts.
617617
@@ -633,18 +633,44 @@ def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vm
633633
Percentage of extreme values to exclude
634634
ax : matplotlib.Axes, optional
635635
Existing axes to plot on
636-
636+
fig_size : tuple, optional
637+
Figure size (width, height) in inches
638+
heatmap_kwargs : dict, optional
639+
Additional keyword arguments passed to sns.heatmap().
640+
Common options include:
641+
- 'cmap': colormap (overrides direct cmap parameter)
642+
- 'vmin', 'vmax': color scale limits (override direct parameters)
643+
- 'center': center value for colormap
644+
- 'annot': whether to annotate cells with values
645+
- 'fmt': format string for annotations
646+
- 'linewidths': width of lines between cells
647+
- 'linecolor': color of lines between cells
648+
- 'cbar': whether to show colorbar
649+
- 'cbar_kws': colorbar customization dict
650+
- 'square': whether to make cells square
651+
- 'xticklabels', 'yticklabels': tick label control
652+
- 'mask': boolean array to mask cells
653+
plot_kwargs : dict, optional
654+
Additional keyword arguments for plot styling and layout.
655+
Available options:
656+
- 'title': plot title
657+
- 'xlabel', 'ylabel': axis labels
658+
- 'xlabel_rotation', 'ylabel_rotation': label rotation angles
659+
- 'grid': whether to show grid
637660
Returns
638661
-------
639662
tuple
640663
(figure, axes, mean_delta_dataframe) if ax is None,
641664
else (axes, mean_delta_dataframe)
642665
"""
666+
667+
if heatmap_kwargs is None:
668+
heatmap_kwargs = {}
643669
structure = multi_contrast.structure
644670

645671
n_rows = structure['n_rows']
646672
n_cols = structure['n_cols']
647-
col_labels = structure['col_labels']
673+
col_labels = structure['col_labels']
648674
row_labels = structure['row_labels']
649675
was_1d = (structure['type'] == '1D')
650676

@@ -694,10 +720,13 @@ def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vm
694720
else:
695721
cbar_orientation = 'vertical'
696722
cbar_location = 'right'
697-
723+
heatmap_kwargs.setdefault('cmap', cmap)
724+
heatmap_kwargs.setdefault('vmax', vmax)
725+
heatmap_kwargs.setdefault('vmin', vmin)
726+
heatmap_kwargs.setdefault('center', 0)
698727
# Create heatmap
699-
sns.heatmap(spirals, cmap=cmap, cbar_kws={"shrink": 1, "pad": .17, "orientation": cbar_orientation, "location": cbar_location},
700-
ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)
728+
sns.heatmap(spirals, cbar_kws={"shrink": 1, "pad": .17, "orientation": cbar_orientation, "location": cbar_location},
729+
ax=a, **heatmap_kwargs)
701730

702731
# Set labels
703732
a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))

nbs/API/multi.ipynb

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,11 @@
351351
" \n",
352352
" def _extract_data(self) -> Tuple[List, List, List, List]:\n",
353353
" \"\"\"\n",
354-
" Extract bootstrap, effect sizes, and CI data.\n",
354+
" Extract bootstrap, effect sizes, CI low bounds and CI high bounds.\n",
355355
" Handles mixed contrast types for vortexmap.\n",
356356
" \"\"\"\n",
357357
" if self._bootstrap_data is not None:\n",
358-
" return self._bootstrap_data, self._effect_data, self._ci_data\n",
358+
" return self._bootstrap_data, self._effect_data, self._ci_lows, self._ci_highs\n",
359359
" \n",
360360
" # Process effect size attribute name\n",
361361
" effect_attr = \"hedges_g\" if self.effect_size == 'delta_g' else self.effect_size\n",
@@ -400,7 +400,8 @@
400400
" # Cache results\n",
401401
" self._bootstrap_data = bootstraps\n",
402402
" self._effect_data = differences\n",
403-
" self._ci_data = (ci_lows, ci_highs)\n",
403+
" self._ci_lows = ci_lows\n",
404+
" self._ci_highs = ci_highs\n",
404405
" \n",
405406
" return bootstraps, differences, ci_lows, ci_highs\n",
406407
" \n",
@@ -449,7 +450,7 @@
449450
" _, _, ci_lows, ci_highs = self._extract_data()\n",
450451
" return ci_lows, ci_highs\n",
451452
" \n",
452-
" def forest_plot(self, **kwargs):\n",
453+
" def forest_plot(self, **forest_plot_kwargs):\n",
453454
" \"\"\"\n",
454455
" Create forest plot using validated data.\n",
455456
" \n",
@@ -480,23 +481,22 @@
480481
" 'ci_type': self.ci_type,\n",
481482
" 'labels': self.structure['col_labels'],\n",
482483
" }\n",
483-
" forest_kwargs.update(kwargs) # kwargs can override defaults\n",
484+
" forest_kwargs.update(forest_plot_kwargs) # kwargs can override defaults\n",
484485
" \n",
485486
" # Call existing forest_plot with validated dabest objects\n",
486487
" return forest_plot(data=all_dabest_objs, **forest_kwargs)\n",
487488
"\n",
488-
" def vortexmap(self, **kwargs):\n",
489+
" def vortexmap(self, **heatmap_kwargs):\n",
489490
" \"\"\"\n",
490491
" Create vortexmap using validated data.\n",
491492
" \n",
492-
" This uses the enhanced vortexmap that can handle both homogeneous\n",
493+
" This uses the vortexmap that can handle both homogeneous\n",
493494
" and mixed contrast types.\n",
494495
" \"\"\"\n",
495-
" # Import here to avoid circular imports \n",
496496
" from .multi import vortexmap\n",
497497
" \n",
498-
" # Call enhanced vortexmap with self as the multi_contrast object\n",
499-
" return vortexmap(multi_contrast=self, **kwargs) \n",
498+
" # Call vortexmap with self as the multi_contrast object\n",
499+
" return vortexmap(multi_contrast=self, **heatmap_kwargs) \n",
500500
" def get_bootstrap_by_position(self, row: int, col: int):\n",
501501
" \"\"\"\n",
502502
" Get bootstrap data for a specific position in the grid.\n",
@@ -727,7 +727,7 @@
727727
"source": [
728728
"#| export\n",
729729
"def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, \n",
730-
" reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):\n",
730+
" reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, heatmap_kwargs=None):\n",
731731
" \"\"\"\n",
732732
" Create a vortexmap visualization of multiple contrasts.\n",
733733
" \n",
@@ -749,18 +749,44 @@
749749
" Percentage of extreme values to exclude\n",
750750
" ax : matplotlib.Axes, optional\n",
751751
" Existing axes to plot on\n",
752-
" \n",
752+
" fig_size : tuple, optional\n",
753+
" Figure size (width, height) in inches\n",
754+
" heatmap_kwargs : dict, optional\n",
755+
" Additional keyword arguments passed to sns.heatmap().\n",
756+
" Common options include:\n",
757+
" - 'cmap': colormap (overrides direct cmap parameter)\n",
758+
" - 'vmin', 'vmax': color scale limits (override direct parameters)\n",
759+
" - 'center': center value for colormap\n",
760+
" - 'annot': whether to annotate cells with values\n",
761+
" - 'fmt': format string for annotations\n",
762+
" - 'linewidths': width of lines between cells\n",
763+
" - 'linecolor': color of lines between cells\n",
764+
" - 'cbar': whether to show colorbar\n",
765+
" - 'cbar_kws': colorbar customization dict\n",
766+
" - 'square': whether to make cells square\n",
767+
" - 'xticklabels', 'yticklabels': tick label control\n",
768+
" - 'mask': boolean array to mask cells\n",
769+
" plot_kwargs : dict, optional\n",
770+
" Additional keyword arguments for plot styling and layout.\n",
771+
" Available options:\n",
772+
" - 'title': plot title\n",
773+
" - 'xlabel', 'ylabel': axis labels\n",
774+
" - 'xlabel_rotation', 'ylabel_rotation': label rotation angles\n",
775+
" - 'grid': whether to show grid \n",
753776
" Returns\n",
754777
" -------\n",
755778
" tuple\n",
756779
" (figure, axes, mean_delta_dataframe) if ax is None, \n",
757780
" else (axes, mean_delta_dataframe)\n",
758781
" \"\"\"\n",
782+
"\n",
783+
" if heatmap_kwargs is None:\n",
784+
" heatmap_kwargs = {}\n",
759785
" structure = multi_contrast.structure\n",
760786
"\n",
761787
" n_rows = structure['n_rows']\n",
762788
" n_cols = structure['n_cols']\n",
763-
" col_labels = structure['col_labels']\n",
789+
" col_labels = structure['col_labels'] \n",
764790
" row_labels = structure['row_labels']\n",
765791
" was_1d = (structure['type'] == '1D')\n",
766792
"\n",
@@ -810,10 +836,13 @@
810836
" else:\n",
811837
" cbar_orientation = 'vertical'\n",
812838
" cbar_location = 'right'\n",
813-
" \n",
839+
" heatmap_kwargs.setdefault('cmap', cmap)\n",
840+
" heatmap_kwargs.setdefault('vmax', vmax)\n",
841+
" heatmap_kwargs.setdefault('vmin', vmin)\n",
842+
" heatmap_kwargs.setdefault('center', 0)\n",
814843
" # Create heatmap\n",
815-
" sns.heatmap(spirals, cmap=cmap, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n",
816-
" ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)\n",
844+
" sns.heatmap(spirals, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n",
845+
" ax=a, **heatmap_kwargs)\n",
817846
" \n",
818847
" # Set labels\n",
819848
" a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))\n",

nbs/_quarto.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ website:
1616
style: floating
1717
contents:
1818
- auto: "/0*.ipynb"
19-
- auto: "tutorials/0*.ipynb" # Autogenerate a section of tutorial notebooks
20-
- auto: "tutorials/1*.ipynb" # Autogenerate a section of tutorial notebooks
19+
- auto: "tutorials/[012]*.ipynb" # Autogenerate a section of tutorial notebooks
2120
- section: API
2221
contents: API/*
2322
favicon: images/Favicon-3-outline.svg

nbs/tutorials/10-multicontrast.ipynb

Lines changed: 84 additions & 40 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)