main.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. import os.path
  2. import sys
  3. from pathlib import Path
  4. from enum import Enum
  5. import clr
  6. import numpy as np
  7. import TrainSdk
  8. from ctypes import *
  9. import json
  10. import copy
  11. import onnxruntime
  12. # 设置当前路径和系统路径
  13. system_path = Path(sys.executable).resolve().parents[0]
  14. current_path = Path(__file__).resolve().parents[0]
  15. network_path = os.path.join(current_path, 'depends\\Networks')
  16. # 导入c#的dll
  17. clr.AddReference('System.Drawing')
  18. clr.AddReference(os.path.join(current_path, 'depends\\AI.Common.dll'))
  19. clr.AddReference(os.path.join(current_path, 'depends\\AI.DiagSystem.dll'))
  20. clr.AddReference(os.path.join(current_path, 'depends\\ImageShowUtilsLib.dll'))
  21. clr.AddReference(os.path.join(current_path, 'depends\\MyocardialSegmenLib.dll'))
  22. from System import Array
  23. import ctypes
  24. from System.Runtime.InteropServices import GCHandle, GCHandleType
  25. from System.Drawing import Bitmap
  26. from System.IO import MemoryStream
  27. from AI.Common import InferenceNetworkUtils, InferenceCore, EnumInferCoreConfigKey, EnumDeviceType
  28. from AI.Common import InferenceNetworkInputImage, RawImage, Rect, Point2D, IDetectedObject
  29. from AI.DiagSystem import *
  30. from AI.DiagSystem.Workers.InferenceNetworks.Onnx import InferNetOnnxLesionDetectBreastBIRads, \
  31. InferNetOnnxOrganDetectorAbdomen, InferNetOnnxScanPartClrBreastAbdomenNeck
  32. from AI.DiagSystem.Workers.InferenceNetworks.Onnx import InferNetOnnxLiverFocalObd, InferNetOnnxLiverDiffuseClr, \
  33. InferNetOnnxLiverFocalSeg, InferNetOnnxLesionContourSegLiver, InferNetOnnxLesionContourSegBreast, \
  34. InferNetOnnxOrganDetectorThyroidCarotidArtery
  35. from ImageShowUtilsLib import RawImageShowUtils
  36. from MyocardialSegmenLib import InferNetOnnxMyocardialSegment
  37. from AI.Common.Tools import UsImageRegionSegHelper
  38. # gt_file_generate import
  39. from gt_file_generate import gtfilegenerate_classifications, gtfilegenerate_classifications_liverdiffuselesionclassifier
  40. from gt_file_generate import gtfilegenerate_objectdetection, gtfilegenerate_objectdetection_liverfocalobd
  41. from gt_file_generate import gtfilegenerate_semantic_segmentation, gtfilegenerate_semantic_segmentation_myocardial, \
  42. gtfilegenerate_semantic_segmentation_liverfocalSeg
  43. # pred_file_generate import
  44. from predict_file_generate import predfilegenerate_classification
  45. from predict_file_generate import predfilegenerate_object_detection
  46. from predict_file_generate import predfilegenerate_semantic_segmentation, \
  47. predfilegenerate_semantic_segmentation_myocardial
  48. # metrics import
  49. from classification_metric import Evaluator_classification
  50. from object_detection_metric import Evaluator_object_detection
  51. from semantic_segmentation_metric import Evaluator_object_semamtic_segmentation
  52. class GtFileGenerateType(Enum):
  53. """
  54. 枚举所有Gt文件的生成方法
  55. """
  56. # 适用于:导航的分类网络,判断乳腺腹部等扫查部分分类网络
  57. gtfilegenerate_classifications = 1
  58. # 适用于:肝脏弥漫性疾病分类
  59. gtfilegenerate_classifications_liverdiffuselesionclassifier = 2
  60. # 适用于: 乳腺检测,
  61. gtfilegenerate_objectdetection = 3
  62. # 适用于: 脏器分割
  63. gtfilegenerate_semantic_segmentation = 4
  64. # 适用于: 心肌的语义分割
  65. gtfilegenerate_semantic_segmentation_myocardial = 5
  66. # 适用于:肝脏局灶性疾病检测
  67. gtfilegenerate_objectdetection_liverfocalobd = 6
  68. # 适用于:肝脏局灶性疾病分割
  69. gtfilegenerate_semantic_segmentation_liverfocalSeg = 7
  70. class PredFileGenerateType(Enum):
  71. """
  72. 枚举所有Pred文件的生成方法
  73. """
  74. predfilegenerate_classification = 1
  75. predfilegenerate_object_detection = 2
  76. predfilegenerate_semantic_segmentation = 3
  77. predfilegenerate_semantic_segmentation_myocardial = 4
  78. class MetricsType(Enum):
  79. """
  80. 枚举所有评价类型,分为分类,检测,语义分割
  81. """
  82. # 分类
  83. classification_metric = 1
  84. # 检测
  85. object_detection_metric = 2
  86. # 语义分割
  87. semantic_segmentation_metric = 3
  88. class ModelName(Enum):
  89. """
  90. 枚举所有模型的名称
  91. """
  92. # 乳腺腹部其他分类模型
  93. InferNetOnnxScanPartClrBreastAbdomenNeck = 1
  94. # 乳腺检测模型
  95. InferNetOnnxLesionDetectBreastBIRads = 2
  96. # 腹部脏器分割模型
  97. InferNetOnnxOrganDetectorAbdomen = 3
  98. # 腹部局灶性检测模型
  99. InferNetOnnxLiverFocalObd = 4
  100. # 腹部弥漫性分类模型
  101. InferNetOnnxLiverDiffuseClr = 5
  102. # 腹部局灶性分割模型
  103. InferNetOnnxLiverFocalSeg = 6
  104. # 心肌分割模型:
  105. InferNetOnnxMyocardialSegment = 7
  106. # 肝脏局灶病灶前后景分割模型
  107. InferNetOnnxLesionContourSegLiver = 8
  108. # 乳腺病灶前后景分割模型
  109. InferNetOnnxLesionContourSegBreast = 9
  110. # 颈部脏器分割模型
  111. InferNetOnnxOrganDetectorThyroidCarotidArtery = 10
  112. class PlatformMetrics:
  113. def __init__(self, token, network_path, is_python_model_onnxruntime, modelname, gtfilegeneratetpye,
  114. predfilegeneratetpye, metricstpye, iscropped,
  115. numcpu, is_crop_region_affect_image_nums, crop_region_label,
  116. needed_imageresults_dict, needed_rois_dict, class_id_map, iou_thres):
  117. """
  118. :param token: 平台所需的token
  119. :param network_path: 测试模型的路径
  120. :param is_python_model_onnxruntime: 是否调用python的onnxruntime
  121. :param modelname: 模型的名称
  122. :param gtfilegeneratetpye:选择所需的gtfile生成方法
  123. :param predfilegeneratetpye:选择所需的predfile生成方法
  124. :param metricstpye:选择所需的metrics方法
  125. :param iscropped:是否裁切
  126. :param numcpu:cpu推理数量
  127. :param is_crop_region_affect_image_nums:
  128. :param crop_region_label:需要crop region的roi标签title
  129. :param needed_imageresults_dict:所需的image标签title 对应 预测模型的label,所形成的字典
  130. :param needed_rois_dict: 所需的roi标签title 对应 预测模型的label,所形成的字典
  131. :param class_id_map: 用于label改变的dict
  132. :param iou_thres: iou阈值,分类时无用
  133. """
  134. self.token = token
  135. self.network_path = network_path
  136. self.is_python_model_onnxruntime = is_python_model_onnxruntime
  137. self.modelname = modelname
  138. self.gtfilegeneratetpye = gtfilegeneratetpye
  139. self.predfilegeneratetpye = predfilegeneratetpye
  140. self.metricstpye = metricstpye
  141. self.iscropped = iscropped
  142. self.numcpu = numcpu
  143. self.is_crop_region_affect_image_nums = is_crop_region_affect_image_nums
  144. self.crop_region_label = crop_region_label
  145. self.needed_imageresults_dict = needed_imageresults_dict
  146. self.needed_rois_dict = needed_rois_dict
  147. self.class_id_map = class_id_map
  148. self.iou_thres = iou_thres
  149. def process(self):
  150. inferNet = self._choose_model_name(self.modelname)
  151. trainedNetwork = InferenceNetworkUtils().ReadNetworkDataFromFile(self.network_path, inferNet.NetworkName,
  152. inferNet.HashCode)
  153. inferenceCore = InferenceCore()
  154. inferenceCore.SetConfig(EnumInferCoreConfigKey.CPU_THREADS_NUM, str(self.numcpu), EnumDeviceType.CPU)
  155. inferNet.LoadNetwork(inferenceCore, EnumDeviceType.CPU, trainedNetwork)
  156. gt_type = self._choose_gt_file_type(self.gtfilegeneratetpye)
  157. pred_type = self._choose_pred_file_type(self.predfilegeneratetpye)
  158. metric_type = self._choose_metrics_type(self.metricstpye)
  159. image_count = TrainSdk.get_test_file_count(self.token)
  160. # 如果是python的onnxruntime,需要解析相应名称的onnx模型,设置session
  161. if self.is_python_model_onnxruntime:
  162. networkname = inferNet.NetworkName.split(".emd")[0] + '.onnx'
  163. trainedNetwork_path = os.path.join(self.network_path, networkname)
  164. onnx_path = os.path.join(os.getcwd(), trainedNetwork_path)
  165. sess_options = onnxruntime.SessionOptions()
  166. sess_options.intra_op_num_threads = self.numcpu
  167. session = onnxruntime.InferenceSession(onnx_path, sess_options, providers=['CPUExecutionProvider'])
  168. input_name = session.get_inputs()[0].name
  169. evaluator = metric_type(self.iou_thres)
  170. for i in range(image_count):
  171. imagedata, labeldata, img_name, _ = TrainSdk.get_test_labeled_file(self.token, i)
  172. stream = MemoryStream()
  173. stream.Write(imagedata, 0, len(imagedata))
  174. bitmap = Bitmap(stream)
  175. # 取图像的width和height
  176. image_width = bitmap.Width
  177. image_height = bitmap.Height
  178. image_size = [image_width, image_height]
  179. rawimage = RawImageShowUtils.BitmapToRawImage(bitmap)
  180. croprect = self._crop_image(rawimage)
  181. label_infos = json.loads(labeldata)
  182. if self.is_crop_region_affect_image_nums:
  183. label_infos_split = self._label_infos_split(label_infos, self.crop_region_label)
  184. else:
  185. label_infos_split = [label_infos]
  186. for label_info_index in range(len(label_infos_split)):
  187. gt_json = gt_type(label_infos_split[label_info_index], image_size, self.needed_imageresults_dict,
  188. self.needed_rois_dict, self.class_id_map)
  189. gt_file = json.loads(gt_json)
  190. inferinput = self._set_input_image(rawimage, label_infos_split[label_info_index], croprect)
  191. inferNet.PreProcess(inferinput) # 调用前处理
  192. # 进行推理
  193. # 推理分两种,一种调用RunModel(),采用c#框架下的onnxruntime;
  194. # 另一种将前处理之后的databuffer做成python所需要的格式,送入python框架下的onnxruntime,然后将结果赋值给_detectedResultData
  195. if self.is_python_model_onnxruntime:
  196. inputvariableshape = list(inferNet._inputVariableShape)
  197. assert len(inputvariableshape) == 4, "onnx模型输入必须为4通道"
  198. # 一般的模型都是将inferNet._moldedImage.DataBuffer放入模型进行推理
  199. # 但是弥漫性疾病分类那边,采用的是四通道,需要重新生成inferNet._inputDataBuffer放入模型进行推理
  200. if inferNet._moldedImage.DataBuffer:
  201. img_in = inferNet._moldedImage.DataBuffer
  202. else:
  203. img_in = inferNet._inputDataBuffer
  204. # img_in2直接转成list再转numpy太慢
  205. src_hndl = GCHandle.Alloc(img_in, GCHandleType.Pinned)
  206. try:
  207. src_ptr = src_hndl.AddrOfPinnedObject().ToInt64()
  208. bufType = ctypes.c_float * len(img_in)
  209. cbuf = bufType.from_address(src_ptr)
  210. img_np = np.frombuffer(cbuf, dtype=cbuf._type_)
  211. finally:
  212. if src_hndl.IsAllocated:
  213. src_hndl.Free()
  214. img_np = img_np.reshape(
  215. inputvariableshape[0], inputvariableshape[1], inputvariableshape[2], inputvariableshape[3])
  216. outputs = session.run(None, {input_name: img_np})[0]
  217. inferNet._detectedResultData = outputs.reshape(1, outputs.size).ravel().tolist()
  218. else:
  219. inferNet.RunModel()
  220. # 调用后处理
  221. idetectedobject = inferNet.PostProcess(inferinput)
  222. # 调用inferNet中的PreProcess(inferinput),RunModel(),PostProcess(inferinput) 等价于调用Process(inferinput)
  223. # idetectedobject = inferNet.Process(inferinput)
  224. pred_json = pred_type(idetectedobject, image_size, self.class_id_map)
  225. pred_file = json.loads(pred_json)
  226. image_index = str(img_name)
  227. evaluator.add_batch(gt_file, pred_file, image_index)
  228. return evaluator
  229. def _crop_image(self, image):
  230. """
  231. 调用c#的裁图函数,得到裁图框
  232. :param image: RawImage图像
  233. :return:
  234. """
  235. rect = Rect(0, 0, image.Width, image.Height)
  236. if self.iscropped:
  237. return rect
  238. else:
  239. res, dst_rect = UsImageRegionSegHelper().CropWithCvCore(image, rect)
  240. if res:
  241. return dst_rect
  242. else:
  243. return rect
  244. def _set_input_image(self, rawimage, label_info, croprect):
  245. """
  246. 生成c#中所需要的InferenceNetworkInputImage
  247. #肝脏弥漫性标注时,会标注一个弥漫性疾病的轮廓,再标注一个肝脏的轮廓的情况
  248. #因此需要max_roi_area,最终的croporgancontours中只会存在一个轮廓
  249. #后续croporgancontours中,如果需要出现多个轮廓,可改写次处
  250. #croporgancontours的结果传入InferenceNetworkInputImage中,
  251. #当_useContoursAsMask为True时,InferenceNetworkUtils.GenMaskOnImageAccordingToContours需要轮廓信息
  252. #GenMaskOnImageAccordingToContours的c++函数,需要支持多轮廓的处理,暂时没有这种情况,未验证
  253. :param rawimage:
  254. :param label_info:
  255. :param croprect:
  256. :return:
  257. """
  258. labeled_result = label_info[0]["FileResultInfos"][0]["LabeledResult"]
  259. use_crop_region = False
  260. croporganrect = Rect.Empty
  261. croporgancontours = Array[Array[Point2D]](())
  262. max_roi_area = 0
  263. for each_roi_label in labeled_result["Rois"]:
  264. roi_cls = each_roi_label["Conclusion"]["Title"]
  265. if roi_cls in self.crop_region_label:
  266. roi_points = each_roi_label["Points"]
  267. x, y, points = [], [], []
  268. for point in roi_points:
  269. point_x = int(point["X"])
  270. point_y = int(point["Y"])
  271. x.append(point_x)
  272. y.append(point_y)
  273. points.append(Point2D(point_x, point_y))
  274. left, right = min(x), max(x)
  275. top, bottom = min(y), max(y)
  276. points_ = Array[Point2D](points)
  277. roi_contour = Array[Array[Point2D]]([points_])
  278. roi_area = (right - left) * (bottom - top)
  279. if roi_area > max_roi_area:
  280. croporganrect = Rect(left, top, right - left, bottom - top)
  281. croporgancontours = roi_contour
  282. max_roi_area = roi_area
  283. use_crop_region = True
  284. if use_crop_region:
  285. inputroi = croporganrect
  286. else:
  287. inputroi = croprect
  288. inferinput = InferenceNetworkInputImage(rawimage, inputroi, 1.0, croporgancontours, croprect)
  289. return inferinput
  290. def _label_infos_split(self, label_info, crop_region_label):
  291. """
  292. 将label_info拆成不同的部分
  293. crop_region_affect_image_nums存在时,
  294. 一张图中,在crop_region_label中如果有多个目标,此时该图像已经不作为一张图使用,
  295. 需要逐个根据crop_region_label,逐个拆分所需的label_info
  296. :param label_info:
  297. :param crop_region_label:
  298. :return:
  299. """
  300. labeled_result = label_info[0]["FileResultInfos"][0]["LabeledResult"]
  301. needed_roi_infos = []
  302. # 过一遍数据,取到所需要的crop_region_roi信息
  303. for each_roi_label in labeled_result["Rois"]:
  304. roi_cls = each_roi_label["Conclusion"]["Title"]
  305. if roi_cls in crop_region_label:
  306. needed_roi_infos.append(each_roi_label)
  307. if len(needed_roi_infos) <= 1:
  308. return [label_info]
  309. else:
  310. label_infos_split = []
  311. for i in range(len(needed_roi_infos)):
  312. new_info = copy.deepcopy(labeled_result)
  313. for j in range(len(needed_roi_infos)):
  314. if j != i:
  315. new_info["Rois"].remove(needed_roi_infos[j])
  316. label_infos_split.append(new_info)
  317. return label_infos_split
  318. @staticmethod
  319. def _choose_model_name(modelname):
  320. """
  321. 根据ModelName中枚举的各个模型名称,选择所对应的推理模型,即选择c#中相应的模型推断class
  322. :param modelname:
  323. :return:
  324. """
  325. # 乳腺检测模型
  326. if modelname == ModelName.InferNetOnnxLesionDetectBreastBIRads:
  327. inferNet = InferNetOnnxLesionDetectBreastBIRads()
  328. elif modelname == ModelName.InferNetOnnxOrganDetectorAbdomen:
  329. inferNet = InferNetOnnxOrganDetectorAbdomen()
  330. # 乳腺腹部其他分类模型
  331. elif modelname == ModelName.InferNetOnnxScanPartClrBreastAbdomenNeck:
  332. inferNet = InferNetOnnxScanPartClrBreastAbdomenNeck()
  333. # 腹部局灶性检测模型
  334. elif modelname == ModelName.InferNetOnnxLiverFocalObd:
  335. inferNet = InferNetOnnxLiverFocalObd()
  336. # 腹部弥漫性分类模型
  337. elif modelname == ModelName.InferNetOnnxLiverDiffuseClr:
  338. inferNet = InferNetOnnxLiverDiffuseClr()
  339. # 腹部局灶性分割模型
  340. elif modelname == ModelName.InferNetOnnxLiverFocalSeg:
  341. inferNet = InferNetOnnxLiverFocalSeg()
  342. # 心肌分割模型:
  343. elif modelname == ModelName.InferNetOnnxMyocardialSegment:
  344. inferNet = InferNetOnnxMyocardialSegment()
  345. elif modelname == ModelName.InferNetOnnxLesionContourSegLiver:
  346. inferNet = InferNetOnnxLesionContourSegLiver()
  347. elif modelname == ModelName.InferNetOnnxLesionContourSegBreast:
  348. inferNet = InferNetOnnxLesionContourSegBreast()
  349. # 颈部脏器分割模型:
  350. elif modelname == ModelName.InferNetOnnxOrganDetectorThyroidCarotidArtery:
  351. inferNet = InferNetOnnxOrganDetectorThyroidCarotidArtery()
  352. else:
  353. raise Exception("Wrong modelname, choose correct one")
  354. return inferNet
  355. @staticmethod
  356. def _choose_gt_file_type(gtfilegeneratetpye):
  357. """
  358. 根据GtFileGenerateType中枚举的各种生成Gt文件函数,选择对应的函数
  359. 函数在gt_file_generate.py中定义
  360. :param gtfilegeneratetpye:
  361. :return:
  362. """
  363. if gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_classifications_liverdiffuselesionclassifier:
  364. gt_type = gtfilegenerate_classifications_liverdiffuselesionclassifier
  365. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_classifications:
  366. gt_type = gtfilegenerate_classifications
  367. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_objectdetection:
  368. gt_type = gtfilegenerate_objectdetection
  369. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_semantic_segmentation:
  370. gt_type = gtfilegenerate_semantic_segmentation
  371. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_semantic_segmentation_myocardial:
  372. gt_type = gtfilegenerate_semantic_segmentation_myocardial
  373. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_objectdetection_liverfocalobd:
  374. gt_type = gtfilegenerate_objectdetection_liverfocalobd
  375. elif gtfilegeneratetpye == GtFileGenerateType.gtfilegenerate_semantic_segmentation_liverfocalSeg:
  376. gt_type = gtfilegenerate_semantic_segmentation_liverfocalSeg
  377. else:
  378. raise Exception("Wrong gtfilegeneratetpye, choose correct one")
  379. return gt_type
  380. @staticmethod
  381. def _choose_pred_file_type(predfilegeneratetpye):
  382. """
  383. 根据GtFileGenerateType中枚举的各种生成Gt文件函数,选择对应的函数
  384. 函数在predict_file_generate.py中定义
  385. :param predfilegeneratetpye:
  386. :return:
  387. """
  388. if predfilegeneratetpye == PredFileGenerateType.predfilegenerate_classification:
  389. pred_type = predfilegenerate_classification
  390. elif predfilegeneratetpye == PredFileGenerateType.predfilegenerate_object_detection:
  391. pred_type = predfilegenerate_object_detection
  392. elif predfilegeneratetpye == PredFileGenerateType.predfilegenerate_semantic_segmentation:
  393. pred_type = predfilegenerate_semantic_segmentation
  394. elif predfilegeneratetpye == PredFileGenerateType.predfilegenerate_semantic_segmentation_myocardial:
  395. pred_type = predfilegenerate_semantic_segmentation_myocardial
  396. else:
  397. raise Exception("Wrong predfilegeneratetpye, choose correct one")
  398. return pred_type
  399. @staticmethod
  400. def _choose_metrics_type(metricstype):
  401. """
  402. 根据MetricsType中枚举的不同类型评价,选择所对应的,暂时只支持分类,检测,语义分割
  403. :param metricstype:
  404. :return:
  405. """
  406. if metricstype == MetricsType.classification_metric:
  407. metrics = Evaluator_classification
  408. elif metricstype == MetricsType.object_detection_metric:
  409. metrics = Evaluator_object_detection
  410. elif metricstype == MetricsType.semantic_segmentation_metric:
  411. metrics = Evaluator_object_semamtic_segmentation
  412. else:
  413. raise Exception("Wrong metrics tpye, choose correct one")
  414. return metrics
  415. if __name__ == '__main__':
  416. # 乳腺检测模型设置
  417. """
  418. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  419. is_python_model_onnxruntime = True
  420. modelname = ModelName.InferNetOnnxLesionDetectBreastBIRads
  421. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_objectdetection
  422. metricstpye = MetricsType.object_detection_metric
  423. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_object_detection
  424. iscropped = True
  425. numcpu = 1
  426. is_crop_region_affect_image_nums = False
  427. crop_region_label = []
  428. needed_imageresults_dict = {
  429. "未见明显异常": 0
  430. }
  431. needed_rois_dict = {
  432. '脂肪瘤': 1,
  433. 'BI-RADS 2': 2,
  434. 'BI-RADS 3': 3,
  435. 'BI-RADS 4a': 4,
  436. 'BI-RADS 4b': 5,
  437. 'BI-RADS 4c': 6,
  438. 'BI-RADS 5': 7,
  439. }
  440. class_id_map = {
  441. 0: 0,
  442. 1: 1,
  443. 2: 1,
  444. 3: 1,
  445. 4: 2,
  446. 5: 2,
  447. 6: 2,
  448. 7: 2
  449. }
  450. """
  451. # 脏器分割模型设置
  452. '''
  453. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  454. is_python_model_onnxruntime = True
  455. modelname = ModelName.InferNetOnnxOrganDetectorAbdomen
  456. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation
  457. metricstpye = MetricsType.semantic_segmentation_metric
  458. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation
  459. iscropped = True
  460. numcpu = 1
  461. is_crop_region_affect_image_nums = False
  462. crop_region_label = []
  463. needed_imageresults_dict = {}
  464. needed_rois_dict = {
  465. '肝': 3,
  466. '胆囊胆道': 4,
  467. '肾脏': 5,
  468. '脾脏': 6,
  469. }
  470. class_id_map = {
  471. 3: 3,
  472. 4: 4,
  473. 5: 5,
  474. 6: 6,
  475. }
  476. '''
  477. # 颈部脏器分割模型设置
  478. """
  479. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  480. is_python_model_onnxruntime = True
  481. modelname = ModelName.InferNetOnnxOrganDetectorThyroidCarotidArtery
  482. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation
  483. metricstpye = MetricsType.semantic_segmentation_metric
  484. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation
  485. iscropped = True
  486. numcpu = 1
  487. is_crop_region_affect_image_nums = False
  488. crop_region_label = []
  489. needed_imageresults_dict = {"无可标项": 0}
  490. needed_rois_dict = {
  491. "甲状腺横切": 9,
  492. "甲状腺纵切": 9,
  493. "颈动脉短轴": 8,
  494. "颈动脉长轴": 8,
  495. }
  496. class_id_map = {
  497. 0: 0,
  498. 8: 2,
  499. 9: 1,
  500. }
  501. """
  502. # 心肌分割模型设置
  503. # 心肌那边c#输出来的label只有0,预测结果那边设置了加一,因此有轮廓对应1
  504. '''
  505. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  506. is_python_model_onnxruntime = True
  507. modelname = ModelName.InferNetOnnxMyocardialSegment
  508. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation_myocardial
  509. metricstpye = MetricsType.semantic_segmentation_metric
  510. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation_myocardial
  511. iscropped = True
  512. numcpu = 1
  513. is_crop_region_affect_image_nums = False
  514. crop_region_label = []
  515. needed_imageresults_dict = {'非B模式图': 0, "其他切面": 0}
  516. needed_rois_dict = {
  517. '四腔心_心肌': 1,
  518. '心尖段_心肌': 1,
  519. '三腔心_心肌': 1,
  520. '基底段_心肌': 1,
  521. '乳头肌水平切面_心肌': 1,
  522. '两腔心_心肌': 1
  523. }
  524. class_id_map = {
  525. 0: 0,
  526. 1: 1
  527. }
  528. '''
  529. # 肝脏局灶性检测
  530. """
  531. # gt file生成时如果只出现弥漫性疾病,该图为背景;既出现局灶性疾病,又出现弥漫性疾病,为有疾病图像
  532. # gt file生成时如果只出现未见明显异常,为背景图像;如果出现未见明显异常,和局灶性疾病,该图像无效
  533. # gt file生成时如果出现弥漫性疾病和未见明显异常,则为背景图像
  534. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  535. is_python_model_onnxruntime = True
  536. modelname = ModelName.InferNetOnnxLiverFocalObd
  537. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_objectdetection_liverfocalobd
  538. metricstpye = MetricsType.object_detection_metric
  539. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_object_detection
  540. iscropped = True
  541. numcpu = 1
  542. is_crop_region_affect_image_nums = False
  543. crop_region_label = ["肝脏未见明显异常",
  544. "肝脏有局灶性疾病/区域",
  545. "脂肪肝声像图改变",
  546. "肝脏弥漫性病变声像图改变",
  547. "肝硬化声像图改变"]
  548. needed_imageresults_dict = {}
  549. needed_rois_dict = {
  550. "肝脏未见明显异常": 0,
  551. "肝脏有局灶性疾病/区域": 0,
  552. "脂肪肝声像图改变": 0,
  553. "肝脏弥漫性病变声像图改变": 0,
  554. "肝硬化声像图改变": 0,
  555. "肝内强回声灶": 1,
  556. "肝血管瘤声像图改变": 2,
  557. "肝囊肿": 3,
  558. "肝癌可能": 4
  559. }
  560. class_id_map = {
  561. 0: 0,
  562. 1: 1,
  563. 2: 2,
  564. 3: 3,
  565. 4: 4
  566. }
  567. """
  568. # 肝脏局灶性分割
  569. """
  570. # gt file生成时如果只出现弥漫性疾病,该图为背景;既出现局灶性疾病,又出现弥漫性疾病,为有疾病图像
  571. # gt file生成时如果只出现未见明显异常,为背景图像;如果出现未见明显异常,和局灶性疾病,该图像无效
  572. # gt file生成时如果出现弥漫性疾病和未见明显异常,则为背景图像
  573. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  574. is_python_model_onnxruntime = True
  575. modelname = ModelName.InferNetOnnxLiverFocalSeg
  576. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation_liverfocalSeg
  577. metricstpye = MetricsType.semantic_segmentation_metric
  578. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation
  579. iscropped = True
  580. numcpu = 1
  581. is_crop_region_affect_image_nums = False
  582. crop_region_label = ["肝脏未见明显异常",
  583. "肝脏有局灶性疾病/区域",
  584. "脂肪肝声像图改变",
  585. "肝脏弥漫性病变声像图改变",
  586. "肝硬化声像图改变",
  587. "肝血吸虫病声像图改变"]
  588. needed_imageresults_dict = {}
  589. needed_rois_dict = {
  590. "肝脏未见明显异常": 0,
  591. "肝脏有局灶性疾病/区域": 0,
  592. "脂肪肝声像图改变": 0,
  593. "肝脏弥漫性病变声像图改变": 0,
  594. "肝硬化声像图改变": 0,
  595. "肝血吸虫病声像图改变": 0,
  596. "肝内强回声灶": 1,
  597. "肝血管瘤声像图改变": 2,
  598. "肝囊肿": 3,
  599. "肝癌可能": 4
  600. }
  601. class_id_map = {
  602. 0: 0,
  603. 1: 1,
  604. 2: 2,
  605. 3: 3,
  606. 4: 4
  607. }
  608. """
  609. # 肝脏弥漫性疾病分类
  610. # 肝脏弥漫性分类,比较特殊,因为标注问题,会出现:标注一个肝的轮廓,再标注一个弥漫性疾病的情况,
  611. # 分类标签无法确定,根据rois标签的title,将'肝脏未见明显异常'删除
  612. """
  613. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  614. is_python_model_onnxruntime = True
  615. modelname = ModelName.InferNetOnnxLiverDiffuseClr
  616. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_classifications_liverdiffuselesionclassifier
  617. metricstpye = MetricsType.classification_metric
  618. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_classification
  619. iscropped = True
  620. numcpu = 1
  621. is_crop_region_affect_image_nums = False
  622. crop_region_label = ["肝脏未见明显异常",
  623. "多囊肝声像图改变",
  624. "脂肪肝声像图改变",
  625. "肝脏弥漫性病变声像图改变",
  626. "肝硬化声像图改变"]
  627. needed_imageresults_dict = {}
  628. needed_rois_dict = {
  629. "肝脏未见明显异常": 0,
  630. "多囊肝声像图改变": 8,
  631. "脂肪肝声像图改变": 5,
  632. "肝脏弥漫性病变声像图改变": 6,
  633. "肝硬化声像图改变": 7,
  634. }
  635. class_id_map = {
  636. 0: 0,
  637. 5: 5,
  638. 6: 6,
  639. 7: 7,
  640. 8: 8
  641. }
  642. """
  643. # 肝脏局灶性疾病前后景分割
  644. '''
  645. # 前后景分割时,
  646. # needed_rois_dict中的
  647. # {
  648. # "肝脏未见明显异常":0,
  649. # "肝脏有局灶性疾病/区域":0,
  650. # "脂肪肝声像图改变":0,
  651. # "肝脏弥漫性病变声像图改变":0,
  652. # "肝硬化声像图改变":0,
  653. # "肝血吸虫病声像图改变":0,
  654. # } 并没有用,因为没有给到病灶轮廓,会拿原图进行推理,没有意义
  655. # 建议设置的原因:gt file生成时辅助判断
  656. # 比如:如果出现未见明显异常,和局灶性疾病,该图像应该是无效的,没有设置上述的needed_rois_dict,则会当作有效进行推理
  657. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  658. is_python_model_onnxruntime = False
  659. modelname = ModelName.InferNetOnnxLesionContourSegLiver
  660. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation_liverfocalSeg
  661. metricstpye = MetricsType.semantic_segmentation_metric
  662. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation
  663. iscropped = True
  664. numcpu = 1
  665. is_crop_region_affect_image_nums = True
  666. crop_region_label = ["肝内强回声灶", "肝血管瘤声像图改变", "肝囊肿", "肝癌可能"]
  667. needed_imageresults_dict = {}
  668. needed_rois_dict = {
  669. "肝脏未见明显异常": 0,
  670. "肝脏有局灶性疾病/区域": 0,
  671. "脂肪肝声像图改变": 0,
  672. "肝脏弥漫性病变声像图改变": 0,
  673. "肝硬化声像图改变": 0,
  674. "肝血吸虫病声像图改变": 0,
  675. "肝内强回声灶": 1,
  676. "肝血管瘤声像图改变": 1,
  677. "肝囊肿": 1,
  678. "肝癌可能": 1
  679. }
  680. class_id_map = {
  681. 0: 0,
  682. 1: 1,
  683. }
  684. '''
  685. # 乳腺病灶前后景分割
  686. '''
  687. # 前后景分割时,needed_imageresults_dict = {"未见明显异常": 0}并没有用,因为没有给到病灶轮廓,
  688. # 会拿原图进行推理,没有意义
  689. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  690. is_python_model_onnxruntime = False
  691. modelname = ModelName.InferNetOnnxLesionContourSegBreast
  692. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_semantic_segmentation
  693. metricstpye = MetricsType.semantic_segmentation_metric
  694. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_semantic_segmentation
  695. iscropped = True
  696. numcpu = 1
  697. is_crop_region_affect_image_nums = True
  698. crop_region_label = ["脂肪瘤", "BI-RADS 2", 'BI-RADS 3', 'BI-RADS 4a','BI-RADS 4b', 'BI-RADS 4c', 'BI-RADS 5']
  699. needed_imageresults_dict = {"未见明显异常": 0}
  700. needed_rois_dict = {
  701. '脂肪瘤': 1,
  702. 'BI-RADS 2': 1,
  703. 'BI-RADS 3': 1,
  704. 'BI-RADS 4a': 1,
  705. 'BI-RADS 4b': 1,
  706. 'BI-RADS 4c': 1,
  707. 'BI-RADS 5': 1,
  708. }
  709. class_id_map = {
  710. 0:0,
  711. 1:1,
  712. }
  713. '''
  714. # 乳腺腹部颈部等脏器分类 导航扫查分类分类 两者一致
  715. # 不同测试时,只需修改模型
  716. # 暂时导航那边的模型导入dll未更新,后续更新可加入
  717. # 没有拿到脏器分类的数据,写了导航那边的标签,测试脏器分类的模型
  718. """
  719. token = "4925EC4929684AA0ABB0173B03CFC8FF"
  720. is_python_model_onnxruntime = True
  721. modelname = ModelName.InferNetOnnxScanPartClrBreastAbdomenNeck
  722. gtfilegeneratetpye = GtFileGenerateType.gtfilegenerate_classifications
  723. metricstpye = MetricsType.classification_metric
  724. predfilegeneratetpye = PredFileGenerateType.predfilegenerate_classification
  725. iscropped = True
  726. numcpu = 1
  727. is_crop_region_affect_image_nums = False
  728. crop_region_label = []
  729. needed_imageresults_dict = {
  730. "肝脏剑突纵扫标准面": 0,
  731. "肝脏剑突纵扫偏上": 1,
  732. "肝脏剑突纵扫偏下": 1,
  733. "肝脏剑突纵扫偏左": 1,
  734. "肝脏剑突纵扫偏右": 2,
  735. "肝脏剑突纵扫无效": 3}
  736. needed_rois_dict = {}
  737. class_id_map = {
  738. 0: 0,
  739. 1: 1,
  740. 2: 2,
  741. 3: 3
  742. }
  743. """
  744. test_result = PlatformMetrics(
  745. token=token,
  746. network_path=network_path,
  747. is_python_model_onnxruntime=is_python_model_onnxruntime,
  748. modelname=modelname,
  749. gtfilegeneratetpye=gtfilegeneratetpye,
  750. predfilegeneratetpye=predfilegeneratetpye,
  751. metricstpye=metricstpye,
  752. iscropped=iscropped,
  753. numcpu=numcpu,
  754. is_crop_region_affect_image_nums=is_crop_region_affect_image_nums,
  755. crop_region_label=crop_region_label,
  756. needed_imageresults_dict=needed_imageresults_dict,
  757. needed_rois_dict=needed_rois_dict,
  758. class_id_map=class_id_map,
  759. iou_thres=0.5,
  760. )
  761. evaluator = test_result.process()
  762. print("--------------------------------------可用图像report--------------------------------------")
  763. wrong_file = evaluator.wrong_file
  764. print("gt file有问题的image:{}".format(wrong_file['gt_wrong']))
  765. print("pred file有问题的image:{}".format(wrong_file['pred_wrong']))
  766. all_image_dict = evaluator.all_image_dict
  767. print("如果is_crop_region_affect_image_nums为True,所有的图像数量大于等于原图的数量")
  768. print('所有可用的图像数量:{}'.format(all_image_dict['images_all_nums']))
  769. print("--------------------------------------背景图像report--------------------------------------")
  770. background_images_results_count = evaluator.background_images_results_count
  771. for key in background_images_results_count.keys():
  772. print(key + ':' + str(background_images_results_count[key]))
  773. background_images_results = evaluator.background_images_results
  774. for key in background_images_results.keys():
  775. print(key + ':' + str(background_images_results[key]))
  776. print("--------------------------------------非背景图像report--------------------------------------")
  777. print('非背景图像数量:{}'.format(
  778. all_image_dict['images_all_nums'] - background_images_results_count['background_images_all_nums']))
  779. metricsPerClass = evaluator.generate_metrics()
  780. for mc in metricsPerClass:
  781. c = mc['class']
  782. precision = mc['precision']
  783. recall = mc['recall']
  784. # ipre = mc['interpolated precision']
  785. # irec = mc['interpolated recall']
  786. total_positives = mc['total positives']
  787. total_TP = mc['total TP']
  788. total_FP = mc['total FP']
  789. precision_all = 0 if (total_TP + total_FP) == 0 else total_TP / (total_TP + total_FP)
  790. recall_all = 0 if total_positives == 0 else total_TP / total_positives
  791. # Print AP per class
  792. print('Label:%s, total_TP: %d, total_FP: %d, total_positives_gt: %d, precision: %f, recall: %f '
  793. % (c, total_TP, total_FP, total_positives, precision_all, recall_all))
  794. try:
  795. average_precision = mc['AP']
  796. print('Label:%s, AP: %f, ' % (c, average_precision))
  797. except:
  798. continue
  799. all_no_background_images_fp_results = evaluator.all_no_background_images_fp_results
  800. for key in all_no_background_images_fp_results.keys():
  801. each_result = all_no_background_images_fp_results[key]
  802. print('Label:' + key + ',FP对应的image:' + str(each_result))
  803. all_no_background_images_pos_results = evaluator.all_no_background_images_pos_results
  804. all_no_background_images_tp_results = evaluator.all_no_background_images_tp_results
  805. for key in all_no_background_images_pos_results.keys():
  806. each_pos_results = all_no_background_images_pos_results[key]
  807. if key in all_no_background_images_tp_results.keys():
  808. each_tp_results = all_no_background_images_tp_results[key]
  809. else:
  810. each_tp_results = []
  811. if key in all_no_background_images_fp_results.keys():
  812. each_fp_results = all_no_background_images_fp_results[key]
  813. else:
  814. each_fp_results = []
  815. each_fn_results = []
  816. for elem in each_pos_results:
  817. if elem not in each_tp_results:
  818. each_fn_results.append(elem)
  819. print('Label:' + key + ',FN对应的image:' + str(each_fn_results))