123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- # -*- coding: utf-8 -*-
- # @ File : mydataset.py
- # @ Author : Guido LuXiaohao
- # @ Date : 2021/8/17
- # @ Software : PyCharm
- # @ Description: 代码文件描述。
- import glob
- import json
- import os
- from ctypes import *
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple
- import albumentations as A
- import cv2
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- from dataset import TrainSDK
- from dataset.utils import interpolate_multicategory_mask, resize_LongestMaxSize
- from structures import PixelData, SegDataSample
- from utils import registry
- # Parameters
- IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
- VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
- @registry.DATASETS.register_module
- class LoadImages:
- # VINNO image dataloader, support for image, directory or video.
- def __init__(self, path, transforms=None, vid_stride=1):
- files = []
- for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
- p = str(Path(p).resolve())
- if '*' in p:
- files.extend(sorted(glob.glob(p, recursive=True))) # glob
- elif os.path.isdir(p):
- files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
- elif os.path.isfile(p):
- files.append(p) # file
- else:
- raise FileNotFoundError(f'{p} does not exist')
- images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
- videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
- ni, nv = len(images), len(videos)
- self.nf = ni + nv # number of files
- self.files = images + videos
- self.video_flag = [False] * ni + [True] * nv
- self.mode = 'image'
- self.transforms = transforms
- self.vid_stride = vid_stride # video frame-rate stride
- if any(videos):
- self._new_video(videos[0]) # new video
- else:
- self.cap = None
- dll_path = Path(__file__).resolve().parents[1] / "depends/CvCropUltImgRegion.dll"
- crop_image_fn = cdll.LoadLibrary(str(dll_path)).CropImage
- crop_image_fn.restype = c_bool
- self.crop_image_fn = crop_image_fn
- def __iter__(self):
- self.count = 0
- return self
- def __next__(self):
- if self.count == self.nf:
- raise StopIteration
- path = self.files[self.count]
- if self.video_flag[self.count]:
- # Read video
- self.mode = 'video'
- for _ in range(self.vid_stride):
- self.cap.grab()
- ret_val, im0 = self.cap.retrieve()
- while not ret_val:
- self.count += 1
- self.cap.release()
- if self.count == self.nf: # last video
- raise StopIteration
- path = self.files[self.count]
- self._new_video(path)
- ret_val, im0 = self.cap.read()
- h, w, c = im0.shape # c=BGR
- img_info = np.array((w, h, w * c, c), dtype=np.intc)
- rect = np.zeros(4, dtype=np.intc)
- crop_results = self.crop_image_fn(
- np.ascontiguousarray(im0).ctypes,
- np.ascontiguousarray(img_info).ctypes,
- np.ascontiguousarray(rect).ctypes)
- if not crop_results or \
- (rect[0] < 0) or \
- (rect[1] < 0) or \
- (rect[0] + rect[2] >= w) or \
- (rect[1] + rect[3] >= h):
- rect[0] = 0
- rect[1] = 0
- rect[2] = w
- rect[3] = h
- im0 = im0[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
- self.frame += 1
- s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
- else:
- # Read image
- self.count += 1
- im0 = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR) # BGR
- assert im0 is not None, f'Image not found {path}'
- s = f'image {self.count}/{self.nf} {path}'
- if self.transforms:
- im = self.transforms(im0) # transforms
- else:
- im = np.ascontiguousarray(im0)
- return path, im, im0, self.cap, s
- def _new_video(self, path):
- # Create a new video capture object
- self.frame = 0
- self.cap = cv2.VideoCapture(path)
- self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
- self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
- def __len__(self):
- return self.nf # number of files
- @registry.DATASETS.register_module
- class VinnoDatasetAIPlatformGeneral(Dataset):
- def __init__(self,
- token: str,
- class_index_map: Dict,
- input_shape: Tuple[int],
- data_index_list: List[int] = None,
- crop_class_index: Dict = None,
- extra_data_path: str = None,
- expansion_mode: str = "fixed_pixel",
- expansion_range: List[int] = [10, 60],
- limit_crop: bool = True,
- limit_crop_draw: bool = False,
- augmentation=None,
- dataset_mode: str = "train",
- preprocess_mode: str = None,
- resize_mode: str = "normal",
- **kwargs):
- super().__init__()
- self.token = token
- class_index_map = class_index_map[0] if not isinstance(
- class_index_map, dict) else class_index_map
- self.class_index_map = class_index_map
- if isinstance(data_index_list[0], list):
- # 如果按片段划分,将图像索引合并
- data_index_list = [
- img_id
- for img_ids in data_index_list
- for img_id in img_ids]
- self.data_index_list = data_index_list
- self.input_shape = input_shape
- self.crop_class_index = crop_class_index
- self.expansion_mode = expansion_mode
- self.expansion_range = expansion_range
- self.limit_crop = limit_crop
- self.limit_crop_draw = limit_crop_draw
- if isinstance(augmentation, list):
- augmentation = A.Compose(augmentation)
- self.augmentation = augmentation
- self.dataset_mode = dataset_mode
- self.preprocess_mode = preprocess_mode
- self.resize_mode = resize_mode
- class_names = list(class_index_map.keys())
- if '背景' not in class_names:
- class_names.insert(0, '背景')
- self.class_names = class_names
- # 读取额外轮廓信息
- self.parse_extra_data(extra_data_path=extra_data_path)
- # 加载数据集索引值,统计标签数量
- self.load_data_list()
- self.indices = list(range(len(self.loop_list)))
- def parse_extra_data(self, extra_data_path: Optional[str] = None):
- """当ROI需要裁切时需要提供额外的轮廓。本函数从文本文件里读取额外轮廓的信息,
- 并以字典的形式将数据赋给属性``self.extra_data``.
- Args:
- extra_data_path (str, optional): Path to the file that stores extra data. Defaults to None.
- """
- if extra_data_path is not None:
- extra_data = dict()
- extra_data_list = [x.strip('\n') for x in
- open(extra_data_path, "r", encoding="utf-8").readlines()]
- for i_extra_data in extra_data_list:
- if i_extra_data != '\n': # 跳过空行
- extra_data.update(json.loads(i_extra_data))
- else:
- continue
- self.extra_data = extra_data
- self.with_extra_data = True
- else:
- self.with_extra_data = False
- def load_data_list(self):
- """处理数据集加载所需要的数据索引值, 并统计数据集中标签的数量。
- 结果分别保存在属性``self.loop_list``和``self.cat_stats``.
- """
- num_classes = len(self.class_names)
- loop_list = [] # 用于记录图像索引和需要裁切的roi索引
- cat_stats = [] # 用于记录数据集中标签的数量
- for data_index in self.data_index_list: # 遍历所有数据,完成数据的筛选,扩增等操作 ------------------------------
- if self.dataset_mode == "train":
- _, label_data, file_name, _ = \
- TrainSDK.get_labeled_file(self.token, data_index)
- else:
- _, label_data, file_name, _ = \
- TrainSDK.get_test_labeled_file(self.token, data_index)
- # 读取标注信息
- label = json.loads(label_data)
- roi_results = label[0]['FileResultInfos'][0]['LabeledResult']['Rois']
- # 是否需要额外的外部轮廓(如肝局灶识别,肝轮廓由外部模型预测得到)
- if self.with_extra_data:
- if file_name + '.txt' in self.extra_data.keys():
- extra_img_roi = self.extra_data[file_name + '.txt']['Rois']
- roi_results.extend(extra_img_roi)
- else:
- # 未找到该图的轮廓信息,直接使用原图
- pass
- # 记录单张图所要识别的目标个数
- cat_stats_single = [0] * num_classes
- for roi in roi_results:
- roi_cls = roi['Conclusion']['Title']
- if roi_cls in self.class_names:
- roi_index = self.class_names.index(roi_cls)
- cat_stats_single[roi_index] += 1
- if sum(cat_stats_single) == 0: # 如果图上没有可识别目标,背景类设为1
- cat_stats_single[0] = 1
- cat_stats.append(cat_stats_single)
- if not self.crop_class_index:
- # 无需裁时
- loop_list.append([data_index, -1]) # [[data_index, crop_roi_index]]
- else:
- # 需要裁切
- max_crop_contour_len = 0
- max_crop_index = -1 # 记录轮廓最大roi的索引
- crop_index_list = []
- for roi_idx, roi in enumerate(roi_results): # 遍历轮廓列表以找到裁图roi ------------------------------
- roi_cls = roi['Conclusion']['Title']
- if roi_cls in self.crop_class_index.keys():
- if self.limit_crop: # 限制每张图裁切部位个数,每张图只裁切轮廓最大的
- roi_pts = roi['Points'] # contour points
- if max_crop_contour_len < len(roi_pts): # 此处假定: 轮廓上的点等间距采 -> 轮廓越大,轮廓点数量越多
- max_crop_contour_len = len(roi_pts)
- max_crop_index = roi_idx
- else: # 不限制每张图裁切部位个数
- crop_index_list.append(roi_idx)
- else:
- continue
- # 此时该张图无裁切部位标注,而此条件内有需要裁切部位,所以过滤掉该张图,后面也可视情况是否需要这种图作为负样本
- if max_crop_index == -1 and crop_index_list == []:
- continue
- if self.limit_crop:
- # 图像索引;裁切roi索引(-1代表不裁切);单张图索要识别的目标个数
- if max_crop_index == -1:
- # 需要裁切而,该张图无裁切部位,跳过
- continue
- loop_list.append([data_index, max_crop_index])
- else:
- for crop_index in crop_index_list:
- # 图像索引;裁切roi索引(-1代表不裁切)
- loop_list.append([data_index, crop_index])
- self.loop_list = loop_list
- self.cat_stats = cat_stats
- class_counts = np.array(cat_stats).sum(0)
- self.class_counts = class_counts
- print(f"数据集{self.dataset_mode}中各类别实例数量为:")
- for idx, name in enumerate(self.class_names):
- print(f"\t类别 {name} 的实例数为 {class_counts[idx]}")
- def __len__(self):
- return len(self.loop_list)
- def __getitem__(self, index):
- # 获取图像-裁图roi索引对
- real_item = self.loop_list[self.indices[index]]
- if self.dataset_mode == "train":
- image_data, label_data, file_name, file_isVideo = TrainSDK.get_labeled_file(self.token, real_item[0])
- else:
- image_data, label_data, file_name, file_isVideo = TrainSDK.get_test_labeled_file(self.token, real_item[0])
- image = np.frombuffer(image_data, dtype=np.uint8)
- image = cv2.imdecode(image, cv2.IMREAD_COLOR) # B, G, R
- ori_shape = crop_size = image.shape[:2]
- image_height = image.shape[0]
- image_width = image.shape[1]
- mask = np.zeros((image_height, image_width), dtype=np.uint8)
- label = json.loads(label_data)
- contour_list = label[0]['FileResultInfos'][0]['LabeledResult']['Rois']
- # 是否需要额外的外部轮廓(如肝局灶识别,肝轮廓由外部模型预测得到)
- if self.with_extra_data:
- if file_name + '.txt' in self.extra_data.keys():
- extra_img_roi = self.extra_data[file_name + '.txt']['Rois']
- contour_list.extend(extra_img_roi)
- else:
- # 未找到该图的轮廓信息,直接使用原图
- pass
- each_image_label_list = {} # 用于记录当前图像所需处理roi的标注信息
- for ctr_idx, contour in enumerate(contour_list): # 遍历当前图像上的所有roi -----------------------------------
- label_name = contour['Conclusion']['Title']
- ctr_pts = contour['Points']
- each_label_dict = {"contour": "", "label_value": "", "area": ""}
- if ctr_pts is not None:
- # 将轮廓坐标处理成drawContours可画的形式
- contours_cv = np.zeros((len(ctr_pts), 1, 2), dtype=np.int32)
- for i in range(len(ctr_pts)):
- contours_cv[i] = [int(ctr_pts[i]["X"]), int(ctr_pts[i]["Y"])]
- if label_name in self.class_index_map.keys():
- each_label_dict["contour"] = contours_cv
- each_label_dict["label_value"] = self.class_index_map[label_name]
- each_label_dict["area"] = cv2.contourArea(contours_cv)
- each_image_label_list.update({ctr_idx: each_label_dict})
- if self.limit_crop_draw and real_item[1] != -1:
- # 是否限制只画一个目标(一般在做病灶前后景分割的时候,限制)
- mask = cv2.drawContours(
- mask.copy(),
- contours=[each_image_label_list[real_item[1]]["contour"]],
- contourIdx=-1,
- color=(each_image_label_list[real_item[1]]["label_value"]),
- thickness=cv2.FILLED)
- else: # 不做限制,画出全部需要的轮廓
- each_image_label_list = list(each_image_label_list.values())
- each_image_label_list.sort(key=lambda x: x["area"], reverse=True)
- for idx in each_image_label_list:
- mask = cv2.drawContours(mask.copy(), [idx["contour"]], -1, (idx["label_value"]), cv2.FILLED)
- if real_item[1] != -1:
- crop_roi = contour_list[real_item[1]]
- crop_contours = crop_roi['Points']
- crop_contours_cv = np.zeros((len(crop_contours), 1, 2), dtype=np.int32)
- for i in range(len(crop_contours)):
- crop_contours_cv[i] = [int(crop_contours[i]["X"]), int(crop_contours[i]["Y"])]
- crop_boundingbox = list(cv2.boundingRect(crop_contours_cv))
- # 外扩固定像素裁切原图img图、掩膜mask图
- if self.expansion_mode == "fixed_pixel":
- expansion_pixel = np.random.choice(list(range(self.expansion_range[0], self.expansion_range[1])))
- elif self.expansion_mode == "fixed_rate":
- expansion_rate = np.random.choice(
- list(np.arange(self.expansion_range[0], self.expansion_range[1], 0.2)))
- expansion_pixel = int(expansion_rate * np.maximum(int(crop_boundingbox[2]), int(crop_boundingbox[3])))
- else:
- expansion_pixel = 0
- if crop_boundingbox != "0, 0, 0, 0":
- # 防止超出边界的边框出现
- left = int(crop_boundingbox[0]) if int(crop_boundingbox[0]) < image_width else image_width
- top = int(crop_boundingbox[1]) if int(crop_boundingbox[1]) < image_height else image_height
- w = int(crop_boundingbox[2])
- h = int(crop_boundingbox[3])
- # 防止边界溢出
- crop_left = np.maximum(left - expansion_pixel, 0)
- crop_top = np.maximum(top - expansion_pixel, 0)
- crop_w = np.minimum(left + w + expansion_pixel, image_width) - crop_left
- crop_h = np.minimum(top + h + expansion_pixel, image_height) - crop_top
- image = image[crop_top: crop_top + crop_h, crop_left: crop_left + crop_w]
- mask = mask[crop_top: crop_top + crop_h, crop_left: crop_left + crop_w]
- crop_size = image.shape[:2]
- # 尺寸放缩到固定大小
- if self.resize_mode == "normal":
- image = cv2.resize(image, (self.input_shape[1], self.input_shape[0]), interpolation=cv2.INTER_LINEAR)
- mask = interpolate_multicategory_mask(mask, (self.input_shape[1], self.input_shape[0]))
- elif self.resize_mode == "fitLargeSizeAndPad":
- image = resize_LongestMaxSize(image, self.input_shape[0], resize_mode=cv2.INTER_LINEAR)
- mask = resize_LongestMaxSize(mask, self.input_shape[0], resize_mode=cv2.INTER_NEAREST)
- else:
- pass
- if self.augmentation:
- sample = self.augmentation(image=image, mask=mask)
- image, mask = sample['image'], sample['mask']
- if self.preprocess_mode == "normalization1":
- dst_image = image.astype("float32") / 255
- elif self.preprocess_mode == "normalization2":
- dst_image = ((image.astype("float32") / 255) - 0.5) * 2.0
- elif self.preprocess_mode == "decentralization":
- img_mean = np.mean(image)
- img_std = np.std(image)
- dst_image = (image - img_mean + 10e-7) / (img_std + 10e-7)
- else:
- dst_image = image
- dst_image = dst_image.transpose((2, 0, 1)) # [H, W, C] -> [C, H, W]
- data_sample = SegDataSample(**{
- 'gt_sem_seg':
- PixelData(**{'data': torch.from_numpy(mask.astype('int64'))}),
- 'metainfo': {
- 'ori_shape': tuple(ori_shape),
- 'img_shape': tuple(self.input_shape[:2]),
- 'crop_size': tuple(crop_size)},
- 'label': label_data
- })
- return dst_image, data_sample
|