Skip to content

Commit 43b8916

Browse files
authored
Add files via upload
1 parent 4d31c9e commit 43b8916

File tree

2 files changed

+368
-0
lines changed

2 files changed

+368
-0
lines changed
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""
2+
@author: Ying Jin
3+
@contact: sherryying003@gmail.com
4+
"""
5+
import random
6+
import time
7+
import warnings
8+
import argparse
9+
import shutil
10+
import os.path as osp
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.backends.cudnn as cudnn
15+
from torch.optim import SGD
16+
from torch.optim.lr_scheduler import LambdaLR
17+
from torch.utils.data import DataLoader
18+
import torch.nn.functional as F
19+
20+
import utils
21+
from tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier
22+
from tllib.self_training.cc_loss import CCConsistency
23+
from tllib.vision.transforms import MultipleApply
24+
from tllib.utils.data import ForeverDataIterator
25+
from tllib.utils.metric import accuracy
26+
from tllib.utils.meter import AverageMeter, ProgressMeter
27+
from tllib.utils.logger import CompleteLogger
28+
from tllib.utils.analysis import collect_feature, tsne, a_distance
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
32+
33+
def main(args: argparse.Namespace):
34+
logger = CompleteLogger(args.log, args.phase)
35+
print(args)
36+
37+
if args.seed is not None:
38+
random.seed(args.seed)
39+
torch.manual_seed(args.seed)
40+
cudnn.deterministic = True
41+
warnings.warn('You have chosen to seed training. '
42+
'This will turn on the CUDNN deterministic setting, '
43+
'which can slow down your training considerably! '
44+
'You may see unexpected behavior when restarting '
45+
'from checkpoints.')
46+
47+
cudnn.benchmark = True
48+
49+
# Data loading code
50+
train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
51+
random_horizontal_flip=not args.no_hflip,
52+
random_color_jitter=False, resize_size=args.resize_size,
53+
norm_mean=args.norm_mean, norm_std=args.norm_std)
54+
weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
55+
random_horizontal_flip=not args.no_hflip,
56+
random_color_jitter=False, resize_size=args.resize_size,
57+
norm_mean=args.norm_mean, norm_std=args.norm_std)
58+
strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
59+
random_horizontal_flip=not args.no_hflip,
60+
random_color_jitter=False, resize_size=args.resize_size,
61+
norm_mean=args.norm_mean, norm_std=args.norm_std,
62+
auto_augment=args.auto_augment)
63+
train_target_transform = MultipleApply([weak_augment, strong_augment])
64+
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
65+
norm_mean=args.norm_mean, norm_std=args.norm_std)
66+
print("train_source_transform: ", train_source_transform)
67+
print("train_target_transform: ", train_target_transform)
68+
print("val_transform: ", val_transform)
69+
70+
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
71+
utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,
72+
train_target_transform=train_target_transform)
73+
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
74+
shuffle=True, num_workers=args.workers, drop_last=True)
75+
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
76+
shuffle=True, num_workers=args.workers, drop_last=True)
77+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
78+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
79+
80+
train_source_iter = ForeverDataIterator(train_source_loader)
81+
train_target_iter = ForeverDataIterator(train_target_loader)
82+
83+
# create model
84+
print("=> using model '{}'".format(args.arch))
85+
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
86+
pool_layer = nn.Identity() if args.no_pool else None
87+
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
88+
pool_layer=pool_layer, finetune=not args.scratch).to(device)
89+
90+
# define optimizer and lr scheduler
91+
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
92+
nesterov=True)
93+
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
94+
95+
# resume from the best checkpoint
96+
if args.phase != 'train':
97+
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
98+
classifier.load_state_dict(checkpoint)
99+
100+
# analysis the model
101+
if args.phase == 'analysis':
102+
# extract features from both domains
103+
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
104+
source_feature = collect_feature(train_source_loader, feature_extractor, device)
105+
target_feature = collect_feature(train_target_loader, feature_extractor, device)
106+
# plot t-SNE
107+
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
108+
tsne.visualize(source_feature, target_feature, tSNE_filename)
109+
print("Saving t-SNE to", tSNE_filename)
110+
# calculate A-distance, which is a measure for distribution discrepancy
111+
A_distance = a_distance.calculate(source_feature, target_feature, device)
112+
print("A-distance =", A_distance)
113+
return
114+
115+
if args.phase == 'test':
116+
acc1 = utils.validate(test_loader, classifier, args, device)
117+
print(acc1)
118+
return
119+
120+
# start training
121+
best_acc1 = 0.
122+
for epoch in range(args.epochs):
123+
print("lr:", lr_scheduler.get_last_lr()[0])
124+
# train for one epoch
125+
train(train_source_iter, train_target_iter, classifier, optimizer,
126+
lr_scheduler, epoch, args)
127+
128+
# evaluate on validation set
129+
acc1 = utils.validate(val_loader, classifier, args, device)
130+
131+
# remember best acc@1 and save checkpoint
132+
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
133+
if acc1 > best_acc1:
134+
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
135+
best_acc1 = max(acc1, best_acc1)
136+
137+
print("best_acc1 = {:3.1f}".format(best_acc1))
138+
139+
# evaluate on test set
140+
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
141+
acc1 = utils.validate(test_loader, classifier, args, device)
142+
print("test_acc1 = {:3.1f}".format(acc1))
143+
144+
logger.close()
145+
146+
147+
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
148+
model: ImageClassifier, optimizer: SGD,
149+
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
150+
batch_time = AverageMeter('Time', ':3.1f')
151+
data_time = AverageMeter('Data', ':3.1f')
152+
losses = AverageMeter('Loss', ':3.2f')
153+
trans_losses = AverageMeter('Trans Loss', ':3.2f')
154+
cls_accs = AverageMeter('Cls Acc', ':3.1f')
155+
156+
progress = ProgressMeter(
157+
args.iters_per_epoch,
158+
[batch_time, data_time, losses, trans_losses, cls_accs],
159+
prefix="Epoch: [{}]".format(epoch))
160+
161+
# define loss function
162+
mcc = MinimumClassConfusionLoss(temperature=args.temperature)
163+
consistency = CCConsistency(temperature=args.temperature, thr=args.thr)
164+
165+
# switch to train mode
166+
model.train()
167+
168+
end = time.time()
169+
for i in range(args.iters_per_epoch):
170+
x_s, labels_s = next(train_source_iter)[:2]
171+
(x_t, x_t_strong), labels_t = next(train_target_iter)[:2]
172+
173+
x_s = x_s.to(device)
174+
x_t = x_t.to(device)
175+
x_t_strong = x_t_strong.to(device)
176+
labels_s = labels_s.to(device)
177+
178+
# measure data loading time
179+
data_time.update(time.time() - end)
180+
181+
# compute output
182+
x = torch.cat((x_s, x_t, x_t_strong), dim=0)
183+
y, f = model(x)
184+
y_s, y_t, y_t_strong = y.chunk(3, dim=0)
185+
186+
cls_loss = F.cross_entropy(y_s, labels_s)
187+
mcc_loss = mcc(y_t)
188+
consistency_loss, selec_ratio = consistency(y_t, y_t_strong)
189+
loss = cls_loss + mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency
190+
transfer_loss = mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency
191+
192+
cls_acc = accuracy(y_s, labels_s)[0]
193+
194+
losses.update(loss.item(), x_s.size(0))
195+
cls_accs.update(cls_acc.item(), x_s.size(0))
196+
trans_losses.update(transfer_loss.item(), x_s.size(0))
197+
198+
# compute gradient and do SGD step
199+
optimizer.zero_grad()
200+
loss.backward()
201+
optimizer.step()
202+
lr_scheduler.step()
203+
204+
# measure elapsed time
205+
batch_time.update(time.time() - end)
206+
end = time.time()
207+
208+
if i % args.print_freq == 0:
209+
progress.display(i)
210+
211+
212+
if __name__ == '__main__':
213+
parser = argparse.ArgumentParser(description='CC Loss for Unsupervised Domain Adaptation')
214+
# dataset parameters
215+
parser.add_argument('root', metavar='DIR',
216+
help='root path of dataset')
217+
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
218+
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
219+
' (default: Office31)')
220+
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
221+
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
222+
parser.add_argument('--train-resizing', type=str, default='default')
223+
parser.add_argument('--val-resizing', type=str, default='default')
224+
parser.add_argument('--resize-size', type=int, default=224,
225+
help='the image size after resizing')
226+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
227+
help='Random resize scale (default: 0.08 1.0)')
228+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
229+
help='Random resize aspect ratio (default: 0.75 1.33)')
230+
parser.add_argument('--no-hflip', action='store_true',
231+
help='no random horizontal flipping during training')
232+
parser.add_argument('--norm-mean', type=float, nargs='+',
233+
default=(0.485, 0.456, 0.406), help='normalization mean')
234+
parser.add_argument('--norm-std', type=float, nargs='+',
235+
default=(0.229, 0.224, 0.225), help='normalization std')
236+
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
237+
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
238+
# model parameters
239+
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
240+
choices=utils.get_model_names(),
241+
help='backbone architecture: ' +
242+
' | '.join(utils.get_model_names()) +
243+
' (default: resnet18)')
244+
parser.add_argument('--bottleneck-dim', default=256, type=int,
245+
help='Dimension of bottleneck')
246+
parser.add_argument('--no-pool', action='store_true',
247+
help='no pool layer after the feature extractor.')
248+
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
249+
parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling')
250+
parser.add_argument('--thr', default=0.95, type=float, help='thr parameter for consistency loss')
251+
parser.add_argument('--trade-off', default=1., type=float,
252+
help='the trade-off hyper-parameter for original mcc loss')
253+
parser.add_argument('--trade_off_consistency', default=1., type=float,
254+
help='the trade-off hyper-parameter for consistency loss')
255+
# training parameters
256+
parser.add_argument('-b', '--batch-size', default=36, type=int,
257+
metavar='N',
258+
help='mini-batch size (default: 36)')
259+
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,
260+
metavar='LR', help='initial learning rate', dest='lr')
261+
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
262+
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
263+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
264+
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
265+
metavar='W', help='weight decay (default: 1e-3)',
266+
dest='weight_decay')
267+
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
268+
help='number of data loading workers (default: 2)')
269+
parser.add_argument('--epochs', default=20, type=int, metavar='N',
270+
help='number of total epochs to run')
271+
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
272+
help='Number of iterations per epoch')
273+
parser.add_argument('-p', '--print-freq', default=100, type=int,
274+
metavar='N', help='print frequency (default: 100)')
275+
parser.add_argument('--seed', default=None, type=int,
276+
help='seed for initializing training. ')
277+
parser.add_argument('--per-class-eval', action='store_true',
278+
help='whether output per-class accuracy during evaluation')
279+
parser.add_argument("--log", type=str, default='mcc',
280+
help="Where to save logs, checkpoints and debugging images.")
281+
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
282+
help="When phase is 'test', only test the model."
283+
"When phase is 'analysis', only analysis the model.")
284+
args = parser.parse_args()
285+
main(args)

0 commit comments

Comments
 (0)