123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534 |
- # -*- coding: utf-8 -*-
- # @ File : main.py
- # @ Author : Guido LuXiaohao
- # @ Date : 2021/8/19
- # @ Software : PyCharm
- # @ Description: 代码文件描述。
- from copy import deepcopy
- import os
- import random
- import sys
- import time
- import warnings
- from argparse import ArgumentParser
- from pathlib import Path
- import numpy as np
- import torch
- import torch.distributed as dist
- from torch.nn.parallel import DistributedDataParallel as DDP
- from torch.utils.data import DataLoader, distributed
- from tqdm import tqdm
- from dataset import TrainSDK, custom_collate_fn
- from dataset.data_process import DataProcess
- from dataset.utils import parse_metadata_info
- from utils.config import Config
- from utils.earlyStopping import EarlyStopping
- from utils.plot_log import draw_log
- from utils.record_log import record_log
- from utils.scheduler import WarmupCosineLR
- from utils.utils import (ModelEMA, de_parallel, get_params, intersect_dicts,
- select_device, start_resume, convert_nested_tensors_to_device)
- from val import evaluate
- warnings.filterwarnings('ignore')
- sys.setrecursionlimit(1000000) # solve problem 'maximum recursion depth exceeded'
- FILE = Path(__file__).resolve()
- ROOT = FILE.parents[0] # project root directory
- if str(ROOT) not in sys.path:
- sys.path.append(str(ROOT)) # add ROOT to PATH
- ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative path
- LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
- RANK = int(os.getenv('RANK', -1))
- WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
- def train(cfg,
- model,
- train_dataset,
- val_dataset=None,
- optimizer=None,
- save=False,
- epochs=100,
- batch_size=2,
- weights=None,
- freeze=None,
- val_interval=5,
- num_workers=0,
- amp=False,
- device=torch.device('cuda')):
- nd = torch.cuda.device_count() # number of CUDA devices
- workers = min([os.cpu_count() // max(nd, 1),
- batch_size if batch_size > 1 else 0,
- num_workers]) # number of workers
- cuda = device.type == 'cuda'
- last_model = cfg['save_dir'] + '/last.pth'
- best_model = cfg['save_dir'] + '/best.pth'
- # Model
- model = model.to(device)
- pretrained = weights.endswith('.pth')
- if pretrained:
- ckpt = torch.load(weights, map_location='cpu')
- csd = ckpt.get("ema") or ckpt['model']
- csd = intersect_dicts(csd, model.state_dict())
- model.load_state_dict(csd, strict=False) # load
- # Freeze
- freeze = [f'backbone.features.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]
- for name, param in model.named_parameters():
- param.requires_grad = True # train all layers
- if any(x in name for x in freeze):
- print(f'freezing {name}')
- param.requires_grad = False
- # EMA
- ema = ModelEMA(model) if RANK in {-1, 0} else None
- # log configs
- if RANK in {-1, 0}:
- recorder.record_args(str(get_params(model) / 1e6) + ' M')
- cfg.dump(save_dir=cfg['save_dir'])
- TrainSDK.save_output_model(os.path.join(cfg['save_dir'], 'config.yaml'))
- # Resume
- best_fitness, start_epoch = 0.0, 1
- epoch_list, loss_tr_list, loss_val_list, miou_list = [], [], [], []
- if pretrained:
- if args.resume:
- logger, lines = recorder.resume_logfile()
- for index, line in enumerate(lines):
- loss_tr_list.append(float(line.strip().split()[2]))
- if len(line.strip().split()) != 3:
- epoch_list.append(int(line.strip().split()[0]))
- loss_val_list.append(float(line.strip().split()[3]))
- miou_list.append(float(line.strip().split()[5]))
- best_fitness, start_epoch, epochs = start_resume(ckpt, optimizer, ema, weights, epochs, args.resume)
- else:
- logger = recorder.initial_logfile()
- logger.flush()
- del ckpt, csd
- else:
- logger = recorder.initial_logfile()
- logger.flush()
- if cuda and RANK != -1:
- # Sync batch norm
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
- # DDP
- model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
- train_sampler = None if RANK == -1 else \
- distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
- train_loader = DataLoader(
- train_dataset,
- batch_size=batch_size // WORLD_SIZE,
- shuffle=True if train_sampler is None else False,
- sampler=train_sampler,
- num_workers=workers,
- collate_fn=custom_collate_fn,
- pin_memory=True,
- drop_last=True)
- if RANK in {-1, 0}:
- val_loader = DataLoader(
- val_dataset,
- batch_size=batch_size // WORLD_SIZE,
- num_workers=workers,
- collate_fn=custom_collate_fn,
- pin_memory=True,
- drop_last=True)
- else:
- val_loader = None
- # Scheduler
- nb = len(train_loader) # number of batches
- nw = max(round(cfg['lr_scheduler']['warmup_epochs'] * nb), 100) # number of warmup iterations
- scheduler = WarmupCosineLR(
- optimizer, 1e-9, cfg.learning_rate, nw, cfg.epochs * nb, 0.1)
- print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
- ">>>>>>>>>>> beginning training >>>>>>>>>>>\n"
- ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
- # Start training
- t0 = time.time()
- scheduler.last_epoch = start_epoch - 2
- early_stopping, stop = EarlyStopping(patience=int(epochs / 5)), False
- scaler = torch.cuda.amp.GradScaler(enabled=amp)
- for epoch in range(start_epoch, epochs + 1): # epoch -------------------------------------------------
- model.train()
- epoch_loss = []
- if RANK != -1:
- train_loader.sampler.set_epoch(epoch)
- pbar = enumerate(train_loader)
- if RANK in {-1, 0}:
- pbar = tqdm(pbar, total=len(train_loader), desc='Epoch {}/{}'.format(epoch, epochs))
- optimizer.zero_grad()
- for step, (images, data_samples) in pbar: # batch ---------------------------------------------------------
- images = images.to(device, non_blocking=True)
- data_samples = convert_nested_tensors_to_device(data_samples, device, non_blocking=True)
- # Forward
- with torch.cuda.amp.autocast(amp):
- losses = model(images, data_samples, mode='loss')
- loss = sum(losses.values())
- if RANK != -1: # gradient averaged between devices in DDP mode
- loss *= WORLD_SIZE
- # Backward
- scaler.scale(loss).backward()
- # Optimize
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- if ema:
- ema.update(model)
- # Scheduler
- scheduler.step()
- epoch_loss.append(loss.item())
- # end batch -------------------------------------------------------------------------------------
- loss_tr = sum(epoch_loss) / len(epoch_loss)
- loss_tr_list.append(loss_tr)
- lr = optimizer.param_groups[0]['lr']
- recorder.record_train_log(logger, epoch, lr, loss_tr)
- # Validation
- if RANK in {-1, 0}:
- if (epoch % val_interval == 0 or epoch == epochs) and (val_loader is not None):
- metrics = evaluate(
- cfg,
- model=ema.ema,
- eval_loader=val_loader,
- scales=[1.0],
- overlap=0.0,
- save=False)
- epoch_list.append(epoch)
- loss_val_list.append(loss.item())
- miou = np.mean([v for m, v in metrics.items() if m.endswith('.mIoU')])
- miou_list.append(miou)
- map = np.mean([v for m, v in metrics.items() if m.endswith('.mAP')])
- # manually configured metrics
- w = [0.9, 0.1]
- fitness = np.sum(np.multiply(
- w, [miou, map]))
- # Update best mIoU
- if fitness > best_fitness:
- best_fitness = fitness
- mpa = np.mean([v for m, v in metrics.items() if m.endswith('.aAcc')])
- recorder.record_best_epoch(epoch, best_fitness, mpa)
- early_stopping.monitor(monitor=miou)
- stop = early_stopping.early_stop
- # Save model
- ckpt = {
- "epoch": epoch,
- "best_fitness": best_fitness,
- "model": de_parallel(model).state_dict(),
- "ema": ema.ema.state_dict(),
- "updates": ema.updates,
- "optimizer": optimizer.state_dict()}
- # Save last, best and delete
- torch.save(ckpt, last_model)
- time.sleep(0.1) # 延时以保证torch.save进程结束后才copy结果到输出目录
- TrainSDK.save_output_model(last_model) # 服务器端保存模型
- if best_fitness == fitness:
- torch.save(ckpt, best_model)
- time.sleep(0.1) # 延时以保证torch.save进程结束后才copy结果到输出目录
- TrainSDK.save_output_model(best_model)
- del ckpt
- # EarlyStopping
- if stop:
- print(f"Early stopping at epoch {epoch}")
- break
- # Draw log fig
- draw_log(cfg['save_dir'], val_interval, epoch, epoch_list, loss_tr_list, miou_list, loss_val_list)
- print('Best Validation fitness:{:.4f}\n'.format(best_fitness))
- # end epoch ---------------------------------------------------------------------------
- # end training ----------------------------------------------------------------------------
- if RANK in {-1, 0}:
- print(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
- if Path(best_model).exists():
- print(f'Validating {best_model}')
- model = cfg.model.to(device) # create model
- ckpt = torch.load(best_model, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
- csd = ckpt.get("ema") or ckpt['model'] # checkpoint state_dict as FP32
- model.load_state_dict(csd, strict=True) # load
- metrics = evaluate(
- cfg,
- model=model,
- eval_loader=val_loader,
- scales=[1.0],
- overlap=0.0,
- save=save)
- miou = np.mean([v for m, v in metrics.items() if m.endswith('.mIoU')])
- all_precision = np.mean([v for m, v in metrics.items() if m.endswith('.aPrecision')])
- all_recall = np.mean([v for m, v in metrics.items() if m.endswith('.aRecall')])
- map50 = np.mean([v for m, v in metrics.items() if m.endswith('.mAP50')])
- map = np.mean([v for m, v in metrics.items() if m.endswith('.mAP')])
- print('best validation at epoch {}\n'
- 'miou: {:.4f}\n'
- 'total precision: {:.4f}, total recall: {:.4f}, map50: {:.4f}, map50-95: {:.4f}\n'
- .format(ckpt['epoch'], miou, all_precision, all_recall, map50, map))
- logger.close()
- torch.cuda.empty_cache()
- if RANK != -1:
- dist.destroy_process_group()
- def parse_args():
- parser = ArgumentParser(description='Semantic segmentation with pytorch')
- parser.add_argument(
- '--cfg',
- help="training configuration file",
- type=str,
- default=ROOT / 'model.yaml')
- parser.add_argument(
- '--weights',
- help="initial weights path",
- type=str,
- default='')
- parser.add_argument(
- '--batch_size',
- help="total batch size for all GPUs, -1 for autobatch",
- type=int,
- default=4)
- parser.add_argument(
- '--num_workers',
- help=" the number of parallel threads",
- type=int,
- default=1)
- parser.add_argument(
- '--max_epochs',
- help="训练整个数据集的最大轮次",
- type=int,
- default=10)
- parser.add_argument(
- '--val_epochs',
- help="每隔固定轮次预测验证集",
- type=int,
- default=1)
- parser.add_argument(
- '--learning_rate',
- help="初始化学习率",
- type=float,
- default=1e-3)
- parser.add_argument(
- '--token',
- help="AI平台训练码",
- type=str,
- default="4925EC4929684AA0ABB0173B03CFC8FF")
- parser.add_argument(
- '--val_ratio',
- help="在线随机选取验证集(验证集占总数据集比例),剩下的作为训练集",
- type=float,
- default=0.2)
- parser.add_argument(
- '--save',
- help="save validation results",
- action='store_const',
- const=True,
- default=False)
- parser.add_argument(
- '--resume',
- help='resume recent training',
- action='store_const',
- const=True,
- default=False)
- parser.add_argument(
- '--freeze',
- help='Freeze layers. Default: [0] for not freezing layers',
- type=list,
- default=[0])
- parser.add_argument(
- '--seed',
- help="固定全局随机数,固定该数,可以固定在线随机划分数据集",
- type=int,
- default=42)
- parser.add_argument(
- '--amp',
- help='enable automatic mixed precision (AMP) mode',
- action='store_true')
- parser.add_argument(
- '--device',
- help="cuda device, i.e. 0 or 0,1,2,3 or cpu",
- type=str,
- default="0")
- return parser.parse_args()
- def main(args):
- if args.seed is not None:
- random.seed(args.seed)
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- torch.cuda.manual_seed_all(args.seed)
- device = select_device(args.device, args.batch_size)
- # DDP mode
- if LOCAL_RANK != -1:
- assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
- print("torch.cuda.device_count():", torch.cuda.device_count())
- torch.cuda.set_device(LOCAL_RANK)
- device = torch.device("cuda", LOCAL_RANK)
- dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
- if not args.cfg:
- raise RuntimeError('No configuration file specified.')
- cfg = Config(
- args.cfg,
- learning_rate=args.learning_rate,
- batch_size=args.batch_size,
- epochs=args.max_epochs)
- # Directories
- cfg['save_dir'] = str(ROOT / f"checkpoints/{cfg['dataset_name']}-{cfg['model_name']}")
- recorder.save_dir = cfg['save_dir']
- if not os.path.exists(cfg['save_dir']) and RANK in {-1, 0}:
- os.makedirs(cfg['save_dir'])
- cfg['save_seg_dir'] = os.path.join(cfg['save_dir'], 'prediction')
- if not os.path.exists(cfg['save_seg_dir']) and RANK in {-1, 0}:
- os.makedirs(cfg['save_seg_dir'])
- # Dataset properties
- tasks, part_cat, roi_cat, cat_id, label_map = \
- parse_metadata_info(metadata=deepcopy(cfg["metadata"]))
- cfg.update(**{
- 'tasks': tasks,
- 'part_cat': part_cat,
- 'roi_cat': roi_cat,
- 'cat_id': cat_id,
- 'label_map': label_map
- })
- # Data preprocessing
- data_processor = DataProcess(
- token=args.token,
- val_ratio=args.val_ratio,
- image_cat=cfg["metric_args"]["needed_imageresults_dict"],
- part_cat=part_cat,
- roi_cat=roi_cat,
- cat_id=cat_id,
- ignore=cfg["preprocess_args"].get("ignore", None),
- video_mode=cfg["preprocess_args"].get("split_by_snippet", False))
- trainval_indexes = data_processor.get_train_val_indexes
- train_data_index_list = trainval_indexes["train"]
- val_data_index_list = trainval_indexes["val"]
- # update dataset info
- cfg['train_dataset'].update({
- 'token': args.token,
- 'tasks': tasks,
- 'data_index_list': train_data_index_list,
- 'class_index_map': cat_id})
- cfg['val_dataset'].update({
- 'token': args.token,
- 'tasks': tasks,
- 'data_index_list': val_data_index_list,
- 'class_index_map': cat_id})
- # train data
- train_dataset = cfg.train_dataset
- # 记录训练集信息
- if RANK in {-1, 0}:
- recorder.record_data_slice_info(
- "train_data_info.txt", None, None, None,
- train_dataset.data_index_list if hasattr(train_dataset, 'data_index_list') else None)
- # validation data
- if RANK in {-1, 0}:
- val_dataset = cfg.val_dataset
- # 记录验证集信息
- recorder.record_data_slice_info(
- "val_data_info.txt", None, None, None,
- val_dataset.data_index_list if hasattr(val_dataset, 'data_index_list') else None)
- else:
- val_dataset = None
- # 绘制实例分布图
- if RANK in {-1, 0}:
- import matplotlib.pyplot as plt
- plt.rcParams['font.family'] = 'sans-serif' # 指定默认字体
- plt.rcParams['font.sans-serif'] = ['SimSun'] # 使用宋体
- for dst in [train_dataset, val_dataset]:
- if len(tasks) == 1:
- plt.clf()
- plt.bar(dst.class_names, dst.class_counts)
- plt.ylabel('实例数')
- plt.title('实例数直方图')
- fig_path = os.path.join(cfg['save_dir'], f'{dst.dataset_mode}_dataset_label_stats.png')
- plt.savefig(fig_path)
- TrainSDK.save_output_model(fig_path)
- else:
- for task_id, task in enumerate(tasks):
- plt.clf()
- plt.bar(dst.class_names[task_id], dst.class_counts[task_id])
- plt.ylabel('实例数')
- plt.title(f'Task ID: {task_id} Task Name: {task}')
- fig_path = os.path.join(cfg['save_dir'], f'{dst.dataset_mode}_taskid{task_id}_label_stats.png')
- plt.savefig(fig_path)
- TrainSDK.save_output_model(fig_path)
- train(
- cfg,
- cfg.model,
- train_dataset,
- val_dataset=val_dataset,
- optimizer=cfg.optimizer,
- save=args.save,
- epochs=cfg.epochs,
- batch_size=cfg.batch_size,
- weights=args.weights,
- freeze=args.freeze,
- val_interval=args.val_epochs,
- num_workers=args.num_workers,
- amp=args.amp,
- device=device)
- if __name__ == '__main__':
- args = parse_args()
- # Loggers
- recorder = record_log(args)
- start = time.time()
- main(args)
- end = time.time()
- hour = 1.0 * (end - start) / 3600
- minute = (hour - int(hour)) * 60
- if RANK in {-1, 0}:
- print("training time: %d hour %d minutes" % (int(hour), int(minute)))
|