Skip to content

Commit 7f6bbc4

Browse files
authored
Merge pull request #4 from Irynei/feature/enhance_training
Feature/enhance training
2 parents 4b6c5cc + 91ea4f1 commit 7f6bbc4

File tree

11 files changed

+356
-87
lines changed

11 files changed

+356
-87
lines changed

base/base_dataset.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ class AutoAugmentDataset(data.Dataset):
1010
Randomly applies subset of augmentations and set them as labels
1111
1212
"""
13-
def __init__(self, dataset, base_transforms, augmentations, max_size=7):
13+
def __init__(self, dataset, base_transforms, augmentations, max_size=7, train=True):
1414
self.dataset = dataset
1515
self.base_transforms = base_transforms
1616
self.augmentations = augmentations
1717
self.max_size = max_size
18+
self.train = train
1819

1920
def __getitem__(self, index):
2021
x, y = self.dataset[index]
@@ -35,22 +36,30 @@ def __len__(self):
3536

3637
def get_subset_of_transforms(self):
3738
"""
38-
Randomly get size of subset and then randomly choose subset of transformations
39+
in case of train dataset:
40+
Randomly get size of subset and then randomly choose subset of transformations
41+
in case of test dataset:
42+
Subset of transformations is always empty
3943
4044
Returns:
4145
list of chosen transformations, one-hot-encoded labels
4246
4347
"""
4448
all_transforms_size = len(self.augmentations)
4549

46-
# size from 0 to max_size - 1
47-
subset_size = np.random.randint(0, self.max_size)
48-
all_transforms_idx = np.arange(all_transforms_size)
49-
# get random subset without duplicates
50-
np.random.shuffle(all_transforms_idx)
51-
transform_idx = all_transforms_idx[:subset_size]
52-
subset_transforms = [self.augmentations[i] for i in transform_idx]
50+
if self.train:
51+
# size from 0 to max_size - 1
52+
subset_size = np.random.randint(0, self.max_size)
53+
all_transforms_idx = np.arange(all_transforms_size)
54+
# get random subset without duplicates
55+
np.random.shuffle(all_transforms_idx)
56+
transform_idx = all_transforms_idx[:subset_size]
57+
subset_transforms = [self.augmentations[i] for i in transform_idx]
5358

54-
labels = np.zeros(all_transforms_size)
55-
labels[transform_idx] = 1
59+
labels = np.zeros(all_transforms_size)
60+
labels[transform_idx] = 1
61+
else:
62+
# in case of test we do
63+
labels = np.zeros(all_transforms_size)
64+
subset_transforms = []
5665
return subset_transforms, labels

base/base_trainer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import glog as log
55
import torch.optim as optim
6+
from utils.util import EarlyStopping
67

78

89
class BaseTrainer:
@@ -50,6 +51,10 @@ def __init__(self, model, loss, metrics, resume, config, train_logger=None):
5051
assert self.monitor_mode == 'min' or self.monitor_mode == 'max'
5152
self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
5253

54+
self.early_stopping = None
55+
if config.get('early_stopping'):
56+
self.early_stopping = EarlyStopping(**config['early_stopping']['early_stopping_params'])
57+
5358
self.start_epoch = 1
5459
self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], config['experiment_name'])
5560
if resume:
@@ -93,9 +98,16 @@ def train(self):
9398

9499
# lr_scheduler logic
95100
if self.lr_scheduler and epoch % self.lr_scheduler_freq == 0:
96-
self.lr_scheduler.step(epoch)
97-
lr = self.lr_scheduler.get_lr()[0]
98-
self.logger.info('New Learning Rate: {:.6f}'.format(lr))
101+
if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
102+
self.lr_scheduler.step(log['val_loss'])
103+
else:
104+
self.lr_scheduler.step(epoch)
105+
106+
# stopping early logic
107+
if self.early_stopping:
108+
stop_early = self.early_stopping.step(log['val_loss'])
109+
if stop_early:
110+
break
99111

100112
def _train_epoch(self, epoch):
101113
"""
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"experiment_name": "DenseNet121_on_CIFAR10",
3+
"cuda": true,
4+
"gpu": 0,
5+
"augmentation": {
6+
"max_size": 5
7+
},
8+
"data_loader": {
9+
"name": "CIFAR10DataLoader",
10+
"data_dir": "datasets/",
11+
"batch_size": 32,
12+
"shuffle": true
13+
},
14+
"model_name": "densenet121_32x32",
15+
"model_params": {
16+
"num_classes": 15
17+
},
18+
"optimizer_type": "SGD",
19+
"optimizer_params": {
20+
"lr": 0.1,
21+
"weight_decay": 0.0005,
22+
"momentum": 0.9
23+
},
24+
"loss": "MultiLabelSoftMarginLoss",
25+
"validation": {
26+
"validation_split": 0.1,
27+
"shuffle": true
28+
},
29+
"lr_scheduler": {
30+
"lr_scheduler_type": "ReduceLROnPlateau",
31+
"lr_scheduler_freq": 1,
32+
"additional_params": {
33+
"patience": 8,
34+
"mode": "min",
35+
"min_lr": 1e-7,
36+
"factor": 0.1,
37+
"verbose": true
38+
}
39+
},
40+
"early_stopping": {
41+
"early_stopping_params": {
42+
"patience": 12,
43+
"mode": "min"
44+
}
45+
},
46+
"metrics": ["accuracy", "jaccard_similarity"],
47+
"trainer": {
48+
"epochs": 500,
49+
"save_dir": "experiments/",
50+
"save_freq": 100,
51+
"verbosity": 2,
52+
"monitor": "val_loss",
53+
"monitor_mode": "min"
54+
}
55+
}

