mydataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # -*- coding: utf-8 -*-
  2. # @ File : mydataset.py
  3. # @ Author : Guido LuXiaohao
  4. # @ Date : 2021/8/17
  5. # @ Software : PyCharm
  6. # @ Description: 代码文件描述。
  7. import glob
  8. import json
  9. import os
  10. from ctypes import *
  11. from pathlib import Path
  12. from typing import Dict, List, Optional, Tuple
  13. import albumentations as A
  14. import cv2
  15. import numpy as np
  16. import torch
  17. from torch.utils.data import Dataset
  18. from dataset import TrainSDK
  19. from dataset.utils import interpolate_multicategory_mask, resize_LongestMaxSize
  20. from structures import PixelData, SegDataSample
  21. from utils import registry
  22. # Parameters
  23. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
  24. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  25. @registry.DATASETS.register_module
  26. class LoadImages:
  27. # VINNO image dataloader, support for image, directory or video.
  28. def __init__(self, path, transforms=None, vid_stride=1):
  29. files = []
  30. for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
  31. p = str(Path(p).resolve())
  32. if '*' in p:
  33. files.extend(sorted(glob.glob(p, recursive=True))) # glob
  34. elif os.path.isdir(p):
  35. files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
  36. elif os.path.isfile(p):
  37. files.append(p) # file
  38. else:
  39. raise FileNotFoundError(f'{p} does not exist')
  40. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  41. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  42. ni, nv = len(images), len(videos)
  43. self.nf = ni + nv # number of files
  44. self.files = images + videos
  45. self.video_flag = [False] * ni + [True] * nv
  46. self.mode = 'image'
  47. self.transforms = transforms
  48. self.vid_stride = vid_stride # video frame-rate stride
  49. if any(videos):
  50. self._new_video(videos[0]) # new video
  51. else:
  52. self.cap = None
  53. dll_path = Path(__file__).resolve().parents[1] / "depends/CvCropUltImgRegion.dll"
  54. crop_image_fn = cdll.LoadLibrary(str(dll_path)).CropImage
  55. crop_image_fn.restype = c_bool
  56. self.crop_image_fn = crop_image_fn
  57. def __iter__(self):
  58. self.count = 0
  59. return self
  60. def __next__(self):
  61. if self.count == self.nf:
  62. raise StopIteration
  63. path = self.files[self.count]
  64. if self.video_flag[self.count]:
  65. # Read video
  66. self.mode = 'video'
  67. for _ in range(self.vid_stride):
  68. self.cap.grab()
  69. ret_val, im0 = self.cap.retrieve()
  70. while not ret_val:
  71. self.count += 1
  72. self.cap.release()
  73. if self.count == self.nf: # last video
  74. raise StopIteration
  75. path = self.files[self.count]
  76. self._new_video(path)
  77. ret_val, im0 = self.cap.read()
  78. h, w, c = im0.shape # c=BGR
  79. img_info = np.array((w, h, w * c, c), dtype=np.intc)
  80. rect = np.zeros(4, dtype=np.intc)
  81. crop_results = self.crop_image_fn(
  82. np.ascontiguousarray(im0).ctypes,
  83. np.ascontiguousarray(img_info).ctypes,
  84. np.ascontiguousarray(rect).ctypes)
  85. if not crop_results or \
  86. (rect[0] < 0) or \
  87. (rect[1] < 0) or \
  88. (rect[0] + rect[2] >= w) or \
  89. (rect[1] + rect[3] >= h):
  90. rect[0] = 0
  91. rect[1] = 0
  92. rect[2] = w
  93. rect[3] = h
  94. im0 = im0[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
  95. self.frame += 1
  96. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  97. else:
  98. # Read image
  99. self.count += 1
  100. im0 = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR) # BGR
  101. assert im0 is not None, f'Image not found {path}'
  102. s = f'image {self.count}/{self.nf} {path}'
  103. if self.transforms:
  104. im = self.transforms(im0) # transforms
  105. else:
  106. im = np.ascontiguousarray(im0)
  107. return path, im, im0, self.cap, s
  108. def _new_video(self, path):
  109. # Create a new video capture object
  110. self.frame = 0
  111. self.cap = cv2.VideoCapture(path)
  112. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
  113. self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
  114. def __len__(self):
  115. return self.nf # number of files
  116. @registry.DATASETS.register_module
  117. class VinnoDatasetAIPlatformGeneral(Dataset):
  118. def __init__(self,
  119. token: str,
  120. class_index_map: Dict,
  121. input_shape: Tuple[int],
  122. data_index_list: List[int] = None,
  123. crop_class_index: Dict = None,
  124. extra_data_path: str = None,
  125. expansion_mode: str = "fixed_pixel",
  126. expansion_range: List[int] = [10, 60],
  127. limit_crop: bool = True,
  128. limit_crop_draw: bool = False,
  129. augmentation=None,
  130. dataset_mode: str = "train",
  131. preprocess_mode: str = None,
  132. resize_mode: str = "normal",
  133. **kwargs):
  134. super().__init__()
  135. self.token = token
  136. class_index_map = class_index_map[0] if not isinstance(
  137. class_index_map, dict) else class_index_map
  138. self.class_index_map = class_index_map
  139. if isinstance(data_index_list[0], list):
  140. # 如果按片段划分,将图像索引合并
  141. data_index_list = [
  142. img_id
  143. for img_ids in data_index_list
  144. for img_id in img_ids]
  145. self.data_index_list = data_index_list
  146. self.input_shape = input_shape
  147. self.crop_class_index = crop_class_index
  148. self.expansion_mode = expansion_mode
  149. self.expansion_range = expansion_range
  150. self.limit_crop = limit_crop
  151. self.limit_crop_draw = limit_crop_draw
  152. if isinstance(augmentation, list):
  153. augmentation = A.Compose(augmentation)
  154. self.augmentation = augmentation
  155. self.dataset_mode = dataset_mode
  156. self.preprocess_mode = preprocess_mode
  157. self.resize_mode = resize_mode
  158. class_names = list(class_index_map.keys())
  159. if '背景' not in class_names:
  160. class_names.insert(0, '背景')
  161. self.class_names = class_names
  162. # 读取额外轮廓信息
  163. self.parse_extra_data(extra_data_path=extra_data_path)
  164. # 加载数据集索引值,统计标签数量
  165. self.load_data_list()
  166. self.indices = list(range(len(self.loop_list)))
  167. def parse_extra_data(self, extra_data_path: Optional[str] = None):
  168. """当ROI需要裁切时需要提供额外的轮廓。本函数从文本文件里读取额外轮廓的信息,
  169. 并以字典的形式将数据赋给属性``self.extra_data``.
  170. Args:
  171. extra_data_path (str, optional): Path to the file that stores extra data. Defaults to None.
  172. """
  173. if extra_data_path is not None:
  174. extra_data = dict()
  175. extra_data_list = [x.strip('\n') for x in
  176. open(extra_data_path, "r", encoding="utf-8").readlines()]
  177. for i_extra_data in extra_data_list:
  178. if i_extra_data != '\n': # 跳过空行
  179. extra_data.update(json.loads(i_extra_data))
  180. else:
  181. continue
  182. self.extra_data = extra_data
  183. self.with_extra_data = True
  184. else:
  185. self.with_extra_data = False
  186. def load_data_list(self):
  187. """处理数据集加载所需要的数据索引值, 并统计数据集中标签的数量。
  188. 结果分别保存在属性``self.loop_list``和``self.cat_stats``.
  189. """
  190. num_classes = len(self.class_names)
  191. loop_list = [] # 用于记录图像索引和需要裁切的roi索引
  192. cat_stats = [] # 用于记录数据集中标签的数量
  193. for data_index in self.data_index_list: # 遍历所有数据,完成数据的筛选,扩增等操作 ------------------------------
  194. if self.dataset_mode == "train":
  195. _, label_data, file_name, _ = \
  196. TrainSDK.get_labeled_file(self.token, data_index)
  197. else:
  198. _, label_data, file_name, _ = \
  199. TrainSDK.get_test_labeled_file(self.token, data_index)
  200. # 读取标注信息
  201. label = json.loads(label_data)
  202. roi_results = label[0]['FileResultInfos'][0]['LabeledResult']['Rois']
  203. # 是否需要额外的外部轮廓(如肝局灶识别,肝轮廓由外部模型预测得到)
  204. if self.with_extra_data:
  205. if file_name + '.txt' in self.extra_data.keys():
  206. extra_img_roi = self.extra_data[file_name + '.txt']['Rois']
  207. roi_results.extend(extra_img_roi)
  208. else:
  209. # 未找到该图的轮廓信息,直接使用原图
  210. pass
  211. # 记录单张图所要识别的目标个数
  212. cat_stats_single = [0] * num_classes
  213. for roi in roi_results:
  214. roi_cls = roi['Conclusion']['Title']
  215. if roi_cls in self.class_names:
  216. roi_index = self.class_names.index(roi_cls)
  217. cat_stats_single[roi_index] += 1
  218. if sum(cat_stats_single) == 0: # 如果图上没有可识别目标,背景类设为1
  219. cat_stats_single[0] = 1
  220. cat_stats.append(cat_stats_single)
  221. if not self.crop_class_index:
  222. # 无需裁时
  223. loop_list.append([data_index, -1]) # [[data_index, crop_roi_index]]
  224. else:
  225. # 需要裁切
  226. max_crop_contour_len = 0
  227. max_crop_index = -1 # 记录轮廓最大roi的索引
  228. crop_index_list = []
  229. for roi_idx, roi in enumerate(roi_results): # 遍历轮廓列表以找到裁图roi ------------------------------
  230. roi_cls = roi['Conclusion']['Title']
  231. if roi_cls in self.crop_class_index.keys():
  232. if self.limit_crop: # 限制每张图裁切部位个数,每张图只裁切轮廓最大的
  233. roi_pts = roi['Points'] # contour points
  234. if max_crop_contour_len < len(roi_pts): # 此处假定: 轮廓上的点等间距采 -> 轮廓越大,轮廓点数量越多
  235. max_crop_contour_len = len(roi_pts)
  236. max_crop_index = roi_idx
  237. else: # 不限制每张图裁切部位个数
  238. crop_index_list.append(roi_idx)
  239. else:
  240. continue
  241. # 此时该张图无裁切部位标注,而此条件内有需要裁切部位,所以过滤掉该张图,后面也可视情况是否需要这种图作为负样本
  242. if max_crop_index == -1 and crop_index_list == []:
  243. continue
  244. if self.limit_crop:
  245. # 图像索引;裁切roi索引(-1代表不裁切);单张图索要识别的目标个数
  246. if max_crop_index == -1:
  247. # 需要裁切而,该张图无裁切部位,跳过
  248. continue
  249. loop_list.append([data_index, max_crop_index])
  250. else:
  251. for crop_index in crop_index_list:
  252. # 图像索引;裁切roi索引(-1代表不裁切)
  253. loop_list.append([data_index, crop_index])
  254. self.loop_list = loop_list
  255. self.cat_stats = cat_stats
  256. class_counts = np.array(cat_stats).sum(0)
  257. self.class_counts = class_counts
  258. print(f"数据集{self.dataset_mode}中各类别实例数量为:")
  259. for idx, name in enumerate(self.class_names):
  260. print(f"\t类别 {name} 的实例数为 {class_counts[idx]}")
  261. def __len__(self):
  262. return len(self.loop_list)
  263. def __getitem__(self, index):
  264. # 获取图像-裁图roi索引对
  265. real_item = self.loop_list[self.indices[index]]
  266. if self.dataset_mode == "train":
  267. image_data, label_data, file_name, file_isVideo = TrainSDK.get_labeled_file(self.token, real_item[0])
  268. else:
  269. image_data, label_data, file_name, file_isVideo = TrainSDK.get_test_labeled_file(self.token, real_item[0])
  270. image = np.frombuffer(image_data, dtype=np.uint8)
  271. image = cv2.imdecode(image, cv2.IMREAD_COLOR) # B, G, R
  272. ori_shape = crop_size = image.shape[:2]
  273. image_height = image.shape[0]
  274. image_width = image.shape[1]
  275. mask = np.zeros((image_height, image_width), dtype=np.uint8)
  276. label = json.loads(label_data)
  277. contour_list = label[0]['FileResultInfos'][0]['LabeledResult']['Rois']
  278. # 是否需要额外的外部轮廓(如肝局灶识别,肝轮廓由外部模型预测得到)
  279. if self.with_extra_data:
  280. if file_name + '.txt' in self.extra_data.keys():
  281. extra_img_roi = self.extra_data[file_name + '.txt']['Rois']
  282. contour_list.extend(extra_img_roi)
  283. else:
  284. # 未找到该图的轮廓信息,直接使用原图
  285. pass
  286. each_image_label_list = {} # 用于记录当前图像所需处理roi的标注信息
  287. for ctr_idx, contour in enumerate(contour_list): # 遍历当前图像上的所有roi -----------------------------------
  288. label_name = contour['Conclusion']['Title']
  289. ctr_pts = contour['Points']
  290. each_label_dict = {"contour": "", "label_value": "", "area": ""}
  291. if ctr_pts is not None:
  292. # 将轮廓坐标处理成drawContours可画的形式
  293. contours_cv = np.zeros((len(ctr_pts), 1, 2), dtype=np.int32)
  294. for i in range(len(ctr_pts)):
  295. contours_cv[i] = [int(ctr_pts[i]["X"]), int(ctr_pts[i]["Y"])]
  296. if label_name in self.class_index_map.keys():
  297. each_label_dict["contour"] = contours_cv
  298. each_label_dict["label_value"] = self.class_index_map[label_name]
  299. each_label_dict["area"] = cv2.contourArea(contours_cv)
  300. each_image_label_list.update({ctr_idx: each_label_dict})
  301. if self.limit_crop_draw and real_item[1] != -1:
  302. # 是否限制只画一个目标(一般在做病灶前后景分割的时候,限制)
  303. mask = cv2.drawContours(
  304. mask.copy(),
  305. contours=[each_image_label_list[real_item[1]]["contour"]],
  306. contourIdx=-1,
  307. color=(each_image_label_list[real_item[1]]["label_value"]),
  308. thickness=cv2.FILLED)
  309. else: # 不做限制,画出全部需要的轮廓
  310. each_image_label_list = list(each_image_label_list.values())
  311. each_image_label_list.sort(key=lambda x: x["area"], reverse=True)
  312. for idx in each_image_label_list:
  313. mask = cv2.drawContours(mask.copy(), [idx["contour"]], -1, (idx["label_value"]), cv2.FILLED)
  314. if real_item[1] != -1:
  315. crop_roi = contour_list[real_item[1]]
  316. crop_contours = crop_roi['Points']
  317. crop_contours_cv = np.zeros((len(crop_contours), 1, 2), dtype=np.int32)
  318. for i in range(len(crop_contours)):
  319. crop_contours_cv[i] = [int(crop_contours[i]["X"]), int(crop_contours[i]["Y"])]
  320. crop_boundingbox = list(cv2.boundingRect(crop_contours_cv))
  321. # 外扩固定像素裁切原图img图、掩膜mask图
  322. if self.expansion_mode == "fixed_pixel":
  323. expansion_pixel = np.random.choice(list(range(self.expansion_range[0], self.expansion_range[1])))
  324. elif self.expansion_mode == "fixed_rate":
  325. expansion_rate = np.random.choice(
  326. list(np.arange(self.expansion_range[0], self.expansion_range[1], 0.2)))
  327. expansion_pixel = int(expansion_rate * np.maximum(int(crop_boundingbox[2]), int(crop_boundingbox[3])))
  328. else:
  329. expansion_pixel = 0
  330. if crop_boundingbox != "0, 0, 0, 0":
  331. # 防止超出边界的边框出现
  332. left = int(crop_boundingbox[0]) if int(crop_boundingbox[0]) < image_width else image_width
  333. top = int(crop_boundingbox[1]) if int(crop_boundingbox[1]) < image_height else image_height
  334. w = int(crop_boundingbox[2])
  335. h = int(crop_boundingbox[3])
  336. # 防止边界溢出
  337. crop_left = np.maximum(left - expansion_pixel, 0)
  338. crop_top = np.maximum(top - expansion_pixel, 0)
  339. crop_w = np.minimum(left + w + expansion_pixel, image_width) - crop_left
  340. crop_h = np.minimum(top + h + expansion_pixel, image_height) - crop_top
  341. image = image[crop_top: crop_top + crop_h, crop_left: crop_left + crop_w]
  342. mask = mask[crop_top: crop_top + crop_h, crop_left: crop_left + crop_w]
  343. crop_size = image.shape[:2]
  344. # 尺寸放缩到固定大小
  345. if self.resize_mode == "normal":
  346. image = cv2.resize(image, (self.input_shape[1], self.input_shape[0]), interpolation=cv2.INTER_LINEAR)
  347. mask = interpolate_multicategory_mask(mask, (self.input_shape[1], self.input_shape[0]))
  348. elif self.resize_mode == "fitLargeSizeAndPad":
  349. image = resize_LongestMaxSize(image, self.input_shape[0], resize_mode=cv2.INTER_LINEAR)
  350. mask = resize_LongestMaxSize(mask, self.input_shape[0], resize_mode=cv2.INTER_NEAREST)
  351. else:
  352. pass
  353. if self.augmentation:
  354. sample = self.augmentation(image=image, mask=mask)
  355. image, mask = sample['image'], sample['mask']
  356. if self.preprocess_mode == "normalization1":
  357. dst_image = image.astype("float32") / 255
  358. elif self.preprocess_mode == "normalization2":
  359. dst_image = ((image.astype("float32") / 255) - 0.5) * 2.0
  360. elif self.preprocess_mode == "decentralization":
  361. img_mean = np.mean(image)
  362. img_std = np.std(image)
  363. dst_image = (image - img_mean + 10e-7) / (img_std + 10e-7)
  364. else:
  365. dst_image = image
  366. dst_image = dst_image.transpose((2, 0, 1)) # [H, W, C] -> [C, H, W]
  367. data_sample = SegDataSample(**{
  368. 'gt_sem_seg':
  369. PixelData(**{'data': torch.from_numpy(mask.astype('int64'))}),
  370. 'metainfo': {
  371. 'ori_shape': tuple(ori_shape),
  372. 'img_shape': tuple(self.input_shape[:2]),
  373. 'crop_size': tuple(crop_size)},
  374. 'label': label_data
  375. })
  376. return dst_image, data_sample