split_dota.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import itertools
  3. from glob import glob
  4. from math import ceil
  5. from pathlib import Path
  6. import cv2
  7. import numpy as np
  8. from PIL import Image
  9. from tqdm import tqdm
  10. from ultralytics.data.utils import exif_size, img2label_paths
  11. from ultralytics.utils.checks import check_requirements
  12. check_requirements("shapely")
  13. from shapely.geometry import Polygon
  14. def bbox_iof(polygon1, bbox2, eps=1e-6):
  15. """
  16. Calculate iofs between bbox1 and bbox2.
  17. Args:
  18. polygon1 (np.ndarray): Polygon coordinates, (n, 8).
  19. bbox2 (np.ndarray): Bounding boxes, (n ,4).
  20. """
  21. polygon1 = polygon1.reshape(-1, 4, 2)
  22. lt_point = np.min(polygon1, axis=-2) # left-top
  23. rb_point = np.max(polygon1, axis=-2) # right-bottom
  24. bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
  25. lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
  26. rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
  27. wh = np.clip(rb - lt, 0, np.inf)
  28. h_overlaps = wh[..., 0] * wh[..., 1]
  29. left, top, right, bottom = (bbox2[..., i] for i in range(4))
  30. polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)
  31. sg_polys1 = [Polygon(p) for p in polygon1]
  32. sg_polys2 = [Polygon(p) for p in polygon2]
  33. overlaps = np.zeros(h_overlaps.shape)
  34. for p in zip(*np.nonzero(h_overlaps)):
  35. overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
  36. unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
  37. unions = unions[..., None]
  38. unions = np.clip(unions, eps, np.inf)
  39. outputs = overlaps / unions
  40. if outputs.ndim == 1:
  41. outputs = outputs[..., None]
  42. return outputs
  43. def load_yolo_dota(data_root, split="train"):
  44. """
  45. Load DOTA dataset.
  46. Args:
  47. data_root (str): Data root.
  48. split (str): The split data set, could be train or val.
  49. Notes:
  50. The directory structure assumed for the DOTA dataset:
  51. - data_root
  52. - images
  53. - train
  54. - val
  55. - labels
  56. - train
  57. - val
  58. """
  59. assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
  60. im_dir = Path(data_root) / "images" / split
  61. assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
  62. im_files = glob(str(Path(data_root) / "images" / split / "*"))
  63. lb_files = img2label_paths(im_files)
  64. annos = []
  65. for im_file, lb_file in zip(im_files, lb_files):
  66. w, h = exif_size(Image.open(im_file))
  67. with open(lb_file) as f:
  68. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  69. lb = np.array(lb, dtype=np.float32)
  70. annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
  71. return annos
  72. def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0.01):
  73. """
  74. Get the coordinates of windows.
  75. Args:
  76. im_size (tuple): Original image size, (h, w).
  77. crop_sizes (List(int)): Crop size of windows.
  78. gaps (List(int)): Gap between crops.
  79. im_rate_thr (float): Threshold of windows areas divided by image ares.
  80. eps (float): Epsilon value for math operations.
  81. """
  82. h, w = im_size
  83. windows = []
  84. for crop_size, gap in zip(crop_sizes, gaps):
  85. assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
  86. step = crop_size - gap
  87. xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
  88. xs = [step * i for i in range(xn)]
  89. if len(xs) > 1 and xs[-1] + crop_size > w:
  90. xs[-1] = w - crop_size
  91. yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
  92. ys = [step * i for i in range(yn)]
  93. if len(ys) > 1 and ys[-1] + crop_size > h:
  94. ys[-1] = h - crop_size
  95. start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
  96. stop = start + crop_size
  97. windows.append(np.concatenate([start, stop], axis=1))
  98. windows = np.concatenate(windows, axis=0)
  99. im_in_wins = windows.copy()
  100. im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
  101. im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
  102. im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
  103. win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
  104. im_rates = im_areas / win_areas
  105. if not (im_rates > im_rate_thr).any():
  106. max_rate = im_rates.max()
  107. im_rates[abs(im_rates - max_rate) < eps] = 1
  108. return windows[im_rates > im_rate_thr]
  109. def get_window_obj(anno, windows, iof_thr=0.7):
  110. """Get objects for each window."""
  111. h, w = anno["ori_size"]
  112. label = anno["label"]
  113. if len(label):
  114. label[:, 1::2] *= w
  115. label[:, 2::2] *= h
  116. iofs = bbox_iof(label[:, 1:], windows)
  117. # Unnormalized and misaligned coordinates
  118. return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
  119. else:
  120. return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
  121. def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
  122. """
  123. Crop images and save new labels.
  124. Args:
  125. anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
  126. windows (list): A list of windows coordinates.
  127. window_objs (list): A list of labels inside each window.
  128. im_dir (str): The output directory path of images.
  129. lb_dir (str): The output directory path of labels.
  130. Notes:
  131. The directory structure assumed for the DOTA dataset:
  132. - data_root
  133. - images
  134. - train
  135. - val
  136. - labels
  137. - train
  138. - val
  139. """
  140. im = cv2.imread(anno["filepath"])
  141. name = Path(anno["filepath"]).stem
  142. for i, window in enumerate(windows):
  143. x_start, y_start, x_stop, y_stop = window.tolist()
  144. new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
  145. patch_im = im[y_start:y_stop, x_start:x_stop]
  146. ph, pw = patch_im.shape[:2]
  147. cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
  148. label = window_objs[i]
  149. if len(label) == 0:
  150. continue
  151. label[:, 1::2] -= x_start
  152. label[:, 2::2] -= y_start
  153. label[:, 1::2] /= pw
  154. label[:, 2::2] /= ph
  155. with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
  156. for lb in label:
  157. formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
  158. f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
  159. def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)):
  160. """
  161. Split both images and labels.
  162. Notes:
  163. The directory structure assumed for the DOTA dataset:
  164. - data_root
  165. - images
  166. - split
  167. - labels
  168. - split
  169. and the output directory structure is:
  170. - save_dir
  171. - images
  172. - split
  173. - labels
  174. - split
  175. """
  176. im_dir = Path(save_dir) / "images" / split
  177. im_dir.mkdir(parents=True, exist_ok=True)
  178. lb_dir = Path(save_dir) / "labels" / split
  179. lb_dir.mkdir(parents=True, exist_ok=True)
  180. annos = load_yolo_dota(data_root, split=split)
  181. for anno in tqdm(annos, total=len(annos), desc=split):
  182. windows = get_windows(anno["ori_size"], crop_sizes, gaps)
  183. window_objs = get_window_obj(anno, windows)
  184. crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
  185. def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
  186. """
  187. Split train and val set of DOTA.
  188. Notes:
  189. The directory structure assumed for the DOTA dataset:
  190. - data_root
  191. - images
  192. - train
  193. - val
  194. - labels
  195. - train
  196. - val
  197. and the output directory structure is:
  198. - save_dir
  199. - images
  200. - train
  201. - val
  202. - labels
  203. - train
  204. - val
  205. """
  206. crop_sizes, gaps = [], []
  207. for r in rates:
  208. crop_sizes.append(int(crop_size / r))
  209. gaps.append(int(gap / r))
  210. for split in ["train", "val"]:
  211. split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
  212. def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
  213. """
  214. Split test set of DOTA, labels are not included within this set.
  215. Notes:
  216. The directory structure assumed for the DOTA dataset:
  217. - data_root
  218. - images
  219. - test
  220. and the output directory structure is:
  221. - save_dir
  222. - images
  223. - test
  224. """
  225. crop_sizes, gaps = [], []
  226. for r in rates:
  227. crop_sizes.append(int(crop_size / r))
  228. gaps.append(int(gap / r))
  229. save_dir = Path(save_dir) / "images" / "test"
  230. save_dir.mkdir(parents=True, exist_ok=True)
  231. im_dir = Path(data_root) / "images" / "test"
  232. assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
  233. im_files = glob(str(im_dir / "*"))
  234. for im_file in tqdm(im_files, total=len(im_files), desc="test"):
  235. w, h = exif_size(Image.open(im_file))
  236. windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
  237. im = cv2.imread(im_file)
  238. name = Path(im_file).stem
  239. for window in windows:
  240. x_start, y_start, x_stop, y_stop = window.tolist()
  241. new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
  242. patch_im = im[y_start:y_stop, x_start:x_stop]
  243. cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
  244. if __name__ == "__main__":
  245. split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
  246. split_test(data_root="DOTAv2", save_dir="DOTAv2-split")