https://pytorch-lightning.readthedocs.io/en/stable/
설치
pip를 이용하여 설치하는 방법
pip install pytorch-lightning
conda를 이용하여 설치하는 방법
conda install pytorch-lightning -c conda-forgeA
Import
import pytorch_lightning as pl
Model 정의
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
PyTorch Lightning을 사용하기 위해 pl.LightningModule을 상속한 Model을 정의한다.
Optimizer와 LR Scheduler 정의
Model의 optimizer와 LR Scheduler는 모델 안에서 configure_optimzers hook으로 정의한다.
class Model(pl.LightningModule):
...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
configure_optimizers의 return 값은 아래 중 하나가 될 수 있다.
- Single optimizer
- List or Tuple of optimizers
- Two lists : 첫 번째 list는 optimizer, 두 번째는 LR scheduler의 list
- Dictionary : “optimizer” key를 반드시 포함해야 한다. (optionally) “lr_scheduler” key를 가질 수 있으며, value는 LR_scheduler이거나 lr_scheduler_config 중 하나이다.
- Tuple of dictionaries : 위의 dictionary에서 “frequency” key를 포함해야 함
- None : optimizer를 사용하지 않음
lr_scheduler_config는 scheduler에 대한 configuration이다.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
Example:
# most cases. no learning rate scheduler
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
return gen_opt, dis_opt
# example with learning rate schedulers
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
dis_sch = CosineAnnealing(dis_opt, T_max=10)
return [gen_opt, dis_opt], [dis_sch]
# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
gen_sch = {
'scheduler': ExponentialLR(gen_opt, 0.99),
'interval': 'step' # called after each training step
}
dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
return [gen_opt, dis_opt], [gen_sch, dis_sch]
# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
)
Training logic 정의
training_step hook을 정의한다. training_step은 파라미터로 self, batch, batch_idx를 받으며, loss를 return 한다.
class LitModel(pl.LightningModule):
...
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
parameters
- batch : DataLoader의 output
- batch_idx (int) : 각 batch의 index를 나타내는 integer
- optimizer_idx : 여러 개의 optimizer를 사용할 때, 사용할 optimizer의 index
- hidden : truncated_bptt_steps > 0일 때 사용
return
- Tensor : loss 값 tensor
- Dictionary : 어떤 key도 포함할 수 있지만 ‘loss’ key는 반드시 포함해야 한다.
training_epoch_end는 training 중 한 번의 epoch이 끝났을 때 실행된다.
def training_epoch_end(self, training_step_outputs):
# do something with all training_step outputs
for out in training_step_outputs:
...
parameters
- List of Tensor
- List of Dictionary
return
- None
Validation logic 정의
Training 중 Validation loop를 사용하기 위해서 valsidation_step() method를 override한다.
class LitModel(pl.LightningModule):
...
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
validation_epoch_end()를 override하여 validate epoch이 종료될 때마다 실행하는 코드를 정의할 수 있다.
Test logic 정의
test_step()을 override한다.
class LitModel(pl.LightningModule):
...
def test_step(self, batch, batch_idx):
...
Log 정의
self.log()를 이용하면 key, value pair로 데이터를 기록할 수 있다.
def training_step(self, batch, batch_idx):
...
self.log("loss", x)
...
parameters
- on_step
- on_epoch
- prog_bar
- logger
- reduce_fx
- enable_graph
- sync_dist
- sync_dist_group
- add_dataloader_idx
- batch_size
- rank_zero_only
실행 구조
def fit(self):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data()
configure_callbacks()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# called PER DEVICE
on_fit_start()
setup("fit")
configure_optimizers()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()
on_fit_end()
teardown("fit")
def fit_loop():
on_train_epoch_start()
for batch in train_dataloader():
on_train_batch_start()
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
training_step()
on_before_zero_grad()
optimizer_zero_grad()
on_before_backward()
backward()
on_after_backward()
on_before_optimizer_step()
configure_gradient_clipping()
optimizer_step()
on_train_batch_end()
if should_check_val:
val_loop()
# end training epoch
training_epoch_end()
on_train_epoch_end()
def val_loop():
on_validation_model_eval() # calls `model.eval()`
torch.set_grad_enabled(False)
on_validation_start()
on_validation_epoch_start()
val_outs = []
for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)
batch = on_before_batch_transfer(batch)
batch = transfer_batch_to_device(batch)
batch = on_after_batch_transfer(batch)
out = validation_step(batch, batch_idx)
on_validation_batch_end(batch, batch_idx)
val_outs.append(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_validation_end()
# set up for train
on_validation_model_train() # calls `model.train()`
torch.set_grad_enabled(True)
Trainer
Trainer class는 pl.LightningModule로 정의된 Model의 training, test를 실행한다.
from pytorch-lightning import Trainer
...
model = Model()
trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
...
trainer.test(model, test_dataloader)