dvc.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
  3. try:
  4. assert not TESTS_RUNNING # do not log pytest
  5. assert SETTINGS["dvc"] is True # verify integration is enabled
  6. import dvclive
  7. assert checks.check_version("dvclive", "2.11.0", verbose=True)
  8. import os
  9. import re
  10. from pathlib import Path
  11. # DVCLive logger instance
  12. live = None
  13. _processed_plots = {}
  14. # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
  15. # distinguish final evaluation of the best model vs last epoch validation
  16. _training_epoch = False
  17. except (ImportError, AssertionError, TypeError):
  18. dvclive = None
  19. def _log_images(path, prefix=""):
  20. """Logs images at specified path with an optional prefix using DVCLive."""
  21. if live:
  22. name = path.name
  23. # Group images by batch to enable sliders in UI
  24. if m := re.search(r"_batch(\d+)", name):
  25. ni = m[1]
  26. new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
  27. name = (Path(new_stem) / ni).with_suffix(path.suffix)
  28. live.log_image(os.path.join(prefix, name), path)
  29. def _log_plots(plots, prefix=""):
  30. """Logs plot images for training progress if they have not been previously processed."""
  31. for name, params in plots.items():
  32. timestamp = params["timestamp"]
  33. if _processed_plots.get(name) != timestamp:
  34. _log_images(name, prefix)
  35. _processed_plots[name] = timestamp
  36. def _log_confusion_matrix(validator):
  37. """Logs the confusion matrix for the given validator using DVCLive."""
  38. targets = []
  39. preds = []
  40. matrix = validator.confusion_matrix.matrix
  41. names = list(validator.names.values())
  42. if validator.confusion_matrix.task == "detect":
  43. names += ["background"]
  44. for ti, pred in enumerate(matrix.T.astype(int)):
  45. for pi, num in enumerate(pred):
  46. targets.extend([names[ti]] * num)
  47. preds.extend([names[pi]] * num)
  48. live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
  49. def on_pretrain_routine_start(trainer):
  50. """Initializes DVCLive logger for training metadata during pre-training routine."""
  51. try:
  52. global live
  53. live = dvclive.Live(save_dvc_exp=True, cache_images=True)
  54. LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
  55. except Exception as e:
  56. LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
  57. def on_pretrain_routine_end(trainer):
  58. """Logs plots related to the training process at the end of the pretraining routine."""
  59. _log_plots(trainer.plots, "train")
  60. def on_train_start(trainer):
  61. """Logs the training parameters if DVCLive logging is active."""
  62. if live:
  63. live.log_params(trainer.args)
  64. def on_train_epoch_start(trainer):
  65. """Sets the global variable _training_epoch value to True at the start of training each epoch."""
  66. global _training_epoch
  67. _training_epoch = True
  68. def on_fit_epoch_end(trainer):
  69. """Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
  70. global _training_epoch
  71. if live and _training_epoch:
  72. all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
  73. for metric, value in all_metrics.items():
  74. live.log_metric(metric, value)
  75. if trainer.epoch == 0:
  76. from ultralytics.utils.torch_utils import model_info_for_loggers
  77. for metric, value in model_info_for_loggers(trainer).items():
  78. live.log_metric(metric, value, plot=False)
  79. _log_plots(trainer.plots, "train")
  80. _log_plots(trainer.validator.plots, "val")
  81. live.next_step()
  82. _training_epoch = False
  83. def on_train_end(trainer):
  84. """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
  85. if live:
  86. # At the end log the best metrics. It runs validator on the best model internally.
  87. all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
  88. for metric, value in all_metrics.items():
  89. live.log_metric(metric, value, plot=False)
  90. _log_plots(trainer.plots, "val")
  91. _log_plots(trainer.validator.plots, "val")
  92. _log_confusion_matrix(trainer.validator)
  93. if trainer.best.exists():
  94. live.log_artifact(trainer.best, copy=True, type="model")
  95. live.end()
  96. callbacks = (
  97. {
  98. "on_pretrain_routine_start": on_pretrain_routine_start,
  99. "on_pretrain_routine_end": on_pretrain_routine_end,
  100. "on_train_start": on_train_start,
  101. "on_train_epoch_start": on_train_epoch_start,
  102. "on_fit_epoch_end": on_fit_epoch_end,
  103. "on_train_end": on_train_end,
  104. }
  105. if dvclive
  106. else {}
  107. )