Transformers

Juan Vera

January 2025

Implementation in PyTorch here.
More background here

Introduced in Attention is all you need by Vaswani et al., Transformers originally took the form of an encoder-decoder architecture with the encoder containing a MultiHead Self Attention block and a position-wise feed-forward network, Fpos\mathcal{F}_{pos} and the decoder containing a Masked MultiHead Self Attention, Cross MultiHead Attention, and another Fpos\mathcal{F}_{pos}.

The encoder architecture takes in the input sequence embeddings, to extract their representation after self attention pooling. The decoder on the other hand takes in the target sequence embeddings, extracting their representation in the same manner as the encoder through self attention pooling.

If you recall, multihead attention can be expressed as:

hi=t=0hsoftmax((Wiqq)(Wikkt)dmodel)αi,t(Wivvt)RhvHcat=Concat(h1,,hn)R(nhv)Hout=WoHcatRheh_i = \sum_{t=0}^{h}\underbrace{ \text{softmax}\left(\frac{(W_i^qq)^\top \cdot (W_i^kk_t)}{\sqrt{d_{\text{model}}}}\right)}_{\alpha_{i,t}}(W_i^{v}v_t) \in \mathbb{R}^{h_v}\\[5mm] H_{\text{cat}} = \text{Concat}(h_1, \dots, h_n) \in \mathbb{R}^{(n \cdot h_v)} \\[5mm] H_{\text{out}} = W^oH_{\text{cat}} \in \mathbb{R}^{h_e}

Note that this can easily be parallellized as a matrix multiplication as is or as a batched matrix multiplication (BMM) if we include the batch dimension bb and the sequence length dimension ll.

Multihead attention is explicitly called multihead self-attention the Q,K,VQ, K, V inputs to the attention mechanism all come from the same source, X\sim X where XX are the input sequence.

Multihead cross attention refers to the situation in which we take in QQ and K,VK, V to be from different input sequence sources.

In the transformer architecture, Q=Encoder(X)Q = \text{Encoder}(X) and K,V=MHSAenc(X)K, V = \text{MHSA}_{\text{enc}}(X)

MHSA\text{MHSA} is multihead self attention

Fpos\mathcal{F}_{\text{pos}} is simply a feed forward neural network, which takes in input vectors position-wise, meaning for sequence of length ll, for each tt token representation, F\mathcal{F} provides a single output vector.

In the transformer architecture, Fpos\mathcal{F}_{\text{pos}} consists of two layers, the first layer projecting the inputs Rdmodel\in \mathbb{R}^{d_{\text{model}}} into a higher dimensional space, R4dmodel\mathbb{R}^{4 \cdot d_{\text{model}}}. The second layer projects the outputs of the first layer back to Rdmodel\mathbb{R}^{d_{\text{model}}}.

Throughout the architecture we can find residual connections, connecting the input and output to the encoder's MHSA, decoder's MHSA and MHCA (multihead cross attention), and in both connecting the input and output to Fpos\mathcal{F}_{\text{pos}}.

Recall that residual connections consist of Layer(X)+I(X)\text{Layer}(X) + I(X) where II is the identity transformation

Encoder

The encoder is structured as:

