session.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import threading
  3. import time
  4. from http import HTTPStatus
  5. from pathlib import Path
  6. import requests
  7. from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM
  8. from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis
  9. from ultralytics.utils.errors import HUBModelError
  10. AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local"
  11. class HUBTrainingSession:
  12. """
  13. HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
  14. Attributes:
  15. model_id (str): Identifier for the YOLO model being trained.
  16. model_url (str): URL for the model in Ultralytics HUB.
  17. rate_limits (dict): Rate limits for different API calls (in seconds).
  18. timers (dict): Timers for rate limiting.
  19. metrics_queue (dict): Queue for the model's metrics.
  20. model (dict): Model data fetched from Ultralytics HUB.
  21. """
  22. def __init__(self, identifier):
  23. """
  24. Initialize the HUBTrainingSession with the provided model identifier.
  25. Args:
  26. identifier (str): Model identifier used to initialize the HUB training session.
  27. It can be a URL string or a model key with specific format.
  28. Raises:
  29. ValueError: If the provided model identifier is invalid.
  30. ConnectionError: If connecting with global API key is not supported.
  31. ModuleNotFoundError: If hub-sdk package is not installed.
  32. """
  33. from hub_sdk import HUBClient
  34. self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds)
  35. self.metrics_queue = {} # holds metrics for each epoch until upload
  36. self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
  37. self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
  38. self.model = None
  39. self.model_url = None
  40. # Parse input
  41. api_key, model_id, self.filename = self._parse_identifier(identifier)
  42. # Get credentials
  43. active_key = api_key or SETTINGS.get("api_key")
  44. credentials = {"api_key": active_key} if active_key else None # set credentials
  45. # Initialize client
  46. self.client = HUBClient(credentials)
  47. # Load models if authenticated
  48. if self.client.authenticated:
  49. if model_id:
  50. self.load_model(model_id) # load existing model
  51. else:
  52. self.model = self.client.model() # load empty model
  53. @classmethod
  54. def create_session(cls, identifier, args=None):
  55. """Class method to create an authenticated HUBTrainingSession or return None."""
  56. try:
  57. session = cls(identifier)
  58. if not session.client.authenticated:
  59. if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
  60. LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
  61. exit()
  62. return None
  63. if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
  64. session.create_model(args)
  65. assert session.model.id, "HUB model not loaded correctly"
  66. return session
  67. # PermissionError and ModuleNotFoundError indicate hub-sdk not installed
  68. except (PermissionError, ModuleNotFoundError, AssertionError):
  69. return None
  70. def load_model(self, model_id):
  71. """Loads an existing model from Ultralytics HUB using the provided model identifier."""
  72. self.model = self.client.model(model_id)
  73. if not self.model.data: # then model does not exist
  74. raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
  75. self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
  76. self._set_train_args()
  77. # Start heartbeats for HUB to monitor agent
  78. self.model.start_heartbeat(self.rate_limits["heartbeat"])
  79. LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
  80. def create_model(self, model_args):
  81. """Initializes a HUB training session with the specified model identifier."""
  82. payload = {
  83. "config": {
  84. "batchSize": model_args.get("batch", -1),
  85. "epochs": model_args.get("epochs", 300),
  86. "imageSize": model_args.get("imgsz", 640),
  87. "patience": model_args.get("patience", 100),
  88. "device": str(model_args.get("device", "")), # convert None to string
  89. "cache": str(model_args.get("cache", "ram")), # convert True, False, None to string
  90. },
  91. "dataset": {"name": model_args.get("data")},
  92. "lineage": {
  93. "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
  94. "parent": {},
  95. },
  96. "meta": {"name": self.filename},
  97. }
  98. if self.filename.endswith(".pt"):
  99. payload["lineage"]["parent"]["name"] = self.filename
  100. self.model.create_model(payload)
  101. # Model could not be created
  102. # TODO: improve error handling
  103. if not self.model.id:
  104. return None
  105. self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
  106. # Start heartbeats for HUB to monitor agent
  107. self.model.start_heartbeat(self.rate_limits["heartbeat"])
  108. LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
  109. @staticmethod
  110. def _parse_identifier(identifier):
  111. """
  112. Parses the given identifier to determine the type of identifier and extract relevant components.
  113. The method supports different identifier formats:
  114. - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
  115. - An identifier containing an API key and a model ID separated by an underscore
  116. - An identifier that is solely a model ID of a fixed length
  117. - A local filename that ends with '.pt' or '.yaml'
  118. Args:
  119. identifier (str): The identifier string to be parsed.
  120. Returns:
  121. (tuple): A tuple containing the API key, model ID, and filename as applicable.
  122. Raises:
  123. HUBModelError: If the identifier format is not recognized.
  124. """
  125. # Initialize variables
  126. api_key, model_id, filename = None, None, None
  127. # Check if identifier is a HUB URL
  128. if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
  129. # Extract the model_id after the HUB_WEB_ROOT URL
  130. model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
  131. else:
  132. # Split the identifier based on underscores only if it's not a HUB URL
  133. parts = identifier.split("_")
  134. # Check if identifier is in the format of API key and model ID
  135. if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
  136. api_key, model_id = parts
  137. # Check if identifier is a single model ID
  138. elif len(parts) == 1 and len(parts[0]) == 20:
  139. model_id = parts[0]
  140. # Check if identifier is a local filename
  141. elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
  142. filename = identifier
  143. else:
  144. raise HUBModelError(
  145. f"model='{identifier}' could not be parsed. Check format is correct. "
  146. f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
  147. )
  148. return api_key, model_id, filename
  149. def _set_train_args(self):
  150. """
  151. Initializes training arguments and creates a model entry on the Ultralytics HUB.
  152. This method sets up training arguments based on the model's state and updates them with any additional
  153. arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
  154. or requires specific file setup.
  155. Raises:
  156. ValueError: If the model is already trained, if required dataset information is missing, or if there are
  157. issues with the provided training arguments.
  158. """
  159. if self.model.is_trained():
  160. raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
  161. if self.model.is_resumable():
  162. # Model has saved weights
  163. self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
  164. self.model_file = self.model.get_weights_url("last")
  165. else:
  166. # Model has no saved weights
  167. self.train_args = self.model.data.get("train_args") # new response
  168. # Set the model file as either a *.pt or *.yaml file
  169. self.model_file = (
  170. self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
  171. )
  172. if "data" not in self.train_args:
  173. # RF bug - datasets are sometimes not exported
  174. raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
  175. self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
  176. self.model_id = self.model.id
  177. def request_queue(
  178. self,
  179. request_func,
  180. retry=3,
  181. timeout=30,
  182. thread=True,
  183. verbose=True,
  184. progress_total=None,
  185. stream_response=None,
  186. *args,
  187. **kwargs,
  188. ):
  189. """Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress."""
  190. def retry_request():
  191. """Attempts to call `request_func` with retries, timeout, and optional threading."""
  192. t0 = time.time() # Record the start time for the timeout
  193. response = None
  194. for i in range(retry + 1):
  195. if (time.time() - t0) > timeout:
  196. LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
  197. break # Timeout reached, exit loop
  198. response = request_func(*args, **kwargs)
  199. if response is None:
  200. LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
  201. time.sleep(2**i) # Exponential backoff before retrying
  202. continue # Skip further processing and retry
  203. if progress_total:
  204. self._show_upload_progress(progress_total, response)
  205. elif stream_response:
  206. self._iterate_content(response)
  207. if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
  208. # if request related to metrics upload
  209. if kwargs.get("metrics"):
  210. self.metrics_upload_failed_queue = {}
  211. return response # Success, no need to retry
  212. if i == 0:
  213. # Initial attempt, check status code and provide messages
  214. message = self._get_failure_message(response, retry, timeout)
  215. if verbose:
  216. LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
  217. if not self._should_retry(response.status_code):
  218. LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
  219. break # Not an error that should be retried, exit loop
  220. time.sleep(2**i) # Exponential backoff for retries
  221. # if request related to metrics upload and exceed retries
  222. if response is None and kwargs.get("metrics"):
  223. self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))
  224. return response
  225. if thread:
  226. # Start a new thread to run the retry_request function
  227. threading.Thread(target=retry_request, daemon=True).start()
  228. else:
  229. # If running in the main thread, call retry_request directly
  230. return retry_request()
  231. @staticmethod
  232. def _should_retry(status_code):
  233. """Determines if a request should be retried based on the HTTP status code."""
  234. retry_codes = {
  235. HTTPStatus.REQUEST_TIMEOUT,
  236. HTTPStatus.BAD_GATEWAY,
  237. HTTPStatus.GATEWAY_TIMEOUT,
  238. }
  239. return status_code in retry_codes
  240. def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
  241. """
  242. Generate a retry message based on the response status code.
  243. Args:
  244. response: The HTTP response object.
  245. retry: The number of retry attempts allowed.
  246. timeout: The maximum timeout duration.
  247. Returns:
  248. (str): The retry message.
  249. """
  250. if self._should_retry(response.status_code):
  251. return f"Retrying {retry}x for {timeout}s." if retry else ""
  252. elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
  253. headers = response.headers
  254. return (
  255. f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
  256. f"Please retry after {headers['Retry-After']}s."
  257. )
  258. else:
  259. try:
  260. return response.json().get("message", "No JSON message.")
  261. except AttributeError:
  262. return "Unable to read JSON."
  263. def upload_metrics(self):
  264. """Upload model metrics to Ultralytics HUB."""
  265. return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
  266. def upload_model(
  267. self,
  268. epoch: int,
  269. weights: str,
  270. is_best: bool = False,
  271. map: float = 0.0,
  272. final: bool = False,
  273. ) -> None:
  274. """
  275. Upload a model checkpoint to Ultralytics HUB.
  276. Args:
  277. epoch (int): The current training epoch.
  278. weights (str): Path to the model weights file.
  279. is_best (bool): Indicates if the current model is the best one so far.
  280. map (float): Mean average precision of the model.
  281. final (bool): Indicates if the model is the final model after training.
  282. """
  283. if Path(weights).is_file():
  284. progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
  285. self.request_queue(
  286. self.model.upload_model,
  287. epoch=epoch,
  288. weights=weights,
  289. is_best=is_best,
  290. map=map,
  291. final=final,
  292. retry=10,
  293. timeout=3600,
  294. thread=not final,
  295. progress_total=progress_total,
  296. stream_response=True,
  297. )
  298. else:
  299. LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
  300. @staticmethod
  301. def _show_upload_progress(content_length: int, response: requests.Response) -> None:
  302. """
  303. Display a progress bar to track the upload progress of a file download.
  304. Args:
  305. content_length (int): The total size of the content to be downloaded in bytes.
  306. response (requests.Response): The response object from the file download request.
  307. Returns:
  308. None
  309. """
  310. with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
  311. for data in response.iter_content(chunk_size=1024):
  312. pbar.update(len(data))
  313. @staticmethod
  314. def _iterate_content(response: requests.Response) -> None:
  315. """
  316. Process the streamed HTTP response data.
  317. Args:
  318. response (requests.Response): The response object from the file download request.
  319. Returns:
  320. None
  321. """
  322. for _ in response.iter_content(chunk_size=1024):
  323. pass # Do nothing with data chunks