Fine-tuning a pre-trained BERT model for text classification is surprisingly less about training from scratch and more about teaching a highly sophisticated language learner to focus its existing knowledge on your specific task.
Let’s see this in action. Imagine we have a dataset of movie reviews, labeled positive or negative. We want to train BERT to predict the sentiment of new reviews.
from transformers import BertTokenizer, BertForSequenceClassification
import torch
# Load pre-trained model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) # 2 labels for positive/negative
# Example text and labels
texts = ["This movie was fantastic!", "The acting was terrible."]
labels = [1, 0] # 1 for positive, 0 for negative
# Tokenize the texts
encoded_inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Forward pass
outputs = model(**encoded_inputs, labels=torch.tensor(labels))
# The loss is calculated automatically if labels are provided
loss = outputs.loss
logits = outputs.logits
print(f"Loss: {loss.item()}")
print(f"Logits: {logits}")
# To get predictions:
predictions = torch.argmax(logits, dim=1)
print(f"Predictions: {predictions}")
This snippet shows the core loop: tokenize your text, feed it to the model along with the correct labels, and the model outputs loss and logits. The BertForSequenceClassification class is key here; it adds a classification head on top of the standard BERT architecture.
The problem BERT fine-tuning solves is the immense cost and data requirement of training a large language model from scratch. BERT, pre-trained on massive amounts of text, already understands grammar, context, and a vast amount of world knowledge. Fine-tuning leverages this pre-existing knowledge, adapting it to a specific downstream task like sentiment analysis, question answering, or named entity recognition with a relatively small dataset and much less computational power.
Internally, BertForSequenceClassification takes the last_hidden_state from the base BERT model (usually the representation of the [CLS] token, which is designed to aggregate sequence-level information) and passes it through a linear layer. This linear layer is the "classification head" – its weights are randomly initialized and are the primary parameters that get updated during fine-tuning. The rest of the BERT model’s weights are also updated, but typically with a much smaller learning rate, allowing them to adjust subtly without forgetting their general language understanding.
The levers you control are primarily:
- Learning Rate: Crucial for fine-tuning. Too high, and you’ll destroy the pre-trained weights. Too low, and training will be slow or get stuck. A common range for BERT fine-tuning is
1e-5to5e-5. - Number of Epochs: Typically very few. 2-4 epochs are often sufficient because the model is already so capable. Overfitting is a significant risk with too many epochs.
- Batch Size: Influences gradient stability. Common sizes are 16 or 32. Larger batches can sometimes be more stable but require more memory.
- Optimizer: AdamW (Adam with weight decay) is the standard choice for transformer models.
- Data Preprocessing: Ensuring your text is cleaned and tokenized consistently with the BERT tokenizer is vital.
padding=Trueandtruncation=Trueare essential for handling variable-length sequences.
A common pitfall is not correctly setting the learning rate for the base BERT model versus the classification head. While BertForSequenceClassification handles this internally by default (often applying a lower learning rate to the base model), explicitly setting different learning rates for different parameter groups (e.g., using torch.optim.AdamW with custom weight decay or by manually creating parameter groups) can offer more control. For instance, you might want to freeze the initial layers of BERT entirely if your downstream task is very similar to the pre-training objective, though this is less common for text classification.
The next concept you’ll likely grapple with is optimizing the fine-tuning process itself, moving beyond basic accuracy to robust evaluation metrics and hyperparameter searching.