Attention

Juan Vera

January 2025

Let's define q,K,Vq, K, V as a query, set of keys, and set of values respectively:

q=({q1,q2,,ql})K=({k1(1),,kl(1)}{k1(m),,kl(m)})V=({v1(1),,vl(1)}{v1(m),,vl(m)})K,VRm×l×hq = \begin{pmatrix}\{q_1, q_2, \dots, q_l \}\end{pmatrix} \\[7mm] K = \begin{pmatrix}\{k^{(1)}_1, \dots, k^{(1)}_l \}\\ \vdots \hspace{.25in} \vdots \hspace{.25in} \vdots \\ \{k^{(m)}_1, \dots, k^{(m)}_l \} \end{pmatrix} \\[7mm] V = \begin{pmatrix}\{v^{(1)}_1, \dots, v^{(1)}_l \}\\ \vdots \hspace{.25in} \vdots \hspace{.25in} \vdots \\ \{v^{(m)}_1, \dots, v^{(m)}_l \} \end{pmatrix} \\[7mm] K, V \in \mathbb{R}^{m \times l \times h}

where mm is batch size and ll is sequence length.

We can define attention as:

(1)Attention(q,K^,V^)t=0lα(q,K^)tV^tRhQ^,K^,V^Rl×hq,K^t,V^tRh(1) \hspace{.25in} \text{Attention}(q, \hat{K}, \hat{V}) \coloneqq \sum_{t = 0}^{l} \alpha(q,\hat{K})_t \odot \hat{V}_t \rightarrow \mathbb{R}^{h} \\[3mm] \hat{Q}, \hat{K}, \hat{V} \in \mathbb{R}^{l \times h} \\[3mm] q, \hat{K}_t, \hat{V}_t \in \mathbb{R}^h

where hh is the dimensionality of the value space and α\alpha is an arbitrary weighting function, generating scalar values as the attention scores.

Attention, in essence, is taking a query, qq, and a set of keys, KK, computing a similarity score denoted by the function α()R\alpha(\cdot) \rightarrow \mathbb{R} which is then multiplied element-wise by value vector, vv.

(2)α(q,K^)t=exp(a(q,K^t))j=0lexp(a(q,K^j))R(2) \hspace{.25in} \alpha(q, \hat{K})_t = \frac{\exp(a(q, \hat{K}_t))}{\sum_{j=0}^l\exp(a(q, \hat{K}_j))} \rightarrow \mathbb{R}

where α\alpha is constructed as the softmax of an arbitrary similarity function aa, used to compute the similarity of a given qq and K^t\hat{K}_t.

We're essentially computing a similarity a()a(\cdot), and normalizing it to be [0,1]\in [0, 1], to be interpretable as a probability distribution if we consider all other similarity scores for all other tt, α()t\alpha(\cdot)_t.

The sum tl\sum_{t}^l shows that for sequence length ll, we're pooling all ll attention weights by multiplying element wise with V^t\hat{V}_t (as we broadcast α(q,K^)\alpha(q, \hat{K}) to a vector Rh\in \mathbb{R}^h), and then summing over ll. You might as well express this as a dot product between a matrix of attention weights α()Rl×h\vec{\alpha}(\cdot) \in \mathbb{R}^{l \times h} and a matrix of value vectors VR(h×l)V^{\top} \in \mathbb{R}^{(h \times l)}, where we get Rl×l\rightarrow \mathbb{R}^{l \times l}, where ll is the sequence length. This is essentially how the attention mechanism becomes parallelizable in large-scale training of neural networks, as you can easily distribute the matrix operations amongst different H200s.

Another property of the summation ( or matrix multiplication, if you follow the above intuition ), is that the resulting attention vector corresponding to the value vector at tt will have a magnitude of a fraction of the original V^t|\hat{V}_t| as α(q,K^)\alpha(q, \hat{K}) is normalized by the softmax()\text{softmax}(\cdot), which turns the operation to be a convex combination ensuring that the output at tt stays within the convex hull of V\mathcal{V}, which proves to be useful for stabilizing training.

Intuitively, imagine summing a set of vectors where each vector isn't normalized by a given attenion weight, and instead has unbounded magnitude which increaes as ll grows. This could destabilize training as gradients can become seemingly very large. It also allows a model to express the attention vectors given a constraint, which can act as a form of "forcing" the model to learn the right values, as it must focus on distributing attention weights properly amongst different V^t\hat{V}_t.

REMARK
In a Seq2Seq model, VV and KK, are the set of all hidden states for the decoder (at a current layer) up to the current time t=lt' = l. qq is the final hidden state of the encoder at a given layer (at the equivalent layer for the decoder).

