base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import deepcopy
  7. from multiprocessing.pool import ThreadPool
  8. from pathlib import Path
  9. from typing import Optional
  10. import cv2
  11. import numpy as np
  12. import psutil
  13. from torch.utils.data import Dataset
  14. from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
  15. from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
  16. from ultralytics.data.datagenerate import get_image_remove_extra_contour
  17. import ultralytics.trainsdk.TrainSdk as TrainSdk
  18. class BaseDataset(Dataset):
  19. """
  20. Base dataset class for loading and processing image data.
  21. Args:
  22. img_path (str): Path to the folder containing images.
  23. is_train_on_platform (bool): 是否在vinno平台上训练
  24. data (dict, optional): A dataset YAML dictionary. Defaults to None.
  25. imgsz (int, optional): Image size. Defaults to 640.\
  26. cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
  27. augment (bool, optional): If True, data augmentation is applied. Defaults to True.
  28. hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
  29. prefix (str, optional): Prefix to print in log messages. Defaults to ''.
  30. rect (bool, optional): If True, rectangular training is used. Defaults to False.
  31. batch_size (int, optional): Size of batches. Defaults to None.
  32. stride (int, optional): Stride. Defaults to 32.
  33. pad (float, optional): Padding. Defaults to 0.0.
  34. single_cls (bool, optional): If True, single class training is used. Defaults to False.
  35. classes (list): List of included classes. Default is None.
  36. fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
  37. Attributes:
  38. im_files (list): List of image file paths.
  39. labels (list): List of label data dictionaries.
  40. ni (int): Number of images in the dataset.
  41. ims (list): List of loaded images.
  42. npy_files (list): List of numpy file paths.
  43. transforms (callable): Image transformation function.
  44. """
  45. def __init__(
  46. self,
  47. img_path,
  48. is_train_on_platform=False,
  49. data=None,
  50. imgsz=640,
  51. cache=False,
  52. augment=True,
  53. hyp=DEFAULT_CFG,
  54. prefix="",
  55. rect=False,
  56. batch_size=16,
  57. stride=32,
  58. pad=0.5,
  59. single_cls=False,
  60. classes=None,
  61. fraction=1.0,
  62. ):
  63. """Initialize BaseDataset with given configuration and options."""
  64. super().__init__()
  65. self.img_path = img_path
  66. self.imgsz = imgsz
  67. self.augment = augment
  68. self.single_cls = single_cls
  69. self.prefix = prefix
  70. self.fraction = fraction
  71. self.is_train_on_platform = is_train_on_platform
  72. self.data = data
  73. if self.is_train_on_platform:
  74. self.train_or_val_data = "train" if self.augment else "val"
  75. self.platform_data_args = self.data['platform_data_args']
  76. self.token = self.platform_data_args["token"]
  77. self.extra_contours_args = self.platform_data_args['extra_contours_args']
  78. self.im_files, self.labels, self.crop_boxes, self.crop_contours = self.get_img_files_and_labels()
  79. else:
  80. self.im_files = self.get_img_files(self.img_path)
  81. self.labels = self.get_labels()
  82. if isinstance(cache, str):
  83. cache = cache.lower()
  84. # Cache images
  85. if cache == "ram" and not self.check_cache_ram():
  86. cache = False
  87. self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
  88. if cache:
  89. self.cache_images(cache)
  90. self.update_labels(include_class=classes) # single_cls and include_class
  91. self.ni = len(self.labels) # number of images
  92. self.rect = rect
  93. self.batch_size = batch_size
  94. self.stride = stride
  95. self.pad = pad
  96. if self.rect:
  97. assert self.batch_size is not None
  98. self.set_rectangle()
  99. # Buffer thread for mosaic images
  100. self.buffer = [] # buffer size = batch size
  101. self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
  102. # Cache images (options are cache = True, False, None, "ram", "disk")
  103. self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
  104. # Transforms
  105. self.transforms = self.build_transforms(hyp=hyp)
  106. def get_img_files(self, img_path):
  107. """Read image files."""
  108. try:
  109. f = [] # image files
  110. for p in img_path if isinstance(img_path, list) else [img_path]:
  111. p = Path(p) # os-agnostic
  112. if p.is_dir(): # dir
  113. f += glob.glob(str(p / "**" / "*.*"), recursive=True)
  114. # F = list(p.rglob('*.*')) # pathlib
  115. elif p.is_file(): # file
  116. with open(p) as t:
  117. t = t.read().strip().splitlines()
  118. parent = str(p.parent) + os.sep
  119. f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
  120. # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  121. else:
  122. raise FileNotFoundError(f"{self.prefix}{p} does not exist")
  123. im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
  124. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  125. assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
  126. except Exception as e:
  127. raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
  128. if self.fraction < 1:
  129. # im_files = im_files[: round(len(im_files) * self.fraction)]
  130. num_elements_to_select = round(len(im_files) * self.fraction)
  131. im_files = random.sample(im_files, num_elements_to_select)
  132. return im_files
  133. def update_labels(self, include_class: Optional[list]):
  134. """Update labels to include only these classes (optional)."""
  135. include_class_array = np.array(include_class).reshape(1, -1)
  136. for i in range(len(self.labels)):
  137. if include_class is not None:
  138. cls = self.labels[i]["cls"]
  139. bboxes = self.labels[i]["bboxes"]
  140. segments = self.labels[i]["segments"]
  141. keypoints = self.labels[i]["keypoints"]
  142. j = (cls == include_class_array).any(1)
  143. self.labels[i]["cls"] = cls[j]
  144. self.labels[i]["bboxes"] = bboxes[j]
  145. if segments:
  146. self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
  147. if keypoints is not None:
  148. self.labels[i]["keypoints"] = keypoints[j]
  149. if self.single_cls:
  150. self.labels[i]["cls"][:, 0] = 0
  151. def load_image(self, i, rect_mode=True):
  152. """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
  153. im, f, = self.ims[i], self.im_files[i]
  154. if not self.is_train_on_platform:
  155. fn = self.npy_files[i]
  156. if im is None: # not cached in RAM
  157. if self.is_train_on_platform:
  158. if self.train_or_val_data == "train":
  159. image_data, _, image_name, _ = TrainSdk.get_labeled_file(self.token, f)
  160. else:
  161. image_data, _, image_name, _ = TrainSdk.get_test_labeled_file(self.token, f)
  162. # read image
  163. nparr_data = np.frombuffer(image_data, dtype=np.uint8)
  164. im = cv2.imdecode(nparr_data, flags=cv2.IMREAD_COLOR) # BGR
  165. # 判断是否需要裁图,裁图影响输入图像
  166. if self.extra_contours_args != '':
  167. use_orig_image_pixel = self.extra_contours_args['use_orig_image_pixel']
  168. im = get_image_remove_extra_contour(im, self.crop_contours[i], self.crop_boxes[i],
  169. use_orig_image_pixel)
  170. else:
  171. if fn.exists(): # load npy
  172. try:
  173. im = np.load(fn)
  174. except Exception as e:
  175. LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
  176. Path(fn).unlink(missing_ok=True)
  177. im = cv2.imread(f) # BGR
  178. else: # read image
  179. im = cv2.imread(f) # BGR
  180. if im is None:
  181. raise FileNotFoundError(f"Image Not Found {f}")
  182. h0, w0 = im.shape[:2] # orig hw
  183. if rect_mode: # resize long side to imgsz while maintaining aspect ratio
  184. r = self.imgsz / max(h0, w0) # ratio
  185. if r != 1: # if sizes are not equal
  186. w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
  187. im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
  188. elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz
  189. im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
  190. # Add to buffer if training with augmentations
  191. if self.augment:
  192. self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  193. self.buffer.append(i)
  194. if len(self.buffer) >= self.max_buffer_length:
  195. j = self.buffer.pop(0)
  196. self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
  197. return im, (h0, w0), im.shape[:2]
  198. return self.ims[i], self.im_hw0[i], self.im_hw[i]
  199. def cache_images(self, cache):
  200. """Cache images to memory or disk."""
  201. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  202. fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
  203. with ThreadPool(NUM_THREADS) as pool:
  204. results = pool.imap(fcn, range(self.ni))
  205. pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
  206. for i, x in pbar:
  207. if cache == "disk":
  208. b += self.npy_files[i].stat().st_size
  209. else: # 'ram'
  210. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  211. b += self.ims[i].nbytes
  212. pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
  213. pbar.close()
  214. def cache_images_to_disk(self, i):
  215. """Saves an image as an *.npy file for faster loading."""
  216. f = self.npy_files[i]
  217. if not f.exists():
  218. np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)
  219. def check_cache_ram(self, safety_margin=0.5):
  220. """Check image caching requirements vs available memory."""
  221. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  222. n = min(self.ni, 30) # extrapolate from 30 random images
  223. for _ in range(n):
  224. im = cv2.imread(random.choice(self.im_files)) # sample image
  225. ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
  226. b += im.nbytes * ratio**2
  227. mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
  228. mem = psutil.virtual_memory()
  229. cache = mem_required < mem.available # to cache or not to cache, that is the question
  230. if not cache:
  231. LOGGER.info(
  232. f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
  233. f'with {int(safety_margin * 100)}% safety margin but only '
  234. f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
  235. f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
  236. )
  237. return cache
  238. def set_rectangle(self):
  239. """Sets the shape of bounding boxes for YOLO detections as rectangles."""
  240. bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
  241. nb = bi[-1] + 1 # number of batches
  242. s = np.array([x.pop("shape") for x in self.labels]) # hw
  243. ar = s[:, 0] / s[:, 1] # aspect ratio
  244. irect = ar.argsort()
  245. self.im_files = [self.im_files[i] for i in irect]
  246. self.labels = [self.labels[i] for i in irect]
  247. ar = ar[irect]
  248. # Set training image shapes
  249. shapes = [[1, 1]] * nb
  250. for i in range(nb):
  251. ari = ar[bi == i]
  252. mini, maxi = ari.min(), ari.max()
  253. if maxi < 1:
  254. shapes[i] = [maxi, 1]
  255. elif mini > 1:
  256. shapes[i] = [1, 1 / mini]
  257. self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
  258. self.batch = bi # batch index of image
  259. def __getitem__(self, index):
  260. """Returns transformed label information for given index."""
  261. return self.transforms(self.get_image_and_label(index))
  262. def get_image_and_label(self, index):
  263. """Get and return label information from the dataset."""
  264. label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
  265. label.pop("shape", None) # shape is for rect, remove it
  266. label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
  267. label["ratio_pad"] = (
  268. label["resized_shape"][0] / label["ori_shape"][0],
  269. label["resized_shape"][1] / label["ori_shape"][1],
  270. ) # for evaluation
  271. if self.rect:
  272. label["rect_shape"] = self.batch_shapes[self.batch[index]]
  273. return self.update_labels_info(label)
  274. def __len__(self):
  275. """Returns the length of the labels list for the dataset."""
  276. return len(self.labels)
  277. def update_labels_info(self, label):
  278. """Custom your label format here."""
  279. return label
  280. def build_transforms(self, hyp=None):
  281. """
  282. Users can customize augmentations here.
  283. Example:
  284. ```python
  285. if self.augment:
  286. # Training transforms
  287. return Compose([])
  288. else:
  289. # Val transforms
  290. return Compose([])
  291. ```
  292. """
  293. raise NotImplementedError
  294. def get_labels(self):
  295. """
  296. Users can customize their own format here.
  297. Note:
  298. Ensure output is a dictionary with the following keys:
  299. ```python
  300. dict(
  301. im_file=im_file,
  302. shape=shape, # format: (height, width)
  303. cls=cls,
  304. bboxes=bboxes, # xywh
  305. segments=segments, # xy
  306. keypoints=keypoints, # xy
  307. normalized=True, # or False
  308. bbox_format="xyxy", # or xywh, ltwh
  309. )
  310. ```
  311. """
  312. raise NotImplementedError
  313. def get_img_files_and_labels(self):
  314. """
  315. 默认的yolo是在基类的get_img_files得到im_files,在继承的YOLODataset的get_labels得到labels
  316. 平台上使用时,im_files和labels同时返回,因此可以重写该函数,无需修改get_img_files和get_labels
  317. Note:
  318. Ensure outputs are im_files, all_labels, crop_boxes, crop_contours:
  319. im_files 所有文件名的list
  320. all_labels 所有label信息组成dict的list,dict格式如下:
  321. ```python
  322. dict(
  323. im_file=im_file,
  324. shape=shape, # format: (height, width)
  325. cls=cls,
  326. bboxes=bboxes, # xywh
  327. segments=segments, # xy
  328. keypoints=keypoints, # xy
  329. normalized=True, # or False
  330. bbox_format="xyxy", # or xywh, ltwh
  331. )
  332. ```
  333. crop_boxes 所有可能需要的裁图框的list
  334. crop_contours 所有可能需要的裁图轮廓的list
  335. """
  336. raise NotImplementedError