123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import contextlib
- import json
- from collections import defaultdict
- from itertools import repeat
- from multiprocessing.pool import ThreadPool
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- from PIL import Image
- from torch.utils.data import ConcatDataset
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
- from ultralytics.utils.ops import resample_segments
- from .augment import (
- Compose,
- Format,
- Instances,
- LetterBox,
- RandomLoadText,
- classify_augmentations,
- classify_transforms,
- v8_transforms,
- )
- from .base import BaseDataset
- from .utils import (
- get_hash,
- img2label_paths,
- load_dataset_cache_file,
- save_dataset_cache_file,
- verify_image,
- verify_image_label,
- )
- from ultralytics.data.datagenerate import DataGenerator
- # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
- class YOLODataset(BaseDataset):
- """
- Dataset class for loading object detection and/or segmentation labels in YOLO format.
- Args:
- task (str): An explicit arg to point current task, Defaults to 'detect'.
- Returns:
- (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
- """
- def __init__(self, *args, task="detect", **kwargs):
- """Initializes the YOLODataset with optional configurations for segments and keypoints."""
- self.use_segments = task == "segment"
- self.use_keypoints = task == "pose"
- self.use_obb = task == "obb"
- assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
- super().__init__(*args, **kwargs)
- def cache_labels(self, path=Path("./labels.cache")):
- """
- Cache dataset labels, check images and read shapes.
- Args:
- path (Path): Path where to save the cache file. Default is Path('./labels.cache').
- Returns:
- (dict): labels.
- """
- x = {"labels": []}
- nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
- desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
- total = len(self.im_files)
- nkpt, ndim = self.data.get("kpt_shape", (0, 0))
- if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
- raise ValueError(
- "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
- "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
- )
- with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(
- func=verify_image_label,
- iterable=zip(
- self.im_files,
- self.label_files,
- repeat(self.prefix),
- repeat(self.use_keypoints),
- repeat(len(self.data["names"])),
- repeat(nkpt),
- repeat(ndim),
- ),
- )
- pbar = TQDM(results, desc=desc, total=total)
- for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
- nm += nm_f
- nf += nf_f
- ne += ne_f
- nc += nc_f
- if im_file:
- x["labels"].append(
- {
- "im_file": im_file,
- "shape": shape,
- "cls": lb[:, 0:1], # n, 1
- "bboxes": lb[:, 1:], # n, 4
- "segments": segments,
- "keypoints": keypoint,
- "normalized": True,
- "bbox_format": "xywh",
- }
- )
- if msg:
- msgs.append(msg)
- pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
- pbar.close()
- if msgs:
- LOGGER.info("\n".join(msgs))
- if nf == 0:
- LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
- x["hash"] = get_hash(self.label_files + self.im_files)
- x["results"] = nf, nm, ne, nc, len(self.im_files)
- x["msgs"] = msgs # warnings
- save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
- return x
- def get_labels(self):
- """Returns dictionary of labels for YOLO training."""
- self.label_files = img2label_paths(self.im_files)
- cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
- try:
- cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
- assert cache["version"] == DATASET_CACHE_VERSION # matches current version
- assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
- except (FileNotFoundError, AssertionError, AttributeError):
- cache, exists = self.cache_labels(cache_path), False # run cache ops
- # Display cache
- nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
- if exists and LOCAL_RANK in {-1, 0}:
- d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
- TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
- if cache["msgs"]:
- LOGGER.info("\n".join(cache["msgs"])) # display warnings
- # Read cache
- [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
- labels = cache["labels"]
- if not labels:
- LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
- self.im_files = [lb["im_file"] for lb in labels] # update im_files
- # Check if the dataset is all boxes or all segments
- lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
- len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
- if len_segments and len_boxes != len_segments:
- LOGGER.warning(
- f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
- f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
- "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
- )
- for lb in labels:
- lb["segments"] = []
- if len_cls == 0:
- LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
- return labels
- def get_img_files_and_labels(self):
- """平台训练时,获取图像和labels"""
- if self.is_train_on_platform:
- last_name = self.data['names']
- wrong_file = self.platform_data_args["wrong_file"]
- data_type = self.platform_data_args["data_type"]
- needed_image_results_dict = self.platform_data_args['needed_image_results_dict']
- needed_rois_dict = self.platform_data_args['needed_rois_dict']
- class_aug_times = self.platform_data_args['class_aug_times']
- label_aug_level = self.platform_data_args['label_aug_level']
- labels, im_files, shapes, segments, crop_boxes, crop_contours = DataGenerator(
- self.token,
- self.train_or_val_data,
- wrong_file,
- needed_image_results_dict,
- needed_rois_dict,
- label_aug_level,
- class_aug_times,
- last_name,
- data_type,
- self.extra_contours_args).generate()
- all_labels = []
- # 暂时不支持输出keypoints,后续可以修改DataGenerator,获取所需要的keypoints
- for idx in range(len(im_files)):
- all_labels.append(
- {
- "im_file": im_files[idx],
- "shape": shapes[idx],
- "cls": labels[idx][:, 0:1], # n, 1
- "bboxes": labels[idx][:, 1:], # n, 4
- "segments": segments[idx],
- "keypoints": None,
- "normalized": True,
- "bbox_format": "xywh",
- }
- )
- return im_files, all_labels, crop_boxes, crop_contours
- def build_transforms(self, hyp=None):
- """Builds and appends transforms to the list."""
- if self.augment:
- hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
- hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
- transforms = v8_transforms(self, self.imgsz, hyp)
- else:
- transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
- transforms.append(
- Format(
- bbox_format="xywh",
- normalize=True,
- return_mask=self.use_segments,
- return_keypoint=self.use_keypoints,
- return_obb=self.use_obb,
- batch_idx=True,
- mask_ratio=hyp.mask_ratio,
- mask_overlap=hyp.overlap_mask,
- bgr=hyp.bgr if self.augment else 0.0, # only affect training.
- )
- )
- return transforms
- def close_mosaic(self, hyp):
- """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
- hyp.mosaic = 0.0 # set mosaic ratio=0.0
- hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
- hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
- self.transforms = self.build_transforms(hyp)
- def update_labels_info(self, label):
- """
- Custom your label format here.
- Note:
- cls is not with bboxes now, classification and semantic segmentation need an independent cls label
- Can also support classification and semantic segmentation by adding or removing dict keys there.
- """
- bboxes = label.pop("bboxes")
- segments = label.pop("segments", [])
- keypoints = label.pop("keypoints", None)
- bbox_format = label.pop("bbox_format")
- normalized = label.pop("normalized")
- # NOTE: do NOT resample oriented boxes
- segment_resamples = 100 if self.use_obb else 1000
- if len(segments) > 0:
- # list[np.array(1000, 2)] * num_samples
- # (N, 1000, 2)
- segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
- else:
- segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
- label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
- return label
- @staticmethod
- def collate_fn(batch):
- """Collates data samples into batches."""
- new_batch = {}
- keys = batch[0].keys()
- values = list(zip(*[list(b.values()) for b in batch]))
- for i, k in enumerate(keys):
- value = values[i]
- if k == "img":
- value = torch.stack(value, 0)
- if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
- value = torch.cat(value, 0)
- new_batch[k] = value
- new_batch["batch_idx"] = list(new_batch["batch_idx"])
- for i in range(len(new_batch["batch_idx"])):
- new_batch["batch_idx"][i] += i # add target image index for build_targets()
- new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
- return new_batch
- class YOLOMultiModalDataset(YOLODataset):
- """
- Dataset class for loading object detection and/or segmentation labels in YOLO format.
- Args:
- data (dict, optional): A dataset YAML dictionary. Defaults to None.
- task (str): An explicit arg to point current task, Defaults to 'detect'.
- Returns:
- (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
- """
- def __init__(self, *args, data=None, task="detect", **kwargs):
- """Initializes a dataset object for object detection tasks with optional specifications."""
- super().__init__(*args, data=data, task=task, **kwargs)
- def update_labels_info(self, label):
- """Add texts information for multi modal model training."""
- labels = super().update_labels_info(label)
- # NOTE: some categories are concatenated with its synonyms by `/`.
- labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
- return labels
- def build_transforms(self, hyp=None):
- """Enhances data transformations with optional text augmentation for multi-modal training."""
- transforms = super().build_transforms(hyp)
- if self.augment:
- # NOTE: hard-coded the args for now.
- transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
- return transforms
- class GroundingDataset(YOLODataset):
- def __init__(self, *args, task="detect", json_file, **kwargs):
- """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
- assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
- self.json_file = json_file
- super().__init__(*args, task=task, data={}, **kwargs)
- def get_img_files(self, img_path):
- """The image files would be read in `get_labels` function, return empty list here."""
- return []
- def get_labels(self):
- """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
- labels = []
- LOGGER.info("Loading annotation file...")
- with open(self.json_file, "r") as f:
- annotations = json.load(f)
- images = {f'{x["id"]:d}': x for x in annotations["images"]}
- imgToAnns = defaultdict(list)
- for ann in annotations["annotations"]:
- imgToAnns[ann["image_id"]].append(ann)
- for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
- img = images[f"{img_id:d}"]
- h, w, f = img["height"], img["width"], img["file_name"]
- im_file = Path(self.img_path) / f
- if not im_file.exists():
- continue
- self.im_files.append(str(im_file))
- bboxes = []
- cat2id = {}
- texts = []
- for ann in anns:
- if ann["iscrowd"]:
- continue
- box = np.array(ann["bbox"], dtype=np.float32)
- box[:2] += box[2:] / 2
- box[[0, 2]] /= float(w)
- box[[1, 3]] /= float(h)
- if box[2] <= 0 or box[3] <= 0:
- continue
- cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
- if cat_name not in cat2id:
- cat2id[cat_name] = len(cat2id)
- texts.append([cat_name])
- cls = cat2id[cat_name] # class
- box = [cls] + box.tolist()
- if box not in bboxes:
- bboxes.append(box)
- lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
- labels.append(
- {
- "im_file": im_file,
- "shape": (h, w),
- "cls": lb[:, 0:1], # n, 1
- "bboxes": lb[:, 1:], # n, 4
- "normalized": True,
- "bbox_format": "xywh",
- "texts": texts,
- }
- )
- return labels
- def build_transforms(self, hyp=None):
- """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
- transforms = super().build_transforms(hyp)
- if self.augment:
- # NOTE: hard-coded the args for now.
- transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
- return transforms
- class YOLOConcatDataset(ConcatDataset):
- """
- Dataset as a concatenation of multiple datasets.
- This class is useful to assemble different existing datasets.
- """
- @staticmethod
- def collate_fn(batch):
- """Collates data samples into batches."""
- return YOLODataset.collate_fn(batch)
- # TODO: support semantic segmentation
- class SemanticDataset(BaseDataset):
- """
- Semantic Segmentation Dataset.
- This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
- from the BaseDataset class.
- Note:
- This class is currently a placeholder and needs to be populated with methods and attributes for supporting
- semantic segmentation tasks.
- """
- def __init__(self):
- """Initialize a SemanticDataset object."""
- super().__init__()
- class ClassificationDataset:
- """
- Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
- augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
- learning models, with optional image transformations and caching mechanisms to speed up training.
- This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
- in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
- to ensure data integrity and consistency.
- Attributes:
- cache_ram (bool): Indicates if caching in RAM is enabled.
- cache_disk (bool): Indicates if caching on disk is enabled.
- samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
- file (if caching on disk), and optionally the loaded image array (if caching in RAM).
- torch_transforms (callable): PyTorch transforms to be applied to the images.
- """
- def __init__(self, root, args, augment=False, prefix=""):
- """
- Initialize YOLO object with root, image size, augmentations, and cache settings.
- Args:
- root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
- args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
- parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
- of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
- `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
- augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
- prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
- debugging. Default is an empty string.
- """
- import torchvision # scope for faster 'import ultralytics'
- # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
- self.base = torchvision.datasets.ImageFolder(root=root)
- self.samples = self.base.samples
- self.root = self.base.root
- # Initialize attributes
- if augment and args.fraction < 1.0: # reduce training fraction
- self.samples = self.samples[: round(len(self.samples) * args.fraction)]
- self.prefix = colorstr(f"{prefix}: ") if prefix else ""
- self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
- self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
- self.samples = self.verify_images() # filter out bad images
- self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
- scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
- self.torch_transforms = (
- classify_augmentations(
- size=args.imgsz,
- scale=scale,
- hflip=args.fliplr,
- vflip=args.flipud,
- erasing=args.erasing,
- auto_augment=args.auto_augment,
- hsv_h=args.hsv_h,
- hsv_s=args.hsv_s,
- hsv_v=args.hsv_v,
- )
- if augment
- else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
- )
- def __getitem__(self, i):
- """Returns subset of data and targets corresponding to given indices."""
- f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
- if self.cache_ram:
- if im is None: # Warning: two separate if statements required here, do not combine this with previous line
- im = self.samples[i][3] = cv2.imread(f)
- elif self.cache_disk:
- if not fn.exists(): # load npy
- np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
- im = np.load(fn)
- else: # read image
- im = cv2.imread(f) # BGR
- # Convert NumPy array to PIL image
- im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
- sample = self.torch_transforms(im)
- return {"img": sample, "cls": j}
- def __len__(self) -> int:
- """Return the total number of samples in the dataset."""
- return len(self.samples)
- def verify_images(self):
- """Verify all images in dataset."""
- desc = f"{self.prefix}Scanning {self.root}..."
- path = Path(self.root).with_suffix(".cache") # *.cache file path
- with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
- cache = load_dataset_cache_file(path) # attempt to load a *.cache file
- assert cache["version"] == DATASET_CACHE_VERSION # matches current version
- assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
- nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
- if LOCAL_RANK in {-1, 0}:
- d = f"{desc} {nf} images, {nc} corrupt"
- TQDM(None, desc=d, total=n, initial=n)
- if cache["msgs"]:
- LOGGER.info("\n".join(cache["msgs"])) # display warnings
- return samples
- # Run scan if *.cache retrieval failed
- nf, nc, msgs, samples, x = 0, 0, [], [], {}
- with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
- pbar = TQDM(results, desc=desc, total=len(self.samples))
- for sample, nf_f, nc_f, msg in pbar:
- if nf_f:
- samples.append(sample)
- if msg:
- msgs.append(msg)
- nf += nf_f
- nc += nc_f
- pbar.desc = f"{desc} {nf} images, {nc} corrupt"
- pbar.close()
- if msgs:
- LOGGER.info("\n".join(msgs))
- x["hash"] = get_hash([x[0] for x in self.samples])
- x["results"] = nf, nc, len(samples), samples
- x["msgs"] = msgs # warnings
- save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
- return samples