Transformers
Juan Vera
January 2025
Introduced in Attention is all you need by Vaswani et al., Transformers originally took the form of an encoder-decoder architecture with the encoder containing a MultiHead Self Attention block and a position-wise feed-forward network, and the decoder containing a Masked MultiHead Self Attention, Cross MultiHead Attention, and another .
The encoder architecture takes in the input sequence embeddings, to extract their representation after self attention pooling. The decoder on the other hand takes in the target sequence embeddings, extracting their representation in the same manner as the encoder through self attention pooling.
If you recall, multihead attention can be expressed as:
Note that this can easily be parallellized as a matrix multiplication as is or as a batched matrix multiplication (BMM) if we include the batch dimension and the sequence length dimension .
Multihead attention is explicitly called multihead self-attention the inputs to the attention mechanism all come from the same source, where are the input sequence.
Multihead cross attention refers to the situation in which we take in and to be from different input sequence sources.
In the transformer architecture, and
is multihead self attention
is simply a feed forward neural network, which takes in input vectors position-wise, meaning for sequence of length , for each token representation, provides a single output vector.
In the transformer architecture, consists of two layers, the first layer projecting the inputs into a higher dimensional space, . The second layer projects the outputs of the first layer back to .
Throughout the architecture we can find residual connections, connecting the input and output to the encoder's MHSA, decoder's MHSA and MHCA (multihead cross attention), and in both connecting the input and output to .
Recall that residual connections consist of where is the identity transformation
Encoder
The encoder is structured as:
performs positional encoding using and functions, project the inputs into the attention space, performs , and finally we feed through for a final output to the encoder.
perform ,computed over the embedding dimension of , .
Unlike , we don't use shared statistics over different samples and unlike over a -dimensional matrix, , we don't compute over the entire sequence length (assuming )
The and statistics are simply computed for each token representation in the sequence.
Note that embed the residual connection within the function.
Coded out, the encoder looks as:
class Encoder(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.qkv = QKV(d_model=d_model)
self.multihead_attention = nn.MultiheadAttention(\
embed_dim=d_model, num_heads=num_heads, batch_first=True)
self.layer_norm1, self.layer_norm2 = nn.LayerNorm(d_model), \
nn.LayerNorm(d_model)
self.positionNN = PositionWiseNN(d_model=d_model)
def forward(self, x, pad_msk):
"""Input: (BatchSize, SequenceLen, d_model) → Output: (BatchSize, SequenceLen, d_model)"""
qa, ka, va = self.qkv(x)
h_self_attn, _ = self.multihead_attention(qa, ka, va, key_padding_mask=pad_msk)
h_res = self.layer_norm1(h_self_attn + x)
enc_out = self.layer_norm2(self.positionNN(h_res) + h_res)
return enc_out
Decoder
The decoder is structured as:
Note that this time, the inputs to the decoder come from the target sequence embeddings, .
Just as the encoder, we process projects the inputs into the attention space. computes masked multihead self attention, where masking is applied to avoid tokens at time from attending to any token at .
We also need to create what PyTorch defines as key_padding_mask, essentially enabling the transformer to ignore the special tokens which should contribute nothing to the training of the model.
Coded out, we can create masks as:
def causal_mask(in_seq_len, target_seq_len):
x = torch.ones(size=(target_seq_len, in_seq_len))
return torch.tril(input=x, diagonal=0).bool()
def padding_mask(tokenized_seqs, pad_token_id):
return (tokenized_seqs != pad_token_id).long().bool()
which will be fed into nn.MultiHeadAttention during the forward pass.
In , we perform MHCA and finally through we compute and normalization.
Coded out, the decoder looks as:
class Decoder(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.qkv1 = QKV(d_model=d_model)
self.masked_multihead_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
self.layernorm1 = nn.LayerNorm(normalized_shape=d_model)
self.qkv2 = QKV(d_model=d_model, in_=False)
self.cross_multihead_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
self.layernorm2 = nn.LayerNorm(normalized_shape=d_model)
self.positionNN = PositionWiseNN(d_model=d_model)
self.layernorm3 = nn.LayerNorm(normalized_shape=d_model)
def forward(self, y, enc_out, csl_msk, pad_msk, enc_pad_msk):
qa, ka, va = self.qkv1(y)
h_masked_attn, _ = self.masked_multihead_attention(qa, ka, va, attn_mask=csl_msk, key_padding_mask=pad_msk)
h_res = self.layernorm1(h_masked_attn + y)
qa, ka, va = self.qkv2(x=(h_res, enc_out, enc_out))
h_cross_attn, _ = self.cross_multihead_attention(qa, ka, va, key_padding_mask=enc_pad_msk)
h_res2 = self.layernorm2(h_cross_attn + h_res)
h_pos = self.positionNN(h_res2)
dec_out = self.layernorm3(h_pos + h_res2)
return dec_out
Architecture
The final architecture is simply a construction of the above Encoder and Decoder, into a single forward pipeline, with the addition of a final Linear Layer and Softmax Activation for the final prediction of the next token.
I won't be writing out the entire mathematical flow, (LaTeX takes extremely long to write out, esp. with all the supercripts / underscripts) as it's essentially the same as the above, just pieced together.
from blocks import PositionalEncoding, Encoder, Decoder
from ops import causal_mask, padding_mask
class Transformer(nn.Module):
def __init__(self, dropout_p, d_model, max_len, num_heads, Y_tokenized_seqs, X_tokenized_seqs, pad_token_id, n, device=('cuda' if torch.cuda.is_available() else 'mps')):
super().__init__()
self.device = device
self.Y_tokenized_seqs = Y_tokenized_seqs
self.X_tokenized_seqs = X_tokenized_seqs
self.pad_token_id = pad_token_id
self.n = n # num of total encoder:decoder blocks
self.pe = PositionalEncoding(dropout_p=dropout_p, d_model=d_model, max_len=max_len).to(device)
self.encoders = nn.ModuleList([Encoder(d_model=d_model, num_heads=num_heads) for _ in range(n)]).to(device)
self.decoders = nn.ModuleList([Decoder(d_model=d_model, num_heads=num_heads) for _ in range(n)]).to(device)
def forward(self, x, y):
target_seq_len = y.size(1)
in_seq_len = target_seq_len # as masked softmax is for self attention, not cross attention.
csl_msk = causal_mask(in_seq_len=in_seq_len, target_seq_len=target_seq_len).to(self.device)
dec_pad_msk = padding_mask(tokenized_seqs=self.Y_tokenized_seqs, pad_token_id=self.pad_token_id).to(self.device)
enc_pad_msk = padding_mask(tokenized_seqs=self.X_tokenized_seqs, pad_token_id=self.pad_token_id).to(self.device)
x_p = x + self.pe(x).to(self.device) # encoder input
y_p = y + self.pe(y).to(self.device) # decoder input
for n in range(self.n):
enc_out = self.encoders[n](x_p, enc_pad_msk) if n == 0 else self.encoders[n](enc_out, enc_pad_msk)
dec_out = self.decoders[n](y_p, enc_out, csl_msk, dec_pad_msk, enc_pad_msk)
return dec_out