123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- import sys
- import onnx
- import os
- import argparse
- import numpy as np
- import cv2
- import onnxruntime
- import multiprocessing
- import matplotlib.pyplot as plt
- from tool.utils import *
- from tool.darknet2onnx import *
- import glob
- onnx_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\yolov3_1_3_256_256_static.onnx'
- output_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\output'
- if not os.path.isdir(output_path):
- os.makedirs(output_path)
- sess_options = onnxruntime.SessionOptions()
- #sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
- #控制用于运行模型的线程数 controls the number of threads to use to run the model
- sess_options.intra_op_num_threads = 1
- #When sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL,
- # you can set sess_options.inter_op_num_threads to control the number of threads used to parallelize the execution of the graph (across nodes).
- sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
- sess_options.inter_op_num_threads = 1
- session = onnxruntime.InferenceSession(onnx_path, sess_options)
- IN_IMAGE_H = session.get_inputs()[0].shape[2]
- IN_IMAGE_W = session.get_inputs()[0].shape[3]
- input_name = session.get_inputs()[0].name
- class_list = [None] * 8
- class_list[0] = '__background__'
- class_list[1] = 'lipomyoma'
- class_list[2] = 'BIRADS2'
- class_list[3] = 'BIRADS3'
- class_list[4] = 'BIRADS4a'
- class_list[5] = 'BIRADS4b'
- class_list[6] = 'BIRADS4c'
- class_list[7] = 'BIRADS5'
- #类别包括背景 7+1
- NUM_CLASSES = 8
- TEST_IOU_FILT_TH = 0.5
- BACKGROUND_LABEL_ID = 0
- TEST_KEEP_TOPK = 25
- INPUT_ROIS_PER_IMAGE = 12
- file_path = 'C:\\Users\\VINNO\\Desktop\\pytorch-YOLOv4-master\\test\\'
- f_names = glob.glob(file_path + '*.jpg')
- num_test = len(f_names)
- # all_boxes:保存所有的预测值
- # all_gt_infos: 保存所有的gt值
- all_boxes = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
- all_gt_infos = [[[] for _ in range(num_test)] for _ in range(NUM_CLASSES)]
- def load_class_names(namesfile):
- class_names = []
- with open(namesfile, 'r') as fp:
- lines = fp.readlines()
- for line in lines:
- line = line.rstrip()
- class_names.append(line)
- return class_names
- namesfile = 'breast.names'
- class_names = load_class_names(namesfile)
- count_sum = 0
- count_wrong = 0
- for i in range(len(f_names)):
- orig_img = cv2.imdecode(np.fromfile(f_names[i], dtype=np.uint8), 1)
- img = orig_img.copy()
- width = img.shape[1]
- height = img.shape[0]
- # if f_names[i].split("\\")[-1].split(".jpg")[0] != '01BEC2230D944DE38D853112F0A9C30D__sYZNUrWTex3SkH86':
- # continue
- # 读入gt标注
- annotations = np.zeros((INPUT_ROIS_PER_IMAGE, 5), dtype=np.int32)
- gt_file = os.path.join(file_path, "{}.txt".format(f_names[i].split("\\")[-1].split(".jpg")[0]))
- with open(gt_file, "r") as gtf:
- gt_anno = gtf.readlines()
- gt_ind = 0
- for ind in range(len(gt_anno)):
- gt_info = gt_anno[ind].strip().split('\n')
- if len(gt_info[0]) > 0:
- bbox_floats = np.fromstring(gt_info[0], dtype=np.float32, sep=' ')
- label_gt= int(bbox_floats[0]) + 1
- x_yolo_gt = bbox_floats[1]
- y_yolo_gt = bbox_floats[2]
- width_yolo_gt = bbox_floats[3]
- height_yolo_gt = bbox_floats[4]
- top_gt = int((y_yolo_gt - height_yolo_gt / 2) * height)
- left_gt = int((x_yolo_gt - width_yolo_gt / 2) * width)
- right_gt = int((x_yolo_gt + width_yolo_gt / 2) * width)
- bottom_gt = int((y_yolo_gt + height_yolo_gt / 2) * height)
- annotations[gt_ind, :] = [left_gt, top_gt, right_gt, bottom_gt, label_gt]
- gt_ind += 1
- gtf.close()
- if annotations[0][4] != 0:
- continue
- count_sum += 1
- # for t in range(len(annotations[0])):
- # if annotations[t][4] != 0:
- # img = cv2.putText(img, class_names[annotations[t][4] - 1], (annotations[t][0] + 5, annotations[t][1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2, cv2.LINE_AA)
- # img = cv2.rectangle(img, (annotations[t][0], annotations[t][1]), (annotations[t][2], annotations[t][3]), (0,0,255), 2)
- #cv2.imwrite(os.path.join(output_path , f_names[i].split("\\")[-1].split(".jpg")[0] + '_gt.jpg'), img)
- resized = cv2.resize(orig_img, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
- img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
- img_in = np.expand_dims(img_in, axis=2)
- # img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
- img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
- img_in = np.expand_dims(img_in, axis=0)
- img_in /= 255.0
- outputs = session.run(None, {input_name: img_in})
- conf_thresh = 0.4
- nms_thresh = 0.6
- # [batch, num, 1, 4]
- box_array = outputs[0]
- # [batch, num, num_classes]
- confs = outputs[1]
- if type(box_array).__name__ != 'ndarray':
- box_array = box_array.cpu().detach().numpy()
- confs = confs.cpu().detach().numpy()
- num_classes = confs.shape[2]
- # [batch, num, 4]
- box_array = box_array[:, :, 0]
- # [batch, num, num_classes] --> [batch, num]
- max_conf = np.max(confs, axis=2)
- max_id = np.argmax(confs, axis=2)
- assert box_array.shape[0] == 1,"每次单幅图像预测,batch为1"
- argwhere = max_conf[0] > conf_thresh
- l_box_array = box_array[0, argwhere, :]
- l_max_conf = max_conf[0, argwhere]
- l_max_id = max_id[0, argwhere]
- bboxes = []
- # nms for each class
- for j in range(num_classes):
- cls_argwhere = l_max_id == j
- ll_box_array = l_box_array[cls_argwhere, :]
- ll_max_conf = l_max_conf[cls_argwhere]
- ll_max_id = l_max_id[cls_argwhere]
- keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
- if (keep.size > 0):
- ll_box_array = ll_box_array[keep, :]
- ll_max_conf = ll_max_conf[keep]
- ll_max_id = ll_max_id[keep]
- for k in range(ll_box_array.shape[0]):
- bboxes.append(
- [
- ll_max_id[k] + 1,
- ll_max_conf[k],
- int(ll_box_array[k, 0] * width),
- int(ll_box_array[k, 1] * height),
- int(ll_box_array[k, 2] * width),
- int(ll_box_array[k, 3] * height),
- ])
- pred_annos = np.zeros((TEST_KEEP_TOPK, 6))
- valid_box_ind = 0
- for t in range(len(bboxes)):
- pred_annos[valid_box_ind, :] = bboxes[t]
- valid_box_ind += 1
- count_wrong += valid_box_ind
- for ind in range(pred_annos.shape[0]):
- c = int(pred_annos[ind][0])
- if c <= 0:
- continue
- score = pred_annos[ind][1]
- box_left = int(pred_annos[ind][2])
- box_top = int(pred_annos[ind][3])
- box_right = int(pred_annos[ind][4])
- box_bottom = int(pred_annos[ind][5])
- str_labelandscore = '{}:{}'.format(class_names[c-1], '%.3f' % score)
- img = cv2.putText(img, str_labelandscore, (box_right - 10, box_bottom - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2, cv2.LINE_AA)
- img = cv2.rectangle(img, (box_left, box_top), (box_right, box_bottom), (0,255,0), 2)
- cv2.imwrite(os.path.join(output_path , f_names[i].split("\\")[-1].split(".jpg")[0] + '_pre.jpg'), img)
- print("total_num:{},wrong_num:{},误检率:{}\n".format(count_sum, count_wrong, count_wrong / count_sum))
|