onnx_use.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import onnx
  2. import onnxruntime as rt
  3. import glob
  4. from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
  5. import cv2 as cv
  6. import numpy as np
  7. import os
  8. model = os.path.join(os.getcwd(), 'classification.onnx')
  9. #对不同的文件夹进行预测,得到相应的指标,每个文件夹对应一个类别
  10. imagesize = 224
  11. imagefilepath= 'D:\\'
  12. #指定每个文件夹,所对应的类别索引
  13. file_list = {'BMode':0, 'BMode_FiveChamberView':1, 'BMode_FourChamberView':2}
  14. onnx.checker.check_model(model)
  15. #print(onnx.helper.printable_graph(model.graph))
  16. sess = rt.InferenceSession(model)
  17. input_name = sess.get_inputs()[0].name
  18. print("input name", input_name)
  19. input_shape = sess.get_inputs()[0].shape
  20. print("input shape", input_shape)
  21. input_type = sess.get_inputs()[0].type
  22. print("input type", input_type)
  23. output_name = sess.get_outputs()[0].name
  24. print("output name", output_name)
  25. output_shape = sess.get_outputs()[0].shape
  26. print("output shape", output_shape)
  27. output_type = sess.get_outputs()[0].type
  28. print("output type", output_type)
  29. img_label = []
  30. img_predict = []
  31. for file in file_list.keys():
  32. in_files = os.path.join(imagefilepath, '{}\\'.format(file))
  33. f_names = glob.glob(in_files + '*jpg')
  34. metric_count = 0
  35. true_label = file_list[file]
  36. for i in range(len(f_names)):
  37. img_orig = cv.imdecode(np.fromfile(f_names[i], dtype=np.uint8),1)
  38. newW, newH = int(imagesize), int(imagesize)
  39. assert newW > 0 and newH > 0, 'Scale is too small'
  40. img = cv.resize(img_orig, (newW, newH), interpolation=cv.INTER_LINEAR)
  41. img_nd = np.array(img, dtype=np.float32)
  42. if len(img_nd.shape) == 2:
  43. img_nd = np.expand_dims(img_nd, axis=2)
  44. # HWC to CHW
  45. img_trans = img_nd.transpose((2, 0, 1))
  46. if img_trans.max() > 1:
  47. img_trans = img_trans / 255
  48. img_trans = np.expand_dims(img_trans, axis=0)
  49. out = sess.run([output_name], {input_name:img_trans})[0]
  50. pred_label = np.argmax(out[0])
  51. img_predict.append(pred_label)
  52. img_label.append(true_label)
  53. if pred_label == true_label:
  54. metric_count += 1
  55. metric = metric_count / len(f_names) * 100
  56. print("{}:total count is ".format(file) + str(len(f_names))+".")
  57. print("{}:The metric is ".format(file) + str(metric) + "%")
  58. # 准确率,精确率,召回率,F1
  59. accuracy = accuracy_score(img_label, img_predict)
  60. precision = precision_score(img_label, img_predict, average='macro')
  61. recall = recall_score(img_label, img_predict, average='macro')
  62. f1 = f1_score(img_label, img_predict, average='macro')
  63. print("accuracy_score = %.2f" % accuracy)
  64. print("precision_score = %.2f" % precision)
  65. print("recall_score = %.2f" % recall)
  66. print("f1_score = %.2f" % f1)