main.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset. Models and datasets download automatically from the latest YOLOv5 release.
  4. Usage - Single-GPU training:
  5. $ python train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (recommended)
  6. $ python train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  7. Usage - Multi-GPU DDP training:
  8. $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 train.py --data coco128.yaml --weights yolov5s.pt --img 640 --device 0,1,2,3
  9. Models: https://github.com/ultralytics/yolov5/tree/master/models
  10. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  11. Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
  12. """
  13. import argparse
  14. import math
  15. import os
  16. import random
  17. import subprocess
  18. import sys
  19. import time
  20. from copy import deepcopy
  21. from datetime import datetime, timedelta
  22. from pathlib import Path
  23. try:
  24. import comet_ml # must be imported before torch (if installed)
  25. except ImportError:
  26. comet_ml = None
  27. import numpy as np
  28. import torch
  29. import torch.distributed as dist
  30. import torch.nn as nn
  31. import yaml
  32. from torch.optim import lr_scheduler
  33. from tqdm import tqdm
  34. FILE = Path(__file__).resolve()
  35. ROOT = FILE.parents[0] # YOLOv5 root directory
  36. if str(ROOT) not in sys.path:
  37. sys.path.append(str(ROOT)) # add ROOT to PATH
  38. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  39. import val as validate # for end-of-epoch mAP
  40. import TrainSdk
  41. from models.experimental import attempt_load
  42. from models.yolo import Model
  43. from utils.autoanchor import check_anchors
  44. from utils.autobatch import check_train_batch_size
  45. from utils.callbacks import Callbacks
  46. from utils.dataloaders import create_dataloader, create_dataloader_platform
  47. from utils.downloads import attempt_download, is_url
  48. from utils.general import (
  49. LOGGER,
  50. TQDM_BAR_FORMAT,
  51. check_amp,
  52. check_dataset,
  53. check_file,
  54. check_img_size,
  55. check_suffix,
  56. check_yaml,
  57. colorstr,
  58. get_latest_run,
  59. increment_path,
  60. init_seeds,
  61. intersect_dicts,
  62. labels_to_class_weights,
  63. labels_to_image_weights,
  64. methods,
  65. one_cycle,
  66. print_args,
  67. print_mutation,
  68. strip_optimizer,
  69. yaml_save,
  70. )
  71. from utils.loggers import LOGGERS, Loggers
  72. from utils.loggers.comet.comet_utils import check_comet_resume
  73. from utils.loss import ComputeLoss, ComputeLossOTA
  74. from utils.metrics import fitness
  75. from utils.plots import plot_evolve
  76. from utils.torch_utils import (
  77. EarlyStopping,
  78. ModelEMA,
  79. de_parallel,
  80. select_device,
  81. smart_DDP,
  82. smart_optimizer,
  83. smart_resume,
  84. torch_distributed_zero_first,
  85. )
  86. from metrics.model2onnx import run as model2onnx
  87. from metrics.image_test.object_metrics import ObjectMetrics, YoloType, YoloMetas
  88. LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
  89. RANK = int(os.getenv("RANK", -1))
  90. WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
  91. def train(hyp, opt, device, callbacks):
  92. """
  93. Trains YOLOv5 model with given hyperparameters, options, and device, managing datasets, model architecture, loss
  94. computation, and optimizer steps.
  95. `hyp` argument is path/to/hyp.yaml or hyp dictionary.
  96. """
  97. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, is_train_on_platform, use_v7_loss = (
  98. Path(opt.save_dir),
  99. opt.epochs,
  100. opt.batch_size,
  101. opt.weights,
  102. opt.single_cls,
  103. opt.evolve,
  104. opt.data,
  105. opt.cfg,
  106. opt.resume,
  107. opt.noval,
  108. opt.nosave,
  109. opt.workers,
  110. opt.freeze,
  111. opt.is_train_on_platform,
  112. opt.use_v7_loss
  113. )
  114. callbacks.run("on_pretrain_routine_start")
  115. # Directories
  116. w = save_dir / "weights" # weights dir
  117. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  118. last, best = w / "last.pt", w / "best.pt"
  119. # Hyperparameters
  120. if isinstance(hyp, str):
  121. with open(hyp, errors="ignore") as f:
  122. hyp = yaml.safe_load(f) # load hyps dict
  123. LOGGER.info(colorstr("hyperparameters: ") + ", ".join(f"{k}={v}" for k, v in hyp.items()))
  124. opt.hyp = hyp.copy() # for saving hyps to checkpoints
  125. # Save run settings
  126. if not evolve:
  127. yaml_save(save_dir / "hyp.yaml", hyp)
  128. yaml_save(save_dir / "opt.yaml", vars(opt))
  129. TrainSdk.save_output_model(save_dir / "hyp.yaml")
  130. TrainSdk.save_output_model(save_dir / "opt.yaml")
  131. # Loggers
  132. data_dict = None
  133. if RANK in {-1, 0}:
  134. include_loggers = list(LOGGERS)
  135. if getattr(opt, "ndjson_console", False):
  136. include_loggers.append("ndjson_console")
  137. if getattr(opt, "ndjson_file", False):
  138. include_loggers.append("ndjson_file")
  139. loggers = Loggers(
  140. save_dir=save_dir,
  141. weights=weights,
  142. opt=opt,
  143. hyp=hyp,
  144. logger=LOGGER,
  145. include=tuple(include_loggers),
  146. )
  147. # Register actions
  148. for k in methods(loggers):
  149. callbacks.register_action(k, callback=getattr(loggers, k))
  150. # Process custom dataset artifact link
  151. data_dict = loggers.remote_dataset
  152. if resume: # If resuming runs from remote artifact
  153. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
  154. # Config
  155. plots = not evolve and not opt.noplots # create plots
  156. cuda = device.type != "cpu"
  157. init_seeds(opt.seed + 1 + RANK, deterministic=True)
  158. with torch_distributed_zero_first(LOCAL_RANK):
  159. data_dict = data_dict or check_dataset(data, is_train_on_platform) # check if None
  160. #训练训练无需train_path和val_path
  161. if not is_train_on_platform:
  162. train_path, val_path = data_dict["train"], data_dict["val"]
  163. nc = 1 if single_cls else int(data_dict["nc"]) # number of classes
  164. names = {0: "item"} if single_cls and len(data_dict["names"]) != 1 else data_dict["names"] # class names
  165. if is_train_on_platform:
  166. is_coco = False
  167. else:
  168. is_coco = isinstance(val_path, str) and val_path.endswith("coco/val2017.txt") # COCO dataset
  169. # Model
  170. check_suffix(weights, ".pt") # check weights
  171. pretrained = weights.endswith(".pt")
  172. if pretrained:
  173. with torch_distributed_zero_first(LOCAL_RANK):
  174. weights = attempt_download(weights) # download if not found locally
  175. ckpt = torch.load(weights, map_location="cpu") # load checkpoint to CPU to avoid CUDA memory leak
  176. model = Model(cfg or ckpt["model"].yaml, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
  177. exclude = ["anchor"] if (cfg or hyp.get("anchors")) and not resume else [] # exclude keys
  178. csd = ckpt["model"].float().state_dict() # checkpoint state_dict as FP32
  179. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  180. model.load_state_dict(csd, strict=False) # load
  181. LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
  182. else:
  183. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
  184. amp = check_amp(model) # check AMP
  185. # Freeze
  186. freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  187. for k, v in model.named_parameters():
  188. v.requires_grad = True # train all layers
  189. # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
  190. if any(x in k for x in freeze):
  191. LOGGER.info(f"freezing {k}")
  192. v.requires_grad = False
  193. # Image size
  194. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  195. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  196. # Batch size
  197. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  198. batch_size = check_train_batch_size(model, imgsz, amp)
  199. loggers.on_params_update({"batch_size": batch_size})
  200. # Optimizer
  201. nbs = 64 # nominal batch size
  202. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  203. hyp["weight_decay"] *= batch_size * accumulate / nbs # scale weight_decay
  204. optimizer = smart_optimizer(model, opt.optimizer, hyp["lr0"], hyp["momentum"], hyp["weight_decay"])
  205. # Scheduler
  206. if opt.cos_lr:
  207. lf = one_cycle(1, hyp["lrf"], epochs) # cosine 1->hyp['lrf']
  208. else:
  209. lf = lambda x: (1 - x / epochs) * (1.0 - hyp["lrf"]) + hyp["lrf"] # linear
  210. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  211. # EMA
  212. ema = ModelEMA(model) if RANK in {-1, 0} else None
  213. # Resume
  214. best_fitness, start_epoch = 0.0, 0
  215. if pretrained:
  216. if resume:
  217. best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
  218. del ckpt, csd
  219. # DP mode
  220. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  221. LOGGER.warning(
  222. "WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n"
  223. "See Multi-GPU Tutorial at https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training to get started."
  224. )
  225. model = torch.nn.DataParallel(model)
  226. # SyncBatchNorm
  227. if opt.sync_bn and cuda and RANK != -1:
  228. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  229. LOGGER.info("Using SyncBatchNorm()")
  230. # Trainloader
  231. if is_train_on_platform:
  232. train_loader, dataset = create_dataloader_platform(
  233. imgsz,
  234. batch_size // WORLD_SIZE,
  235. gs,
  236. single_cls,
  237. hyp=hyp,
  238. data_dict=data_dict,
  239. train_or_val_data='train',
  240. augment=True,
  241. cache=None if opt.cache == "val" else opt.cache,
  242. rect=opt.rect,
  243. rank=LOCAL_RANK,
  244. workers=workers,
  245. image_weights=opt.image_weights,
  246. quad=opt.quad,
  247. prefix=colorstr("train: "),
  248. shuffle=True,
  249. seed=opt.seed
  250. )
  251. else:
  252. train_loader, dataset = create_dataloader(
  253. train_path,
  254. imgsz,
  255. batch_size // WORLD_SIZE,
  256. gs,
  257. single_cls,
  258. hyp=hyp,
  259. augment=True,
  260. cache=None if opt.cache == "val" else opt.cache,
  261. rect=opt.rect,
  262. rank=LOCAL_RANK,
  263. workers=workers,
  264. image_weights=opt.image_weights,
  265. quad=opt.quad,
  266. prefix=colorstr("train: "),
  267. shuffle=True,
  268. seed=opt.seed,
  269. )
  270. labels = np.concatenate(dataset.labels, 0)
  271. mlc = int(labels[:, 0].max()) # max label class
  272. assert mlc < nc, f"Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}"
  273. # Process 0
  274. if RANK in {-1, 0}:
  275. if is_train_on_platform:
  276. val_loader = create_dataloader_platform(
  277. imgsz,
  278. batch_size // WORLD_SIZE * 2,
  279. gs,
  280. single_cls,
  281. hyp=hyp,
  282. data_dict=data_dict,
  283. train_or_val_data='val',
  284. cache=None if noval else opt.cache,
  285. rect=True,
  286. rank=-1,
  287. workers=workers * 2,
  288. pad=0.5,
  289. prefix=colorstr("val: "),
  290. )[0]
  291. else:
  292. val_loader = create_dataloader(
  293. val_path,
  294. imgsz,
  295. batch_size // WORLD_SIZE * 2,
  296. gs,
  297. single_cls,
  298. hyp=hyp,
  299. cache=None if noval else opt.cache,
  300. rect=True,
  301. rank=-1,
  302. workers=workers * 2,
  303. pad=0.5,
  304. prefix=colorstr("val: "),
  305. )[0]
  306. if not resume:
  307. if not opt.noautoanchor:
  308. check_anchors(dataset, model=model, thr=hyp["anchor_t"], imgsz=imgsz) # run AutoAnchor
  309. model.half().float() # pre-reduce anchor precision
  310. callbacks.run("on_pretrain_routine_end", labels, names)
  311. # DDP mode
  312. if cuda and RANK != -1:
  313. model = smart_DDP(model)
  314. # Model attributes
  315. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  316. hyp["box"] *= 3 / nl # scale to layers
  317. hyp["cls"] *= nc / 80 * 3 / nl # scale to classes and layers
  318. hyp["obj"] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  319. hyp["label_smoothing"] = opt.label_smoothing
  320. model.nc = nc # attach number of classes to model
  321. model.hyp = hyp # attach hyperparameters to model
  322. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  323. model.names = names
  324. # Start training
  325. t0 = time.time()
  326. nb = len(train_loader) # number of batches
  327. nw = max(round(hyp["warmup_epochs"] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
  328. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  329. last_opt_step = -1
  330. maps = np.zeros(nc) # mAP per class
  331. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  332. scheduler.last_epoch = start_epoch - 1 # do not move
  333. scaler = torch.cuda.amp.GradScaler(enabled=amp)
  334. stopper, stop = EarlyStopping(patience=opt.patience), False
  335. if use_v7_loss:
  336. compute_loss = ComputeLossOTA(model)
  337. else:
  338. compute_loss = ComputeLoss(model) # init loss class
  339. callbacks.run("on_train_start")
  340. LOGGER.info(
  341. f'Image sizes {imgsz} train, {imgsz} val\n'
  342. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  343. f"Logging results to {colorstr('bold', save_dir)}\n"
  344. f'Starting training for {epochs} epochs...'
  345. )
  346. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  347. callbacks.run("on_train_epoch_start")
  348. model.train()
  349. # Update image weights (optional, single-GPU only)
  350. if opt.image_weights:
  351. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  352. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  353. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  354. # Update mosaic border (optional)
  355. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  356. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  357. mloss = torch.zeros(3, device=device) # mean losses
  358. if RANK != -1:
  359. train_loader.sampler.set_epoch(epoch)
  360. pbar = enumerate(train_loader)
  361. LOGGER.info(("\n" + "%11s" * 7) % ("Epoch", "GPU_mem", "box_loss", "obj_loss", "cls_loss", "Instances", "Size"))
  362. if RANK in {-1, 0}:
  363. pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
  364. optimizer.zero_grad()
  365. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  366. callbacks.run("on_train_batch_start")
  367. ni = i + nb * epoch # number integrated batches (since train start)
  368. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  369. # Warmup
  370. if ni <= nw:
  371. xi = [0, nw] # x interp
  372. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  373. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  374. for j, x in enumerate(optimizer.param_groups):
  375. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  376. x["lr"] = np.interp(ni, xi, [hyp["warmup_bias_lr"] if j == 0 else 0.0, x["initial_lr"] * lf(epoch)])
  377. if "momentum" in x:
  378. x["momentum"] = np.interp(ni, xi, [hyp["warmup_momentum"], hyp["momentum"]])
  379. # Multi-scale
  380. if opt.multi_scale:
  381. sz = random.randrange(int(imgsz * 0.5), int(imgsz * 1.5) + gs) // gs * gs # size
  382. sf = sz / max(imgs.shape[2:]) # scale factor
  383. if sf != 1:
  384. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  385. imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
  386. # Forward
  387. with torch.cuda.amp.autocast(amp):
  388. pred = model(imgs) # forward
  389. if use_v7_loss:
  390. loss, loss_items = compute_loss(pred, targets.to(device), imgs)
  391. else:
  392. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  393. if RANK != -1:
  394. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  395. if opt.quad:
  396. loss *= 4.0
  397. # Backward
  398. scaler.scale(loss).backward()
  399. # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
  400. if ni - last_opt_step >= accumulate:
  401. scaler.unscale_(optimizer) # unscale gradients
  402. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
  403. scaler.step(optimizer) # optimizer.step
  404. scaler.update()
  405. optimizer.zero_grad()
  406. if ema:
  407. ema.update(model)
  408. last_opt_step = ni
  409. # Log
  410. if RANK in {-1, 0}:
  411. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  412. mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
  413. pbar.set_description(
  414. ("%11s" * 2 + "%11.4g" * 5)
  415. % (f"{epoch}/{epochs - 1}", mem, *mloss, targets.shape[0], imgs.shape[-1])
  416. )
  417. callbacks.run("on_train_batch_end", model, ni, imgs, targets, paths, list(mloss))
  418. if callbacks.stop_training:
  419. return
  420. # end batch ------------------------------------------------------------------------------------------------
  421. # Scheduler
  422. lr = [x["lr"] for x in optimizer.param_groups] # for loggers
  423. scheduler.step()
  424. if RANK in {-1, 0}:
  425. # mAP
  426. callbacks.run("on_train_epoch_end", epoch=epoch)
  427. ema.update_attr(model, include=["yaml", "nc", "hyp", "names", "stride", "class_weights"])
  428. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  429. if not noval or final_epoch: # Calculate mAP
  430. results, maps, _ = validate.run(
  431. data_dict,
  432. batch_size=batch_size // WORLD_SIZE * 2,
  433. imgsz=imgsz,
  434. half=amp,
  435. model=ema.ema,
  436. single_cls=single_cls,
  437. dataloader=val_loader,
  438. save_dir=save_dir,
  439. plots=False,
  440. callbacks=callbacks,
  441. compute_loss=compute_loss,
  442. use_v7_loss=use_v7_loss,
  443. is_train_on_platform=is_train_on_platform,
  444. )
  445. # Update best mAP
  446. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  447. stop = stopper(epoch=epoch, fitness=fi) # early stop check
  448. if fi > best_fitness:
  449. best_fitness = fi
  450. log_vals = list(mloss) + list(results) + lr
  451. callbacks.run("on_fit_epoch_end", log_vals, epoch, best_fitness, fi)
  452. # Save model
  453. if (not nosave) or (final_epoch and not evolve): # if save
  454. ckpt = {
  455. "epoch": epoch,
  456. "best_fitness": best_fitness,
  457. "model": deepcopy(de_parallel(model)).half(),
  458. "ema": deepcopy(ema.ema).half(),
  459. "updates": ema.updates,
  460. "optimizer": optimizer.state_dict(),
  461. "opt": vars(opt),
  462. "date": datetime.now().isoformat(),
  463. }
  464. # Save last, best and delete
  465. torch.save(ckpt, last)
  466. TrainSdk.save_output_model(last)
  467. if best_fitness == fi:
  468. torch.save(ckpt, best)
  469. TrainSdk.save_output_model(best)
  470. if opt.save_period > 0 and epoch % opt.save_period == 0:
  471. torch.save(ckpt, w / f"epoch{epoch}.pt")
  472. TrainSdk.save_output_model(w / f"epoch{epoch}.pt")
  473. # 额外计算转成onnx模型的评价指标
  474. if is_train_on_platform:
  475. model2onnx(
  476. weights=w / f"epoch{epoch}.pt",
  477. imgsz=[opt.imgsz, opt.imgsz],
  478. batch_size=1,
  479. device="cpu",
  480. inplace=True,
  481. dynamic=False,
  482. simplify=False,
  483. opset=17
  484. )
  485. onnx_model = w / f"epoch{epoch}.onnx"
  486. TrainSdk.save_output_model(onnx_model)
  487. platform_data_args = data_dict["platform_data_args"]
  488. class_id_map_list = platform_data_args["class_id_map_list"]
  489. dll_file = platform_data_args["dll_file"]
  490. wrong_file = platform_data_args["wrong_file"]
  491. yolo_metas_yaml = platform_data_args['yolo_metas']
  492. metrics_type = platform_data_args['metrics_type']
  493. extra_contours_args = platform_data_args['extra_contours_args']
  494. needed_image_results_dict = val_loader.dataset.needed_image_results_dict
  495. needed_rois_dict = val_loader.dataset.needed_rois_dict
  496. yolo_metas = YoloMetas(
  497. yolotype=YoloType[yolo_metas_yaml['yolotype']],
  498. confthres=yolo_metas_yaml['confthres'],
  499. clsconfthres=yolo_metas_yaml['clsconfthres'],
  500. batchsize=yolo_metas_yaml['batchsize'],
  501. maxdet=yolo_metas_yaml['maxdet'],
  502. minboxratio=yolo_metas_yaml['minboxratio'],
  503. # ApplyPostProcessToBBox中,用于单个类别的box筛选
  504. postprocesstopk=yolo_metas_yaml['postprocesstopk'],
  505. enableioufilt=yolo_metas_yaml['enableioufilt'],
  506. enableiosfilt=yolo_metas_yaml['enableiosfilt'],
  507. ioufltth=yolo_metas_yaml['ioufltth'],
  508. iosfltth=yolo_metas_yaml['iosfltth'],
  509. # FindBoxesToUnion中 当两框重叠率高于一定程度,且合并后增加面积并不多,则合并两框
  510. enableunion=yolo_metas_yaml['enableunion'],
  511. unioniouth=yolo_metas_yaml['unioniouth'],
  512. unioniosth=yolo_metas_yaml['unioniosth'],
  513. unionuobth=yolo_metas_yaml['unionuobth'],
  514. # ApplyBoxClassFilter中,同一幅图上有多个框,用于不用类别的box筛选
  515. enableioufiltdiffcls=yolo_metas_yaml['enableioufiltdiffcls'],
  516. enableiosfiltdiffcls=yolo_metas_yaml['enableiosfiltdiffcls'],
  517. ioufltthdiffcls=yolo_metas_yaml['ioufltthdiffcls'],
  518. iosfltthdiffcls=yolo_metas_yaml['iosfltthdiffcls'], )
  519. for class_id_map in class_id_map_list:
  520. metric_reports = ObjectMetrics(is_local_file=False,
  521. files=val_loader.dataset.im_files,
  522. token=val_loader.dataset.token,
  523. onnx_file=onnx_model,
  524. needed_image_results_dict=needed_image_results_dict,
  525. needed_rois_dict=needed_rois_dict,
  526. extra_contours_args=extra_contours_args,
  527. class_id_map=class_id_map,
  528. wrong_file=wrong_file,
  529. dll_file=dll_file,
  530. yolo_metas=yolo_metas,
  531. metrics_type=metrics_type)
  532. metric_reports.run()
  533. del ckpt
  534. callbacks.run("on_model_save", last, epoch, final_epoch, best_fitness, fi)
  535. # EarlyStopping
  536. if RANK != -1: # if DDP training
  537. broadcast_list = [stop if RANK == 0 else None]
  538. dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
  539. if RANK != 0:
  540. stop = broadcast_list[0]
  541. if stop:
  542. break # must break all DDP ranks
  543. # end epoch ----------------------------------------------------------------------------------------------------
  544. # end training -----------------------------------------------------------------------------------------------------
  545. if RANK in {-1, 0}:
  546. LOGGER.info(f"\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.")
  547. for f in last, best:
  548. if f.exists():
  549. strip_optimizer(f) # strip optimizers
  550. if f is best:
  551. LOGGER.info(f"\nValidating {f}...")
  552. results, _, _ = validate.run(
  553. data_dict,
  554. batch_size=batch_size // WORLD_SIZE * 2,
  555. imgsz=imgsz,
  556. model=attempt_load(f, device).half(),
  557. iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
  558. single_cls=single_cls,
  559. dataloader=val_loader,
  560. save_dir=save_dir,
  561. save_json=is_coco,
  562. verbose=True,
  563. plots=plots,
  564. callbacks=callbacks,
  565. compute_loss=compute_loss,
  566. use_v7_loss=use_v7_loss,
  567. is_train_on_platform=is_train_on_platform,
  568. ) # val best model with plots
  569. if is_coco:
  570. callbacks.run("on_fit_epoch_end", list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  571. callbacks.run("on_train_end", last, best, epoch, results)
  572. torch.cuda.empty_cache()
  573. return results
  574. def parse_opt(known=False):
  575. """Parses command-line arguments for YOLOv5 training, validation, and testing."""
  576. parser = argparse.ArgumentParser()
  577. parser.add_argument("--weights", type=str, default="", help="initial weights path")
  578. parser.add_argument("--cfg", type=str, default=ROOT / "models/hub/yolov5s-ghost.yaml", help="model.yaml path")
  579. parser.add_argument("--data", type=str, default=ROOT / "data/vinno_data/Thyroid-TIRADS-Obj.yaml", help="dataset.yaml path")
  580. parser.add_argument("--hyp", type=str, default=ROOT / "data/hyps/hyp.scratch_breast-birads.yaml",
  581. help="hyperparameters path")
  582. parser.add_argument("--epochs", type=int, default=80, help="total training epochs")
  583. parser.add_argument("--batch-size", type=int, default=1, help="total batch size for all GPUs, -1 for autobatch")
  584. parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=320, help="train, val image size (pixels)")
  585. parser.add_argument("--rect", action="store_true", help="rectangular training")
  586. parser.add_argument("--resume", nargs="?", const=True, default=False, help="resume most recent training")
  587. parser.add_argument("--nosave", action="store_true", help="only save final checkpoint")
  588. parser.add_argument("--noval", action="store_true", help="only validate final epoch")
  589. parser.add_argument("--noautoanchor", action="store_true", help="disable AutoAnchor")
  590. parser.add_argument("--noplots", action="store_true", help="save no plot files")
  591. parser.add_argument("--evolve", type=int, nargs="?", const=300, help="evolve hyperparameters for x generations")
  592. parser.add_argument(
  593. "--evolve_population", type=str, default=ROOT / "data/hyps", help="location for loading population"
  594. )
  595. parser.add_argument("--resume_evolve", type=str, default=None, help="resume evolve from last generation")
  596. parser.add_argument("--bucket", type=str, default="", help="gsutil bucket")
  597. parser.add_argument("--cache", type=str, nargs="?", const="ram", help="image --cache ram/disk")
  598. parser.add_argument("--image-weights", action="store_true", default=False,
  599. help="use weighted image selection for training")
  600. parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
  601. parser.add_argument("--multi-scale", action="store_true", help="vary img-size +/- 50%%")
  602. parser.add_argument("--single-cls", action="store_true", help="train multi-class data as single-class")
  603. parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "AdamW"], default="SGD", help="optimizer")
  604. parser.add_argument("--sync-bn", action="store_true", help="use SyncBatchNorm, only available in DDP mode")
  605. parser.add_argument("--workers", type=int, default=1, help="max dataloader workers (per RANK in DDP mode)")
  606. parser.add_argument("--project", default=ROOT / "runs/train", help="save to project/name")
  607. parser.add_argument("--name", default="exp", help="save to project/name")
  608. parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
  609. parser.add_argument("--quad", action="store_true", help="quad dataloader")
  610. parser.add_argument("--cos-lr", action="store_true", help="cosine LR scheduler")
  611. parser.add_argument("--label-smoothing", type=float, default=0.0, help="Label smoothing epsilon")
  612. parser.add_argument("--patience", type=int, default=100, help="EarlyStopping patience (epochs without improvement)")
  613. parser.add_argument("--freeze", nargs="+", type=int, default=[0], help="Freeze layers: backbone=10, first3=0 1 2")
  614. parser.add_argument("--save-period", type=int, default=10, help="Save checkpoint every x epochs (disabled if < 1)")
  615. parser.add_argument("--seed", type=int, default=0, help="Global training seed")
  616. parser.add_argument("--local_rank", type=int, default=-1, help="Automatic DDP Multi-GPU argument, do not modify")
  617. # Logger arguments
  618. parser.add_argument("--entity", default=None, help="Entity")
  619. parser.add_argument("--upload_dataset", nargs="?", const=True, default=False, help='Upload data, "val" option')
  620. parser.add_argument("--bbox_interval", type=int, default=-1, help="Set bounding-box image logging interval")
  621. parser.add_argument("--artifact_alias", type=str, default="latest", help="Version of dataset artifact to use")
  622. # NDJSON logging
  623. parser.add_argument("--ndjson-console", action="store_true", help="Log ndjson to console")
  624. parser.add_argument("--ndjson-file", action="store_true", help="Log ndjson to file")
  625. parser.add_argument('--use_v7_loss', type=bool, default=True, help='True is use v7_loss,False is use v5_loss')
  626. # VINNO AI平台训练
  627. parser.add_argument('--is_train_on_platform', type=bool, default=True,
  628. help='True is train on platform,False is train on local')
  629. return parser.parse_known_args()[0] if known else parser.parse_args()
  630. def main(opt, callbacks=Callbacks()):
  631. """Runs training or hyperparameter evolution with specified options and optional callbacks."""
  632. if RANK in {-1, 0}:
  633. print_args(vars(opt))
  634. # Resume (from specified or most recent last.pt)
  635. if opt.resume and not check_comet_resume(opt) and not opt.evolve:
  636. last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
  637. opt_yaml = last.parent.parent / "opt.yaml" # train options yaml
  638. opt_data = opt.data # original dataset
  639. if opt_yaml.is_file():
  640. with open(opt_yaml, errors="ignore") as f:
  641. d = yaml.safe_load(f)
  642. else:
  643. d = torch.load(last, map_location="cpu")["opt"]
  644. opt = argparse.Namespace(**d) # replace
  645. opt.cfg, opt.weights, opt.resume = "", str(last), True # reinstate
  646. if is_url(opt_data):
  647. opt.data = check_file(opt_data) # avoid HUB resume auth timeout
  648. else:
  649. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = (
  650. check_file(opt.data),
  651. check_yaml(opt.cfg),
  652. check_yaml(opt.hyp),
  653. str(opt.weights),
  654. str(opt.project),
  655. ) # checks
  656. assert len(opt.cfg) or len(opt.weights), "either --cfg or --weights must be specified"
  657. if opt.evolve:
  658. if opt.project == str(ROOT / "runs/train"): # if default project name, rename to runs/evolve
  659. opt.project = str(ROOT / "runs/evolve")
  660. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  661. if opt.name == "cfg":
  662. opt.name = Path(opt.cfg).stem # use model.yaml as name
  663. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  664. # DDP mode
  665. device = select_device(opt.device, batch_size=opt.batch_size)
  666. if LOCAL_RANK != -1:
  667. msg = "is not compatible with YOLOv5 Multi-GPU DDP training"
  668. assert not opt.image_weights, f"--image-weights {msg}"
  669. assert not opt.evolve, f"--evolve {msg}"
  670. assert opt.batch_size != -1, f"AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size"
  671. assert opt.batch_size % WORLD_SIZE == 0, f"--batch-size {opt.batch_size} must be multiple of WORLD_SIZE"
  672. assert torch.cuda.device_count() > LOCAL_RANK, "insufficient CUDA devices for DDP command"
  673. torch.cuda.set_device(LOCAL_RANK)
  674. device = torch.device("cuda", LOCAL_RANK)
  675. dist.init_process_group(
  676. backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=10800)
  677. )
  678. # Train
  679. if not opt.evolve:
  680. train(opt.hyp, opt, device, callbacks)
  681. # Evolve hyperparameters (optional)
  682. else:
  683. # Hyperparameter evolution metadata (including this hyperparameter True-False, lower_limit, upper_limit)
  684. meta = {
  685. "lr0": (False, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  686. "lrf": (False, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  687. "momentum": (False, 0.6, 0.98), # SGD momentum/Adam beta1
  688. "weight_decay": (False, 0.0, 0.001), # optimizer weight decay
  689. "warmup_epochs": (False, 0.0, 5.0), # warmup epochs (fractions ok)
  690. "warmup_momentum": (False, 0.0, 0.95), # warmup initial momentum
  691. "warmup_bias_lr": (False, 0.0, 0.2), # warmup initial bias lr
  692. "box": (False, 0.02, 0.2), # box loss gain
  693. "cls": (False, 0.2, 4.0), # cls loss gain
  694. "cls_pw": (False, 0.5, 2.0), # cls BCELoss positive_weight
  695. "obj": (False, 0.2, 4.0), # obj loss gain (scale with pixels)
  696. "obj_pw": (False, 0.5, 2.0), # obj BCELoss positive_weight
  697. "iou_t": (False, 0.1, 0.7), # IoU training threshold
  698. "anchor_t": (False, 2.0, 8.0), # anchor-multiple threshold
  699. "anchors": (False, 2.0, 10.0), # anchors per output grid (0 to ignore)
  700. "fl_gamma": (False, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  701. "hsv_h": (True, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  702. "hsv_s": (True, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  703. "hsv_v": (True, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  704. "degrees": (True, 0.0, 45.0), # image rotation (+/- deg)
  705. "translate": (True, 0.0, 0.9), # image translation (+/- fraction)
  706. "scale": (True, 0.0, 0.9), # image scale (+/- gain)
  707. "shear": (True, 0.0, 10.0), # image shear (+/- deg)
  708. "perspective": (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  709. "flipud": (True, 0.0, 1.0), # image flip up-down (probability)
  710. "fliplr": (True, 0.0, 1.0), # image flip left-right (probability)
  711. "mosaic": (True, 0.0, 1.0), # image mixup (probability)
  712. "mixup": (True, 0.0, 1.0), # image mixup (probability)
  713. "copy_paste": (True, 0.0, 1.0),
  714. } # segment copy-paste (probability)
  715. # GA configs
  716. pop_size = 50
  717. mutation_rate_min = 0.01
  718. mutation_rate_max = 0.5
  719. crossover_rate_min = 0.5
  720. crossover_rate_max = 1
  721. min_elite_size = 2
  722. max_elite_size = 5
  723. tournament_size_min = 2
  724. tournament_size_max = 10
  725. with open(opt.hyp, errors="ignore") as f:
  726. hyp = yaml.safe_load(f) # load hyps dict
  727. if "anchors" not in hyp: # anchors commented in hyp.yaml
  728. hyp["anchors"] = 3
  729. if opt.noautoanchor:
  730. del hyp["anchors"], meta["anchors"]
  731. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  732. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  733. evolve_yaml, evolve_csv = save_dir / "hyp_evolve.yaml", save_dir / "evolve.csv"
  734. if opt.bucket:
  735. # download evolve.csv if exists
  736. subprocess.run(
  737. [
  738. "gsutil",
  739. "cp",
  740. f"gs://{opt.bucket}/evolve.csv",
  741. str(evolve_csv),
  742. ]
  743. )
  744. # Delete the items in meta dictionary whose first value is False
  745. del_ = [item for item, value_ in meta.items() if value_[0] is False]
  746. hyp_GA = hyp.copy() # Make a copy of hyp dictionary
  747. for item in del_:
  748. del meta[item] # Remove the item from meta dictionary
  749. del hyp_GA[item] # Remove the item from hyp_GA dictionary
  750. # Set lower_limit and upper_limit arrays to hold the search space boundaries
  751. lower_limit = np.array([meta[k][1] for k in hyp_GA.keys()])
  752. upper_limit = np.array([meta[k][2] for k in hyp_GA.keys()])
  753. # Create gene_ranges list to hold the range of values for each gene in the population
  754. gene_ranges = [(lower_limit[i], upper_limit[i]) for i in range(len(upper_limit))]
  755. # Initialize the population with initial_values or random values
  756. initial_values = []
  757. # If resuming evolution from a previous checkpoint
  758. if opt.resume_evolve is not None:
  759. assert os.path.isfile(ROOT / opt.resume_evolve), "evolve population path is wrong!"
  760. with open(ROOT / opt.resume_evolve, errors="ignore") as f:
  761. evolve_population = yaml.safe_load(f)
  762. for value in evolve_population.values():
  763. value = np.array([value[k] for k in hyp_GA.keys()])
  764. initial_values.append(list(value))
  765. # If not resuming from a previous checkpoint, generate initial values from .yaml files in opt.evolve_population
  766. else:
  767. yaml_files = [f for f in os.listdir(opt.evolve_population) if f.endswith(".yaml")]
  768. for file_name in yaml_files:
  769. with open(os.path.join(opt.evolve_population, file_name)) as yaml_file:
  770. value = yaml.safe_load(yaml_file)
  771. value = np.array([value[k] for k in hyp_GA.keys()])
  772. initial_values.append(list(value))
  773. # Generate random values within the search space for the rest of the population
  774. if initial_values is None:
  775. population = [generate_individual(gene_ranges, len(hyp_GA)) for _ in range(pop_size)]
  776. elif pop_size > 1:
  777. population = [generate_individual(gene_ranges, len(hyp_GA)) for _ in range(pop_size - len(initial_values))]
  778. for initial_value in initial_values:
  779. population = [initial_value] + population
  780. # Run the genetic algorithm for a fixed number of generations
  781. list_keys = list(hyp_GA.keys())
  782. for generation in range(opt.evolve):
  783. if generation >= 1:
  784. save_dict = {}
  785. for i in range(len(population)):
  786. little_dict = {list_keys[j]: float(population[i][j]) for j in range(len(population[i]))}
  787. save_dict[f"gen{str(generation)}number{str(i)}"] = little_dict
  788. with open(save_dir / "evolve_population.yaml", "w") as outfile:
  789. yaml.dump(save_dict, outfile, default_flow_style=False)
  790. # Adaptive elite size
  791. elite_size = min_elite_size + int((max_elite_size - min_elite_size) * (generation / opt.evolve))
  792. # Evaluate the fitness of each individual in the population
  793. fitness_scores = []
  794. for individual in population:
  795. for key, value in zip(hyp_GA.keys(), individual):
  796. hyp_GA[key] = value
  797. hyp.update(hyp_GA)
  798. results = train(hyp.copy(), opt, device, callbacks)
  799. callbacks = Callbacks()
  800. # Write mutation results
  801. keys = (
  802. "metrics/precision",
  803. "metrics/recall",
  804. "metrics/mAP_0.5",
  805. "metrics/mAP_0.5:0.95",
  806. "val/box_loss",
  807. "val/obj_loss",
  808. "val/cls_loss",
  809. )
  810. print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
  811. fitness_scores.append(results[2])
  812. # Select the fittest individuals for reproduction using adaptive tournament selection
  813. selected_indices = []
  814. for _ in range(pop_size - elite_size):
  815. # Adaptive tournament size
  816. tournament_size = max(
  817. max(2, tournament_size_min),
  818. int(min(tournament_size_max, pop_size) - (generation / (opt.evolve / 10))),
  819. )
  820. # Perform tournament selection to choose the best individual
  821. tournament_indices = random.sample(range(pop_size), tournament_size)
  822. tournament_fitness = [fitness_scores[j] for j in tournament_indices]
  823. winner_index = tournament_indices[tournament_fitness.index(max(tournament_fitness))]
  824. selected_indices.append(winner_index)
  825. # Add the elite individuals to the selected indices
  826. elite_indices = [i for i in range(pop_size) if fitness_scores[i] in sorted(fitness_scores)[-elite_size:]]
  827. selected_indices.extend(elite_indices)
  828. # Create the next generation through crossover and mutation
  829. next_generation = []
  830. for _ in range(pop_size):
  831. parent1_index = selected_indices[random.randint(0, pop_size - 1)]
  832. parent2_index = selected_indices[random.randint(0, pop_size - 1)]
  833. # Adaptive crossover rate
  834. crossover_rate = max(
  835. crossover_rate_min, min(crossover_rate_max, crossover_rate_max - (generation / opt.evolve))
  836. )
  837. if random.uniform(0, 1) < crossover_rate:
  838. crossover_point = random.randint(1, len(hyp_GA) - 1)
  839. child = population[parent1_index][:crossover_point] + population[parent2_index][crossover_point:]
  840. else:
  841. child = population[parent1_index]
  842. # Adaptive mutation rate
  843. mutation_rate = max(
  844. mutation_rate_min, min(mutation_rate_max, mutation_rate_max - (generation / opt.evolve))
  845. )
  846. for j in range(len(hyp_GA)):
  847. if random.uniform(0, 1) < mutation_rate:
  848. child[j] += random.uniform(-0.1, 0.1)
  849. child[j] = min(max(child[j], gene_ranges[j][0]), gene_ranges[j][1])
  850. next_generation.append(child)
  851. # Replace the old population with the new generation
  852. population = next_generation
  853. # Print the best solution found
  854. best_index = fitness_scores.index(max(fitness_scores))
  855. best_individual = population[best_index]
  856. print("Best solution found:", best_individual)
  857. # Plot results
  858. plot_evolve(evolve_csv)
  859. LOGGER.info(
  860. f'Hyperparameter evolution finished {opt.evolve} generations\n'
  861. f"Results saved to {colorstr('bold', save_dir)}\n"
  862. f'Usage example: $ python train.py --hyp {evolve_yaml}'
  863. )
  864. def generate_individual(input_ranges, individual_length):
  865. """Generates a list of random values within specified input ranges for each gene in the individual."""
  866. individual = []
  867. for i in range(individual_length):
  868. lower_bound, upper_bound = input_ranges[i]
  869. individual.append(random.uniform(lower_bound, upper_bound))
  870. return individual
  871. def run(**kwargs):
  872. """
  873. Executes YOLOv5 training with given options, overriding with any kwargs provided.
  874. Example: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  875. """
  876. opt = parse_opt(True)
  877. for k, v in kwargs.items():
  878. setattr(opt, k, v)
  879. main(opt)
  880. return opt
  881. if __name__ == "__main__":
  882. opt = parse_opt()
  883. main(opt)