PyTorch Lightning
PyTorch Lightning is an open-source machine learning library built on top of PyTorch, designed to help researchers and developers more conveniently develop deep learning models. The design philosophy of Lightning is to separate the tedious code in model training (such as device management, distributed training, etc.) from the research code (model architecture, data processing, etc.), so that researchers can focus on the research itself rather than the underlying engineering details.
You can use PyTorch Lightning to quickly train models while using SwanLab for experiment tracking and visualization.
1. Import SwanLabLogger
from swanlab.integration.pytorch_lightning import SwanLabLogger
SwanLabLogger is a logging class adapted for PyTorch Lightning.
SwanLabLogger can define parameters such as:
project
,experiment_name
,description
, and other parameters consistent withswanlab.init
, used for initializing the SwanLab project.- You can also create the project externally via
swanlab.init
, and the integration will log the experiment to the project you created externally.
2. Pass to Trainer
import pytorch_lightning as pl ... # Instantiate SwanLabLogger swanlab_logger = SwanLabLogger(project="lightning-visualization") trainer = pl.Trainer( ... # Pass the logger parameter logger=swanlab_logger, ) trainer.fit(...)
3. Complete Example Code
from swanlab.integration.pytorch_lightning import SwanLabLogger import importlib.util import os import pytorch_lightning as pl from torch import nn, optim, utils from torchvision.datasets import MNIST from torchvision.transforms import ToTensor # define any number of nn.Modules (or use your current ones) encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) # define the LightningModule class LitAutoEncoder(pl.LightningModule): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def training_step(self, batch, batch_idx): # training_step defines the train loop. # it is independent of forward x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = nn.functional.mse_loss(x_hat, x) # Logging to TensorBoard (if installed) by default self.log("train_loss", loss) return loss def test_step(self, batch, batch_idx): # test_step defines the test loop. # it is independent of forward x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = nn.functional.mse_loss(x_hat, x) # Logging to TensorBoard (if installed) by default self.log("test_loss", loss) return loss def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-3) return optimizer # init the autoencoder autoencoder = LitAutoEncoder(encoder, decoder) # setup data dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000]) test_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) train_loader = utils.data.DataLoader(train_dataset) val_loader = utils.data.DataLoader(val_dataset) test_loader = utils.data.DataLoader(test_dataset) swanlab_logger = SwanLabLogger( project="swanlab_example", experiment_name="example_experiment", ) trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, logger=swanlab_logger) trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader) trainer.test(dataloaders=test_loader)
4. Note: If you call trainer.fit
multiple times
If you call trainer.fit
multiple times in a single process (e.g., N-fold cross-validation), you need to add the following line after trainer.fit
:
swanlab_logger.experiment.finish() # or swanlab.finish()
Example code:
import torch from torch.utils.data import Dataset, DataLoader, Subset from sklearn.model_selection import KFold import pytorch_lightning as pl from swanlab.integration.pytorch_lightning import SwanLabLogger import datetime import argparse class RandomDataset(Dataset): def __init__(self, size=100): self.x = torch.randn(size, 10) self.y = (self.x.sum(dim=1) > 0).long() # Simple classification task def __len__(self): return len(self.x) def __getitem__(self, idx): return self.x[idx], self.y[idx] class SimpleClassifier(pl.LightningModule): def __init__(self): super().__init__() self.model = torch.nn.Linear(10, 2) self.loss_fn = torch.nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) acc = (logits.argmax(dim=1) == y).float().mean() self.log("val_loss", loss) self.log("val_acc", acc) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) def main(args): dataset = RandomDataset(size=100) kfold = KFold(n_splits=3, shuffle=True, random_state=42) for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)): print(f"\nFold {fold + 1}/3") train_loader = DataLoader(Subset(dataset, train_idx), batch_size=16, shuffle=True) val_loader = DataLoader(Subset(dataset, val_idx), batch_size=16) # Log name current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') run_name = f"{args.save_name}_fold{fold + 1}_{current_time}" swanlab_logger = SwanLabLogger( project="swanlab_example", experiment_name=run_name, ) model = SimpleClassifier() trainer = pl.Trainer( max_epochs=5, logger=swanlab_logger, log_every_n_steps=1 ) trainer.fit(model, train_loader, val_loader) swanlab_logger.experiment.finish() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--save_name", type=str, default="test_swan") args = parser.parse_args() main(args)