Skip to content

JeyRunner/batchix

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Batchix

This is a utility package for jax for handing batched data. Its features include:

  • splitting pytrees in to batches of fixed size
  • recombining a pytree split into batches
  • batched scan (scan over batches)
  • batched vmap (scan over vmap)
  • automatic handling of cases where the data to be split into batches is not dividable by the batch size
    • either pad or make a separate last smaller batch

Install

pip install git+https://github.com/JeyRunner/batchix.git

Usage

Split and Merge Batches

Split array into batches and merge again:

data_batched = pytree_split_in_batches( x=jnp.arange(15), batch_size=5, ) # transform data_batched # e.g vmap/scan over data_batched to process each batch y = jax.vmap(lambda batch: batch**2)(data_batched) data_recombined = pytree_combine_batches(y)

Split array into batches and merge again with support for data not being dividable by the batch size:

data_batched, data_batched_remainder = pytree_split_in_batches_with_remainder( x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # data_batched_remainder is last batch that has smaller batch size batch_remainder_strategy='ExtraLastBatch' ) # transform data_batched, data_batched_remainder # e.g vmap/scan over data_batched and direct processing of last single batch data_batched_remainder data_recombined = pytree_combine_batches(data_batched, data_batched_remainder)

Scan over Batches

Split pytree in batches and scan over the batches:

def process_batch(carry, x): # for last batch x has shape (5,) otherwise (10,) return carry, x*2 carry, out = scan_batched( process_batch, x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # makes separate call to process_batch for the last remaining elements batch_remainder_stategy='ExtraLastBatch' )

Split pytree in batches and scan over the batches with manual padding handling in scan body:

def process_batch(carry, x, valid_x_mask, invalid_last_n_elements_in_x): # x has allways shape (10,) # but for the last batch, last n elements or x are invalid padded values y = x*2 y = y[valid_x_mask] # important: just use valid x elements return carry, y carry, out = scan_batched( process_batch, x=jnp.arange(15), batch_size=10, # takes care of the data not being dividable by the batch size # makes separate call to process_batch for the last remaining elements batch_remainder_stategy='PadAndExtraLastBatch' )

When the data size is dividable by the batch size both of the above example will also work fine.

Dev setup

Install deps:

pip install .[dev, test]

About

Jax batching utility library

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages