The Importance of Attention
Language modeling tasks like answering questions and classifying documents are now designed with transformer networks. The attention building block is a central component of the transformer. Recently, transformer-based computer vision models have attained state-of-the-art results (further underscoring the importance of attention). This write-up provides an intuitive understanding of the attention block at the heart of transformers.
I assume the reader has a basic understanding of Neural Net-based spatial and sequence learning (i.e., CNNs and LSTMs).
Prior to transformers (with attention), recurrent nets (LSTMs) were a natural way to construct sequence learning tasks like language translation. However, the main disadvantage of recurrent nets was their sequential processing flow. Past words had to sequentially pass through the LSTM “memory” cell to be able to establish a context for the present word. Transformers, on the other hand, process the entire sentence in parallel, allowing for higher-quality long-distance context.
Found in Translation
Pun aside, the roots of attention are in language translation; the source language is first transformed into an encoded form; next, words have to be decoded step-by-step into the destination language. For the decoder to predict the next word vector such that it contains the most relevant information from all of the encoded hidden states, we perform a weighted average on the encoded words. How much each encoded state contributes to the weighted average is determined by an alignment score between that encoded state and the previous hidden state of the decoder. We can consider the previous decoder state as the query vector, and the encoder hidden states as key and value vectors (shown below).
The output is a weighted average of the value vectors, where the weights are determined by the compatibility (or similarity) function between the query and the keys. This output is used to predict the next word.
When building a language model, we’re interested in the similarity measure of a word with respect to every other word in the source sequence and this can be formulated as the retrieval (from a database) with a query q of a value v(i) available at the key, k(i).
Attention can be defined as a query that performs the best “match” against a set of keys that is then used to retrieve a value associated with the key (picture above).
To make it concrete, the query is the word we’re operating on, and keys and values are words we’ve generated in the past and available to us in memory. Keys and values can be the same thing.
To summarize, in the attention operation, we take the query and find the most similar key by using a similarity metric and get the value that corresponds to the key.
Note, for the attention parameters to be learned via back-propagation of errors, its mathematical formulations has to be differentiable.
As shown below, we compute the dot products of the query vector with each of the keys and apply a scaling factor. Next, we apply a softmax function to the scaled dot-product values and that has the effect of normalizing the outputs to sum to 1 reflecting probability scores. Finally, the probabilistically weighted output of the similarities is linearly combined with the value vectors, to produce the attention. As a side note, the dot-product operation is scaled by the square root of the dimension of the keys (√dₖ).
As an analogy, in convolution neural nets, we use several filters on the output side to produce several different feature maps. Multi-headed attention serves a similar purpose. The queries, keys and values are projected into a different space multiple times (h times) with different, learnt linear projections. That then gets fed into the attention layer with multiple “heads” (denoted as h heads below). The output of all the heads are concatenated and projected with a linear layer to generate the multi-head attention.
The computational complexity of self-attention is O(n² * d) where n is the sequence length and d is the representation (embedding) dimensionality. The squared term comes from the fact that every position (word) has to be connected to every other word to get a similarity measure. On the plus side, the long-range dependencies (between words) are of unit distance in all cases (unlike LSTMs).
The redeeming aspect of this n-squared compute complexity is that, since we look at every pair of words simultaneously, the entire attention operation can be parallelized thus making complexity less of an issue on hardware that enables parallel compute.