123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- import argparse
- import os
- import random
- import time
- import warnings
- import torch
- import torch.nn as nn
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- from torch.optim import lr_scheduler
- from torch.utils.data import DataLoader
- import torch.utils.data.distributed
- import model.resnext_pytorch as resnext
- from model.ghostnet import ghostnet
- from model.mobilenetv3 import mobilenetv3
- from dataset.dataset_for_platform import BasicDatasetFolder_ForPlatform
- from dataset.dataset_for_local import BasicDatasetFolder_ForLocal
- from model.efficientnet_pytorch import EfficientNet
- from utils.earlyStopping import EarlyStopping
- from utils.plot_log import draw_log
- import TrainSdk
- import logging
- import datetime
- parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
- parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenetv3',help='model architecture efficientnet-b0 efficientnet-b1 ghostnet MobileNetV3 resnext50 mobilenetv3')
- parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
- parser.add_argument('--epochs', default=50, type=int, metavar='N', help='number of total epochs to run')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
- parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N')
- parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,metavar='LR', help='initial learning rate', dest='lr')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
- parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay')
- parser.add_argument('-p', '--print-freq', default=20, type=int, metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate model on validation set')
- parser.add_argument('--rank', default=1, type=int, help='node rank for distributed training')
- parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
- parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
- parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
- parser.add_argument('--image_size', default=256, type=int, help='image size')
- parser.add_argument('--class_nums', default=9, type=int, help='分类的类别数量')
- parser.add_argument('--checkpoints', default=os.path.join(os.getcwd(),'checkpoints'), type=str, metavar='PATH', help='path of checkpoint')
- parser.add_argument('--validate_frequency', default=1, type=int, help='每多少个epoch测试一下val')
- #若本地训练,选'dataset_for_local'模式,平台训练选'dataset_for_platform'
- parser.add_argument('--custom_dataset', default='dataset_for_platform', type=str, help='本地训练:"dataset_for_local" , 平台训练:"dataset_for_platform"')
- #判断在平台还是在本地训练
- parser.add_argument('--is_local_not_platform', default=False, type=bool, help='train local is True, train platform is False')
- #本地训练,需要给定数据文件夹
- parser.add_argument('--data', default=os.path.join(os.getcwd(), 'data'), help='train local, need set data path, data\\train data\\val')
- #在AI平台上训练时在线划分训练测试集,没有预先划分的情况下使用
- parser.add_argument('--split_data', default= True, type=bool,help='train platform, need split the training set and the verification set')
- parser.add_argument('--split_rate', default=0.8, type=float, help='在线划分数据集的比例,数据总集的80%作为训练集,剩余的作为验证集')
- classes_dict = {'BMode':0, 'BModeBlood':1,'Pseudocolor':2, 'PseudocolorBlood':3, 'Spectrogram':4, 'CEUS':5, 'SE':6,'STE':7,'FourDime':8}
- parser.add_argument('--class_dict', default=classes_dict, help='每个文件夹的文件名对应其类别')
- '''
- 给定的data文件夹为root
- root/
- train/
- classA/
- folder1/
- *.jpg
- folder2/
- *.jpg
- classB/
- folder1/
- *.jpg
- folder2/
- *.jpg
- ......
- val/
- classA/
- folder1/
- *.jpg
- folder2/
- *.jpg
- classB/
- folder1/
- *.jpg
- folder2/
- *.jpg
- ......
- 给定data文件夹,分成train和val两个文件夹,
- train或者val文件夹中每个类别class一个文件夹,
- 每个类别class文件中,可以存放多个文件夹,用于多个来源,相同类别文件
- '''
- #平台训练,需要给定token,train和val的图像label的title,所对应的类别
- parser.add_argument('--token', default='4925EC4929684AA0ABB0173B03CFC8FF', type=str, help='train platform,need set token') # 52cfaf5a8a364e5fb635c94683963cfe
- #dataset_for_usguide 图像增强所需的参数配置
- # parser.add_argument('--brightness_radius', default=0.3, type=float,help='输入图像,亮度调整')
- # parser.add_argument('--contrast_radius', default=0.3, type=float,help='输入图像,对比度调整')
- # parser.add_argument('--saturation_radius', default=0, type=float,help='输入图像,饱和度调整')
- logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',level=logging.DEBUG)
- def main():
- args = parser.parse_args()
- args_dict = vars(args)
- for k, v in args_dict.items():
- print(f"{k}: {v}")
- logging.info("para:{} default:{}".format(k, v))
- if args.seed is not None:
- random.seed(args.seed)
- torch.manual_seed(args.seed)
- cudnn.deterministic = True
- warnings.warn('You have chosen to seed training. '
- 'This will turn on the CUDNN deterministic setting, '
- 'which can slow down your training considerably! '
- 'You may see unexpected behavior when restarting '
- 'from checkpoints.')
- if args.gpu is not None:
- warnings.warn('You have chosen a specific GPU. This will completely '
- 'disable data parallelism.')
- ngpus_per_node = torch.cuda.device_count()
- main_worker(args.gpu, ngpus_per_node, args)
- def main_worker(gpu, ngpus_per_node, args):
- global best_acc1
- args.gpu = gpu
- if not os.path.exists(args.checkpoints):
- os.mkdir(args.checkpoints)
- now_time = datetime.datetime.now()
- time_str = datetime.datetime.strftime(now_time, '%m-%d_%H-%M-%S')
- save_path = os.path.join(args.checkpoints, args.arch, time_str)
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- if args.gpu is not None:
- print("Use GPU: {} for training".format(args.arch))
- # logging.info("Use GPU: {} for training".format(args.arch))
- # create model
- if 'efficientnet' in args.arch: # NEW
- logging.info("=> creating model '{}'".format(args.arch))
- #bo b1 b2......不同的efficientnet有不同的默认image_size,可参考 model.efficientnet_pytorch.utils.efficientnet_params
- model = EfficientNet.from_name(model_name=args.arch, in_channels=3, num_classes=args.class_nums, image_size=args.image_size)
- elif 'ghostnet' in args.arch: # NEW
- logging.info("=> creating model '{}'".format(args.arch))
- model = ghostnet.ghostnet(num_classes=args.class_nums, channels=3)
- elif 'mobilenetv3' in args.arch: # NEW
- logging.info("=> creating model '{}'".format(args.arch))
- # mode分别为 small large Rebbcca_LiverDiffuseLesionClassifier Joseph_USGuide
- model = mobilenetv3.MobileNetV3(n_class=args.class_nums, input_size=args.image_size, mode='small',channels=3, dropout=0, width_mult=1.0)
- elif 'resnext' in args.arch:
- logging.info("=> creating model '{}'".format(args.arch))
- model = resnext.from_name(model_name=args.arch, basewidth=4, cardinality=8, class_nums=args.class_nums)
- else:
- warnings.warn('You have chosen a wrong model.Using a default model instead.')
- # model = resnext.from_name(model_name=args.arch, basewidth=4, cardinality=8, class_nums=args.class_nums)
- if args.gpu is not None:
- torch.cuda.set_device(args.gpu)
- model = model.cuda(args.gpu)
- else:
- model = torch.nn.DataParallel(model).cuda()
- print(model)
- # define loss function (criterion) and optimizer
- criterion = nn.CrossEntropyLoss().cuda(args.gpu)
- optimizer = torch.optim.SGD(model.parameters(), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
- # optimizer = torch.optim.Adam(model.parameters(), args.lr,
- # betas=args.momentum,
- # weight_decay=args.weight_decay)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
- #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
- if args.resume:
- if os.path.isfile(args.resume):
- logging.info("=> => loading checkpoint '{}'".format(args.resume))
- checkpoint = torch.load(args.resume)
- model.load_state_dict(checkpoint)
- else:
- print("=> no checkpoint found at '{}'".format(args.resume))
- logging.info("=> no checkpoint found at '{}'".format(args.resume))
- cudnn.benchmark = True
- if args.is_local_not_platform:
- traindir = os.path.join(args.data, 'train')
- valdir = os.path.join(args.data, 'val')
- if args.custom_dataset == 'dataset_for_local':
- train_dataset = BasicDatasetFolder_ForLocal(traindir, transform=True, balance=True, image_size=args.image_size, class_dict=args.class_dict)
- val_dataset = BasicDatasetFolder_ForLocal(valdir, transform=False, balance=False, image_size=args.image_size, class_dict=args.class_dict)
- else:
- if args.custom_dataset == 'dataset_for_platform':
- train_dataset = BasicDatasetFolder_ForPlatform(split_data=args.split_data,
- istrain=True,
- transform=True,
- balance=False,
- token=args.token,
- class_dict=args.class_dict,
- image_size=args.image_size,
- split_rate=args.split_rate)
- val_dataset = BasicDatasetFolder_ForPlatform(split_data=args.split_data,
- istrain=False,
- transform=False,
- balance=False,
- token=args.token,
- class_dict=args.class_dict,
- image_size=args.image_size,
- split_rate=args.split_rate)
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None)
- val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
- # 早停策略
- # early_stopping, stop = EarlyStopping(patience=int(args.epochs/5)), False
- best_val_acc = 0.0
- epoch_list, lossTr_list, lossVal_list, accTr_list, accVal_list = [], [], [], [], []
- # switch to train mode
- for epoch in range(args.start_epoch, args.epochs):
- localtime = time.asctime(time.localtime(time.time()))
- # logging.info("Current time:{}".format(localtime))
- adjust_learning_rate(optimizer, epoch, args)
- acc_tr, loss_tr = train(train_loader, model, criterion, optimizer, epoch, train_dataset.exampes_nums ,args)
- accTr_list.append(float(acc_tr))
- lossTr_list.append(loss_tr)
- scheduler.step()
- #根据设定的validate_frequency,测试val数据和保存模型
- if epoch % args.validate_frequency == 0:
- epoch_list.append(epoch)
- # evaluate on validation set
- acc_var ,loss_var = validate(val_loader, model, criterion, val_dataset.exampes_nums, args)
- accVal_list.append(float(acc_var))
- lossVal_list.append(loss_var)
- # scheduler.step(acc_var)
- is_best = acc_var > best_val_acc
- # logging.info("=> 第{}的epoch中的is_best为:{}, acc1为:{}\n\n".format(epoch, is_best, acc_var))
- best_val_acc = max(acc_var, best_val_acc)
- # early_stopping.monitor(monitor=acc_var)
- # stop = early_stopping.early_stop
- if (args.rank % ngpus_per_node == 0 and (is_best)):
- torch.save(model.state_dict(), os.path.join(save_path, f'CP_epoch{epoch + 1}.pth'))
- if not args.is_local_not_platform:
- TrainSdk.save_output_model(os.path.join(save_path, f'CP_epoch{epoch + 1}.pth'))
- # EarlyStopping
- # if stop:
- # print(f"Early stopping at epoch {epoch}")
- # break
- # Draw log fig
- draw_log(save_path, args.validate_frequency, epoch, epoch_list, lossTr_list, lossVal_list, accTr_list, accVal_list)
- def train(train_loader, model, criterion, optimizer, epoch, sample_nums,args):
- if args.is_local_not_platform:
- total_numbers_iteration = len(train_loader)
- else:
- total_numbers_iteration = sample_nums//args.batch_size
- end = time.time()
- # switch to train mode
- model.train()
- batch_time = AverageMeter('Time', ':6.3f')
- data_time = AverageMeter('Data', ':6.3f')
- losses = AverageMeter('Loss', ':6.3f')
- top1 = AverageMeter('Acc@1', ':6.3f')
- progress = ProgressMeter(total_numbers_iteration, batch_time, data_time, losses, top1, prefix="Epoch: [{}]".format(epoch))
- for i, (images, target) in enumerate(train_loader):
- data_time.update(time.time() - end)
- # logging.info("=> images {}, label_index {}".format(images, target))
- if args.gpu is not None:
- images = images.cuda(args.gpu, non_blocking=True)
- target = target.cuda(args.gpu, non_blocking=True)
- # compute output
- output = model(images)
- loss = criterion(output, target)
- prec1 = accuracy(output.data, target.data, topk=(1,))
- top1.update(prec1[0], images.size(0))
- losses.update(loss.item(), images.size(0))
- # compute gradient and do SGD step
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- if i % args.print_freq == 0:
- progress.print(i)
-
- return top1.get_avg() , losses.get_avg()
- def validate(val_loader, model, criterion, sample_nums, args):
- batch_time = AverageMeter('Time', ':6.3f')
- losses = AverageMeter('Loss', ':6.3f')
- top1 = AverageMeter('Acc@1', ':6.2f')
- if args.is_local_not_platform:
- total_numbers_iteration = len(val_loader)
- else:
- total_numbers_iteration = sample_nums//args.batch_size
- progress = ProgressMeter(total_numbers_iteration, batch_time, losses, top1, prefix='Test: ')
- # switch to evaluate mode
- model.eval()
- with torch.no_grad():
- end = time.time()
- for i, (images, target) in enumerate(val_loader):
- if args.gpu is not None:
- images = images.cuda()
- target = target.cuda()
- # compute output
- output = model(images)
- loss = criterion(output, target)
- acc1 = accuracy(output, target, topk=(1,))
- losses.update(loss.item(), images.size(0))
- top1.update(acc1[0], images.size(0))
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- if i % args.print_freq == 0:
- progress.print(i)
- return top1.avg, losses.avg
- def accuracy(output, target, topk=(1,)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- maxk = max(topk)
- batch_size = target.size(0)
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
- class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self, name, fmt=':f'):
- self.name = name
- self.fmt = fmt
- self.reset()
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
- def get_avg(self):
- return self.avg
- def __str__(self):
- fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
- return fmtstr.format(**self.__dict__)
- class ProgressMeter(object):
- def __init__(self, num_batches, *meters, prefix=""):
- self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
- self.meters = meters
- self.prefix = prefix
- def print(self, batch):
- entries = [self.prefix + self.batch_fmtstr.format(batch)]
- entries += [str(meter) for meter in self.meters]
- print('\t'.join(entries))
- def _get_batch_fmtstr(self, num_batches):
- num_digits = len(str(num_batches // 1))
- fmt = '{:' + str(num_digits) + 'd}'
- return '[' + fmt + '/' + fmt.format(num_batches) + ']'
- def adjust_learning_rate(optimizer, epoch, args):
- """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
- lr = args.lr * (0.1 ** (epoch // 30))
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- if __name__ == '__main__':
- main()
|