The Mamba in the Llama: Distilling and Accelerating Hybrid Models

submited by
Style Pass
2024-09-09 22:30:14

The evolution of large language models (LLMs) has been largely driven by the success of Transformer architectures. However, despite their impressive capabilities, Transformers suffer from significant inefficiencies, particularly in scenarios involving long sequences due to their quadratic complexity and heavy memory requirements. These challenges have spurred interest in exploring alternative architectures that can offer similar or even better performance with greater efficiency.

One such promising direction is the use of linear Recurrent Neural Networks (linear RNNs), specifically the Mamba and Mamba2 architecture. Mamba and its variants have demonstrated competitive performance to Transformers while offering significant advantages in inference speed. Mamba enjoys parallel training, as well as constant memory requirements during inference. But can we bridge the gap between these architectures and harness the strengths of both? The answer lies in distilling large-scale Transformer models into hybrid linear RNNs and accelerating inference, combining the best of both worlds.

‍ The self-attention mechanism is vital to transformers, enabling models to weigh the importance of different tokens in a sequence. However, this comes at the cost of computational and memory inefficiencies, particularly for long sequences. For example, during inference, transformers store the key and value vectors for every token they encounter in the KV-cache. For big models and long sequences, dealing with the KV-cache causes a big memory overhead. It slows inference down, and it occupied a lot of GPU memory. In contrast, linear RNNs like Mamba enjoy linear-time scaling during training, and constant memory cost during inference as the entire state is summarized in a fixed size tensor. As a consequence, Mamba offers up to 5× higher throughput in inference tasks. Therefore, it would make sense to take a pretrained transformer and distill its capabilities in a Mamba model, so as to explot the inference capabilities of linear RNNs while preserving the generation quality of transformer LLMs.

Leave a Comment