1818
1919"""
2020
21+ from torch .cuda .memory import reset_accumulated_memory_stats
2122from AL_yolov5 import Yolov5
22- import activelearning .select_function as select
23- import activelearning .config as config
23+ import AL_config as config
2424import glob
2525import os
2626from shutil import copyfile , move
2727import io
2828import copy
29+ def RandomSelect (num_select , result ):
30+ pass
31+
32+ def UncertaintySamplingBinary (num_select , result , typ ):
33+ """
34+ result =
35+ {"<link ảnh>":
36+ [
37+ {"class": cls.item(), "box": [x,y,w,h], "conf": conf.item(),
38+ ...
39+ ],
40+ ...
41+ }
42+ """
43+ probas = {}
44+ if typ == 'sum' :
45+ for item , lst_dic in result .items ():
46+ conf = 0
47+ for dic in lst_dic :
48+ conf += (1.0 - dic ["conf" ])
49+ probas [item ] = conf
50+ elif typ == 'avg' :
51+ for item , lst_dic in result .items ():
52+ conf = 0
53+ for dic in lst_dic :
54+ conf += (1.0 - dic ["conf" ])
55+ probas [item ] = conf / len (lst_dic )
56+ elif typ == 'max' :
57+ for item , lst_dic in result .items ():
58+ conf = 0
59+ for dic in lst_dic :
60+ conf = max (conf , 1.0 - dic ["conf" ])
61+ probas [item ] = conf
62+ return sorted (probas , key = probas .get , reverse = True )[:num_select ]
63+
2964
3065class ActiveLearning (object ):
31- def __init__ (self , model , select_function ):
66+ def __init__ (self , model ):
3267 self .model = model
33- self .select_function = select_function
3468 self .num_select = config .num_select
3569 self .type = 'sum' # 'avg' , 'max', 'sum'
3670
@@ -41,46 +75,32 @@ def run(self):
4175 while queried < config .max_queried :
4276 # Dự đoán các ảnh trong tập unlabeled
4377 result = self .model .detect ()
44- # Tổng hợp kết quả
45- probas = {str (file .split ('/' )[- 1 ]): 0.0 for file in glob .glob (config .source + '/*' )}
46- num_object = probas .copy ()
47-
48- with open (config .info_predict_path , 'r' ) as f :
49- for line in f .readlines ():
50- * cxywh , prob , file_name = line .split (',' )
51- file_name = file_name [:- 1 ]
52- probas [file_name ] += 1.0 - float (prob )
53- num_object [file_name ] += 1
54-
55- if config .mode_active == 'mean' :
56- for key , value in probas .items ():
57- probas [key ] = value / num_object [key ]
58-
59- # Chọn ra k samples
60- if len (probas ) >= self .num_select :
61- U_best = self .select_function .select (self .num_select , probas )
62-
63- # Gán nhãn cho U_best
64- """ GIẢ SỬ ĐOẠN NÀY LÀ NGƯỜI GÁN """
65- # Tạo ra các file label cho U_best vào thư mục labeled
78+
79+ # Chon ra k ảnh có score cao nhất
80+ # Sử dụng lấy mẫu không chắc chắn
81+ if len (result ) >= self .num_select :
82+ U_best = UncertaintySamplingBinary (self .num_select , result , 'sum' )
83+ print (U_best )
84+
85+ # Gán nhãn cho các file trong samples (Người tương tác)
86+
87+ # Duyệt tất cả các file được chọn
6688 for f in U_best :
6789 # Chuyển file ảnh vào thự mục labeled
68- move (os . path . join ( config . unlabeled , f ), os . path . join ( config . labeled , f ))
90+ move (f , f . replace ( "unlabeled" , "labeled" ))
6991 # Tạo file nhãn vào thư mục labeled
70- file_name = '.' .join (f .split ('.' )[:- 1 ])
71- # print(f)
72- # source = open(os.path.join('data/gun/',f),"r",encoding='utf-8').readlines()
73- # dect = open(os.path.join(config.labeled, file_name + '.txt', "w", encoding='utf-8').write(copy.copy(source)))
74- copyfile (os .path .join ('data/gun/' , file_name + '.txt' ), os .path .join (config .labeled , file_name + '.txt' ))
75-
76- # Cập nhật file train.txt
77- with open ('data/train.txt' , 'w' ) as f :
78- for fn in U_best :
79- f .write (config .project + config .labeled + "/" + fn + "\n " )
80-
92+ type_file = f .split ('.' )[- 1 ]
93+ copyfile (f .replace ("unlabeled" ,"gun" ).replace (type_file , 'txt' ), f .replace ("unlabeled" , "labeled" ).replace (type_file , 'txt' ))
94+
95+ # thêm danh sách file đã gán nhãn vào dữ liệu train
96+ with open ('data/train.txt' , "a" ) as f :
97+ for file_name in U_best :
98+ f .write (file_name .replace ("unlabeled" ,"labeled" ) + '\n ' )
99+
81100 # Train model
82101 self .model .train ()
83102
103+ ####################### LOADING ########################
84104 # Xoá file weight cũ
85105 if os .path .exists (config .weight ):
86106 os .remove (config .weight )
@@ -95,7 +115,5 @@ def run(self):
95115
96116
97117if __name__ == '__main__' :
98- bot = ActiveLearning (model = Yolov5 (), select_function = select .RandomSelect ())
99- bot .run ()
100- # model = Yolov5()
101- # model.train()
118+ bot = ActiveLearning (model = Yolov5 ())
119+ bot .run ()
0 commit comments