Kong Junhyeong

Pytorch Lightning 사용법

https://pytorch-lightning.readthedocs.io/en/stable/

설치

pip를 이용하여 설치하는 방법

pip install pytorch-lightning

conda를 이용하여 설치하는 방법

conda install pytorch-lightning -c conda-forgeA

Import

import pytorch_lightning as pl

Model 정의

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

PyTorch Lightning을 사용하기 위해 pl.LightningModule을 상속한 Model을 정의한다.

Optimizer와 LR Scheduler 정의

Model의 optimizer와 LR Scheduler는 모델 안에서 configure_optimzers hook으로 정의한다.

class Model(pl.LightningModule):
		...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

configure_optimizers의 return 값은 아래 중 하나가 될 수 있다.

  • Single optimizer
  • List or Tuple of optimizers
  • Two lists : 첫 번째 list는 optimizer, 두 번째는 LR scheduler의 list
  • Dictionary : “optimizer” key를 반드시 포함해야 한다. (optionally) “lr_scheduler” key를 가질 수 있으며, value는 LR_scheduler이거나 lr_scheduler_config 중 하나이다.
  • Tuple of dictionaries : 위의 dictionary에서 “frequency” key를 포함해야 함
  • None : optimizer를 사용하지 않음

lr_scheduler_config는 scheduler에 대한 configuration이다.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

Example:

# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

Training logic 정의

training_step hook을 정의한다. training_step은 파라미터로 self, batch, batch_idx를 받으며, loss를 return 한다.

class LitModel(pl.LightningModule):
    ...

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

parameters

  • batch : DataLoader의 output
  • batch_idx (int) : 각 batch의 index를 나타내는 integer
  • optimizer_idx : 여러 개의 optimizer를 사용할 때, 사용할 optimizer의 index
  • hidden : truncated_bptt_steps > 0일 때 사용

return

  • Tensor : loss 값 tensor
  • Dictionary : 어떤 key도 포함할 수 있지만 ‘loss’ key는 반드시 포함해야 한다.

training_epoch_end는 training 중 한 번의 epoch이 끝났을 때 실행된다.

def training_epoch_end(self, training_step_outputs):
    # do something with all training_step outputs
    for out in training_step_outputs:
        ...

parameters

  • List of Tensor
  • List of Dictionary

return

  • None

Validation logic 정의

Training 중 Validation loop를 사용하기 위해서 valsidation_step() method를 override한다.

class LitModel(pl.LightningModule):
    ...
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)

validation_epoch_end()를 override하여 validate epoch이 종료될 때마다 실행하는 코드를 정의할 수 있다.

Test logic 정의

test_step()을 override한다.

class LitModel(pl.LightningModule):
    ...
    def test_step(self, batch, batch_idx):
			  ...

Log 정의

self.log()를 이용하면 key, value pair로 데이터를 기록할 수 있다.

def training_step(self, batch, batch_idx):
		...
		self.log("loss", x)
		...

parameters

  • on_step
  • on_epoch
  • prog_bar
  • logger
  • reduce_fx
  • enable_graph
  • sync_dist
  • sync_dist_group
  • add_dataloader_idx
  • batch_size
  • rank_zero_only

실행 구조

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)

def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")
    configure_optimizers()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")

def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()

def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    val_outs = []
    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(batch, batch_idx)
        val_outs.append(out)

    validation_epoch_end(val_outs)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)

Trainer

Trainer class는 pl.LightningModule로 정의된 Model의 training, test를 실행한다.

from pytorch-lightning import Trainer

...
model = Model()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

...
trainer.test(model, test_dataloader)