object_detection_metric.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. """
  2. 检测:
  3. [{'Label': 1, 'Confidence': 1.0, 'BoundingBox':[2, 3, 4, 5]}]
  4. [{'Label': 1, 'Confidence': 1.0, 'BoundingBox':[2, 3, 4, 5]},{'Label': 2, 'Confidence': 1.0, 'BoundingBox':[2, 3, 4, 5]}]
  5. [{'Label': 0, 'Confidence': 1.0, 'BoundingBox':[0, 0, 0, 0]}]
  6. []
  7. """
  8. import json
  9. import sys
  10. import numpy as np
  11. # 检测:
  12. from enum import Enum
  13. class MethodAveragePrecision(Enum):
  14. """
  15. Class representing if the coordinates are relative to the
  16. image size or are absolute values.
  17. """
  18. EveryPointInterpolation = 1
  19. ElevenPointInterpolation = 2
  20. # https://github.com/rafaelpadilla/Object-Detection-Metrics
  21. class Evaluator_object_detection(object):
  22. def __init__(self, iou_thres, method=MethodAveragePrecision.EveryPointInterpolation):
  23. self.method = method
  24. self.iou_thres = iou_thres
  25. self.wrong_file = {}
  26. self.background_images_all_indexs = []
  27. self.background_images_results = {}
  28. self.background_images_results_count = {}
  29. self.all_image_dict = {}
  30. self.all_no_background_images_pos_results = {}
  31. self.all_no_background_images_fp_results = {}
  32. self.all_no_background_images_tp_results = {}
  33. self.pred_results = []
  34. self.gt_results = []
  35. self.classes = []
  36. self.wrong_file['gt_wrong'] = []
  37. self.wrong_file['pred_wrong'] = []
  38. self.background_images_results_count['background_images_all_nums'] = 0
  39. self.all_image_dict['images_all_nums'] = 0
  40. def generate_metrics(self):
  41. classes = sorted(self.classes)
  42. # Precision x Recall is obtained individually by each class
  43. # Loop through by classes
  44. ret = [] # list containing metrics (precision, recall, average precision) of each class
  45. for c in classes:
  46. if c == 0:
  47. continue
  48. # Get only detection of class c
  49. dects = []
  50. [dects.append(d) for d in self.pred_results if d[1] == c]
  51. # Get only ground truths of class c, use filename as key
  52. gts = {}
  53. npos = 0
  54. for g in self.gt_results:
  55. if g[1] == c:
  56. npos += 1
  57. gts[g[0]] = gts.get(g[0], []) + [g]
  58. if str(c) not in self.all_no_background_images_pos_results.keys():
  59. self.all_no_background_images_pos_results[str(c)] = [g[0]]
  60. else:
  61. self.all_no_background_images_pos_results[str(c)].append(g[0])
  62. # sort detections by decreasing confidence
  63. dects = sorted(dects, key=lambda conf: conf[2], reverse=True)
  64. TP = np.zeros(len(dects))
  65. FP = np.zeros(len(dects))
  66. # create dictionary with amount of gts for each image
  67. det = {key: np.zeros(len(gts[key])) for key in gts}
  68. # print("Evaluating class: %s (%d detections)" % (str(c), len(dects)))
  69. # Loop through detections
  70. for d in range(len(dects)):
  71. # print('dect %s => %s' % (dects[d][0], dects[d][3],))
  72. # Find ground truth image
  73. gt = gts[dects[d][0]] if dects[d][0] in gts else []
  74. iouMax = sys.float_info.min
  75. for j in range(len(gt)):
  76. # print('Ground truth gt => %s' % (gt[j][3],))
  77. iou = Evaluator_object_detection.iou(dects[d][3], gt[j][3])
  78. if iou > iouMax:
  79. iouMax = iou
  80. jmax = j
  81. # Assign detection as true positive/don't care/false positive
  82. if iouMax >= self.iou_thres:
  83. if det[dects[d][0]][jmax] == 0:
  84. TP[d] = 1 # count as true positive
  85. det[dects[d][0]][jmax] = 1 # flag as already 'seen'
  86. # print("TP")
  87. if str(c) not in self.all_no_background_images_tp_results.keys():
  88. self.all_no_background_images_tp_results[str(c)] = [dects[d][0]]
  89. else:
  90. self.all_no_background_images_tp_results[str(c)].append(dects[d][0])
  91. else:
  92. FP[d] = 1 # count as false positive
  93. if str(c) not in self.all_no_background_images_fp_results.keys():
  94. self.all_no_background_images_fp_results[str(c)] = [dects[d][0]]
  95. else:
  96. self.all_no_background_images_fp_results[str(c)].append(dects[d][0])
  97. # print("FP")
  98. # - A detected "cat" is overlaped with a GT "cat" with IOU >= IOUThreshold.
  99. else:
  100. FP[d] = 1 # count as false positive
  101. if str(c) not in self.all_no_background_images_fp_results.keys():
  102. self.all_no_background_images_fp_results[str(c)] = [dects[d][0]]
  103. else:
  104. self.all_no_background_images_fp_results[str(c)].append(dects[d][0])
  105. # print("FP")
  106. # compute precision, recall and average precision
  107. acc_FP = np.cumsum(FP)
  108. acc_TP = np.cumsum(TP)
  109. rec = acc_TP / npos
  110. prec = np.divide(acc_TP, (acc_FP + acc_TP))
  111. # Depending on the method, call the right implementation
  112. if self.method == MethodAveragePrecision.EveryPointInterpolation:
  113. [ap, mpre, mrec, ii] = Evaluator_object_detection.CalculateAveragePrecision(rec, prec)
  114. else:
  115. [ap, mpre, mrec, _] = Evaluator_object_detection.ElevenPointInterpolatedAP(rec, prec)
  116. # add class result in the dictionary to be returned
  117. r = {
  118. 'class': c,
  119. 'precision': prec,
  120. 'recall': rec,
  121. 'AP': ap,
  122. 'interpolated precision': mpre,
  123. 'interpolated recall': mrec,
  124. 'total positives': npos,
  125. 'total TP': np.sum(TP),
  126. 'total FP': np.sum(FP)
  127. }
  128. ret.append(r)
  129. return ret
  130. def add_batch(self, gt_file, pred_file, image_index):
  131. if gt_file == []:
  132. self.wrong_file['gt_wrong'].append(image_index)
  133. elif pred_file == []:
  134. self.wrong_file['pred_wrong'].append(image_index)
  135. elif gt_file != [] and pred_file != []:
  136. # 判断gt为背景的图像:
  137. self.all_image_dict['images_all_nums'] += 1
  138. if len(gt_file) == 1 and gt_file[0]['Label'] == 0 and gt_file[0]['BoundingBox'] == [0, 0, 0, 0]:
  139. self.background_images_results_count['background_images_all_nums'] += 1
  140. if len(pred_file) == 1 and pred_file[0]['Label'] == 0 and gt_file[0]['BoundingBox'] == [0, 0, 0, 0]:
  141. # if 'gtlabel_0_predlabel_{}'.format(0) not in self.background_images_results.keys():
  142. # self.background_images_results['gtlabel_0_predlabel_{}'.format(0)] = [image_index]
  143. # else:
  144. # self.background_images_results['gtlabel_0_predlabel_{}'.format(0)].append(image_index)
  145. if 'gtlabel_0_predlabel_{}'.format(0) not in self.background_images_results_count.keys():
  146. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(0)] = 1
  147. else:
  148. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(0)] += 1
  149. else:
  150. for i in range(len(pred_file)):
  151. pred_label = pred_file[i]['Label']
  152. if 'gtlabel_0_predlabel_{}'.format(pred_label) not in self.background_images_results.keys():
  153. self.background_images_results['gtlabel_0_predlabel_{}'.format(pred_label)] = [image_index]
  154. else:
  155. self.background_images_results['gtlabel_0_predlabel_{}'.format(pred_label)].append(
  156. image_index)
  157. if 'gtlabel_0_predlabel_{}'.format(
  158. pred_label) not in self.background_images_results_count.keys():
  159. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(pred_label)] = 1
  160. else:
  161. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(pred_label)] += 1
  162. else:
  163. gt_len = len(gt_file)
  164. pred_len = len(pred_file)
  165. for i in range(gt_len):
  166. each_gt_label = gt_file[i]['Label']
  167. each_gt_bb_left = gt_file[i]['BoundingBox'][0]
  168. each_gt_bb_top = gt_file[i]['BoundingBox'][1]
  169. each_gt_bb_right = gt_file[i]['BoundingBox'][2]
  170. each_gt_bb_bottom = gt_file[i]['BoundingBox'][3]
  171. self.gt_results.append([image_index, each_gt_label, 1.0,
  172. (each_gt_bb_left, each_gt_bb_top, each_gt_bb_right, each_gt_bb_bottom)])
  173. if each_gt_label not in self.classes:
  174. self.classes.append(each_gt_label)
  175. for i in range(pred_len):
  176. each_pred_label = pred_file[i]['Label']
  177. each_pred_confidence = pred_file[i]['Confidence']
  178. each_pred_bb_left = pred_file[i]['BoundingBox'][0]
  179. each_pred_bb_top = pred_file[i]['BoundingBox'][1]
  180. each_pred_bb_right = pred_file[i]['BoundingBox'][2]
  181. each_pred_bb_bottom = pred_file[i]['BoundingBox'][3]
  182. self.pred_results.append([image_index, each_pred_label, each_pred_confidence, (
  183. each_pred_bb_left, each_pred_bb_top, each_pred_bb_right, each_pred_bb_bottom)])
  184. if each_pred_label not in self.classes:
  185. self.classes.append(each_pred_label)
  186. @staticmethod
  187. def iou(boxA, boxB):
  188. # if boxes dont intersect
  189. if Evaluator_object_detection._boxesIntersect(boxA, boxB) is False:
  190. return 0
  191. interArea = Evaluator_object_detection._getIntersectionArea(boxA, boxB)
  192. union = Evaluator_object_detection._getUnionAreas(boxA, boxB, interArea=interArea)
  193. # intersection over union
  194. iou = interArea / union
  195. assert iou >= 0
  196. return iou
  197. # boxA = (Ax1,Ay1,Ax2,Ay2)
  198. # boxB = (Bx1,By1,Bx2,By2)
  199. @staticmethod
  200. def _boxesIntersect(boxA, boxB):
  201. if boxA[0] > boxB[2]:
  202. return False # boxA is right of boxB
  203. if boxB[0] > boxA[2]:
  204. return False # boxA is left of boxB
  205. if boxA[3] < boxB[1]:
  206. return False # boxA is above boxB
  207. if boxA[1] > boxB[3]:
  208. return False # boxA is below boxB
  209. return True
  210. @staticmethod
  211. def _getIntersectionArea(boxA, boxB):
  212. xA = max(boxA[0], boxB[0])
  213. yA = max(boxA[1], boxB[1])
  214. xB = min(boxA[2], boxB[2])
  215. yB = min(boxA[3], boxB[3])
  216. # intersection area
  217. return (xB - xA + 1) * (yB - yA + 1)
  218. @staticmethod
  219. def _getUnionAreas(boxA, boxB, interArea=None):
  220. area_A = Evaluator_object_detection._getArea(boxA)
  221. area_B = Evaluator_object_detection._getArea(boxB)
  222. if interArea is None:
  223. interArea = Evaluator_object_detection._getIntersectionArea(boxA, boxB)
  224. return float(area_A + area_B - interArea)
  225. @staticmethod
  226. def _getArea(box):
  227. return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
  228. @staticmethod
  229. def CalculateAveragePrecision(rec, prec):
  230. mrec = []
  231. mrec.append(0)
  232. [mrec.append(e) for e in rec]
  233. mrec.append(1)
  234. mpre = []
  235. mpre.append(0)
  236. [mpre.append(e) for e in prec]
  237. mpre.append(0)
  238. for i in range(len(mpre) - 1, 0, -1):
  239. mpre[i - 1] = max(mpre[i - 1], mpre[i])
  240. ii = []
  241. for i in range(len(mrec) - 1):
  242. if mrec[1 + i] != mrec[i]:
  243. ii.append(i + 1)
  244. ap = 0
  245. for i in ii:
  246. ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
  247. # return [ap, mpre[1:len(mpre)-1], mrec[1:len(mpre)-1], ii]
  248. return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii]
  249. @staticmethod
  250. # 11-point interpolated average precision
  251. def ElevenPointInterpolatedAP(rec, prec):
  252. # def CalculateAveragePrecision2(rec, prec):
  253. mrec = []
  254. # mrec.append(0)
  255. [mrec.append(e) for e in rec]
  256. # mrec.append(1)
  257. mpre = []
  258. # mpre.append(0)
  259. [mpre.append(e) for e in prec]
  260. # mpre.append(0)
  261. recallValues = np.linspace(0, 1, 11)
  262. recallValues = list(recallValues[::-1])
  263. rhoInterp = []
  264. recallValid = []
  265. # For each recallValues (0, 0.1, 0.2, ... , 1)
  266. for r in recallValues:
  267. # Obtain all recall values higher or equal than r
  268. argGreaterRecalls = np.argwhere(mrec[:] >= r)
  269. pmax = 0
  270. # If there are recalls above r
  271. if argGreaterRecalls.size != 0:
  272. pmax = max(mpre[argGreaterRecalls.min():])
  273. recallValid.append(r)
  274. rhoInterp.append(pmax)
  275. # By definition AP = sum(max(precision whose recall is above r))/11
  276. ap = sum(rhoInterp) / 11
  277. # Generating values for the plot
  278. rvals = []
  279. rvals.append(recallValid[0])
  280. [rvals.append(e) for e in recallValid]
  281. rvals.append(0)
  282. pvals = []
  283. pvals.append(0)
  284. [pvals.append(e) for e in rhoInterp]
  285. pvals.append(0)
  286. # rhoInterp = rhoInterp[::-1]
  287. cc = []
  288. for i in range(len(rvals)):
  289. p = (rvals[i], pvals[i - 1])
  290. if p not in cc:
  291. cc.append(p)
  292. p = (rvals[i], pvals[i])
  293. if p not in cc:
  294. cc.append(p)
  295. recallValues = [i[0] for i in cc]
  296. rhoInterp = [i[1] for i in cc]
  297. return [ap, rhoInterp, recallValues, None]
  298. if __name__ == "__main__":
  299. from object_detection_test.test_files_generate import test_pred_files_get, test_gt_files_get
  300. import os
  301. currentPath = os.path.dirname(os.path.abspath(__file__))
  302. path_pred = os.path.join(currentPath, 'object_detection_test\\all_files\\pred')
  303. pred_files = test_pred_files_get(path_pred)
  304. path_gt = os.path.join(currentPath, 'object_detection_test\\all_files\\gt')
  305. gt_files = test_gt_files_get(path_gt)
  306. evaluator = Evaluator_object_detection(iou_thres=0.3, method=MethodAveragePrecision.EveryPointInterpolation)
  307. assert len(gt_files) == len(pred_files)
  308. for i in range(len(gt_files)):
  309. image_index = str(i)
  310. evaluator.add_batch(json.loads(gt_files[i]), json.loads(pred_files[i]), image_index)
  311. wrong_file = evaluator.wrong_file
  312. print(wrong_file)
  313. background_images_results = evaluator.background_images_results
  314. print(background_images_results)
  315. all_image_dict = evaluator.all_image_dict
  316. print(all_image_dict)
  317. metricsPerClass = evaluator.generate_metrics()
  318. for mc in metricsPerClass:
  319. c = mc['class']
  320. precision = mc['precision']
  321. recall = mc['recall']
  322. average_precision = mc['AP']
  323. ipre = mc['interpolated precision']
  324. irec = mc['interpolated recall']
  325. # Print AP per class
  326. print('%s: %f' % (c, average_precision))
  327. precision_all1 = precision[-1]
  328. precision_all = mc['total TP'] / (mc['total TP'] + mc['total FP'])
  329. recall_all1 = recall[-1]
  330. recall_all = mc['total TP'] / mc['total positives']
  331. print()