utils.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import hashlib
  4. import json
  5. import os
  6. import random
  7. import subprocess
  8. import time
  9. import zipfile
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from tarfile import is_tarfile
  13. import cv2
  14. import numpy as np
  15. from PIL import Image, ImageOps
  16. from ultralytics.nn.autobackend import check_class_names
  17. from ultralytics.utils import (
  18. DATASETS_DIR,
  19. LOGGER,
  20. NUM_THREADS,
  21. ROOT,
  22. SETTINGS_YAML,
  23. TQDM,
  24. clean_url,
  25. colorstr,
  26. emojis,
  27. is_dir_writeable,
  28. yaml_load,
  29. yaml_save,
  30. )
  31. from ultralytics.utils.checks import check_file, check_font, is_ascii
  32. from ultralytics.utils.downloads import download, safe_download, unzip_file
  33. from ultralytics.utils.ops import segments2boxes
  34. HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
  35. IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
  36. VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
  37. PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
  38. FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
  39. def img2label_paths(img_paths):
  40. """Define label paths as a function of image paths."""
  41. sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
  42. return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
  43. def get_hash(paths):
  44. """Returns a single hash value of a list of paths (files or dirs)."""
  45. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  46. h = hashlib.sha256(str(size).encode()) # hash sizes
  47. h.update("".join(paths).encode()) # hash paths
  48. return h.hexdigest() # return hash
  49. def exif_size(img: Image.Image):
  50. """Returns exif-corrected PIL size."""
  51. s = img.size # (width, height)
  52. if img.format == "JPEG": # only support JPEG images
  53. with contextlib.suppress(Exception):
  54. exif = img.getexif()
  55. if exif:
  56. rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
  57. if rotation in {6, 8}: # rotation 270 or 90
  58. s = s[1], s[0]
  59. return s
  60. def verify_image(args):
  61. """Verify one image."""
  62. (im_file, cls), prefix = args
  63. # Number (found, corrupt), message
  64. nf, nc, msg = 0, 0, ""
  65. try:
  66. im = Image.open(im_file)
  67. im.verify() # PIL verify
  68. shape = exif_size(im) # image size
  69. shape = (shape[1], shape[0]) # hw
  70. assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
  71. assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
  72. if im.format.lower() in {"jpg", "jpeg"}:
  73. with open(im_file, "rb") as f:
  74. f.seek(-2, 2)
  75. if f.read() != b"\xff\xd9": # corrupt JPEG
  76. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
  77. msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
  78. nf = 1
  79. except Exception as e:
  80. nc = 1
  81. msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
  82. return (im_file, cls), nf, nc, msg
  83. def verify_image_label(args):
  84. """Verify one image-label pair."""
  85. im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
  86. # Number (missing, found, empty, corrupt), message, segments, keypoints
  87. nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
  88. try:
  89. # Verify images
  90. im = Image.open(im_file)
  91. im.verify() # PIL verify
  92. shape = exif_size(im) # image size
  93. shape = (shape[1], shape[0]) # hw
  94. assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
  95. assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
  96. if im.format.lower() in {"jpg", "jpeg"}:
  97. with open(im_file, "rb") as f:
  98. f.seek(-2, 2)
  99. if f.read() != b"\xff\xd9": # corrupt JPEG
  100. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
  101. msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
  102. # Verify labels
  103. if os.path.isfile(lb_file):
  104. nf = 1 # label found
  105. with open(lb_file) as f:
  106. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  107. if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
  108. classes = np.array([x[0] for x in lb], dtype=np.float32)
  109. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  110. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  111. lb = np.array(lb, dtype=np.float32)
  112. nl = len(lb)
  113. if nl:
  114. if keypoint:
  115. assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
  116. points = lb[:, 5:].reshape(-1, ndim)[:, :2]
  117. else:
  118. assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
  119. points = lb[:, 1:]
  120. assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
  121. assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
  122. # All labels
  123. max_cls = lb[:, 0].max() # max label count
  124. assert max_cls <= num_cls, (
  125. f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
  126. f"Possible class labels are 0-{num_cls - 1}"
  127. )
  128. _, i = np.unique(lb, axis=0, return_index=True)
  129. if len(i) < nl: # duplicate row check
  130. lb = lb[i] # remove duplicates
  131. if segments:
  132. segments = [segments[x] for x in i]
  133. msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
  134. else:
  135. ne = 1 # label empty
  136. lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
  137. else:
  138. nm = 1 # label missing
  139. lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
  140. if keypoint:
  141. keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
  142. if ndim == 2:
  143. kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
  144. keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
  145. lb = lb[:, :5]
  146. return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
  147. except Exception as e:
  148. nc = 1
  149. msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
  150. return [None, None, None, None, None, nm, nf, ne, nc, msg]
  151. def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
  152. """
  153. Convert a list of polygons to a binary mask of the specified image size.
  154. Args:
  155. imgsz (tuple): The size of the image as (height, width).
  156. polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
  157. N is the number of polygons, and M is the number of points such that M % 2 = 0.
  158. color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
  159. downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
  160. Returns:
  161. (np.ndarray): A binary mask of the specified image size with the polygons filled in.
  162. """
  163. mask = np.zeros(imgsz, dtype=np.uint8)
  164. polygons = np.asarray(polygons, dtype=np.int32)
  165. polygons = polygons.reshape((polygons.shape[0], -1, 2))
  166. cv2.fillPoly(mask, polygons, color=color)
  167. nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
  168. # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
  169. return cv2.resize(mask, (nw, nh))
  170. def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
  171. """
  172. Convert a list of polygons to a set of binary masks of the specified image size.
  173. Args:
  174. imgsz (tuple): The size of the image as (height, width).
  175. polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
  176. N is the number of polygons, and M is the number of points such that M % 2 = 0.
  177. color (int): The color value to fill in the polygons on the masks.
  178. downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
  179. Returns:
  180. (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
  181. """
  182. return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
  183. def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
  184. """Return a (640, 640) overlap mask."""
  185. masks = np.zeros(
  186. (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
  187. dtype=np.int32 if len(segments) > 255 else np.uint8,
  188. )
  189. areas = []
  190. ms = []
  191. for si in range(len(segments)):
  192. mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
  193. ms.append(mask)
  194. areas.append(mask.sum())
  195. areas = np.asarray(areas)
  196. index = np.argsort(-areas)
  197. ms = np.array(ms)[index]
  198. for i in range(len(segments)):
  199. mask = ms[i] * (i + 1)
  200. masks = masks + mask
  201. masks = np.clip(masks, a_min=0, a_max=i + 1)
  202. return masks, index
  203. def find_dataset_yaml(path: Path) -> Path:
  204. """
  205. Find and return the YAML file associated with a Detect, Segment or Pose dataset.
  206. This function searches for a YAML file at the root level of the provided directory first, and if not found, it
  207. performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
  208. is raised if no YAML file is found or if multiple YAML files are found.
  209. Args:
  210. path (Path): The directory path to search for the YAML file.
  211. Returns:
  212. (Path): The path of the found YAML file.
  213. """
  214. files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
  215. assert files, f"No YAML file found in '{path.resolve()}'"
  216. if len(files) > 1:
  217. files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
  218. assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
  219. return files[0]
  220. def check_det_dataset(dataset, is_train_on_platform=False, autodownload=True):
  221. """
  222. Download, verify, and/or unzip a dataset if not found locally.
  223. This function checks the availability of a specified dataset, and if not found, it has the option to download and
  224. unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
  225. resolves paths related to the dataset.
  226. Args:
  227. dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
  228. is_train_on_platform(bool, optional): Whether or not the dataset trains on VINNO Platform or not.
  229. autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
  230. Returns:
  231. (dict): Parsed dataset information and paths.
  232. """
  233. file = check_file(dataset)
  234. # Download (optional)
  235. extract_dir = ""
  236. if zipfile.is_zipfile(file) or is_tarfile(file):
  237. new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
  238. file = find_dataset_yaml(DATASETS_DIR / new_dir)
  239. extract_dir, autodownload = file.parent, False
  240. # Read YAML
  241. data = yaml_load(file, append_filename=True) # dictionary
  242. # 判断是否在vinno平台上训练
  243. # 需要is_train_on_platform, data含有platform_data_args并且platform_data_args不为空,来判断在vinno平台上训练
  244. # 如果在平台训练,则无需进行下面的checks
  245. if is_train_on_platform and ('platform_data_args' in data and data['platform_data_args'] != ''):
  246. if "names" not in data and "nc" not in data:
  247. raise SyntaxError(
  248. emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
  249. if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
  250. raise SyntaxError(
  251. emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
  252. if "names" not in data:
  253. data["names"] = [f"class_{i}" for i in range(data["nc"])]
  254. else:
  255. data["nc"] = len(data["names"])
  256. data["names"] = check_class_names(data["names"])
  257. return data
  258. # Checks
  259. for k in "train", "val":
  260. if k not in data:
  261. if k != "val" or "validation" not in data:
  262. raise SyntaxError(
  263. emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
  264. )
  265. LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
  266. data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
  267. if "names" not in data and "nc" not in data:
  268. raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
  269. if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
  270. raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
  271. if "names" not in data:
  272. data["names"] = [f"class_{i}" for i in range(data["nc"])]
  273. else:
  274. data["nc"] = len(data["names"])
  275. data["names"] = check_class_names(data["names"])
  276. # Resolve paths
  277. path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
  278. if not path.is_absolute():
  279. path = (DATASETS_DIR / path).resolve()
  280. # Set paths
  281. data["path"] = path # download scripts
  282. for k in "train", "val", "test", "minival":
  283. if data.get(k): # prepend path
  284. if isinstance(data[k], str):
  285. x = (path / data[k]).resolve()
  286. if not x.exists() and data[k].startswith("../"):
  287. x = (path / data[k][3:]).resolve()
  288. data[k] = str(x)
  289. else:
  290. data[k] = [str((path / x).resolve()) for x in data[k]]
  291. # Parse YAML
  292. val, s = (data.get(x) for x in ("val", "download"))
  293. if val:
  294. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  295. if not all(x.exists() for x in val):
  296. name = clean_url(dataset) # dataset name with URL auth stripped
  297. m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
  298. if s and autodownload:
  299. LOGGER.warning(m)
  300. else:
  301. m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
  302. raise FileNotFoundError(m)
  303. t = time.time()
  304. r = None # success
  305. if s.startswith("http") and s.endswith(".zip"): # URL
  306. safe_download(url=s, dir=DATASETS_DIR, delete=True)
  307. elif s.startswith("bash "): # bash script
  308. LOGGER.info(f"Running {s} ...")
  309. r = os.system(s)
  310. else: # python script
  311. exec(s, {"yaml": data})
  312. dt = f"({round(time.time() - t, 1)}s)"
  313. s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
  314. LOGGER.info(f"Dataset download {s}\n")
  315. check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
  316. return data # dictionary
  317. def check_cls_dataset(dataset, split=""):
  318. """
  319. Checks a classification dataset such as Imagenet.
  320. This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
  321. If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
  322. Args:
  323. dataset (str | Path): The name of the dataset.
  324. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
  325. Returns:
  326. (dict): A dictionary containing the following keys:
  327. - 'train' (Path): The directory path containing the training set of the dataset.
  328. - 'val' (Path): The directory path containing the validation set of the dataset.
  329. - 'test' (Path): The directory path containing the test set of the dataset.
  330. - 'nc' (int): The number of classes in the dataset.
  331. - 'names' (dict): A dictionary of class names in the dataset.
  332. """
  333. # Download (optional if dataset=https://file.zip is passed directly)
  334. if str(dataset).startswith(("http:/", "https:/")):
  335. dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
  336. elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
  337. file = check_file(dataset)
  338. dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
  339. dataset = Path(dataset)
  340. data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
  341. if not data_dir.is_dir():
  342. LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
  343. t = time.time()
  344. if str(dataset) == "imagenet":
  345. subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
  346. else:
  347. url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
  348. download(url, dir=data_dir.parent)
  349. s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
  350. LOGGER.info(s)
  351. train_set = data_dir / "train"
  352. val_set = (
  353. data_dir / "val"
  354. if (data_dir / "val").exists()
  355. else data_dir / "validation"
  356. if (data_dir / "validation").exists()
  357. else None
  358. ) # data/test or data/val
  359. test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
  360. if split == "val" and not val_set:
  361. LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
  362. elif split == "test" and not test_set:
  363. LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
  364. nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
  365. names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
  366. names = dict(enumerate(sorted(names)))
  367. # Print to console
  368. for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
  369. prefix = f'{colorstr(f"{k}:")} {v}...'
  370. if v is None:
  371. LOGGER.info(prefix)
  372. else:
  373. files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
  374. nf = len(files) # number of files
  375. nd = len({file.parent for file in files}) # number of directories
  376. if nf == 0:
  377. if k == "train":
  378. raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
  379. else:
  380. LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
  381. elif nd != nc:
  382. LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
  383. else:
  384. LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
  385. return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
  386. class HUBDatasetStats:
  387. """
  388. A class for generating HUB dataset JSON and `-hub` dataset directory.
  389. Args:
  390. path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
  391. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
  392. autodownload (bool): Attempt to download dataset if not found locally. Default is False.
  393. Example:
  394. Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
  395. i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
  396. ```python
  397. from ultralytics.data.utils import HUBDatasetStats
  398. stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
  399. stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
  400. stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
  401. stats = HUBDatasetStats('path/to/dota8.zip', task='obb') # OBB dataset
  402. stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify') # classification dataset
  403. stats.get_json(save=True)
  404. stats.process_images()
  405. ```
  406. """
  407. def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
  408. """Initialize class."""
  409. path = Path(path).resolve()
  410. LOGGER.info(f"Starting HUB dataset checks for {path}....")
  411. self.task = task # detect, segment, pose, classify
  412. if self.task == "classify":
  413. unzip_dir = unzip_file(path)
  414. data = check_cls_dataset(unzip_dir)
  415. data["path"] = unzip_dir
  416. else: # detect, segment, pose
  417. _, data_dir, yaml_path = self._unzip(Path(path))
  418. try:
  419. # Load YAML with checks
  420. data = yaml_load(yaml_path)
  421. data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
  422. yaml_save(yaml_path, data)
  423. data = check_det_dataset(yaml_path, autodownload) # dict
  424. data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
  425. except Exception as e:
  426. raise Exception("error/HUB/dataset_stats/init") from e
  427. self.hub_dir = Path(f'{data["path"]}-hub')
  428. self.im_dir = self.hub_dir / "images"
  429. self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
  430. self.data = data
  431. @staticmethod
  432. def _unzip(path):
  433. """Unzip data.zip."""
  434. if not str(path).endswith(".zip"): # path is data.yaml
  435. return False, None, path
  436. unzip_dir = unzip_file(path, path=path.parent)
  437. assert unzip_dir.is_dir(), (
  438. f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
  439. )
  440. return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
  441. def _hub_ops(self, f):
  442. """Saves a compressed image for HUB previews."""
  443. compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
  444. def get_json(self, save=False, verbose=False):
  445. """Return dataset JSON for Ultralytics HUB."""
  446. def _round(labels):
  447. """Update labels to integer class and 4 decimal place floats."""
  448. if self.task == "detect":
  449. coordinates = labels["bboxes"]
  450. elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
  451. coordinates = [x.flatten() for x in labels["segments"]]
  452. elif self.task == "pose":
  453. n, nk, nd = labels["keypoints"].shape
  454. coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
  455. else:
  456. raise ValueError(f"Undefined dataset task={self.task}.")
  457. zipped = zip(labels["cls"], coordinates)
  458. return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
  459. for split in "train", "val", "test":
  460. self.stats[split] = None # predefine
  461. path = self.data.get(split)
  462. # Check split
  463. if path is None: # no split
  464. continue
  465. files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
  466. if not files: # no images
  467. continue
  468. # Get dataset statistics
  469. if self.task == "classify":
  470. from torchvision.datasets import ImageFolder
  471. dataset = ImageFolder(self.data[split])
  472. x = np.zeros(len(dataset.classes)).astype(int)
  473. for im in dataset.imgs:
  474. x[im[1]] += 1
  475. self.stats[split] = {
  476. "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
  477. "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
  478. "labels": [{Path(k).name: v} for k, v in dataset.imgs],
  479. }
  480. else:
  481. from ultralytics.data import YOLODataset
  482. dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
  483. x = np.array(
  484. [
  485. np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
  486. for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
  487. ]
  488. ) # shape(128x80)
  489. self.stats[split] = {
  490. "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
  491. "image_stats": {
  492. "total": len(dataset),
  493. "unlabelled": int(np.all(x == 0, 1).sum()),
  494. "per_class": (x > 0).sum(0).tolist(),
  495. },
  496. "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
  497. }
  498. # Save, print and return
  499. if save:
  500. self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
  501. stats_path = self.hub_dir / "stats.json"
  502. LOGGER.info(f"Saving {stats_path.resolve()}...")
  503. with open(stats_path, "w") as f:
  504. json.dump(self.stats, f) # save stats.json
  505. if verbose:
  506. LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
  507. return self.stats
  508. def process_images(self):
  509. """Compress images for Ultralytics HUB."""
  510. from ultralytics.data import YOLODataset # ClassificationDataset
  511. self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
  512. for split in "train", "val", "test":
  513. if self.data.get(split) is None:
  514. continue
  515. dataset = YOLODataset(img_path=self.data[split], data=self.data)
  516. with ThreadPool(NUM_THREADS) as pool:
  517. for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
  518. pass
  519. LOGGER.info(f"Done. All images saved to {self.im_dir}")
  520. return self.im_dir
  521. def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
  522. """
  523. Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
  524. Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
  525. resized.
  526. Args:
  527. f (str): The path to the input image file.
  528. f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
  529. max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
  530. quality (int, optional): The image compression quality as a percentage. Default is 50%.
  531. Example:
  532. ```python
  533. from pathlib import Path
  534. from ultralytics.data.utils import compress_one_image
  535. for f in Path('path/to/dataset').rglob('*.jpg'):
  536. compress_one_image(f)
  537. ```
  538. """
  539. try: # use PIL
  540. im = Image.open(f)
  541. r = max_dim / max(im.height, im.width) # ratio
  542. if r < 1.0: # image too large
  543. im = im.resize((int(im.width * r), int(im.height * r)))
  544. im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
  545. except Exception as e: # use OpenCV
  546. LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
  547. im = cv2.imread(f)
  548. im_height, im_width = im.shape[:2]
  549. r = max_dim / max(im_height, im_width) # ratio
  550. if r < 1.0: # image too large
  551. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  552. cv2.imwrite(str(f_new or f), im)
  553. def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
  554. """
  555. Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
  556. Args:
  557. path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
  558. weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
  559. annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
  560. Example:
  561. ```python
  562. from ultralytics.data.utils import autosplit
  563. autosplit()
  564. ```
  565. """
  566. path = Path(path) # images dir
  567. files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  568. n = len(files) # number of files
  569. random.seed(0) # for reproducibility
  570. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  571. txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
  572. for x in txt:
  573. if (path.parent / x).exists():
  574. (path.parent / x).unlink() # remove existing
  575. LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
  576. for i, img in TQDM(zip(indices, files), total=n):
  577. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  578. with open(path.parent / txt[i], "a") as f:
  579. f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
  580. def load_dataset_cache_file(path):
  581. """Load an Ultralytics *.cache dictionary from path."""
  582. import gc
  583. gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
  584. cache = np.load(str(path), allow_pickle=True).item() # load dict
  585. gc.enable()
  586. return cache
  587. def save_dataset_cache_file(prefix, path, x, version):
  588. """Save an Ultralytics dataset *.cache dictionary x to path."""
  589. x["version"] = version # add cache version
  590. if is_dir_writeable(path.parent):
  591. if path.exists():
  592. path.unlink() # remove *.cache file if exists
  593. np.save(str(path), x) # save cache for next time
  594. path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
  595. LOGGER.info(f"{prefix}New cache created: {path}")
  596. else:
  597. LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")