If you’ve ever worked on machine learning projects, you’ll know that training models is just one aspect of the process. Code setup, configuration management, and ensuring reproducibility can also take up a lot of time. I’m a big fan of PyTorch Lightning primarily because it hides most of the boilerplate code you usually need, making your code more modular and readable. It even allows you to train your models on multiple GPUs with ease. All of this comes with the minor trade-off of learning an intuitive API, which can be easily extended to tweak any low-level details for those rare cases where the standard API falls short.
However, despite finding PyTorch Lightning incredibly useful, there’s one aspect that has always bothered me: the configuration of the model and training hyperparameters in a flexible and reproducible manner. In my view, the best approach to address this is to use configuration files for the various modules involved. These files can be easily overridden at runtime using command-line arguments or environment variables. To achieve this, I developed my own packages, configfile
and argParseFromDoc
, which facilitates this process.
But now, there’s a tool within the Lightning suite that offers all these features in a seamlessly integrated package. Allow me to introduce you to LightningCLI. This tool streamlines the process of hyperparameter configuration, making it both flexible and reproducible. With LightningCLI, you get the best of both worlds: the power of PyTorch Lightning and a hassle-free setup.
The core idea here is to write a config file (or several) that contains the required parameters for the trainer, the model and the dataset. This is done as yaml files with the following structure.
trainer:
logger: true
...
model:
out_dim: 10
learning_rate: 0.02
data:
data_dir: ./
image_size: 256
ckpt_path: null
…
Where the yaml fields should correspond to the parameters of the PytorchLightning Trainer, and your custom Model and Data classes, that inherit from LightningModule and LightningDataModule. So a full self-contained example could be
import lightning.pytorch as pl
from lightning.pytorch.cli import LightningCLI
class MyModel(pl.LightningModule):
def __init__(self, out_dim: int, learning_rate: float):
super().__init__()
self.save_hyperparameters()
self.out_dim = out_dim
self.learning_rate = learning_rate
self.model = create_my_model(out_dim)
def training_step(self, batch, batch_idx):
out = self.model(batch.x)
loss = self.compute_loss(out, batch.y)
return loss
class MyDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, image_size: int):
super().__init__()
self.data_dir = data_dir
self.image_size = image_size
def train_dataloader(self):
return create_dataloader(self.image_size, self.data_dir)
def main():
cli = LightningCLI(model_class=MyModel, datamodule_class=MyDataModule)
if __name__ == "__main__":
main()
That can be run easily as
python scrip.py --config config.yaml fit
What is even better is that you can split the configuration into several config files and that the configuration files can refer to Python classes to be instantiated, making this configuration system so flexible that you can literally configure everything you can imagine.
model:
class_path: model.MyModel2
init_args:
learning_rate: 0.2
loss:
class_path: torch.nn.CrossEntropyLoss
init_args:
reduction: mean
In conclusion, LightningCLI brings the convenience of configuration management, command-line flexibility, and reproducibility to your PyTorch Lightning projects. With simple yet powerful features, it’s a tool that should be part of any machine learning engineer’s toolkit.