123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import contextlib
- import glob
- import inspect
- import math
- import os
- import platform
- import re
- import shutil
- import subprocess
- import time
- from importlib import metadata
- from pathlib import Path
- from typing import Optional
- import cv2
- import numpy as np
- import requests
- import torch
- from ultralytics.utils import (
- ASSETS,
- AUTOINSTALL,
- IS_COLAB,
- IS_JUPYTER,
- IS_KAGGLE,
- IS_PIP_PACKAGE,
- LINUX,
- LOGGER,
- ONLINE,
- PYTHON_VERSION,
- ROOT,
- TORCHVISION_VERSION,
- USER_CONFIG_DIR,
- Retry,
- SimpleNamespace,
- ThreadingLocked,
- TryExcept,
- clean_url,
- colorstr,
- downloads,
- emojis,
- is_github_action_running,
- url2file,
- )
- def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
- """
- Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
- Args:
- file_path (Path): Path to the requirements.txt file.
- package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
- Returns:
- (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
- Example:
- ```python
- from ultralytics.utils.checks import parse_requirements
- parse_requirements(package='ultralytics')
- ```
- """
- if package:
- requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
- else:
- requires = Path(file_path).read_text().splitlines()
- requirements = []
- for line in requires:
- line = line.strip()
- if line and not line.startswith("#"):
- line = line.split("#")[0].strip() # ignore inline comments
- match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
- if match:
- requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
- return requirements
- def parse_version(version="0.0.0") -> tuple:
- """
- Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
- function replaces deprecated 'pkg_resources.parse_version(v)'.
- Args:
- version (str): Version string, i.e. '2.0.1+cpu'
- Returns:
- (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
- """
- try:
- return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
- except Exception as e:
- LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
- return 0, 0, 0
- def is_ascii(s) -> bool:
- """
- Check if a string is composed of only ASCII characters.
- Args:
- s (str): String to be checked.
- Returns:
- (bool): True if the string is composed only of ASCII characters, False otherwise.
- """
- # Convert list, tuple, None, etc. to string
- s = str(s)
- # Check if the string is composed of only ASCII characters
- return all(ord(c) < 128 for c in s)
- def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
- """
- Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
- stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
- Args:
- imgsz (int | cList[int]): Image size.
- stride (int): Stride value.
- min_dim (int): Minimum number of dimensions.
- max_dim (int): Maximum number of dimensions.
- floor (int): Minimum allowed value for image size.
- Returns:
- (List[int]): Updated image size.
- """
- # Convert stride to integer if it is a tensor
- stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
- # Convert image size to list if it is an integer
- if isinstance(imgsz, int):
- imgsz = [imgsz]
- elif isinstance(imgsz, (list, tuple)):
- imgsz = list(imgsz)
- elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
- imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
- else:
- raise TypeError(
- f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
- f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
- )
- # Apply max_dim
- if len(imgsz) > max_dim:
- msg = (
- "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
- "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
- )
- if max_dim != 1:
- raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
- LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
- imgsz = [max(imgsz)]
- # Make image size a multiple of the stride
- sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
- # Print warning message if image size was updated
- if sz != imgsz:
- LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
- # Add missing dimensions if necessary
- sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
- return sz
- def check_version(
- current: str = "0.0.0",
- required: str = "0.0.0",
- name: str = "version",
- hard: bool = False,
- verbose: bool = False,
- msg: str = "",
- ) -> bool:
- """
- Check current version against the required version or range.
- Args:
- current (str): Current version or package name to get version from.
- required (str): Required version or range (in pip-style format).
- name (str, optional): Name to be used in warning message.
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
- verbose (bool, optional): If True, print warning message if requirement is not met.
- msg (str, optional): Extra message to display if verbose.
- Returns:
- (bool): True if requirement is met, False otherwise.
- Example:
- ```python
- # Check if current version is exactly 22.04
- check_version(current='22.04', required='==22.04')
- # Check if current version is greater than or equal to 22.04
- check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
- # Check if current version is less than or equal to 22.04
- check_version(current='22.04', required='<=22.04')
- # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
- check_version(current='21.10', required='>20.04,<22.04')
- ```
- """
- if not current: # if current is '' or None
- LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
- return True
- elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
- try:
- name = current # assigned package name to 'name' arg
- current = metadata.version(current) # get version string from package name
- except metadata.PackageNotFoundError as e:
- if hard:
- raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
- else:
- return False
- if not required: # if required is '' or None
- return True
- op = ""
- version = ""
- result = True
- c = parse_version(current) # '1.2.3' -> (1, 2, 3)
- for r in required.strip(",").split(","):
- op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
- v = parse_version(version) # '1.2.3' -> (1, 2, 3)
- if op == "==" and c != v:
- result = False
- elif op == "!=" and c == v:
- result = False
- elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
- result = False
- elif op == "<=" and not (c <= v):
- result = False
- elif op == ">" and not (c > v):
- result = False
- elif op == "<" and not (c < v):
- result = False
- if not result:
- warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
- if hard:
- raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
- if verbose:
- LOGGER.warning(warning)
- return result
- def check_latest_pypi_version(package_name="ultralytics"):
- """
- Returns the latest version of a PyPI package without downloading or installing it.
- Parameters:
- package_name (str): The name of the package to find the latest version for.
- Returns:
- (str): The latest version of the package.
- """
- with contextlib.suppress(Exception):
- requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
- response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
- if response.status_code == 200:
- return response.json()["info"]["version"]
- def check_pip_update_available():
- """
- Checks if a new version of the ultralytics package is available on PyPI.
- Returns:
- (bool): True if an update is available, False otherwise.
- """
- if ONLINE and IS_PIP_PACKAGE:
- with contextlib.suppress(Exception):
- from ultralytics import __version__
- latest = check_latest_pypi_version()
- if check_version(__version__, f"<{latest}"): # check if current version is < latest version
- LOGGER.info(
- f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
- f"Update with 'pip install -U ultralytics'"
- )
- return True
- return False
- @ThreadingLocked()
- def check_font(font="Arial.ttf"):
- """
- Find font locally or download to user's configuration directory if it does not already exist.
- Args:
- font (str): Path or name of font.
- Returns:
- file (Path): Resolved font file path.
- """
- from matplotlib import font_manager
- # Check USER_CONFIG_DIR
- name = Path(font).name
- file = USER_CONFIG_DIR / name
- if file.exists():
- return file
- # Check system fonts
- matches = [s for s in font_manager.findSystemFonts() if font in s]
- if any(matches):
- return matches[0]
- # Download to USER_CONFIG_DIR if missing
- url = f"https://ultralytics.com/assets/{name}"
- if downloads.is_url(url, check=True):
- downloads.safe_download(url=url, file=file)
- return file
- def check_python(minimum: str = "3.8.0", hard: bool = True) -> bool:
- """
- Check current python version against the required minimum version.
- Args:
- minimum (str): Required minimum version of python.
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
- Returns:
- (bool): Whether the installed Python version meets the minimum constraints.
- """
- return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard)
- @TryExcept()
- def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
- """
- Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
- Args:
- requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
- string, or a list of package requirements as strings.
- exclude (Tuple[str]): Tuple of package names to exclude from checking.
- install (bool): If True, attempt to auto-update packages that don't meet requirements.
- cmds (str): Additional commands to pass to the pip install command when auto-updating.
- Example:
- ```python
- from ultralytics.utils.checks import check_requirements
- # Check a requirements.txt file
- check_requirements('path/to/requirements.txt')
- # Check a single package
- check_requirements('ultralytics>=8.0.0')
- # Check multiple packages
- check_requirements(['numpy', 'ultralytics>=8.0.0'])
- ```
- """
- prefix = colorstr("red", "bold", "requirements:")
- check_python() # check python version
- check_torchvision() # check torch-torchvision compatibility
- if isinstance(requirements, Path): # requirements.txt file
- file = requirements.resolve()
- assert file.exists(), f"{prefix} {file} not found, check failed."
- requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
- elif isinstance(requirements, str):
- requirements = [requirements]
- pkgs = []
- for r in requirements:
- r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
- match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
- name, required = match[1], match[2].strip() if match[2] else ""
- try:
- assert check_version(metadata.version(name), required) # exception if requirements not met
- except (AssertionError, metadata.PackageNotFoundError):
- pkgs.append(r)
- @Retry(times=2, delay=1)
- def attempt_install(packages, commands):
- """Attempt pip install command with retries on failure."""
- return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
- s = " ".join(f'"{x}"' for x in pkgs) # console string
- if s:
- if install and AUTOINSTALL: # check environment variable
- n = len(pkgs) # number of packages updates
- LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
- try:
- t = time.time()
- assert ONLINE, "AutoUpdate skipped (offline)"
- LOGGER.info(attempt_install(s, cmds))
- dt = time.time() - t
- LOGGER.info(
- f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
- f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
- )
- except Exception as e:
- LOGGER.warning(f"{prefix} ❌ {e}")
- return False
- else:
- return False
- return True
- def check_torchvision():
- """
- Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
- This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
- to the provided compatibility table based on:
- https://github.com/pytorch/vision#installation.
- The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
- Torchvision versions.
- """
- # Compatibility table
- compatibility_table = {
- "2.3": ["0.18"],
- "2.2": ["0.17"],
- "2.1": ["0.16"],
- "2.0": ["0.15"],
- "1.13": ["0.14"],
- "1.12": ["0.13"],
- }
- # Extract only the major and minor versions
- v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
- if v_torch in compatibility_table:
- compatible_versions = compatibility_table[v_torch]
- v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
- if all(v_torchvision != v for v in compatible_versions):
- print(
- f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
- f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
- "'pip install -U torch torchvision' to update both.\n"
- "For a full compatibility table see https://github.com/pytorch/vision#installation"
- )
- def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
- """Check file(s) for acceptable suffix."""
- if file and suffix:
- if isinstance(suffix, str):
- suffix = (suffix,)
- for f in file if isinstance(file, (list, tuple)) else [file]:
- s = Path(f).suffix.lower().strip() # file suffix
- if len(s):
- assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
- def check_yolov5u_filename(file: str, verbose: bool = True):
- """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
- if "yolov3" in file or "yolov5" in file:
- if "u.yaml" in file:
- file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
- elif ".pt" in file and "u" not in file:
- original_file = file
- file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
- file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
- file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
- if file != original_file and verbose:
- LOGGER.info(
- f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
- f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
- f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
- )
- return file
- def check_model_file_from_stem(model="yolov8n"):
- """Return a model filename from a valid model stem."""
- if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
- return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
- else:
- return model
- def check_file(file, suffix="", download=True, hard=True):
- """Search/download file (if necessary) and return path."""
- check_suffix(file, suffix) # optional
- file = str(file).strip() # convert to string and strip spaces
- file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
- if (
- not file
- or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
- or file.lower().startswith("grpc://")
- ): # file exists or gRPC Triton images
- return file
- elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
- url = file # warning: Pathlib turns :// -> :/
- file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
- if Path(file).exists():
- LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
- else:
- downloads.safe_download(url=url, file=file, unzip=False)
- return file
- else: # search
- files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
- if not files and hard:
- raise FileNotFoundError(f"'{file}' does not exist")
- elif len(files) > 1 and hard:
- raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
- return files[0] if len(files) else [] # return file
- def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
- """Search/download YAML file (if necessary) and return path, checking suffix."""
- return check_file(file, suffix, hard=hard)
- def check_is_path_safe(basedir, path):
- """
- Check if the resolved path is under the intended directory to prevent path traversal.
- Args:
- basedir (Path | str): The intended directory.
- path (Path | str): The path to check.
- Returns:
- (bool): True if the path is safe, False otherwise.
- """
- base_dir_resolved = Path(basedir).resolve()
- path_resolved = Path(path).resolve()
- return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
- def check_imshow(warn=False):
- """Check if environment supports image displays."""
- try:
- if LINUX:
- assert not IS_COLAB and not IS_KAGGLE
- assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
- cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
- cv2.waitKey(1)
- cv2.destroyAllWindows()
- cv2.waitKey(1)
- return True
- except Exception as e:
- if warn:
- LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
- return False
- def check_yolo(verbose=True, device=""):
- """Return a human-readable YOLO software and hardware summary."""
- import psutil
- from ultralytics.utils.torch_utils import select_device
- if IS_JUPYTER:
- if check_requirements("wandb", install=False):
- os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
- if IS_COLAB:
- shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
- if verbose:
- # System info
- gib = 1 << 30 # bytes per GiB
- ram = psutil.virtual_memory().total
- total, used, free = shutil.disk_usage("/")
- s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
- with contextlib.suppress(Exception): # clear display if ipython is installed
- from IPython import display
- display.clear_output()
- else:
- s = ""
- select_device(device=device, newline=False)
- LOGGER.info(f"Setup complete ✅ {s}")
- def collect_system_info():
- """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
- import psutil
- from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR
- from ultralytics.utils.torch_utils import get_cpu_info
- ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
- check_yolo()
- LOGGER.info(
- f"\n{'OS':<20}{platform.platform()}\n"
- f"{'Environment':<20}{ENVIRONMENT}\n"
- f"{'Python':<20}{PYTHON_VERSION}\n"
- f"{'Install':<20}{'git' if IS_GIT_DIR else 'pip' if IS_PIP_PACKAGE else 'other'}\n"
- f"{'RAM':<20}{ram_info:.2f} GB\n"
- f"{'CPU':<20}{get_cpu_info()}\n"
- f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
- )
- for r in parse_requirements(package="ultralytics"):
- try:
- current = metadata.version(r.name)
- is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
- except metadata.PackageNotFoundError:
- current = "(not installed)"
- is_met = "❌ "
- LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
- if is_github_action_running():
- LOGGER.info(
- f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
- f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
- f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
- f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
- f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
- f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
- )
- def check_amp(model):
- """
- This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
- fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
- be disabled during training.
- Args:
- model (nn.Module): A YOLOv8 model instance.
- Example:
- ```python
- from ultralytics import YOLO
- from ultralytics.utils.checks import check_amp
- model = YOLO('yolov8n.pt').model.cuda()
- check_amp(model)
- ```
- Returns:
- (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
- """
- device = next(model.parameters()).device # get model device
- if device.type in {"cpu", "mps"}:
- return False # AMP only used on CUDA devices
- def amp_allclose(m, im):
- """All close FP32 vs AMP results."""
- a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
- with torch.cuda.amp.autocast(True):
- b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
- del m
- return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
- im = ASSETS / "bus.jpg" # image to check
- prefix = colorstr("AMP: ")
- LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
- warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
- try:
- from ultralytics import YOLO
- assert amp_allclose(YOLO("yolov8n.pt"), im)
- LOGGER.info(f"{prefix}checks passed ✅")
- except ConnectionError:
- LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
- except (AttributeError, ModuleNotFoundError):
- LOGGER.warning(
- f"{prefix}checks skipped ⚠️. "
- f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
- )
- except AssertionError:
- LOGGER.warning(
- f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
- f"NaN losses or zero-mAP results, so AMP will be disabled during training."
- )
- return False
- return True
- def git_describe(path=ROOT): # path must be a directory
- """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
- with contextlib.suppress(Exception):
- return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
- return ""
- def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
- """Print function arguments (optional args dict)."""
- def strip_auth(v):
- """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
- return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
- x = inspect.currentframe().f_back # previous frame
- file, _, func, _, _ = inspect.getframeinfo(x)
- if args is None: # get args automatically
- args, _, _, frm = inspect.getargvalues(x)
- args = {k: v for k, v in frm.items() if k in args}
- try:
- file = Path(file).resolve().relative_to(ROOT).with_suffix("")
- except ValueError:
- file = Path(file).stem
- s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
- LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
- def cuda_device_count() -> int:
- """
- Get the number of NVIDIA GPUs available in the environment.
- Returns:
- (int): The number of NVIDIA GPUs available.
- """
- try:
- # Run the nvidia-smi command and capture its output
- output = subprocess.check_output(
- ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
- )
- # Take the first line and strip any leading/trailing white space
- first_line = output.strip().split("\n")[0]
- return int(first_line)
- except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
- # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
- return 0
- def cuda_is_available() -> bool:
- """
- Check if CUDA is available in the environment.
- Returns:
- (bool): True if one or more NVIDIA GPUs are available, False otherwise.
- """
- return cuda_device_count() > 0
- # Define constants
- IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
- IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
|