(1)Xp=P(Xe)Rb×l×he(2)Q=WqXpRb×l×hq(3)K=WkXpRb×l×hk(4)V=WvXpRb×l×hv(5)Q^=WaqQRb×l×ha(6)K^=WakKRb×l×ha(7)V^=WavVRb×l×ha(8)hi=Self Attention(Q^,K^,V^)iRb×l×hv(9)Hcat=cat(h1,,hn)Rb×l×(nhv)(10)Hout=WoHcatRb×l×he(11)Hres=LayerNorm(Hout+Xp)Rb×l×he(12)Z^=Fpos(Hres)+HresRb×l×he(13)Z^enc=LayerNorm(Z^enc)(1) \hspace{.25in} X_p = \mathcal{P}(X_e) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (2) \hspace{.25in} Q = W^qX_p \in \mathbb{R}^{b \times l \times h_q} \\[5mm] (3) \hspace{.25in} K = W^kX_p \in \mathbb{R}^{b \times l \times h_k} \\[5mm] (4) \hspace{.25in} V = W^vX_p \in \mathbb{R}^{b \times l \times h_v} \\[5mm] (5) \hspace{.25in} \hat{Q} = W^{aq}Q \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (6) \hspace{.25in} \hat{K} = W^{ak}K \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (7) \hspace{.25in} \hat{V} = W^{av}V \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (8) \hspace{.25in} h_i = \text{Self Attention}(\hat{Q}, \hat{K}, \hat{V})_i \in \mathbb{R}^{b \times l \times h_v} \\[5mm] (9) \hspace{.25in} H_{\text{cat}} = \text{cat}(h_1, \dots, h_n) \in \mathbb{R}^{b \times l \times (n \cdot h_v)}\\[5mm] (10) \hspace{.25in} H_{\text{out}} = W^oH_{\text{cat}} \in \mathbb{R}^{b \times l \times h_e}\\[5mm] (11) \hspace{.25in} H_{\text{res}} = \text{LayerNorm}(H_{\text{out}} + X_p) \in \mathbb{R}^{b \times l \times h_e}\\[5mm] (12) \hspace{.25in} \hat{Z} = \mathcal{F}_{\text{pos}}(H_{\text{res}}) + H_{\text{res}} \in \mathbb{R}^{b \times l \times h_e}\\[5mm] (13) \hspace{.25in} \hat{Z}_{\text{enc}} = \text{LayerNorm}(\hat{Z}_{\text{enc}})

(1)(1) performs positional encoding using sin\sin and cos\cos functions, (27)(2 - 7) project the inputs into the attention space, (810)(8 - 10) performs MHSA\text{MHSA}, and finally we feed through Fpos\mathcal{F}_{\text{pos}} for a final output to the encoder.

(11,13)(11, 13) perform LayerNorm\text{LayerNorm},computed over the embedding dimension of HoutH_{\text{out}}, heh_e.

Unlike BatchNorm\text{BatchNorm}, we don't use shared statistics over different samples and unlike LayerNorm\text{LayerNorm} over a 22-dimensional matrix, X\mathcal{X}, we don't compute over the entire sequence length ll (assuming XR10×784=b×l\mathcal{X} \in \mathbb{R}^{10 \times 784 \hspace{1mm}=\hspace{1mm}b \times l})

The μ\mu and σ2\sigma^2 statistics are simply computed for each tt token representation in the sequence.

Note that (11,13)(11, 13) embed the residual connection within the LayerNorm()\text{LayerNorm}(\cdot) function.

Coded out, the encoder looks as:

class Encoder(nn.Module):  
    def __init__(self, d_model, num_heads):  
        super().__init__()  
        self.qkv = QKV(d_model=d_model)  
        self.multihead_attention = nn.MultiheadAttention(\
            embed_dim=d_model, num_heads=num_heads, batch_first=True)  
        self.layer_norm1, self.layer_norm2 = nn.LayerNorm(d_model), \
            nn.LayerNorm(d_model)  
        self.positionNN = PositionWiseNN(d_model=d_model)  

    def forward(self, x, pad_msk):  
        """Input: (BatchSize, SequenceLen, d_model) → Output: (BatchSize, SequenceLen, d_model)"""  
        qa, ka, va = self.qkv(x)  
        h_self_attn, _ = self.multihead_attention(qa, ka, va, key_padding_mask=pad_msk)  
        h_res = self.layer_norm1(h_self_attn + x)  
        enc_out = self.layer_norm2(self.positionNN(h_res) + h_res)  
        return enc_out  

Decoder

The decoder is structured as:

