To execute the training, run the following three commands on an 8xA100 or 8xH100 node. They complete in <45min on an 8xH100 with decent internet connection.
This will train a 124M-parameter transformer for 6000 steps on 3.15B tokens of Fineweb [1], achieving ~3.275 validation loss. For comparison, the default llm.c PyTorch trainer yields >3.28 validation loss after training for 10B tokens.
Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of CIFAR-10 speedrunning. In particular, we experimentally obtained the following practices:
Our use of a Newton-Schulz iteration for orthogonalization traces to Bernstein & Newhouse (2024), who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation. In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the orthogonalization method for this optimizer. If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful. Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm, and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent. The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs compared to Shampoo.
To simplify the code, some features have been removed, including text generation. And to obtain a training speed improvement, we have diverged from being a strict reproduction of the GPT-2 paper.