A Transformer can learn dependencies between sequence elements regardless of their distance, a feat traditional RNNs struggle with.

Let’s build a simplified Transformer for sequence-to-sequence tasks, like machine translation, from scratch using PyTorch. We’ll focus on the core components: self-attention, multi-head attention, and the feed-forward network.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x.size(1) is the sequence length
        x = x + self.pe[:x.size(1), :]
        return x

# --- Multi-Head Attention ---
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        matmul_qk = torch.matmul(Q, K.transpose(-1, -2))
        dk = K.size(-1)
        scaled_attention_logits = matmul_qk / math.sqrt(dk)
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)
        attention_weights = F.softmax(scaled_attention_logits, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, v, k, q, mask=None):
        batch_size = v.size(0)

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
        concat_attention = scaled_attention.view(batch_size, -1, self.d_model)

        output = self.dense(concat_attention)
        return output, attention_weights

# --- Feed Forward Network ---
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, dff):
        super(FeedForwardNetwork, self).__init__()
        self.layer1 = nn.Linear(d_model, dff)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(dff, d_model)

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

# --- Encoder Layer ---
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForwardNetwork(d_model, dff)

        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)

    def forward(self, x, mask=None):
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.layernorm2(out1 + ffn_output)
        return out2

# --- Decoder Layer ---
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads) # Masked self-attention
        self.mha2 = MultiHeadAttention(d_model, num_heads) # Encoder-decoder attention
        self.ffn = FeedForwardNetwork(d_model, dff)

        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)
        self.dropout3 = nn.Dropout(rate)

    def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        # Masked Multi-Head Attention (self-attention on decoder inputs)
        attn1, _ = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1)
        out1 = self.layernorm1(x + attn1)

        # Encoder-Decoder Attention
        attn2, _ = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2)
        out2 = self.layernorm2(out1 + attn2)

        # Feed Forward Network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output)
        out3 = self.layernorm3(out2 + ffn_output)
        return out3

# --- Transformer ---
class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                 target_vocab_size, pe_max_len, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = nn.ModuleList([EncoderLayer(d_model, num_heads, dff, rate)
                                     for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, num_heads, dff, rate)
                                     for _ in range(num_layers)])

        self.pos_encoding = PositionalEncoding(d_model, pe_max_len)
        self.dropout = nn.Dropout(rate)

        self.final_layer = nn.Linear(d_model, target_vocab_size)

    def create_padding_mask(self, seq):
        seq = torch.eq(seq, 0).unsqueeze(1).unsqueeze(2)
        return seq.expand(-1, 1, -1, seq[-1])

    def create_look_ahead_mask(self, size):
        mask = 1 - torch.triu(torch.ones((size, size), device=self.pe.device), diagonal=1)
        return mask

    def forward(self, inp, tar):
        # Input shape: (batch_size, input_seq_len)
        # Target shape: (batch_size, target_seq_len)

        # Create masks
        enc_padding_mask = self.create_padding_mask(inp)
        dec_padding_mask = self.create_padding_mask(tar)

        look_ahead_mask = self.create_look_ahead_mask(tf.shape(tar)[1])
        dec_target_padding_mask = self.create_look_ahead_mask(tf.shape(tar)[1])

        combined_mask = tf.maximum(dec_target_padding_mask, dec_padding_mask)

        # Positional Encoding for Input
        inp = self.pos_encoding(inp) # Assuming inp is embedding layer output
        inp = self.dropout(inp)

        # Encoder
        enc_output = inp
        for i in range(len(self.encoder)):
            enc_output = self.encoder[i](enc_output, enc_padding_mask)

        # Positional Encoding for Target
        tar = self.pos_encoding(tar) # Assuming tar is embedding layer output
        tar = self.dropout(tar)

        # Decoder
        dec_output = tar
        for i in range(len(self.decoder)):
            dec_output = self.decoder[i](dec_output, enc_output, combined_mask, dec_padding_mask)

        # Final Linear Layer
        final_output = self.final_layer(dec_output) # (batch_size, target_seq_len, target_vocab_size)

        return final_output

The core idea is "attention is all you need." Instead of recurrently processing tokens, we allow every token to "attend" to every other token in the sequence.

Self-attention, the heart of the Transformer, works by calculating Query, Key, and Value vectors for each input token. The attention score between two tokens is the dot product of their Query and Key vectors, scaled. These scores determine how much "attention" each token pays to others when generating its new representation.

When we use "multi-head attention," we’re essentially performing self-attention multiple times in parallel with different learned linear projections for Q, K, and V. This allows the model to jointly attend to information from different representation subspaces at different positions. The outputs from these "heads" are concatenated and then linearly projected.

The positional encoding is crucial because self-attention itself is permutation-invariant. Without it, the model wouldn’t know the order of tokens. We inject information about the relative or absolute position of tokens by adding a unique positional vector to each input embedding. These vectors are typically generated using sine and cosine functions of different frequencies.

The encoder block processes the input sequence. It consists of a multi-head self-attention layer followed by a position-wise feed-forward network. Crucially, both sub-layers use residual connections and layer normalization. The decoder block is similar but has an additional multi-head attention layer that attends to the output of the encoder stack. It also uses a masked self-attention layer to prevent positions from attending to subsequent positions, ensuring that predictions for position i can depend only on known outputs at less than i.

The specific implementation detail that most people overlook is how the masks are constructed and applied. The look_ahead_mask prevents the decoder from "cheating" by looking at future tokens during training, ensuring a causal prediction. The padding_mask is used to ignore padding tokens in both the encoder and decoder, so they don’t influence the attention calculations. These masks are applied before the softmax in the scaled dot-product attention, effectively setting the attention scores for masked positions to negative infinity.

The next concept to explore is how to train this Transformer, including the choice of loss function, optimizer, and learning rate scheduling.

Want structured learning?

Take the full Pytorch course →