(14)Yp=P(Ye)Rb×l×he(15)Q=WqYpRb×l×hq(16)K=WkYpRb×l×hk(17)V=WvYpRb×l×hv(18)Q^=WaqQRb×l×ha(19)K^=WakKRb×l×ha(20)V^=WavVRb×l×ha(21)hi=Masked MultiHead Cross Attention(Q^,K^,V^)Rb×l×hv(22)Hres=ConcatAddNorm(h1,,hn)Rb×l×he(23)hi=MultiHead Attention(WqHres,WkZ^enc,WvZ^enc)Rb×l×he(24)Hres(2)=ConcatAddNorm(h1,,hn)Rb×l×he(25)HF=Fpos(Hres(2))Rb×l×he(26)Hres(3)=LayerNorm(HF+Hres(2))Rb×l×he(27)Z=WzHres(3)Rb×l×hy^(28)Y^=softmax(Z)Rb×l×hy^(14) \hspace{.25in} Y_p = \mathcal{P}(Y_e) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (15) \hspace{.25in} Q = W^qY_p \in \mathbb{R}^{b \times l \times h_q} \\[5mm] (16) \hspace{.25in} K = W^kY_p \in \mathbb{R}^{b \times l \times h_k} \\[5mm] (17) \hspace{.25in} V = W^vY_p \in \mathbb{R}^{b \times l \times h_v} \\[5mm] (18) \hspace{.25in} \hat{Q} = W^{aq}Q \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (19) \hspace{.25in} \hat{K} = W^{ak}K \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (20) \hspace{.25in} \hat{V} = W^{av}V \in \mathbb{R}^{b \times l \times h_a} \\[5mm] (21) \hspace{.25in} h_i = \text{Masked MultiHead Cross Attention}(\hat{Q}, \hat{K}, \hat{V}) \in \mathbb{R}^{b \times l \times h_v} \\[5mm] (22) \hspace{.25in} H_{\text{res}} = \text{ConcatAddNorm}(h_1, \dots, h_n) \in \mathbb{R}^{b \times l \times h_e}\\[5mm] (23) \hspace{.25in} h_i = \text{MultiHead Attention}(W^qH_{\text{res}}, W^k\hat{Z}_{\text{enc}}, W^v\hat{Z}_{\text{enc}}) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (24) \hspace{.25in} H_{\text{res}}^{(2)} = \text{ConcatAddNorm}(h_1, \dots, h_n) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (25) \hspace{.25in} H_{\mathcal{F}} = \mathcal{F}_{\text{pos}}(H_{\text{res}}^{(2)}) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (26) \hspace{.25in} H_{\text{res}}^{(3)} = \text{LayerNorm}(H_{\mathcal{F}} + H_{\text{res}}^{(2)}) \in \mathbb{R}^{b \times l \times h_e} \\[5mm] (27) \hspace{.25in} Z = W^{z}H_{\text{res}}^{(3)} \in \mathbb{R}^{b \times l \times h_{\hat{y}}} \\[5mm] (28) \hspace{.25in} \hat{Y} = \text{softmax}(Z) \in \mathbb{R}^{b \times l \times h_{\hat{y}}}

Note that this time, the inputs to the decoder come from the target sequence embeddings, YeY_e.

Just as the encoder, we process (1520)(15 - 20) projects the inputs into the attention space. (21)(21) computes masked multihead self attention, where masking is applied to avoid tokens at time tt from attending to any token at t+n<lt + n < l.

We also need to create what PyTorch defines as key_padding_mask, essentially enabling the transformer to ignore the special <pad><pad> tokens which should contribute nothing to the training of the model.

Coded out, we can create masks as:

def causal_mask(in_seq_len, target_seq_len):
    x = torch.ones(size=(target_seq_len, in_seq_len))
    return torch.tril(input=x, diagonal=0).bool()

def padding_mask(tokenized_seqs, pad_token_id):
    return (tokenized_seqs != pad_token_id).long().bool()

