1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
|
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
# from opt import get_opts
# from models.networks1 import myLinearModel
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
seed_everything(1234, workers=True)
def get_learning_rate(optimizer): # for recording logs
for param_group in optimizer.param_groups:
return param_group['lr']
class MNISTSystem(LightningModule): # LightningModule puts all parts together
# Design model
def __init__(self, hparams):
'''
hparams: all hyper parameters
'''
super().__init__()
# self.hparams = hparams
self.save_hyperparameters(hparams) # store exp conditions for reproduction and loss visualization
# network components
self.net = myLinearModel(self.hparams.hidden_dim)
def forward(self, x):
return self.net(x)
# Prepare data
def prepare_data(self):
'''
Download train set and test set (execute only once)
'''
datasets.MNIST(self.hparams.root_dir, train=True, download=True)
datasets.MNIST(self.hparams.root_dir, train=False, download=True)
def setup(self, stage=None):
'''
Preprocessing
Split the train and validation set (called by every device)
'''
dataset = datasets.MNIST(self.hparams.root_dir, train=True, download=False,
transform=transforms.ToTensor()) # load data (B, channel, H,W)
train_length = len(dataset)
self.train_set, self.val_set = random_split(dataset, \
[train_length - self.hparams.val_len, self.hparams.val_len])
def train_dataloader(self):
'''
Split the train set into multiple batches for one epoch
'''
return DataLoader(self.train_set,
shuffle=True, # every epoch is different
num_workers=4, # cpu threads
batch_size=self.hparams.batch_size,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_set,
shuffle=False, # comparing acc between epochs
num_workers=4,
batch_size=self.hparams.batch_size,
pin_memory=True)
# Construct optimizer and loss func
def configure_optimizers(self):
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.hparams.lr)
schedular = CosineAnnealingLR(self.optimizer, T_max=self.hparams.num_epochs, \
eta_min=self.hparams.lr/1e2)
return [self.optimizer], [schedular] # different models use different optimizers,schedulars
# Training cycle
def training_step(self, batch, batch_idx):
'''
batch: come from iterable train_dataloader
batch_idx: index of the current batch
'''
imgs, labels = batch # images' pixels, labels
logits = self(imgs) # call forward
loss = F.cross_entropy(logits, labels) # including softmax
# tensorboard
self.log('train/loss', loss)
self.log('lr',get_learning_rate(self.optimizer))
return loss
def validation_step(self, batch, batch_idx):
'''
Compute acc for every batch
'''
imgs, labels = batch
logits = self(imgs)
loss = F.cross_entropy(logits, labels)
acc = torch.sum(torch.eq(torch.argmax(logits, -1),labels).to(torch.float32))/len(labels)
log = {'val_loss':loss, 'acc': acc}
return log
def validation_epoch_end(self, batch_outputs) -> None:
'''
Compute average loss/acc among all batches for one epoch
'''
mean_loss = torch.stack([x['val_loss'] for x in batch_outputs]).mean()
mean_acc = torch.stack([x['acc'] for x in batch_outputs]).mean()
self.log('val/loss', mean_loss, prog_bar=True) # record loss of every step
self.log('val/acc', mean_acc, prog_bar=True) # show on the progress bar
if __name__ == '__main__':
hparams = get_opts()
mnistsystem = MNISTSystem(hparams) # construct training system
# save weights to files
ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.expname}',
filename='{epoch:d}', # epoch=0,...
monitor='val/acc',
mode='max',
save_top_k=5) # only store 5 max acc models' weights (-1 all)
# progress bar
pbar = TQDMProgressBar(refresh_rate=1)
callbacks = [ckpt_cb, pbar]
# tensorboard events
logger = TensorBoardLogger(save_dir='logs', name=hparams.expname, default_hp_metric=False)
#
trainer = Trainer(max_epochs=hparams.num_epochs,
callbacks=callbacks,
logger = logger,
enable_model_summary = True, # print model structure
accelerator='auto', # devices type
devices = hparams.num_gpus,
num_sanity_val_steps = 1, # run once val before training to verfiy if it's normal
benchmark=True, # cudnn accelerate need each batch has same size
profiler="simple" if hparams.num_gpus==1 else None, # count time for every operation
)
trainer.fit(mnistsystem)
|