Skip to content

Commit 915371b

Browse files
committed
Improving speed by moving video capture to multithreading.
1 parent 834da03 commit 915371b

File tree

3 files changed

+188
-10
lines changed

3 files changed

+188
-10
lines changed

object_detection_app.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorflow as tf
66

7-
from utils import FPS
7+
from utils import FPS, WebcamVideoStream
88
from object_detection.utils import label_map_util
99
from object_detection.utils import visualization_utils as vis_util
1010

@@ -54,7 +54,6 @@ def detect_objects(image_np, sess, detection_graph):
5454
category_index,
5555
use_normalized_coordinates=True,
5656
line_thickness=8)
57-
5857
return image_np
5958

6059

@@ -70,18 +69,19 @@ def detect_objects(image_np, sess, detection_graph):
7069

7170
sess = tf.Session(graph=detection_graph)
7271

73-
video_capture = cv2.VideoCapture(0)
74-
video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
75-
video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 360)
76-
72+
video_capture = WebcamVideoStream(src=0).start()
7773
fps = FPS().start()
7874

79-
while True:
80-
ret, frame = video_capture.read()
75+
while fps._numFrames < 120:
76+
frame = video_capture.read()
8177

8278
t = time.time()
8379

84-
cv2.imshow('Video', detect_objects(frame, sess, detection_graph))
80+
# cv2.imshow('Video', detect_objects(frame, sess, detection_graph))
81+
82+
# time.sleep(2)
83+
84+
detect_objects(frame, sess, detection_graph)
8585

8686
print(time.time() - t)
8787

@@ -97,6 +97,6 @@ def detect_objects(image_np, sess, detection_graph):
9797
print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))
9898

9999
# When everything is done, release the capture
100-
video_capture.release()
100+
video_capture.stop()
101101
cv2.destroyAllWindows()
102102
sess.close()

