This is a utility package for jax for handing batched data. Its features include:
- splitting pytrees in to batches of fixed size
- recombining a pytree split into batches
- batched scan (scan over batches)
- batched vmap (scan over vmap)
- automatic handling of cases where the data to be split into batches is not dividable by the batch size
- either pad or make a separate last smaller batch
pip install git+https://github.com/JeyRunner/batchix.gitSplit array into batches and merge again:
data_batched = pytree_split_in_batches( x=jnp.arange(15), batch_size=5, ) # transform data_batched # e.g vmap/scan over data_batched to process each batch y = jax.vmap(lambda batch: batch**2)(data_batched) data_recombined = pytree_combine_batches(y)Split array into batches and merge again with support for data not being dividable by the batch size:
data_batched, data_batched_remainder = pytree_split_in_batches_with_remainder( x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # data_batched_remainder is last batch that has smaller batch size batch_remainder_strategy='ExtraLastBatch' ) # transform data_batched, data_batched_remainder # e.g vmap/scan over data_batched and direct processing of last single batch data_batched_remainder data_recombined = pytree_combine_batches(data_batched, data_batched_remainder)Split pytree in batches and scan over the batches:
def process_batch(carry, x): # for last batch x has shape (5,) otherwise (10,) return carry, x*2 carry, out = scan_batched( process_batch, x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # makes separate call to process_batch for the last remaining elements batch_remainder_stategy='ExtraLastBatch' )Split pytree in batches and scan over the batches with manual padding handling in scan body:
def process_batch(carry, x, valid_x_mask, invalid_last_n_elements_in_x): # x has allways shape (10,) # but for the last batch, last n elements or x are invalid padded values y = x*2 y = y[valid_x_mask] # important: just use valid x elements return carry, y carry, out = scan_batched( process_batch, x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # makes separate call to process_batch for the last remaining elements batch_remainder_stategy='PadAndExtraLastBatch' )When the data size is dividable by the batch size both of the above example will also work fine.
Install deps:
pip install .[dev, test]