data_loaders/data_loader.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def __init__(self, config):
4242
dataset=datasets.CIFAR10(self.data_dir, train=True, download=True),
4343
base_transforms=self.base_transforms,
4444
augmentations=self.augmentations,
45-
max_size=self.max_size
45+
max_size=self.max_size,
46+
train=True
4647
),
47-
'test': datasets.CIFAR10(
48-
self.data_dir,
49-
train=False,
50-
download=True,
51-
transform=self.base_transforms
48+
'test': AutoAugmentDataset(
49+
dataset=datasets.CIFAR10(self.data_dir, train=False, download=True),
50+
base_transforms=self.base_transforms,
51+
augmentations=self.augmentations,
52+
train=False
5253
)
5354
}
5455
super(CIFAR10DataLoader, self).__init__(self.dataset, config)
@@ -74,11 +75,11 @@ def __init__(self, config):
7475
augmentations=self.augmentations,
7576
max_size=self.max_size
7677
),
77-
'test': datasets.SVHN(
78-
self.data_dir,
79-
split='test',
80-
download=True,
81-
transform=self.base_transforms
78+
'test': AutoAugmentDataset(
79+
dataset=datasets.SVHN(self.data_dir, split='test', download=True),
80+
base_transforms=self.base_transforms,
81+
augmentations=self.augmentations,
82+
train=False
8283
)
8384
}
8485
super(SVHNDataLoader, self).__init__(self.dataset, config)

