main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import argparse
  2. import os
  3. import random
  4. import time
  5. import warnings
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.parallel
  9. import torch.backends.cudnn as cudnn
  10. from torch.optim import lr_scheduler
  11. from torch.utils.data import DataLoader
  12. import torch.utils.data.distributed
  13. import model.resnext_pytorch as resnext
  14. from model.ghostnet import ghostnet
  15. from model.mobilenetv3 import mobilenetv3
  16. from dataset.dataset_for_platform import BasicDatasetFolder_ForPlatform
  17. from dataset.dataset_for_local import BasicDatasetFolder_ForLocal
  18. from model.efficientnet_pytorch import EfficientNet
  19. from utils.earlyStopping import EarlyStopping
  20. from utils.plot_log import draw_log
  21. import TrainSdk
  22. import logging
  23. import datetime
  24. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  25. parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenetv3',help='model architecture efficientnet-b0 efficientnet-b1 ghostnet MobileNetV3 resnext50 mobilenetv3')
  26. parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
  27. parser.add_argument('--epochs', default=50, type=int, metavar='N', help='number of total epochs to run')
  28. parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
  29. parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N')
  30. parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,metavar='LR', help='initial learning rate', dest='lr')
  31. parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
  32. parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay')
  33. parser.add_argument('-p', '--print-freq', default=20, type=int, metavar='N', help='print frequency (default: 10)')
  34. parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
  35. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate model on validation set')
  36. parser.add_argument('--rank', default=1, type=int, help='node rank for distributed training')
  37. parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
  38. parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
  39. parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
  40. parser.add_argument('--image_size', default=256, type=int, help='image size')
  41. parser.add_argument('--class_nums', default=9, type=int, help='分类的类别数量')
  42. parser.add_argument('--checkpoints', default=os.path.join(os.getcwd(),'checkpoints'), type=str, metavar='PATH', help='path of checkpoint')
  43. parser.add_argument('--validate_frequency', default=1, type=int, help='每多少个epoch测试一下val')
  44. #若本地训练,选'dataset_for_local'模式,平台训练选'dataset_for_platform'
  45. parser.add_argument('--custom_dataset', default='dataset_for_platform', type=str, help='本地训练:"dataset_for_local" , 平台训练:"dataset_for_platform"')
  46. #判断在平台还是在本地训练
  47. parser.add_argument('--is_local_not_platform', default=False, type=bool, help='train local is True, train platform is False')
  48. #本地训练,需要给定数据文件夹
  49. parser.add_argument('--data', default=os.path.join(os.getcwd(), 'data'), help='train local, need set data path, data\\train data\\val')
  50. #在AI平台上训练时在线划分训练测试集,没有预先划分的情况下使用
  51. parser.add_argument('--split_data', default= True, type=bool,help='train platform, need split the training set and the verification set')
  52. parser.add_argument('--split_rate', default=0.8, type=float, help='在线划分数据集的比例,数据总集的80%作为训练集,剩余的作为验证集')
  53. classes_dict = {'BMode':0, 'BModeBlood':1,'Pseudocolor':2, 'PseudocolorBlood':3, 'Spectrogram':4, 'CEUS':5, 'SE':6,'STE':7,'FourDime':8}
  54. parser.add_argument('--class_dict', default=classes_dict, help='每个文件夹的文件名对应其类别')
  55. '''
  56. 给定的data文件夹为root
  57. root/
  58. train/
  59. classA/
  60. folder1/
  61. *.jpg
  62. folder2/
  63. *.jpg
  64. classB/
  65. folder1/
  66. *.jpg
  67. folder2/
  68. *.jpg
  69. ......
  70. val/
  71. classA/
  72. folder1/
  73. *.jpg
  74. folder2/
  75. *.jpg
  76. classB/
  77. folder1/
  78. *.jpg
  79. folder2/
  80. *.jpg
  81. ......
  82. 给定data文件夹,分成train和val两个文件夹,
  83. train或者val文件夹中每个类别class一个文件夹,
  84. 每个类别class文件中,可以存放多个文件夹,用于多个来源,相同类别文件
  85. '''
  86. #平台训练,需要给定token,train和val的图像label的title,所对应的类别
  87. parser.add_argument('--token', default='4925EC4929684AA0ABB0173B03CFC8FF', type=str, help='train platform,need set token') # 52cfaf5a8a364e5fb635c94683963cfe
  88. #dataset_for_usguide 图像增强所需的参数配置
  89. # parser.add_argument('--brightness_radius', default=0.3, type=float,help='输入图像,亮度调整')
  90. # parser.add_argument('--contrast_radius', default=0.3, type=float,help='输入图像,对比度调整')
  91. # parser.add_argument('--saturation_radius', default=0, type=float,help='输入图像,饱和度调整')
  92. logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',level=logging.DEBUG)
  93. def main():
  94. args = parser.parse_args()
  95. args_dict = vars(args)
  96. for k, v in args_dict.items():
  97. print(f"{k}: {v}")
  98. logging.info("para:{} default:{}".format(k, v))
  99. if args.seed is not None:
  100. random.seed(args.seed)
  101. torch.manual_seed(args.seed)
  102. cudnn.deterministic = True
  103. warnings.warn('You have chosen to seed training. '
  104. 'This will turn on the CUDNN deterministic setting, '
  105. 'which can slow down your training considerably! '
  106. 'You may see unexpected behavior when restarting '
  107. 'from checkpoints.')
  108. if args.gpu is not None:
  109. warnings.warn('You have chosen a specific GPU. This will completely '
  110. 'disable data parallelism.')
  111. ngpus_per_node = torch.cuda.device_count()
  112. main_worker(args.gpu, ngpus_per_node, args)
  113. def main_worker(gpu, ngpus_per_node, args):
  114. global best_acc1
  115. args.gpu = gpu
  116. if not os.path.exists(args.checkpoints):
  117. os.mkdir(args.checkpoints)
  118. now_time = datetime.datetime.now()
  119. time_str = datetime.datetime.strftime(now_time, '%m-%d_%H-%M-%S')
  120. save_path = os.path.join(args.checkpoints, args.arch, time_str)
  121. if not os.path.exists(save_path):
  122. os.makedirs(save_path)
  123. if args.gpu is not None:
  124. print("Use GPU: {} for training".format(args.arch))
  125. # logging.info("Use GPU: {} for training".format(args.arch))
  126. # create model
  127. if 'efficientnet' in args.arch: # NEW
  128. logging.info("=> creating model '{}'".format(args.arch))
  129. #bo b1 b2......不同的efficientnet有不同的默认image_size,可参考 model.efficientnet_pytorch.utils.efficientnet_params
  130. model = EfficientNet.from_name(model_name=args.arch, in_channels=3, num_classes=args.class_nums, image_size=args.image_size)
  131. elif 'ghostnet' in args.arch: # NEW
  132. logging.info("=> creating model '{}'".format(args.arch))
  133. model = ghostnet.ghostnet(num_classes=args.class_nums, channels=3)
  134. elif 'mobilenetv3' in args.arch: # NEW
  135. logging.info("=> creating model '{}'".format(args.arch))
  136. # mode分别为 small large Rebbcca_LiverDiffuseLesionClassifier Joseph_USGuide
  137. model = mobilenetv3.MobileNetV3(n_class=args.class_nums, input_size=args.image_size, mode='small',channels=3, dropout=0, width_mult=1.0)
  138. elif 'resnext' in args.arch:
  139. logging.info("=> creating model '{}'".format(args.arch))
  140. model = resnext.from_name(model_name=args.arch, basewidth=4, cardinality=8, class_nums=args.class_nums)
  141. else:
  142. warnings.warn('You have chosen a wrong model.Using a default model instead.')
  143. # model = resnext.from_name(model_name=args.arch, basewidth=4, cardinality=8, class_nums=args.class_nums)
  144. if args.gpu is not None:
  145. torch.cuda.set_device(args.gpu)
  146. model = model.cuda(args.gpu)
  147. else:
  148. model = torch.nn.DataParallel(model).cuda()
  149. print(model)
  150. # define loss function (criterion) and optimizer
  151. criterion = nn.CrossEntropyLoss().cuda(args.gpu)
  152. optimizer = torch.optim.SGD(model.parameters(), args.lr,
  153. momentum=args.momentum,
  154. weight_decay=args.weight_decay)
  155. # optimizer = torch.optim.Adam(model.parameters(), args.lr,
  156. # betas=args.momentum,
  157. # weight_decay=args.weight_decay)
  158. scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
  159. #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
  160. if args.resume:
  161. if os.path.isfile(args.resume):
  162. logging.info("=> => loading checkpoint '{}'".format(args.resume))
  163. checkpoint = torch.load(args.resume)
  164. model.load_state_dict(checkpoint)
  165. else:
  166. print("=> no checkpoint found at '{}'".format(args.resume))
  167. logging.info("=> no checkpoint found at '{}'".format(args.resume))
  168. cudnn.benchmark = True
  169. if args.is_local_not_platform:
  170. traindir = os.path.join(args.data, 'train')
  171. valdir = os.path.join(args.data, 'val')
  172. if args.custom_dataset == 'dataset_for_local':
  173. train_dataset = BasicDatasetFolder_ForLocal(traindir, transform=True, balance=True, image_size=args.image_size, class_dict=args.class_dict)
  174. val_dataset = BasicDatasetFolder_ForLocal(valdir, transform=False, balance=False, image_size=args.image_size, class_dict=args.class_dict)
  175. else:
  176. if args.custom_dataset == 'dataset_for_platform':
  177. train_dataset = BasicDatasetFolder_ForPlatform(split_data=args.split_data,
  178. istrain=True,
  179. transform=True,
  180. balance=False,
  181. token=args.token,
  182. class_dict=args.class_dict,
  183. image_size=args.image_size,
  184. split_rate=args.split_rate)
  185. val_dataset = BasicDatasetFolder_ForPlatform(split_data=args.split_data,
  186. istrain=False,
  187. transform=False,
  188. balance=False,
  189. token=args.token,
  190. class_dict=args.class_dict,
  191. image_size=args.image_size,
  192. split_rate=args.split_rate)
  193. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None)
  194. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
  195. # 早停策略
  196. # early_stopping, stop = EarlyStopping(patience=int(args.epochs/5)), False
  197. best_val_acc = 0.0
  198. epoch_list, lossTr_list, lossVal_list, accTr_list, accVal_list = [], [], [], [], []
  199. # switch to train mode
  200. for epoch in range(args.start_epoch, args.epochs):
  201. localtime = time.asctime(time.localtime(time.time()))
  202. # logging.info("Current time:{}".format(localtime))
  203. adjust_learning_rate(optimizer, epoch, args)
  204. acc_tr, loss_tr = train(train_loader, model, criterion, optimizer, epoch, train_dataset.exampes_nums ,args)
  205. accTr_list.append(float(acc_tr))
  206. lossTr_list.append(loss_tr)
  207. scheduler.step()
  208. #根据设定的validate_frequency,测试val数据和保存模型
  209. if epoch % args.validate_frequency == 0:
  210. epoch_list.append(epoch)
  211. # evaluate on validation set
  212. acc_var ,loss_var = validate(val_loader, model, criterion, val_dataset.exampes_nums, args)
  213. accVal_list.append(float(acc_var))
  214. lossVal_list.append(loss_var)
  215. # scheduler.step(acc_var)
  216. is_best = acc_var > best_val_acc
  217. # logging.info("=> 第{}的epoch中的is_best为:{}, acc1为:{}\n\n".format(epoch, is_best, acc_var))
  218. best_val_acc = max(acc_var, best_val_acc)
  219. # early_stopping.monitor(monitor=acc_var)
  220. # stop = early_stopping.early_stop
  221. if (args.rank % ngpus_per_node == 0 and (is_best)):
  222. torch.save(model.state_dict(), os.path.join(save_path, f'CP_epoch{epoch + 1}.pth'))
  223. if not args.is_local_not_platform:
  224. TrainSdk.save_output_model(os.path.join(save_path, f'CP_epoch{epoch + 1}.pth'))
  225. # EarlyStopping
  226. # if stop:
  227. # print(f"Early stopping at epoch {epoch}")
  228. # break
  229. # Draw log fig
  230. draw_log(save_path, args.validate_frequency, epoch, epoch_list, lossTr_list, lossVal_list, accTr_list, accVal_list)
  231. def train(train_loader, model, criterion, optimizer, epoch, sample_nums,args):
  232. if args.is_local_not_platform:
  233. total_numbers_iteration = len(train_loader)
  234. else:
  235. total_numbers_iteration = sample_nums//args.batch_size
  236. end = time.time()
  237. # switch to train mode
  238. model.train()
  239. batch_time = AverageMeter('Time', ':6.3f')
  240. data_time = AverageMeter('Data', ':6.3f')
  241. losses = AverageMeter('Loss', ':6.3f')
  242. top1 = AverageMeter('Acc@1', ':6.3f')
  243. progress = ProgressMeter(total_numbers_iteration, batch_time, data_time, losses, top1, prefix="Epoch: [{}]".format(epoch))
  244. for i, (images, target) in enumerate(train_loader):
  245. data_time.update(time.time() - end)
  246. # logging.info("=> images {}, label_index {}".format(images, target))
  247. if args.gpu is not None:
  248. images = images.cuda(args.gpu, non_blocking=True)
  249. target = target.cuda(args.gpu, non_blocking=True)
  250. # compute output
  251. output = model(images)
  252. loss = criterion(output, target)
  253. prec1 = accuracy(output.data, target.data, topk=(1,))
  254. top1.update(prec1[0], images.size(0))
  255. losses.update(loss.item(), images.size(0))
  256. # compute gradient and do SGD step
  257. optimizer.zero_grad()
  258. loss.backward()
  259. optimizer.step()
  260. # measure elapsed time
  261. batch_time.update(time.time() - end)
  262. end = time.time()
  263. if i % args.print_freq == 0:
  264. progress.print(i)
  265. return top1.get_avg() , losses.get_avg()
  266. def validate(val_loader, model, criterion, sample_nums, args):
  267. batch_time = AverageMeter('Time', ':6.3f')
  268. losses = AverageMeter('Loss', ':6.3f')
  269. top1 = AverageMeter('Acc@1', ':6.2f')
  270. if args.is_local_not_platform:
  271. total_numbers_iteration = len(val_loader)
  272. else:
  273. total_numbers_iteration = sample_nums//args.batch_size
  274. progress = ProgressMeter(total_numbers_iteration, batch_time, losses, top1, prefix='Test: ')
  275. # switch to evaluate mode
  276. model.eval()
  277. with torch.no_grad():
  278. end = time.time()
  279. for i, (images, target) in enumerate(val_loader):
  280. if args.gpu is not None:
  281. images = images.cuda()
  282. target = target.cuda()
  283. # compute output
  284. output = model(images)
  285. loss = criterion(output, target)
  286. acc1 = accuracy(output, target, topk=(1,))
  287. losses.update(loss.item(), images.size(0))
  288. top1.update(acc1[0], images.size(0))
  289. # measure elapsed time
  290. batch_time.update(time.time() - end)
  291. end = time.time()
  292. if i % args.print_freq == 0:
  293. progress.print(i)
  294. return top1.avg, losses.avg
  295. def accuracy(output, target, topk=(1,)):
  296. """Computes the accuracy over the k top predictions for the specified values of k"""
  297. maxk = max(topk)
  298. batch_size = target.size(0)
  299. _, pred = output.topk(maxk, 1, True, True)
  300. pred = pred.t()
  301. correct = pred.eq(target.view(1, -1).expand_as(pred))
  302. res = []
  303. for k in topk:
  304. correct_k = correct[:k].view(-1).float().sum(0)
  305. res.append(correct_k.mul_(100.0 / batch_size))
  306. return res
  307. class AverageMeter(object):
  308. """Computes and stores the average and current value"""
  309. def __init__(self, name, fmt=':f'):
  310. self.name = name
  311. self.fmt = fmt
  312. self.reset()
  313. def reset(self):
  314. self.val = 0
  315. self.avg = 0
  316. self.sum = 0
  317. self.count = 0
  318. def update(self, val, n=1):
  319. self.val = val
  320. self.sum += val * n
  321. self.count += n
  322. self.avg = self.sum / self.count
  323. def get_avg(self):
  324. return self.avg
  325. def __str__(self):
  326. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  327. return fmtstr.format(**self.__dict__)
  328. class ProgressMeter(object):
  329. def __init__(self, num_batches, *meters, prefix=""):
  330. self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
  331. self.meters = meters
  332. self.prefix = prefix
  333. def print(self, batch):
  334. entries = [self.prefix + self.batch_fmtstr.format(batch)]
  335. entries += [str(meter) for meter in self.meters]
  336. print('\t'.join(entries))
  337. def _get_batch_fmtstr(self, num_batches):
  338. num_digits = len(str(num_batches // 1))
  339. fmt = '{:' + str(num_digits) + 'd}'
  340. return '[' + fmt + '/' + fmt.format(num_batches) + ']'
  341. def adjust_learning_rate(optimizer, epoch, args):
  342. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  343. lr = args.lr * (0.1 ** (epoch // 30))
  344. for param_group in optimizer.param_groups:
  345. param_group['lr'] = lr
  346. if __name__ == '__main__':
  347. main()