dataset.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # -*- coding: utf-8 -*-
  2. '''
  3. @Time : 2020/05/06 21:09
  4. @Author : Tianxiaomo
  5. @File : dataset.py
  6. @Noice :
  7. @Modificattion :
  8. @Author :
  9. @Time :
  10. @Detail :
  11. '''
  12. import os
  13. import random
  14. import sys
  15. import cv2
  16. import numpy as np
  17. import torch
  18. from torch.utils.data.dataset import Dataset
  19. def rand_uniform_strong(min, max):
  20. if min > max:
  21. swap = min
  22. min = max
  23. max = swap
  24. return random.random() * (max - min) + min
  25. def rand_scale(s):
  26. scale = rand_uniform_strong(1, s)
  27. if random.randint(0, 1) % 2:
  28. return scale
  29. return 1. / scale
  30. def rand_precalc_random(min, max, random_part):
  31. if max < min:
  32. swap = min
  33. min = max
  34. max = swap
  35. return (random_part * (max - min)) + min
  36. def fill_truth_detection(bboxes, num_boxes, classes, flip, dx, dy, sx, sy, net_w, net_h):
  37. if bboxes.shape[0] == 0:
  38. return bboxes, 10000
  39. np.random.shuffle(bboxes)
  40. bboxes[:, 0] -= dx
  41. bboxes[:, 2] -= dx
  42. bboxes[:, 1] -= dy
  43. bboxes[:, 3] -= dy
  44. bboxes[:, 0] = np.clip(bboxes[:, 0], 0, sx)
  45. bboxes[:, 2] = np.clip(bboxes[:, 2], 0, sx)
  46. bboxes[:, 1] = np.clip(bboxes[:, 1], 0, sy)
  47. bboxes[:, 3] = np.clip(bboxes[:, 3], 0, sy)
  48. out_box = list(np.where(((bboxes[:, 1] == sy) & (bboxes[:, 3] == sy)) |
  49. ((bboxes[:, 0] == sx) & (bboxes[:, 2] == sx)) |
  50. ((bboxes[:, 1] == 0) & (bboxes[:, 3] == 0)) |
  51. ((bboxes[:, 0] == 0) & (bboxes[:, 2] == 0)))[0])
  52. list_box = list(range(bboxes.shape[0]))
  53. for i in out_box:
  54. list_box.remove(i)
  55. bboxes = bboxes[list_box]
  56. if bboxes.shape[0] == 0:
  57. return bboxes, 10000
  58. bboxes = bboxes[np.where((bboxes[:, 4] < classes) & (bboxes[:, 4] >= 0))[0]]
  59. if bboxes.shape[0] > num_boxes:
  60. bboxes = bboxes[:num_boxes]
  61. min_w_h = np.array([bboxes[:, 2] - bboxes[:, 0], bboxes[:, 3] - bboxes[:, 1]]).min()
  62. bboxes[:, 0] *= (net_w / sx)
  63. bboxes[:, 2] *= (net_w / sx)
  64. bboxes[:, 1] *= (net_h / sy)
  65. bboxes[:, 3] *= (net_h / sy)
  66. if flip:
  67. temp = net_w - bboxes[:, 0]
  68. bboxes[:, 0] = net_w - bboxes[:, 2]
  69. bboxes[:, 2] = temp
  70. return bboxes, min_w_h
  71. def rect_intersection(a, b):
  72. minx = max(a[0], b[0])
  73. miny = max(a[1], b[1])
  74. maxx = min(a[2], b[2])
  75. maxy = min(a[3], b[3])
  76. return [minx, miny, maxx, maxy]
  77. def image_data_augmentation(mat, w, h, pleft, ptop, swidth, sheight, flip, dhue, dsat, dexp, gaussian_noise, blur,
  78. truth):
  79. try:
  80. img = mat
  81. oh, ow, _ = img.shape
  82. pleft, ptop, swidth, sheight = int(pleft), int(ptop), int(swidth), int(sheight)
  83. # crop
  84. src_rect = [pleft, ptop, swidth + pleft, sheight + ptop] # x1,y1,x2,y2
  85. img_rect = [0, 0, ow, oh]
  86. new_src_rect = rect_intersection(src_rect, img_rect) # 交集
  87. dst_rect = [max(0, -pleft), max(0, -ptop), max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
  88. max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
  89. # cv2.Mat sized
  90. if (src_rect[0] == 0 and src_rect[1] == 0 and src_rect[2] == img.shape[0] and src_rect[3] == img.shape[1]):
  91. sized = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
  92. else:
  93. cropped = np.zeros([sheight, swidth, 3])
  94. cropped[:, :, ] = np.mean(img, axis=(0, 1))
  95. cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
  96. img[new_src_rect[1]:new_src_rect[3], new_src_rect[0]:new_src_rect[2]]
  97. # resize
  98. sized = cv2.resize(cropped, (w, h), cv2.INTER_LINEAR)
  99. # flip
  100. if flip:
  101. # cv2.Mat cropped
  102. sized = cv2.flip(sized, 1) # 0 - x-axis, 1 - y-axis, -1 - both axes (x & y)
  103. # HSV augmentation
  104. # cv2.COLOR_BGR2HSV, cv2.COLOR_RGB2HSV, cv2.COLOR_HSV2BGR, cv2.COLOR_HSV2RGB
  105. if dsat != 1 or dexp != 1 or dhue != 0:
  106. if img.shape[2] >= 3:
  107. hsv_src = cv2.cvtColor(sized.astype(np.float32), cv2.COLOR_RGB2HSV) # RGB to HSV
  108. hsv = cv2.split(hsv_src)
  109. hsv[1] *= dsat
  110. hsv[2] *= dexp
  111. hsv[0] += 179 * dhue
  112. hsv_src = cv2.merge(hsv)
  113. sized = np.clip(cv2.cvtColor(hsv_src, cv2.COLOR_HSV2RGB), 0, 255) # HSV to RGB (the same as previous)
  114. else:
  115. sized *= dexp
  116. if blur:
  117. if blur == 1:
  118. dst = cv2.GaussianBlur(sized, (17, 17), 0)
  119. # cv2.bilateralFilter(sized, dst, 17, 75, 75)
  120. else:
  121. ksize = (blur / 2) * 2 + 1
  122. dst = cv2.GaussianBlur(sized, (ksize, ksize), 0)
  123. if blur == 1:
  124. img_rect = [0, 0, sized.cols, sized.rows]
  125. for b in truth:
  126. left = (b.x - b.w / 2.) * sized.shape[1]
  127. width = b.w * sized.shape[1]
  128. top = (b.y - b.h / 2.) * sized.shape[0]
  129. height = b.h * sized.shape[0]
  130. roi(left, top, width, height)
  131. roi = roi & img_rect
  132. dst[roi[0]:roi[0] + roi[2], roi[1]:roi[1] + roi[3]] = sized[roi[0]:roi[0] + roi[2],
  133. roi[1]:roi[1] + roi[3]]
  134. sized = dst
  135. if gaussian_noise:
  136. noise = np.array(sized.shape)
  137. gaussian_noise = min(gaussian_noise, 127)
  138. gaussian_noise = max(gaussian_noise, 0)
  139. cv2.randn(noise, 0, gaussian_noise) # mean and variance
  140. sized = sized + noise
  141. except:
  142. print("OpenCV can't augment image: " + str(w) + " x " + str(h))
  143. sized = mat
  144. return sized
  145. def filter_truth(bboxes, dx, dy, sx, sy, xd, yd):
  146. bboxes[:, 0] -= dx
  147. bboxes[:, 2] -= dx
  148. bboxes[:, 1] -= dy
  149. bboxes[:, 3] -= dy
  150. bboxes[:, 0] = np.clip(bboxes[:, 0], 0, sx)
  151. bboxes[:, 2] = np.clip(bboxes[:, 2], 0, sx)
  152. bboxes[:, 1] = np.clip(bboxes[:, 1], 0, sy)
  153. bboxes[:, 3] = np.clip(bboxes[:, 3], 0, sy)
  154. out_box = list(np.where(((bboxes[:, 1] == sy) & (bboxes[:, 3] == sy)) |
  155. ((bboxes[:, 0] == sx) & (bboxes[:, 2] == sx)) |
  156. ((bboxes[:, 1] == 0) & (bboxes[:, 3] == 0)) |
  157. ((bboxes[:, 0] == 0) & (bboxes[:, 2] == 0)))[0])
  158. list_box = list(range(bboxes.shape[0]))
  159. for i in out_box:
  160. list_box.remove(i)
  161. bboxes = bboxes[list_box]
  162. bboxes[:, 0] += xd
  163. bboxes[:, 2] += xd
  164. bboxes[:, 1] += yd
  165. bboxes[:, 3] += yd
  166. return bboxes
  167. def blend_truth_mosaic(out_img, img, bboxes, w, h, cut_x, cut_y, i_mixup,
  168. left_shift, right_shift, top_shift, bot_shift):
  169. left_shift = min(left_shift, w - cut_x)
  170. top_shift = min(top_shift, h - cut_y)
  171. right_shift = min(right_shift, cut_x)
  172. bot_shift = min(bot_shift, cut_y)
  173. if i_mixup == 0:
  174. bboxes = filter_truth(bboxes, left_shift, top_shift, cut_x, cut_y, 0, 0)
  175. out_img[:cut_y, :cut_x] = img[top_shift:top_shift + cut_y, left_shift:left_shift + cut_x]
  176. if i_mixup == 1:
  177. bboxes = filter_truth(bboxes, cut_x - right_shift, top_shift, w - cut_x, cut_y, cut_x, 0)
  178. out_img[:cut_y, cut_x:] = img[top_shift:top_shift + cut_y, cut_x - right_shift:w - right_shift]
  179. if i_mixup == 2:
  180. bboxes = filter_truth(bboxes, left_shift, cut_y - bot_shift, cut_x, h - cut_y, 0, cut_y)
  181. out_img[cut_y:, :cut_x] = img[cut_y - bot_shift:h - bot_shift, left_shift:left_shift + cut_x]
  182. if i_mixup == 3:
  183. bboxes = filter_truth(bboxes, cut_x - right_shift, cut_y - bot_shift, w - cut_x, h - cut_y, cut_x, cut_y)
  184. out_img[cut_y:, cut_x:] = img[cut_y - bot_shift:h - bot_shift, cut_x - right_shift:w - right_shift]
  185. return out_img, bboxes
  186. def draw_box(img, bboxes):
  187. for b in bboxes:
  188. img = cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2)
  189. return img
  190. class Yolo_dataset(Dataset):
  191. def __init__(self, lable_path, cfg, train=True):
  192. super(Yolo_dataset, self).__init__()
  193. if cfg.mixup == 2:
  194. print("cutmix=1 - isn't supported for Detector")
  195. raise
  196. elif cfg.mixup == 2 and cfg.letter_box:
  197. print("Combination: letter_box=1 & mosaic=1 - isn't supported, use only 1 of these parameters")
  198. raise
  199. self.cfg = cfg
  200. self.train = train
  201. truth = {}
  202. f = open(lable_path, 'r', encoding='utf-8')
  203. for line in f.readlines():
  204. data = line.split(" ")
  205. truth[data[0]] = []
  206. for i in data[1:]:
  207. truth[data[0]].append([int(float(j)) for j in i.split(',')])
  208. self.truth = truth
  209. self.imgs = list(self.truth.keys())
  210. def __len__(self):
  211. return len(self.truth.keys())
  212. def __getitem__(self, index):
  213. if not self.train:
  214. return self._get_val_item(index)
  215. img_path = self.imgs[index]
  216. bboxes = np.array(self.truth.get(img_path), dtype=np.float)
  217. img_path = os.path.join(self.cfg.dataset_dir, img_path)
  218. use_mixup = self.cfg.mixup
  219. if random.randint(0, 1):
  220. use_mixup = 0
  221. if use_mixup == 3:
  222. min_offset = 0.2
  223. cut_x = random.randint(int(self.cfg.w * min_offset), int(self.cfg.w * (1 - min_offset)))
  224. cut_y = random.randint(int(self.cfg.h * min_offset), int(self.cfg.h * (1 - min_offset)))
  225. r1, r2, r3, r4, r_scale = 0, 0, 0, 0, 0
  226. dhue, dsat, dexp, flip, blur = 0, 0, 0, 0, 0
  227. gaussian_noise = 0
  228. out_img = np.zeros([self.cfg.h, self.cfg.w, 3])
  229. out_bboxes = []
  230. for i in range(use_mixup + 1):
  231. if i != 0:
  232. img_path = random.choice(list(self.truth.keys()))
  233. bboxes = np.array(self.truth.get(img_path), dtype=np.float)
  234. img_path = os.path.join(self.cfg.dataset_dir, img_path)
  235. img = cv2.imread(img_path)
  236. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  237. if img is None:
  238. continue
  239. oh, ow, oc = img.shape
  240. dh, dw, dc = np.array(np.array([oh, ow, oc]) * self.cfg.jitter, dtype=np.int)
  241. dhue = rand_uniform_strong(-self.cfg.hue, self.cfg.hue)
  242. dsat = rand_scale(self.cfg.saturation)
  243. dexp = rand_scale(self.cfg.exposure)
  244. pleft = random.randint(-dw, dw)
  245. pright = random.randint(-dw, dw)
  246. ptop = random.randint(-dh, dh)
  247. pbot = random.randint(-dh, dh)
  248. flip = random.randint(0, 1) if self.cfg.flip else 0
  249. if (self.cfg.blur):
  250. tmp_blur = random.randint(0, 2) # 0 - disable, 1 - blur background, 2 - blur the whole image
  251. if tmp_blur == 0:
  252. blur = 0
  253. elif tmp_blur == 1:
  254. blur = 1
  255. else:
  256. blur = self.cfg.blur
  257. if self.cfg.gaussian and random.randint(0, 1):
  258. gaussian_noise = self.cfg.gaussian
  259. else:
  260. gaussian_noise = 0
  261. if self.cfg.letter_box:
  262. img_ar = ow / oh
  263. net_ar = self.cfg.w / self.cfg.h
  264. result_ar = img_ar / net_ar
  265. # print(" ow = %d, oh = %d, w = %d, h = %d, img_ar = %f, net_ar = %f, result_ar = %f \n", ow, oh, w, h, img_ar, net_ar, result_ar);
  266. if result_ar > 1: # sheight - should be increased
  267. oh_tmp = ow / net_ar
  268. delta_h = (oh_tmp - oh) / 2
  269. ptop = ptop - delta_h
  270. pbot = pbot - delta_h
  271. # print(" result_ar = %f, oh_tmp = %f, delta_h = %d, ptop = %f, pbot = %f \n", result_ar, oh_tmp, delta_h, ptop, pbot);
  272. else: # swidth - should be increased
  273. ow_tmp = oh * net_ar
  274. delta_w = (ow_tmp - ow) / 2
  275. pleft = pleft - delta_w
  276. pright = pright - delta_w
  277. # printf(" result_ar = %f, ow_tmp = %f, delta_w = %d, pleft = %f, pright = %f \n", result_ar, ow_tmp, delta_w, pleft, pright);
  278. swidth = ow - pleft - pright
  279. sheight = oh - ptop - pbot
  280. truth, min_w_h = fill_truth_detection(bboxes, self.cfg.boxes, self.cfg.classes, flip, pleft, ptop, swidth,
  281. sheight, self.cfg.w, self.cfg.h)
  282. if (min_w_h / 8) < blur and blur > 1: # disable blur if one of the objects is too small
  283. blur = min_w_h / 8
  284. ai = image_data_augmentation(img, self.cfg.w, self.cfg.h, pleft, ptop, swidth, sheight, flip,
  285. dhue, dsat, dexp, gaussian_noise, blur, truth)
  286. if use_mixup == 0:
  287. out_img = ai
  288. out_bboxes = truth
  289. if use_mixup == 1:
  290. if i == 0:
  291. old_img = ai.copy()
  292. old_truth = truth.copy()
  293. elif i == 1:
  294. out_img = cv2.addWeighted(ai, 0.5, old_img, 0.5)
  295. out_bboxes = np.concatenate([old_truth, truth], axis=0)
  296. elif use_mixup == 3:
  297. if flip:
  298. tmp = pleft
  299. pleft = pright
  300. pright = tmp
  301. left_shift = int(min(cut_x, max(0, (-int(pleft) * self.cfg.w / swidth))))
  302. top_shift = int(min(cut_y, max(0, (-int(ptop) * self.cfg.h / sheight))))
  303. right_shift = int(min((self.cfg.w - cut_x), max(0, (-int(pright) * self.cfg.w / swidth))))
  304. bot_shift = int(min(self.cfg.h - cut_y, max(0, (-int(pbot) * self.cfg.h / sheight))))
  305. out_img, out_bbox = blend_truth_mosaic(out_img, ai, truth.copy(), self.cfg.w, self.cfg.h, cut_x,
  306. cut_y, i, left_shift, right_shift, top_shift, bot_shift)
  307. out_bboxes.append(out_bbox)
  308. # print(img_path)
  309. if use_mixup == 3:
  310. out_bboxes = np.concatenate(out_bboxes, axis=0)
  311. out_bboxes1 = np.zeros([self.cfg.boxes, 5])
  312. out_bboxes1[:min(out_bboxes.shape[0], self.cfg.boxes)] = out_bboxes[:min(out_bboxes.shape[0], self.cfg.boxes)]
  313. return out_img, out_bboxes1
  314. def _get_val_item(self, index):
  315. """
  316. """
  317. img_path = self.imgs[index]
  318. bboxes_with_cls_id = np.array(self.truth.get(img_path), dtype=np.float)
  319. img = cv2.imread(os.path.join(self.cfg.dataset_dir, img_path))
  320. # img_height, img_width = img.shape[:2]
  321. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  322. # img = cv2.resize(img, (self.cfg.w, self.cfg.h))
  323. # img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
  324. num_objs = len(bboxes_with_cls_id)
  325. target = {}
  326. # boxes to coco format
  327. boxes = bboxes_with_cls_id[...,:4]
  328. boxes[..., 2:] = boxes[..., 2:] - boxes[..., :2] # box width, box height
  329. target['boxes'] = torch.as_tensor(boxes, dtype=torch.float32)
  330. target['labels'] = torch.as_tensor(bboxes_with_cls_id[...,-1].flatten(), dtype=torch.int64)
  331. target['image_id'] = torch.tensor([get_image_id(img_path)])
  332. target['area'] = (target['boxes'][:,3])*(target['boxes'][:,2])
  333. target['iscrowd'] = torch.zeros((num_objs,), dtype=torch.int64)
  334. return img, target
  335. def get_image_id(filename:str) -> int:
  336. """
  337. Convert a string to a integer.
  338. Make sure that the images and the `image_id`s are in one-one correspondence.
  339. There are already `image_id`s in annotations of the COCO dataset,
  340. in which case this function is unnecessary.
  341. For creating one's own `get_image_id` function, one can refer to
  342. https://github.com/google/automl/blob/master/efficientdet/dataset/create_pascal_tfrecord.py#L86
  343. or refer to the following code (where the filenames are like 'level1_123.jpg')
  344. >>> lv, no = os.path.splitext(os.path.basename(filename))[0].split("_")
  345. >>> lv = lv.replace("level", "")
  346. >>> no = f"{int(no):04d}"
  347. >>> return int(lv+no)
  348. """
  349. raise NotImplementedError("Create your own 'get_image_id' function")
  350. lv, no = os.path.splitext(os.path.basename(filename))[0].split("_")
  351. lv = lv.replace("level", "")
  352. no = f"{int(no):04d}"
  353. return int(lv+no)
  354. if __name__ == "__main__":
  355. from cfg import Cfg
  356. import matplotlib.pyplot as plt
  357. random.seed(2020)
  358. np.random.seed(2020)
  359. Cfg.dataset_dir = '/mnt/e/Dataset'
  360. dataset = Yolo_dataset(Cfg.train_label, Cfg)
  361. for i in range(100):
  362. out_img, out_bboxes = dataset.__getitem__(i)
  363. a = draw_box(out_img.copy(), out_bboxes.astype(np.int32))
  364. plt.imshow(a.astype(np.int32))
  365. plt.show()