predict.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import argparse
  2. import os
  3. import logging
  4. import cv2
  5. import numpy as np
  6. import glob
  7. import json
  8. import time
  9. import torch
  10. import torch.nn.parallel
  11. import torch.optim
  12. import torch.utils.data
  13. import torch.utils.data.distributed
  14. from utils.confusion_matrix import draw_confusion_matrix
  15. parser = argparse.ArgumentParser(description='PyTorch ImageNet Predict')
  16. parser.add_argument('--model', '-m', default=os.path.join(os.getcwd(), 'checkpoints\\mobilenetv3\\CP_epoch31.pth'),
  17. metavar='FILE',
  18. help="Specify the file in which the model is stored")
  19. parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
  20. help='filenames of input images', default=os.path.join(os.getcwd(), "data\\test\\"))
  21. parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+',
  22. help='path of save the result', default=os.path.join(os.getcwd(), 'output\\20240515-augmentation-2\\'))
  23. parser.add_argument('-a', '--arch', metavar='ARCH', default='efficientnet-b0',
  24. help='model architecture (default: efficientnet-b0)')
  25. parser.add_argument('--gpu', default=0, type=int,
  26. help='GPU id to use.')
  27. parser.add_argument('--image_size', default=256, type=int,
  28. help='image size')
  29. index_label = {0:'BMode', 1:'BModeBlood',2:'Pseudocolor',3:'PseudocolorBlood', 4:'Spectrogram', 5:'CEUS', 6:'SE',7:'STE',8:'FourDime'}
  30. def predict_img(net,
  31. full_img,
  32. device,
  33. scale_factor=1):
  34. net.eval() # 测试时的网络特征
  35. img = cv2.resize(full_img, (scale_factor, scale_factor))
  36. if len(img.shape) == 2:
  37. img = np.expand_dims(img, axis=2) # 表示在axis位置添加数据
  38. img = img.transpose((2, 0, 1)) # 转置:Pytorch中为[Channels, H, W]
  39. if img.max() > 1:
  40. img = img / 255
  41. img = torch.from_numpy(img) # 将x转换为torch类型
  42. img = img.unsqueeze(0) # 扩充第0个维度
  43. img = img.to(device=device, dtype=torch.float32)
  44. with torch.no_grad():
  45. t = time.time()
  46. output = net(img).cpu().numpy()
  47. # print("时间:{}".format((time.time() - t)*1000))
  48. pred_label = np.argmax(output)
  49. score = output[0][int(pred_label)]
  50. # print("score:{}".format(output[0][int(pred_label)]))
  51. return pred_label, score
  52. def find_images(root_dir):
  53. """
  54. 深度遍历查找指定目录下的所有图像文件
  55. """
  56. swap_dict = {v:k for k,v in index_label.items()}
  57. image_files = []
  58. for root, dirs, files in os.walk(root_dir):
  59. for file in files:
  60. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
  61. json_file = os.path.join(root, file[0:-4] + ".txt")
  62. labeled_info = open(json_file, 'r')
  63. txt_info = labeled_info.read()
  64. txt_info_dict = json.loads(txt_info)
  65. classes = txt_info_dict[0]["FileResultInfos"][0]["LabeledResult"]["ImageResults"][0]["Conclusion"]["Title"]
  66. dst_label = swap_dict[classes]
  67. image_files.append((os.path.join(root, file), dst_label))
  68. return image_files
  69. def main():
  70. args = parser.parse_args()
  71. if not os.path.exists(args.output):
  72. os.makedirs(args.output)
  73. # EfficientNet预测
  74. # from model.efficientnet_pytorch.model import EfficientNet
  75. # net = EfficientNet.from_name(model_name=args.arch, in_channels=3, num_classes=9, image_size=args.image_size)
  76. # resnext50预测
  77. # from resnext_pytorch.resnext import resnext50
  78. # net = resnext50(baseWidth=4, cardinality=8)
  79. # mobilenetv3预测
  80. from model.mobilenetv3 import mobilenetv3
  81. net = mobilenetv3.MobileNetV3(n_class=9, input_size=256, channels=3, mode='small')
  82. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  83. device = torch.device('cpu')
  84. logging.info(f'Using device {device}')
  85. net.to(device=device)
  86. net.load_state_dict(torch.load(args.model, map_location=device))
  87. logging.info("Model loaded !")
  88. y_gt=[]
  89. y_pred=[]
  90. f_names = find_images(args.input) # 匹配所有的符合条件的文件,并将其以list的形式返回完整路径
  91. times = 0
  92. for i in range(len(f_names)):
  93. imgpath, labels_true = f_names[i]
  94. img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), 1) # 读回数据时需要用户指定元素类型,并对数组的形状进行适当的修改
  95. t =time.time()
  96. classification_pred, score = predict_img(net=net, full_img=img, scale_factor=args.image_size, device=device)
  97. times += (time.time() - t)*1000
  98. # print("时间:{}".format((time.time() - t)*1000))
  99. logging.info("Visualizing results for image {}, close to continue ...".format(imgpath))
  100. # save results
  101. Label_pred_dir = os.path.join(args.output, index_label[classification_pred])
  102. if not os.path.exists(Label_pred_dir):
  103. os.makedirs(Label_pred_dir)
  104. # print("预测的第{}图的类别为{}".format(i, index_label[classification_pred]))
  105. cv2.imencode('.jpg', img.astype(np.uint8))[1].tofile(os.path.join(Label_pred_dir, imgpath.split('\\')[-1]))
  106. y_pred.append(classification_pred)
  107. y_gt.append(labels_true)
  108. if i%20==0:
  109. print(i, '/', len(f_names))
  110. # print("分类结果:{}\n".format(index_label[classification_pred]))
  111. # 绘制热力图
  112. draw_confusion_matrix(label_true=y_gt,
  113. label_pred=y_pred,
  114. label_name=[v for k, v in index_label.items()],
  115. normlize=True,
  116. title="Confusion Matrix",
  117. pdf_save_path=os.path.join(args.output,"CM.jpg"),
  118. dpi=300)
  119. if __name__ == '__main__':
  120. main()