N/B: Codes are implemented pedagogically at the expense of repetition. Each model is purposefully contained in a file without inter-file dependencies.

Search code, repositories, users, issues, pull requests...

submited by
Style Pass
2024-05-10 01:00:09

N/B: Codes are implemented pedagogically at the expense of repetition. Each model is purposefully contained in a file without inter-file dependencies.

Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

There are experimental features (like MAMBA architecture and RLHF) in the repo which are not available via the package, pending tests. Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the Discord, or just let us know what you're working on!

You will need Python 3.9 or later, and working JAX installation, FLAX installation, OPTAX installation (with GPU support for running training, without can only support creations). Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:

Leave a Comment