which will be fed into nn.MultiHeadAttention during the forward pass.

In (23)(23), we perform MHCA and finally through (2528)(25 - 28) we compute Fpos\mathcal{F}_{\text{pos}} and normalization.

Coded out, the decoder looks as:

class Decoder(nn.Module):
    def __init__(self, d_model, num_heads):
        
        super().__init__()
        
        self.qkv1 = QKV(d_model=d_model)
        self.masked_multihead_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model)
        self.qkv2 = QKV(d_model=d_model, in_=False)
        self.cross_multihead_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model)
        self.positionNN = PositionWiseNN(d_model=d_model)
        self.layernorm3 = nn.LayerNorm(normalized_shape=d_model)

    def forward(self, y, enc_out, csl_msk, pad_msk, enc_pad_msk):
        
        qa, ka, va = self.qkv1(y)
        h_masked_attn, _ = self.masked_multihead_attention(qa, ka, va, attn_mask=csl_msk, key_padding_mask=pad_msk)
        h_res = self.layernorm1(h_masked_attn + y)
        qa, ka, va = self.qkv2(x=(h_res, enc_out, enc_out))
        h_cross_attn, _ = self.cross_multihead_attention(qa, ka, va, key_padding_mask=enc_pad_msk)
        h_res2 = self.layernorm2(h_cross_attn + h_res)
        h_pos = self.positionNN(h_res2)
        dec_out = self.layernorm3(h_pos + h_res2)
        
        return dec_out

Architecture

The final architecture is simply a construction of the above Encoder and Decoder, into a single forward pipeline, with the addition of a final Linear Layer and Softmax Activation for the final prediction of the next token.

I won't be writing out the entire mathematical flow, (LaTeX takes extremely long to write out, esp. with all the supercripts / underscripts) as it's essentially the same as the above, just pieced together.

from blocks import PositionalEncoding, Encoder, Decoder
from ops import causal_mask, padding_mask

class Transformer(nn.Module):
    
    def __init__(self, dropout_p, d_model, max_len, num_heads, Y_tokenized_seqs, X_tokenized_seqs, pad_token_id, n, device=('cuda' if torch.cuda.is_available() else 'mps')):
        super().__init__()
        self.device = device 
        self.Y_tokenized_seqs = Y_tokenized_seqs
        self.X_tokenized_seqs = X_tokenized_seqs
        self.pad_token_id = pad_token_id
        self.n = n  # num of total encoder:decoder blocks
        self.pe = PositionalEncoding(dropout_p=dropout_p, d_model=d_model, max_len=max_len).to(device)
        self.encoders = nn.ModuleList([Encoder(d_model=d_model, num_heads=num_heads) for _ in range(n)]).to(device)
        self.decoders = nn.ModuleList([Decoder(d_model=d_model, num_heads=num_heads) for _ in range(n)]).to(device)

    def forward(self, x, y):
        target_seq_len = y.size(1)
        in_seq_len = target_seq_len  # as masked softmax is for self attention, not cross attention.
        
        csl_msk = causal_mask(in_seq_len=in_seq_len, target_seq_len=target_seq_len).to(self.device)
        dec_pad_msk = padding_mask(tokenized_seqs=self.Y_tokenized_seqs, pad_token_id=self.pad_token_id).to(self.device)
        enc_pad_msk = padding_mask(tokenized_seqs=self.X_tokenized_seqs, pad_token_id=self.pad_token_id).to(self.device)
        
        x_p = x + self.pe(x).to(self.device)  # encoder input
        y_p = y + self.pe(y).to(self.device)  # decoder input
        
        for n in range(self.n):
            enc_out = self.encoders[n](x_p, enc_pad_msk) if n == 0 else self.encoders[n](enc_out, enc_pad_msk)
            dec_out = self.decoders[n](y_p, enc_out, csl_msk, dec_pad_msk, enc_pad_msk)
        
        return dec_out