predict.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # -*- coding: utf-8 -*-
  2. # @ File : predict.py
  3. # @ Author : Guido LuXiaohao
  4. # @ Date : 2021/8/30
  5. # @ Software : PyCharm
  6. # @ Description: 用于模型推理
  7. import json
  8. import math
  9. import os
  10. import sys
  11. import time
  12. from argparse import ArgumentParser
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import onnxruntime as ort
  17. import torch
  18. from PIL import Image, ImageDraw, ImageFont
  19. from dataset import LoadImages, resize_LongestMaxSize
  20. from evaluation.utils.metric_file_generate import \
  21. prfile_generate_semantic_segmentation
  22. FILE = Path(__file__).resolve()
  23. ROOT = FILE.parents[0] # project root directory
  24. if str(ROOT) not in sys.path:
  25. sys.path.append(str(ROOT)) # add ROOT to PATH
  26. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative path
  27. # {label: [min_area, area_rate, max_count], ...}
  28. NECKORGAN_CFGS = {
  29. "甲状腺横切": [0, 1 / 25, False],
  30. "甲状腺纵切": [0, 1 / 25, True],
  31. "颈动脉短轴": [0, 1 / 100, False],
  32. "颈动脉长轴": [0, 1 / 100, True],
  33. "颈部气管": [0, 1 / 25, True]
  34. }
  35. NECKORGAN_CFGS2 = [
  36. {
  37. "甲状腺横切": [0, 1 / 25, False],
  38. "甲状腺纵切": [0, 1 / 25, True],
  39. "颈动脉短轴": [0, 1 / 100, False],
  40. "颈动脉长轴": [0, 1 / 100, True]},
  41. {
  42. "颈部气管": [0, 1 / 25, True]},
  43. ]
  44. NECKLESION_CFGS = {
  45. "斑块或内中膜增厚": [50, 0, False]
  46. }
  47. NERVE_CFGS = {
  48. "神经": [0, 1 / 100, False],
  49. "动脉": [0, 1 / 100, False],
  50. "静脉": [0, 1 / 100, False]}
  51. def resize_LongestMaxSize_back(image, ori_size, resize_mode=cv2.INTER_CUBIC):
  52. # 将经过resize_LongestMaxSize的图像去除pad填充,再resize回原始尺寸
  53. image_height = ori_size[0]
  54. image_width = ori_size[1]
  55. norm_size = image.shape[0]
  56. if len(image.shape) != 3:
  57. image = np.expand_dims(image, -1)
  58. if image_height > image_width:
  59. pad_len = int((norm_size - norm_size * image_width / image_height) / 2)
  60. image_before_pad = image[0:norm_size, pad_len:norm_size - pad_len]
  61. image = cv2.resize(image_before_pad, (image_width, image_height), interpolation=resize_mode)
  62. else:
  63. pad_len = int((norm_size - norm_size * image_height / image_width) / 2)
  64. image_before_pad = image[pad_len:norm_size - pad_len, 0:norm_size]
  65. image = cv2.resize(image_before_pad, (image_width, image_height), interpolation=resize_mode)
  66. return image
  67. def semantics_segmentation_postprocess(input_tensor,
  68. label_index,
  69. min_area,
  70. area_rate,
  71. max_count=True):
  72. # 筛选语义分割结果
  73. select_mask = (input_tensor[..., label_index].squeeze() > 0.5) * 255
  74. select_mask = np.array(select_mask, np.uint8)
  75. all_area = select_mask.shape[0] * select_mask.shape[1]
  76. output_mask = np.zeros(select_mask.shape[:2], dtype=np.uint8)
  77. # 对应轮廓寻找
  78. contours, _ = cv2.findContours(select_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  79. contours = list(contours)
  80. output_contours = []
  81. if len(contours) != 0:
  82. contours.sort(key=lambda x: cv2.contourArea(x), reverse=True)
  83. for contour_idx in range(len(contours)):
  84. # 轮廓为一点时,是个向量
  85. if np.squeeze(contours[contour_idx]).shape[0] > 10 and \
  86. cv2.contourArea(contours[contour_idx]) > min_area and \
  87. cv2.contourArea(contours[contour_idx]) > (all_area * area_rate):
  88. output_contours.append(contours[contour_idx])
  89. if len(output_contours) != 0:
  90. if max_count:
  91. output_mask = cv2.drawContours(output_mask.copy(), [output_contours[0]], -1, label_index, cv2.FILLED)
  92. else:
  93. output_mask = cv2.drawContours(output_mask.copy(), output_contours, -1, label_index, cv2.FILLED)
  94. return output_mask
  95. def contour_sort(mask, output_label, label_mapper):
  96. each_image_label_list = []
  97. # 对应轮廓寻找
  98. contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
  99. contours = list(contours)
  100. contours.sort(key=lambda x: cv2.contourArea(x), reverse=True)
  101. for i in range(len(contours)):
  102. each_label_dict = {
  103. "contour": contours[i],
  104. "label_value": label_mapper[output_label],
  105. "area": cv2.contourArea(contours[i])}
  106. each_image_label_list.append(each_label_dict)
  107. return each_image_label_list
  108. def draw_contours(im0, pr_contour, class_name, fontText, color_list):
  109. im0 = Image.fromarray(im0) # BGR
  110. color = tuple([int(x) for x in color_list[pr_contour['label_value'] - 1]]) # B, G, R
  111. # color = (0, 0, 255) # red
  112. draw = ImageDraw.Draw(im0)
  113. draw.text(pr_contour['contour'][0][0], class_name[pr_contour['label_value'] - 1],
  114. fill=color, font=fontText, stroke_width=1)
  115. im0 = np.asarray(im0)
  116. cv2.drawContours(im0, pr_contour['contour'], contourIdx=-1, color=color, thickness=2)
  117. return im0
  118. def get_bounding_box(contours, img_size, expand_pixel=40):
  119. # 默认外扩40个像素
  120. roi_points = [list(i[0]) for i in contours['contour']]
  121. x, y = [], []
  122. point_count = len(roi_points)
  123. for ni in range(point_count):
  124. point_x, point_y = roi_points[ni][:]
  125. x.append(int(point_x))
  126. y.append(int(point_y))
  127. left, right = max(min(x) - expand_pixel, 0), min(max(x) + expand_pixel, img_size[1]) # 防止超出边界
  128. top, bottom = max(min(y) - expand_pixel, 0), min(max(y) + expand_pixel, img_size[0])
  129. return [top, bottom, left, right]
  130. def generate_pred_file(predict_info, index_class_map):
  131. predict_info.sort(key=lambda x: x["Label"], reverse=False)
  132. output_info = [{
  133. "FileResultInfos": [{
  134. "Index": 0,
  135. "FrameStatus": None,
  136. "LabeledResult": {"Rois": []}}]}]
  137. for label_idx, label_info in enumerate(predict_info):
  138. if len(label_info["Contours"]) > 0:
  139. pts = [{"X": x, "Y": y} for x, y in label_info["Contours"][0]]
  140. output_info[0]['FileResultInfos'][0]['LabeledResult']['Rois'].append({
  141. "Index": label_idx,
  142. "Points": pts,
  143. "Conclusion": {
  144. "Title": index_class_map[label_info["Label"]],
  145. "Confidence": label_info["Confidence"]}})
  146. else:
  147. continue
  148. return json.dumps(output_info, ensure_ascii=False)
  149. def model_prepare(model, num_classes, device='cuda'):
  150. if model.endswith(".onnx"):
  151. _backend = "ONNX"
  152. elif model.endswith((".pt", ".pth")):
  153. _backend = "PyTorch"
  154. else:
  155. _backend = None
  156. raise RuntimeError("args model must be supported format files: {'.onnx', '.pt', '.pth'}!")
  157. if _backend == "ONNX":
  158. sess_options = ort.SessionOptions()
  159. sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
  160. sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
  161. providers = ['CPUExecutionProvider']
  162. if device == "cuda":
  163. providers = ['CUDAExecutionProvider'] + providers
  164. elif device == "limit_cpu":
  165. sess_options.inter_op_num_threads = 1 # number of threads used to parallelize the execution across nodes
  166. sess_options.intra_op_num_threads = 1 # number of threads used to parallelize the execution within nodes
  167. # load ONNX model
  168. seg_model = ort.InferenceSession(model, sess_options, providers=providers)
  169. nc = []
  170. for output in seg_model.get_outputs():
  171. nc += [output.shape[1]]
  172. elif _backend == "PyTorch":
  173. print("You are using PyTorch as inference backend. Please make sure you call the correct model module!")
  174. import models
  175. seg_model = models.PPLiteSeg(
  176. num_classes=num_classes, backbone=models.STDC1(attn=True), arm_type="UAFM_ChAtten", apply_nonlin=True)
  177. csd = torch.load(model, map_location='cpu')['model'].float().state_dict()
  178. seg_model.load_state_dict(csd) # load model state_dict
  179. seg_model.to(device).eval()
  180. nc = model.num_classes
  181. else:
  182. seg_model = None
  183. nc = None
  184. return seg_model, nc, _backend
  185. def focus_seg(image,
  186. model,
  187. label_mapper,
  188. image_pre_size=(256, 256),
  189. resize_mode="normal",
  190. preprocess_mode=None,
  191. inference_backend="ONNX",
  192. postprocess_cfgs=None):
  193. """分割模型识别函数
  194. Args:
  195. image: 输入为裁切掉医院等无关信息的图
  196. model: 分割模型
  197. label_mapper: 用于对模型输出类别的映射
  198. image_pre_size: 图像预处理尺寸
  199. resize_mode: 图像缩放方式
  200. preprocess_mode: 图像预处理方式
  201. inference_backend: 推理后端类型
  202. postprocess_cfgs: 后处理配置
  203. Returns:
  204. tuple: tuple contains:
  205. output_contours: 模型预测的3维数组, 原图像尺寸
  206. output_mask: 形状为[N, H, W]的原尺寸掩模图像
  207. T0: 图片前处理耗时
  208. T1: 模型推理耗时
  209. """
  210. t0_0 = time.time()
  211. image_src = image.copy()
  212. # 尺寸放缩到固定大小
  213. if resize_mode == "normal":
  214. h = image_pre_size[0]
  215. w = image_pre_size[1]
  216. image = cv2.resize(image, (w, h))
  217. elif resize_mode == "fitLargeSizeAndPad":
  218. image = resize_LongestMaxSize(image, image_pre_size[0], resize_mode=cv2.INTER_CUBIC)
  219. else:
  220. image = image
  221. # 预测网络固定前处理部分
  222. if preprocess_mode == "normalization1":
  223. predict_image = (image / 255.0).astype("float32")
  224. elif preprocess_mode == "normalization2":
  225. predict_image = ((image / 255.0).astype("float32") - 0.5) / 2.0
  226. elif preprocess_mode == "decentralization":
  227. img_mean = np.mean(image)
  228. img_std = np.std(image)
  229. predict_image = ((image - img_mean + 10e-7) / (img_std + 10e-7)).astype("float32")
  230. elif preprocess_mode == "int8":
  231. predict_image = (image - 255).astype("int8")
  232. else:
  233. predict_image = image
  234. ############################
  235. t0_1 = time.time()
  236. T0 = t0_1 - t0_0
  237. print("图片前处理时间:%0.8f" % T0)
  238. t1_0 = time.time()
  239. # 模型预测
  240. predict_image = np.expand_dims(predict_image, axis=0)
  241. input_image = predict_image.transpose((0, 3, 1, 2))
  242. if inference_backend == "ONNX":
  243. input_name = model.get_inputs()[0].name
  244. pr_masks = model.run([], {input_name: input_image})
  245. pr_masks = [pr_masks] if not isinstance(pr_masks, list) else pr_masks
  246. elif inference_backend == "PyTorch":
  247. input_image = torch.from_numpy(input_image).to("cuda" if torch.cuda.is_available() else "cpu")
  248. pr_masks = model(input_image)
  249. pr_masks = [pr_masks] if not isinstance(pr_masks, list) else pr_masks
  250. pr_masks = [pr_mask.detach().cpu().numpy() for pr_mask in pr_masks]
  251. else:
  252. raise RuntimeError("Unsupported backend type! Only ONNX and PyTorch are supported.")
  253. t1_1 = time.time()
  254. T1 = t1_1 - t1_0
  255. print("模型预测时间:%0.8f" % T1)
  256. pr_masks = [pr_mask.squeeze(0).transpose((1, 2, 0)) for pr_mask in pr_masks]
  257. if resize_mode == "fitLargeSizeAndPad":
  258. pr_masks = [
  259. resize_LongestMaxSize_back(pr_mask, image_src.shape[0:2], resize_mode=cv2.INTER_LINEAR
  260. ) for pr_mask in pr_masks
  261. ]
  262. else:
  263. pr_masks = [
  264. cv2.resize(pr_mask, (image_src.shape[1], image_src.shape[0]), interpolation=cv2.INTER_LINEAR
  265. ) for pr_mask in pr_masks
  266. ]
  267. # mask后处理
  268. output_contours = []
  269. for out_idx, mask in enumerate(pr_masks):
  270. for class_idx, label in enumerate(postprocess_cfgs[out_idx].keys()):
  271. min_area, area_rate, max_count = postprocess_cfgs[out_idx][label]
  272. pr_class_mask = semantics_segmentation_postprocess(mask, class_idx + 1, min_area, area_rate, max_count)
  273. pr_class_contour = contour_sort(pr_class_mask, class_idx + 1, label_mapper[out_idx])
  274. output_contours.extend(pr_class_contour)
  275. output_contours.sort(key=lambda x: x["area"], reverse=True)
  276. output_mask = np.zeros((image_src.shape[0], image_src.shape[1]), dtype=np.uint8)
  277. for single_contour in output_contours:
  278. output_mask = cv2.drawContours(output_mask.copy(), [single_contour["contour"]], -1,
  279. single_contour["label_value"], cv2.FILLED)
  280. return output_contours, output_mask, T0, T1
  281. def run(model,
  282. source,
  283. imgsz=(256, 256),
  284. postprocess_cfgs=None,
  285. save_dir=None,
  286. device='cuda',
  287. save_txt=False):
  288. """用于非串联检测,如仅检测颈动脉
  289. """
  290. source = str(source)
  291. fontText = ImageFont.truetype(r"C:\Windows\Fonts\msyhl.ttc", 18, encoding="utf-8")
  292. if not isinstance(postprocess_cfgs, list):
  293. postprocess_cfgs = [postprocess_cfgs]
  294. class_name = []
  295. for cfg in postprocess_cfgs:
  296. class_name += list(cfg.keys())
  297. # prepare model
  298. seg_model, nc, _backend = model_prepare(model, 3, device) # 3为模型的输出类别数量,依据实际情况修改
  299. if not isinstance(nc, list):
  300. nc = [nc]
  301. nc = [c - 1 for c in nc]
  302. label_mapper = [{i: i+sum(nc[:idx]) if idx > 0 else i for i in range(1, c+1)} for idx, c in enumerate(nc)]
  303. # Dataloader
  304. dataset = LoadImages(source)
  305. bs = len(dataset)
  306. vid_path, vid_writer = [None] * bs, [None] * bs
  307. # Run inference
  308. Time = []
  309. for path, im, im0, vid_cap, s in dataset:
  310. print(f'{s}')
  311. save_path = str(save_dir / Path(path).name)
  312. pr_contours, pr_focus_mask, T0, T1 = focus_seg(im, seg_model, label_mapper,
  313. image_pre_size=imgsz,
  314. resize_mode="normal",
  315. preprocess_mode="normalization1",
  316. inference_backend=_backend,
  317. postprocess_cfgs=postprocess_cfgs)
  318. Time.append(T1) # record inference time
  319. # draw contours
  320. for pr_contour in pr_contours:
  321. im0 = draw_contours(im0, pr_contour, class_name, fontText, COLOR_LIST_ORGAN)
  322. if dataset.mode == 'image':
  323. cv2.imencode(".jpg", im0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])[1].tofile(save_path)
  324. else: # 'video'
  325. if vid_path[dataset.count] != save_path: # new video
  326. vid_path[dataset.count] = save_path
  327. if isinstance(vid_writer[dataset.count], cv2.VideoWriter):
  328. vid_writer[dataset.count].release() # release previous video writer
  329. if vid_cap: # video
  330. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  331. w = im0.shape[1]
  332. h = im0.shape[0]
  333. else:
  334. fps, w, h = 30, im.shape[1], im.shape[0]
  335. save_path = str(Path(save_path).with_suffix('.mp4'))
  336. vid_writer[dataset.count] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  337. vid_writer[dataset.count].write(im0)
  338. if save_txt:
  339. pred_info = prfile_generate_semantic_segmentation(pr_contours, im.shape[:2])
  340. prf = generate_pred_file(pred_info, {k + 1: v for k, v in enumerate(class_name)})
  341. with open(Path(save_path).with_suffix(".txt"), "w", encoding="utf-8") as f:
  342. f.write(prf)
  343. # exclude warmup inference
  344. warmup = 5
  345. for _ in range(warmup):
  346. Time.pop(0)
  347. inference_speed = sum(Time) / len(Time)
  348. print('#################################\n'
  349. '平均每张图像推理速度为:{:.6f}\n'.format(inference_speed),
  350. '#################################\n')
  351. def run_pipeline(organ_model,
  352. lesion_model,
  353. source,
  354. imgsz=(256, 256),
  355. postprocess_cfgs_organ=None,
  356. postprocess_cfgs_lesion=None,
  357. save_dir=None,
  358. device='cuda'):
  359. '''
  360. 用于串联检测,脏器+病灶,依据脏器的检测结果作为病灶的输入图像,例如颈动脉+斑块
  361. '''
  362. source = str(source)
  363. fontText = ImageFont.truetype(r"C:\Windows\Fonts\msyhl.ttc", 18, encoding="utf-8")
  364. class_name_organ = list(postprocess_cfgs_organ.keys())
  365. class_name_lesion = list(postprocess_cfgs_lesion.keys())
  366. # prepare model
  367. seg_organ_model, organ_backend = model_prepare(organ_model, 3, device) # 3为模型的输出类别数量,依据实际情况修改
  368. seg_lesion_model, lesion_backend = model_prepare(lesion_model, 2, device)
  369. # Dataloader
  370. dataset = LoadImages(source)
  371. bs = len(dataset) # batch size
  372. vid_path, vid_writer = [None] * bs, [None] * bs
  373. # Run inference
  374. Time = []
  375. for path, im, im0, vid_cap, s in dataset:
  376. print(f'{s}')
  377. save_path = str(save_dir / Path(path).name)
  378. pr_contours_organ, pr_focus_mask_organ, T0, T1 = \
  379. focus_seg(
  380. im, seg_organ_model,
  381. image_pre_size=imgsz,
  382. resize_mode="normal",
  383. preprocess_mode="normalization1",
  384. inference_backend=organ_backend,
  385. postprocess_cfgs=postprocess_cfgs_organ)
  386. for pr_contour_organ in pr_contours_organ:
  387. # TODO 需要进行病灶检测的脏器类别id,可能为多种脏器,待组成一个list
  388. if pr_contour_organ['label_value'] == 2:
  389. # 获取外接扩展矩形,如若不需要扩充像素,expand_pixel设置为0
  390. organ_bounding_box = get_bounding_box(pr_contour_organ,
  391. im.shape[:2],
  392. expand_pixel=40)
  393. im_lesion = im[organ_bounding_box[0]: organ_bounding_box[1],
  394. organ_bounding_box[2]: organ_bounding_box[3]]
  395. # 裁切好的图像送入病灶检测模型
  396. pr_contours_lesion, pr_focus_mask_lesion, T0, T1 = \
  397. focus_seg(
  398. im_lesion, seg_lesion_model,
  399. image_pre_size=imgsz,
  400. resize_mode="fitLargeSizeAndPad",
  401. preprocess_mode="normalization1",
  402. inference_backend=lesion_backend,
  403. postprocess_cfgs=postprocess_cfgs_lesion)
  404. # draw contours
  405. for pr_contour_lesion in pr_contours_lesion:
  406. im_lesion = draw_contours(im_lesion, pr_contour_lesion, class_name_lesion, fontText,
  407. COLOR_LIST_LESION)
  408. im0 = im0.copy()
  409. im0[organ_bounding_box[0]: organ_bounding_box[1],
  410. organ_bounding_box[2]: organ_bounding_box[3]] = im_lesion
  411. im0 = draw_contours(im0, pr_contour_organ, class_name_organ, fontText, COLOR_LIST_ORGAN)
  412. Time.append(T1) # record inference time
  413. if dataset.mode == 'image':
  414. cv2.imencode(".jpg", im0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])[1].tofile(save_path)
  415. else: # 'video'
  416. if vid_path[dataset.count] != save_path: # new video
  417. vid_path[dataset.count] = save_path
  418. if isinstance(vid_writer[dataset.count], cv2.VideoWriter):
  419. vid_writer[dataset.count].release() # release previous video writer
  420. if vid_cap: # video
  421. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  422. w = im0.shape[1]
  423. h = im0.shape[0]
  424. else:
  425. fps, w, h = 30, im.shape[1], im.shape[0]
  426. save_path = str(Path(save_path).with_suffix('.mp4'))
  427. vid_writer[dataset.count] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  428. vid_writer[dataset.count].write(im0)
  429. Time.pop(0)
  430. inference_speed = sum(Time) / len(Time)
  431. print('#################################\n'
  432. '平均每张图像推理速度为:{:.6f}\n'.format(inference_speed),
  433. '#################################\n')
  434. def parse_args():
  435. parser = ArgumentParser()
  436. parser.add_argument(
  437. "--organ_model",
  438. help="organ model path",
  439. type=str,
  440. default=r"model.onnx")
  441. parser.add_argument(
  442. "--lesion_model",
  443. help="lesion model path",
  444. type=str,
  445. default=None)
  446. parser.add_argument(
  447. "--source",
  448. help="root path to predict",
  449. type=str,
  450. default=None)
  451. parser.add_argument(
  452. "--imgsz",
  453. help="inference size h,w",
  454. type=tuple,
  455. default=(256, 256))
  456. parser.add_argument(
  457. "--cfg_organ",
  458. help="postprocess configurations",
  459. default=NECKORGAN_CFGS2)
  460. parser.add_argument(
  461. "--cfg_lesion",
  462. help="postprocess configurations",
  463. default=NECKLESION_CFGS)
  464. parser.add_argument(
  465. "--save_dir",
  466. help="directory to save predicting results",
  467. type=str,
  468. default=None)
  469. parser.add_argument(
  470. "--device",
  471. help="select using cuda, cpu, or cpu with one core and one thread",
  472. type=str,
  473. choices=['cuda', 'cpu', 'limit_cpu'],
  474. default="cuda")
  475. parser.add_argument(
  476. "--save_txt",
  477. help="If true, save results to *.txt",
  478. type=bool,
  479. default=False)
  480. return parser.parse_args()
  481. if __name__ == '__main__':
  482. opt = parse_args()
  483. if not isinstance(opt.cfg_organ, list):
  484. opt.cfg_organ = [opt.cfg_organ]
  485. organ_categories = []
  486. for cfg in opt.cfg_organ:
  487. organ_categories += list(cfg.keys())
  488. seed_arr_organ = np.array([range(1, 255, math.ceil(255 / len(organ_categories)))]).astype(np.uint8)
  489. COLOR_LIST_ORGAN = cv2.applyColorMap(seed_arr_organ, cv2.COLORMAP_RAINBOW)[0]
  490. seed_arr_lesion = np.array([range(1, 255, math.ceil(255 / len(opt.cfg_lesion)))]).astype(np.uint8)
  491. COLOR_LIST_LESION = cv2.applyColorMap(seed_arr_lesion, cv2.COLORMAP_RAINBOW)[0]
  492. # Directories
  493. save_dir = Path(str(opt.save_dir)) / 'predict'
  494. save_dir.mkdir(exist_ok=True)
  495. if opt.lesion_model is not None:
  496. # 串联检测,如检测颈动脉+斑块,目前仅支持串联两个模型,而串联三个模型的有待后续使用到再添加
  497. run_pipeline(
  498. opt.organ_model,
  499. opt.lesion_model,
  500. opt.source,
  501. imgsz=opt.imgsz,
  502. postprocess_cfgs_organ=opt.cfg_organ,
  503. postprocess_cfgs_lesion=opt.cfg_lesion,
  504. save_dir=save_dir,
  505. device=opt.device)
  506. else:
  507. # 非串联检测,即单个模型推理
  508. run(
  509. opt.organ_model,
  510. opt.source,
  511. imgsz=opt.imgsz,
  512. postprocess_cfgs=opt.cfg_organ,
  513. save_dir=save_dir,
  514. device=opt.device,
  515. save_txt=opt.save_txt)