The magic of Transformer attention is that it doesn’t just look at the current word; it can look at any word in the input sequence, no matter how far away, and decide how relevant it is.

Let’s see it in action. Imagine we’re translating "The cat sat on the mat" into French. When the model is generating the French word for "mat," it needs to know what "mat" refers to.

import torch
import torch.nn.functional as F

# Example inputs
batch_size = 1
sequence_length = 5
embedding_dim = 8

# Simulate input embeddings for "The cat sat on the mat"
# Shape: (batch_size, sequence_length, embedding_dim)
input_embeddings = torch.randn(batch_size, sequence_length, embedding_dim)

# In a real Transformer, these would come from linear layers (W_Q, W_K, W_V)
# For simplicity, we'll use the input embeddings directly as Q, K, V
query = input_embeddings
key = input_embeddings
value = input_embeddings

# Calculate attention scores
# Dot product of Query and Key.
# Shape: (batch_size, sequence_length, sequence_length)
attention_scores = torch.matmul(query, key.transpose(-2, -1))

# Scale the scores
scale_factor = embedding_dim ** 0.5
scaled_attention_scores = attention_scores / scale_factor

# Apply softmax to get attention weights
# These weights sum to 1 for each query position.
# Shape: (batch_size, sequence_length, sequence_length)
attention_weights = F.softmax(scaled_attention_scores, dim=-1)

# Apply weights to the Value
# This is the output of the self-attention mechanism.
# Each output vector is a weighted sum of the input value vectors.
attention_output = torch.matmul(attention_weights, value)

print("Input Embeddings Shape:", input_embeddings.shape)
print("Attention Weights Shape:", attention_weights.shape)
print("Attention Output Shape:", attention_output.shape)

# Let's look at the attention weights for the last word ("mat")
# This shows how much attention the model pays to each word when processing "mat"
print("\nAttention weights for 'mat' (last word):")
print(attention_weights[0, -1, :])

This code snippet demonstrates the core of self-attention. We have query, key, and value vectors (derived from the input embeddings). The process involves:

  1. Scoring: Calculating the dot product between each query vector and all key vectors. This gives us a raw score indicating how "compatible" a query is with each key.
  2. Scaling: Dividing these scores by the square root of the embedding_dim. This prevents the dot products from becoming too large, which could lead to vanishing gradients after the softmax.
  3. Weighting: Applying a softmax function to the scaled scores. This turns the scores into probability-like weights that sum to 1. These are the attention weights. A high weight means a particular input word is highly relevant to the current word being processed.
  4. Output: Multiplying these attention weights by the value vectors and summing them up. This produces the final output for each position, which is a context-aware representation, having incorporated information from other parts of the sequence based on their relevance.

The problem this solves is the limitations of traditional Recurrent Neural Networks (RNNs) or Convolutional Neural Networks (CNNs) when processing long sequences. RNNs struggle with long-range dependencies due to vanishing/exploding gradients and sequential processing bottlenecks. CNNs have a limited receptive field unless stacked very deeply. Attention allows the model to directly access and weigh information from any position in the input sequence, regardless of distance, enabling it to capture complex relationships and dependencies much more effectively.

The key levers you control are the dimensionality of the embeddings and the number of attention heads (in multi-head attention, which is a common extension). The embedding_dim influences the richness of the representations, and more heads allow the model to attend to different aspects of the input simultaneously, capturing diverse relationships.

What most people don’t realize is that the key and value matrices don’t have to be derived from the same source as the query. In cross-attention, for instance, the query comes from one sequence (e.g., the decoder’s current state), while the key and value come from another (e.g., the encoder’s output). This allows one sequence to selectively attend to information in another.

The next concept you’ll likely grapple with is how to make this mechanism more efficient for very long sequences, leading to architectures like Longformer or Reformer.

Want structured learning?

Take the full Pytorch course →