123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import onnx
- import onnxruntime as rt
- import glob
- from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
- import cv2 as cv
- import numpy as np
- import os
- model = os.path.join(os.getcwd(), 'classification.onnx')
- #对不同的文件夹进行预测,得到相应的指标,每个文件夹对应一个类别
- imagesize = 224
- imagefilepath= 'D:\\'
- #指定每个文件夹,所对应的类别索引
- file_list = {'BMode':0, 'BMode_FiveChamberView':1, 'BMode_FourChamberView':2}
- onnx.checker.check_model(model)
- #print(onnx.helper.printable_graph(model.graph))
- sess = rt.InferenceSession(model)
- input_name = sess.get_inputs()[0].name
- print("input name", input_name)
- input_shape = sess.get_inputs()[0].shape
- print("input shape", input_shape)
- input_type = sess.get_inputs()[0].type
- print("input type", input_type)
- output_name = sess.get_outputs()[0].name
- print("output name", output_name)
- output_shape = sess.get_outputs()[0].shape
- print("output shape", output_shape)
- output_type = sess.get_outputs()[0].type
- print("output type", output_type)
- img_label = []
- img_predict = []
- for file in file_list.keys():
- in_files = os.path.join(imagefilepath, '{}\\'.format(file))
- f_names = glob.glob(in_files + '*jpg')
- metric_count = 0
- true_label = file_list[file]
- for i in range(len(f_names)):
- img_orig = cv.imdecode(np.fromfile(f_names[i], dtype=np.uint8),1)
- newW, newH = int(imagesize), int(imagesize)
- assert newW > 0 and newH > 0, 'Scale is too small'
- img = cv.resize(img_orig, (newW, newH), interpolation=cv.INTER_LINEAR)
- img_nd = np.array(img, dtype=np.float32)
- if len(img_nd.shape) == 2:
- img_nd = np.expand_dims(img_nd, axis=2)
- # HWC to CHW
- img_trans = img_nd.transpose((2, 0, 1))
- if img_trans.max() > 1:
- img_trans = img_trans / 255
- img_trans = np.expand_dims(img_trans, axis=0)
- out = sess.run([output_name], {input_name:img_trans})[0]
- pred_label = np.argmax(out[0])
- img_predict.append(pred_label)
- img_label.append(true_label)
- if pred_label == true_label:
- metric_count += 1
- metric = metric_count / len(f_names) * 100
- print("{}:total count is ".format(file) + str(len(f_names))+".")
- print("{}:The metric is ".format(file) + str(metric) + "%")
- # 准确率,精确率,召回率,F1
- accuracy = accuracy_score(img_label, img_predict)
- precision = precision_score(img_label, img_predict, average='macro')
- recall = recall_score(img_label, img_predict, average='macro')
- f1 = f1_score(img_label, img_predict, average='macro')
- print("accuracy_score = %.2f" % accuracy)
- print("precision_score = %.2f" % precision)
- print("recall_score = %.2f" % recall)
- print("f1_score = %.2f" % f1)
|