object_detection_multilayer.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import cv2
2+
import multiprocessing
3+
import time
4+
5+
import os
6+
import numpy as np
7+
import tensorflow as tf
8+
9+
from utils import FPS
10+
from object_detection.utils import label_map_util
11+
from object_detection.utils import visualization_utils as vis_util
12+
13+
CWD_PATH = os.getcwd()
14+
15+
# Path to frozen detection graph. This is the actual model that is used for the object detection.
16+
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
17+
PATH_TO_CKPT = os.path.join(CWD_PATH, 'object_detection', MODEL_NAME, 'frozen_inference_graph.pb')
18+
19+
# List of the strings that is used to add correct label for each box.
20+
PATH_TO_LABELS = os.path.join(CWD_PATH, 'object_detection', 'data', 'mscoco_label_map.pbtxt')
21+
22+
NUM_CLASSES = 90
23+
24+
# Loading label map
25+
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
26+
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
27+
use_display_name=True)
28+
category_index = label_map_util.create_category_index(categories)
29+
30+
31+
def detect_objects(image_np, sess, detection_graph):
32+
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
33+
image_np_expanded = np.expand_dims(image_np, axis=0)
34+
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
35+
36+
# Each box represents a part of the image where a particular object was detected.
37+
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
38+
39+
# Each score represent how level of confidence for each of the objects.
40+
# Score is shown on the result image, together with the class label.
41+
scores = detection_graph.get_tensor_by_name('detection_scores:0')
42+
classes = detection_graph.get_tensor_by_name('detection_classes:0')
43+
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
44+
45+
# Actual detection.
46+
(boxes, scores, classes, num_detections) = sess.run(
47+
[boxes, scores, classes, num_detections],
48+
feed_dict={image_tensor: image_np_expanded})
49+
50+
# Visualization of the results of a detection.
51+
vis_util.visualize_boxes_and_labels_on_image_array(
52+
image_np,
53+
np.squeeze(boxes),
54+
np.squeeze(classes).astype(np.int32),
55+
np.squeeze(scores),
56+
category_index,
57+
use_normalized_coordinates=True,
58+
line_thickness=8)
59+
60+
return image_np
61+
62+
63+
def blend_non_transparent(face_img, overlay_img):
64+
# Let's find a mask covering all the non-black (foreground) pixels
65+
# NB: We need to do this on grayscale version of the image
66+
gray_overlay = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2GRAY)
67+
overlay_mask = cv2.threshold(gray_overlay, 1, 255, cv2.THRESH_BINARY)[1]
68+
69+
# Let's shrink and blur it a little to make the transitions smoother...
70+
overlay_mask = cv2.erode(overlay_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
71+
overlay_mask = cv2.blur(overlay_mask, (3, 3))
72+
73+
# And the inverse mask, that covers all the black (background) pixels
74+
background_mask = 255 - overlay_mask
75+
76+
# Turn the masks into three channel, so we can use them as weights
77+
overlay_mask = cv2.cvtColor(overlay_mask, cv2.COLOR_GRAY2BGR)
78+
background_mask = cv2.cvtColor(background_mask, cv2.COLOR_GRAY2BGR)
79+
80+
# Create a masked out face image, and masked out overlay
81+
# We convert the images to floating point in range 0.0 - 1.0
82+
face_part = (face_img * (1 / 255.0)) * (background_mask * (1 / 255.0))
83+
overlay_part = (overlay_img * (1 / 255.0)) * (overlay_mask * (1 / 255.0))
84+
85+
# And finally just add them together, and rescale it back to an 8bit integer image
86+
return np.uint8(cv2.addWeighted(face_part, 255.0, overlay_part, 255.0, 0.0))
87+
88+
89+
def main_process(input, output):
90+
while True:
91+
time.sleep(0.5)
92+
image = input.get()
93+
output.put(image)
94+
95+
96+
def child_process(input, output):
97+
# Load a (frozen) Tensorflow model into memory.
98+
detection_graph = tf.Graph()
99+
with detection_graph.as_default():
100+
od_graph_def = tf.GraphDef()
101+
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
102+
serialized_graph = fid.read()
103+
od_graph_def.ParseFromString(serialized_graph)
104+
tf.import_graph_def(od_graph_def, name='')
105+
106+
sess = tf.Session(graph=detection_graph)
107+
108+
while True:
109+
image = input.get()
110+
image2 = detect_objects(image, sess, detection_graph)
111+
result = blend_non_transparent(image, image2)
112+
output.put(result)
113+
114+
115+
if __name__ == '__main__':
116+
input = multiprocessing.Queue(5)
117+
output = multiprocessing.Queue(5)
118+
119+
main_process = multiprocessing.Process(target=main_process, args=(input, output))
120+
main_process.daemon = True
121+
child_process = multiprocessing.Process(target=child_process, args=(input, output))
122+
child_process.daemon = False
123+
124+
main_process.start()
125+
child_process.start()
126+
127+
video_capture = cv2.VideoCapture(0)
128+
video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
129+
video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 360)
130+
131+
while True:
132+
_, frame = video_capture.read()
133+
134+
input.put(frame)
135+
136+
cv2.imshow('Video', output.get())
137+
138+
if cv2.waitKey(1) & 0xFF == ord('q'):
139+
break

utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import cv2
12
import datetime
3+
from threading import Thread
24

35

46
class FPS:
@@ -31,3 +33,40 @@ def elapsed(self):
3133
def fps(self):
3234
# compute the (approximate) frames per second
3335
return self._numFrames / self.elapsed()
36+
37+
38+
class WebcamVideoStream:
39+
def __init__(self, src=0):
40+
# initialize the video camera stream and read the first frame
41+
# from the stream
42+
self.stream = cv2.VideoCapture(src)
43+
self.stream.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
44+
self.stream.set(cv2.CAP_PROP_FRAME_HEIGHT, 360)
45+
(self.grabbed, self.frame) = self.stream.read()
46+
47+
# initialize the variable used to indicate if the thread should
48+
# be stopped
49+
self.stopped = False
50+
51+
def start(self):
52+
# start the thread to read frames from the video stream
53+
Thread(target=self.update, args=()).start()
54+
return self
55+
56+
def update(self):
57+
# keep looping infinitely until the thread is stopped
58+
while True:
59+
# if the thread indicator variable is set, stop the thread
60+
if self.stopped:
61+
return
62+
63+
# otherwise, read the next frame from the stream
64+
(self.grabbed, self.frame) = self.stream.read()
65+
66+
def read(self):
67+
# return the frame most recently read
68+
return self.frame
69+
70+
def stop(self):
71+
# indicate that the thread should be stopped
72+
self.stopped = True

0 commit comments

Comments
 (0)