当前位置:   article > 正文

rtm姿态跟踪

rtm姿态跟踪

6年前:

GitHub - YuliangXiu/PoseFlow: PoseFlow: Efficient Online Pose Tracking (BMVC'18)

报错:

Clarification on min_keypoints in tracking · Issue #1411 · open-mmlab/mmpose · GitHub

https://github.com/open-mmlab/mmpose/blob/c8e91ff456d82c7f985bf938cabfb68b2aa51d27/mmpose/apis/inference_tracking.py#L227

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import numpy as np
  4. from mmpose.core import OneEuroFilter, oks_iou
  5. def _compute_iou(bboxA, bboxB):
  6. """Compute the Intersection over Union (IoU) between two boxes .
  7. Args:
  8. bboxA (list): The first bbox info (left, top, right, bottom, score).
  9. bboxB (list): The second bbox info (left, top, right, bottom, score).
  10. Returns:
  11. float: The IoU value.
  12. """
  13. x1 = max(bboxA[0], bboxB[0])
  14. y1 = max(bboxA[1], bboxB[1])
  15. x2 = min(bboxA[2], bboxB[2])
  16. y2 = min(bboxA[3], bboxB[3])
  17. inter_area = max(0, x2 - x1) * max(0, y2 - y1)
  18. bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
  19. bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
  20. union_area = float(bboxA_area + bboxB_area - inter_area)
  21. if union_area == 0:
  22. union_area = 1e-5
  23. warnings.warn('union_area=0 is unexpected')
  24. iou = inter_area / union_area
  25. return iou
  26. def _track_by_iou(res, results_last, thr):
  27. """Get track id using IoU tracking greedily.
  28. Args:
  29. res (dict): The bbox & pose results of the person instance.
  30. results_last (list[dict]): The bbox & pose & track_id info of the
  31. last frame (bbox_result, pose_result, track_id).
  32. thr (float): The threshold for iou tracking.
  33. Returns:
  34. int: The track id for the new person instance.
  35. list[dict]: The bbox & pose & track_id info of the persons
  36. that have not been matched on the last frame.
  37. dict: The matched person instance on the last frame.
  38. """
  39. bbox = list(res['bbox'])
  40. max_iou_score = -1
  41. max_index = -1
  42. match_result = {}
  43. for index, res_last in enumerate(results_last):
  44. bbox_last = list(res_last['bbox'])
  45. iou_score = _compute_iou(bbox, bbox_last)
  46. if iou_score > max_iou_score:
  47. max_iou_score = iou_score
  48. max_index = index
  49. if max_iou_score > thr:
  50. track_id = results_last[max_index]['track_id']
  51. match_result = results_last[max_index]
  52. del results_last[max_index]
  53. else:
  54. track_id = -1
  55. return track_id, results_last, match_result
  56. def _track_by_oks(res, results_last, thr):
  57. """Get track id using OKS tracking greedily.
  58. Args:
  59. res (dict): The pose results of the person instance.
  60. results_last (list[dict]): The pose & track_id info of the
  61. last frame (pose_result, track_id).
  62. thr (float): The threshold for oks tracking.
  63. Returns:
  64. int: The track id for the new person instance.
  65. list[dict]: The pose & track_id info of the persons
  66. that have not been matched on the last frame.
  67. dict: The matched person instance on the last frame.
  68. """
  69. pose = res['keypoints'].reshape((-1))
  70. area = res['area']
  71. max_index = -1
  72. match_result = {}
  73. if len(results_last) == 0:
  74. return -1, results_last, match_result
  75. pose_last = np.array(
  76. [res_last['keypoints'].reshape((-1)) for res_last in results_last])
  77. area_last = np.array([res_last['area'] for res_last in results_last])
  78. oks_score = oks_iou(pose, pose_last, area, area_last)
  79. max_index = np.argmax(oks_score)
  80. if oks_score[max_index] > thr:
  81. track_id = results_last[max_index]['track_id']
  82. match_result = results_last[max_index]
  83. del results_last[max_index]
  84. else:
  85. track_id = -1
  86. return track_id, results_last, match_result
  87. def _get_area(results):
  88. """Get bbox for each person instance on the current frame.
  89. Args:
  90. results (list[dict]): The pose results of the current frame
  91. (pose_result).
  92. Returns:
  93. list[dict]: The bbox & pose info of the current frame
  94. (bbox_result, pose_result, area).
  95. """
  96. for result in results:
  97. if 'bbox' in result:
  98. result['area'] = ((result['bbox'][2] - result['bbox'][0]) *
  99. (result['bbox'][3] - result['bbox'][1]))
  100. else:
  101. xmin = np.min(
  102. result['keypoints'][:, 0][result['keypoints'][:, 0] > 0],
  103. initial=1e10)
  104. xmax = np.max(result['keypoints'][:, 0])
  105. ymin = np.min(
  106. result['keypoints'][:, 1][result['keypoints'][:, 1] > 0],
  107. initial=1e10)
  108. ymax = np.max(result['keypoints'][:, 1])
  109. result['area'] = (xmax - xmin) * (ymax - ymin)
  110. result['bbox'] = np.array([xmin, ymin, xmax, ymax])
  111. return results
  112. def _temporal_refine(result, match_result, fps=None):
  113. """Refine koypoints using tracked person instance on last frame.
  114. Args:
  115. results (dict): The pose results of the current frame
  116. (pose_result).
  117. match_result (dict): The pose results of the last frame
  118. (match_result)
  119. Returns:
  120. (array): The person keypoints after refine.
  121. """
  122. if 'one_euro' in match_result:
  123. result['keypoints'][:, :2] = match_result['one_euro'](
  124. result['keypoints'][:, :2])
  125. result['one_euro'] = match_result['one_euro']
  126. else:
  127. result['one_euro'] = OneEuroFilter(result['keypoints'][:, :2], fps=fps)
  128. return result['keypoints']
  129. def get_track_id(results,
  130. results_last,
  131. next_id,
  132. min_keypoints=3,
  133. use_oks=False,
  134. tracking_thr=0.3,
  135. use_one_euro=False,
  136. fps=None):
  137. """Get track id for each person instance on the current frame.
  138. Args:
  139. results (list[dict]): The bbox & pose results of the current frame
  140. (bbox_result, pose_result).
  141. results_last (list[dict], optional): The bbox & pose & track_id info
  142. of the last frame (bbox_result, pose_result, track_id). None is
  143. equivalent to an empty result list. Default: None
  144. next_id (int): The track id for the new person instance.
  145. min_keypoints (int): Minimum number of keypoints recognized as person.
  146. 0 means no minimum threshold required. Default: 3.
  147. use_oks (bool): Flag to using oks tracking. default: False.
  148. tracking_thr (float): The threshold for tracking.
  149. use_one_euro (bool): Option to use one-euro-filter. default: False.
  150. fps (optional): Parameters that d_cutoff
  151. when one-euro-filter is used as a video input
  152. Returns:
  153. tuple:
  154. - results (list[dict]): The bbox & pose & track_id info of the \
  155. current frame (bbox_result, pose_result, track_id).
  156. - next_id (int): The track id for the new person instance.
  157. """
  158. if use_one_euro:
  159. warnings.warn(
  160. 'In the future, get_track_id() will no longer perform '
  161. 'temporal refinement and the arguments `use_one_euro` and '
  162. '`fps` will be deprecated. This part of function has been '
  163. 'migrated to Smoother (mmpose.core.Smoother). See '
  164. 'demo/top_down_pose_trackign_demo_with_mmdet.py for an '
  165. 'example.', DeprecationWarning)
  166. if results_last is None:
  167. results_last = []
  168. results = _get_area(results)
  169. if use_oks:
  170. _track = _track_by_oks
  171. else:
  172. _track = _track_by_iou
  173. for result in results:
  174. track_id, results_last, match_result = _track(result, results_last,
  175. tracking_thr)
  176. if track_id == -1:
  177. if np.count_nonzero(result['keypoints'][:, 1]) >= min_keypoints:
  178. result['track_id'] = next_id
  179. next_id += 1
  180. else:
  181. # If the number of keypoints detected is small,
  182. # delete that person instance.
  183. result['keypoints'][:, 1] = -10
  184. result['bbox'] *= 0
  185. result['track_id'] = -1
  186. else:
  187. result['track_id'] = track_id
  188. if use_one_euro:
  189. result['keypoints'] = _temporal_refine(
  190. result, match_result, fps=fps)
  191. del match_result
  192. return results, next_id
  193. def vis_pose_tracking_result(model,
  194. img,
  195. result,
  196. radius=4,
  197. thickness=1,
  198. kpt_score_thr=0.3,
  199. dataset='TopDownCocoDataset',
  200. dataset_info=None,
  201. show=False,
  202. out_file=None):
  203. """Visualize the pose tracking results on the image.
  204. Args:
  205. model (nn.Module): The loaded detector.
  206. img (str | np.ndarray): Image filename or loaded image.
  207. result (list[dict]): The results to draw over `img`
  208. (bbox_result, pose_result).
  209. radius (int): Radius of circles.
  210. thickness (int): Thickness of lines.
  211. kpt_score_thr (float): The threshold to visualize the keypoints.
  212. skeleton (list[tuple]): Default None.
  213. show (bool): Whether to show the image. Default True.
  214. out_file (str|None): The filename of the output visualization image.
  215. """
  216. if hasattr(model, 'module'):
  217. model = model.module
  218. palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
  219. [230, 230, 0], [255, 153, 255], [153, 204, 255],
  220. [255, 102, 255], [255, 51, 255], [102, 178, 255],
  221. [51, 153, 255], [255, 153, 153], [255, 102, 102],
  222. [255, 51, 51], [153, 255, 153], [102, 255, 102],
  223. [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
  224. [255, 255, 255]])
  225. if dataset_info is None and dataset is not None:
  226. warnings.warn(
  227. 'dataset is deprecated.'
  228. 'Please set `dataset_info` in the config.'
  229. 'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
  230. DeprecationWarning)
  231. # TODO: These will be removed in the later versions.
  232. if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset',
  233. 'TopDownOCHumanDataset'):
  234. kpt_num = 17
  235. skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
  236. [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
  237. [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
  238. [3, 5], [4, 6]]
  239. elif dataset == 'TopDownCocoWholeBodyDataset':
  240. kpt_num = 133
  241. skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
  242. [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
  243. [8, 10], [1, 2], [0, 1], [0, 2],
  244. [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18],
  245. [15, 19], [16, 20], [16, 21], [16, 22], [91, 92],
  246. [92, 93], [93, 94], [94, 95], [91, 96], [96, 97],
  247. [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],
  248. [102, 103], [91, 104], [104, 105], [105, 106],
  249. [106, 107], [91, 108], [108, 109], [109, 110],
  250. [110, 111], [112, 113], [113, 114], [114, 115],
  251. [115, 116], [112, 117], [117, 118], [118, 119],
  252. [119, 120], [112, 121], [121, 122], [122, 123],
  253. [123, 124], [112, 125], [125, 126], [126, 127],
  254. [127, 128], [112, 129], [129, 130], [130, 131],
  255. [131, 132]]
  256. radius = 1
  257. elif dataset == 'TopDownAicDataset':
  258. kpt_num = 14
  259. skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5],
  260. [8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13],
  261. [0, 6], [3, 9]]
  262. elif dataset == 'TopDownMpiiDataset':
  263. kpt_num = 16
  264. skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
  265. [7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13],
  266. [13, 14], [14, 15]]
  267. elif dataset in ('OneHand10KDataset', 'FreiHandDataset',
  268. 'PanopticDataset'):
  269. kpt_num = 21
  270. skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
  271. [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13],
  272. [13, 14], [14, 15], [15, 16], [0, 17], [17, 18],
  273. [18, 19], [19, 20]]
  274. elif dataset == 'InterHand2DDataset':
  275. kpt_num = 21
  276. skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9],
  277. [9, 10], [10, 11], [12, 13], [13, 14], [14, 15],
  278. [16, 17], [17, 18], [18, 19], [3, 20], [7, 20],
  279. [11, 20], [15, 20], [19, 20]]
  280. else:
  281. raise NotImplementedError()
  282. elif dataset_info is not None:
  283. kpt_num = dataset_info.keypoint_num
  284. skeleton = dataset_info.skeleton
  285. for res in result:
  286. track_id = res['track_id']
  287. bbox_color = palette[track_id % len(palette)]
  288. pose_kpt_color = palette[[track_id % len(palette)] * kpt_num]
  289. pose_link_color = palette[[track_id % len(palette)] * len(skeleton)]
  290. img = model.show_result(
  291. img, [res],
  292. skeleton,
  293. radius=radius,
  294. thickness=thickness,
  295. pose_kpt_color=pose_kpt_color,
  296. pose_link_color=pose_link_color,
  297. bbox_color=tuple(bbox_color.tolist()),
  298. kpt_score_thr=kpt_score_thr,
  299. show=show,
  300. out_file=out_file)
  301. return img

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/59514?site
推荐阅读
相关标签
  

闽ICP备14008679号