当前位置:   article > 正文

YOLOV5-ONNX 模型推理_yolov5lite如何实现onnx推理

yolov5lite如何实现onnx推理

       在深度学习领域,YOLO(You Only Look Once)系列模型因其出色的实时物体检测性能而广受欢迎。随着ONNX(Open Neural Network Exchange)格式的普及,将YOLOv5模型转换为ONNX格式,使其能在多种平台和框架间无缝运行,成为了提高部署灵活性和效率的关键步骤。本文将指导你完成使用YOLOv5-ONNX模型进行物体检测的全过程,从环境搭建到实际推理。

准备工作

YOLOv5-ONNX模型文件

代码如下:

yolov5-interface

  1. import os
  2. import random
  3. import onnxruntime
  4. from tool import *
  5. def onnx_load(w):
  6. cuda = torch.cuda.is_available()
  7. providers = ['CUDAExecutionProvider','CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  8. session = onnxruntime.InferenceSession(w,providers=providers)
  9. output_names = [x.name for x in session.get_outputs()]
  10. print('-------',output_names)
  11. return session,output_names
  12. if __name__ == '__main__':
  13. # w = 'yolov5s.onnx'
  14. w = 'best.onnx'
  15. image_dir = './images'
  16. imgsz = [640,640]
  17. session,output_names = onnx_load(w)
  18. device = torch.device("cuda:0")
  19. image_list = os.listdir(image_dir)
  20. random.shuffle(image_list)
  21. for image_item in image_list:
  22. start_time = time.time()
  23. path = os.path.join(image_dir,image_item)
  24. im0 = cv2.imread(path)
  25. im, org_data = data_process_cv2(im0,imgsz)
  26. y = session.run(output_names,{session.get_inputs()[0].name: im})
  27. pred = torch.from_numpy(y[0]).to(device)
  28. pred = non_max_suppression(pred,conf_thres=0.25,iou_thres=0.45,max_det=1000)
  29. print("spend time:{0} ms".format((time.time() - start_time) * 1000))
  30. res_img = post_process_yolov5(pred[0],org_data)
  31. cv2.imshow('res',res_img)
  32. cv2.waitKey(0)

tool

  1. import cv2
  2. import numpy as np
  3. import torch.cuda
  4. import time
  5. import torchvision
  6. import yaml
  7. def resize_image_cv2(image,size):
  8. ih,iw,ic = image.shape
  9. w,h = size
  10. scale = min(w/iw,h / ih)
  11. nw = int(iw * scale)
  12. nh = int(ih * scale)
  13. image = cv2.resize(image,(nw,nh))
  14. new_image = np.ones((size[0],size[1],3),dtype='uint8') * 128
  15. start_h = (h - nh) / 2
  16. start_w = (w - nw) / 2
  17. end_h = size[1] - start_h
  18. end_w = size[0] - start_w
  19. new_image[int(start_h):int(end_h),int(start_w):int(end_w)] = image
  20. return new_image,nw,nh
  21. def xywh2xyxy(x):
  22. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  23. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  24. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  25. y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
  26. y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
  27. return y
  28. def box_iou(box1, box2, eps=1e-7):
  29. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  30. (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  31. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  32. # IoU = inter / (area1 + area2 - inter)
  33. return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  34. def non_max_suppression(
  35. prediction,
  36. conf_thres=0.25,
  37. iou_thres=0.35,
  38. classes=None,
  39. agnostic=False,
  40. multi_label=False,
  41. labels=(),
  42. max_det=300,
  43. nm=0, # number of masks
  44. ):
  45. # Checks
  46. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  47. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  48. if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
  49. prediction = prediction[0] # select only inference output
  50. device = prediction.device
  51. mps = 'mps' in device.type # Apple MPS
  52. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  53. prediction = prediction.cpu()
  54. bs = prediction.shape[0] # batch size
  55. nc = prediction.shape[2] - nm - 5 # number of classes
  56. xc = prediction[..., 4] > conf_thres # candidates
  57. # Settings
  58. # min_wh = 2 # (pixels) minimum box width and height
  59. max_wh = 7680 # (pixels) maximum box width and height
  60. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  61. time_limit = 0.5 + 0.05 * bs # seconds to quit after
  62. redundant = True # require redundant detections
  63. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  64. merge = False # use merge-NMS
  65. t = time.time()
  66. mi = 5 + nc # mask start index
  67. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  68. for xi, x in enumerate(prediction): # image index, image inference
  69. # Apply constraints
  70. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  71. x = x[xc[xi]] # confidence
  72. # Cat apriori labels if autolabelling
  73. if labels and len(labels[xi]):
  74. lb = labels[xi]
  75. v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  76. v[:, :4] = lb[:, 1:5] # box
  77. v[:, 4] = 1.0 # conf
  78. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  79. x = torch.cat((x, v), 0)
  80. # If none remain process next image
  81. if not x.shape[0]:
  82. continue
  83. # Compute conf
  84. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  85. # Box/Mask
  86. box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
  87. mask = x[:, mi:] # zero columns if no masks
  88. # Detections matrix nx6 (xyxy, conf, cls)
  89. if multi_label:
  90. i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
  91. x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
  92. else: # best class only
  93. conf, j = x[:, 5:mi].max(1, keepdim=True)
  94. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  95. # Filter by class
  96. if classes is not None:
  97. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  98. # Check shape
  99. n = x.shape[0] # number of boxes
  100. if not n: # no boxes
  101. continue
  102. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  103. # Batched NMS
  104. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  105. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  106. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  107. i = i[:max_det] # limit detections
  108. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  109. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  110. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  111. weights = iou * scores[None] # box weights
  112. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  113. if redundant:
  114. i = i[iou.sum(1) > 1] # require redundancy
  115. output[xi] = x[i]
  116. if mps:
  117. output[xi] = output[xi].to(device)
  118. if (time.time() - t) > time_limit:
  119. # LOGGER.warning(f'WARNING NMS time limit {time_limit:.3f}s exceeded')
  120. break # time limit exceeded
  121. return output
  122. def data_process_cv2(frame,input_shape):
  123. image_data,nw,nh = resize_image_cv2(frame,(input_shape[1],input_shape[0]))
  124. org_data = image_data.copy()
  125. np_data = np.array(image_data,np.float32)
  126. np_data = np_data / 255
  127. image_data = np.expand_dims(np.transpose(np_data,(2,0,1)),0)
  128. image_data = np.ascontiguousarray(image_data)
  129. return image_data,org_data
  130. def post_process_yolov5(det,im,label_path='coco128.yaml'):
  131. if len(det):
  132. det[:,:4] = scale_boxes(im.shape[:2],det[:,:4],im.shape).round()
  133. names = yaml_load(label_path)['names']
  134. colors = Colors()
  135. for *xyxy,conf,cls in reversed(det):
  136. c = int(cls)
  137. label = names[c]
  138. box_label(im,xyxy,label,color=colors(c,True))
  139. return im
  140. def scale_boxes(img1_shape,boxes,img0_shape,ratio_pad=None):
  141. if ratio_pad is None:
  142. gain = min(img1_shape[0] / img0_shape[0],img1_shape[1] / img0_shape[1])
  143. pad = (img1_shape[1] - img0_shape[1] * gain) / 2,(img1_shape[0] - img0_shape[0] * gain)
  144. else:
  145. gain = ratio_pad[0][0]
  146. pad = ratio_pad[1]
  147. boxes[...,[0,2]] -= pad[0]
  148. boxes[...,[1,3]] -= pad[1]
  149. boxes[...,:4] /= gain
  150. clip_boxes(boxes,img0_shape)
  151. return boxes
  152. def clip_boxes(boxes,shape):
  153. if isinstance(boxes,torch.Tensor):
  154. boxes[...,0].clamp_(0,shape[1])
  155. boxes[...,1].clamp_(0,shape[0])
  156. boxes[...,2].clamp_(0,shape[1])
  157. boxes[...,3].clamp_(0,shape[0])
  158. else:
  159. boxes[..., [0,2]] = boxes[...,[0,2]].clip(0, shape[1])
  160. boxes[..., [1,3]] = boxes[...,[1,3]].clip(0, shape[0])
  161. def yaml_load(file='coco128.yaml'):
  162. with open(file,errors='ignore') as f:
  163. return yaml.safe_load(f)
  164. class Colors:
  165. # Ultralytics color palette https://ultralytics.com/
  166. def __init__(self):
  167. """
  168. Initializes the Colors class with a palette derived from Ultralytics color scheme, converting hex codes to RGB.
  169. Colors derived from `hex = matplotlib.colors.TABLEAU_COLORS.values()`.
  170. """
  171. hexs = (
  172. "FF3838",
  173. "FF9D97",
  174. "FF701F",
  175. "FFB21D",
  176. "CFD231",
  177. "48F90A",
  178. "92CC17",
  179. "3DDB86",
  180. "1A9334",
  181. "00D4BB",
  182. "2C99A8",
  183. "00C2FF",
  184. "344593",
  185. "6473FF",
  186. "0018EC",
  187. "8438FF",
  188. "520085",
  189. "CB38FF",
  190. "FF95C8",
  191. "FF37C7",
  192. )
  193. self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
  194. self.n = len(self.palette)
  195. def __call__(self, i, bgr=False):
  196. """Returns color from palette by index `i`, in BGR format if `bgr=True`, else RGB; `i` is an integer index."""
  197. c = self.palette[int(i) % self.n]
  198. return (c[2], c[1], c[0]) if bgr else c
  199. @staticmethod
  200. def hex2rgb(h):
  201. """Converts hex color codes to RGB values (i.e. default PIL order)."""
  202. return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4))
  203. def box_label(im, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
  204. lw = 2
  205. p1,p2 = (int(box[0]),int(box[1])),(int(box[2]),int(box[3]))
  206. cv2.rectangle(im,p1,p2,color,thickness=lw,lineType=cv2.LINE_AA)
  207. if label:
  208. tf = max(lw - 1,1)
  209. w,h = cv2.getTextSize(label,0,fontScale=lw / 3,thickness=tf)[0]
  210. outside = p1[1] - h >= 3
  211. p2 = p1[0] + w,p1[1] - h - 3 if outside else p1[1] + h + 3
  212. cv2.rectangle(im,p1,p2,color,-1,cv2.LINE_AA)
  213. cv2.putText(im,label,(p1[0],p1[1] - 2 if outside else p1[1] + h + 2),
  214. 0,lw / 3,txt_color,thickness=tf,lineType=cv2.LINE_AA)
  215. def is_ascii(s) -> bool:
  216. # Convert list, tuple, None, etc. to string
  217. s = str(s)
  218. # Check if the string is composed of only ASCII characters
  219. return all(ord(c) < 128 for c in s)

源码在yolov5-master中可找到

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/725326
推荐阅读
相关标签
  

闽ICP备14008679号