- # Ultralytics YOLO 🚀, AGPL-3.0 license
- """
- Train a model on a dataset.
- Usage:
- $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
- """
- import gc
- import math
- import os
- import subprocess
- import time
- import warnings
- from copy import deepcopy
- from datetime import datetime, timedelta
- from pathlib import Path
- import numpy as np
- import torch
- from torch import distributed as dist
- from torch import nn, optim
- from ultralytics.cfg import get_cfg, get_save_dir
- from ultralytics.data.utils import check_cls_dataset, check_det_dataset
- from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
- from ultralytics.utils import (
- __version__,
- callbacks,
- clean_url,
- colorstr,
- emojis,
- yaml_save,
- )
- from ultralytics.utils.autobatch import check_train_batch_size
- from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
- from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
- from ultralytics.utils.files import get_latest_run
- from ultralytics.utils.torch_utils import (
- EarlyStopping,
- ModelEMA,
- convert_optimizer_state_dict_to_fp16,
- init_seeds,
- one_cycle,
- select_device,
- strip_optimizer,
- torch_distributed_zero_first,
- )
- from ultralytics.nn.extra_modules.kernel_warehouse import get_temperature
- from ultralytics.trainsdk import TrainSdk
- from ultralytics.vinno_metrics.image_test.object_metrics import ObjectMetrics, YoloType, YoloMetas
- from ultralytics.vinno_metrics.image_test.seg_metrics import YoloInstanceSegMetas, SegMetrics, ResizeMode
- class BaseTrainer:
- """
- BaseTrainer.
- A base class for creating trainers.
- Attributes:
- args (SimpleNamespace): Configuration for the trainer.
- validator (BaseValidator): Validator instance.
- model (nn.Module): Model instance.
- callbacks (defaultdict): Dictionary of callbacks.
- save_dir (Path): Directory to save results.
- wdir (Path): Directory to save weights.
- last (Path): Path to the last checkpoint.
- best (Path): Path to the best checkpoint.
- save_period (int): Save checkpoint every x epochs (disabled if < 1).
- batch_size (int): Batch size for training.
- epochs (int): Number of epochs to train for.
- start_epoch (int): Starting epoch for training.
- device (torch.device): Device to use for training.
- amp (bool): Flag to enable AMP (Automatic Mixed Precision).
- scaler (amp.GradScaler): Gradient scaler for AMP.
- data (str): Path to data.
- trainset (torch.utils.data.Dataset): Training dataset.
- testset (torch.utils.data.Dataset): Testing dataset.
- ema (nn.Module): EMA (Exponential Moving Average) of the model.
- resume (bool): Resume training from a checkpoint.
- lf (nn.Module): Loss function.
- scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
- best_fitness (float): The best fitness value achieved.
- fitness (float): Current fitness value.
- loss (float): Current loss value.
- tloss (float): Total loss value.
- loss_names (list): List of loss names.
- csv (Path): Path to results CSV file.
- """
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
- """
- Initializes the BaseTrainer class.
- Args:
- cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
- overrides (dict, optional): Configuration overrides. Defaults to None.
- """
- self.args = get_cfg(cfg, overrides)
- self.check_resume(overrides)
- self.device = select_device(self.args.device, self.args.batch)
- self.validator = None
- self.metrics = None
- self.plots = {}
- init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
- # Dirs
- self.save_dir = get_save_dir(self.args)
- self.args.name = self.save_dir.name # update name for loggers
- self.wdir = self.save_dir / "weights" # weights dir
- if RANK in {-1, 0}:
- self.wdir.mkdir(parents=True, exist_ok=True) # make dir
- self.args.save_dir = str(self.save_dir)
- yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
- if self.args.is_train_on_platform:
- TrainSdk.save_output_model(self.save_dir / "args.yaml")
- self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
- self.save_period = self.args.save_period
- self.batch_size = self.args.batch
- self.epochs = self.args.epochs
- self.start_epoch = 0
- if RANK == -1:
- print_args(vars(self.args))
- # Device
- if self.device.type in {"cpu", "mps"}:
- self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
- # Model and Dataset
- self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
- with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times
- self.trainset, self.testset = self.get_dataset()
- self.ema = None
- # Optimization utils init
- self.lf = None
- self.scheduler = None
- # Epoch level metrics
- self.best_fitness = None
- self.fitness = None
- self.loss = None
- self.tloss = None
- self.loss_names = ["Loss"]
- self.csv = self.save_dir / "results.csv"
- self.plot_idx = [0, 1, 2]
- # HUB
- self.hub_session = None
- # Callbacks
- self.callbacks = _callbacks or callbacks.get_default_callbacks()
- if RANK in {-1, 0}:
- callbacks.add_integration_callbacks(self)
- def add_callback(self, event: str, callback):
- """Appends the given callback."""
- self.callbacks[event].append(callback)
- def set_callback(self, event: str, callback):
- """Overrides the existing callbacks with the given callback."""
- self.callbacks[event] = [callback]
- def run_callbacks(self, event: str):
- """Run all existing callbacks associated with a particular event."""
- for callback in self.callbacks.get(event, []):
- callback(self)
- def train(self):
- """Allow device='', device=None on Multi-GPU systems to default to device=0."""
- if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
- world_size = len(self.args.device.split(","))
- elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
- world_size = len(self.args.device)
- elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
- world_size = 1 # default to device 0
- else: # i.e. device='cpu' or 'mps'
- world_size = 0
- # Run subprocess if DDP training, else train normally
- if world_size > 1 and "LOCAL_RANK" not in os.environ:
- # Argument checks
- if self.args.rect:
- LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
- self.args.rect = False
- if self.args.batch < 1.0:
- LOGGER.warning(
- "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
- "default 'batch=16'"
- )
- self.args.batch = 16
- # Command
- cmd, file = generate_ddp_command(world_size, self)
- try:
- LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
- subprocess.run(cmd, check=True)
- except Exception as e:
- raise e
- finally:
- ddp_cleanup(self, str(file))
- else:
- self._do_train(world_size)
- def _setup_scheduler(self):
- """Initialize training learning rate scheduler."""
- if self.args.cos_lr:
- self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
- else:
- self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
- def _setup_ddp(self, world_size):
- """Initializes and sets the DistributedDataParallel parameters for training."""
- torch.cuda.set_device(RANK)
- self.device = torch.device("cuda", RANK)
- # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
- os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
- dist.init_process_group(
- backend="nccl" if dist.is_nccl_available() else "gloo",
- timeout=timedelta(seconds=10800), # 3 hours
- rank=RANK,
- world_size=world_size,
- )
- def _setup_train(self, world_size):
- """Builds dataloaders and optimizer on correct rank process."""
- # Model
- self.run_callbacks("on_pretrain_routine_start")
- ckpt = self.setup_model()
- self.model = self.model.to(self.device)
- self.set_model_attributes()
- # Freeze layers
- freeze_list = (
- self.args.freeze
- if isinstance(self.args.freeze, list)
- else range(self.args.freeze)
- if isinstance(self.args.freeze, int)
- else []
- )
- always_freeze_names = [".dfl"] # always freeze these layers
- freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
- for k, v in self.model.named_parameters():
- # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
- if any(x in k for x in freeze_layer_names):
- LOGGER.info(f"Freezing layer '{k}'")
- v.requires_grad = False
- # elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
- # LOGGER.info(
- # f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
- # "See ultralytics.engine.trainer for customization of frozen layers."
- # )
- # v.requires_grad = True
- # Check AMP
- self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
- if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
- callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
- self.amp = torch.tensor(check_amp(self.model), device=self.device)
- callbacks.default_callbacks = callbacks_backup # restore callbacks
- if RANK > -1 and world_size > 1: # DDP
- dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
- self.amp = bool(self.amp) # as boolean
- self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
- if world_size > 1:
- self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
- # Check imgsz
- gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
- self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
- self.stride = gs # for multiscale training
- # Batch size
- if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
- self.args.batch = self.batch_size = check_train_batch_size(
- model=self.model,
- imgsz=self.args.imgsz,
- amp=self.amp,
- batch=self.batch_size,
- )
- # Dataloaders
- batch_size = self.batch_size // max(world_size, 1)
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
- if RANK in {-1, 0}:
- # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
- self.test_loader = self.get_dataloader(
- self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
- )
- self.validator = self.get_validator()
- metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
- self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
- self.ema = ModelEMA(self.model)
- if self.args.plots:
- self.plot_training_labels()
- # Optimizer
- self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
- weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
- iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
- self.optimizer = self.build_optimizer(
- model=self.model,
- name=self.args.optimizer,
- lr=self.args.lr0,
- momentum=self.args.momentum,
- decay=weight_decay,
- iterations=iterations,
- )
- # Scheduler
- self._setup_scheduler()
- self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
- self.resume_training(ckpt)
- self.scheduler.last_epoch = self.start_epoch - 1 # do not move
- self.run_callbacks("on_pretrain_routine_end")
- def _do_train(self, world_size=1):
- """Train completed, evaluate and plot if specified by arguments."""
- if world_size > 1:
- self._setup_ddp(world_size)
- self._setup_train(world_size)
- nb = len(self.train_loader) # number of batches
- nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
- last_opt_step = -1
- self.epoch_time = None
- self.epoch_time_start = time.time()
- self.train_time_start = time.time()
- self.run_callbacks("on_train_start")
- LOGGER.info(
- f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
- f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
- f"Logging results to {colorstr('bold', self.save_dir)}\n"
- f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
- )
- if self.args.close_mosaic:
- base_idx = (self.epochs - self.args.close_mosaic) * nb
- self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
- epoch = self.start_epoch
- self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
- while True:
- self.epoch = epoch
- self.run_callbacks("on_train_epoch_start")
- with warnings.catch_warnings():
- warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
- self.scheduler.step()
- self.model.train()
- if RANK != -1:
- self.train_loader.sampler.set_epoch(epoch)
- pbar = enumerate(self.train_loader)
- # Update dataloader attributes (optional)
- if epoch == (self.epochs - self.args.close_mosaic):
- self._close_dataloader_mosaic()
- self.train_loader.reset()
- if RANK in {-1, 0}:
- LOGGER.info(self.progress_string())
- pbar = TQDM(enumerate(self.train_loader), total=nb)
- self.tloss = None
- for i, batch in pbar:
- self.run_callbacks("on_train_batch_start")
- # Warmup
- ni = i + nb * epoch
- if ni <= nw:
- xi = [0, nw] # x interp
- self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
- for j, x in enumerate(self.optimizer.param_groups):
- # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
- x["lr"] = np.interp(
- ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
- )
- if "momentum" in x:
- x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
- if hasattr(self.model, 'net_update_temperature'):
- temp = get_temperature(i + 1, epoch, len(self.train_loader), temp_epoch=20, temp_init_value=1.0)
- self.model.net_update_temperature(temp)
- # Forward
- with torch.cuda.amp.autocast(self.amp):
- batch = self.preprocess_batch(batch)
- self.loss, self.loss_items = self.model(batch)
- if RANK != -1:
- self.loss *= world_size
- self.tloss = (
- (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
- )
- # Backward
- self.scaler.scale(self.loss).backward()
- # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
- if ni - last_opt_step >= self.accumulate:
- self.optimizer_step()
- last_opt_step = ni
- # Timed stopping
- if self.args.time:
- self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
- if RANK != -1: # if DDP training
- broadcast_list = [self.stop if RANK == 0 else None]
- dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
- self.stop = broadcast_list[0]
- if self.stop: # training time exceeded
- break
- # Log
- mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
- loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
- losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
- if RANK in {-1, 0}:
- pbar.set_description(
- ("%11s" * 2 + "%11.4g" * (2 + loss_len))
- % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
- )
- self.run_callbacks("on_batch_end")
- if self.args.plots and ni in self.plot_idx:
- self.plot_training_samples(batch, ni)
- self.run_callbacks("on_train_batch_end")
- self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
- self.run_callbacks("on_train_epoch_end")
- if RANK in {-1, 0}:
- final_epoch = epoch + 1 >= self.epochs
- self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
- # Validation
- if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
- self.metrics, self.fitness = self.validate()
- self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
- self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
- if self.args.time:
- self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
- # Save model
- if self.args.save or final_epoch:
- self.save_model()
- self.run_callbacks("on_model_save")
- # Scheduler
- t = time.time()
- self.epoch_time = t - self.epoch_time_start
- self.epoch_time_start = t
- if self.args.time:
- mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
- self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
- self._setup_scheduler()
- self.scheduler.last_epoch = self.epoch # do not move
- self.stop |= epoch >= self.epochs # stop if exceeded epochs
- self.run_callbacks("on_fit_epoch_end")
- gc.collect()
- torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
- # Early Stopping
- if RANK != -1: # if DDP training
- broadcast_list = [self.stop if RANK == 0 else None]
- dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
- self.stop = broadcast_list[0]
- if self.stop:
- break # must break all DDP ranks
- epoch += 1
- if RANK in {-1, 0}:
- # Do final val with best.pt
- LOGGER.info(
- f"\n{epoch - self.start_epoch + 1} epochs completed in "
- f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
- )
- self.final_eval()
- if self.args.plots:
- self.plot_metrics()
- self.run_callbacks("on_train_end")
- gc.collect()
- torch.cuda.empty_cache()
- self.run_callbacks("teardown")
- def save_model(self):
- """Save model training checkpoints with additional metadata."""
- import io
- import pandas as pd # scope for faster 'import ultralytics'
- # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
- buffer = io.BytesIO()
- torch.save(
- {
- "epoch": self.epoch,
- "best_fitness": self.best_fitness,
- "model": None, # resume and final checkpoints derive from EMA
- "ema": deepcopy(self.ema.ema).half(),
- "updates": self.ema.updates,
- "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
- "train_args": vars(self.args), # save as dict
- "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
- "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
- "date": datetime.now().isoformat(),
- "version": __version__,
- "license": "AGPL-3.0 (https://ultralytics.com/license)",
- "docs": "https://docs.ultralytics.com",
- },
- buffer,
- )
- serialized_ckpt = buffer.getvalue() # get the serialized content to save
- # Save checkpoints
- self.last.write_bytes(serialized_ckpt) # save last.pt
- if self.args.is_train_on_platform:
- TrainSdk.save_output_model(self.last)
- if self.best_fitness == self.fitness:
- self.best.write_bytes(serialized_ckpt) # save best.pt
- if self.args.is_train_on_platform:
- TrainSdk.save_output_model(self.best)
- if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
- (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
- # 额外计算转成onnx模型的评价指标,暂时只支持YOLOv8和v9,需要nms后处理的模型
- # RTDETR、yolov10无需后处理的模型,转成onnx之后,需要额外操作,暂时不支持
- # if self.args.is_train_on_platform:
- # TrainSdk.save_output_model(self.wdir / f"epoch{self.epoch}.pt")
- # from ultralytics import YOLO
- # model = YOLO(self.wdir / f"epoch{self.epoch}.pt")
- # model.export(format="onnx") # export the model to ONNX format
- # onnx_model = self.wdir / f"epoch{self.epoch}.onnx"
- # TrainSdk.save_output_model(onnx_model)
- # self.platform_test_metrics(onnx_model)
- def platform_test_metrics(self, onnx_model):
- """
- 在VINNO平台训练时,将模型转成onnx,采用默认resize方式,调用c++的后处理,得到评价指标
- """
- if self.args.is_train_on_platform and self.data['platform_data_args']:
- platform_data_args = self.data["platform_data_args"]
- token = platform_data_args["token"]
- data_type = platform_data_args["data_type"]
- class_id_map_list = platform_data_args["class_id_map_list"]
- dll_file = platform_data_args["dll_file"]
- wrong_file = platform_data_args["wrong_file"]
- metrics_type = platform_data_args['metrics_type']
- extra_contours_args = platform_data_args['extra_contours_args']
- needed_image_results_dict = platform_data_args['needed_image_results_dict']
- needed_rois_dict = platform_data_args['needed_rois_dict']
- # 暂时没有直接从val或者test的dataset中取im_files,
- # 平台训练时,暂时都是直接区分train和test,可以直接取到测试数据
- files = []
- image_count = TrainSdk.get_test_file_count(token)
- for i in range(image_count):
- files.append(str(i))
- if data_type == 'detection':
- yolo_metas_yaml = platform_data_args['yolo_metas']
- yolo_metas = YoloMetas(
- yolotype=YoloType[yolo_metas_yaml['yolotype']],
- confthres=yolo_metas_yaml['confthres'],
- clsconfthres=yolo_metas_yaml['clsconfthres'],
- batchsize=yolo_metas_yaml['batchsize'],
- maxdet=yolo_metas_yaml['maxdet'],
- minboxratio=yolo_metas_yaml['minboxratio'],
- # ApplyPostProcessToBBox中,用于单个类别的box筛选
- postprocesstopk=yolo_metas_yaml['postprocesstopk'],
- enableioufilt=yolo_metas_yaml['enableioufilt'],
- enableiosfilt=yolo_metas_yaml['enableiosfilt'],
- ioufltth=yolo_metas_yaml['ioufltth'],
- iosfltth=yolo_metas_yaml['iosfltth'],
- # FindBoxesToUnion中 当两框重叠率高于一定程度,且合并后增加面积并不多,则合并两框
- enableunion=yolo_metas_yaml['enableunion'],
- unioniouth=yolo_metas_yaml['unioniouth'],
- unioniosth=yolo_metas_yaml['unioniosth'],
- unionuobth=yolo_metas_yaml['unionuobth'],
- # ApplyBoxClassFilter中,同一幅图上有多个框,用于不用类别的box筛选
- enableioufiltdiffcls=yolo_metas_yaml['enableioufiltdiffcls'],
- enableiosfiltdiffcls=yolo_metas_yaml['enableiosfiltdiffcls'],
- ioufltthdiffcls=yolo_metas_yaml['ioufltthdiffcls'],
- iosfltthdiffcls=yolo_metas_yaml['iosfltthdiffcls'], )
- for class_id_map in class_id_map_list:
- metric_reports = ObjectMetrics(is_local_file=False,
- files=files,
- token=token,
- onnx_file=onnx_model,
- needed_image_results_dict=needed_image_results_dict,
- needed_rois_dict=needed_rois_dict,
- extra_contours_args=extra_contours_args,
- class_id_map=class_id_map,
- wrong_file=wrong_file,
- dll_file=dll_file,
- yolo_metas=yolo_metas,
- metrics_type=metrics_type)
- metric_reports.run()
- elif data_type == 'segment':
- yolo_metas_yaml = platform_data_args['yolo_inst_seg_metas']
- resize_mode = ResizeMode[platform_data_args['resize_mode']]
- seg_post_process_param = platform_data_args['seg_post_process_param']
- yolo_inst_seg_metas = YoloInstanceSegMetas(
- yolo_type=YoloType[yolo_metas_yaml['yolo_type']],
- box_conf_thres=yolo_metas_yaml['box_conf_thres'],
- cls_conf_thres=yolo_metas_yaml['cls_conf_thres'],
- batch_size=yolo_metas_yaml['batch_size'],
- max_det=yolo_metas_yaml['max_det'],
- min_box_ratio=yolo_metas_yaml['min_box_ratio'],
- post_process_top_k=yolo_metas_yaml['post_process_top_k'],
- iou_fltth=yolo_metas_yaml['iou_fltth'],
- ios_fltth=yolo_metas_yaml['ios_fltth'],
- enable_iou_filt=yolo_metas_yaml['enable_iou_filt'],
- enable_ios_filt=yolo_metas_yaml['enable_ios_filt'],
- iou_fltth_diff_cls=yolo_metas_yaml['iou_fltth_diff_cls'],
- ios_fltth_diff_cls=yolo_metas_yaml['ios_fltth_diff_cls'],
- enable_iou_filt_diff_cls=yolo_metas_yaml['enable_iou_filt_diff_cls'],
- enable_ios_filt_diff_cls=yolo_metas_yaml['enable_iou_filt_diff_cls'],
- mask_thres=yolo_metas_yaml['mask_thres'],
- )
- for class_id_map in class_id_map_list:
- metric_reports = SegMetrics(is_local_file=False,
- files=files,
- token=token,
- onnx_file=onnx_model,
- resize_mode=resize_mode,
- needed_image_results_dict=needed_image_results_dict,
- needed_rois_dict=needed_rois_dict,
- extra_contours_args=extra_contours_args,
- class_id_map=class_id_map,
- wrong_file=wrong_file,
- dll_file=dll_file,
- yolo_inst_seg_metas=yolo_inst_seg_metas,
- seg_post_process_param=seg_post_process_param,
- use_contour_for_iou=True,
- metrics_type=metrics_type)
- metric_reports.run()
- else:
- pass
- def get_dataset(self):
- """
- Get train, val path from data dict if it exists.
- Returns None if data format is not recognized.
- """
- try:
- if self.args.task == "classify":
- data = check_cls_dataset(self.args.data)
- elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
- "detect",
- "segment",
- "pose",
- "obb",
- }:
- data = check_det_dataset(self.args.data, self.args.is_train_on_platform)
- if "yaml_file" in data:
- self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
- except Exception as e:
- raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
- self.data = data
- # 在平台训练时,无需train和val或test路径,直接返回空
- if self.args.is_train_on_platform:
- return "", ""
- else:
- return data["train"], data.get("val") or data.get("test")
- def setup_model(self):
- """Load/create/download model for any task."""
- if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
- return
- cfg, weights = self.model, None
- ckpt = None
- if str(self.model).endswith(".pt"):
- weights, ckpt = attempt_load_one_weight(self.model)
- cfg = weights.yaml
- elif isinstance(self.args.pretrained, (str, Path)):
- weights, _ = attempt_load_one_weight(self.args.pretrained)
- self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
- return ckpt
- def optimizer_step(self):
- """Perform a single step of the training optimizer with gradient clipping and EMA update."""
- self.scaler.unscale_(self.optimizer) # unscale gradients
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
- self.scaler.step(self.optimizer)
- self.scaler.update()
- self.optimizer.zero_grad()
- if self.ema:
- self.ema.update(self.model)
- def preprocess_batch(self, batch):
- """Allows custom preprocessing model inputs and ground truths depending on task type."""
- return batch
- def validate(self):
- """
- Runs validation on test set using self.validator.
- The returned dict is expected to contain "fitness" key.
- """
- metrics = self.validator(self)
- fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
- if not self.best_fitness or self.best_fitness < fitness:
- self.best_fitness = fitness
- return metrics, fitness
- def get_model(self, cfg=None, weights=None, verbose=True):
- """Get model and raise NotImplementedError for loading cfg files."""
- raise NotImplementedError("This task trainer doesn't support loading cfg files")
- def get_validator(self):
- """Returns a NotImplementedError when the get_validator function is called."""
- raise NotImplementedError("get_validator function not implemented in trainer")
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
- """Returns dataloader derived from torch.data.Dataloader."""
- raise NotImplementedError("get_dataloader function not implemented in trainer")
- def build_dataset(self, img_path, mode="train", batch=None):
- """Build dataset."""
- raise NotImplementedError("build_dataset function not implemented in trainer")
- def label_loss_items(self, loss_items=None, prefix="train"):
- """
- Returns a loss dict with labelled training loss items tensor.
- Note:
- This is not needed for classification but necessary for segmentation & detection
- """
- return {"loss": loss_items} if loss_items is not None else ["loss"]
- def set_model_attributes(self):
- """To set or update model parameters before training."""
- self.model.names = self.data["names"]
- def build_targets(self, preds, targets):
- """Builds target tensors for training YOLO model."""
- pass
- def progress_string(self):
- """Returns a string describing training progress."""
- return ""
- # TODO: may need to put these following functions into callback
- def plot_training_samples(self, batch, ni):
- """Plots training samples during YOLO training."""
- pass
- def plot_training_labels(self):
- """Plots training labels for YOLO model."""
- pass
- def save_metrics(self, metrics):
- """Saves training metrics to a CSV file."""
- keys, vals = list(metrics.keys()), list(metrics.values())
- n = len(metrics) + 1 # number of cols
- s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
- with open(self.csv, "a") as f:
- f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
- if self.args.is_train_on_platform:
- TrainSdk.save_output_model(self.csv)
- def plot_metrics(self):
- """Plot and display metrics visually."""
- pass
- def on_plot(self, name, data=None):
- """Registers plots (e.g. to be consumed in callbacks)"""
- path = Path(name)
- self.plots[path] = {"data": data, "timestamp": time.time()}
- def final_eval(self):
- """Performs final evaluation and validation for object detection YOLO model."""
- for f in self.last, self.best:
- if f.exists():
- strip_optimizer(f) # strip optimizers
- if f is self.best:
- LOGGER.info(f"\nValidating {f}...")
- self.validator.args.plots = self.args.plots
- self.metrics = self.validator(model=f)
- self.metrics.pop("fitness", None)
- self.run_callbacks("on_fit_epoch_end")
- def check_resume(self, overrides):
- """Check if resume checkpoint exists and update arguments accordingly."""
- resume = self.args.resume
- if resume:
- try:
- exists = isinstance(resume, (str, Path)) and Path(resume).exists()
- last = Path(check_file(resume) if exists else get_latest_run())
- # Check that resume data YAML exists, otherwise strip to force re-download of dataset
- ckpt_args = attempt_load_weights(last).args
- if not Path(ckpt_args["data"]).exists():
- ckpt_args["data"] = self.args.data
- resume = True
- self.args = get_cfg(ckpt_args)
- self.args.model = self.args.resume = str(last) # reinstate model
- for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
- if k in overrides:
- setattr(self.args, k, overrides[k])
- except Exception as e:
- raise FileNotFoundError(
- "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
- "i.e. 'yolo train resume model=path/to/last.pt'"
- ) from e
- self.resume = resume
- def resume_training(self, ckpt):
- """Resume YOLO training from given epoch and best fitness."""
- if ckpt is None or not self.resume:
- return
- best_fitness = 0.0
- start_epoch = ckpt.get("epoch", -1) + 1
- if ckpt.get("optimizer", None) is not None:
- self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
- best_fitness = ckpt["best_fitness"]
- if self.ema and ckpt.get("ema"):
- self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
- self.ema.updates = ckpt["updates"]
- assert start_epoch > 0, (
- f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
- f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
- )
- LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
- if self.epochs < start_epoch:
- LOGGER.info(
- f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
- )
- self.epochs += ckpt["epoch"] # finetune additional epochs
- self.best_fitness = best_fitness
- self.start_epoch = start_epoch
- if start_epoch > (self.epochs - self.args.close_mosaic):
- self._close_dataloader_mosaic()
- def _close_dataloader_mosaic(self):
- """Update dataloaders to stop using mosaic augmentation."""
- if hasattr(self.train_loader.dataset, "mosaic"):
- self.train_loader.dataset.mosaic = False
- if hasattr(self.train_loader.dataset, "close_mosaic"):
- LOGGER.info("Closing dataloader mosaic")
- self.train_loader.dataset.close_mosaic(hyp=self.args)
- def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
- """
- Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
- weight decay, and number of iterations.
- Args:
- model (torch.nn.Module): The model for which to build an optimizer.
- name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
- based on the number of iterations. Default: 'auto'.
- lr (float, optional): The learning rate for the optimizer. Default: 0.001.
- momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
- decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
- iterations (float, optional): The number of iterations, which determines the optimizer if
- name is 'auto'. Default: 1e5.
- Returns:
- (torch.optim.Optimizer): The constructed optimizer.
- """
- g = [], [], [] # optimizer parameter groups
- bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
- if name == "auto":
- LOGGER.info(
- f"{colorstr('optimizer:')} 'optimizer=auto' found, "
- f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
- f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
- )
- nc = getattr(model, "nc", 10) # number of classes
- lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
- name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
- self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
- for module_name, module in model.named_modules():
- for param_name, param in module.named_parameters(recurse=False):
- fullname = f"{module_name}.{param_name}" if module_name else param_name
- if "bias" in fullname: # bias (no decay)
- g[2].append(param)
- elif isinstance(module, bn): # weight (no decay)
- g[1].append(param)
- else: # weight (with decay)
- g[0].append(param)
- if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
- optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
- elif name == "RMSProp":
- optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
- elif name == "SGD":
- optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
- else:
- raise NotImplementedError(
- f"Optimizer '{name}' not found in list of available optimizers "
- f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
- "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
- )
- optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
- optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
- LOGGER.info(
- f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
- f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
- )
- return optimizer