|
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import glob
- import math
- import os
- import random
- from copy import deepcopy
- from multiprocessing.pool import ThreadPool
- from pathlib import Path
- from typing import Optional
- import cv2
- import numpy as np
- import psutil
- from torch.utils.data import Dataset
- from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
- from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
- from ultralytics.data.datagenerate import get_image_remove_extra_contour
- import ultralytics.trainsdk.TrainSdk as TrainSdk
- class BaseDataset(Dataset):
- """
- Base dataset class for loading and processing image data.
- Args:
- img_path (str): Path to the folder containing images.
- is_train_on_platform (bool): 是否在vinno平台上训练
- data (dict, optional): A dataset YAML dictionary. Defaults to None.
- imgsz (int, optional): Image size. Defaults to 640.\
- cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
- augment (bool, optional): If True, data augmentation is applied. Defaults to True.
- hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
- prefix (str, optional): Prefix to print in log messages. Defaults to ''.
- rect (bool, optional): If True, rectangular training is used. Defaults to False.
- batch_size (int, optional): Size of batches. Defaults to None.
- stride (int, optional): Stride. Defaults to 32.
- pad (float, optional): Padding. Defaults to 0.0.
- single_cls (bool, optional): If True, single class training is used. Defaults to False.
- classes (list): List of included classes. Default is None.
- fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
- Attributes:
- im_files (list): List of image file paths.
- labels (list): List of label data dictionaries.
- ni (int): Number of images in the dataset.
- ims (list): List of loaded images.
- npy_files (list): List of numpy file paths.
- transforms (callable): Image transformation function.
- """
- def __init__(
- self,
- img_path,
- is_train_on_platform=False,
- data=None,
- imgsz=640,
- cache=False,
- augment=True,
- hyp=DEFAULT_CFG,
- prefix="",
- rect=False,
- batch_size=16,
- stride=32,
- pad=0.5,
- single_cls=False,
- classes=None,
- fraction=1.0,
- ):
- """Initialize BaseDataset with given configuration and options."""
- super().__init__()
- self.img_path = img_path
- self.imgsz = imgsz
- self.augment = augment
- self.single_cls = single_cls
- self.prefix = prefix
- self.fraction = fraction
- self.is_train_on_platform = is_train_on_platform
- self.data = data
- if self.is_train_on_platform:
- self.train_or_val_data = "train" if self.augment else "val"
- self.platform_data_args = self.data['platform_data_args']
- self.token = self.platform_data_args["token"]
- self.extra_contours_args = self.platform_data_args['extra_contours_args']
- self.im_files, self.labels, self.crop_boxes, self.crop_contours = self.get_img_files_and_labels()
- else:
- self.im_files = self.get_img_files(self.img_path)
- self.labels = self.get_labels()
- if isinstance(cache, str):
- cache = cache.lower()
- # Cache images
- if cache == "ram" and not self.check_cache_ram():
- cache = False
- self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
- if cache:
- self.cache_images(cache)
- self.update_labels(include_class=classes) # single_cls and include_class
- self.ni = len(self.labels) # number of images
- self.rect = rect
- self.batch_size = batch_size
- self.stride = stride
- self.pad = pad
- if self.rect:
- assert self.batch_size is not None
- self.set_rectangle()
- # Buffer thread for mosaic images
- self.buffer = [] # buffer size = batch size
- self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
- # Cache images (options are cache = True, False, None, "ram", "disk")
- self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
- # Transforms
- self.transforms = self.build_transforms(hyp=hyp)
- def get_img_files(self, img_path):
- """Read image files."""
- try:
- f = [] # image files
- for p in img_path if isinstance(img_path, list) else [img_path]:
- p = Path(p) # os-agnostic
- if p.is_dir(): # dir
- f += glob.glob(str(p / "**" / "*.*"), recursive=True)
- # F = list(p.rglob('*.*')) # pathlib
- elif p.is_file(): # file
- with open(p) as t:
- t = t.read().strip().splitlines()
- parent = str(p.parent) + os.sep
- f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
- # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
- else:
- raise FileNotFoundError(f"{self.prefix}{p} does not exist")
- im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
- # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
- assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
- except Exception as e:
- raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
- if self.fraction < 1:
- # im_files = im_files[: round(len(im_files) * self.fraction)]
- num_elements_to_select = round(len(im_files) * self.fraction)
- im_files = random.sample(im_files, num_elements_to_select)
- return im_files
- def update_labels(self, include_class: Optional[list]):
- """Update labels to include only these classes (optional)."""
- include_class_array = np.array(include_class).reshape(1, -1)
- for i in range(len(self.labels)):
- if include_class is not None:
- cls = self.labels[i]["cls"]
- bboxes = self.labels[i]["bboxes"]
- segments = self.labels[i]["segments"]
- keypoints = self.labels[i]["keypoints"]
- j = (cls == include_class_array).any(1)
- self.labels[i]["cls"] = cls[j]
- self.labels[i]["bboxes"] = bboxes[j]
- if segments:
- self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
- if keypoints is not None:
- self.labels[i]["keypoints"] = keypoints[j]
- if self.single_cls:
- self.labels[i]["cls"][:, 0] = 0
- def load_image(self, i, rect_mode=True):
- """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
- im, f, = self.ims[i], self.im_files[i]
- if not self.is_train_on_platform:
- fn = self.npy_files[i]
- if im is None: # not cached in RAM
- if self.is_train_on_platform:
- if self.train_or_val_data == "train":
- image_data, _, image_name, _ = TrainSdk.get_labeled_file(self.token, f)
- else:
- image_data, _, image_name, _ = TrainSdk.get_test_labeled_file(self.token, f)
- # read image
- nparr_data = np.frombuffer(image_data, dtype=np.uint8)
- im = cv2.imdecode(nparr_data, flags=cv2.IMREAD_COLOR) # BGR
- # 判断是否需要裁图,裁图影响输入图像
- if self.extra_contours_args != '':
- use_orig_image_pixel = self.extra_contours_args['use_orig_image_pixel']
- im = get_image_remove_extra_contour(im, self.crop_contours[i], self.crop_boxes[i],
- use_orig_image_pixel)
- else:
- if fn.exists(): # load npy
- try:
- im = np.load(fn)
- except Exception as e:
- LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
- Path(fn).unlink(missing_ok=True)
- im = cv2.imread(f) # BGR
- else: # read image
- im = cv2.imread(f) # BGR
- if im is None:
- raise FileNotFoundError(f"Image Not Found {f}")
- h0, w0 = im.shape[:2] # orig hw
- if rect_mode: # resize long side to imgsz while maintaining aspect ratio
- r = self.imgsz / max(h0, w0) # ratio
- if r != 1: # if sizes are not equal
- w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
- im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
- elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz
- im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
- # Add to buffer if training with augmentations
- if self.augment:
- self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
- self.buffer.append(i)
- if len(self.buffer) >= self.max_buffer_length:
- j = self.buffer.pop(0)
- self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
- return im, (h0, w0), im.shape[:2]
- return self.ims[i], self.im_hw0[i], self.im_hw[i]
- def cache_images(self, cache):
- """Cache images to memory or disk."""
- b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
- fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
- with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(fcn, range(self.ni))
- pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
- for i, x in pbar:
- if cache == "disk":
- b += self.npy_files[i].stat().st_size
- else: # 'ram'
- self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
- b += self.ims[i].nbytes
- pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
- pbar.close()
- def cache_images_to_disk(self, i):
- """Saves an image as an *.npy file for faster loading."""
- f = self.npy_files[i]
- if not f.exists():
- np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)
- def check_cache_ram(self, safety_margin=0.5):
- """Check image caching requirements vs available memory."""
- b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
- n = min(self.ni, 30) # extrapolate from 30 random images
- for _ in range(n):
- im = cv2.imread(random.choice(self.im_files)) # sample image
- ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
- b += im.nbytes * ratio**2
- mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
- mem = psutil.virtual_memory()
- cache = mem_required < mem.available # to cache or not to cache, that is the question
- if not cache:
- LOGGER.info(
- f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
- f'with {int(safety_margin * 100)}% safety margin but only '
- f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
- f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
- )
- return cache
- def set_rectangle(self):
- """Sets the shape of bounding boxes for YOLO detections as rectangles."""
- bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
- nb = bi[-1] + 1 # number of batches
- s = np.array([x.pop("shape") for x in self.labels]) # hw
- ar = s[:, 0] / s[:, 1] # aspect ratio
- irect = ar.argsort()
- self.im_files = [self.im_files[i] for i in irect]
- self.labels = [self.labels[i] for i in irect]
- ar = ar[irect]
- # Set training image shapes
- shapes = [[1, 1]] * nb
- for i in range(nb):
- ari = ar[bi == i]
- mini, maxi = ari.min(), ari.max()
- if maxi < 1:
- shapes[i] = [maxi, 1]
- elif mini > 1:
- shapes[i] = [1, 1 / mini]
- self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
- self.batch = bi # batch index of image
- def __getitem__(self, index):
- """Returns transformed label information for given index."""
- return self.transforms(self.get_image_and_label(index))
- def get_image_and_label(self, index):
- """Get and return label information from the dataset."""
- label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
- label.pop("shape", None) # shape is for rect, remove it
- label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
- label["ratio_pad"] = (
- label["resized_shape"][0] / label["ori_shape"][0],
- label["resized_shape"][1] / label["ori_shape"][1],
- ) # for evaluation
- if self.rect:
- label["rect_shape"] = self.batch_shapes[self.batch[index]]
- return self.update_labels_info(label)
- def __len__(self):
- """Returns the length of the labels list for the dataset."""
- return len(self.labels)
- def update_labels_info(self, label):
- """Custom your label format here."""
- return label
- def build_transforms(self, hyp=None):
- """
- Users can customize augmentations here.
- Example:
- ```python
- if self.augment:
- # Training transforms
- return Compose([])
- else:
- # Val transforms
- return Compose([])
- ```
- """
- raise NotImplementedError
- def get_labels(self):
- """
- Users can customize their own format here.
- Note:
- Ensure output is a dictionary with the following keys:
- ```python
- dict(
- im_file=im_file,
- shape=shape, # format: (height, width)
- cls=cls,
- bboxes=bboxes, # xywh
- segments=segments, # xy
- keypoints=keypoints, # xy
- normalized=True, # or False
- bbox_format="xyxy", # or xywh, ltwh
- )
- ```
- """
- raise NotImplementedError
- def get_img_files_and_labels(self):
- """
- 默认的yolo是在基类的get_img_files得到im_files,在继承的YOLODataset的get_labels得到labels
- 平台上使用时,im_files和labels同时返回,因此可以重写该函数,无需修改get_img_files和get_labels
- Note:
- Ensure outputs are im_files, all_labels, crop_boxes, crop_contours:
- im_files 所有文件名的list
- all_labels 所有label信息组成dict的list,dict格式如下:
- ```python
- dict(
- im_file=im_file,
- shape=shape, # format: (height, width)
- cls=cls,
- bboxes=bboxes, # xywh
- segments=segments, # xy
- keypoints=keypoints, # xy
- normalized=True, # or False
- bbox_format="xyxy", # or xywh, ltwh
- )
- ```
- crop_boxes 所有可能需要的裁图框的list
- crop_contours 所有可能需要的裁图轮廓的list
- """
- raise NotImplementedError
|