123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
- try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["dvc"] is True # verify integration is enabled
- import dvclive
- assert checks.check_version("dvclive", "2.11.0", verbose=True)
- import os
- import re
- from pathlib import Path
- # DVCLive logger instance
- live = None
- _processed_plots = {}
- # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
- # distinguish final evaluation of the best model vs last epoch validation
- _training_epoch = False
- except (ImportError, AssertionError, TypeError):
- dvclive = None
- def _log_images(path, prefix=""):
- """Logs images at specified path with an optional prefix using DVCLive."""
- if live:
- name = path.name
- # Group images by batch to enable sliders in UI
- if m := re.search(r"_batch(\d+)", name):
- ni = m[1]
- new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
- name = (Path(new_stem) / ni).with_suffix(path.suffix)
- live.log_image(os.path.join(prefix, name), path)
- def _log_plots(plots, prefix=""):
- """Logs plot images for training progress if they have not been previously processed."""
- for name, params in plots.items():
- timestamp = params["timestamp"]
- if _processed_plots.get(name) != timestamp:
- _log_images(name, prefix)
- _processed_plots[name] = timestamp
- def _log_confusion_matrix(validator):
- """Logs the confusion matrix for the given validator using DVCLive."""
- targets = []
- preds = []
- matrix = validator.confusion_matrix.matrix
- names = list(validator.names.values())
- if validator.confusion_matrix.task == "detect":
- names += ["background"]
- for ti, pred in enumerate(matrix.T.astype(int)):
- for pi, num in enumerate(pred):
- targets.extend([names[ti]] * num)
- preds.extend([names[pi]] * num)
- live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
- def on_pretrain_routine_start(trainer):
- """Initializes DVCLive logger for training metadata during pre-training routine."""
- try:
- global live
- live = dvclive.Live(save_dvc_exp=True, cache_images=True)
- LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
- except Exception as e:
- LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
- def on_pretrain_routine_end(trainer):
- """Logs plots related to the training process at the end of the pretraining routine."""
- _log_plots(trainer.plots, "train")
- def on_train_start(trainer):
- """Logs the training parameters if DVCLive logging is active."""
- if live:
- live.log_params(trainer.args)
- def on_train_epoch_start(trainer):
- """Sets the global variable _training_epoch value to True at the start of training each epoch."""
- global _training_epoch
- _training_epoch = True
- def on_fit_epoch_end(trainer):
- """Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
- global _training_epoch
- if live and _training_epoch:
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
- for metric, value in all_metrics.items():
- live.log_metric(metric, value)
- if trainer.epoch == 0:
- from ultralytics.utils.torch_utils import model_info_for_loggers
- for metric, value in model_info_for_loggers(trainer).items():
- live.log_metric(metric, value, plot=False)
- _log_plots(trainer.plots, "train")
- _log_plots(trainer.validator.plots, "val")
- live.next_step()
- _training_epoch = False
- def on_train_end(trainer):
- """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
- if live:
- # At the end log the best metrics. It runs validator on the best model internally.
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
- for metric, value in all_metrics.items():
- live.log_metric(metric, value, plot=False)
- _log_plots(trainer.plots, "val")
- _log_plots(trainer.validator.plots, "val")
- _log_confusion_matrix(trainer.validator)
- if trainer.best.exists():
- live.log_artifact(trainer.best, copy=True, type="model")
- live.end()
- callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_pretrain_routine_end": on_pretrain_routine_end,
- "on_train_start": on_train_start,
- "on_train_epoch_start": on_train_epoch_start,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_end": on_train_end,
- }
- if dvclive
- else {}
- )
|