Skip to content

Commit 0e0bd0f

Browse files
committed
WIP: Add multithreading example with lag; need to draw boxes only instead of entire predicted image.
1 parent 13adb69 commit 0e0bd0f

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

object_detection_multithreading.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import os
2+
import cv2
3+
import time
4+
import argparse
5+
import numpy as np
6+
import tensorflow as tf
7+
8+
from queue import Queue
9+
from threading import Thread
10+
from utils import FPS, WebcamVideoStream
11+
from object_detection.utils import label_map_util
12+
from object_detection.utils import visualization_utils as vis_util
13+
14+
CWD_PATH = os.getcwd()
15+
16+
# Path to frozen detection graph. This is the actual model that is used for the object detection.
17+
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
18+
PATH_TO_CKPT = os.path.join(CWD_PATH, 'object_detection', MODEL_NAME, 'frozen_inference_graph.pb')
19+
20+
# List of the strings that is used to add correct label for each box.
21+
PATH_TO_LABELS = os.path.join(CWD_PATH, 'object_detection', 'data', 'mscoco_label_map.pbtxt')
22+
23+
NUM_CLASSES = 90
24+
25+
# Loading label map
26+
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
27+
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
28+
use_display_name=True)
29+
category_index = label_map_util.create_category_index(categories)
30+
31+
32+
def detect_objects(image_np, sess, detection_graph):
33+
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
34+
image_np_expanded = np.expand_dims(image_np, axis=0)
35+
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
36+
37+
# Each box represents a part of the image where a particular object was detected.
38+
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
39+
40+
# Each score represent how level of confidence for each of the objects.
41+
# Score is shown on the result image, together with the class label.
42+
scores = detection_graph.get_tensor_by_name('detection_scores:0')
43+
classes = detection_graph.get_tensor_by_name('detection_classes:0')
44+
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
45+
46+
# Actual detection.
47+
(boxes, scores, classes, num_detections) = sess.run(
48+
[boxes, scores, classes, num_detections],
49+
feed_dict={image_tensor: image_np_expanded})
50+
51+
# Visualization of the results of a detection.
52+
vis_util.visualize_boxes_and_labels_on_image_array(
53+
image_np,
54+
np.squeeze(boxes),
55+
np.squeeze(classes).astype(np.int32),
56+
np.squeeze(scores),
57+
category_index,
58+
use_normalized_coordinates=True,
59+
line_thickness=8)
60+
return image_np
61+
62+
63+
def worker(input_q, output_q):
64+
# Load a (frozen) Tensorflow model into memory.
65+
detection_graph = tf.Graph()
66+
with detection_graph.as_default():
67+
od_graph_def = tf.GraphDef()
68+
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
69+
serialized_graph = fid.read()
70+
od_graph_def.ParseFromString(serialized_graph)
71+
tf.import_graph_def(od_graph_def, name='')
72+
73+
sess = tf.Session(graph=detection_graph)
74+
75+
fps = FPS().start()
76+
while True:
77+
fps.update()
78+
frame = input_q.get()
79+
output_q.put(detect_objects(frame, sess, detection_graph))
80+
81+
fps.stop()
82+
sess.close()
83+
84+
85+
if __name__ == '__main__':
86+
parser = argparse.ArgumentParser()
87+
parser.add_argument('-src', '--source', dest='video_source', type=int,
88+
default=0, help='Device index of the camera.')
89+
parser.add_argument('-wd', '--width', dest='width', type=int,
90+
default=480, help='Width of the frames in the video stream.')
91+
parser.add_argument('-ht', '--height', dest='height', type=int,
92+
default=360, help='Height of the frames in the video stream.')
93+
args = parser.parse_args()
94+
95+
input_q = Queue(30)
96+
output_q = Queue()
97+
for i in range(1):
98+
t = Thread(target=worker, args=(input_q, output_q))
99+
t.daemon = True
100+
t.start()
101+
102+
video_capture = WebcamVideoStream(src=args.video_source,
103+
width=args.width,
104+
height=args.height).start()
105+
fps = FPS().start()
106+
107+
while True:
108+
frame = video_capture.read()
109+
input_q.put(frame)
110+
111+
t = time.time()
112+
113+
if output_q.empty():
114+
cv2.imshow('Video', frame)
115+
else:
116+
# TO-DO need to draw the boxes here
117+
cv2.imshow('Video', frame)
118+
print(output_q.get())
119+
120+
fps.update()
121+
122+
print('[INFO] elapsed time: {:.2f}'.format(time.time() - t))
123+
124+
if cv2.waitKey(1) & 0xFF == ord('q'):
125+
break
126+
127+
fps.stop()
128+
print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed()))
129+
print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))
130+
131+
video_capture.stop()
132+
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)