Implementing Formal Algorithms for Transformers
Machine learning by doing. Writing a pedagogical implementation of multi-head attention from scratch using pseudocode from Deep Mind
The transformer architecture was introduced in the 2017 paper Attention is All You Need. Since then, hundreds of variations on the central theme of contextualizing token representations via attention have been introduced (for recent wide angle views see A Survey of Transformers and Efficient Methods for Natural Language Processing: A Survey ).
Large language models are probably the most widely known types of transformers, but the architecture has been rapidly diffusing into every corner of machine learning. Lucas Beyer recently published a nice set of slides that cover applications in text, vision, speech, and reinforcement learning while highlighting the “unification of communities” around the transformer architecture.
Whether you are just getting to know transformers or modifying them to work on brand new types of data, it always helps to have a solid understanding of the fundamentals. In this post we will examine the 2022 paper by Deep Mind’s Mary Phuong and Marcus Hutter, Formal Algorithms for Transformers (PH22). Then we will use the pseudocode algorithms they present to implement the attention components of the transformer in PyTorch.
This post assumes some familiarity with the transformer architecture. If you are brand new to the transformer world, consider checking out some of the material in the Further Resources section at the end of this post.
Formal Algorithms for Transformers
The main contribution of the PH22 paper is a set of concise pseudocode descriptions of the transformer algorithms. The authors motivate their work by saying,
Transformers are deep feed-forward artificial neural networks with a (self)attention mechanism. They have been tremendously successful in natural language processing tasks and other domains. Since their inception 5 years ago [VSP+17], many variants have been suggested [LWLQ21]. Descriptions are usually graphical, verbal, partial, or incremental. Despite their popularity, it seems no pseudocode has ever been published for any variant. Contrast this to other fields of computer science …
The essentially complete pseudocode is about 50 lines, compared to thousands of lines of actual real source code. We believe these formal algorithms will be useful for theoreticians who require compact, complete, and precise formulations, experimental researchers interested in implementing a Transformer from scratch, and encourage authors to augment their paper or text book with formal Transformer algorithms (Section 2).
This unified and self-contained approach with unambiguous algorithms is very helpful for writing an implementation from scratch. The concise pseudocode descriptions also fill a gap that existed in the transformers literature. Below, we will see that the central attention algorithms can be expressed in relatively little code.
Some Implementation Notes
Matrix Multiplication Convention
In their Notation section, the authors describe their convention when writing matrix multiplications.
We use matrix × column vector convention more common in mathematics, compared to the default row vector × matrix in the transformer literature, i.e. our matrices are transposed.
The implementations presented here will use the “row vector × matrix” convention. I found this convention more natural in PyTorch but it is definitely possible to implement either way. Fortunately, the authors uploaded their LaTeX files to arXiv and I was able to create transposed versions of their pseudocode algorithms in Overleaf.
Treatment of Batched Input
The pseudocode algorithms do not deal with the batch dimension of inputs. This makes the algorithms cleaner and easier to read but is something to keep in mind when comparing them to the code snippets in this post and the full implementations in github.
Padding and Masking
With treatment of the batch dimension comes treatment of masked tokens due to padding. This is distinct from the unidirectional or “causal” masking used in autoregressive language models (e.g. GPT) but the same mask tensor can be used to accomplish both.
Pseudocode and Code Code
This post will cover the single-query, single-head, and multi-head attention algorithms. For each, we will examine the transposed version of the algorithm from PH22 along with a code snippet. The code snippet is just a sketch but the full implementation is available in the galtay/formal-algos-transformers github repo. By placing them side by side, I hope to convey the expressiveness of the pseudocode PH22 created.
Single-Query Attention
The authors present a full picture of transformers from embeddings to training, to inference and I suggest reading through the entire paper to benefit from their self-contained presentation. However, this post focuses on the attention mechanism so we will jump ahead to the first attention-based algorithm: single-query attention.
In this algorithm, we take a single primary token embedding and create a new embedding by attending to a sequence of context embeddings. In many transformer implementations d_x
, d_z
, and d_out
are equal to each other but this is not a requirement.
Single-Head Attention
In the single-head attention algorithm, we expand on the single-query implementation in a few ways.
- We allow a sequence of primary token embeddings instead of a single primary token embedding
- We still have a sequence of context token embeddings but we collect them into a into a single tensor instead of a list of tensors
- We introduce a batch dimension in the code snippet
- We introduce a mask tensor to handle padding and unidirectional attention
The attention algorithm is applied independently to each batch. The mask tensor is used to prevent token t_x
in X
from attending to token t_z
in Z
. In PH22, the mask is used in autoregressive models to prevent tokens from attending to other tokens ahead of them in a sequence. Batching introduces the need for padding tokens and we can use the same mask to prevent the attention mechanism from contextualizing these. In our implementation, the user passes in a mask tensor with shape (b
, l_x
, l_z
).
We take advantage of the torch.einsum
module to do batch matrix multiplications. Reading and writing these expressions can take some getting used to, but they can be used across multiple frameworks (e.g. numpy.einsum
and tf.einsum
). Tim Rocktäschel’s tutorial is a great place to start learning.
Multi-Head Attention
In the multi-head attention algorithm, we expand on the single-head implementation in a few ways.
- The primary and context embeddings are passed through
H
independent attention mechanisms. - The
H
attention outputs are concatenated and then passed through another affine transformation parameterized byW_o
andb_o.
Unit Tests With Stateless Functional Calls
The github repo linked in this post comes with a set of unit tests. The testing strategy relies on a gold implementation of single-query attention that allows users to trust the single-head and multi-head implementations if they trust the single-query implementation.
These tests make use of a new feature introduced in PyTorch 1.12 that provide stateless function calls. This feature allows for the temporary replacement of torch.nn.module
parameters and buffers with user provided ones. In practice, this gives us a way to ensure that two instances of the attention modules above use the same set of weights and biases for a given computation. You can see a concise example in their LinkedIn post.
Further Resources
Implementations
- The Annotated Transformer
- Annotated Implementations from labml.ai
- The Implementations of Phil Wang (lucidrains)
- minGPT from Andrej Karpathy
Lectures / Tutorials
- Peter Bloem’s Lectures from Deep Learning at VU Amsterdam
- Hugging Face Course
- Jay Alammar’s Illustrated Transformer
- Lucas Beyer’s slides on Transformers
Einsum
- The implementations here make use of einsum. This provides a platform agnostic way to write tensor operations (i.e. you can use the same string in numpy, tensorflow, or pytorch). You can read more about einsum in Tim Rocktäschel’s tutorial
Code in Medium
- the code snippets in this post were generated with https://carbon.now.sh/