semantic_segmentation_metric.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. 语义分割:
  3. [{'Label': 0, 'Confidence': 1.0, 'Image_size':[w, h], 'Contours': []}]
  4. [{'Label': 1, 'Confidence': 1.0, 'Image_size':[w, h], 'Contours':[ [ [1,2],[3,4] ] , [ [5,6],[7,8] ] ]}]
  5. [{'Label': 1, 'Confidence': 1.0, 'Image_size':[w, h], 'Contours':[ [ [1,2],[3,4] ] ]}]
  6. [{'Label': 1, 'Confidence': 1.0, 'Image_size':[w, h], 'Contours':[ [ [1,2],[3,4] ] ]},{'Label': 2, 'Confidence': 1.0, 'Contours':[ [ [1,2],[3,4] ] ]}]
  7. []
  8. """
  9. import json
  10. import sys
  11. import numpy as np
  12. import cv2
  13. from enum import Enum
  14. class MethodAveragePrecision(Enum):
  15. """
  16. Class representing if the coordinates are relative to the
  17. image size or are absolute values.
  18. """
  19. EveryPointInterpolation = 1
  20. ElevenPointInterpolation = 2
  21. class Evaluator_object_semamtic_segmentation(object):
  22. def __init__(self, iou_thres, method=MethodAveragePrecision.EveryPointInterpolation):
  23. """
  24. :param iou_thres: iou阈值
  25. :param method: 计算mAP的方法
  26. """
  27. self.method = method
  28. self.iou_thres = iou_thres
  29. self.wrong_file = {}
  30. self.background_images_all_indexs = []
  31. self.background_images_results = {}
  32. self.background_images_results_count = {}
  33. self.all_image_dict = {}
  34. self.all_no_background_images_pos_results = {}
  35. self.all_no_background_images_fp_results = {}
  36. self.all_no_background_images_tp_results = {}
  37. self.pred_results = []
  38. self.gt_results = []
  39. self.classes = []
  40. self.wrong_file['gt_wrong'] = []
  41. self.wrong_file['pred_wrong'] = []
  42. self.background_images_results_count['background_images_all_nums'] = 0
  43. self.all_image_dict['images_all_nums'] = 0
  44. def generate_metrics(self):
  45. classes = sorted(self.classes)
  46. # Precision x Recall is obtained individually by each class
  47. # Loop through by classes
  48. ret = [] # list containing metrics (precision, recall, average precision) of each class
  49. for c in classes:
  50. if c == 0:
  51. continue
  52. # Get only detection of class c
  53. dects = []
  54. [dects.append(d) for d in self.pred_results if d[1] == c]
  55. # Get only ground truths of class c, use filename as key
  56. gts = {}
  57. npos = 0
  58. for g in self.gt_results:
  59. if g[1] == c:
  60. npos += 1
  61. gts[g[0]] = gts.get(g[0], []) + [g]
  62. if str(c) not in self.all_no_background_images_pos_results.keys():
  63. self.all_no_background_images_pos_results[str(c)] = [g[0]]
  64. else:
  65. self.all_no_background_images_pos_results[str(c)].append(g[0])
  66. # sort detections by decreasing confidence
  67. dects = sorted(dects, key=lambda conf: conf[2], reverse=True)
  68. TP = np.zeros(len(dects))
  69. FP = np.zeros(len(dects))
  70. # create dictionary with amount of gts for each image
  71. det = {key: np.zeros(len(gts[key])) for key in gts}
  72. # print("Evaluating class: %s (%d detections)" % (str(c), len(dects)))
  73. # Loop through detections
  74. for d in range(len(dects)):
  75. # print('dect %s => %s' % (dects[d][0], dects[d][3],))
  76. # Find ground truth image
  77. gt = gts[dects[d][0]] if dects[d][0] in gts else []
  78. iouMax = sys.float_info.min
  79. for j in range(len(gt)):
  80. # print('Ground truth gt => %s' % (gt[j][3],))
  81. assert dects[d][4] == gt[j][4]
  82. image_size = dects[d][4]
  83. iou = Evaluator_object_semamtic_segmentation.iou(dects[d][3], gt[j][3], image_size)
  84. if iou > iouMax:
  85. iouMax = iou
  86. jmax = j
  87. # Assign detection as true positive/don't care/false positive
  88. if iouMax >= self.iou_thres:
  89. if det[dects[d][0]][jmax] == 0:
  90. TP[d] = 1 # count as true positive
  91. det[dects[d][0]][jmax] = 1 # flag as already 'seen'
  92. # print("TP")
  93. if str(c) not in self.all_no_background_images_tp_results.keys():
  94. self.all_no_background_images_tp_results[str(c)] = [dects[d][0]]
  95. else:
  96. self.all_no_background_images_tp_results[str(c)].append(dects[d][0])
  97. else:
  98. FP[d] = 1 # count as false positive
  99. if str(c) not in self.all_no_background_images_fp_results.keys():
  100. self.all_no_background_images_fp_results[str(c)] = [dects[d][0]]
  101. else:
  102. self.all_no_background_images_fp_results[str(c)].append(dects[d][0])
  103. # print("FP")
  104. # - A detected "cat" is overlaped with a GT "cat" with IOU >= IOUThreshold.
  105. else:
  106. FP[d] = 1 # count as false positive
  107. if str(c) not in self.all_no_background_images_fp_results.keys():
  108. self.all_no_background_images_fp_results[str(c)] = [dects[d][0]]
  109. else:
  110. self.all_no_background_images_fp_results[str(c)].append(dects[d][0])
  111. # print("FP")
  112. # compute precision, recall and average precision
  113. acc_FP = np.cumsum(FP)
  114. acc_TP = np.cumsum(TP)
  115. rec = acc_TP / npos
  116. prec = np.divide(acc_TP, (acc_FP + acc_TP))
  117. # Depending on the method, call the right implementation
  118. if self.method == MethodAveragePrecision.EveryPointInterpolation:
  119. [ap, mpre, mrec, ii] = Evaluator_object_semamtic_segmentation.CalculateAveragePrecision(rec, prec)
  120. else:
  121. [ap, mpre, mrec, _] = Evaluator_object_semamtic_segmentation.ElevenPointInterpolatedAP(rec, prec)
  122. # add class result in the dictionary to be returned
  123. r = {
  124. 'class': c,
  125. 'precision': prec,
  126. 'recall': rec,
  127. 'AP': ap,
  128. 'interpolated precision': mpre,
  129. 'interpolated recall': mrec,
  130. 'total positives': npos,
  131. 'total TP': np.sum(TP),
  132. 'total FP': np.sum(FP)
  133. }
  134. ret.append(r)
  135. return ret
  136. def add_batch(self, gt_file, pred_file, image_index):
  137. if gt_file == []:
  138. self.wrong_file['gt_wrong'].append(image_index)
  139. elif pred_file == []:
  140. self.wrong_file['pred_wrong'].append(image_index)
  141. elif gt_file != [] and pred_file != []:
  142. # 判断gt为背景的图像:
  143. self.all_image_dict['images_all_nums'] += 1
  144. if len(gt_file) == 1 and gt_file[0]['Label'] == 0 and gt_file[0]['Contours'] == []:
  145. self.background_images_results_count['background_images_all_nums'] += 1
  146. if len(pred_file) == 1 and pred_file[0]['Label'] == 0 and gt_file[0]['Contours'] == []:
  147. # if 'gtlabel_0_predlabel_{}'.format(0) not in self.background_images_results.keys():
  148. # self.background_images_results['gtlabel_0_predlabel_{}'.format(0)] = [image_index]
  149. # else:
  150. # self.background_images_results['gtlabel_0_predlabel_{}'.format(0)].append(image_index)
  151. if 'gtlabel_0_predlabel_{}'.format(0) not in self.background_images_results_count.keys():
  152. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(0)] = 1
  153. else:
  154. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(0)] += 1
  155. else:
  156. for i in range(len(pred_file)):
  157. pred_label = pred_file[i]['Label']
  158. if 'gtlabel_0_predlabel_{}'.format(pred_label) not in self.background_images_results.keys():
  159. self.background_images_results['gtlabel_0_predlabel_{}'.format(pred_label)] = [image_index]
  160. else:
  161. self.background_images_results['gtlabel_0_predlabel_{}'.format(pred_label)].append(
  162. image_index)
  163. if 'gtlabel_0_predlabel_{}'.format(
  164. pred_label) not in self.background_images_results_count.keys():
  165. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(pred_label)] = 1
  166. else:
  167. self.background_images_results_count['gtlabel_0_predlabel_{}'.format(pred_label)] += 1
  168. else:
  169. gt_len = len(gt_file)
  170. pred_len = len(pred_file)
  171. for i in range(gt_len):
  172. each_gt_label = gt_file[i]['Label']
  173. each_gt_counters = gt_file[i]['Contours']
  174. each_image_size = gt_file[i]['Image_size']
  175. self.gt_results.append([image_index, each_gt_label, 1.0, (each_gt_counters), each_image_size])
  176. if each_gt_label not in self.classes:
  177. self.classes.append(each_gt_label)
  178. for i in range(pred_len):
  179. each_pred_label = pred_file[i]['Label']
  180. each_pred_confidence = pred_file[i]['Confidence']
  181. each_pred_counters = pred_file[i]['Contours']
  182. each_pred_image_size = pred_file[i]['Image_size']
  183. self.pred_results.append([image_index, each_pred_label, each_pred_confidence, (each_pred_counters),
  184. each_pred_image_size])
  185. if each_pred_label not in self.classes:
  186. self.classes.append(each_pred_label)
  187. @staticmethod
  188. def iou(contoursA, contoursB, image_size):
  189. pred_mask = Evaluator_object_semamtic_segmentation.contourtomask(contoursA, image_size)
  190. gt_mask = Evaluator_object_semamtic_segmentation.contourtomask(contoursB, image_size)
  191. iou = Evaluator_object_semamtic_segmentation.iou_score(pred_mask, gt_mask)
  192. assert iou >= 0
  193. return iou
  194. @staticmethod
  195. def iou_score(output, target):
  196. smooth = 1e-5
  197. intersection = (output & target).sum()
  198. union = (output | target).sum()
  199. return (intersection + smooth) / (union + smooth)
  200. @staticmethod
  201. def contourtomask(contours, image_size):
  202. mask = np.zeros((image_size[1], image_size[0]), dtype=np.int32)
  203. if len(contours) == 1:
  204. contours_cv = np.zeros((len(contours[0]), 1, 2), dtype=np.int32)
  205. for i in range(len(contours[0])):
  206. contours_cv[i] = [int(contours[0][i][0]), int(contours[0][i][1])]
  207. mask = cv2.drawContours(mask.copy(), [contours_cv], -1, 1, cv2.FILLED)
  208. # 暂时一个contours中,如果存放了两条轮廓,则必须是去差值
  209. # 后续有超过两条的情况,需要再添加相关代码
  210. if len(contours) == 2:
  211. contours_cv_1 = np.zeros((len(contours[0]), 1, 2), dtype=np.int32)
  212. for i in range(len(contours[0])):
  213. contours_cv_1[i] = [int(contours[0][i][0]), int(contours[0][i][1])]
  214. mask_1 = cv2.drawContours(mask.copy(), [contours_cv_1], -1, 1, cv2.FILLED)
  215. contours_cv_2 = np.zeros((len(contours[1]), 1, 2), dtype=np.int32)
  216. for i in range(len(contours[1])):
  217. contours_cv_2[i] = [int(contours[1][i][0]), int(contours[1][i][1])]
  218. mask_2 = cv2.drawContours(mask.copy(), [contours_cv_2], -1, 1, cv2.FILLED)
  219. mask = abs(mask_2 - mask_1)
  220. if len(contours) >= 3:
  221. raise IOError('Error: contours: %s could not be process.' % contours)
  222. return mask
  223. @staticmethod
  224. def CalculateAveragePrecision(rec, prec):
  225. mrec = []
  226. mrec.append(0)
  227. [mrec.append(e) for e in rec]
  228. mrec.append(1)
  229. mpre = []
  230. mpre.append(0)
  231. [mpre.append(e) for e in prec]
  232. mpre.append(0)
  233. for i in range(len(mpre) - 1, 0, -1):
  234. mpre[i - 1] = max(mpre[i - 1], mpre[i])
  235. ii = []
  236. for i in range(len(mrec) - 1):
  237. if mrec[1 + i] != mrec[i]:
  238. ii.append(i + 1)
  239. ap = 0
  240. for i in ii:
  241. ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
  242. # return [ap, mpre[1:len(mpre)-1], mrec[1:len(mpre)-1], ii]
  243. return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii]
  244. @staticmethod
  245. # 11-point interpolated average precision
  246. def ElevenPointInterpolatedAP(rec, prec):
  247. # def CalculateAveragePrecision2(rec, prec):
  248. mrec = []
  249. # mrec.append(0)
  250. [mrec.append(e) for e in rec]
  251. # mrec.append(1)
  252. mpre = []
  253. # mpre.append(0)
  254. [mpre.append(e) for e in prec]
  255. # mpre.append(0)
  256. recallValues = np.linspace(0, 1, 11)
  257. recallValues = list(recallValues[::-1])
  258. rhoInterp = []
  259. recallValid = []
  260. # For each recallValues (0, 0.1, 0.2, ... , 1)
  261. for r in recallValues:
  262. # Obtain all recall values higher or equal than r
  263. argGreaterRecalls = np.argwhere(mrec[:] >= r)
  264. pmax = 0
  265. # If there are recalls above r
  266. if argGreaterRecalls.size != 0:
  267. pmax = max(mpre[argGreaterRecalls.min():])
  268. recallValid.append(r)
  269. rhoInterp.append(pmax)
  270. # By definition AP = sum(max(precision whose recall is above r))/11
  271. ap = sum(rhoInterp) / 11
  272. # Generating values for the plot
  273. rvals = []
  274. rvals.append(recallValid[0])
  275. [rvals.append(e) for e in recallValid]
  276. rvals.append(0)
  277. pvals = []
  278. pvals.append(0)
  279. [pvals.append(e) for e in rhoInterp]
  280. pvals.append(0)
  281. # rhoInterp = rhoInterp[::-1]
  282. cc = []
  283. for i in range(len(rvals)):
  284. p = (rvals[i], pvals[i - 1])
  285. if p not in cc:
  286. cc.append(p)
  287. p = (rvals[i], pvals[i])
  288. if p not in cc:
  289. cc.append(p)
  290. recallValues = [i[0] for i in cc]
  291. rhoInterp = [i[1] for i in cc]
  292. return [ap, rhoInterp, recallValues, None]
  293. if __name__ == "__main__":
  294. import os
  295. import glob
  296. currentPath = os.path.dirname(os.path.abspath(__file__))
  297. path_pred = os.path.join(currentPath, 'segment_test\\pred\\')
  298. path_gt = os.path.join(currentPath, 'segment_test\\gt\\')
  299. files = glob.glob(path_gt + "*.txt")
  300. files.sort()
  301. evaluator = Evaluator_object_semamtic_segmentation(iou_thres=0.5)
  302. for file in files:
  303. with open(file, "r", encoding='utf-8') as gtf:
  304. gt_anno = gtf.readlines()
  305. assert len(gt_anno) == 1, '每个txt对应一个json'
  306. gt_file = json.loads(gt_anno[0])
  307. pred_files_name = file.split('\\')[-1]
  308. pred_files = os.path.join(path_pred, pred_files_name)
  309. with open(pred_files, "r", encoding='utf-8') as predf:
  310. pred_anno = predf.readlines()
  311. assert len(pred_anno) == 1, '每个txt对应一个json'
  312. pred_file = json.loads(pred_anno[0])
  313. pred_len = len(pred_file)
  314. for i in range(pred_len):
  315. each_pred_label = pred_file[i]['Label']
  316. image_index = str(file.split('\\')[-1].split('.txt')[0])
  317. evaluator.add_batch(gt_file, pred_file, image_index)
  318. wrong_file = evaluator.wrong_file
  319. print(wrong_file)
  320. background_images_results = evaluator.background_images_results
  321. print(background_images_results)
  322. all_image_dict = evaluator.all_image_dict
  323. print(all_image_dict)
  324. metricsPerClass = evaluator.generate_metrics()
  325. for mc in metricsPerClass:
  326. c = mc['class']
  327. precision = mc['precision']
  328. recall = mc['recall']
  329. # Print AP per class
  330. # precision_all1 = precision[-1]
  331. # precision_all = mc['total TP'] / (mc['total TP'] + mc['total FP'])
  332. precision_all2 = 0 if (mc['total TP'] + mc['total FP']) == 0 else np.divide(mc['total TP'],
  333. (mc['total TP'] + mc['total FP']))
  334. # recall_all1 = recall[-1]
  335. # recall_all = mc['total TP'] / mc['total positives']
  336. recall_all2 = 0 if mc['total positives'] == 0 else np.divide(mc['total TP'], mc['total positives'])
  337. print(c)
  338. print(precision_all2)
  339. print(recall_all2)
  340. print(mc['total TP'])
  341. print(mc['total FP'])
  342. print(mc['total positives'])
  343. print('------------------------------')