compute_test_set_aps.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import sys
  2. import onnx
  3. import os
  4. import argparse
  5. import numpy as np
  6. import cv2
  7. import onnxruntime
  8. import multiprocessing
  9. from tool.utils import *
  10. from tool.darknet2onnx import *
  11. import glob
  12. onnx_path = 'C:\\Users\\VINNO\\Desktop\\新建文件夹 (2)\\pytorch-YOLOv4-master\\20210824_yolov4_1_224_224_static.onnx'
  13. OUTPUT_PATH = 'C:\\Users\\VINNO\\Desktop\\新建文件夹 (2)\\pytorch-YOLOv4-master\\output\\'
  14. sess_options = onnxruntime.SessionOptions()
  15. #sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
  16. #控制用于运行模型的线程数 controls the number of threads to use to run the model
  17. #sess_options.intra_op_num_threads = 1
  18. #When sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL,
  19. # you can set sess_options.inter_op_num_threads to control the number of threads used to parallelize the execution of the graph (across nodes).
  20. #sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
  21. #sess_options.inter_op_num_threads = 1
  22. session = onnxruntime.InferenceSession(onnx_path, sess_options)
  23. IN_IMAGE_H = session.get_inputs()[0].shape[2]
  24. IN_IMAGE_W = session.get_inputs()[0].shape[3]
  25. input_name = session.get_inputs()[0].name
  26. # class_id_map = None
  27. class_id_map = [None] * 8
  28. class_id_map[1] = 1
  29. class_id_map[2] = 1
  30. class_id_map[3] = 1
  31. class_id_map[4] = 1
  32. class_id_map[5] = 1
  33. class_id_map[6] = 1
  34. class_id_map[7] = 1
  35. class_list = [None] * 8
  36. class_list[0] = '__background__'
  37. class_list[1] = 'lipomyoma'
  38. class_list[2] = 'BIRADS2'
  39. class_list[3] = 'BIRADS3'
  40. class_list[4] = 'BIRADS4a'
  41. class_list[5] = 'BIRADS4b'
  42. class_list[6] = 'BIRADS4c'
  43. class_list[7] = 'BIRADS5'
  44. #在测试集上计算mAP,绘制ROC曲线等,所有的评估曲线保存在 Output\\TestCurves 文件夹内
  45. def evaluate_detections(c, boxes, gtinfo, dstpath, ovthresh=0.5):
  46. """
  47. 计算评价指标
  48. precision=tp/(tp+fp):计算准确率时,以预测框个数为准
  49. recall=tp/(tp+fn):计算召回率时,以gt框个数为准
  50. :param c: 当前类在classes中的序号
  51. :param boxes: 所有的预测结果中属于该类的所有box,
  52. 类型:list,list长度为测试图像总数,list中每一项为长度为k*5的子list
  53. 其中,k为当前图像上的预测框个数,每个预测框用'左,上,右,下,置信度'5个值表示
  54. 若当前图像没有预测到任何框,则子list为空
  55. :param gtinfo: 所有gt中属于该类的所有box
  56. 类型:list,list长度为测试图像总数,list中每一项为长度为n*4的子list
  57. 其中,n为当前图像上gt框的个数,每个gt框用'左,上,右,下'4个值表示
  58. 若当前图像没有gtbox,则子list为空
  59. :param dstpath: 输出目录(绘制的roc曲线等保存在指定位置)
  60. :param ovthresh: 重叠率阈值,当预测框与gtbox的IOU必须大于该值才会被当作是tp
  61. :return: roc曲线,tp/fp曲线等
  62. """
  63. assert len(boxes) == len(gtinfo), "the length of pred boxes and gtinfos should be the same"
  64. assert ovthresh > 0, "the overlaps between gt boxes and pred boxes should be greater than 0."
  65. detBboxes = np.zeros((0, 7), dtype=np.float32)
  66. gtBboxes = np.zeros((0, 6), dtype=np.float32)
  67. for reading_ind in range(len(boxes)):
  68. preds = boxes[reading_ind]
  69. gts = gtinfo[reading_ind]
  70. assert len(preds) % 5 == 0 and len(gts) % 4 == 0, \
  71. "incorrect inputs for the pred boxes or gt boxes of image:{} class:{}".format(reading_ind, c)
  72. num_preds = int(len(preds) / 5)
  73. num_gts = int(len(gts) / 4)
  74. if num_preds > 0:
  75. reshaped_preds = np.reshape(preds, (num_preds, 5))
  76. imginds = reading_ind * np.ones((num_preds, 1), dtype=np.float32)
  77. tpsign = np.zeros((num_preds, 1), dtype=np.float32)
  78. detBboxes = np.vstack((detBboxes, np.hstack((imginds, reshaped_preds, tpsign))))
  79. if num_gts > 0:
  80. reshaped_gts = np.reshape(gts, (num_gts, 4))
  81. imginds = reading_ind * np.ones((num_gts, 1), dtype=np.float32)
  82. tpsign = np.zeros((num_gts, 1), dtype=np.float32)
  83. gtBboxes = np.vstack((gtBboxes, np.hstack((imginds, reshaped_gts, tpsign))))
  84. # 然后计算Tps和Fps
  85. nd = len(detBboxes)
  86. ngt = len(gtBboxes)
  87. if nd > 0:
  88. # sort by confidence
  89. sorted_ind = np.argsort(-detBboxes[:, -2])
  90. detBboxes = detBboxes[sorted_ind, :]
  91. for d in range(nd):
  92. detbox = detBboxes[d, :]
  93. imgind = detbox[0]
  94. gtind = np.where(gtBboxes[:, 0] == imgind)[0]
  95. if len(gtind > 0):
  96. gts = gtBboxes[gtind, :]
  97. # 计算重叠区域
  98. ix1 = np.maximum(gts[:, 1], detbox[1])
  99. iy1 = np.maximum(gts[:, 2], detbox[2])
  100. ix2 = np.minimum(gts[:, 3], detbox[3])
  101. iy2 = np.minimum(gts[:, 4], detbox[4])
  102. iw = np.maximum(ix2 - ix1 + 1., 0.)
  103. ih = np.maximum(iy2 - iy1 + 1., 0.)
  104. inters = iw * ih
  105. # unions
  106. uni = ((detbox[3] - detbox[1] + 1.) * (detbox[4] - detbox[2] + 1.) +
  107. (gts[:, 3] - gts[:, 1] + 1.) * (gts[:, 4] - gts[:, 2] + 1.) - inters)
  108. overlaps = inters / uni
  109. ovmax = np.max(overlaps)
  110. jmax = np.argmax(overlaps)
  111. if ovmax > ovthresh and ovmax > 0:
  112. if gts[jmax, -1] == 0:
  113. gtBboxes[gtind[jmax], -1] = 1
  114. detBboxes[d, -1] = 1
  115. # 计算数值
  116. prec_val = np.sum(detBboxes[:, -1] == 1) / np.maximum(float(nd), np.finfo(np.float32).eps)
  117. print('prec_val的 tp:{}'.format(np.sum(detBboxes[:, -1] == 1)))
  118. rec_val = np.sum(gtBboxes[:, -1] == 1) / np.maximum(float(ngt), np.finfo(np.float32).eps)
  119. print('rec_val的 tp:{}'.format(np.sum(gtBboxes[:, -1] == 1)))
  120. F1 = 2 * prec_val * rec_val / np.maximum(prec_val + rec_val, np.finfo(np.float32).eps)
  121. # fp值false positive,实际为负例,预测为正例,为了计算误检
  122. temp_fp = np.maximum(float(nd), np.finfo(np.float32).eps) - np.sum(detBboxes[:, -1] == 1)
  123. # 绘制p/r曲线 求mAP
  124. tp = detBboxes[:, -1]
  125. fp = 1 - tp
  126. fp = np.cumsum(fp)
  127. tp = np.cumsum(tp)
  128. recalls = tp / np.maximum(float(ngt), np.finfo(np.float32).eps)
  129. precisions = tp / np.maximum(tp + fp, np.finfo(np.float32).eps)
  130. mrecalls = np.concatenate(([0.], recalls, [1.]))
  131. mprecisions = np.concatenate(([0.], precisions, [0.]))
  132. # compute the precision envelope
  133. for i in range(mprecisions.size - 1, 0, -1):
  134. mprecisions[i - 1] = np.maximum(mprecisions[i - 1], mprecisions[i])
  135. # to calculate area under PR curve, look for points
  136. # where X axis (recall) changes value
  137. i = np.where(mrecalls[1:] != mrecalls[:-1])[0]
  138. # and sum (\Delta recall) * prec
  139. ap = np.sum((mrecalls[i + 1] - mrecalls[i]) * mprecisions[i + 1])
  140. pr_x = np.concatenate(([mrecalls[0]], mrecalls[i + 1]))
  141. pr_y = np.concatenate(([mprecisions[0]], mprecisions[i + 1]))
  142. # 绘制ROC曲线,求AUC
  143. fpr = fp / np.maximum(float(nd), np.finfo(np.float32).eps)
  144. tpr = tp / np.maximum(float(nd), np.finfo(np.float32).eps)
  145. mfpr = np.concatenate(([0.], fpr, [1.]))
  146. mtpr = np.concatenate(([0.], tpr, [1.]))
  147. i = np.where(mfpr[1:] != mfpr[:-1])[0]
  148. auc = np.sum((mfpr[i + 1] - mfpr[i]) * mtpr[i + 1])
  149. roc_x = np.concatenate(([mfpr[0]], mfpr[i + 1]))
  150. roc_y = np.concatenate(([mtpr[0]], mtpr[i + 1]))
  151. print('class:{} prec:{:.3f} recall:{:.3f} F1:{:.3f} mAP:{:.3f} AUC:{:.3f} temp_fp:{}\n'.format(
  152. c, prec_val, rec_val, F1, ap, auc, temp_fp))
  153. resultoutpath = os.path.join(OUTPUT_PATH, 'predict.txt')
  154. resulttxt = open(resultoutpath, "a")
  155. resulttxt.write('class:{} prec:{:.3f} recall:{:.3f} F1:{:.3f} mAP:{:.3f} AUC:{:.3f} temp_fp:{}\n'.format(
  156. c, prec_val, rec_val, F1, ap, auc, temp_fp))
  157. resulttxt.close()
  158. return
  159. # 输出路径
  160. outputpath = os.path.join(OUTPUT_PATH, 'TestCurves')
  161. if not os.path.isdir(outputpath):
  162. os.makedirs(outputpath)
  163. #类别包括背景 7+1
  164. NUM_CLASSES = 8
  165. TEST_IOU_FILT_TH = 0.5
  166. BACKGROUND_LABEL_ID = 0
  167. TEST_KEEP_TOPK = 25
  168. INPUT_ROIS_PER_IMAGE = 12
  169. file_path = 'E:\\20210823_AnnotatedBreastDatas\\yolo_dataset\\test\\test\\'
  170. f_names = glob.glob(file_path + '*.jpg')
  171. num_test = len(f_names)
  172. conf_thresh = 0.4
  173. nms_thresh = 0.6
  174. # all_boxes:保存所有的预测值
  175. # all_gt_infos: 保存所有的gt值
  176. all_boxes = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
  177. all_gt_infos = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
  178. for i in range(len(f_names)):
  179. # if f_names[i] != '01BB8F2539C341C79DA1D00559786324__LVYvbFEnhzAcPHR2.jpg':
  180. # continue
  181. orig_img = cv2.imdecode(np.fromfile(f_names[i], dtype=np.uint8), 1)
  182. aaa = orig_img.copy()
  183. width = aaa.shape[1]
  184. height = aaa.shape[0]
  185. # # 读入gt标注
  186. annotations = np.zeros((INPUT_ROIS_PER_IMAGE, 5), dtype=np.float32)
  187. gt_file = os.path.join(file_path, "{}.txt".format(f_names[i].split("\\")[-1].split(".jpg")[0]))
  188. with open(gt_file, "r") as gtf:
  189. gt_anno = gtf.readlines()
  190. gt_ind = 0
  191. for ind in range(len(gt_anno)):
  192. gt_info = gt_anno[ind].strip().split('\n')
  193. if len(gt_info[0]) > 0:
  194. bbox_floats = np.fromstring(gt_info[0], dtype=np.float32, sep=' ')
  195. label_gt= int(bbox_floats[0]) + 1
  196. x_yolo_gt = bbox_floats[1]
  197. y_yolo_gt = bbox_floats[2]
  198. width_yolo_gt = bbox_floats[3]
  199. height_yolo_gt = bbox_floats[4]
  200. top_gt = int((y_yolo_gt - height_yolo_gt / 2) * height)
  201. left_gt = int((x_yolo_gt - width_yolo_gt / 2) * width)
  202. right_gt = int((x_yolo_gt + width_yolo_gt / 2) * width)
  203. bottom_gt = int((y_yolo_gt + height_yolo_gt / 2) * height)
  204. annotations[gt_ind, :] = [left_gt, top_gt, right_gt, bottom_gt, label_gt]
  205. gt_ind += 1
  206. gtf.close()
  207. for c, _ in enumerate(class_list):
  208. if c == BACKGROUND_LABEL_ID:
  209. continue
  210. cls_gt_boxes = annotations[np.where(annotations[:, -1] == c)]
  211. gt_box_count = len(cls_gt_boxes)
  212. if gt_box_count > 0:
  213. if class_id_map is None:
  214. all_gt_infos[c][i] = np.reshape(cls_gt_boxes[:, 0:4], gt_box_count * 4).tolist()
  215. else:
  216. all_gt_infos[class_id_map[c]][i] = np.reshape( cls_gt_boxes[:, 0:4], gt_box_count * 4).tolist()
  217. resized = cv2.resize(orig_img, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
  218. # img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
  219. # img_in = np.expand_dims(img_in, axis=2)
  220. img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
  221. img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
  222. img_in = np.expand_dims(img_in, axis=0)
  223. img_in /= 255.0
  224. outputs = session.run(None, {input_name: img_in})
  225. # [batch, num, 1, 4]
  226. box_array = outputs[0]
  227. # [batch, num, num_classes]
  228. confs = outputs[1]
  229. if type(box_array).__name__ != 'ndarray':
  230. box_array = box_array.cpu().detach().numpy()
  231. confs = confs.cpu().detach().numpy()
  232. num_classes = confs.shape[2]
  233. # [batch, num, 4]
  234. box_array = box_array[:, :, 0]
  235. # [batch, num, num_classes] --> [batch, num]
  236. max_conf = np.max(confs, axis=2)
  237. max_id = np.argmax(confs, axis=2)
  238. assert box_array.shape[0] == 1,"每次单幅图像预测,batch为1"
  239. argwhere = max_conf[0] > conf_thresh
  240. l_box_array = box_array[0, argwhere, :]
  241. l_max_conf = max_conf[0, argwhere]
  242. l_max_id = max_id[0, argwhere]
  243. bboxes = []
  244. # nms for each class
  245. for j in range(num_classes):
  246. cls_argwhere = l_max_id == j
  247. ll_box_array = l_box_array[cls_argwhere, :]
  248. ll_max_conf = l_max_conf[cls_argwhere]
  249. ll_max_id = l_max_id[cls_argwhere]
  250. keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
  251. if (keep.size > 0):
  252. ll_box_array = ll_box_array[keep, :]
  253. ll_max_conf = ll_max_conf[keep]
  254. ll_max_id = ll_max_id[keep]
  255. for k in range(ll_box_array.shape[0]):
  256. bboxes.append(
  257. [
  258. ll_max_id[k] + 1,
  259. ll_max_conf[k],
  260. int(ll_box_array[k, 0] * width),
  261. int(ll_box_array[k, 1] * height),
  262. int(ll_box_array[k, 2] * width),
  263. int(ll_box_array[k, 3] * height),
  264. ])
  265. pred_annos = np.zeros((TEST_KEEP_TOPK, 6))
  266. valid_box_ind = 0
  267. if len(bboxes) == 0:
  268. print(1)
  269. for t in range(len(bboxes)):
  270. pred_annos[valid_box_ind, :] = bboxes[t]
  271. valid_box_ind += 1
  272. for ind in range(pred_annos.shape[0]):
  273. c = int(pred_annos[ind][0])
  274. if c <= 0:
  275. continue
  276. score = pred_annos[ind][1]
  277. box_left = int(pred_annos[ind][2])
  278. box_top = int(pred_annos[ind][3])
  279. box_right = int(pred_annos[ind][4])
  280. box_bottom = int(pred_annos[ind][5])
  281. if box_right > box_left and box_bottom > box_top:
  282. if class_id_map is None:
  283. all_boxes[c][i].append(box_left)
  284. all_boxes[c][i].append(box_top)
  285. all_boxes[c][i].append(box_right)
  286. all_boxes[c][i].append(box_bottom)
  287. all_boxes[c][i].append(score)
  288. else:
  289. all_boxes[class_id_map[c]][i].append(box_left)
  290. all_boxes[class_id_map[c]][i].append(box_top)
  291. all_boxes[class_id_map[c]][i].append(box_right)
  292. all_boxes[class_id_map[c]][i].append(box_bottom)
  293. all_boxes[class_id_map[c]][i].append(score)
  294. # # 当所有图像都测试完毕
  295. evaluated_class_groups = []
  296. for c, _ in enumerate(class_list):
  297. if c != BACKGROUND_LABEL_ID:
  298. if class_id_map is None:
  299. evaluate_detections(c, all_boxes[c], all_gt_infos[c], outputpath, ovthresh=TEST_IOU_FILT_TH)
  300. else:
  301. g = class_id_map[c]
  302. if g not in evaluated_class_groups:
  303. evaluate_detections(g, all_boxes[g], all_gt_infos[g], outputpath, ovthresh=TEST_IOU_FILT_TH)
  304. evaluated_class_groups.append(g)