val.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @ File : val.py
  4. # @ Author : Guido LuXiaohao
  5. # @ Date : 2021/8/19
  6. # @ Software : PyCharm
  7. # @ Description: 代码文件描述。
  8. import os
  9. import sys
  10. import zipfile
  11. from argparse import ArgumentParser
  12. from copy import deepcopy
  13. from pathlib import Path
  14. from typing import Sequence
  15. import numpy as np
  16. import torch
  17. from torch.utils.data import DataLoader
  18. from tqdm import tqdm
  19. from dataset import custom_collate_fn
  20. from dataset.data_process import DataProcess
  21. from dataset.utils import parse_metadata_info
  22. from utils.config import Config
  23. from utils.utils import convert_nested_tensors_to_device, select_device
  24. FILE = Path(__file__).resolve()
  25. ROOT = FILE.parents[0] # project root directory
  26. if str(ROOT) not in sys.path:
  27. sys.path.append(str(ROOT)) # add ROOT to PATH
  28. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative path
  29. def zipDir(dir_path, save_name=""):
  30. zip_file = dir_path + '.zip' if not save_name.endswith(".zip") else save_name # 压缩后文件夹的名字
  31. z = zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) # 参数一:文件夹名
  32. for dirpath, dirnames, filenames in os.walk(dir_path):
  33. fpath = dirpath.replace(dir_path, '') # 这一句很重要,不replace的话,就从根目录开始复制
  34. fpath = fpath and fpath + os.sep or '' # 这句话理解我也点郁闷,实现当前文件夹以及包含的所有文件的压缩
  35. for filename in filenames:
  36. z.write(os.path.join(dirpath, filename), fpath + filename)
  37. # print ('压缩成功')
  38. z.close()
  39. def evaluate(cfg, model, eval_loader, scales=[1.0], overlap=0.0, save=False):
  40. model.eval()
  41. device = next(model.parameters()).device
  42. loss = 0
  43. loss_cnt = 0
  44. pixel_confidence_threshold = cfg['postprocess_args']['pixel_confidence_threshold']
  45. limit_area = cfg['postprocess_args']['limit_area']
  46. if not isinstance(limit_area, list):
  47. limit_area = [limit_area]
  48. limit_area_rate = cfg['postprocess_args']['limit_area_rate']
  49. if not isinstance(limit_area_rate, list):
  50. limit_area_rate = [limit_area_rate]
  51. limit_number = cfg['postprocess_args']['limit_number']
  52. if not isinstance(limit_number, list):
  53. limit_number = [limit_number]
  54. needed_rois_dict = cfg['metric_args']['needed_rois_dict']
  55. if not isinstance(needed_rois_dict, list):
  56. needed_rois_dict = [needed_rois_dict]
  57. class_id_map = cfg['metric_args']['class_id_map']
  58. if not isinstance(class_id_map, list):
  59. class_id_map = [class_id_map]
  60. # Metric method
  61. _metrics = cfg.metrics
  62. for idx, _metric in enumerate(_metrics):
  63. dataset_meta = dict(
  64. pixel_confidence_threshold=pixel_confidence_threshold,
  65. limit_area=limit_area[idx],
  66. limit_area_rate=limit_area_rate[idx],
  67. limit_number=limit_number[idx],
  68. needed_rois_dict=needed_rois_dict[idx],
  69. class_id_map=class_id_map[idx],
  70. class_names=cfg['roi_cat'][idx],
  71. crop=cfg['val_dataset'].get('crop_class_index', None),
  72. )
  73. _metric.init_dataset_meta(dataset_meta)
  74. pbar = tqdm(iterable=enumerate(eval_loader), total=len(eval_loader), desc='Evaluating on validation dataset')
  75. for step, (image, data_samples) in pbar: # ---------------- start evaluating
  76. with torch.no_grad():
  77. image = image.to(device)
  78. data_samples = convert_nested_tensors_to_device(data_samples, device)
  79. loss_val = model(image, data_samples, mode='loss')
  80. predictions = model(image, data_samples, mode='predict')
  81. loss += sum(loss_val.values())
  82. loss_cnt += 1
  83. # Add step to data_samples for `Metric`
  84. for data_sample in data_samples:
  85. if isinstance(data_sample, Sequence):
  86. # adapt to multi-output model
  87. for i_data_sample in data_sample:
  88. i_data_sample.set_metainfo(dict(step=step))
  89. else:
  90. data_sample.set_metainfo(dict(step=step))
  91. for idx, _metric in enumerate(_metrics):
  92. if isinstance(predictions[0], Sequence):
  93. task_predictions = [i_predictions[idx] for i_predictions in predictions]
  94. _metric.process(image, task_predictions)
  95. else:
  96. _metric.process(image, predictions)
  97. metrics = dict(loss=loss/loss_cnt)
  98. for _metric in _metrics:
  99. ret_metric = _metric.evaluate()
  100. metrics.update(ret_metric)
  101. return metrics
  102. def parse_args():
  103. parser = ArgumentParser(description='Semantic segmentation with pytorch')
  104. parser.add_argument(
  105. '--cfg',
  106. help="training configuration file",
  107. type=str,
  108. default=ROOT / 'model.yaml')
  109. parser.add_argument(
  110. '--weights',
  111. help="initial weights path",
  112. type=str,
  113. default='')
  114. parser.add_argument(
  115. '--batch_size',
  116. help="total batch size for all GPUs, -1 for autobatch",
  117. type=int,
  118. default=8)
  119. parser.add_argument(
  120. '--num_workers',
  121. help=" the number of parallel threads",
  122. type=int,
  123. default=1)
  124. parser.add_argument(
  125. '--token',
  126. help="AI平台训练码",
  127. type=str,
  128. default="4925EC4929684AA0ABB0173B03CFC8FF")
  129. parser.add_argument(
  130. '--val_ratio',
  131. help="在线随机选取验证集(验证集占总数据集比例),剩下的作为训练集",
  132. type=float,
  133. default=0.)
  134. parser.add_argument(
  135. '--device',
  136. help="cuda device, i.e. 0 or 0,1,2,3 or cpu",
  137. type=str,
  138. default="0")
  139. return parser.parse_args()
  140. def run(args):
  141. weights = args.weights
  142. device = select_device(args.device, args.batch_size)
  143. if not args.cfg:
  144. raise RuntimeError('No configuration file specified.')
  145. cfg = Config(args.cfg, batch_size=args.batch_size)
  146. # Dataset properties
  147. tasks, part_cat, roi_cat, cat_id, label_map = \
  148. parse_metadata_info(metadata=deepcopy(cfg["metadata"]))
  149. cfg.update(**{
  150. 'tasks': tasks,
  151. 'part_cat': part_cat,
  152. 'roi_cat': roi_cat,
  153. 'cat_id': cat_id,
  154. 'label_map': label_map
  155. })
  156. # Data preprocessing
  157. data_processor = DataProcess(
  158. token=args.token,
  159. val_ratio=args.val_ratio,
  160. image_cat=cfg["metric_args"]["needed_imageresults_dict"],
  161. part_cat=part_cat,
  162. roi_cat=roi_cat,
  163. cat_id=cat_id,
  164. ignore=cfg.dic.get("ignore", None),
  165. video_mode=cfg.dic.get("split_by_snippet", False))
  166. trainval_indexes = data_processor.get_train_val_indexes
  167. train_data_index_list = trainval_indexes["train"]
  168. val_data_index_list = trainval_indexes["val"]
  169. # update dataset info
  170. cfg['train_dataset'].update({
  171. 'token': args.token,
  172. 'tasks': tasks,
  173. 'data_index_list': train_data_index_list,
  174. 'class_index_map': cat_id})
  175. cfg['val_dataset'].update({
  176. 'token': args.token,
  177. 'tasks': tasks,
  178. 'data_index_list': val_data_index_list,
  179. 'class_index_map': cat_id})
  180. val_dataset = cfg.val_dataset
  181. val_loader = DataLoader(
  182. val_dataset,
  183. batch_size=args.batch_size,
  184. num_workers=args.num_workers,
  185. collate_fn=custom_collate_fn,
  186. pin_memory=True,
  187. drop_last=True if args.batch_size > 1 else False)
  188. # Model
  189. model = cfg.model
  190. model = model.to(device)
  191. pretrained = weights.endswith('.pth')
  192. if pretrained:
  193. ckpt = torch.load(weights, map_location='cpu')
  194. csd = ckpt.get("ema") or ckpt['model']
  195. model.load_state_dict(csd, strict=True) # load
  196. # Start evaluation
  197. metrics = evaluate(cfg, model, val_loader)
  198. # Calculate the mean value over tasks
  199. miou = np.mean([v for m, v in metrics.items() if m.endswith('.mIoU')])
  200. all_precision = np.mean([v for m, v in metrics.items() if m.endswith('.aPrecision')])
  201. all_recall = np.mean([v for m, v in metrics.items() if m.endswith('.aRecall')])
  202. map50 = np.mean([v for m, v in metrics.items() if m.endswith('.mAP50')])
  203. map = np.mean([v for m, v in metrics.items() if m.endswith('.mAP')])
  204. print('Validating {}\n'
  205. 'miou: {:.4f}\n'
  206. 'total precision: {:.4f}, total recall: {:.4f}, map50: {:.4f}, map50-95: {:.4f}\n'
  207. .format(args.weights, miou, all_precision, all_recall, map50, map))
  208. if __name__ == '__main__':
  209. args = parse_args()
  210. run(args)