Shape Rotation 101: An Intro to Einsum and Jax Transformers | sankalp's blog

submited by
Style Pass
2024-06-22 08:30:05

This post is divided into two parts. In the first part, we go through einsum notation basics. The second part is about understanding simple transformer code in jax which uses a lot of einsum.

From your end, I want some brains (no I am not a zombie, I just want your attention). I also assume knowledge of numpy basics, matrix multiplication brute force algorithm and transformers basics (only part 2).

In the case I spectacularly fail to explain einsum, you can refer the posts 2, 3 and 4 mentioned above. 2. is einsum in pytorch and builds up on 3 and 4. 3 goes into internals while 4 focuses on the notations with examples.

Einsum is an alternative API for tensor/numerical array manipulation provided by several libraries. NumPy (since v1.6), PyTorch, and other scientific computing libraries offer an einsum function.

This function leverages Einstein summation notation to simplify complex linear algebraic operations on multi-dimensional arrays - tensor contractions (more on this later) and summations. The syntax is mostly consistent across NumPy, Torch, Jax etc.

Leave a Comment