I was recently re-reading Finbarr Timber’s post on transformer inference optimizations, and I wanted to try to implement each of these techniques1 i

Calculating GPT-2’s Inference Speedups | njkumar

submited by
Style Pass
2024-11-26 22:00:03

I was recently re-reading Finbarr Timber’s post on transformer inference optimizations, and I wanted to try to implement each of these techniques1 in nanoGPT to see how much we could practically speed up inference with GPT-2’s architecture and reduce the computational bottlenecks in the model.

For my benchmark model, I'm going to use the GPT2-XL model weights and load them into nanoGPT. This will let us directly change the modeling code for our optimization tests. All tests will be run on an A100 80GB GPU, and below are the sampling parameters for our experiment:

Our naive model is only generating 62 tokens per second 🙁. In order to speed up inference, we generally want to focus on the most compute-intensive parts of the model, and one optimization is to create a KV cache that stores our key and values projections in memory, which saves time on the recalculation of past tokens in the forward pass of the attention layer.

In very simple terms, within the attention block the query is the token we are currently looking at, the key is the previous context we want to attend to, and the value is the weighted sum of this context2. Our attention blocks are wasting compute by recalculating the q, k, and v projections every time we pass in a token. When generating the current token, the model needs to pass in N previous tokens as context for the attention formula. This previous context is needed for attention, but the other parts of the model (MLP, embedding layers, classification head) don’t need the previous context, as they process tokens in parallel.

Leave a Comment