memo: PL | Template - AIkui

Reference:

pytorch-lighting 优点:

  1. 方便在各种设备(cpu/gpu/tpu/多节点)上运行,只需关心算法实现
  2. 控制随机数,重现实验结果,固定划分各batch

工程文件夹目录:

  • requirements.txt : 所需依赖包
  • train.py : main
  • opt.py : hyperparameter
  • models 文件夹: 包含不同的模型 networks1.py
  • datasets 文件夹: 包含不同数据集的 dataloader
  • losses.py: 各种loss func
  • .gitignore: ckpts/, logs/, MNIST/

requirements.txt

1
2
3
torch==1.11.0
torchvision==0.12.0
pytorch_lightning==1.6.0

opt.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse

def get_opts():
    parser = argparse.ArgumentParser()

    parser.add_argument('--root_dir', type=str, default='./data/',
                        help='root directory of dataset')

    parser.add_argument('--hidden_dim', type=int, default=128,
                        help='number of hidden dimensions')

    parser.add_argument('--val_len', type=int, default=5000,
                        help='number of validation samples split from train set')
    
    parser.add_argument('--batch_size', type=int, default=128,
                        help='number of training samples in one batch')
    
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    
    parser.add_argument('--num_epochs', type=int, default=1,
                        help='number of epochs')

    parser.add_argument('--num_gpus', type=int, default=1,
                        help='number of gpus to be used')

    parser.add_argument('--expname', type=str, default='test',
                        help='experiment name')

    # return parser.parse_args()
    args, unknown = parser.parse_known_args()
    return args

networks1.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
from torch import nn

class myLinearModel(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, hidden_dim),
            nn.ReLU(True),
            nn.Linear(hidden_dim, 10)
        )
    def forward(self, x):
        '''
        x: (B, 1, 28, 28) channel=1
        '''
        x = x.flatten(start_dim=1)     # (B, 28*28)
        return self.net(x)

train.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR

# from opt import get_opts
# from models.networks1 import myLinearModel

from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
seed_everything(1234, workers=True)

def get_learning_rate(optimizer):  # for recording logs
    for param_group in optimizer.param_groups:
        return param_group['lr']

class MNISTSystem(LightningModule): # LightningModule puts all parts together
    # Design model
    def __init__(self, hparams):
        '''
        hparams: all hyper parameters
        '''
        super().__init__()
        # self.hparams = hparams
        self.save_hyperparameters(hparams) # store exp conditions for reproduction and loss visualization

        # network components
        self.net =  myLinearModel(self.hparams.hidden_dim)

    def forward(self, x):
        return self.net(x)

    # Prepare data
    def prepare_data(self):
        '''
        Download train set and test set (execute only once)
        '''
        datasets.MNIST(self.hparams.root_dir, train=True, download=True)
        datasets.MNIST(self.hparams.root_dir, train=False, download=True)
    
    def setup(self, stage=None):
        '''
        Preprocessing
        Split the train and validation set (called by every device)
        '''
        dataset = datasets.MNIST(self.hparams.root_dir, train=True, download=False,
                    transform=transforms.ToTensor())  # load data (B, channel, H,W)
        train_length = len(dataset)
        self.train_set, self.val_set = random_split(dataset, \
                    [train_length - self.hparams.val_len, self.hparams.val_len])
    
    def train_dataloader(self):
        '''
        Split the train set into multiple batches for one epoch
        '''
        return DataLoader(self.train_set,
                          shuffle=True,     # every epoch is different
                          num_workers=4,    # cpu threads
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set,
                          shuffle=False,    # comparing acc between epochs
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

    # Construct optimizer and loss func
    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.hparams.lr)
        schedular = CosineAnnealingLR(self.optimizer, T_max=self.hparams.num_epochs, \
                                      eta_min=self.hparams.lr/1e2)
        return [self.optimizer], [schedular]   # different models use different optimizers,schedulars
    
    # Training cycle
    def training_step(self, batch, batch_idx):
        '''
        batch: come from iterable train_dataloader
        batch_idx: index of the current batch
        '''
        imgs, labels = batch    # images' pixels, labels
        logits = self(imgs)      # call forward
        loss = F.cross_entropy(logits, labels)   # including softmax
        # tensorboard
        self.log('train/loss', loss)
        self.log('lr',get_learning_rate(self.optimizer))
        return loss

    def validation_step(self, batch, batch_idx):
        '''
        Compute acc for every batch
        '''
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        acc = torch.sum(torch.eq(torch.argmax(logits, -1),labels).to(torch.float32))/len(labels)
        log = {'val_loss':loss, 'acc': acc}
        return log

    def validation_epoch_end(self, batch_outputs) -> None:
        '''
        Compute average loss/acc among all batches for one epoch
        '''
        mean_loss = torch.stack([x['val_loss'] for x in batch_outputs]).mean()
        mean_acc = torch.stack([x['acc'] for x in batch_outputs]).mean()

        self.log('val/loss', mean_loss, prog_bar=True)  # record loss of every step  
        self.log('val/acc', mean_acc, prog_bar=True)    # show on the progress bar
         

if __name__ == '__main__':
    hparams = get_opts()
    mnistsystem = MNISTSystem(hparams)  # construct training system

    # save weights to files
    ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.expname}',
                              filename='{epoch:d}', # epoch=0,...
                              monitor='val/acc',
                              mode='max',
                              save_top_k=5) # only store 5 max acc models' weights (-1 all)
    # progress bar
    pbar = TQDMProgressBar(refresh_rate=1)
    callbacks = [ckpt_cb, pbar]

    # tensorboard events
    logger = TensorBoardLogger(save_dir='logs', name=hparams.expname, default_hp_metric=False)

    # 
    trainer = Trainer(max_epochs=hparams.num_epochs,
                      callbacks=callbacks,
                      logger = logger,
                      enable_model_summary = True,  # print model structure
                      accelerator='auto',   # devices type
                      devices = hparams.num_gpus,
                      num_sanity_val_steps = 1, # run once val before training to verfiy if it's normal
                      benchmark=True,   # cudnn accelerate need each batch has same size
                      profiler="simple" if hparams.num_gpus==1 else None, # count time for every operation
                      )
    trainer.fit(mnistsystem)
Built with Hugo
Theme Stack designed by Jimmy