plot_test_set_bboxes.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import onnxruntime
  2. from tool.utils import *
  3. import glob
  4. import cv2
  5. onnx_path = 'C:\\Users\\VINNO\\Desktop\\新建文件夹 (2)\\pytorch-YOLOv4-master\\20210824_yolov4_1_224_224_static.onnx'
  6. output_path = 'C:\\Users\\VINNO\\Desktop\\新建文件夹 (2)\\pytorch-YOLOv4-master\\output\\'
  7. if not os.path.isdir(output_path):
  8. os.makedirs(output_path)
  9. sess_options = onnxruntime.SessionOptions()
  10. #sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
  11. #控制用于运行模型的线程数 controls the number of threads to use to run the model
  12. # sess_options.intra_op_num_threads = 1
  13. #When sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL,
  14. # you can set sess_options.inter_op_num_threads to control the number of threads used to parallelize the execution of the graph (across nodes).
  15. # sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
  16. # sess_options.inter_op_num_threads = 1
  17. session = onnxruntime.InferenceSession(onnx_path, sess_options)
  18. IN_IMAGE_H = session.get_inputs()[0].shape[2]
  19. IN_IMAGE_W = session.get_inputs()[0].shape[3]
  20. input_name = session.get_inputs()[0].name
  21. class_list = [None] * 8
  22. class_list[0] = '__background__'
  23. class_list[1] = 'lipomyoma'
  24. class_list[2] = 'BIRADS2'
  25. class_list[3] = 'BIRADS3'
  26. class_list[4] = 'BIRADS4a'
  27. class_list[5] = 'BIRADS4b'
  28. class_list[6] = 'BIRADS4c'
  29. class_list[7] = 'BIRADS5'
  30. #类别包括背景 7+1
  31. NUM_CLASSES = 8
  32. TEST_IOU_FILT_TH = 0.5
  33. BACKGROUND_LABEL_ID = 0
  34. TEST_KEEP_TOPK = 25
  35. INPUT_ROIS_PER_IMAGE = 12
  36. file_path = 'E:\\20210823_AnnotatedBreastDatas\\yolo_dataset\\test\\test\\'
  37. f_names = glob.glob(file_path + '*.jpg')
  38. num_test = len(f_names)
  39. def load_class_names(namesfile):
  40. class_names = []
  41. with open(namesfile, 'r') as fp:
  42. lines = fp.readlines()
  43. for line in lines:
  44. line = line.rstrip()
  45. class_names.append(line)
  46. return class_names
  47. namesfile = 'breast.names'
  48. class_names = load_class_names(namesfile)
  49. for i in range(len(f_names)):
  50. orig_img = cv2.imdecode(np.fromfile(f_names[i], dtype=np.uint8), 1)
  51. img = orig_img.copy()
  52. width = img.shape[1]
  53. height = img.shape[0]
  54. # if f_names[i].split("\\")[-1].split(".jpg")[0] != '0BF616FD4A91425B8818DBD12254108F__RBYgtFKM7uTEd0bm':
  55. # continue
  56. # 读入gt标注
  57. annotations = np.zeros((INPUT_ROIS_PER_IMAGE, 5), dtype=np.int32)
  58. gt_file = os.path.join(file_path, "{}.txt".format(f_names[i].split("\\")[-1].split(".jpg")[0]))
  59. with open(gt_file, "r") as gtf:
  60. gt_anno = gtf.readlines()
  61. gt_ind = 0
  62. for ind in range(len(gt_anno)):
  63. gt_info = gt_anno[ind].strip().split('\n')
  64. if len(gt_info[0]) > 0:
  65. bbox_floats = np.fromstring(gt_info[0], dtype=np.float32, sep=' ')
  66. label_gt= int(bbox_floats[0]) + 1
  67. x_yolo_gt = bbox_floats[1]
  68. y_yolo_gt = bbox_floats[2]
  69. width_yolo_gt = bbox_floats[3]
  70. height_yolo_gt = bbox_floats[4]
  71. top_gt = int((y_yolo_gt - height_yolo_gt / 2) * height)
  72. left_gt = int((x_yolo_gt - width_yolo_gt / 2) * width)
  73. right_gt = int((x_yolo_gt + width_yolo_gt / 2) * width)
  74. bottom_gt = int((y_yolo_gt + height_yolo_gt / 2) * height)
  75. annotations[gt_ind, :] = [left_gt, top_gt, right_gt, bottom_gt, label_gt]
  76. gt_ind += 1
  77. gtf.close()
  78. for t in range(len(annotations[0])):
  79. if annotations[t][4] != 0:
  80. img = cv2.putText(img, class_names[annotations[t][4] - 1], (annotations[t][0] + 5, annotations[t][1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2, cv2.LINE_AA)
  81. img = cv2.rectangle(img, (annotations[t][0], annotations[t][1]), (annotations[t][2], annotations[t][3]), (0,0,255), 2)
  82. # cv.imencode('.jpg', img.astype(np.uint8))[1].tofile(os.path.join(output_path, f_names[i].split("\\")[-1].split(".jpg")[0] + '_gt.jpg'))
  83. resized = cv2.resize(orig_img, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
  84. img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
  85. img_in = np.expand_dims(img_in, axis=2)
  86. # img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
  87. img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
  88. img_in = np.expand_dims(img_in, axis=0)
  89. img_in /= 255.0
  90. outputs = session.run(None, {input_name: img_in})
  91. conf_thresh = 0.4
  92. nms_thresh = 0.6
  93. # [batch, num, 1, 4]
  94. box_array = outputs[0]
  95. # [batch, num, num_classes]
  96. confs = outputs[1]
  97. if type(box_array).__name__ != 'ndarray':
  98. box_array = box_array.cpu().detach().numpy()
  99. confs = confs.cpu().detach().numpy()
  100. num_classes = confs.shape[2]
  101. # [batch, num, 4]
  102. box_array = box_array[:, :, 0]
  103. # [batch, num, num_classes] --> [batch, num]
  104. max_conf = np.max(confs, axis=2)
  105. max_id = np.argmax(confs, axis=2)
  106. assert box_array.shape[0] == 1,"每次单幅图像预测,batch为1"
  107. argwhere = max_conf[0] > conf_thresh
  108. l_box_array = box_array[0, argwhere, :]
  109. l_max_conf = max_conf[0, argwhere]
  110. l_max_id = max_id[0, argwhere]
  111. bboxes = []
  112. # nms for each class
  113. for j in range(num_classes):
  114. cls_argwhere = l_max_id == j
  115. ll_box_array = l_box_array[cls_argwhere, :]
  116. ll_max_conf = l_max_conf[cls_argwhere]
  117. ll_max_id = l_max_id[cls_argwhere]
  118. keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
  119. if (keep.size > 0):
  120. ll_box_array = ll_box_array[keep, :]
  121. ll_max_conf = ll_max_conf[keep]
  122. ll_max_id = ll_max_id[keep]
  123. for k in range(ll_box_array.shape[0]):
  124. bboxes.append(
  125. [
  126. ll_max_id[k] + 1,
  127. ll_max_conf[k],
  128. int(ll_box_array[k, 0] * width),
  129. int(ll_box_array[k, 1] * height),
  130. int(ll_box_array[k, 2] * width),
  131. int(ll_box_array[k, 3] * height),
  132. ])
  133. pred_annos = np.zeros((TEST_KEEP_TOPK, 6))
  134. valid_box_ind = 0
  135. for t in range(len(bboxes)):
  136. pred_annos[valid_box_ind, :] = bboxes[t]
  137. valid_box_ind += 1
  138. for ind in range(pred_annos.shape[0]):
  139. c = int(pred_annos[ind][0])
  140. if c <= 0:
  141. continue
  142. score = pred_annos[ind][1]
  143. box_left = int(pred_annos[ind][2])
  144. box_top = int(pred_annos[ind][3])
  145. box_right = int(pred_annos[ind][4])
  146. box_bottom = int(pred_annos[ind][5])
  147. str_labelandscore = '{}:{}'.format(class_names[c-1], '%.3f' % score)
  148. img = cv2.putText(img, str_labelandscore, (box_right - 10, box_bottom - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2, cv2.LINE_AA)
  149. img = cv2.rectangle(img, (box_left, box_top), (box_right, box_bottom), (0,255,0), 2)
  150. cv2.imencode('.jpg', img.astype(np.uint8))[1].tofile(os.path.join(output_path , f_names[i].split("\\")[-1].split(".jpg")[0] + '_.jpg'))