main.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # -*- coding: utf-8 -*-
  2. # @ File : main.py
  3. # @ Author : Guido LuXiaohao
  4. # @ Date : 2021/8/19
  5. # @ Software : PyCharm
  6. # @ Description: 代码文件描述。
  7. from copy import deepcopy
  8. import os
  9. import random
  10. import sys
  11. import time
  12. import warnings
  13. from argparse import ArgumentParser
  14. from pathlib import Path
  15. import numpy as np
  16. import torch
  17. import torch.distributed as dist
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from torch.utils.data import DataLoader, distributed
  20. from tqdm import tqdm
  21. from dataset import TrainSDK, custom_collate_fn
  22. from dataset.data_process import DataProcess
  23. from dataset.utils import parse_metadata_info
  24. from utils.config import Config
  25. from utils.earlyStopping import EarlyStopping
  26. from utils.plot_log import draw_log
  27. from utils.record_log import record_log
  28. from utils.scheduler import WarmupCosineLR
  29. from utils.utils import (ModelEMA, de_parallel, get_params, intersect_dicts,
  30. select_device, start_resume, convert_nested_tensors_to_device)
  31. from val import evaluate
  32. warnings.filterwarnings('ignore')
  33. sys.setrecursionlimit(1000000) # solve problem 'maximum recursion depth exceeded'
  34. FILE = Path(__file__).resolve()
  35. ROOT = FILE.parents[0] # project root directory
  36. if str(ROOT) not in sys.path:
  37. sys.path.append(str(ROOT)) # add ROOT to PATH
  38. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative path
  39. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
  40. RANK = int(os.getenv('RANK', -1))
  41. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  42. def train(cfg,
  43. model,
  44. train_dataset,
  45. val_dataset=None,
  46. optimizer=None,
  47. save=False,
  48. epochs=100,
  49. batch_size=2,
  50. weights=None,
  51. freeze=None,
  52. val_interval=5,
  53. num_workers=0,
  54. amp=False,
  55. device=torch.device('cuda')):
  56. nd = torch.cuda.device_count() # number of CUDA devices
  57. workers = min([os.cpu_count() // max(nd, 1),
  58. batch_size if batch_size > 1 else 0,
  59. num_workers]) # number of workers
  60. cuda = device.type == 'cuda'
  61. last_model = cfg['save_dir'] + '/last.pth'
  62. best_model = cfg['save_dir'] + '/best.pth'
  63. # Model
  64. model = model.to(device)
  65. pretrained = weights.endswith('.pth')
  66. if pretrained:
  67. ckpt = torch.load(weights, map_location='cpu')
  68. csd = ckpt.get("ema") or ckpt['model']
  69. csd = intersect_dicts(csd, model.state_dict())
  70. model.load_state_dict(csd, strict=False) # load
  71. # Freeze
  72. freeze = [f'backbone.features.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]
  73. for name, param in model.named_parameters():
  74. param.requires_grad = True # train all layers
  75. if any(x in name for x in freeze):
  76. print(f'freezing {name}')
  77. param.requires_grad = False
  78. # EMA
  79. ema = ModelEMA(model) if RANK in {-1, 0} else None
  80. # log configs
  81. if RANK in {-1, 0}:
  82. recorder.record_args(str(get_params(model) / 1e6) + ' M')
  83. cfg.dump(save_dir=cfg['save_dir'])
  84. TrainSDK.save_output_model(os.path.join(cfg['save_dir'], 'config.yaml'))
  85. # Resume
  86. best_fitness, start_epoch = 0.0, 1
  87. epoch_list, loss_tr_list, loss_val_list, miou_list = [], [], [], []
  88. if pretrained:
  89. if args.resume:
  90. logger, lines = recorder.resume_logfile()
  91. for index, line in enumerate(lines):
  92. loss_tr_list.append(float(line.strip().split()[2]))
  93. if len(line.strip().split()) != 3:
  94. epoch_list.append(int(line.strip().split()[0]))
  95. loss_val_list.append(float(line.strip().split()[3]))
  96. miou_list.append(float(line.strip().split()[5]))
  97. best_fitness, start_epoch, epochs = start_resume(ckpt, optimizer, ema, weights, epochs, args.resume)
  98. else:
  99. logger = recorder.initial_logfile()
  100. logger.flush()
  101. del ckpt, csd
  102. else:
  103. logger = recorder.initial_logfile()
  104. logger.flush()
  105. if cuda and RANK != -1:
  106. # Sync batch norm
  107. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  108. # DDP
  109. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  110. train_sampler = None if RANK == -1 else \
  111. distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
  112. train_loader = DataLoader(
  113. train_dataset,
  114. batch_size=batch_size // WORLD_SIZE,
  115. shuffle=True if train_sampler is None else False,
  116. sampler=train_sampler,
  117. num_workers=workers,
  118. collate_fn=custom_collate_fn,
  119. pin_memory=True,
  120. drop_last=True)
  121. if RANK in {-1, 0}:
  122. val_loader = DataLoader(
  123. val_dataset,
  124. batch_size=batch_size // WORLD_SIZE,
  125. num_workers=workers,
  126. collate_fn=custom_collate_fn,
  127. pin_memory=True,
  128. drop_last=True)
  129. else:
  130. val_loader = None
  131. # Scheduler
  132. nb = len(train_loader) # number of batches
  133. nw = max(round(cfg['lr_scheduler']['warmup_epochs'] * nb), 100) # number of warmup iterations
  134. scheduler = WarmupCosineLR(
  135. optimizer, 1e-9, cfg.learning_rate, nw, cfg.epochs * nb, 0.1)
  136. print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
  137. ">>>>>>>>>>> beginning training >>>>>>>>>>>\n"
  138. ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
  139. # Start training
  140. t0 = time.time()
  141. scheduler.last_epoch = start_epoch - 2
  142. early_stopping, stop = EarlyStopping(patience=int(epochs / 5)), False
  143. scaler = torch.cuda.amp.GradScaler(enabled=amp)
  144. for epoch in range(start_epoch, epochs + 1): # epoch -------------------------------------------------
  145. model.train()
  146. epoch_loss = []
  147. if RANK != -1:
  148. train_loader.sampler.set_epoch(epoch)
  149. pbar = enumerate(train_loader)
  150. if RANK in {-1, 0}:
  151. pbar = tqdm(pbar, total=len(train_loader), desc='Epoch {}/{}'.format(epoch, epochs))
  152. optimizer.zero_grad()
  153. for step, (images, data_samples) in pbar: # batch ---------------------------------------------------------
  154. images = images.to(device, non_blocking=True)
  155. data_samples = convert_nested_tensors_to_device(data_samples, device, non_blocking=True)
  156. # Forward
  157. with torch.cuda.amp.autocast(amp):
  158. losses = model(images, data_samples, mode='loss')
  159. loss = sum(losses.values())
  160. if RANK != -1: # gradient averaged between devices in DDP mode
  161. loss *= WORLD_SIZE
  162. # Backward
  163. scaler.scale(loss).backward()
  164. # Optimize
  165. scaler.step(optimizer)
  166. scaler.update()
  167. optimizer.zero_grad()
  168. if ema:
  169. ema.update(model)
  170. # Scheduler
  171. scheduler.step()
  172. epoch_loss.append(loss.item())
  173. # end batch -------------------------------------------------------------------------------------
  174. loss_tr = sum(epoch_loss) / len(epoch_loss)
  175. loss_tr_list.append(loss_tr)
  176. lr = optimizer.param_groups[0]['lr']
  177. recorder.record_train_log(logger, epoch, lr, loss_tr)
  178. # Validation
  179. if RANK in {-1, 0}:
  180. if (epoch % val_interval == 0 or epoch == epochs) and (val_loader is not None):
  181. metrics = evaluate(
  182. cfg,
  183. model=ema.ema,
  184. eval_loader=val_loader,
  185. scales=[1.0],
  186. overlap=0.0,
  187. save=False)
  188. epoch_list.append(epoch)
  189. loss_val_list.append(loss.item())
  190. miou = np.mean([v for m, v in metrics.items() if m.endswith('.mIoU')])
  191. miou_list.append(miou)
  192. map = np.mean([v for m, v in metrics.items() if m.endswith('.mAP')])
  193. # manually configured metrics
  194. w = [0.9, 0.1]
  195. fitness = np.sum(np.multiply(
  196. w, [miou, map]))
  197. # Update best mIoU
  198. if fitness > best_fitness:
  199. best_fitness = fitness
  200. mpa = np.mean([v for m, v in metrics.items() if m.endswith('.aAcc')])
  201. recorder.record_best_epoch(epoch, best_fitness, mpa)
  202. early_stopping.monitor(monitor=miou)
  203. stop = early_stopping.early_stop
  204. # Save model
  205. ckpt = {
  206. "epoch": epoch,
  207. "best_fitness": best_fitness,
  208. "model": de_parallel(model).state_dict(),
  209. "ema": ema.ema.state_dict(),
  210. "updates": ema.updates,
  211. "optimizer": optimizer.state_dict()}
  212. # Save last, best and delete
  213. torch.save(ckpt, last_model)
  214. time.sleep(0.1) # 延时以保证torch.save进程结束后才copy结果到输出目录
  215. TrainSDK.save_output_model(last_model) # 服务器端保存模型
  216. if best_fitness == fitness:
  217. torch.save(ckpt, best_model)
  218. time.sleep(0.1) # 延时以保证torch.save进程结束后才copy结果到输出目录
  219. TrainSDK.save_output_model(best_model)
  220. del ckpt
  221. # EarlyStopping
  222. if stop:
  223. print(f"Early stopping at epoch {epoch}")
  224. break
  225. # Draw log fig
  226. draw_log(cfg['save_dir'], val_interval, epoch, epoch_list, loss_tr_list, miou_list, loss_val_list)
  227. print('Best Validation fitness:{:.4f}\n'.format(best_fitness))
  228. # end epoch ---------------------------------------------------------------------------
  229. # end training ----------------------------------------------------------------------------
  230. if RANK in {-1, 0}:
  231. print(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  232. if Path(best_model).exists():
  233. print(f'Validating {best_model}')
  234. model = cfg.model.to(device) # create model
  235. ckpt = torch.load(best_model, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  236. csd = ckpt.get("ema") or ckpt['model'] # checkpoint state_dict as FP32
  237. model.load_state_dict(csd, strict=True) # load
  238. metrics = evaluate(
  239. cfg,
  240. model=model,
  241. eval_loader=val_loader,
  242. scales=[1.0],
  243. overlap=0.0,
  244. save=save)
  245. miou = np.mean([v for m, v in metrics.items() if m.endswith('.mIoU')])
  246. all_precision = np.mean([v for m, v in metrics.items() if m.endswith('.aPrecision')])
  247. all_recall = np.mean([v for m, v in metrics.items() if m.endswith('.aRecall')])
  248. map50 = np.mean([v for m, v in metrics.items() if m.endswith('.mAP50')])
  249. map = np.mean([v for m, v in metrics.items() if m.endswith('.mAP')])
  250. print('best validation at epoch {}\n'
  251. 'miou: {:.4f}\n'
  252. 'total precision: {:.4f}, total recall: {:.4f}, map50: {:.4f}, map50-95: {:.4f}\n'
  253. .format(ckpt['epoch'], miou, all_precision, all_recall, map50, map))
  254. logger.close()
  255. torch.cuda.empty_cache()
  256. if RANK != -1:
  257. dist.destroy_process_group()
  258. def parse_args():
  259. parser = ArgumentParser(description='Semantic segmentation with pytorch')
  260. parser.add_argument(
  261. '--cfg',
  262. help="training configuration file",
  263. type=str,
  264. default=ROOT / 'model.yaml')
  265. parser.add_argument(
  266. '--weights',
  267. help="initial weights path",
  268. type=str,
  269. default='')
  270. parser.add_argument(
  271. '--batch_size',
  272. help="total batch size for all GPUs, -1 for autobatch",
  273. type=int,
  274. default=4)
  275. parser.add_argument(
  276. '--num_workers',
  277. help=" the number of parallel threads",
  278. type=int,
  279. default=1)
  280. parser.add_argument(
  281. '--max_epochs',
  282. help="训练整个数据集的最大轮次",
  283. type=int,
  284. default=10)
  285. parser.add_argument(
  286. '--val_epochs',
  287. help="每隔固定轮次预测验证集",
  288. type=int,
  289. default=1)
  290. parser.add_argument(
  291. '--learning_rate',
  292. help="初始化学习率",
  293. type=float,
  294. default=1e-3)
  295. parser.add_argument(
  296. '--token',
  297. help="AI平台训练码",
  298. type=str,
  299. default="4925EC4929684AA0ABB0173B03CFC8FF")
  300. parser.add_argument(
  301. '--val_ratio',
  302. help="在线随机选取验证集(验证集占总数据集比例),剩下的作为训练集",
  303. type=float,
  304. default=0.2)
  305. parser.add_argument(
  306. '--save',
  307. help="save validation results",
  308. action='store_const',
  309. const=True,
  310. default=False)
  311. parser.add_argument(
  312. '--resume',
  313. help='resume recent training',
  314. action='store_const',
  315. const=True,
  316. default=False)
  317. parser.add_argument(
  318. '--freeze',
  319. help='Freeze layers. Default: [0] for not freezing layers',
  320. type=list,
  321. default=[0])
  322. parser.add_argument(
  323. '--seed',
  324. help="固定全局随机数,固定该数,可以固定在线随机划分数据集",
  325. type=int,
  326. default=42)
  327. parser.add_argument(
  328. '--amp',
  329. help='enable automatic mixed precision (AMP) mode',
  330. action='store_true')
  331. parser.add_argument(
  332. '--device',
  333. help="cuda device, i.e. 0 or 0,1,2,3 or cpu",
  334. type=str,
  335. default="0")
  336. return parser.parse_args()
  337. def main(args):
  338. if args.seed is not None:
  339. random.seed(args.seed)
  340. np.random.seed(args.seed)
  341. torch.manual_seed(args.seed)
  342. torch.cuda.manual_seed_all(args.seed)
  343. device = select_device(args.device, args.batch_size)
  344. # DDP mode
  345. if LOCAL_RANK != -1:
  346. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  347. print("torch.cuda.device_count():", torch.cuda.device_count())
  348. torch.cuda.set_device(LOCAL_RANK)
  349. device = torch.device("cuda", LOCAL_RANK)
  350. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  351. if not args.cfg:
  352. raise RuntimeError('No configuration file specified.')
  353. cfg = Config(
  354. args.cfg,
  355. learning_rate=args.learning_rate,
  356. batch_size=args.batch_size,
  357. epochs=args.max_epochs)
  358. # Directories
  359. cfg['save_dir'] = str(ROOT / f"checkpoints/{cfg['dataset_name']}-{cfg['model_name']}")
  360. recorder.save_dir = cfg['save_dir']
  361. if not os.path.exists(cfg['save_dir']) and RANK in {-1, 0}:
  362. os.makedirs(cfg['save_dir'])
  363. cfg['save_seg_dir'] = os.path.join(cfg['save_dir'], 'prediction')
  364. if not os.path.exists(cfg['save_seg_dir']) and RANK in {-1, 0}:
  365. os.makedirs(cfg['save_seg_dir'])
  366. # Dataset properties
  367. tasks, part_cat, roi_cat, cat_id, label_map = \
  368. parse_metadata_info(metadata=deepcopy(cfg["metadata"]))
  369. cfg.update(**{
  370. 'tasks': tasks,
  371. 'part_cat': part_cat,
  372. 'roi_cat': roi_cat,
  373. 'cat_id': cat_id,
  374. 'label_map': label_map
  375. })
  376. # Data preprocessing
  377. data_processor = DataProcess(
  378. token=args.token,
  379. val_ratio=args.val_ratio,
  380. image_cat=cfg["metric_args"]["needed_imageresults_dict"],
  381. part_cat=part_cat,
  382. roi_cat=roi_cat,
  383. cat_id=cat_id,
  384. ignore=cfg["preprocess_args"].get("ignore", None),
  385. video_mode=cfg["preprocess_args"].get("split_by_snippet", False))
  386. trainval_indexes = data_processor.get_train_val_indexes
  387. train_data_index_list = trainval_indexes["train"]
  388. val_data_index_list = trainval_indexes["val"]
  389. # update dataset info
  390. cfg['train_dataset'].update({
  391. 'token': args.token,
  392. 'tasks': tasks,
  393. 'data_index_list': train_data_index_list,
  394. 'class_index_map': cat_id})
  395. cfg['val_dataset'].update({
  396. 'token': args.token,
  397. 'tasks': tasks,
  398. 'data_index_list': val_data_index_list,
  399. 'class_index_map': cat_id})
  400. # train data
  401. train_dataset = cfg.train_dataset
  402. # 记录训练集信息
  403. if RANK in {-1, 0}:
  404. recorder.record_data_slice_info(
  405. "train_data_info.txt", None, None, None,
  406. train_dataset.data_index_list if hasattr(train_dataset, 'data_index_list') else None)
  407. # validation data
  408. if RANK in {-1, 0}:
  409. val_dataset = cfg.val_dataset
  410. # 记录验证集信息
  411. recorder.record_data_slice_info(
  412. "val_data_info.txt", None, None, None,
  413. val_dataset.data_index_list if hasattr(val_dataset, 'data_index_list') else None)
  414. else:
  415. val_dataset = None
  416. # 绘制实例分布图
  417. if RANK in {-1, 0}:
  418. import matplotlib.pyplot as plt
  419. plt.rcParams['font.family'] = 'sans-serif' # 指定默认字体
  420. plt.rcParams['font.sans-serif'] = ['SimSun'] # 使用宋体
  421. for dst in [train_dataset, val_dataset]:
  422. if len(tasks) == 1:
  423. plt.clf()
  424. plt.bar(dst.class_names, dst.class_counts)
  425. plt.ylabel('实例数')
  426. plt.title('实例数直方图')
  427. fig_path = os.path.join(cfg['save_dir'], f'{dst.dataset_mode}_dataset_label_stats.png')
  428. plt.savefig(fig_path)
  429. TrainSDK.save_output_model(fig_path)
  430. else:
  431. for task_id, task in enumerate(tasks):
  432. plt.clf()
  433. plt.bar(dst.class_names[task_id], dst.class_counts[task_id])
  434. plt.ylabel('实例数')
  435. plt.title(f'Task ID: {task_id} Task Name: {task}')
  436. fig_path = os.path.join(cfg['save_dir'], f'{dst.dataset_mode}_taskid{task_id}_label_stats.png')
  437. plt.savefig(fig_path)
  438. TrainSDK.save_output_model(fig_path)
  439. train(
  440. cfg,
  441. cfg.model,
  442. train_dataset,
  443. val_dataset=val_dataset,
  444. optimizer=cfg.optimizer,
  445. save=args.save,
  446. epochs=cfg.epochs,
  447. batch_size=cfg.batch_size,
  448. weights=args.weights,
  449. freeze=args.freeze,
  450. val_interval=args.val_epochs,
  451. num_workers=args.num_workers,
  452. amp=args.amp,
  453. device=device)
  454. if __name__ == '__main__':
  455. args = parse_args()
  456. # Loggers
  457. recorder = record_log(args)
  458. start = time.time()
  459. main(args)
  460. end = time.time()
  461. hour = 1.0 * (end - start) / 3600
  462. minute = (hour - int(hour)) * 60
  463. if RANK in {-1, 0}:
  464. print("training time: %d hour %d minutes" % (int(hour), int(minute)))