55import glob
66import shutil
77sys .path .append (os .getcwd ())
8-
98from lib .networks .factory import get_network
109from lib .fast_rcnn .config import cfg ,cfg_from_file
1110from lib .fast_rcnn .test import test_ctpn
12- from lib .fast_rcnn .nms_wrapper import nms
1311from lib .utils .timer import Timer
14- from text_proposal_connector import TextProposalConnector
15-
16- CLASSES = ('__background__' ,
17- 'text' )
18-
19-
20- def connect_proposal (text_proposals , scores , im_size ):
21- cp = TextProposalConnector ()
22- line = cp .get_text_lines (text_proposals , scores , im_size )
23- return line
24-
25- def save_results (image_name ,im ,im_scale ,line ,thresh ):
26- inds = np .where (line [:,- 1 ]>= thresh )[0 ]
27- image_name = image_name .split ('/' )[- 1 ]
28- if len (inds )== 0 :
29- im = cv2 .resize (im , None , None , fx = 1.0 / im_scale , fy = 1.0 / im_scale , interpolation = cv2 .INTER_LINEAR )
30- cv2 .imwrite (os .path .join ("data/results" ,image_name ),im )
31- return
12+ from lib .text_connector .detectors import TextDetector
13+ from lib .text_connector .text_connect_cfg import Config as TextLineCfg
3214
33- for i in inds :
34- bbox = line [i ,:4 ]
35- score = line [i ,- 1 ]
36- cv2 .rectangle (im ,(bbox [0 ],bbox [1 ]),(bbox [2 ],bbox [3 ]),color = (0 ,255 ,0 ),thickness = 2 )
37- im = cv2 .resize (im , None , None , fx = 1.0 / im_scale , fy = 1.0 / im_scale , interpolation = cv2 .INTER_LINEAR )
38- cv2 .imwrite (os .path .join ("data/results" ,image_name ),im )
3915
16+ def resize_im (im , scale , max_scale = None ):
17+ f = float (scale )/ min (im .shape [0 ], im .shape [1 ])
18+ if max_scale != None and f * max (im .shape [0 ], im .shape [1 ])> max_scale :
19+ f = float (max_scale )/ max (im .shape [0 ], im .shape [1 ])
20+ return cv2 .resize (im , None ,None , fx = f , fy = f ,interpolation = cv2 .INTER_LINEAR ), f
4021
41- def check_img (img ):
42- img_size = img .shape
43- im_size_min = np .min (img_size [0 :2 ])
44- im_size_max = np .max (img_size [0 :2 ])
4522
46- im_scale = float (600 ) / float (im_size_min )
47- if np .round (im_scale * im_size_max ) > 1200 :
48- im_scale = float (1200 ) / float (im_size_max )
49- re_im = cv2 .resize (img , None , None , fx = im_scale , fy = im_scale , interpolation = cv2 .INTER_LINEAR )
50- return re_im , im_scale
23+ def draw_boxes (img ,image_name ,boxes ,scale ):
24+ for box in boxes :
25+ cv2 .line (img , (int (box [0 ]), int (box [1 ])), (int (box [2 ]), int (box [3 ])), (0 , 255 , 0 ), 2 )
26+ cv2 .line (img , (int (box [0 ]), int (box [1 ])), (int (box [4 ]), int (box [5 ])), (0 , 255 , 0 ), 2 )
27+ cv2 .line (img , (int (box [6 ]), int (box [7 ])), (int (box [2 ]), int (box [3 ])), (0 , 255 , 0 ), 2 )
28+ cv2 .line (img , (int (box [4 ]), int (box [5 ])), (int (box [6 ]), int (box [7 ])), (0 , 255 , 0 ), 2 )
5129
30+ base_name = image_name .split ('/' )[- 1 ]
31+ img = cv2 .resize (img , None , None , fx = 1.0 / scale , fy = 1.0 / scale , interpolation = cv2 .INTER_LINEAR )
32+ cv2 .imwrite (os .path .join ("data/results" , base_name ), img )
5233
5334def ctpn (sess , net , image_name ):
54- img = cv2 .imread (image_name )
55- im , im_scale = check_img (img )
5635 timer = Timer ()
5736 timer .tic ()
58- scores , boxes = test_ctpn (sess , net , im )
37+
38+ img = cv2 .imread (image_name )
39+ img , scale = resize_im (img , scale = TextLineCfg .SCALE , max_scale = TextLineCfg .MAX_SCALE )
40+ scores , boxes = test_ctpn (sess , net , img )
41+
42+ textdetector = TextDetector ()
43+ boxes = textdetector .detect (boxes , scores [:, np .newaxis ], img .shape [:2 ])
44+ draw_boxes (img , image_name , boxes , scale )
5945 timer .toc ()
6046 print (('Detection took {:.3f}s for '
6147 '{:d} object proposals' ).format (timer .total_time , boxes .shape [0 ]))
6248
63- # Visualize detections for each class
64- CONF_THRESH = 0.9
65- NMS_THRESH = 0.3
66- dets = np .hstack ((boxes , scores [:, np .newaxis ])).astype (np .float32 )
67- keep = nms (dets , NMS_THRESH )
68- dets = dets [keep , :]
69-
70- keep = np .where (dets [:, 4 ] >= 0.7 )[0 ]
71- dets = dets [keep , :]
72- line = connect_proposal (dets [:, 0 :4 ], dets [:, 4 ], im .shape )
73- save_results (image_name , im ,im_scale , line ,thresh = 0.9 )
7449
7550
7651if __name__ == '__main__' :
@@ -91,16 +66,12 @@ def ctpn(sess, net, image_name):
9166
9267 try :
9368 ckpt = tf .train .get_checkpoint_state (cfg .TEST .checkpoints_path )
94- #ckpt=tf.train.get_checkpoint_state("output/ctpn_end2end/voc_2007_trainval/")
9569 print ('Restoring from {}...' .format (ckpt .model_checkpoint_path ), end = ' ' )
9670 saver .restore (sess , ckpt .model_checkpoint_path )
9771 print ('done' )
9872 except :
9973 raise 'Check your pretrained {:s}' .format (ckpt .model_checkpoint_path )
100- print (' done.' )
10174
102- #saver.restore(sess, os.path.join(os.getcwd(),"checkpoints/model_final_tf13.ckpt"))
103- # Warmup on a dummy image
10475 im = 128 * np .ones ((300 , 300 , 3 ), dtype = np .uint8 )
10576 for i in range (2 ):
10677 _ , _ = test_ctpn (sess , net , im )
@@ -112,3 +83,4 @@ def ctpn(sess, net, image_name):
11283 print ('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' )
11384 print (('Demo for {:s}' .format (im_name )))
11485 ctpn (sess , net , im_name )
86+
0 commit comments