classification_metric.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """
  2. 分类:
  3. [{'Label': 0, 'Confidence': 1.0}]
  4. []
  5. """
  6. import numpy as np
  7. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  8. # 分类:
  9. from enum import Enum
  10. import sys
  11. class MethodAveragePrecision(Enum):
  12. """
  13. Class representing if the coordinates are relative to the
  14. image size or are absolute values.
  15. """
  16. EveryPointInterpolation = 1
  17. ElevenPointInterpolation = 2
  18. class Evaluator_classification():
  19. def __init__(self, iou_thres, method=MethodAveragePrecision.EveryPointInterpolation):
  20. self.wrong_file = {}
  21. self.background_images_all_indexs = []
  22. self.background_images_results = {}
  23. self.background_images_results_count = {}
  24. self.all_image_dict = {}
  25. self.all_no_background_images_pos_results = {}
  26. self.all_no_background_images_fp_results = {}
  27. self.all_no_background_images_tp_results = {}
  28. self.pred_results = []
  29. self.gt_results = []
  30. self.classes = []
  31. self.wrong_file['gt_wrong'] = []
  32. self.wrong_file['pred_wrong'] = []
  33. self.background_images_results_count['background_images_all_nums'] = 0
  34. self.all_image_dict['images_all_nums'] = 0
  35. def generate_metrics(self):
  36. classes = sorted(self.classes)
  37. ret = []
  38. assert len(self.gt_results) == len(self.pred_results)
  39. for c in classes:
  40. dects = []
  41. gts = []
  42. for i in range(len(self.pred_results)):
  43. if self.pred_results[i][1] == c:
  44. dects.append(self.pred_results[i])
  45. gts.append(self.gt_results[i])
  46. npos = 0
  47. for g in self.gt_results:
  48. if g[1] == c:
  49. npos += 1
  50. if str(c) not in self.all_no_background_images_pos_results.keys():
  51. self.all_no_background_images_pos_results[str(c)] = [g[0]]
  52. else:
  53. self.all_no_background_images_pos_results[str(c)].append(g[0])
  54. # sort detections by decreasing confidence
  55. TP = np.zeros(len(dects))
  56. FP = np.zeros(len(dects))
  57. for d in range(len(dects)):
  58. if dects[d][1] == gts[d][1]:
  59. TP[d] = 1
  60. if str(c) not in self.all_no_background_images_tp_results.keys():
  61. self.all_no_background_images_tp_results[str(c)] = [dects[d][0]]
  62. else:
  63. self.all_no_background_images_tp_results[str(c)].append(dects[d][0])
  64. else:
  65. FP[d] = 1 # count as false positive
  66. if str(c) not in self.all_no_background_images_fp_results.keys():
  67. self.all_no_background_images_fp_results[str(c)] = [dects[d][0]]
  68. else:
  69. self.all_no_background_images_fp_results[str(c)].append(dects[d][0])
  70. acc_FP = np.cumsum(FP)
  71. acc_TP = np.cumsum(TP)
  72. rec = 0 if npos == 0 else acc_TP / npos
  73. prec = np.divide(acc_TP, (acc_FP + acc_TP))
  74. # add class result in the dictionary to be returned
  75. r = {
  76. 'class': c,
  77. 'precision': prec,
  78. 'recall': rec,
  79. 'total positives': npos,
  80. 'total TP': np.sum(TP),
  81. 'total FP': np.sum(FP)
  82. }
  83. ret.append(r)
  84. return ret
  85. def add_batch(self, gt_file, pred_file, image_index):
  86. if gt_file == []:
  87. self.wrong_file['gt_wrong'].append(image_index)
  88. elif pred_file == []:
  89. self.wrong_file['pred_wrong'].append(image_index)
  90. else:
  91. self.all_image_dict['images_all_nums'] += 1
  92. assert len(gt_file) == 1 and len(pred_file) == 1, '分类标签必须为1'
  93. each_gt_label = gt_file[0]['Label']
  94. each_pred_label = pred_file[0]['Label']
  95. self.gt_results.append([image_index, each_gt_label])
  96. self.pred_results.append([image_index, each_pred_label])
  97. if each_gt_label not in self.classes:
  98. self.classes.append(each_gt_label)
  99. if each_pred_label not in self.classes:
  100. self.classes.append(each_pred_label)
  101. def sklearn_result(all_gt, all_pred, image_index, inferNetLabelCount):
  102. assert all_gt.shape == all_pred.shape, "gt和pred数量不对"
  103. wrong_gt_index = np.argwhere(all_gt == -1)
  104. wrong_pred_index = np.argwhere(all_pred == -1)
  105. wrong_pred = image_index[wrong_pred_index]
  106. wrong_gt = image_index[wrong_gt_index]
  107. print(wrong_pred.transpose())
  108. print(wrong_gt.transpose())
  109. # 将错误的gt和pred合并,并去除重复,
  110. wrong_index_concate = np.concatenate((wrong_gt_index, wrong_pred_index))
  111. b = np.ascontiguousarray(wrong_index_concate).view(
  112. np.dtype((np.void, wrong_index_concate.dtype.itemsize * wrong_index_concate.shape[1])))
  113. _, idx = np.unique(b, return_index=True)
  114. wrong_index = wrong_index_concate[idx]
  115. all_gt = np.delete(all_gt, wrong_index)
  116. all_pred = np.delete(all_pred, wrong_index)
  117. image_index = np.delete(image_index, wrong_index)
  118. # 横坐标是pred #纵坐标是gt
  119. confusionMatrix = confusion_matrix(all_gt, all_pred, [i for i in range(inferNetLabelCount)])
  120. print(confusionMatrix)
  121. accuracy = np.divide(np.diag(confusionMatrix).sum(), confusionMatrix.sum())
  122. accuracy2 = accuracy_score(all_gt, all_pred)
  123. gt_num = np.sum(confusionMatrix, axis=1)
  124. pre_num = np.sum(confusionMatrix, axis=0)
  125. tp = np.diag(confusionMatrix)
  126. fp = pre_num - tp
  127. precision = np.divide(np.diag(confusionMatrix), pre_num)
  128. precision2 = precision_score(all_gt, all_pred, average=None)
  129. print(precision)
  130. recall = np.divide(np.diag(confusionMatrix), gt_num)
  131. recall2 = recall_score(all_gt, all_pred, average=None)
  132. print(recall)
  133. if __name__ == "__main__":
  134. import random
  135. image_num = 100
  136. all_gt = np.zeros(image_num)
  137. all_pred = np.zeros(image_num)
  138. image_index = np.zeros(image_num)
  139. Count = 3
  140. evaluator = Evaluator_classification(iou_thres=0)
  141. for i in range(image_num):
  142. a1 = random.randint(0, Count)
  143. b1 = random.randint(0, Count)
  144. temp = random.randint(0, 10)
  145. if temp < 2:
  146. a = []
  147. b = [{'Label': b1, 'Confidence': 1.0}]
  148. all_gt[i] = -1
  149. all_pred[i] = b1
  150. elif 2 <= temp < 4:
  151. b = []
  152. a = [{'Label': a1, 'Confidence': 1.0}]
  153. all_gt[i] = a1
  154. all_pred[i] = -1
  155. else:
  156. a = [{'Label': a1, 'Confidence': 1.0}]
  157. b = [{'Label': b1, 'Confidence': 1.0}]
  158. all_gt[i] = a1
  159. all_pred[i] = b1
  160. image_index[i] = str(i)
  161. evaluator.add_batch(a, b, str(i))
  162. print("--------------------------------------可用图像report--------------------------------------")
  163. wrong_file = evaluator.wrong_file
  164. print("gt file有问题的image:{}".format(wrong_file['gt_wrong']))
  165. print("pred file有问题的image:{}".format(wrong_file['pred_wrong']))
  166. all_image_dict = evaluator.all_image_dict
  167. print('所有可用的图像数量:{}'.format(all_image_dict['images_all_nums']))
  168. print("--------------------------------------背景图像report--------------------------------------")
  169. background_images_results_count = evaluator.background_images_results_count
  170. for key in background_images_results_count.keys():
  171. print(key + ':' + str(background_images_results_count[key]))
  172. background_images_results = evaluator.background_images_results
  173. for key in background_images_results.keys():
  174. print(key + ':' + str(background_images_results[key]))
  175. print("--------------------------------------非背景图像report--------------------------------------")
  176. print('非背景图像数量:{}'.format(all_image_dict['images_all_nums']))
  177. metricsPerClass = evaluator.generate_metrics()
  178. for mc in metricsPerClass:
  179. c = mc['class']
  180. precision = mc['precision']
  181. recall = mc['recall']
  182. total_positives = mc['total positives']
  183. total_TP = mc['total TP']
  184. total_FP = mc['total FP']
  185. precision_all = 0 if (total_TP + total_FP) == 0 else total_TP / (total_TP + total_FP)
  186. recall_all = 0 if total_positives == 0 else total_TP / total_positives
  187. # Print AP per class
  188. print('Label:%s, total_TP: %d, total_FP: %d, total_positives_gt: %d, precision: %f, recall: %f '
  189. % (c, total_TP, total_FP, total_positives, precision_all, recall_all))
  190. try:
  191. average_precision = mc['AP']
  192. print('Label:%s, mAP: %f, ' % (c, average_precision))
  193. except:
  194. continue
  195. all_no_background_images_fp_results = evaluator.all_no_background_images_fp_results
  196. for key in all_no_background_images_fp_results.keys():
  197. each_result = all_no_background_images_fp_results[key]
  198. print('Label:' + key + ',FP对应的image:' + str(sorted(set(each_result), key=each_result.index)))
  199. all_no_background_images_pos_results = evaluator.all_no_background_images_pos_results
  200. all_no_background_images_tp_results = evaluator.all_no_background_images_tp_results
  201. for key in all_no_background_images_pos_results.keys():
  202. each_pos_results = all_no_background_images_pos_results[key]
  203. if key in all_no_background_images_tp_results.keys():
  204. each_tp_results = all_no_background_images_tp_results[key]
  205. else:
  206. each_tp_results = []
  207. if key in all_no_background_images_fp_results.keys():
  208. each_fp_results = all_no_background_images_fp_results[key]
  209. else:
  210. each_fp_results = []
  211. each_fn_results = []
  212. for elem in each_pos_results:
  213. if elem not in each_tp_results:
  214. each_fn_results.append(elem)
  215. print('Label:' + key + ',FN对应的image:' + str(each_fn_results))
  216. # print(evaluator.wrong_gt_file)
  217. # print(evaluator.wrong_pred_file)
  218. # print(evaluator.get_recall())
  219. # print(evaluator.all_pre_gt_result_image_name_dict)
  220. sklearn_result(all_gt, all_pred, image_index, Count + 1)