此项目使用YOLOv3对视频进行物体检测并输出新的视频。
参考YOLO官网:https://pjreddie.com/darknet/yolo/,安装darknet,配置opencv和gpu。
主要步骤如下:
官网主要使用c接口对图片、视频或是进行实时检测,Python代码可实现的功能较少,这里修改Python代码来实现以上功能。
对Python文件夹中darknet.py的修改主要如下:
先贴几张图,最后合成的视频地址:https://www.bilibili.com/video/av94163474,代码在下方。
此代码也可实现对批量图片的检测以及实时检测。
修改后的darknet.py文件如下:
1 from ctypes import * 2 import random 3 import cv2 4 import numpy as np 5 import os 6 import time 7 8 9 def sample(probs): 10 s = sum(probs) 11 probs = [a / s for a in probs] 12 r = random.uniform(0, 1) 13 for i in range(len(probs)): 14 r = r - probs[i] 15 if r <= 0: 16 return i 17 return len(probs) - 1 18 19 20 def c_array(ctype, values): 21 arr = (ctype * len(values))() 22 arr[:] = values 23 return arr 24 25 26 class BOX(Structure): 27 _fields_ = [("x", c_float), 28 ("y", c_float), 29 ("w", c_float), 30 ("h", c_float)] 31 32 33 class DETECTION(Structure): 34 _fields_ = [("bbox", BOX), 35 ("classes", c_int), 36 ("prob", POINTER(c_float)), 37 ("mask", POINTER(c_float)), 38 ("objectness", c_float), 39 ("sort_class", c_int)] 40 41 42 class IMAGE(Structure): 43 _fields_ = [("w", c_int), 44 ("h", c_int), 45 ("c", c_int), 46 ("data", POINTER(c_float))] 47 48 49 class METADATA(Structure): 50 _fields_ = [("classes", c_int), 51 ("names", POINTER(c_char_p))] 52 53 54 lib = CDLL("./libdarknet.so", RTLD_GLOBAL) 55 lib.network_width.argtypes = [c_void_p] 56 lib.network_width.restype = c_int 57 lib.network_height.argtypes = [c_void_p] 58 lib.network_height.restype = c_int 59 60 predict = lib.network_predict 61 predict.argtypes = [c_void_p, POINTER(c_float)] 62 predict.restype = POINTER(c_float) 63 64 set_gpu = lib.cuda_set_device 65 set_gpu.argtypes = [c_int] 66 67 make_image = lib.make_image 68 make_image.argtypes = [c_int, c_int, c_int] 69 make_image.restype = IMAGE 70 71 get_network_boxes = lib.get_network_boxes 72 get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)] 73 get_network_boxes.restype = POINTER(DETECTION) 74 75 make_network_boxes = lib.make_network_boxes 76 make_network_boxes.argtypes = [c_void_p] 77 make_network_boxes.restype = POINTER(DETECTION) 78 79 free_detections = lib.free_detections 80 free_detections.argtypes = [POINTER(DETECTION), c_int] 81 82 free_ptrs = lib.free_ptrs 83 free_ptrs.argtypes = [POINTER(c_void_p), c_int] 84 85 network_predict = lib.network_predict 86 network_predict.argtypes = [c_void_p, POINTER(c_float)] 87 88 reset_rnn = lib.reset_rnn 89 reset_rnn.argtypes = [c_void_p] 90 91 load_net = lib.load_network 92 load_net.argtypes = [c_char_p, c_char_p, c_int] 93 load_net.restype = c_void_p 94 95 do_nms_obj = lib.do_nms_obj 96 do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] 97 98 do_nms_sort = lib.do_nms_sort 99 do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] 100 101 free_image = lib.free_image 102 free_image.argtypes = [IMAGE] 103 104 letterbox_image = lib.letterbox_image 105 letterbox_image.argtypes = [IMAGE, c_int, c_int] 106 letterbox_image.restype = IMAGE 107 108 load_meta = lib.get_metadata 109 lib.get_metadata.argtypes = [c_char_p] 110 lib.get_metadata.restype = METADATA 111 112 load_image = lib.load_image_color 113 load_image.argtypes = [c_char_p, c_int, c_int] 114 load_image.restype = IMAGE 115 116 rgbgr_image = lib.rgbgr_image 117 rgbgr_image.argtypes = [IMAGE] 118 119 predict_image = lib.network_predict_image 120 predict_image.argtypes = [c_void_p, IMAGE] 121 predict_image.restype = POINTER(c_float) 122 123 124 def convertBack(x, y, w, h): 125 xmin = int(round(x - (w / 2))) 126 xmax = int(round(x + (w / 2))) 127 ymin = int(round(y - (h / 2))) 128 ymax = int(round(y + (h / 2))) 129 return xmin, ymin, xmax, ymax 130 131 132 def array_to_image(arr): 133 # need to return old values to avoid python freeing memory 134 # arr = np.asarray(arr, dtype='float64') # add by dengjie 135 arr = arr.transpose(2, 0, 1) 136 c, h, w = arr.shape[0:3] 137 arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0 138 data = arr.ctypes.data_as(POINTER(c_float)) 139 im = IMAGE(w, h, c, data) 140 return im, arr 141 142 143 def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): 144 im, image = array_to_image(image) 145 rgbgr_image(im) 146 num = c_int(0) 147 148 pnum = pointer(num) 149 predict_image(net, im) 150 dets = get_network_boxes(net, im.w, im.h, thresh, 151 hier_thresh, None, 0, pnum) 152 num = pnum[0] 153 if nms: do_nms_obj(dets, num, meta.classes, nms) 154 155 res = [] 156 for j in range(num): 157 a = dets[j].prob[0:meta.classes] 158 if any(a): 159 ai = np.array(a).nonzero()[0] 160 for i in ai: 161 b = dets[j].bbox 162 res.append((meta.names[i], dets[j].prob[i], 163 (b.x, b.y, b.w, b.h))) 164 165 res = sorted(res, key=lambda x: -x[1]) 166 if isinstance(image, bytes): free_image(im) 167 free_detections(dets, num) 168 return res 169 170 171 def mode_select(state): 172 if state not in {'picture', 'video', 'real_time'}: 173 raise ValueError('{} is not a valid argument!'.format(state)) 174 if state == 'video' or state == 'real_time': 175 if state == 'real_time': 176 # video = "http://admin:admin@192.168.0.13:8081" 177 video = 0 178 elif state == 'video': 179 video = '../test/test_video/video7.mp4' 180 cap = cv2.VideoCapture(video) 181 else: 182 cap = 1 183 return cap 184 185 186 def find_object_in_picture(ret, img): 187 for i in ret: 188 # index = LABELS.index(str(i[0])[2:-1]) 189 index = LABELS.index(i[0].decode()) 190 color = COLORS[index].tolist() 191 x, y, w, h = i[2][0], i[2][1], i[2][2], i[2][3] 192 xmin, ymin, xmax, ymax = convertBack(float(x), float(y), float(w), float(h)) 193 pt1 = (xmin, ymin) 194 pt2 = (xmax, ymax) 195 cv2.rectangle(img, pt1, pt2, color, 3) 196 if state == 'video': 197 text = i[0].decode() 198 else: 199 text = i[0].decode() + " [" + str(round(i[1] * 100, 2)) + "]" 200 (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 201 cv2.rectangle(img, (pt1[0], pt1[1] - text_h - baseline), (pt1[0] + text_w, pt1[1]), color, -1) 202 cv2.putText(img, text, (pt1[0], pt1[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 203 return img 204 205 206 def save_video(state, out_video): 207 if state == 'video': 208 if out_video: 209 img = cv2.imread('../result/result_frame/result_frame_0.jpg', 1) 210 isColor = 1 211 FPS = 20.0 212 frameWidth = img.shape[1] 213 frameHeight = img.shape[0] 214 fourcc = cv2.VideoWriter_fourcc(*'XVID') 215 out = cv2.VideoWriter('../result/result_video/result_video.avi', fourcc, FPS, 216 (frameWidth, frameHeight), isColor) 217 list = os.listdir(frame_path) 218 print('the number of video frames is', len(list)) 219 for i in range(len(list)): 220 frame = cv2.imread( 221 '../result/result_frame/result_frame_%d.jpg' % i, 1) 222 out.write(frame) 223 if cv2.waitKey(25) & 0xFF == ord('q'): 224 break 225 out.release() 226 print('video has already saved.') 227 return 1 228 else: 229 return 0 230 else: 231 return 0 232 233 234 def load_model(): 235 net1 = load_net(b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.cfg", 236 b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.weights", 237 0) 238 meta1 = load_meta("/home/dengjie/dengjie/project/detection/from_darknet/cfg/coco.data".encode('utf-8')) 239 label_path = '../data/coco.names' 240 LABELS1 = open(label_path).read().strip().split("n") 241 num_class = len(LABELS1) 242 return net1, meta1, LABELS1, num_class 243 244 245 def random_color(num): 246 """ 247 colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'] 248 color = "" 249 for i in range(6): 250 color += colorArr[random.randint(0, 14)] 251 return "#" + color 252 color = np.random.randint(0, 256, size=[1, 3]) 253 color = color.tolist()[0] 254 """ 255 # 为每个类别的边界框随机匹配相应颜色 256 np.random.seed(80) 257 COLORS = np.random.randint(0, 256, size=(num, 3), dtype='uint8') # 258 return COLORS 259 260 261 if __name__ == "__main__": 262 k = 0 263 path = '../test/test_pic' 264 frame_path = '../result/result_frame' 265 266 state = 'video' # 检测模式选择,state = 'video','picture','real_time' 267 268 net, meta, LABELS, class_num = load_model() 269 cap = mode_select(state) 270 COLORS = random_color(class_num) 271 print('start detect') 272 273 if cap == 1: 274 test_list = os.listdir(path) 275 test_list.sort() 276 k = 0 277 sum_t = 0 278 print('test_list', test_list[1:]) 279 for j in test_list: 280 time_p = time.time() 281 img = cv2.imread(os.path.join(path, j), 1) 282 r = detect(net, meta, img) 283 # print(r) 284 # [(b'person', 0.6372514963150024, 285 # (414.55322265625, 279.70245361328125, 483.99005126953125, 394.2349853515625))] 286 # 类别,识别概率,识别物体的X坐标,识别物体的Y坐标,识别物体的长度,识别物体的高度 287 image = find_object_in_picture(r, img) 288 t = time.time() - time_p 289 if j != test_list[0]: 290 sum_t += t 291 print('process ' + j + ' spend %.5fs' % t) 292 cv2.imshow("img", img) 293 cv2.imwrite('../result/result_pic/result_%d.jpg' % k, image) 294 k += 1 295 cv2.waitKey() 296 cv2.destroyAllWindows() 297 print('Have processed %d pictures.' % k) 298 print('Total picture-processing time is %.5fs' % sum_t) 299 print('Average processing time is %.5fs' % (sum_t / k)) 300 print('Have Done!') 301 else: 302 sum_v = 0 303 sum_fps = 0 304 i = 0 # 帧数记录 305 while True: 306 time_v = time.time() 307 ret, img = cap.read() 308 # fps = cap.get(cv2.CAP_PROP_FPS) 309 # print('fps', fps) 310 if ret: 311 i += 1 312 r = detect(net, meta, img) 313 image = find_object_in_picture(r, img) 314 cv2.imshow("window", image) 315 t_v = time.time() - time_v 316 fps = 1 / t_v 317 if i > 1: 318 print('FPS %.3f' % fps) 319 sum_fps += fps 320 sum_v += t_v 321 if state == 'video': 322 cv2.imwrite('../result/result_frame/result_frame_%d.jpg' % k, image) 323 k += 1 324 else: # 视频播放结束 325 print('Total processing time is %.5fs' % sum_v) 326 print('Detected frames : %d ' % i) 327 print('Average fps is %.3f' % (sum_fps / (i - 1))) 328 cap.release() 329 cv2.destroyAllWindows() 330 break 331 if cv2.waitKey(1) & 0xFF == ord('q'): 332 # cv2.waitKey(1) 1为参数,单位毫秒,表示间隔时间,ord(' ')将字符转化为对应的整数(ASCII码); 333 # cv2.waitKey()和(0)是等待输入 334 print('Detected time is %.5fs' % sum_v) 335 print('Average fps is %.3f' % (sum_fps / (i - 1))) 336 print('Detected frames : %d ' % i) 337 cap.release() 338 cv2.destroyAllWindows() 339 break 340 val = save_video(state, True) 341 if val == 1: 342 print('Have Done!') 343 else: 344 print('Detection has finished.')
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:【目标检测】使用YOLOv3对视频进行物体检测并输出新的视频 - Python技术站