model/architectures/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .densenet_32x32 import (
2+
densenet121_32x32,
3+
densenet161_32x32,
4+
densenet169_32x32,
5+
densenet201_32x32
6+
)
7+
from .vgg_32x32 import (
8+
VGG16_32x32
9+
)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
def densenet121_32x32(num_classes):
8+
""" densenet121 that works with 32x32 input (e.g. CIFAR10) """
9+
return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, num_classes=num_classes)
10+
11+
12+
def densenet169_32x32(num_classes):
13+
""" densenet169 that works with 32x32 input (e.g. CIFAR10) """
14+
return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32, num_classes=num_classes)
15+
16+
17+
def densenet201_32x32(num_classes):
18+
""" densenet201 that works with 32x32 input (e.g. CIFAR10) """
19+
return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32, num_classes=num_classes)
20+
21+
22+
def densenet161_32x32(num_classes):
23+
""" densenet161 that works with 32x32 input (e.g. CIFAR10) """
24+
return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48, num_classes=num_classes)
25+
26+
27+
class Bottleneck(nn.Module):
28+
def __init__(self, in_planes, growth_rate):
29+
super(Bottleneck, self).__init__()
30+
self.bn1 = nn.BatchNorm2d(in_planes)
31+
self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False)
32+
self.bn2 = nn.BatchNorm2d(4 * growth_rate)
33+
self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
34+
35+
def forward(self, x):
36+
out = self.conv1(F.relu(self.bn1(x)))
37+
out = self.conv2(F.relu(self.bn2(out)))
38+
out = torch.cat([out, x], 1)
39+
return out
40+
41+
42+
class Transition(nn.Module):
43+
def __init__(self, in_planes, out_planes):
44+
super(Transition, self).__init__()
45+
self.bn = nn.BatchNorm2d(in_planes)
46+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
47+
48+
def forward(self, x):
49+
out = self.conv(F.relu(self.bn(x)))
50+
out = F.avg_pool2d(out, 2)
51+
return out
52+
53+
54+
class DenseNet(nn.Module):
55+
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
56+
super(DenseNet, self).__init__()
57+
self.growth_rate = growth_rate
58+
59+
num_planes = 2 * growth_rate
60+
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
61+
62+
self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
63+
num_planes += nblocks[0] * growth_rate
64+
out_planes = int(math.floor(num_planes * reduction))
65+
self.trans1 = Transition(num_planes, out_planes)
66+
num_planes = out_planes
67+
68+
self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
69+
num_planes += nblocks[1] * growth_rate
70+
out_planes = int(math.floor(num_planes * reduction))
71+
self.trans2 = Transition(num_planes, out_planes)
72+
num_planes = out_planes
73+
74+
self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
75+
num_planes += nblocks[2] * growth_rate
76+
out_planes = int(math.floor(num_planes * reduction))
77+
self.trans3 = Transition(num_planes, out_planes)
78+
num_planes = out_planes
79+
80+
self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
81+
num_planes += nblocks[3] * growth_rate
82+
83+
self.bn = nn.BatchNorm2d(num_planes)
84+
self.linear = nn.Linear(num_planes, num_classes)
85+
86+
def _make_dense_layers(self, block, in_planes, nblock):
87+
layers = []
88+
for i in range(nblock):
89+
layers.append(block(in_planes, self.growth_rate))
90+
in_planes += self.growth_rate
91+
return nn.Sequential(*layers)
92+
93+
def forward(self, x):
94+
out = self.conv1(x)
95+
out = self.trans1(self.dense1(out))
96+
out = self.trans2(self.dense2(out))
97+
out = self.trans3(self.dense3(out))
98+
out = self.dense4(out)
99+
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
100+
out = out.view(out.size(0), -1)
101+
out = self.linear(out)
102+
return out

model/architectures/vgg_32x32.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch.nn as nn
2+
3+
4+
class VGG16_32x32(nn.Module):
5+
""" VGG16 that works with 32x32 input (e.g. CIFAR10)"""
6+
def __init__(self, num_classes):
7+
super(VGG16_32x32, self).__init__()
8+
self.layers = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
9+
self.features = self._make_layers(self.layers)
10+
self.classifier = nn.Linear(512, num_classes)
11+
12+
def forward(self, x):
13+
out = self.features(x)
14+
out = out.view(out.size(0), -1)
15+
out = self.classifier(out)
16+
return out
17+
18+
def _make_layers(self, cfg):
19+
layers = []
20+
in_channels = 3
21+
for x in cfg:
22+
if x == 'M':
23+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
24+
else:
25+
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
26+
nn.BatchNorm2d(x),
27+
nn.ReLU(inplace=True)]
28+
in_channels = x
29+
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
30+
return nn.Sequential(*layers)

model/metric.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
accuracy_score,
44
fbeta_score,
55
hamming_loss,
6+
jaccard_similarity_score
67
)
78

89

@@ -18,14 +19,16 @@ def get_metric_functions(metric_names):
1819
metric_fns.append(f_beta)
1920
elif metric_name == 'ham_loss':
2021
metric_fns.append(ham_loss)
22+
elif metric_name == 'jaccard_similarity':
23+
metric_fns.append(jaccard_similarity)
2124
else:
2225
raise NameError("Metric '{metric}' not found.".format(metric=metric_name))
2326
return metric_fns
2427

2528

2629
def accuracy(preds, targs, threshold=0.5):
2730
"""
28-
Accuracy classification score.
31+
Exact match accuracy classification score.
2932
The set of labels predicted for a sample (preds) must exactly match the
3033
corresponding set of labels (targs)
3134
Args:
@@ -64,3 +67,16 @@ def ham_loss(preds, targs, threshold=0.5):
6467
6568
"""
6669
return hamming_loss(targs, (preds > threshold))
70+
71+
72+
def jaccard_similarity(preds, targs, threshold=0.5):
73+
"""
74+
Jaccard similarity score for multi-label classification.
75+
Intersection over union
76+
Args:
77+
preds: predicted targets as returned by a model
78+
targs: ground truth target value
79+
threshold: threshold, default is 0.5
80+
81+
"""
82+
return jaccard_similarity_score(targs, (preds > threshold))

0 commit comments

Comments
 (0)