Scalax is a collection of utilties for helping developers to easily scale up JAX based machine learning models. The main idea of scalax is pretty simple: users write model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs. This is made possible by the JAX jit compiler, and scalax provides a set of utilities to help the users obtain the sharding annotations required by the jit compiler. Because scalax wraps around the jit compiler, existing JAX code can be easily scaled up using scalax with minimal changes.
We are running an unofficial Discord community (unaffiliated with Google) for discussion related to training large models in JAX. Follow this link to join the Discord server. We have dedicated channel for scalax.
This works fine for a single GPU/TPU, but if we want to scale up to multiple GPU/TPUs, we need to partition the data or the model in order to parallelize the training across devices. Fortunately, JAX JIT already provides a way to handle these partitions with sharding annotations. For example, if we have sharding annotations for the train_state and batch pytree, we can simply JIT compile the train_step function with these sharding annotations: