此项目使用YOLOv3对视频进行物体检测并输出新的视频。

参考YOLO官网:https://pjreddie.com/darknet/yolo/,安装darknet,配置opencv和gpu。

主要步骤如下:

1) 下载 darknet 并编译:
  git clone https://github.com/pjreddie/darknet
  cd darknet
  make
2) 下载 yolov3 权重:
  wget https://pjreddie.com/media/files/yolov3.weights
3) 将编译生成的 libdarknet.so 复制到 python 文件夹下,此动态库中包含了检测需要的所有参数;
4) 修改 darknet.py 以实现摄像头和视频检测;
5) 将检测后的每一帧保存,然后调用 opencv 将所有帧合成视频。

官网主要使用c接口对图片、视频或是进行实时检测,Python代码可实现的功能较少,这里修改Python代码来实现以上功能。

对Python文件夹中darknet.py的修改主要如下:

1)为 80 个类别的检测框分别分配一种颜色
2) 为检测到的物体画出边界框(以及置信度)
3) 保存检测后的每一帧图像
4) 将所有帧合成一个 30fps 的视频

先贴几张图,最后合成的视频地址:https://www.bilibili.com/video/av94163474,代码在下方。

此代码也可实现对批量图片的检测以及实时检测。

【目标检测】使用YOLOv3对视频进行物体检测并输出新的视频

 

 【目标检测】使用YOLOv3对视频进行物体检测并输出新的视频

 

 【目标检测】使用YOLOv3对视频进行物体检测并输出新的视频

 

 修改后的darknet.py文件如下:

  1 from ctypes import *
  2 import random
  3 import cv2
  4 import numpy as np
  5 import os
  6 import time
  7 
  8 
  9 def sample(probs):
 10     s = sum(probs)
 11     probs = [a / s for a in probs]
 12     r = random.uniform(0, 1)
 13     for i in range(len(probs)):
 14         r = r - probs[i]
 15         if r <= 0:
 16             return i
 17     return len(probs) - 1
 18 
 19 
 20 def c_array(ctype, values):
 21     arr = (ctype * len(values))()
 22     arr[:] = values
 23     return arr
 24 
 25 
 26 class BOX(Structure):
 27     _fields_ = [("x", c_float),
 28                 ("y", c_float),
 29                 ("w", c_float),
 30                 ("h", c_float)]
 31 
 32 
 33 class DETECTION(Structure):
 34     _fields_ = [("bbox", BOX),
 35                 ("classes", c_int),
 36                 ("prob", POINTER(c_float)),
 37                 ("mask", POINTER(c_float)),
 38                 ("objectness", c_float),
 39                 ("sort_class", c_int)]
 40 
 41 
 42 class IMAGE(Structure):
 43     _fields_ = [("w", c_int),
 44                 ("h", c_int),
 45                 ("c", c_int),
 46                 ("data", POINTER(c_float))]
 47 
 48 
 49 class METADATA(Structure):
 50     _fields_ = [("classes", c_int),
 51                 ("names", POINTER(c_char_p))]
 52 
 53 
 54 lib = CDLL("./libdarknet.so", RTLD_GLOBAL)
 55 lib.network_width.argtypes = [c_void_p]
 56 lib.network_width.restype = c_int
 57 lib.network_height.argtypes = [c_void_p]
 58 lib.network_height.restype = c_int
 59 
 60 predict = lib.network_predict
 61 predict.argtypes = [c_void_p, POINTER(c_float)]
 62 predict.restype = POINTER(c_float)
 63 
 64 set_gpu = lib.cuda_set_device
 65 set_gpu.argtypes = [c_int]
 66 
 67 make_image = lib.make_image
 68 make_image.argtypes = [c_int, c_int, c_int]
 69 make_image.restype = IMAGE
 70 
 71 get_network_boxes = lib.get_network_boxes
 72 get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
 73 get_network_boxes.restype = POINTER(DETECTION)
 74 
 75 make_network_boxes = lib.make_network_boxes
 76 make_network_boxes.argtypes = [c_void_p]
 77 make_network_boxes.restype = POINTER(DETECTION)
 78 
 79 free_detections = lib.free_detections
 80 free_detections.argtypes = [POINTER(DETECTION), c_int]
 81 
 82 free_ptrs = lib.free_ptrs
 83 free_ptrs.argtypes = [POINTER(c_void_p), c_int]
 84 
 85 network_predict = lib.network_predict
 86 network_predict.argtypes = [c_void_p, POINTER(c_float)]
 87 
 88 reset_rnn = lib.reset_rnn
 89 reset_rnn.argtypes = [c_void_p]
 90 
 91 load_net = lib.load_network
 92 load_net.argtypes = [c_char_p, c_char_p, c_int]
 93 load_net.restype = c_void_p
 94 
 95 do_nms_obj = lib.do_nms_obj
 96 do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
 97 
 98 do_nms_sort = lib.do_nms_sort
 99 do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
