|
1 | 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/load.ipynb. |
2 | 2 |
|
3 | 3 | # %% auto 0 |
4 | | -__all__ = ['load'] |
| 4 | +__all__ = ['load', 'prop_dataset'] |
5 | 5 |
|
6 | 6 | # %% ../nbs/API/load.ipynb 4 |
7 | 7 | def load(data, idx=None, x=None, y=None, paired=None, id_col=None, |
@@ -77,3 +77,54 @@ def load(data, idx=None, x=None, y=None, paired=None, id_col=None, |
77 | 77 | return Dabest(data, idx, x, y, paired, id_col, ci, resamples, random_seed, proportional, delta2, experiment, experiment_label, x1_level, mini_meta) |
78 | 78 |
|
79 | 79 |
|
| 80 | + |
| 81 | +# %% ../nbs/API/load.ipynb 5 |
| 82 | +import numpy as np |
| 83 | +from typing import Union, Optional |
| 84 | + |
| 85 | +def prop_dataset(group:Union[list, tuple, np.ndarray, dict], #Accepts lists, tuples, or numpy ndarrays of numeric types. |
| 86 | + group_names: Optional[list] = None): |
| 87 | + ''' |
| 88 | + Convenient function to generate a dataframe of binary data. |
| 89 | + ''' |
| 90 | + import pandas as pd |
| 91 | + |
| 92 | + if isinstance(group, dict): |
| 93 | + # If group_names is not provided, use the keys of the dict as group_names |
| 94 | + if group_names is None: |
| 95 | + group_names = list(group.keys()) |
| 96 | + elif not set(group_names) == set(group.keys()): |
| 97 | + # Check if the group_names provided is the same as the keys of the dict |
| 98 | + raise ValueError('group_names must be the same as the keys of the dict.') |
| 99 | + # Check if the values in the dict are numeric |
| 100 | + if not all([isinstance(group[name], (list, tuple, np.ndarray)) for name in group_names]): |
| 101 | + raise ValueError('group must be a dict of lists, tuples, or numpy ndarrays of numeric types.') |
| 102 | + # Check if the values in the dict only have two elements under each parent key |
| 103 | + if not all([len(group[name]) == 2 for name in group_names]): |
| 104 | + raise ValueError('Each parent key should have only two elements.') |
| 105 | + group_val = group |
| 106 | + |
| 107 | + else: |
| 108 | + if group_names is None: |
| 109 | + raise ValueError('group_names must be provided if group is not a dict.') |
| 110 | + # Check if the length of group is two times of the length of group_names |
| 111 | + if not len(group) == 2 * len(group_names): |
| 112 | + raise ValueError('The length of group must be two times of the length of group_names.') |
| 113 | + group_val = {group_names[i]: [group[i*2], group[i*2+1]] for i in range(len(group_names))} |
| 114 | + |
| 115 | + # Check if the sum of values in group_val under each key are the same |
| 116 | + if not all([sum(group_val[name]) == sum(group_val[group_names[0]]) for name in group_val.keys()]): |
| 117 | + raise ValueError('The sum of values under each key must be the same.') |
| 118 | + |
| 119 | + id_col = pd.Series(range(1, sum(group_val[group_names[0]])+1)) |
| 120 | + |
| 121 | + final_df = pd.DataFrame() |
| 122 | + |
| 123 | + for name in group_val.keys(): |
| 124 | + col = np.repeat(0, group_val[name][0]).tolist() + np.repeat(1, group_val[name][1]).tolist() |
| 125 | + df = pd.DataFrame({name:col}) |
| 126 | + final_df = pd.concat([final_df, df], axis=1) |
| 127 | + |
| 128 | + final_df['ID'] = id_col |
| 129 | + |
| 130 | + return final_df |
0 commit comments