Population Based Training (PBT) lets your hyperparameter search evolve during training, not just before it.
Let’s watch PBT in action. We’re going to train a simple MLP on MNIST using Ray Tune and PBT to evolve both the learning rate and the number of hidden units.
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
def train_mnist_mlp(config):
# Simulate training a simple MLP on MNIST
# In a real scenario, this would involve TensorFlow, PyTorch, etc.
num_epochs = 10
batch_size = 64
learning_rate = config["lr"]
hidden_units = config["hidden_units"]
print(f"Training with lr={learning_rate}, hidden_units={hidden_units}")
# Simulate training progress and performance
for epoch in range(num_epochs):
# Simulate some work
loss = 1.0 / (epoch + 1) + (100 / hidden_units) + (0.1 / learning_rate)
accuracy = 0.8 - (epoch / num_epochs) * 0.2 + (hidden_units / 200) * 0.1
# Report intermediate results to Tune
tune.report(epoch=epoch, loss=loss, accuracy=accuracy)
# Return final performance metric
return {"final_accuracy": accuracy, "final_loss": loss}
if __name__ == "__main__":
ray.init(ignore_reinit_error=True)
# Define the search space for hyperparameters
search_space = {
"lr": tune.uniform(0.001, 0.1),
"hidden_units": tune.randint(32, 256),
}
# Configure the Population Based Training scheduler
pbt_scheduler = PopulationBasedTraining(
time_attr="epoch", # The attribute to use for time (e.g., epochs, steps)
reward_attr="accuracy", # The metric to optimize
hyperparam_mutations={
"lr": tune.uniform(0.001, 0.1),
"hidden_units": tune.randint(32, 256),
},
# PBT will copy hyperparameters from the best performing
# population members to others periodically.
# This example triggers a copy every 2 epochs.
resample_probability=0.25, # Probability of resampling hyperparameters
perturbation_interval=2, # How often to check for perturbations
quantile_fraction=0.25, # How many of the best population to keep
)
analysis = tune.run(
train_mnist_mlp,
metric="accuracy",
mode="max",
config=search_space,
num_samples=8, # Number of parallel trials (population size)
scheduler=pbt_scheduler,
resources_per_trial={"cpu": 1},
# PBT requires a time attribute to be reported
progress_reporter=tune.CLIReporter(
metric_columns=["epoch", "loss", "accuracy"]
),
# Keep training until a certain epoch, or for a fixed number of trials
stop={"epoch": 10},
)
print("Best config: ", analysis.best_config)
ray.shutdown()
Here’s the mental model: PBT treats your training runs as a population. Periodically, it inspects this population. The "fittest" members (those with the best reward_attr, in our case accuracy) get to "reproduce" their hyperparameters. This means their lr and hidden_units are copied to other, less fit members, with a small chance of mutation. The less fit members then resume training from where they left off, but with these new, potentially better hyperparameters. This allows the search to adapt on the fly, exploring promising regions of the hyperparameter space while abandoning poor ones.
The key levers you control are:
time_attr: What unit of progress PBT uses to decide when to perturb. Usuallyepochortraining_iteration.reward_attr: The metric PBT uses to determine "fitness." This must be a metric youtune.report()during training.hyperparam_mutations: The space PBT will sample from when perturbing hyperparameters. This defines the range and type of mutations.perturbation_interval: How often PBT checks if it’s time to copy/mutate hyperparameters. A smaller interval means more frequent adaptation.resample_probability: The chance that a trial will not be perturbed and will instead resample its hyperparameters from scratch. This helps avoid getting stuck in local optima.quantile_fraction: Determines how many of the best trials are considered for copying hyperparameters. Aquantile_fractionof 0.25 means the top 25% best trials’ hyperparameters are candidates for copying.num_samples: The size of your population. More samples mean a larger search space explored simultaneously, but also more resources used.
PBT works by creating a PopulationBasedTraining scheduler object. You then pass this scheduler to tune.run. Inside tune.run, for each trial, you must tune.report() the time_attr and reward_attr. The scheduler monitors these reports. When perturbation_interval is met, it looks at the reported reward_attr values across all active trials. It identifies the best-performing trials based on quantile_fraction and then samples hyperparameters from their configurations. These sampled hyperparameters are then used to "resume" the training of other, less fit trials. The hyperparam_mutations define the distribution from which these new hyperparameters are drawn, allowing for exploration around the successful ones.
The resample_probability is crucial. If a trial’s hyperparameters are copied, there’s a chance it will also randomly resample its hyperparameters from the hyperparam_mutations space. This prevents the population from becoming too homogeneous and getting stuck in a local optimum. Instead, it ensures that even the "winners" might try something a bit different.
When PBT copies hyperparameters, it doesn’t restart the trial. It takes the current state of a trial and applies the new hyperparameters to its future training steps. This is how it achieves continuous evolution.
What happens is that the PBT scheduler, after a perturbation_interval, will look at the reported reward_attr for all trials that have reached that time_attr milestone. It then selects a subset of the best trials (based on quantile_fraction). For each trial that is not among the best, the scheduler picks one of the best trials (with a probability based on their relative performance) and copies its hyperparameters. It then applies the hyperparam_mutations to these copied hyperparameters, and the less fit trial resumes training with these potentially new values.
The next thing you’ll likely want to explore is how PBT interacts with more complex search spaces and how to properly set the perturbation_interval and quantile_fraction for different problem types.