The parallelism scheme is similar to the original Megatron-LM, which is efficient on TPUs due to the high speed 2d mesh network. This library is desig

kingoflolz / mesh-transformer-jax

submited by
Style Pass
2021-06-09 02:30:05

The parallelism scheme is similar to the original Megatron-LM, which is efficient on TPUs due to the high speed 2d mesh network.

This library is designed for scalability up to approximately 20B parameters on TPUv3s, beyond which different parallelism strategies should be used. See other implementations such as GPT-NeoX or DeepSpeed for that.

One future direction for research is integrating this codebase with swarm-jax, to achieve further scalability with pipeline parallelism.

This project would not have been possible without compute generously provided by the TPU Research Cloud with assistance from EleutherAI.

The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64 dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as GPT-2/GPT-3.

* represents evaluation numbers reported by their respective authors, all other numbers are provided by running the lm-evaluation-harness either with the released weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these might not be directly comparable. See this blog post for more details.

Leave a Comment
Related Posts