Beating NumPy's matrix multiplication in 150 lines of C code

submited by
Style Pass
2024-07-02 07:00:15

TL;DR The code from the tutorial is available at matmul.c. This blog post is the result of my attempt to implement high-performance matrix multiplication on CPU while keeping the code simple, portable and scalable. The implementation follows the BLIS design, works for arbitrary matrix sizes, and, when fine-tuned for an AMD Ryzen 7700 (8 cores), outperforms NumPy (=OpenBLAS), achieving over 1 TFLOPS of peak performance across a wide range of matrix sizes.

By efficiently parallelizing the code with just 3 lines of OpenMP directives, it’s both scalable and easy to understand. The implementation hasn’t been tested on other CPUs, so I would appreciate feedback on its performance on your hardware. Although the code is portable and targets Intel Core and AMD Zen CPUs with FMA3 and AVX instructions (i.e., all modern Intel Core and AMD Zen CPUs), please don’t expect peak performance without fine-tuning the hyperparameters, such as the number of threads, kernel, and block sizes, unless you are running it on a Ryzen 7700(X). Additionally, on some Intel CPUs, the OpenBLAS implementation might be notably faster due to AVX-512 instructions, which were intentionally omitted here to support a broader range of processors. Throughout this tutorial, we’ll implement matrix multiplication from scratch, learning how to optimize and parallelize C code using matrix multiplication as an example. This is my first time writing a blog post. If you enjoy it, please subscribe and share it! I would be happy to hear feedback from all of you. This is the first part of my planned two-part blog series. In the second part, we will learn how to optimize matrix multiplication on GPUs. Stay tuned!

Matrix multiplication is an essential part of nearly all modern neural networks. For example, most of the time spent during inference in Transformers is actually taken up by matrix multiplications. Despite using matmul daily in PyTorch, NumPy, or JAX, I’ve never really thought about how it is designed and implemented to maximize hardware efficiency. To achieve such speeds, NumPy, for instance, relies on external BLAS (Basic Linear Algebra Subprograms) libraries. These libraries implement highly optimized common linear algebra operations such as dot product, matrix multiplication, vector addition, and scalar multiplication. Examples of BLAS implementations include:

Leave a Comment