100 
101 free_image = lib.free_image
102 free_image.argtypes = [IMAGE]
103 
104 letterbox_image = lib.letterbox_image
105 letterbox_image.argtypes = [IMAGE, c_int, c_int]
106 letterbox_image.restype = IMAGE
107 
108 load_meta = lib.get_metadata
109 lib.get_metadata.argtypes = [c_char_p]
110 lib.get_metadata.restype = METADATA
111 
112 load_image = lib.load_image_color
113 load_image.argtypes = [c_char_p, c_int, c_int]
114 load_image.restype = IMAGE
115 
116 rgbgr_image = lib.rgbgr_image
117 rgbgr_image.argtypes = [IMAGE]
118 
119 predict_image = lib.network_predict_image
120 predict_image.argtypes = [c_void_p, IMAGE]
121 predict_image.restype = POINTER(c_float)
122 
123 
124 def convertBack(x, y, w, h):
125     xmin = int(round(x - (w / 2)))
126     xmax = int(round(x + (w / 2)))
127     ymin = int(round(y - (h / 2)))
128     ymax = int(round(y + (h / 2)))
129     return xmin, ymin, xmax, ymax
130 
131 
132 def array_to_image(arr):
133     # need to return old values to avoid python freeing memory
134     # arr = np.asarray(arr, dtype='float64') # add by dengjie
135     arr = arr.transpose(2, 0, 1)
136     c, h, w = arr.shape[0:3]
137     arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
138     data = arr.ctypes.data_as(POINTER(c_float))
139     im = IMAGE(w, h, c, data)
140     return im, arr
141 
142 
143 def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
144     im, image = array_to_image(image)
145     rgbgr_image(im)
146     num = c_int(0)
147 
148     pnum = pointer(num)
149     predict_image(net, im)
150     dets = get_network_boxes(net, im.w, im.h, thresh,
151                              hier_thresh, None, 0, pnum)
152     num = pnum[0]
153     if nms: do_nms_obj(dets, num, meta.classes, nms)
154 
155     res = []
156     for j in range(num):
157         a = dets[j].prob[0:meta.classes]
158         if any(a):
159             ai = np.array(a).nonzero()[0]
160             for i in ai:
161                 b = dets[j].bbox
162                 res.append((meta.names[i], dets[j].prob[i],
163                             (b.x, b.y, b.w, b.h)))
164 
165     res = sorted(res, key=lambda x: -x[1])
166     if isinstance(image, bytes): free_image(im)
167     free_detections(dets, num)
168     return res
169 
170 
171 def mode_select(state):
172     if state not in {'picture', 'video', 'real_time'}:
173         raise ValueError('{} is not a valid argument!'.format(state))
174     if state == 'video' or state == 'real_time':
175         if state == 'real_time':
176             # video = "http://admin:admin@192.168.0.13:8081"
177             video = 0
178         elif state == 'video':
179             video = '../test/test_video/video7.mp4'
180         cap = cv2.VideoCapture(video)
181     else:
182         cap = 1
183     return cap
184 
185 
186 def find_object_in_picture(ret, img):
187     for i in ret:
188         # index = LABELS.index(str(i[0])[2:-1])
189         index = LABELS.index(i[0].decode())
190         color = COLORS[index].tolist()
191         x, y, w, h = i[2][0], i[2][1], i[2][2], i[2][3]
192         xmin, ymin, xmax, ymax = convertBack(float(x), float(y), float(w), float(h))
193         pt1 = (xmin, ymin)
194         pt2 = (xmax, ymax)
195         cv2.rectangle(img, pt1, pt2, color, 3)
196         if state == 'video':
197             text = i[0].decode()
198         else:
199             text = i[0].decode() + " [" + str(round(i[1] * 100, 2)) + "]"
200         (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
201         cv2.rectangle(img, (pt1[0], pt1[1] - text_h - baseline), (pt1[0] + text_w, pt1[1]), color, -1)
202         cv2.putText(img, text, (pt1[0], pt1[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
203     return img
204 
205 
206 def save_video(state, out_video):
207     if state == 'video':
208         if out_video:
209             img = cv2.imread('../result/result_frame/result_frame_0.jpg', 1)
210             isColor = 1
211             FPS = 20.0
212             frameWidth = img.shape[1]
213             frameHeight = img.shape[0]
214             fourcc = cv2.VideoWriter_fourcc(*'XVID')
215             out = cv2.VideoWriter('../result/result_video/result_video.avi', fourcc, FPS,
216                                   (frameWidth, frameHeight), isColor)
217             list = os.listdir(frame_path)
218             print('the number of video frames is', len(list))
219             for i in range(len(list)):
220                 frame = cv2.imread(
221                     '../result/result_frame/result_frame_%d.jpg' % i, 1)
222                 out.write(frame)
223                 if cv2.waitKey(25) & 0xFF == ord('q'):
224                     break
225             out.release()
226             print('video has already saved.')
227             return 1
228         else:
229             return 0
230     else:
231         return 0
232 
233 
234 def load_model():
235     net1 = load_net(b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.cfg",
236                     b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.weights",
237                     0)
238     meta1 = load_meta("/home/dengjie/dengjie/project/detection/from_darknet/cfg/coco.data".encode('utf-8'))
239     label_path = '../data/coco.names'
240     LABELS1 = open(label_path).read().strip().split("n")
241     num_class = len(LABELS1)
242     return net1, meta1, LABELS1, num_class
243 
244 
245 def random_color(num):
246     """
247     colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F']
248     color = ""
249     for i in range(6):
250         color += colorArr[random.randint(0, 14)]
251     return "#" + color
252     color = np.random.randint(0, 256, size=[1, 3])
253     color = color.tolist()[0]
254     """
255     # 为每个类别的边界框随机匹配相应颜色
256     np.random.seed(80)
257     COLORS = np.random.randint(0, 256, size=(num, 3), dtype='uint8')  #
258     return COLORS
259 
260 
261 if __name__ == "__main__":
262     k = 0
263     path = '../test/test_pic'
264     frame_path = '../result/result_frame'
265 
266     state = 'video'  # 检测模式选择,state = 'video','picture','real_time'
267 
268     net, meta, LABELS, class_num = load_model()
269     cap = mode_select(state)
270     COLORS = random_color(class_num)
271     print('start detect')
272 
273     if cap == 1:
274         test_list = os.listdir(path)
275         test_list.sort()
276         k = 0
277         sum_t = 0
278         print('test_list', test_list[1:])
279         for j in test_list:
280             time_p = time.time()
281             img = cv2.imread(os.path.join(path, j), 1)
282             r = detect(net, meta, img)
283             # print(r)
284             # [(b'person', 0.6372514963150024,
285             # (414.55322265625, 279.70245361328125, 483.99005126953125, 394.2349853515625))]
286             # 类别,识别概率,识别物体的X坐标,识别物体的Y坐标,识别物体的长度,识别物体的高度
287             image = find_object_in_picture(r, img)
288             t = time.time() - time_p
289             if j != test_list[0]:
290                 sum_t += t
291                 print('process ' + j + ' spend %.5fs' % t)
292                 cv2.imshow("img", img)
293                 cv2.imwrite('../result/result_pic/result_%d.jpg' % k, image)
294                 k += 1
295                 cv2.waitKey()
296                 cv2.destroyAllWindows()
297         print('Have processed %d pictures.' % k)
298         print('Total picture-processing time is %.5fs' % sum_t)
299         print('Average processing time is %.5fs' % (sum_t / k))
300         print('Have Done!')
301     else:
302         sum_v = 0
303         sum_fps = 0
304         i = 0  # 帧数记录
305         while True:
306             time_v = time.time()
307             ret, img = cap.read()
308             # fps = cap.get(cv2.CAP_PROP_FPS)
309             # print('fps', fps)
310             if ret:
311                 i += 1
312                 r = detect(net, meta, img)
313                 image = find_object_in_picture(r, img)
314                 cv2.imshow("window", image)
315                 t_v = time.time() - time_v
316                 fps = 1 / t_v
317                 if i > 1:
318                     print('FPS %.3f' % fps)
319                     sum_fps += fps
320                 sum_v += t_v
321                 if state == 'video':
322                     cv2.imwrite('../result/result_frame/result_frame_%d.jpg' % k, image)
323                     k += 1
324             else:  # 视频播放结束
325                 print('Total processing time is %.5fs' % sum_v)
326                 print('Detected frames : %d ' % i)
327                 print('Average fps is %.3f' % (sum_fps / (i - 1)))
328                 cap.release()
329                 cv2.destroyAllWindows()
330                 break
331             if cv2.waitKey(1) & 0xFF == ord('q'):
332                 # cv2.waitKey(1) 1为参数,单位毫秒,表示间隔时间,ord(' ')将字符转化为对应的整数(ASCII码);
333                 # cv2.waitKey()和(0)是等待输入
334                 print('Detected time is %.5fs' % sum_v)
335                 print('Average fps is %.3f' % (sum_fps / (i - 1)))
336                 print('Detected frames : %d ' % i)
337                 cap.release()
338                 cv2.destroyAllWindows()
339                 break
340         val = save_video(state, True)
341         if val == 1:
342             print('Have Done!')
343         else:
344             print('Detection has finished.')