eval_normal_breast.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import sys
  2. import onnx
  3. import os
  4. import argparse
  5. import numpy as np
  6. import cv2
  7. import onnxruntime
  8. import multiprocessing
  9. import matplotlib.pyplot as plt
  10. from tool.utils import *
  11. from tool.darknet2onnx import *
  12. import glob
  13. onnx_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\yolov3_1_3_256_256_static.onnx'
  14. output_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\output'
  15. if not os.path.isdir(output_path):
  16. os.makedirs(output_path)
  17. sess_options = onnxruntime.SessionOptions()
  18. #sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
  19. #控制用于运行模型的线程数 controls the number of threads to use to run the model
  20. sess_options.intra_op_num_threads = 1
  21. #When sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL,
  22. # you can set sess_options.inter_op_num_threads to control the number of threads used to parallelize the execution of the graph (across nodes).
  23. sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
  24. sess_options.inter_op_num_threads = 1
  25. session = onnxruntime.InferenceSession(onnx_path, sess_options)
  26. IN_IMAGE_H = session.get_inputs()[0].shape[2]
  27. IN_IMAGE_W = session.get_inputs()[0].shape[3]
  28. input_name = session.get_inputs()[0].name
  29. class_list = [None] * 8
  30. class_list[0] = '__background__'
  31. class_list[1] = 'lipomyoma'
  32. class_list[2] = 'BIRADS2'
  33. class_list[3] = 'BIRADS3'
  34. class_list[4] = 'BIRADS4a'
  35. class_list[5] = 'BIRADS4b'
  36. class_list[6] = 'BIRADS4c'
  37. class_list[7] = 'BIRADS5'
  38. #类别包括背景 7+1
  39. NUM_CLASSES = 8
  40. TEST_IOU_FILT_TH = 0.5
  41. BACKGROUND_LABEL_ID = 0
  42. TEST_KEEP_TOPK = 25
  43. INPUT_ROIS_PER_IMAGE = 12
  44. file_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\test\\'
  45. f_names = glob.glob(file_path + '*.jpg')
  46. num_test = len(f_names)
  47. # all_boxes:保存所有的预测值
  48. # all_gt_infos: 保存所有的gt值
  49. all_boxes = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
  50. all_gt_infos = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
  51. def load_class_names(namesfile):
  52. class_names = []
  53. with open(namesfile, 'r') as fp:
  54. lines = fp.readlines()
  55. for line in lines:
  56. line = line.rstrip()
  57. class_names.append(line)
  58. return class_names
  59. namesfile = 'breast.names'
  60. class_names = load_class_names(namesfile)
  61. count_sum = 0
  62. count_wrong = 0
  63. for i in range(len(f_names)):
  64. orig_img = cv2.imdecode(np.fromfile(f_names[i], dtype=np.uint8), 1)
  65. img = orig_img.copy()
  66. width = img.shape[1]
  67. height = img.shape[0]
  68. # if f_names[i].split("\\")[-1].split(".jpg")[0] != '01BEC2230D944DE38D853112F0A9C30D__sYZNUrWTex3SkH86':
  69. # continue
  70. # 读入gt标注
  71. annotations = np.zeros((INPUT_ROIS_PER_IMAGE, 5), dtype=np.int32)
  72. gt_file = os.path.join(file_path, "{}.txt".format(f_names[i].split("\\")[-1].split(".jpg")[0]))
  73. with open(gt_file, "r") as gtf:
  74. gt_anno = gtf.readlines()
  75. gt_ind = 0
  76. for ind in range(len(gt_anno)):
  77. gt_info = gt_anno[ind].strip().split('\n')
  78. if len(gt_info[0]) > 0:
  79. bbox_floats = np.fromstring(gt_info[0], dtype=np.float32, sep=' ')
  80. label_gt= int(bbox_floats[0]) + 1
  81. x_yolo_gt = bbox_floats[1]
  82. y_yolo_gt = bbox_floats[2]
  83. width_yolo_gt = bbox_floats[3]
  84. height_yolo_gt = bbox_floats[4]
  85. top_gt = int((y_yolo_gt - height_yolo_gt / 2) * height)
  86. left_gt = int((x_yolo_gt - width_yolo_gt / 2) * width)
  87. right_gt = int((x_yolo_gt + width_yolo_gt / 2) * width)
  88. bottom_gt = int((y_yolo_gt + height_yolo_gt / 2) * height)
  89. annotations[gt_ind, :] = [left_gt, top_gt, right_gt, bottom_gt, label_gt]
  90. gt_ind += 1
  91. gtf.close()
  92. if annotations[0][4] != 0:
  93. continue
  94. count_sum += 1
  95. # for t in range(len(annotations[0])):
  96. # if annotations[t][4] != 0:
  97. # img = cv2.putText(img, class_names[annotations[t][4] - 1], (annotations[t][0] + 5, annotations[t][1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2, cv2.LINE_AA)
  98. # img = cv2.rectangle(img, (annotations[t][0], annotations[t][1]), (annotations[t][2], annotations[t][3]), (0,0,255), 2)
  99. #cv2.imwrite(os.path.join(output_path , f_names[i].split("\\")[-1].split(".jpg")[0] + '_gt.jpg'), img)
  100. resized = cv2.resize(orig_img, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
  101. img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
  102. img_in = np.expand_dims(img_in, axis=2)
  103. # img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
  104. img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
  105. img_in = np.expand_dims(img_in, axis=0)
  106. img_in /= 255.0
  107. outputs = session.run(None, {input_name: img_in})
  108. conf_thresh = 0.4
  109. nms_thresh = 0.6
  110. # [batch, num, 1, 4]
  111. box_array = outputs[0]
  112. # [batch, num, num_classes]
  113. confs = outputs[1]
  114. if type(box_array).__name__ != 'ndarray':
  115. box_array = box_array.cpu().detach().numpy()
  116. confs = confs.cpu().detach().numpy()
  117. num_classes = confs.shape[2]
  118. # [batch, num, 4]
  119. box_array = box_array[:, :, 0]
  120. # [batch, num, num_classes] --> [batch, num]
  121. max_conf = np.max(confs, axis=2)
  122. max_id = np.argmax(confs, axis=2)
  123. assert box_array.shape[0] == 1,"每次单幅图像预测,batch为1"
  124. argwhere = max_conf[0] > conf_thresh
  125. l_box_array = box_array[0, argwhere, :]
  126. l_max_conf = max_conf[0, argwhere]
  127. l_max_id = max_id[0, argwhere]
  128. bboxes = []
  129. # nms for each class
  130. for j in range(num_classes):
  131. cls_argwhere = l_max_id == j
  132. ll_box_array = l_box_array[cls_argwhere, :]
  133. ll_max_conf = l_max_conf[cls_argwhere]
  134. ll_max_id = l_max_id[cls_argwhere]
  135. keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
  136. if (keep.size > 0):
  137. ll_box_array = ll_box_array[keep, :]
  138. ll_max_conf = ll_max_conf[keep]
  139. ll_max_id = ll_max_id[keep]
  140. for k in range(ll_box_array.shape[0]):
  141. bboxes.append(
  142. [
  143. ll_max_id[k] + 1,
  144. ll_max_conf[k],
  145. int(ll_box_array[k, 0] * width),
  146. int(ll_box_array[k, 1] * height),
  147. int(ll_box_array[k, 2] * width),
  148. int(ll_box_array[k, 3] * height),
  149. ])
  150. pred_annos = np.zeros((TEST_KEEP_TOPK, 6))
  151. valid_box_ind = 0
  152. for t in range(len(bboxes)):
  153. pred_annos[valid_box_ind, :] = bboxes[t]
  154. valid_box_ind += 1
  155. count_wrong += valid_box_ind
  156. for ind in range(pred_annos.shape[0]):
  157. c = int(pred_annos[ind][0])
  158. if c <= 0:
  159. continue
  160. score = pred_annos[ind][1]
  161. box_left = int(pred_annos[ind][2])
  162. box_top = int(pred_annos[ind][3])
  163. box_right = int(pred_annos[ind][4])
  164. box_bottom = int(pred_annos[ind][5])
  165. str_labelandscore = '{}:{}'.format(class_names[c-1], '%.3f' % score)
  166. img = cv2.putText(img, str_labelandscore, (box_right - 10, box_bottom - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2, cv2.LINE_AA)
  167. img = cv2.rectangle(img, (box_left, box_top), (box_right, box_bottom), (0,255,0), 2)
  168. cv2.imwrite(os.path.join(output_path , f_names[i].split("\\")[-1].split(".jpg")[0] + '_pre.jpg'), img)
  169. print("total_num:{},wrong_num:{},误检率:{}\n".format(count_sum, count_wrong, count_wrong / count_sum))