Attention
Juan Vera
January 2025
Let's define as a query, set of keys, and set of values respectively:
where is batch size and is sequence length.
We can define attention as:
where is the dimensionality of the value space and is an arbitrary weighting function, generating scalar values as the attention scores.
Attention, in essence, is taking a query, , and a set of keys, , computing a similarity score denoted by the function which is then multiplied element-wise by value vector, .
where is constructed as the softmax of an arbitrary similarity function , used to compute the similarity of a given and .
We're essentially computing a similarity , and normalizing it to be , to be interpretable as a probability distribution if we consider all other similarity scores for all other , .
The sum shows that for sequence length , we're pooling all attention weights by multiplying element wise with (as we broadcast to a vector ), and then summing over . You might as well express this as a dot product between a matrix of attention weights and a matrix of value vectors , where we get , where is the sequence length. This is essentially how the attention mechanism becomes parallelizable in large-scale training of neural networks, as you can easily distribute the matrix operations amongst different H200s.
Another property of the summation ( or matrix multiplication, if you follow the above intuition ), is that the resulting attention vector corresponding to the value vector at will have a magnitude of a fraction of the original as is normalized by the , which turns the operation to be a convex combination ensuring that the output at stays within the convex hull of , which proves to be useful for stabilizing training.
Intuitively, imagine summing a set of vectors where each vector isn't normalized by a given attenion weight, and instead has unbounded magnitude which increaes as grows. This could destabilize training as gradients can become seemingly very large. It also allows a model to express the attention vectors given a constraint, which can act as a form of "forcing" the model to learn the right values, as it must focus on distributing attention weights properly amongst different .
REMARK
In a Seq2Seq model, and , are the set of all hidden states for the decoder (at a current layer) up to the current time . is the final hidden state of the encoder at a given layer (at the equivalent layer for the decoder).Interpreting attention (as described below) can become more intuitive if are seen as such, where the decoder hidden state at , as , is the query for the set of encoder hidden state as keys , to "extract" the similarity score between the two, helping compute how much attention should be paid to the final hidden state of the decoder (value vector) given a query, , at any , where is the sequence length at the current time step.
For each time step in the forward pass of a given RNN cell, the decoder hidden state () at , varies at every , allowing and to denote how much attention should be paid to in the forward pass at time step .
Note that there can be two interpretations for time here ( and ). In the recurrent forward pass for the RNN, is defined as the current time step or token in the sequence which the RNN is processing. In the summation for attention, is defined as an intermediate time step .
Now, in the definition for , we iterate over the sequence of length , extracting different and at each to compute how "important" or how much the value vector at should be attended to for a given , (given by the element-wise multiplication of the th attention score, and the value vector at , ) and then sum over the sequence, to encode in a final vector how much "attention" should be paid to given a .
The most straightforward way to compute the similarity, is the dot product (or the scaled dot product as in scaled dot product attention), as it's scalar output denotes a measure for similarity.
Reminder that , thereby the larger positive the dot product of the two vectors are, the smaller is and thereby, the more similar the vectors are.
The inverse is true for large negative and a dot product leading to denotes orthogonality, indicating no similarity but no notion of dissimilarity.
Bahdanau Attention
Neural Machine Translation by Jointly Learning to Align and Translate
Bahdanau Attention, is defined similarly as above, with the difference being that is defined as:
where and are transformation matrices transforming the vectors and from their respective dimensionality (query space and value space), to the dimensionality of attention space.
is the vector which encodes into a final scalar value, which serves as the attention score logit, prior to the in eq. .
Note that the query space and the key space don't need to be of the same dimensionality, as and transform both vectors into the same dimensionality transform both vectors into the same dimensionality.
It's difficult to precisely know what and are really doing to both and , neural networks are black boxes. But more generally, we can see both matrices as encoding the relevant information from both and which enables the final attention score to be effectively computed. It all really pieces together when we factor in how gradient descent finds the optimal to minimize the loss .
Again, noting the dot product as a metric for similarity between two vectors, perhaps we can interpret as the more similar or are to the set of respective row vectors in , the higher magnitude the input to will have and therefore, the larger the attention score will be if is also similar in direction to the output of .
Coded out,
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
def bahdanau(X, query_space, key_space, value_space, attention_space):
batch_size, embed_dim = X.size(0), X.size(1)
W1_q, W1_k, W1_v = torch.randn(query_space, embed_dim), torch.randn(key_space, embed_dim), torch.randn(value_space, embed_dim)
W2_q, W2_k, w2_v = torch.randn(attention_space, query_space), torch.randn(attention_space, key_space), torch.randn(attention_space, 1)
w1_q_mul, w1_k_mul, w1_v_mul = W1_q.unsqueeze(0).expand(batch_size, -1, -1), W1_k.unsqueeze(0).expand(batch_size, -1, -1), W1_v.unsqueeze(0).expand(batch_size, -1, -1)
w2_q_mul, w2_k_mul, w2_v_mul = W2_q.unsqueeze(0).expand(batch_size, -1, -1), W2_k.unsqueeze(0).expand(batch_size, -1, -1), w2_v.T.unsqueeze(0).expand(batch_size, -1, -1)
Q, K, V = torch.bmm(w1_q_mul, X), torch.bmm(w1_k_mul, X), torch.bmm(w1_v_mul, X)
A_q, A_k = torch.bmm(w2_q_mul, Q), torch.bmm(w2_k_mul, K)
attention_logits = torch.bmm(w2_v_mul, torch.tanh(A_q + A_k))
attention_scores = F.softmax(attention_logits, dim=2)
attention = torch.mul(attention_scores, V)
attention_pool = torch.sum(attention, dim=2)
return attention, attention_scores, attention_pool
if __name__ == "__main__":
batch_size, hidden_neurons, seq_len = 2, 10, 15
query_space, value_space, key_space, attention_space = 18, 20, 19, 30
X = torch.randn(batch_size, hidden_neurons, seq_len)
attention, attention_scores, attention_pool = bahdanau(X, query_space, key_space, value_space, attention_space)
plt.imshow(attention_scores.view(2, 20), cmap='hot', interpolation='nearest')
plt.title('Attention. All we need.', {'size': 30})
plt.show()
Of course, now in modern LLMs, we typically use the scaled dot product, where is the dimensionality of and .
Multihead Attention
Given a set of queries, keys, and values, it's a good idea to consider a model that can compute attention pooling in multiple different manners which capture different dependencies in a sequence (long-range vs short-range for instance). This is akin to having multiple kernels, in a convolutional layer, which are optimized to extract different features of the input feature map. Each set of parameters happens to capture different patterns, in the ConvNet perhaps different edge patterns, in the attention head perhaps different semantic meaning.
Therefore, it's beneficial to allow a model to jointly use different representation spaces of the query space, value space, and key space.
Instead of performing a single attention pooling operation, we can compute independent attention pooling operations.
Each attention pooling output is called a "head", hence, multi-head attention.
Assuming a single token.
Given a query , , , we can compute a single attention head as,
where is a similarity function, , , , is the dimension of the attention space.
Including something such dot product operation ( no scaling, it's simpler notation, less LaTeX to type out ),
where is the attention weight for the th head at token .
After computing all , we concatenate all into a single vector,
where is the count of attention heads.
Finally, we compute through a fully connected layer,
where is the original dimensionality of the original embedding space and .
Positional Encoding
Assuming we're using attention in a transformer, and naturally, transformers don't have recurrent connections as RNNs, we want a way to embed positional information of a token.
Otherwise, the model wouldn't be able to to tell which token comes first, second, , last, and knowing which to pay attention to the most would become difficult as there is no notion if which word comes first, second, etc in the sequence.
Assume input embeddings of size , then we can create positional vectors of same size ,
where is the th row of the input matrix, ( is embedding dimension and is sequence length), and is the th column of the input matrix.
So is computed for every row, every even column, while is computed for every row, odd column.
Ultimately we get a matrix, holding the positional encodings for a given input sequence.
This allows us to effectively encode positional embeddings as we introduce deterministic values for position, given deterministic computation of and . Mathematically, this introduces (approximately for larger ) orthogonality between each vector, guaranteeing linear independence, which indicates that each row vector in is able to capture a unique direction in the row-space.
As a simple example, consider .
These two vectors are orthogonal (orthonormal in fact), and this guarantees linear independence as both vectors capture unique information in different directions (in the and respectively) and you can't express in terms of .
Positional encodings approximate orthogonality for larger (<s>curse</s> blessing of dimensionality), and therefore allow us to capture unique information, in this context with regards to position.
Coded out,
class PositionalEncoding(nn.Module):
def __init__(self, dropout_p: float, d_model: int, max_len: int):
super().__init__()
self.dropout = nn.Dropout(p=dropout_p)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(size=(max_len, d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
pe = self.pe[:, :x.size(1)]
return pe
Masked Softmax
Assume we compute the raw attention scores in matrix form, prior to the softmax, as:
where and , and where is sequence length and .
We transpose , in such that the matrix multiplication yields an matrix and we get the raw attention scores with respect to the sequence (length) and not some meaningless arbitrary dimension ( with respect to our task ) .
Assume , then :
is size , which can be interpreted as follows. Let's rename our variables for now in context of , where is now . For every th column vector, it's th (row) value denotes how much "attention" weighting a the th token in the sequence pays attention to all other tokens (or equivalently ) in the sequence.
Note that once we compute the matrix of attention weights, we compute a matrix multiplication with where and in this case is
Remember that in each column vector corresponds to a single token in the sequence, given the column size .
Therefore, in the matrix multiplication of , we're multiplying each value column vector (or row if we account for the transpose) in the attention space. Note that at time , we have but at , we don't expect the given token to have seen any other tokens at .
Therefore, for any attention weights in , we need to zero out a portion of the values.
Transformers
Following up on the iterations that led to the formulation for the attention mechanism, Vaswani et al., formulated the Transformer architecture in their seminal paper, Attention is All you Need.
I explain this in depth here.