The most surprising thing about PyTorch’s Dataset and DataLoader is how little they actually do for you by default; they’re primarily organizational tools that leverage Python’s iteration protocol to make data handling feel more declarative.

Let’s see this in action. Imagine you have a directory of image files and corresponding text files, each with the same base name (e.g., img_001.png and txt_001.txt). You want to load these pairs for a PyTorch model.

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch

class ImageTextDataset(Dataset):
    def __init__(self, image_dir, text_dir, transform=None):
        self.image_dir = image_dir
        self.text_dir = text_dir
        self.transform = transform
        # Get a list of all image files, assuming they have corresponding text files
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
        self.image_files.sort() # Ensure consistent ordering

    def __len__(self):
        # The total number of samples is the number of image files
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get the base name for the current index
        img_filename = self.image_files[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Construct full paths
        img_path = os.path.join(self.image_dir, img_filename)
        txt_path = os.path.join(self.text_dir, base_name + '.txt')

        # Load the image
        image = Image.open(img_path).convert('RGB') # Ensure 3 channels

        # Load the text
        with open(txt_path, 'r') as f:
            text = f.read()

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)

        # Return the image and text. For a model, you'd typically
        # tokenize and convert text to a tensor here.
        return image, text

# Example usage (assuming you have 'images/' and 'texts/' directories)
# Create dummy files for demonstration
if not os.path.exists('images'):
    os.makedirs('images')
if not os.path.exists('texts'):
    os.makedirs('texts')

# Create dummy image files
for i in range(5):
    img = Image.new('RGB', (60, 30), color = (i*50, i*30, i*20))
    img.save(f'images/img_{i:03d}.png')

# Create dummy text files
for i in range(5):
    with open(f'texts/txt_{i:03d}.txt', 'w') as f:
        f.write(f'This is text for sample {i}.')

# Instantiate the dataset
# You'd typically use torchvision.transforms here
my_dataset = ImageTextDataset(image_dir='images', text_dir='texts')

# Instantiate the DataLoader
# batch_size=2 means we'll get 2 samples at a time
# shuffle=True means the order will be randomized each epoch
# num_workers=0 means data loading happens in the main process
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=True, num_workers=0)

# Iterate through the DataLoader
print("Iterating through DataLoader:")
for i, (images, texts) in enumerate(data_loader):
    print(f"\n--- Batch {i+1} ---")
    print(f"Image batch shape: {images.shape}") # Will be (batch_size, C, H, W) if transformed
    print(f"Text batch: {texts}")
    if i == 1: # Just show a couple of batches
        break

# Clean up dummy files
import shutil
shutil.rmtree('images')
shutil.rmtree('texts')

This ImageTextDataset class is a blueprint. It tells PyTorch how to get a single data sample (__getitem__) and how many samples there are in total (__len__). The DataLoader then takes this blueprint and wraps it with powerful features like batching, shuffling, and parallel loading.

The core problem this solves is abstracting away the boilerplate of data loading. Instead of writing loops to load individual files, check for missing ones, and then form batches, you define the "what" (how to get one item) and the DataLoader handles the "how" (efficiently fetching and batching many items).

Internally, DataLoader is a Python iterator. When you call iter(data_loader), it creates an iterator object. Each time you call next() on this iterator (which happens implicitly in a for loop), the DataLoader orchestrates fetching the requested number of samples from your Dataset. If num_workers > 0, it uses a pool of worker processes. Each worker process has a copy of the Dataset and calls __getitem__ for specific indices. These workers then send the loaded data back to the main process, where the DataLoader assembles them into a batch (a single torch.Tensor or a tuple/list of tensors for images and texts in our example). The shuffle=True parameter means the DataLoader maintains a list of indices and shuffles it before each epoch, ensuring different orders of data.

The __getitem__ method is where the magic of data augmentation and preprocessing happens. By passing a torchvision.transforms.Compose object to your Dataset’s __init__, you can chain operations like random cropping, resizing, normalization, and converting PIL Images to torch.Tensors. This is crucial for training robust deep learning models. For text, this would involve tokenization, numericalization (mapping tokens to IDs), padding, and converting to a torch.LongTensor.

What most people don’t realize is how DataLoader handles the num_workers setting. When num_workers > 0, each worker process is a distinct Python interpreter. It needs to pickle your Dataset object (or at least the parts it needs to function) and send it to the worker. This means your Dataset’s __init__ should ideally not contain large, unpickleable objects, or you need to handle their re-initialization in the worker process. If your Dataset relies on global state or complex shared resources, using multiple workers can become tricky. The data itself is then sent back from the worker to the main process, also via pickling/unpickling, which can be a bottleneck if your data items are very large.

The next concept you’ll likely encounter is how to handle variable-length sequences in text data, which requires custom collate functions within the DataLoader.

Want structured learning?

Take the full Pytorch course →