Interpreting attention (as described below) can become more intuitive if q,K,Vq, K, V are seen as such, where the decoder hidden state at t1t - 1, as qq, is the query for the set of encoder hidden state as keys K^t\hat{K}_t, to "extract" the similarity score between the two, helping compute how much attention should be paid to the final hidden state of the decoder (value vector) given a query, qq, at any t[1,l]t \in [1, l], where ll is the sequence length at the current time step.

For each time step in the forward pass of a given RNN cell, the decoder hidden state (qq) at t1t-1, varies at every tt, allowing KK and VV to denote how much attention should be paid to qq in the forward pass at time step tt.

Note that there can be two interpretations for time tt here (tt and tt'). In the recurrent forward pass for the RNN, tt' is defined as the current time step or token in the sequence which the RNN is processing. In the summation for attention, tt is defined as an intermediate time step [0,l]\in [0, l].

Now, in the definition for Attention\text{Attention}, we iterate over the sequence of length ll, extracting different V^t\hat{V}_t and K^t\hat{K}_t at each tt to compute how "important" or how much the value vector at should be attended to for a given qq, (given by the element-wise multiplication of the ttth attention score, α(q,K^)t\alpha(q, \hat{K})_t and the value vector at tt, V^t\hat{V}_t) and then sum over the sequence, to encode in a final Attention\text{Attention} vector how much "attention" should be paid to V^t\hat{V}_t given a qq.

The most straightforward way to compute the similarity, a()a(\cdot) is the dot product (or the scaled dot product as in scaled dot product attention), as it's scalar output denotes a measure for similarity.

Reminder that ab=abcosθ\vec{a}^\top\vec{b} = |\vec{a}| |\vec{b}| \cos \theta, thereby the larger positive the dot product of the two vectors are, the smaller θ\theta is and thereby, the more similar the vectors are.

The inverse is true for large negative and a dot product leading to 00 denotes orthogonality, indicating no similarity but no notion of dissimilarity.

Bahdanau Attention

Neural Machine Translation by Jointly Learning to Align and Translate

Bahdanau Attention, is defined similarly as above, with the difference being that a()a(\cdot) is defined as:

(3)a(q,kt)=wtanh(Wqq+Wkkt)R,(3) \hspace{.25in} a(\mathbf q, \mathbf k_t) = w^\top \textrm{tanh}( W_q q + W_k k_t) \in \mathbb{R},

where WqW_q and WkW_k are transformation matrices transforming the vectors qq and ktk_t from their respective dimensionality (query space and value space), to the dimensionality of attention space.

ww^\top is the vector which encodes tanh()\tanh(\cdot) into a final scalar value, which serves as the attention score logit, prior to the softmax\text{softmax} in eq. 22.

Note that the query space and the key space don't need to be of the same dimensionality, as WqW_q and WkW_k transform both vectors into the same dimensionality transform both vectors into the same dimensionality.

It's difficult to precisely know what WqW_q and WkW_k are really doing to both qq and kk, neural networks are black boxes. But more generally, we can see both matrices as encoding the relevant information from both qq and kk which enables the final attention score to be effectively computed. It all really pieces together when we factor in how gradient descent finds the optimal WW to minimize the loss L\mathcal{L}.

Again, noting the dot product as a metric for similarity between two vectors, perhaps we can interpret as the more similar qq or kk are to the set of respective row vectors in Wk,WqW_k, W_q, the higher magnitude the input to tanh\tanh will have and therefore, the larger the attention score will be if ww is also similar in direction to the output of tanh\tanh.

Coded out,

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def bahdanau(X, query_space, key_space, value_space, attention_space):
    batch_size, embed_dim = X.size(0), X.size(1)
    
    W1_q, W1_k, W1_v = torch.randn(query_space, embed_dim), torch.randn(key_space, embed_dim), torch.randn(value_space, embed_dim)
    W2_q, W2_k, w2_v = torch.randn(attention_space, query_space), torch.randn(attention_space, key_space), torch.randn(attention_space, 1)
    
    w1_q_mul, w1_k_mul, w1_v_mul = W1_q.unsqueeze(0).expand(batch_size, -1, -1), W1_k.unsqueeze(0).expand(batch_size, -1, -1), W1_v.unsqueeze(0).expand(batch_size, -1, -1)
    w2_q_mul, w2_k_mul, w2_v_mul = W2_q.unsqueeze(0).expand(batch_size, -1, -1), W2_k.unsqueeze(0).expand(batch_size, -1, -1), w2_v.T.unsqueeze(0).expand(batch_size, -1, -1)
    
    Q, K, V = torch.bmm(w1_q_mul, X), torch.bmm(w1_k_mul, X), torch.bmm(w1_v_mul, X)
    A_q, A_k = torch.bmm(w2_q_mul, Q), torch.bmm(w2_k_mul, K)
    attention_logits = torch.bmm(w2_v_mul, torch.tanh(A_q + A_k))
    
    attention_scores = F.softmax(attention_logits, dim=2)
    attention = torch.mul(attention_scores, V)
    attention_pool = torch.sum(attention, dim=2)
    
    return attention, attention_scores, attention_pool

if __name__ == "__main__":
    batch_size, hidden_neurons, seq_len = 2, 10, 15
    query_space, value_space, key_space, attention_space = 18, 20, 19, 30
    
    X = torch.randn(batch_size, hidden_neurons, seq_len)
    attention, attention_scores, attention_pool = bahdanau(X, query_space, key_space, value_space, attention_space)
    
    plt.imshow(attention_scores.view(2, 20), cmap='hot', interpolation='nearest')
    plt.title('Attention. All we need.', {'size': 30})
    plt.show()

Of course, now in modern LLMs, we typically use the scaled dot product, qTkd\frac{q^Tk}{\sqrt{d}} where dd is the dimensionality of kk and qq.

Multihead Attention

Given a set of queries, keys, and values, it's a good idea to consider a model that can compute attention pooling in multiple different manners which capture different dependencies in a sequence (long-range vs short-range for instance). This is akin to having multiple kernels, K\mathcal{K} in a convolutional layer, which are optimized to extract different features of the input feature map. Each set of parameters happens to capture different patterns, in the ConvNet perhaps different edge patterns, in the attention head perhaps different semantic meaning.

Therefore, it's beneficial to allow a model to jointly use different representation spaces of the query space, value space, and key space.

Instead of performing a single attention pooling operation, we can compute hh independent attention pooling operations.

Each hh attention pooling output is called a "head", hence, multi-head attention.

Assuming a single token.

Given a query qRhqq \in \mathbb{R}^{h_q}, kRhkk \in \mathbb{R}^{h_k}, vRhvv \in \mathbb{R}^{h_v}, we can compute a single attention head as,

hi=f(Wiqq,Wikk,Wivv)Rhvh_i = f(W_i^qq, W_i^kk, W_i^vv) \in \mathbb{R}^{h_v}

where ff is a similarity function, WiqR(ha×hq)W_i^q \in \mathbb{R}^{(h_a \times h_q)}, WikR(ha×hk)W_i^k \in \mathbb{R}^{(h_a \times h_k)}, WivR(ha×hv)W_i^v \in \mathbb{R}^{(h_a \times h_v)}, hah_a is the dimension of the attention space.

Including something such dot product operation ( no scaling, it's simpler notation, less LaTeX to type out ),

hi=t=0lsoftmax((Wiqq)(Wikkt)j=1lexp((Wiqq)(Wikkj))αi,t(Wivvt)Rhvh_i = \sum_{t=0}^{l}\underbrace{\text{softmax} \left( \frac{(W_i^q q)^\top \cdot (W_i^k k_t)}{\sum_{j=1}^{l} \exp((W_i^q q)^\top \cdot (W_i^k k_j)} \right)}_{\alpha_{i,t}}(W_i^{v}v_t) \in \mathbb{R}^{h_v}

where αi,t\alpha_{i, t} is the attention weight for the iith head at token tt.

After computing all hih_i, we concatenate all hh into a single vector,

Hcat=Concat(h1,h2,,hn)RnhvH_{\text{cat}} = \text{Concat}(h_1, h_2, \dots, h_n) \in \mathbb{R^{n \cdot h_v}}

where nn is the count of attention heads.

Finally, we compute through a fully connected layer,

Hout=WoHcatRheH_{out} = W^oH_{\text{cat}} \in \mathbb{R}^{h_e}\\[3mm]

where heh_e is the original dimensionality of the original embedding space and W(o)R(he×(nhv))W^{(o)} \in \mathbb{R}^{(h_e \times (n\cdot h_v))}.

Positional Encoding

Assuming we're using attention in a transformer, and naturally, transformers don't have recurrent connections as RNNs, we want a way to embed positional information of a token.

Otherwise, the model wouldn't be able to to tell which token comes first, second, \dots, last, and knowing which tt WivvW_i^vv to pay attention to the most would become difficult as there is no notion if which word comes first, second, etc in the sequence.

Assume input embeddings of size heh_e, then we can create positional vectors of same size heh_e,

pi,2j=sin(i10002jd)pi,2j+1=cos(i10002jd)p_{i, 2j} = \sin\left(\frac{i}{1000^{\frac{2j}{d}}}\right) \\[3mm] p_{i, {2j + 1}} = \cos\left(\frac{i}{1000^{\frac{2j}{d}}}\right)

where ii is the iith row of the input matrix, XRl×heX \in \mathbb{R}^{l \times h_e} (heh_e is embedding dimension and ll is sequence length), and jj is the jjth column of the input matrix.

So sin\sin is computed for every row, every even column, while cos\cos is computed for every row, odd column.

Ultimately we get a matrix, PRl×he\mathcal{P} \in \mathbb{R}^{l \times h_e} holding the positional encodings for a given input sequence.

This allows us to effectively encode positional embeddings as we introduce deterministic values for position, given deterministic computation of sin()\sin(\cdot) and cos()\cos(\cdot). Mathematically, this introduces (approximately for larger dd) orthogonality between each vector, guaranteeing linear independence, which indicates that each row vector in P\mathcal{P} is able to capture a unique direction in the row-space.

As a simple example, consider a=[10],b=[01]a = \begin{bmatrix} 1 \\ 0 \end{bmatrix}, b = \begin{bmatrix} 0 \\ 1 \end{bmatrix}.

These two vectors are orthogonal (orthonormal in fact), and this guarantees linear independence as both vectors capture unique information in different directions (in the i^\hat{i} and j^\hat{j} respectively) and you can't express aa in terms of bb.

Positional encodings approximate orthogonality for larger dd (<s>curse</s> blessing of dimensionality), and therefore allow us to capture unique information, in this context with regards to position.

Coded out,

class PositionalEncoding(nn.Module):
    
    def __init__(self, dropout_p: float, d_model: int, max_len: int):
        super().__init__()
       
        self.dropout = nn.Dropout(p=dropout_p)

        position = torch.arange(max_len).unsqueeze(1)  
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe = torch.zeros(size=(max_len, d_model))  
        pe[:, 0::2] = torch.sin(position * div_term) 
        pe[:, 1::2] = torch.cos(position * div_term) 

        pe = pe.unsqueeze(0) 
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = self.pe[:, :x.size(1)] 
        return pe

Masked Softmax

Assume we compute the raw attention scores a()a(\cdot) in matrix form, prior to the softmax, as:

Za=QKd Z_{a} = \frac{QK^\top}{\sqrt{d}}

where QRl×hqQ \in \mathbb{R}^{l \times h_q} and KRl×hkK \in \mathbb{R}^{l \times h_k}, and ZqRl×lZ_q \in \mathbb{R}^{l \times l} where ll is sequence length and hq=hkh_q = h_k.

We transpose KK, in such that the matrix multiplication yields an l×ll \times l matrix and we get the raw attention scores with respect to the sequence (length) and not some meaningless arbitrary dimension ( with respect to our task ) hqh_q.

Assume l=4l = 4, then softmax(Za)\text{softmax}(Z_a):

Za=[.2,.5,.9,.6.4,.4,.9,.3.1,.9,.3,.6.7,.2,.9,.2]A=[0.1708,0.2305,0.3439,0.25480.2196,0.2196,0.3621,0.19870.1641,0.3651,0.2004,0.27050.2912,0.1766,0.3556,0.1766]Z_a = \begin{bmatrix} .2, .5, .9, .6 \\ .4, .4, .9, .3 \\ .1, .9, .3, .6 \\ .7, .2, .9, .2 \end{bmatrix}\\[10mm] A = \begin{bmatrix} 0.1708 , 0.2305 , 0.3439 , 0.2548 \\ 0.2196 , 0.2196 , 0.3621 , 0.1987 \\ 0.1641 , 0.3651 , 0.2004 , 0.2705 \\ 0.2912 , 0.1766 , 0.3556 , 0.1766 \end{bmatrix}

AA is size l×ll \times l, which can be interpreted as follows. Let's rename our variables for now in context of AA, where AA is now Ri×j,i=j\mathbb{R}^{i \times j}, i = j. For every jjth column vector, it's iith (row) value denotes how much "attention" weighting a the jjth token in the sequence pays attention to all other ii tokens (or equivalently ll) in the sequence.

Note that once we compute the matrix of attention weights, we compute a matrix multiplication with WvVRha×lW^v{V} \in \mathbb{R}^{h_a \times l} where WvRha×hvW^v \in \mathbb{R}^{h_a \times h_v} and VV in this case is Rhv×l\in \mathbb{R}^{h_v \times l}

A(WvV)Rl×haA(W^vV)^\top \in \mathbb{R}^{l \times h_a}

Remember that in WvVW^vV each column vector corresponds to a single token tt in the sequence, given the column size ll.

Therefore, in the matrix multiplication of A(WvV)Rl×haA(W^vV)^\top \in \mathbb{R}^{l \times h_a}, we're multiplying each tt value column vector (or row if we account for the transpose) in the attention space. Note that at time tt, we have (WvV)t(W^vV)_t but at tt, we don't expect the given tt token to have seen any other tokens at t+n<lt + n < l.

Therefore, for any attention weights in AA, we need to zero out a portion of the values.

Transformers

Following up on the iterations that led to the formulation for the attention mechanism, Vaswani et al., formulated the Transformer architecture in their seminal paper, Attention is All you Need.

I explain this in depth here.