Skip to content

Commit 193c594

Browse files
committed
test code
1 parent dab5147 commit 193c594

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ def argument_parser():
2222
parser.add_argument('--device', default=0, type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
2323
parser.add_argument("--redirector", action='store_false')
2424
parser.add_argument('--use_bn', action='store_false')
25+
parser.add_argument("--checkpoint", type=str)
2526

2627
return parser

test.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#%%writefile test.py
2+
3+
import os
4+
import pprint
5+
from collections import OrderedDict, defaultdict
6+
7+
import numpy as np
8+
import torch
9+
from torch.optim.lr_scheduler import ReduceLROnPlateau
10+
from torch.utils.data import DataLoader
11+
12+
from batch_engine import valid_trainer, batch_trainer
13+
from config import argument_parser
14+
from dataset.AttrDataset import AttrDataset, get_transform
15+
from loss.CE_loss import CEL_Sigmoid
16+
from models.base_block import FeatClassifier, BaseClassifier
17+
from models.vgg import vgg16
18+
from tools.function import get_model_log_path, get_pedestrian_metrics
19+
from tools.utils import time_str, save_ckpt, ReDirectSTD, set_seed
20+
21+
set_seed(605)
22+
23+
24+
def main(args):
25+
visenv_name = args.dataset
26+
exp_dir = os.path.join('exp_result', args.dataset)
27+
model_dir, log_dir = get_model_log_path(exp_dir, visenv_name)
28+
stdout_file = os.path.join(log_dir, f'stdout_{time_str()}.txt')
29+
30+
checkpoint_file = args.checkpoint
31+
32+
if args.redirector:
33+
print('redirector stdout')
34+
ReDirectSTD(stdout_file, 'stdout', False)
35+
36+
pprint.pprint(OrderedDict(args.__dict__))
37+
38+
print('-' * 60)
39+
print(f'use GPU{args.device} for testing')
40+
#print(f'train set: {args.dataset} {args.train_split}, test set: {args.valid_split}')
41+
42+
train_tsfm, valid_tsfm = get_transform(args)
43+
#print(train_tsfm)
44+
45+
train_set = AttrDataset(args=args, split=args.train_split, transform=train_tsfm)
46+
47+
train_loader = DataLoader(
48+
dataset=train_set,
49+
batch_size=args.batchsize,
50+
shuffle=True,
51+
num_workers=4,
52+
pin_memory=True,
53+
)
54+
valid_set = AttrDataset(args=args, split=args.valid_split, transform=valid_tsfm)
55+
56+
valid_loader = DataLoader(
57+
dataset=valid_set,
58+
batch_size=args.batchsize,
59+
shuffle=False,
60+
num_workers=4,
61+
pin_memory=True,
62+
)
63+
64+
print(f'{args.train_split} set: {len(train_loader.dataset)}, '
65+
f'{args.valid_split} set: {len(valid_loader.dataset)}, '
66+
f'attr_num : {train_set.attr_num}')
67+
68+
labels = train_set.label
69+
sample_weight = labels.mean(0)
70+
71+
backbone = vgg16()
72+
classifier = BaseClassifier(nattr=train_set.attr_num)
73+
model = FeatClassifier(backbone, classifier)
74+
75+
#model.load_state_dict(torch.load(filename))
76+
if torch.cuda.is_available():
77+
model = torch.nn.DataParallel(model).cuda()
78+
79+
#print(checkpoint_file['state_dicts'])
80+
checkpoint = torch.load(checkpoint_file)
81+
model.load_state_dict(checkpoint["state_dicts"])
82+
83+
criterion = CEL_Sigmoid(sample_weight)
84+
85+
param_groups = [{'params': model.module.finetune_params(), 'lr': args.lr_ft},
86+
{'params': model.module.fresh_params(), 'lr': args.lr_new}]
87+
optimizer = torch.optim.SGD(param_groups, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
88+
lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=4)
89+
90+
91+
92+
tester(epoch=args.train_epoch,
93+
model=model,
94+
train_loader=train_loader,
95+
valid_loader=valid_loader,
96+
criterion=criterion,
97+
optimizer=optimizer,
98+
lr_scheduler=lr_scheduler,
99+
)
100+
101+
102+
103+
104+
def tester(epoch, model, train_loader, valid_loader, criterion, optimizer, lr_scheduler):
105+
maximum = float(-np.inf)
106+
best_epoch = 0
107+
108+
result_list = defaultdict()
109+
110+
111+
112+
valid_loss, valid_gt, valid_probs = valid_trainer(
113+
model=model,
114+
valid_loader=valid_loader,
115+
criterion=criterion,
116+
)
117+
118+
lr_scheduler.step(metrics=valid_loss, epoch=1)
119+
120+
121+
valid_result = get_pedestrian_metrics(valid_gt, valid_probs)
122+
123+
print(f'Evaluation on test set, \n',
124+
'ma: {:.4f}, pos_recall: {:.4f} , neg_recall: {:.4f} \n'.format(
125+
valid_result.ma, np.mean(valid_result.label_pos_recall), np.mean(valid_result.label_neg_recall)),
126+
'Acc: {:.4f}, Prec: {:.4f}, Rec: {:.4f}, F1: {:.4f}'.format(
127+
valid_result.instance_acc, valid_result.instance_prec, valid_result.instance_recall,
128+
valid_result.instance_f1))
129+
130+
print(f'{time_str()}')
131+
print('-' * 60)
132+
133+
134+
135+
136+
if __name__ == '__main__':
137+
parser = argument_parser()
138+
139+
args = parser.parse_args()
140+
print(args)
141+
main(args)
142+
143+
# os.path.abspath()
144+
145+
"""
146+
载入的时候要:
147+
from tools.function import LogVisual
148+
sys.modules['LogVisual'] = LogVisual
149+
log = torch.load('./save/2018-10-29_21:17:34trlog')
150+
"""

0 commit comments

Comments
 (0)