Skip to content

Commit 821ddc1

Browse files
committed
Building train function for AL
1 parent 5605542 commit 821ddc1

16 files changed

+760
-215
lines changed

activelearning/config.py renamed to AL_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
unlabeled = 'data/unlabeled'
44
labeled = 'data/labeled'
55
num_select = 8
6-
info_predict_path = 'activelearning/predict.txt'
7-
mode_active = 'sum' # 'sum or mean
86

97
# Config general
108
project = "/media/thang/New Volume/Active-learning-for-object-detection/"
11-
weight = 'activelearning/yolov5s.pt'
9+
weight = 'yolov5s.pt'
1210
device = '0' # cpu or 0,1,...
1311
name = 'gun'
1412
exist_ok = 1

AL_detect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from utils.datasets import LoadImages
55
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
66
from utils.torch_utils import select_device, time_synchronized
7-
import activelearning.config as config
7+
import AL_config as config
88
import time
99

1010
def AL_detect(opt):

AL_run.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,53 @@
1818
1919
"""
2020

21+
from torch.cuda.memory import reset_accumulated_memory_stats
2122
from AL_yolov5 import Yolov5
22-
import activelearning.select_function as select
23-
import activelearning.config as config
23+
import AL_config as config
2424
import glob
2525
import os
2626
from shutil import copyfile, move
2727
import io
2828
import copy
29+
def RandomSelect(num_select, result):
30+
pass
31+
32+
def UncertaintySamplingBinary(num_select, result, typ):
33+
"""
34+
result =
35+
{"<link ảnh>":
36+
[
37+
{"class": cls.item(), "box": [x,y,w,h], "conf": conf.item(),
38+
...
39+
],
40+
...
41+
}
42+
"""
43+
probas = {}
44+
if typ == 'sum':
45+
for item, lst_dic in result.items():
46+
conf = 0
47+
for dic in lst_dic:
48+
conf += (1.0 - dic["conf"])
49+
probas[item] = conf
50+
elif typ == 'avg':
51+
for item, lst_dic in result.items():
52+
conf = 0
53+
for dic in lst_dic:
54+
conf += (1.0 - dic["conf"])
55+
probas[item] = conf/len(lst_dic)
56+
elif typ == 'max':
57+
for item, lst_dic in result.items():
58+
conf = 0
59+
for dic in lst_dic:
60+
conf = max(conf, 1.0 - dic["conf"])
61+
probas[item] = conf
62+
return sorted(probas, key=probas.get, reverse=True)[:num_select]
63+
2964

3065
class ActiveLearning(object):
31-
def __init__(self, model, select_function):
66+
def __init__(self, model):
3267
self.model = model
33-
self.select_function = select_function
3468
self.num_select = config.num_select
3569
self.type = 'sum' # 'avg' , 'max', 'sum'
3670

@@ -41,46 +75,32 @@ def run(self):
4175
while queried < config.max_queried:
4276
# Dự đoán các ảnh trong tập unlabeled
4377
result = self.model.detect()
44-
# Tổng hợp kết quả
45-
probas = {str(file.split('/')[-1]): 0.0 for file in glob.glob(config.source + '/*')}
46-
num_object = probas.copy()
47-
48-
with open(config.info_predict_path, 'r') as f:
49-
for line in f.readlines():
50-
*cxywh, prob, file_name = line.split(',')
51-
file_name = file_name[:-1]
52-
probas[file_name] += 1.0 - float(prob)
53-
num_object[file_name] += 1
54-
55-
if config.mode_active == 'mean':
56-
for key, value in probas.items():
57-
probas[key] = value/num_object[key]
58-
59-
# Chọn ra k samples
60-
if len(probas) >= self.num_select:
61-
U_best = self.select_function.select(self.num_select, probas)
62-
63-
# Gán nhãn cho U_best
64-
""" GIẢ SỬ ĐOẠN NÀY LÀ NGƯỜI GÁN """
65-
# Tạo ra các file label cho U_best vào thư mục labeled
78+
79+
# Chon ra k ảnh có score cao nhất
80+
# Sử dụng lấy mẫu không chắc chắn
81+
if len(result) >= self.num_select:
82+
U_best = UncertaintySamplingBinary(self.num_select, result, 'sum')
83+
print(U_best)
84+
85+
# Gán nhãn cho các file trong samples (Người tương tác)
86+
87+
# Duyệt tất cả các file được chọn
6688
for f in U_best:
6789
# Chuyển file ảnh vào thự mục labeled
68-
move(os.path.join(config.unlabeled, f), os.path.join(config.labeled, f))
90+
move(f, f.replace("unlabeled", "labeled"))
6991
# Tạo file nhãn vào thư mục labeled
70-
file_name = '.'.join(f.split('.')[:-1])
71-
# print(f)
72-
# source = open(os.path.join('data/gun/',f),"r",encoding='utf-8').readlines()
73-
# dect = open(os.path.join(config.labeled, file_name + '.txt', "w", encoding='utf-8').write(copy.copy(source)))
74-
copyfile(os.path.join('data/gun/', file_name+'.txt'), os.path.join(config.labeled, file_name + '.txt'))
75-
76-
# Cập nhật file train.txt
77-
with open('data/train.txt', 'w') as f:
78-
for fn in U_best:
79-
f.write(config.project + config.labeled + "/" + fn + "\n")
80-
92+
type_file = f.split('.')[-1]
93+
copyfile(f.replace("unlabeled","gun").replace(type_file, 'txt'), f.replace("unlabeled", "labeled").replace(type_file, 'txt'))
94+
95+
# thêm danh sách file đã gán nhãn vào dữ liệu train
96+
with open('data/train.txt', "a") as f:
97+
for file_name in U_best:
98+
f.write(file_name.replace("unlabeled","labeled") + '\n')
99+
81100
# Train model
82101
self.model.train()
83102

103+
####################### LOADING ########################
84104
# Xoá file weight cũ
85105
if os.path.exists(config.weight):
86106
os.remove(config.weight)
@@ -95,7 +115,5 @@ def run(self):
95115

96116

97117
if __name__ == '__main__':
98-
bot = ActiveLearning(model=Yolov5(), select_function=select.RandomSelect())
99-
bot.run()
100-
# model = Yolov5()
101-
# model.train()
118+
bot = ActiveLearning(model=Yolov5())
119+
bot.run()

0 commit comments

Comments
 (0)